fix: allow both dir/oci paths to parse safetensor files

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-05-28 12:00:49 -04:00
parent 58b6f5807e
commit 324fecf4a4
No known key found for this signature in database
6 changed files with 227 additions and 15 deletions

View File

@ -26,15 +26,19 @@ func NewGGUFCataloger() pkg.Cataloger {
} }
// NewSafeTensorsCataloger returns a cataloger for SafeTensors model files, // NewSafeTensorsCataloger returns a cataloger for SafeTensors model files,
// covering three discovery paths: // covering four discovery paths:
// - **/*.safetensors files (single-file models; header-only parse) // - **/*.safetensors files (single-file models; header-only parse)
// - **/model.safetensors.index.json files (sharded models) // - **/model.safetensors.index.json files (sharded models)
// - application/vnd.docker.ai.model.config.v0.1+json / v0.2+json OCI layers // - application/vnd.docker.ai.model.config.v0.1+json / v0.2+json OCI layers
// (Docker Model Runner artifacts whose config advertises format=="safetensors") // (Docker Model Runner artifacts whose config advertises format=="safetensors")
// - application/vnd.docker.ai.safetensors OCI layers (per-shard JSON headers,
// fetched as a prefix by the OCI model source; emitted as nameless
// packages and merged into the config-derived package as Parts)
func NewSafeTensorsCataloger() pkg.Cataloger { func NewSafeTensorsCataloger() pkg.Cataloger {
return generic.NewCataloger(safeTensorsCatalogerName). return generic.NewCataloger(safeTensorsCatalogerName).
WithParserByGlobs(parseSafeTensorsFile, "**/*.safetensors"). WithParserByGlobs(parseSafeTensorsFile, "**/*.safetensors").
WithParserByGlobs(parseSafeTensorsIndex, "**/*.safetensors.index.json"). WithParserByGlobs(parseSafeTensorsIndex, "**/*.safetensors.index.json").
WithParserByMediaType(parseSafeTensorsOCIConfig, dockerAIModelConfigMediaTypes...). WithParserByMediaType(parseSafeTensorsOCIConfig, dockerAIModelConfigMediaTypes...).
WithParserByMediaType(parseSafeTensorsOCILayer, dockerAISafeTensorsMediaType).
WithProcessors(safeTensorsMergeProcessor) WithProcessors(safeTensorsMergeProcessor)
} }

View File

@ -20,6 +20,7 @@ import (
const ( const (
dockerAIModelFileMediaType = "application/vnd.docker.ai.model.file" dockerAIModelFileMediaType = "application/vnd.docker.ai.model.file"
dockerAILicenseMediaType = "application/vnd.docker.ai.license" dockerAILicenseMediaType = "application/vnd.docker.ai.license"
dockerAISafeTensorsMediaType = "application/vnd.docker.ai.safetensors"
) )
// dockerAIModelConfigMediaTypes are the model-config schema versions this // dockerAIModelConfigMediaTypes are the model-config schema versions this
@ -241,5 +242,48 @@ func lastPathSegment(s string) string {
return s return s
} }
// integrity check // parseSafeTensorsOCILayer parses a SafeTensors weight layer from an OCI model
var _ generic.Parser = parseSafeTensorsOCIConfig // artifact by reading only its JSON header (the layer is fetched up to a small
// byte cap by the source layer; tensor data is never downloaded). It emits a
// nameless package so safeTensorsMergeProcessor folds the result into the
// config-derived named package as a Part. The point of this parser is to give
// OCI scans the same content-derived fields the directory-scan path produces:
// real tensor count, normalized quantization, __metadata__, and MetadataHash.
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
defer internal.CloseAndLogError(reader, reader.Path())
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
if err != nil {
return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err)
}
md := pkg.SafeTensorsModelInfo{
Format: "safetensors",
TensorCount: uint64(len(header.tensors)),
Quantization: normalizeDType(header.dominantDType()),
UserMetadata: header.metadata,
MetadataHash: header.metadataHash(),
}
if p := header.parameterCount(); p > 0 {
md.Parameters = formatParameterCount(p)
}
// Emit nameless; safeTensorsMergeProcessor will absorb this into the
// config-derived named package as a Part. The merge runs even when only
// nameless packages exist, in which case the result is dropped.
p := newSafeTensorsPackage(
&md,
"",
"",
"",
reader.WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
)
return []pkg.Package{p}, nil, nil
}
// integrity checks
var (
_ generic.Parser = parseSafeTensorsOCIConfig
_ generic.Parser = parseSafeTensorsOCILayer
)

