fix: make MetadataHash consistent across oci/dir source

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-05-28 13:06:40 -04:00
parent d12cf9a3e2
commit 69b7c5e3d0
No known key found for this signature in database
3 changed files with 55 additions and 19 deletions

View File

@ -225,35 +225,53 @@ func TestSafeTensorsMergeProcessor(t *testing.T) {
named := pkg.Package{Name: "model-a", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}} named := pkg.Package{Name: "model-a", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}}
nameless := pkg.Package{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}} nameless := pkg.Package{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}}
t.Run("merges nameless into named parts", func(t *testing.T) { t.Run("preserves the part's MetadataHash when the named package already has one", func(t *testing.T) {
out, _, err := safeTensorsMergeProcessor([]pkg.Package{named, nameless}, nil, nil) out, _, err := safeTensorsMergeProcessor([]pkg.Package{named, nameless}, nil, nil)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, out, 1) require.Len(t, out, 1)
assert.Equal(t, "model-a", out[0].Name) assert.Equal(t, "model-a", out[0].Name)
md := out[0].Metadata.(pkg.SafeTensorsModelInfo) md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
require.Len(t, md.Parts, 1) require.Len(t, md.Parts, 1)
assert.Empty(t, md.Parts[0].MetadataHash, "nameless part hash should be cleared") assert.Equal(t, "bbbb", md.Parts[0].MetadataHash, "part hash must survive: it is the cross-source fingerprint")
assert.Equal(t, "aaaa", md.MetadataHash, "named package's own hash is not overwritten")
assert.Equal(t, 1, md.ShardCount) assert.Equal(t, 1, md.ShardCount)
}) })
t.Run("sets ShardCount from absorbed parts", func(t *testing.T) { t.Run("lifts the single part's MetadataHash to top-level when named has none", func(t *testing.T) {
// This is the OCI single-shard shape: the config-blob parser produces a
// named package with no hash; the weight-layer parser produces a nameless
// part with the real header hash. Top-level should land in the same field
// a dir-scan single-file would populate, so callers can correlate them.
namedNoHash := pkg.Package{Name: "model-b", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors"}}
part := pkg.Package{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "deadbeef"}}
out, _, err := safeTensorsMergeProcessor([]pkg.Package{namedNoHash, part}, nil, nil)
require.NoError(t, err)
require.Len(t, out, 1)
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, "deadbeef", md.MetadataHash, "single-shard lift makes OCI top-level match dir-scan top-level")
require.Len(t, md.Parts, 1)
assert.Equal(t, "deadbeef", md.Parts[0].MetadataHash, "part also retains its hash")
})
t.Run("multi-shard preserves per-part hashes and sorts deterministically", func(t *testing.T) {
// Three nameless layer packages absorbed into one named config-derived package. // Three nameless layer packages absorbed into one named config-derived package.
// Top-level MetadataHash stays empty (no canonical single hash for a sharded
// model — callers must combine the per-shard hashes themselves).
namedNoHash := pkg.Package{Name: "model-c", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors"}}
parts := []pkg.Package{ parts := []pkg.Package{
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "cccc"}}, {Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "cccc"}},
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}}, {Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}},
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}}, {Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}},
} }
out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{named}, parts...), nil, nil) out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{namedNoHash}, parts...), nil, nil)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, out, 1) require.Len(t, out, 1)
md := out[0].Metadata.(pkg.SafeTensorsModelInfo) md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, 3, md.ShardCount) assert.Equal(t, 3, md.ShardCount)
assert.Empty(t, md.MetadataHash, "multi-shard leaves top-level hash unset")
require.Len(t, md.Parts, 3) require.Len(t, md.Parts, 3)
// Hashes are cleared on absorbed parts, so sort order is deterministic ("" repeated). // Parts sorted by MetadataHash for deterministic SBOM output regardless of resolver order.
// The non-deterministic resolver order should not surface here either way. assert.Equal(t, []string{"aaaa", "bbbb", "cccc"}, []string{md.Parts[0].MetadataHash, md.Parts[1].MetadataHash, md.Parts[2].MetadataHash})
for _, p := range md.Parts {
assert.Empty(t, p.MetadataHash)
}
}) })
t.Run("drops result when no named package", func(t *testing.T) { t.Run("drops result when no named package", func(t *testing.T) {
@ -323,11 +341,12 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
// Producer-declared top-level fields are preserved. // Producer-declared top-level fields are preserved.
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture) assert.Equal(t, "Qwen3ForCausalLM", md.Architecture)
assert.Equal(t, "Q4_K_M", md.Quantization) assert.Equal(t, "Q4_K_M", md.Quantization)
// The header-derived hash lives in Parts so callers can compare against a dir scan. // Single-shard: the header-derived MetadataHash is lifted to top-level so
// it matches the field a dir-scan would populate.
assert.Equal(t, wantHash, md.MetadataHash, "single-shard OCI scan must expose the hash at the same field as a dir scan")
// The full per-shard breakdown is also preserved under Parts.
require.Len(t, md.Parts, 1) require.Len(t, md.Parts, 1)
// MetadataHash is cleared on absorbed parts by the existing merge processor. assert.Equal(t, wantHash, md.Parts[0].MetadataHash)
// What survives is the rest of the per-shard metadata (UserMetadata, TensorCount,
// header-derived Quantization). Confirm those are intact.
assert.Equal(t, wantUserMetadata, md.Parts[0].UserMetadata) assert.Equal(t, wantUserMetadata, md.Parts[0].UserMetadata)
assert.Equal(t, uint64(2), md.Parts[0].TensorCount) assert.Equal(t, uint64(2), md.Parts[0].TensorCount)
assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype") assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype")

View File

@ -62,9 +62,16 @@ func ggufMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship, err er
// safeTensorsMergeProcessor mirrors ggufMergeProcessor for SafeTensors packages. // safeTensorsMergeProcessor mirrors ggufMergeProcessor for SafeTensors packages.
// When scanning an OCI AI artifact, the model-config blob produces one named // When scanning an OCI AI artifact, the model-config blob produces one named
// package and individual .safetensors shard layers (if we ever decide to parse // package and each safetensors weight layer produces a nameless package. The
// them directly) would produce nameless packages. Any nameless SafeTensors // nameless packages are absorbed into the named one's Parts slice.
// packages are collapsed into the named one's Parts slice. //
// MetadataHash is intentionally preserved on absorbed parts: it is derived
// purely from the on-disk safetensors header (see SafeTensorsModelInfo doc),
// so it acts as the cross-source content fingerprint. For a single-shard
// model we also copy it up to the named package's top-level MetadataHash so
// that an OCI scan and a directory scan of the same single .safetensors file
// expose the hash at the same field — `md.MetadataHash` — without callers
// having to inspect Parts.
func safeTensorsMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) { func safeTensorsMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) {
if err != nil { if err != nil {
return pkgs, rels, err return pkgs, rels, err
@ -81,7 +88,6 @@ func safeTensorsMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship,
continue continue
} }
if md, ok := p.Metadata.(pkg.SafeTensorsModelInfo); ok { if md, ok := p.Metadata.(pkg.SafeTensorsModelInfo); ok {
md.MetadataHash = ""
namelessParts = append(namelessParts, md) namelessParts = append(namelessParts, md)
} }
} }
@ -101,6 +107,14 @@ func safeTensorsMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship,
md.Parts = namelessParts md.Parts = namelessParts
// Trust per-shard headers over the producer-declared shard count. // Trust per-shard headers over the producer-declared shard count.
md.ShardCount = len(namelessParts) md.ShardCount = len(namelessParts)
// Single-shard: lift the part's content fingerprint to the top
// level so the field placement matches a dir-scan single file.
// Only lift when the named package has no hash of its own (the
// OCI config-blob parser never sets one; dir-scan paths never
// produce nameless parts, so they don't reach this branch).
if len(namelessParts) == 1 && md.MetadataHash == "" {
md.MetadataHash = namelessParts[0].MetadataHash
}
winner.Metadata = md winner.Metadata = md
} }
} }

View File

@ -50,8 +50,11 @@ type SafeTensorsModelInfo struct {
// slice rather than a Go map so SBOM output is stable across runs. // slice rather than a Go map so SBOM output is stable across runs.
UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"` UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"`
// MetadataHash is an xxhash of the normalized header metadata, providing a stable // MetadataHash is an xxhash over the on-disk SafeTensors header (sorted tensor
// identifier for identical model content across repositories or filenames. // entries + __metadata__). It is derived ONLY from the safetensors file bytes —
// never from OCI manifest, layer descriptor, or config-blob fields — so the same
// model content scanned via a directory source and via an OCI image produces the
// same value. Treat this as the cross-source content fingerprint.
MetadataHash string `json:"metadataHash,omitempty" cyclonedx:"metadataHash"` MetadataHash string `json:"metadataHash,omitempty" cyclonedx:"metadataHash"`
// Parts contains metadata from additional SafeTensors shards or OCI layers that // Parts contains metadata from additional SafeTensors shards or OCI layers that