lint: lintfix

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-05-29 04:02:41 -04:00
parent 9644340981
commit b731aa4f33
No known key found for this signature in database
4 changed files with 16 additions and 16 deletions

View File

@ -35,36 +35,37 @@ type safeTensorsEntry struct {
DataOffsets []int64 `json:"data_offsets"` DataOffsets []int64 `json:"data_offsets"`
} }
// readSafeTensorsHeader reads and parses the JSON header from a .safetensors file. // readSafeTensorsHeader reads and parses the JSON header from a .safetensors
// It returns the decoded header plus the on-disk size of the header JSON in bytes. // file (the leading `[8-byte LE length] [length bytes of JSON]` block) and
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) { // returns the decoded header.
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, error) {
var lenBuf [8]byte var lenBuf [8]byte
if _, err := io.ReadFull(r, lenBuf[:]); err != nil { if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
return nil, 0, fmt.Errorf("failed to read header length: %w", err) return nil, fmt.Errorf("failed to read header length: %w", err)
} }
headerLen := binary.LittleEndian.Uint64(lenBuf[:]) headerLen := binary.LittleEndian.Uint64(lenBuf[:])
if headerLen == 0 { if headerLen == 0 {
return nil, 0, fmt.Errorf("safetensors header length is zero") return nil, fmt.Errorf("safetensors header length is zero")
} }
if headerLen > maxSafeTensorsHeaderSize { if headerLen > maxSafeTensorsHeaderSize {
return nil, 0, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize) return nil, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
} }
body := make([]byte, headerLen) body := make([]byte, headerLen)
if _, err := io.ReadFull(r, body); err != nil { if _, err := io.ReadFull(r, body); err != nil {
return nil, 0, fmt.Errorf("failed to read header body: %w", err) return nil, fmt.Errorf("failed to read header body: %w", err)
} }
var raw map[string]json.RawMessage var raw map[string]json.RawMessage
if err := json.Unmarshal(body, &raw); err != nil { if err := json.Unmarshal(body, &raw); err != nil {
return nil, 0, fmt.Errorf("failed to decode safetensors header JSON: %w", err) return nil, fmt.Errorf("failed to decode safetensors header JSON: %w", err)
} }
h := &safeTensorsHeader{tensors: make(map[string]safeTensorsEntry, len(raw))} h := &safeTensorsHeader{tensors: make(map[string]safeTensorsEntry, len(raw))}
for key, val := range raw { for key, val := range raw {
if key == "__metadata__" { if key == "__metadata__" {
if err := json.Unmarshal(val, &h.metadata); err != nil { if err := json.Unmarshal(val, &h.metadata); err != nil {
return nil, 0, fmt.Errorf("failed to decode __metadata__: %w", err) return nil, fmt.Errorf("failed to decode __metadata__: %w", err)
} }
continue continue
} }
@ -76,7 +77,7 @@ func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) {
h.tensors[key] = entry h.tensors[key] = entry
} }
return h, headerLen, nil return h, nil
} }
// parameterCount sums the element counts across all tensors in the header. // parameterCount sums the element counts across all tensors in the header.

View File

@ -21,7 +21,7 @@ import (
func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
defer internal.CloseAndLogError(reader, reader.Path()) defer internal.CloseAndLogError(reader, reader.Path())
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8}) header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err) return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err)
} }

View File

@ -95,7 +95,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
defer internal.CloseAndLogError(reader, reader.Path()) defer internal.CloseAndLogError(reader, reader.Path())
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8}) header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err) return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err)
} }

View File

@ -641,23 +641,22 @@ func TestReadSafeTensorsHeader(t *testing.T) {
data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{ data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{
"w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}}, "w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
}) })
h, n, err := readSafeTensorsHeader(bytes.NewReader(data)) h, err := readSafeTensorsHeader(bytes.NewReader(data))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint64(len(data)-8), n)
assert.Len(t, h.tensors, 1) assert.Len(t, h.tensors, 1)
assert.Equal(t, "pt", h.metadata["format"]) assert.Equal(t, "pt", h.metadata["format"])
}) })
t.Run("zero-length header", func(t *testing.T) { t.Run("zero-length header", func(t *testing.T) {
var buf [8]byte // length prefix of 0 var buf [8]byte // length prefix of 0
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
require.Error(t, err) require.Error(t, err)
}) })
t.Run("truncated body", func(t *testing.T) { t.Run("truncated body", func(t *testing.T) {
var buf [8]byte var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
require.Error(t, err) require.Error(t, err)
}) })
} }