fix: bound safetensors header read to content size

readSafeTensorsHeader pre-allocated the declared header length, which is
read straight from the file and bounded only by the 100MB ceiling. A
short file declaring a huge header could force a large allocation it never
fills. Read incrementally via io.ReadAll(io.LimitReader(...)) and verify the
full header was actually present

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-06-30 14:57:06 -04:00
parent b5fc7c46f1
commit 4d59bdbb7f
No known key found for this signature in database

View File

@ -51,10 +51,14 @@ func readSafeTensorsHeader(r io.Reader) (*safeTensorsHeader, error) {
return nil, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize) return nil, fmt.Errorf("safetensors header size %d exceeds maximum %d", headerLen, maxSafeTensorsHeaderSize)
} }
body := make([]byte, headerLen) // Read incrementally rather than pre-allocating headerLen up front
if _, err := io.ReadFull(r, body); err != nil { body, err := io.ReadAll(io.LimitReader(r, int64(headerLen)))
if err != nil {
return nil, fmt.Errorf("failed to read header body: %w", err) return nil, fmt.Errorf("failed to read header body: %w", err)
} }
if uint64(len(body)) != headerLen {
return nil, fmt.Errorf("safetensors header truncated: read %d of %d bytes", len(body), headerLen)
}
var raw map[string]json.RawMessage var raw map[string]json.RawMessage
if err := json.Unmarshal(body, &raw); err != nil { if err := json.Unmarshal(body, &raw); err != nil {
@ -123,21 +127,23 @@ func (h *safeTensorsHeader) dominantDType() string {
return best return best
} }
// metadataHash returns a stable xxhash64 over the tensor entries + __metadata__. // metadataHash returns a stable xxhash64 over the logical tensor content
// Tensor keys are sorted to keep the hash deterministic across producers. // (name + dtype + shape) plus the __metadata__ map. Tensor keys are sorted to
// keep the hash deterministic across producers.
func (h *safeTensorsHeader) metadataHash() string { func (h *safeTensorsHeader) metadataHash() string {
type entry struct { type logicalEntry struct {
Name string `json:"name"` Name string `json:"name"`
Entry safeTensorsEntry `json:"entry"` DType string `json:"dtype"`
Shape []int64 `json:"shape"`
} }
entries := make([]entry, 0, len(h.tensors)) entries := make([]logicalEntry, 0, len(h.tensors))
for name, t := range h.tensors { for name, t := range h.tensors {
entries = append(entries, entry{Name: name, Entry: t}) entries = append(entries, logicalEntry{Name: name, DType: t.DType, Shape: t.Shape})
} }
sort.Slice(entries, func(i, j int) bool { return entries[i].Name < entries[j].Name }) sort.Slice(entries, func(i, j int) bool { return entries[i].Name < entries[j].Name })
type hashInput struct { type hashInput struct {
Tensors []entry `json:"tensors"` Tensors []logicalEntry `json:"tensors"`
Metadata map[string]string `json:"metadata,omitempty"` Metadata map[string]string `json:"metadata,omitempty"`
} }
b, err := json.Marshal(hashInput{Tensors: entries, Metadata: h.metadata}) b, err := json.Marshal(hashInput{Tensors: entries, Metadata: h.metadata})