View File

@ -233,6 +233,27 @@ func TestSafeTensorsMergeProcessor(t *testing.T) {
md := out[0].Metadata.(pkg.SafeTensorsModelInfo) md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
require.Len(t, md.Parts, 1) require.Len(t, md.Parts, 1)
assert.Empty(t, md.Parts[0].MetadataHash, "nameless part hash should be cleared") assert.Empty(t, md.Parts[0].MetadataHash, "nameless part hash should be cleared")
assert.Equal(t, 1, md.ShardCount)
})
t.Run("sets ShardCount from absorbed parts", func(t *testing.T) {
// Three nameless layer packages absorbed into one named config-derived package.
parts := []pkg.Package{
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "cccc"}},
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}},
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}},
}
out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{named}, parts...), nil, nil)
require.NoError(t, err)
require.Len(t, out, 1)
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, 3, md.ShardCount)
require.Len(t, md.Parts, 3)
// Hashes are cleared on absorbed parts, so sort order is deterministic ("" repeated).
// The non-deterministic resolver order should not surface here either way.
for _, p := range md.Parts {
assert.Empty(t, p.MetadataHash)
}
}) })
t.Run("drops result when no named package", func(t *testing.T) { t.Run("drops result when no named package", func(t *testing.T) {
@ -249,6 +270,106 @@ func TestSafeTensorsMergeProcessor(t *testing.T) {
}) })
} }
func TestParseSafeTensorsOCILayer(t *testing.T) {
tensors := map[string]safeTensorsEntry{
"layer.0.weight": {DType: "BF16", Shape: []int64{1024, 16}, DataOffsets: []int64{0, 32768}},
"layer.1.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32768, 33280}},
}
userMeta := map[string]string{"format": "pt"}
blob := buildSafeTensorsFile(t, userMeta, tensors)
wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash()
t.Run("emits a nameless package with header-derived metadata", func(t *testing.T) {
reader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(blob)))
pkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, reader)
require.NoError(t, err)
require.Len(t, pkgs, 1)
p := pkgs[0]
assert.Empty(t, p.Name, "weight-layer parser must emit nameless; the merge processor names it")
md := p.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, "safetensors", md.Format)
assert.Equal(t, uint64(2), md.TensorCount)
assert.Equal(t, "BF16", md.Quantization)
assert.Equal(t, userMeta, md.UserMetadata)
assert.Equal(t, wantHash, md.MetadataHash)
})
t.Run("merges with config-derived named package and lifts ShardCount", func(t *testing.T) {
// Synthesize what the OCI scan would produce: one config-derived named
// package + one weight-layer derived nameless package. Run them through
// the merge processor and assert the result looks like a complete model.
configMd := pkg.SafeTensorsModelInfo{
Format: "safetensors",
Architecture: "Qwen3ForCausalLM",
Parameters: "2.68B",
TotalSize: "5.00GB",
Quantization: "Q4_K_M", // raw producer string
}
named := pkg.Package{Name: "qwen", Type: pkg.ModelPkg, Metadata: configMd}
reader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(blob)))
layerPkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, reader)
require.NoError(t, err)
require.Len(t, layerPkgs, 1)
out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{named}, layerPkgs...), nil, nil)
require.NoError(t, err)
require.Len(t, out, 1)
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, 1, md.ShardCount, "merge processor should set ShardCount from absorbed parts")
// Producer-declared top-level fields are preserved.
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture)
assert.Equal(t, "Q4_K_M", md.Quantization)
// The header-derived hash lives in Parts so callers can compare against a dir scan.
require.Len(t, md.Parts, 1)
// MetadataHash is cleared on absorbed parts by the existing merge processor.
// What survives is the rest of the per-shard metadata (UserMetadata, TensorCount,
// header-derived Quantization). Confirm those are intact.
assert.Equal(t, userMeta, md.Parts[0].UserMetadata)
assert.Equal(t, uint64(2), md.Parts[0].TensorCount)
assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype")
})
}
func TestSafeTensorsCrossSourceHashParity(t *testing.T) {
// Same content, two paths: a directory scan via parseSafeTensorsFile, and an
// OCI weight-layer scan via parseSafeTensorsOCILayer. The MetadataHash of
// the dir-scan package must equal the per-shard hash captured before the
// merge processor absorbs it. This is the convergence point that lets a
// caller correlate the two source types.
tensors := map[string]safeTensorsEntry{
"a.weight": {DType: "BF16", Shape: []int64{8, 8}, DataOffsets: []int64{0, 128}},
"b.weight": {DType: "BF16", Shape: []int64{4, 4}, DataOffsets: []int64{128, 160}},
}
userMeta := map[string]string{"format": "pt", "producer": "test"}
blob := buildSafeTensorsFile(t, userMeta, tensors)
// dir-scan path
dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), blob, 0o644))
dirReader := func() file.LocationReadCloser {
f, err := os.Open(filepath.Join(dir, "model.safetensors"))
require.NoError(t, err)
return file.NewLocationReadCloser(file.NewLocation(filepath.Join(dir, "model.safetensors")), f)
}()
dirPkgs, _, err := parseSafeTensorsFile(context.Background(), nil, nil, dirReader)
require.NoError(t, err)
require.Len(t, dirPkgs, 1)
dirHash := dirPkgs[0].Metadata.(pkg.SafeTensorsModelInfo).MetadataHash
require.NotEmpty(t, dirHash)
// OCI weight-layer path
ociReader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(blob)))
ociPkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, ociReader)
require.NoError(t, err)
require.Len(t, ociPkgs, 1)
ociHash := ociPkgs[0].Metadata.(pkg.SafeTensorsModelInfo).MetadataHash
assert.Equal(t, dirHash, ociHash, "same content via dir scan and OCI weight-layer scan must hash equal")
}
func configReader(blob []byte) file.LocationReadCloser { func configReader(blob []byte) file.LocationReadCloser {
return file.NewLocationReadCloser(file.NewLocation("/config.json"), io.NopCloser(bytes.NewReader(blob))) return file.NewLocationReadCloser(file.NewLocation("/config.json"), io.NopCloser(bytes.NewReader(blob)))
} }

