mirror of
https://github.com/anchore/syft.git
synced 2026-07-05 02:28:25 +02:00
lint: lintfix
Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
parent
9644340981
commit
b731aa4f33
@ -35,36 +35,37 @@ type safeTensorsEntry struct {
|
||||
DataOffsets []int64 `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// readSafeTensorsHeader reads and parses the JSON header from a .safetensors file.
|
||||
// It returns the decoded header plus the on-disk size of the header JSON in bytes.
|
||||
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) {
|
||||
// readSafeTensorsHeader reads and parses the JSON header from a .safetensors
|
||||
// file (the leading `[8-byte LE length] [length bytes of JSON]` block) and
|
||||
// returns the decoded header.
|
||||
func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, error) {
|
||||
var lenBuf [8]byte
|
||||
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to read header length: %w", err)
|
||||
return nil, fmt.Errorf("failed to read header length: %w", err)
|
||||
}
|
||||
headerLen := binary.LittleEndian.Uint64(lenBuf[:])
|
||||
if headerLen == 0 {
|
||||
return nil, 0, fmt.Errorf("safetensors header length is zero")
|
||||
return nil, fmt.Errorf("safetensors header length is zero")
|
||||
}
|
||||
if headerLen > maxSafeTensorsHeaderSize {
|
||||
return nil, 0, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
|
||||
return nil, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
|
||||
}
|
||||
|
||||
body := make([]byte, headerLen)
|
||||
if _, err := io.ReadFull(r, body); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to read header body: %w", err)
|
||||
return nil, fmt.Errorf("failed to read header body: %w", err)
|
||||
}
|
||||
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to decode safetensors header JSON: %w", err)
|
||||
return nil, fmt.Errorf("failed to decode safetensors header JSON: %w", err)
|
||||
}
|
||||
|
||||
h := &safeTensorsHeader{tensors: make(map[string]safeTensorsEntry, len(raw))}
|
||||
for key, val := range raw {
|
||||
if key == "__metadata__" {
|
||||
if err := json.Unmarshal(val, &h.metadata); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to decode __metadata__: %w", err)
|
||||
return nil, fmt.Errorf("failed to decode __metadata__: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@ -76,7 +77,7 @@ func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, uint64, error) {
|
||||
h.tensors[key] = entry
|
||||
}
|
||||
|
||||
return h, headerLen, nil
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// parameterCount sums the element counts across all tensors in the header.
|
||||
|
||||
@ -21,7 +21,7 @@ import (
|
||||
func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||
header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err)
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En
|
||||
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||
header, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err)
|
||||
}
|
||||
|
||||
@ -641,23 +641,22 @@ func TestReadSafeTensorsHeader(t *testing.T) {
|
||||
data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{
|
||||
"w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
|
||||
})
|
||||
h, n, err := readSafeTensorsHeader(bytes.NewReader(data))
|
||||
h, err := readSafeTensorsHeader(bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint64(len(data)-8), n)
|
||||
assert.Len(t, h.tensors, 1)
|
||||
assert.Equal(t, "pt", h.metadata["format"])
|
||||
})
|
||||
|
||||
t.Run("zero-length header", func(t *testing.T) {
|
||||
var buf [8]byte // length prefix of 0
|
||||
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||
_, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated body", func(t *testing.T) {
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none
|
||||
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||
_, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user