mirror of
https://github.com/anchore/syft.git
synced 2026-07-05 02:28:25 +02:00
lint: lintfix
Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
parent
9644340981
commit
b731aa4f33
@ -35,36 +35,37 @@ type safeTensorsEntry struct {
|
|||||||
DataOffsets []int64 `json:"data_offsets"`
|
DataOffsets []int64 `json:"data_offsets"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// readSafeTensorsHeader reads and parses the JSON header from a .safetensors file.
|
// readSafeTensorsHeader reads and parses the JSON header from a .safetensors
|
||||||
// It returns the decoded header plus the on-disk size of the header JSON in bytes.
|
// file (the leading `[8-byte LE length] [length bytes of JSON]` block) and
|
||||||
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) {
|
// returns the decoded header.
|
||||||
|
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, error) {
|
||||||
var lenBuf [8]byte
|
var lenBuf [8]byte
|
||||||
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to read header length: %w", err)
|
return nil, fmt.Errorf("failed to read header length: %w", err)
|
||||||
}
|
}
|
||||||
headerLen := binary.LittleEndian.Uint64(lenBuf[:])
|
headerLen := binary.LittleEndian.Uint64(lenBuf[:])
|
||||||
if headerLen == 0 {
|
if headerLen == 0 {
|
||||||
return nil, 0, fmt.Errorf("safetensors header length is zero")
|
return nil, fmt.Errorf("safetensors header length is zero")
|
||||||
}
|
}
|
||||||
if headerLen > maxSafeTensorsHeaderSize {
|
if headerLen > maxSafeTensorsHeaderSize {
|
||||||
return nil, 0, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
|
return nil, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
body := make([]byte, headerLen)
|
body := make([]byte, headerLen)
|
||||||
if _, err := io.ReadFull(r, body); err != nil {
|
if _, err := io.ReadFull(r, body); err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to read header body: %w", err)
|
return nil, fmt.Errorf("failed to read header body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var raw map[string]json.RawMessage
|
var raw map[string]json.RawMessage
|
||||||
if err := json.Unmarshal(body, &raw); err != nil {
|
if err := json.Unmarshal(body, &raw); err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to decode safetensors header JSON: %w", err)
|
return nil, fmt.Errorf("failed to decode safetensors header JSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &safeTensorsHeader{tensors: make(map[string]safeTensorsEntry, len(raw))}
|
h := &safeTensorsHeader{tensors: make(map[string]safeTensorsEntry, len(raw))}
|
||||||
for key, val := range raw {
|
for key, val := range raw {
|
||||||
if key == "__metadata__" {
|
if key == "__metadata__" {
|
||||||
if err := json.Unmarshal(val, &h.metadata); err != nil {
|
if err := json.Unmarshal(val, &h.metadata); err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to decode __metadata__: %w", err)
|
return nil, fmt.Errorf("failed to decode __metadata__: %w", err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -76,7 +77,7 @@ func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) {
|
|||||||
h.tensors[key] = entry
|
h.tensors[key] = entry
|
||||||
}
|
}
|
||||||
|
|
||||||
return h, headerLen, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parameterCount sums the element counts across all tensors in the header.
|
// parameterCount sums the element counts across all tensors in the header.
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import (
|
|||||||
func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||||
defer internal.CloseAndLogError(reader, reader.Path())
|
defer internal.CloseAndLogError(reader, reader.Path())
|
||||||
|
|
||||||
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err)
|
return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -95,7 +95,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En
|
|||||||
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||||
defer internal.CloseAndLogError(reader, reader.Path())
|
defer internal.CloseAndLogError(reader, reader.Path())
|
||||||
|
|
||||||
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err)
|
return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -641,23 +641,22 @@ func TestReadSafeTensorsHeader(t *testing.T) {
|
|||||||
data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{
|
data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{
|
||||||
"w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
|
"w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
|
||||||
})
|
})
|
||||||
h, n, err := readSafeTensorsHeader(bytes.NewReader(data))
|
h, err := readSafeTensorsHeader(bytes.NewReader(data))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint64(len(data)-8), n)
|
|
||||||
assert.Len(t, h.tensors, 1)
|
assert.Len(t, h.tensors, 1)
|
||||||
assert.Equal(t, "pt", h.metadata["format"])
|
assert.Equal(t, "pt", h.metadata["format"])
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("zero-length header", func(t *testing.T) {
|
t.Run("zero-length header", func(t *testing.T) {
|
||||||
var buf [8]byte // length prefix of 0
|
var buf [8]byte // length prefix of 0
|
||||||
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
_, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("truncated body", func(t *testing.T) {
|
t.Run("truncated body", func(t *testing.T) {
|
||||||
var buf [8]byte
|
var buf [8]byte
|
||||||
binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none
|
binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none
|
||||||
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
_, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user