View File

@ -1,6 +1,8 @@
package ai package ai
import ( import (
"sort"
"github.com/anchore/syft/syft/artifact" "github.com/anchore/syft/syft/artifact"
"github.com/anchore/syft/syft/pkg" "github.com/anchore/syft/syft/pkg"
) )
@ -89,9 +91,16 @@ func safeTensorsMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship,
} }
if len(namedPkgs) == 1 && len(namelessParts) > 0 { if len(namedPkgs) == 1 && len(namelessParts) > 0 {
// Sort by MetadataHash so OCI layer order (map iteration) doesn't leak
// into the SBOM output.
sort.Slice(namelessParts, func(i, j int) bool {
return namelessParts[i].MetadataHash < namelessParts[j].MetadataHash
})
winner := &namedPkgs[0] winner := &namedPkgs[0]
if md, ok := winner.Metadata.(pkg.SafeTensorsModelInfo); ok { if md, ok := winner.Metadata.(pkg.SafeTensorsModelInfo); ok {
md.Parts = namelessParts md.Parts = namelessParts
// Trust per-shard headers over the producer-declared shard count.
md.ShardCount = len(namelessParts)
winner.Metadata = md winner.Metadata = md
} }
} }

View File

@ -88,12 +88,12 @@ func validateAndFetchArtifact(ctx context.Context, client *registryClient, refer
// model artifact and stores them on disk so the ContainerImageModel resolver // model artifact and stores them on disk so the ContainerImageModel resolver
// can serve them by media type: // can serve them by media type:
// //
// - For GGUF: the first maxHeaderBytes of each weight layer (existing behavior). // - For GGUF: the first maxHeaderBytes of each weight layer.
// - For SafeTensors: the model-config blob (already in memory as RawConfig) // - For SafeTensors: the model-config blob (already in memory as RawConfig),
// plus each companion layer in full. We deliberately skip the multi-GB // each companion layer in full, and the first maxHeaderBytes of each
// safetensors weight layers — the config blob carries aggregate metadata // weight layer. The weight-layer prefix is enough to read the JSON header
// (format, quantization, parameter count, tensor count, total size) that // (tensor map + __metadata__), which is what the cataloger hashes for
// the cataloger needs, and individual shard headers are not yet used. // cross-source identity.
func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, artifact *modelArtifact) (string, *fileresolver.ContainerImageModel, error) { func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, artifact *modelArtifact) (string, *fileresolver.ContainerImageModel, error) {
tempDir, err := os.MkdirTemp("", "syft-oci-model") tempDir, err := os.MkdirTemp("", "syft-oci-model")
if err != nil { if err != nil {
@ -108,7 +108,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
layerFiles := make(map[string]fileresolver.LayerInfo) layerFiles := make(map[string]fileresolver.LayerInfo)
// GGUF weight-layer headers (unchanged). // GGUF weight-layer headers.
for _, layer := range artifact.GGUFLayers { for _, layer := range artifact.GGUFLayers {
li, err := fetchSingleGGUFHeader(ctx, client, artifact.Reference, layer, tempDir) li, err := fetchSingleGGUFHeader(ctx, client, artifact.Reference, layer, tempDir)
if err != nil { if err != nil {
@ -143,6 +143,21 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
} }
} }
// SafeTensors weight-layer headers. We only pull the leading prefix (same
// budget as a GGUF header) because the JSON header lives at the very start
// of the file — multi-GB tensor data that follows is intentionally not
// downloaded.
if artifact.Format == modelFormatSafeTensors {
for _, layer := range artifact.SafeTensorsLayers {
li, err := fetchSafeTensorsLayerHeader(ctx, client, artifact.Reference, layer, tempDir)
if err != nil {
cleanup()
return "", nil, err
}
layerFiles[layer.Digest.String()] = li
}
}
resolver := fileresolver.NewContainerImageModel(tempDir, layerFiles) resolver := fileresolver.NewContainerImageModel(tempDir, layerFiles)
return tempDir, resolver, nil return tempDir, resolver, nil
@ -202,6 +217,26 @@ func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name
}, nil }, nil
} }
// fetchSafeTensorsLayerHeader fetches the leading bytes of a SafeTensors weight
// layer (enough to cover the JSON header) and writes them to a temp file.
func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes)
if err != nil {
return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch safetensors layer header: %w", err)
}
safeDigest := strings.ReplaceAll(layer.Digest.String(), ":", "-")
tempPath := filepath.Join(tempDir, safeDigest+".safetensors")
if err := os.WriteFile(tempPath, headerData, 0600); err != nil {
return fileresolver.LayerInfo{}, fmt.Errorf("failed to write temp file: %w", err)
}
return fileresolver.LayerInfo{
TempPath: tempPath,
MediaType: string(layer.MediaType),
}, nil
}
// buildMetadata constructs OCIModelMetadata from a modelArtifact. // buildMetadata constructs OCIModelMetadata from a modelArtifact.
func buildMetadata(artifact *modelArtifact) source.OCIModelMetadata { func buildMetadata(artifact *modelArtifact) source.OCIModelMetadata {
// layers // layers

View File

@ -133,9 +133,8 @@ type modelArtifact struct {
GGUFLayers []v1.Descriptor GGUFLayers []v1.Descriptor
// SafeTensorsLayers are descriptors for layers carrying SafeTensors-format weights. // SafeTensorsLayers are descriptors for layers carrying SafeTensors-format weights.
// For safetensors we do NOT fetch these layers — the model-config blob already // We fetch the first maxHeaderBytes of each so the cataloger can read the JSON
// contains the aggregate metadata we need — but we record them here for counting // header (tensor map + __metadata__) without pulling the multi-GB tensor data.
// and for future per-shard parsing.
SafeTensorsLayers []v1.Descriptor SafeTensorsLayers []v1.Descriptor
// CompanionLayers are non-weight layers (README, config.json, license) that // CompanionLayers are non-weight layers (README, config.json, license) that