fix: update userMetadata to use KeyValue

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-05-28 13:00:47 -04:00
parent 1a1f2af92b
commit d12cf9a3e2
No known key found for this signature in database
6 changed files with 35 additions and 13 deletions

View File

@ -4129,11 +4129,8 @@
"description": "ShardCount is the number of .safetensors shards for a sharded model (1 for a\nsingle-file model)." "description": "ShardCount is the number of .safetensors shards for a sharded model (1 for a\nsingle-file model)."
}, },
"userMetadata": { "userMetadata": {
"additionalProperties": { "$ref": "#/$defs/KeyValues",
"type": "string" "description": "UserMetadata is the optional \"__metadata__\" map from a .safetensors file header\n(string-to-string key/values set by the producer). Stored as a sorted KeyValues\nslice rather than a Go map so SBOM output is stable across runs."
},
"type": "object",
"description": "UserMetadata is the optional \"__metadata__\" map from a .safetensors file header\n(string-to-string key/values set by the producer)."
}, },
"metadataHash": { "metadataHash": {
"type": "string", "type": "string",

View File

@ -9,6 +9,8 @@ import (
"strings" "strings"
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
"github.com/anchore/syft/syft/pkg"
) )
// SafeTensors file format: [8 bytes u64 LE header size] [N bytes JSON header] [tensor data]. // SafeTensors file format: [8 bytes u64 LE header size] [N bytes JSON header] [tensor data].
@ -144,6 +146,27 @@ func (h *safeTensorsHeader) metadataHash() string {
return fmt.Sprintf("%016x", xxhash.Sum64(b)) return fmt.Sprintf("%016x", xxhash.Sum64(b))
} }
// userMetadataKeyValues converts the safetensors __metadata__ map into a
// KeyValues slice sorted by key. We do not use the convention of returning a
// nil slice for an empty input — instead, an empty input maps to an empty
// (length-0, non-nil) KeyValues — so downstream JSON serialization remains
// stable: `omitempty` drops the field either way.
func userMetadataKeyValues(m map[string]string) pkg.KeyValues {
if len(m) == 0 {
return nil
}
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
out := make(pkg.KeyValues, 0, len(keys))
for _, k := range keys {
out = append(out, pkg.KeyValue{Key: k, Value: m[k]})
}
return out
}
// normalizeDType maps a safetensors/torch dtype label to an uppercase quantization // normalizeDType maps a safetensors/torch dtype label to an uppercase quantization
// shorthand matching conventions used elsewhere in syft (e.g., BF16, F16, I8). // shorthand matching conventions used elsewhere in syft (e.g., BF16, F16, I8).
func normalizeDType(dtype string) string { func normalizeDType(dtype string) string {

View File

@ -38,7 +38,7 @@ func parseSafeTensorsFile(_ context.Context, resolver file.Resolver, _ *generic.
TensorCount: uint64(len(header.tensors)), TensorCount: uint64(len(header.tensors)),
Quantization: normalizeDType(header.dominantDType()), Quantization: normalizeDType(header.dominantDType()),
ShardCount: 1, ShardCount: 1,
UserMetadata: header.metadata, UserMetadata: userMetadataKeyValues(header.metadata),
MetadataHash: header.metadataHash(), MetadataHash: header.metadataHash(),
} }
if p := header.parameterCount(); p > 0 { if p := header.parameterCount(); p > 0 {

View File

@ -261,7 +261,7 @@ func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Env
Format: "safetensors", Format: "safetensors",
TensorCount: uint64(len(header.tensors)), TensorCount: uint64(len(header.tensors)),
Quantization: normalizeDType(header.dominantDType()), Quantization: normalizeDType(header.dominantDType()),
UserMetadata: header.metadata, UserMetadata: userMetadataKeyValues(header.metadata),
MetadataHash: header.metadataHash(), MetadataHash: header.metadataHash(),
} }
if p := header.parameterCount(); p > 0 { if p := header.parameterCount(); p > 0 {

View File

@ -74,7 +74,7 @@ func TestSafeTensorsCataloger_singleFile(t *testing.T) {
TorchDtype: "bfloat16", TorchDtype: "bfloat16",
TransformersVersion: "4.40.0", TransformersVersion: "4.40.0",
ShardCount: 1, ShardCount: 1,
UserMetadata: userMeta, UserMetadata: pkg.KeyValues{{Key: "format", Value: "pt"}},
MetadataHash: wantHash, MetadataHash: wantHash,
}, },
}, },
@ -276,6 +276,7 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
"layer.1.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32768, 33280}}, "layer.1.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32768, 33280}},
} }
userMeta := map[string]string{"format": "pt"} userMeta := map[string]string{"format": "pt"}
wantUserMetadata := pkg.KeyValues{{Key: "format", Value: "pt"}}
blob := buildSafeTensorsFile(t, userMeta, tensors) blob := buildSafeTensorsFile(t, userMeta, tensors)
wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash() wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash()
@ -291,7 +292,7 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
assert.Equal(t, "safetensors", md.Format) assert.Equal(t, "safetensors", md.Format)
assert.Equal(t, uint64(2), md.TensorCount) assert.Equal(t, uint64(2), md.TensorCount)
assert.Equal(t, "BF16", md.Quantization) assert.Equal(t, "BF16", md.Quantization)
assert.Equal(t, userMeta, md.UserMetadata) assert.Equal(t, wantUserMetadata, md.UserMetadata)
assert.Equal(t, wantHash, md.MetadataHash) assert.Equal(t, wantHash, md.MetadataHash)
}) })
@ -327,7 +328,7 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
// MetadataHash is cleared on absorbed parts by the existing merge processor. // MetadataHash is cleared on absorbed parts by the existing merge processor.
// What survives is the rest of the per-shard metadata (UserMetadata, TensorCount, // What survives is the rest of the per-shard metadata (UserMetadata, TensorCount,
// header-derived Quantization). Confirm those are intact. // header-derived Quantization). Confirm those are intact.
assert.Equal(t, userMeta, md.Parts[0].UserMetadata) assert.Equal(t, wantUserMetadata, md.Parts[0].UserMetadata)
assert.Equal(t, uint64(2), md.Parts[0].TensorCount) assert.Equal(t, uint64(2), md.Parts[0].TensorCount)
assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype") assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype")
}) })
@ -357,7 +358,7 @@ func TestParseSafeTensorsOCILayer_realFixture(t *testing.T) {
assert.Equal(t, uint64(148), md.TensorCount, "nomic-embed-v2-moe 475M ships 148 tensor entries in this shard") assert.Equal(t, uint64(148), md.TensorCount, "nomic-embed-v2-moe 475M ships 148 tensor entries in this shard")
assert.Equal(t, "F32", md.Quantization, "every tensor in the captured shard is F32") assert.Equal(t, "F32", md.Quantization, "every tensor in the captured shard is F32")
assert.Equal(t, "475.29M", md.Parameters) assert.Equal(t, "475.29M", md.Parameters)
assert.Equal(t, map[string]string{"format": "pt"}, md.UserMetadata) assert.Equal(t, pkg.KeyValues{{Key: "format", Value: "pt"}}, md.UserMetadata)
// MetadataHash is locked to the exact value the parser produces for this // MetadataHash is locked to the exact value the parser produces for this
// captured input. The fixture is immutable on disk; if this value changes // captured input. The fixture is immutable on disk; if this value changes
// either the hash algorithm or the canonicalization changed, both of which // either the hash algorithm or the canonicalization changed, both of which

View File

@ -46,8 +46,9 @@ type SafeTensorsModelInfo struct {
ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"` ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"`
// UserMetadata is the optional "__metadata__" map from a .safetensors file header // UserMetadata is the optional "__metadata__" map from a .safetensors file header
// (string-to-string key/values set by the producer). // (string-to-string key/values set by the producer). Stored as a sorted KeyValues
UserMetadata map[string]string `json:"userMetadata,omitempty" cyclonedx:"userMetadata"` // slice rather than a Go map so SBOM output is stable across runs.
UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"`
// MetadataHash is an xxhash of the normalized header metadata, providing a stable // MetadataHash is an xxhash of the normalized header metadata, providing a stable
// identifier for identical model content across repositories or filenames. // identifier for identical model content across repositories or filenames.