diff --git a/syft/pkg/cataloger/ai/parse_safetensors.go b/syft/pkg/cataloger/ai/parse_safetensors.go index 6615d7ee6..52e48a19e 100644 --- a/syft/pkg/cataloger/ai/parse_safetensors.go +++ b/syft/pkg/cataloger/ai/parse_safetensors.go @@ -35,36 +35,37 @@ type safeTensorsEntry struct { DataOffsets []int64 `json:"data_offsets"` } -// readSafeTensorsHeader reads and parses the JSON header from a .safetensors file. -// It returns the decoded header plus the on-disk size of the header JSON in bytes. -func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) { +// readSafeTensorsHeader reads and parses the JSON header from a .safetensors +// file (the leading `[8-byte LE length] [length bytes of JSON]` block) and +// returns the decoded header. +func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, error) { var lenBuf [8]byte if _, err := io.ReadFull(r, lenBuf[:]); err != nil { - return nil, 0, fmt.Errorf("failed to read header length: %w", err) + return nil, fmt.Errorf("failed to read header length: %w", err) } headerLen := binary.LittleEndian.Uint64(lenBuf[:]) if headerLen == 0 { - return nil, 0, fmt.Errorf("safetensors header length is zero") + return nil, fmt.Errorf("safetensors header length is zero") } if headerLen > maxSafeTensorsHeaderSize { - return nil, 0, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize) + return nil, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize) } body := make([]byte, headerLen) if _, err := io.ReadFull(r, body); err != nil { - return nil, 0, fmt.Errorf("failed to read header body: %w", err) + return nil, fmt.Errorf("failed to read header body: %w", err) } var raw map[string]json.RawMessage if err := json.Unmarshal(body, &raw); err != nil { - return nil, 0, fmt.Errorf("failed to decode safetensors header JSON: %w", err) + return nil, fmt.Errorf("failed to decode safetensors header JSON: %w", err) } h := &safeTensorsHeader{tensors: make(map[string]safeTensorsEntry, len(raw))} for key, val := range raw { if key == "__metadata__" { if err := json.Unmarshal(val, &h.metadata); err != nil { - return nil, 0, fmt.Errorf("failed to decode __metadata__: %w", err) + return nil, fmt.Errorf("failed to decode __metadata__: %w", err) } continue } @@ -76,7 +77,7 @@ func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) { h.tensors[key] = entry } - return h, headerLen, nil + return h, nil } // parameterCount sums the element counts across all tensors in the header. diff --git a/syft/pkg/cataloger/ai/parse_safetensors_model.go b/syft/pkg/cataloger/ai/parse_safetensors_model.go index 28bf11b78..d7288f1ad 100644 --- a/syft/pkg/cataloger/ai/parse_safetensors_model.go +++ b/syft/pkg/cataloger/ai/parse_safetensors_model.go @@ -21,7 +21,7 @@ import ( func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { defer internal.CloseAndLogError(reader, reader.Path()) - header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8}) + header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8}) if err != nil { return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err) } diff --git a/syft/pkg/cataloger/ai/parse_safetensors_oci.go b/syft/pkg/cataloger/ai/parse_safetensors_oci.go index 6449c2aca..f3fce0403 100644 --- a/syft/pkg/cataloger/ai/parse_safetensors_oci.go +++ b/syft/pkg/cataloger/ai/parse_safetensors_oci.go @@ -95,7 +95,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { defer internal.CloseAndLogError(reader, reader.Path()) - header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8}) + header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8}) if err != nil { return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err) } diff --git a/syft/pkg/cataloger/ai/parse_safetensors_test.go b/syft/pkg/cataloger/ai/parse_safetensors_test.go index f58c4cf61..e669c6a43 100644 --- a/syft/pkg/cataloger/ai/parse_safetensors_test.go +++ b/syft/pkg/cataloger/ai/parse_safetensors_test.go @@ -641,23 +641,22 @@ func TestReadSafeTensorsHeader(t *testing.T) { data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{ "w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}}, }) - h, n, err := readSafeTensorsHeader(bytes.NewReader(data)) + h, err := readSafeTensorsHeader(bytes.NewReader(data)) require.NoError(t, err) - assert.Equal(t, uint64(len(data)-8), n) assert.Len(t, h.tensors, 1) assert.Equal(t, "pt", h.metadata["format"]) }) t.Run("zero-length header", func(t *testing.T) { var buf [8]byte // length prefix of 0 - _, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) + _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) require.Error(t, err) }) t.Run("truncated body", func(t *testing.T) { var buf [8]byte binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none - _, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) + _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) require.Error(t, err) }) }