lint: lintfix

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-05-29 04:02:41 -04:00
parent 9644340981
commit b731aa4f33
No known key found for this signature in database
4 changed files with 16 additions and 16 deletions

View File

@ -35,36 +35,37 @@ type safeTensorsEntry struct {
DataOffsets []int64 `json:"data_offsets"`
}
// readSafeTensorsHeader reads and parses the JSON header from a .safetensors file.
// It returns the decoded header plus the on-disk size of the header JSON in bytes.
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) {
// readSafeTensorsHeader reads and parses the JSON header from a .safetensors
// file (the leading `[8-byte LE length] [length bytes of JSON]` block) and
// returns the decoded header.
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, error) {
var lenBuf [8]byte
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
return nil, 0, fmt.Errorf("failed to read header length: %w", err)
return nil, fmt.Errorf("failed to read header length: %w", err)
}
headerLen := binary.LittleEndian.Uint64(lenBuf[:])
if headerLen == 0 {
return nil, 0, fmt.Errorf("safetensors header length is zero")
return nil, fmt.Errorf("safetensors header length is zero")
}
if headerLen > maxSafeTensorsHeaderSize {
return nil, 0, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
return nil, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
}
body := make([]byte, headerLen)
if _, err := io.ReadFull(r, body); err != nil {
return nil, 0, fmt.Errorf("failed to read header body: %w", err)
return nil, fmt.Errorf("failed to read header body: %w", err)
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(body, &raw); err != nil {
return nil, 0, fmt.Errorf("failed to decode safetensors header JSON: %w", err)
return nil, fmt.Errorf("failed to decode safetensors header JSON: %w", err)
}
h := &safeTensorsHeader{tensors: make(map[string]safeTensorsEntry, len(raw))}
for key, val := range raw {
if key == "__metadata__" {
if err := json.Unmarshal(val, &h.metadata); err != nil {
return nil, 0, fmt.Errorf("failed to decode __metadata__: %w", err)
return nil, fmt.Errorf("failed to decode __metadata__: %w", err)
}
continue
}
@ -76,7 +77,7 @@ func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) {
h.tensors[key] = entry
}
return h, headerLen, nil
return h, nil
}
// parameterCount sums the element counts across all tensors in the header.

View File

@ -21,7 +21,7 @@ import (
func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
defer internal.CloseAndLogError(reader, reader.Path())
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
if err != nil {
return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err)
}

View File

@ -95,7 +95,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
defer internal.CloseAndLogError(reader, reader.Path())
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
if err != nil {
return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err)
}

View File

@ -641,23 +641,22 @@ func TestReadSafeTensorsHeader(t *testing.T) {
data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{
"w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
})
h, n, err := readSafeTensorsHeader(bytes.NewReader(data))
h, err := readSafeTensorsHeader(bytes.NewReader(data))
require.NoError(t, err)
assert.Equal(t, uint64(len(data)-8), n)
assert.Len(t, h.tensors, 1)
assert.Equal(t, "pt", h.metadata["format"])
})
t.Run("zero-length header", func(t *testing.T) {
var buf [8]byte // length prefix of 0
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
_, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
require.Error(t, err)
})
t.Run("truncated body", func(t *testing.T) {
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
_, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
require.Error(t, err)
})
}