mirror of
https://github.com/anchore/syft.git
synced 2026-07-05 02:28:25 +02:00
fix: update userMetadata to use KeyValue
Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
parent
1a1f2af92b
commit
d12cf9a3e2
@ -4129,11 +4129,8 @@
|
|||||||
"description": "ShardCount is the number of .safetensors shards for a sharded model (1 for a\nsingle-file model)."
|
"description": "ShardCount is the number of .safetensors shards for a sharded model (1 for a\nsingle-file model)."
|
||||||
},
|
},
|
||||||
"userMetadata": {
|
"userMetadata": {
|
||||||
"additionalProperties": {
|
"$ref": "#/$defs/KeyValues",
|
||||||
"type": "string"
|
"description": "UserMetadata is the optional \"__metadata__\" map from a .safetensors file header\n(string-to-string key/values set by the producer). Stored as a sorted KeyValues\nslice rather than a Go map so SBOM output is stable across runs."
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"description": "UserMetadata is the optional \"__metadata__\" map from a .safetensors file header\n(string-to-string key/values set by the producer)."
|
|
||||||
},
|
},
|
||||||
"metadataHash": {
|
"metadataHash": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
|
|
||||||
|
"github.com/anchore/syft/syft/pkg"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SafeTensors file format: [8 bytes u64 LE header size] [N bytes JSON header] [tensor data].
|
// SafeTensors file format: [8 bytes u64 LE header size] [N bytes JSON header] [tensor data].
|
||||||
@ -144,6 +146,27 @@ func (h *safeTensorsHeader) metadataHash() string {
|
|||||||
return fmt.Sprintf("%016x", xxhash.Sum64(b))
|
return fmt.Sprintf("%016x", xxhash.Sum64(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userMetadataKeyValues converts the safetensors __metadata__ map into a
|
||||||
|
// KeyValues slice sorted by key. We do not use the convention of returning a
|
||||||
|
// nil slice for an empty input — instead, an empty input maps to an empty
|
||||||
|
// (length-0, non-nil) KeyValues — so downstream JSON serialization remains
|
||||||
|
// stable: `omitempty` drops the field either way.
|
||||||
|
func userMetadataKeyValues(m map[string]string) pkg.KeyValues {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
keys := make([]string, 0, len(m))
|
||||||
|
for k := range m {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
out := make(pkg.KeyValues, 0, len(keys))
|
||||||
|
for _, k := range keys {
|
||||||
|
out = append(out, pkg.KeyValue{Key: k, Value: m[k]})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// normalizeDType maps a safetensors/torch dtype label to an uppercase quantization
|
// normalizeDType maps a safetensors/torch dtype label to an uppercase quantization
|
||||||
// shorthand matching conventions used elsewhere in syft (e.g., BF16, F16, I8).
|
// shorthand matching conventions used elsewhere in syft (e.g., BF16, F16, I8).
|
||||||
func normalizeDType(dtype string) string {
|
func normalizeDType(dtype string) string {
|
||||||
|
|||||||
@ -38,7 +38,7 @@ func parseSafeTensorsFile(_ context.Context, resolver file.Resolver, _ *generic.
|
|||||||
TensorCount: uint64(len(header.tensors)),
|
TensorCount: uint64(len(header.tensors)),
|
||||||
Quantization: normalizeDType(header.dominantDType()),
|
Quantization: normalizeDType(header.dominantDType()),
|
||||||
ShardCount: 1,
|
ShardCount: 1,
|
||||||
UserMetadata: header.metadata,
|
UserMetadata: userMetadataKeyValues(header.metadata),
|
||||||
MetadataHash: header.metadataHash(),
|
MetadataHash: header.metadataHash(),
|
||||||
}
|
}
|
||||||
if p := header.parameterCount(); p > 0 {
|
if p := header.parameterCount(); p > 0 {
|
||||||
|
|||||||
@ -261,7 +261,7 @@ func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Env
|
|||||||
Format: "safetensors",
|
Format: "safetensors",
|
||||||
TensorCount: uint64(len(header.tensors)),
|
TensorCount: uint64(len(header.tensors)),
|
||||||
Quantization: normalizeDType(header.dominantDType()),
|
Quantization: normalizeDType(header.dominantDType()),
|
||||||
UserMetadata: header.metadata,
|
UserMetadata: userMetadataKeyValues(header.metadata),
|
||||||
MetadataHash: header.metadataHash(),
|
MetadataHash: header.metadataHash(),
|
||||||
}
|
}
|
||||||
if p := header.parameterCount(); p > 0 {
|
if p := header.parameterCount(); p > 0 {
|
||||||
|
|||||||
@ -74,7 +74,7 @@ func TestSafeTensorsCataloger_singleFile(t *testing.T) {
|
|||||||
TorchDtype: "bfloat16",
|
TorchDtype: "bfloat16",
|
||||||
TransformersVersion: "4.40.0",
|
TransformersVersion: "4.40.0",
|
||||||
ShardCount: 1,
|
ShardCount: 1,
|
||||||
UserMetadata: userMeta,
|
UserMetadata: pkg.KeyValues{{Key: "format", Value: "pt"}},
|
||||||
MetadataHash: wantHash,
|
MetadataHash: wantHash,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -276,6 +276,7 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
|
|||||||
"layer.1.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32768, 33280}},
|
"layer.1.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32768, 33280}},
|
||||||
}
|
}
|
||||||
userMeta := map[string]string{"format": "pt"}
|
userMeta := map[string]string{"format": "pt"}
|
||||||
|
wantUserMetadata := pkg.KeyValues{{Key: "format", Value: "pt"}}
|
||||||
blob := buildSafeTensorsFile(t, userMeta, tensors)
|
blob := buildSafeTensorsFile(t, userMeta, tensors)
|
||||||
wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash()
|
wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash()
|
||||||
|
|
||||||
@ -291,7 +292,7 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
|
|||||||
assert.Equal(t, "safetensors", md.Format)
|
assert.Equal(t, "safetensors", md.Format)
|
||||||
assert.Equal(t, uint64(2), md.TensorCount)
|
assert.Equal(t, uint64(2), md.TensorCount)
|
||||||
assert.Equal(t, "BF16", md.Quantization)
|
assert.Equal(t, "BF16", md.Quantization)
|
||||||
assert.Equal(t, userMeta, md.UserMetadata)
|
assert.Equal(t, wantUserMetadata, md.UserMetadata)
|
||||||
assert.Equal(t, wantHash, md.MetadataHash)
|
assert.Equal(t, wantHash, md.MetadataHash)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -327,7 +328,7 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
|
|||||||
// MetadataHash is cleared on absorbed parts by the existing merge processor.
|
// MetadataHash is cleared on absorbed parts by the existing merge processor.
|
||||||
// What survives is the rest of the per-shard metadata (UserMetadata, TensorCount,
|
// What survives is the rest of the per-shard metadata (UserMetadata, TensorCount,
|
||||||
// header-derived Quantization). Confirm those are intact.
|
// header-derived Quantization). Confirm those are intact.
|
||||||
assert.Equal(t, userMeta, md.Parts[0].UserMetadata)
|
assert.Equal(t, wantUserMetadata, md.Parts[0].UserMetadata)
|
||||||
assert.Equal(t, uint64(2), md.Parts[0].TensorCount)
|
assert.Equal(t, uint64(2), md.Parts[0].TensorCount)
|
||||||
assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype")
|
assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype")
|
||||||
})
|
})
|
||||||
@ -357,7 +358,7 @@ func TestParseSafeTensorsOCILayer_realFixture(t *testing.T) {
|
|||||||
assert.Equal(t, uint64(148), md.TensorCount, "nomic-embed-v2-moe 475M ships 148 tensor entries in this shard")
|
assert.Equal(t, uint64(148), md.TensorCount, "nomic-embed-v2-moe 475M ships 148 tensor entries in this shard")
|
||||||
assert.Equal(t, "F32", md.Quantization, "every tensor in the captured shard is F32")
|
assert.Equal(t, "F32", md.Quantization, "every tensor in the captured shard is F32")
|
||||||
assert.Equal(t, "475.29M", md.Parameters)
|
assert.Equal(t, "475.29M", md.Parameters)
|
||||||
assert.Equal(t, map[string]string{"format": "pt"}, md.UserMetadata)
|
assert.Equal(t, pkg.KeyValues{{Key: "format", Value: "pt"}}, md.UserMetadata)
|
||||||
// MetadataHash is locked to the exact value the parser produces for this
|
// MetadataHash is locked to the exact value the parser produces for this
|
||||||
// captured input. The fixture is immutable on disk; if this value changes
|
// captured input. The fixture is immutable on disk; if this value changes
|
||||||
// either the hash algorithm or the canonicalization changed, both of which
|
// either the hash algorithm or the canonicalization changed, both of which
|
||||||
|
|||||||
@ -46,8 +46,9 @@ type SafeTensorsModelInfo struct {
|
|||||||
ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"`
|
ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"`
|
||||||
|
|
||||||
// UserMetadata is the optional "__metadata__" map from a .safetensors file header
|
// UserMetadata is the optional "__metadata__" map from a .safetensors file header
|
||||||
// (string-to-string key/values set by the producer).
|
// (string-to-string key/values set by the producer). Stored as a sorted KeyValues
|
||||||
UserMetadata map[string]string `json:"userMetadata,omitempty" cyclonedx:"userMetadata"`
|
// slice rather than a Go map so SBOM output is stable across runs.
|
||||||
|
UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"`
|
||||||
|
|
||||||
// MetadataHash is an xxhash of the normalized header metadata, providing a stable
|
// MetadataHash is an xxhash of the normalized header metadata, providing a stable
|
||||||
// identifier for identical model content across repositories or filenames.
|
// identifier for identical model content across repositories or filenames.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user