mirror of
https://github.com/anchore/syft.git
synced 2026-07-05 02:28:25 +02:00
fix: allow both dir/oci paths to parse safetensor files
Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
parent
58b6f5807e
commit
324fecf4a4
@ -26,15 +26,19 @@ func NewGGUFCataloger() pkg.Cataloger {
|
||||
}
|
||||
|
||||
// NewSafeTensorsCataloger returns a cataloger for SafeTensors model files,
|
||||
// covering three discovery paths:
|
||||
// covering four discovery paths:
|
||||
// - **/*.safetensors files (single-file models; header-only parse)
|
||||
// - **/model.safetensors.index.json files (sharded models)
|
||||
// - application/vnd.docker.ai.model.config.v0.1+json / v0.2+json OCI layers
|
||||
// (Docker Model Runner artifacts whose config advertises format=="safetensors")
|
||||
// - application/vnd.docker.ai.safetensors OCI layers (per-shard JSON headers,
|
||||
// fetched as a prefix by the OCI model source; emitted as nameless
|
||||
// packages and merged into the config-derived package as Parts)
|
||||
func NewSafeTensorsCataloger() pkg.Cataloger {
|
||||
return generic.NewCataloger(safeTensorsCatalogerName).
|
||||
WithParserByGlobs(parseSafeTensorsFile, "**/*.safetensors").
|
||||
WithParserByGlobs(parseSafeTensorsIndex, "**/*.safetensors.index.json").
|
||||
WithParserByMediaType(parseSafeTensorsOCIConfig, dockerAIModelConfigMediaTypes...).
|
||||
WithParserByMediaType(parseSafeTensorsOCILayer, dockerAISafeTensorsMediaType).
|
||||
WithProcessors(safeTensorsMergeProcessor)
|
||||
}
|
||||
|
||||
@ -18,8 +18,9 @@ import (
|
||||
|
||||
// Docker AI OCI media types used by Docker Model Runner artifacts.
|
||||
const (
|
||||
dockerAIModelFileMediaType = "application/vnd.docker.ai.model.file"
|
||||
dockerAILicenseMediaType = "application/vnd.docker.ai.license"
|
||||
dockerAIModelFileMediaType = "application/vnd.docker.ai.model.file"
|
||||
dockerAILicenseMediaType = "application/vnd.docker.ai.license"
|
||||
dockerAISafeTensorsMediaType = "application/vnd.docker.ai.safetensors"
|
||||
)
|
||||
|
||||
// dockerAIModelConfigMediaTypes are the model-config schema versions this
|
||||
@ -241,5 +242,48 @@ func lastPathSegment(s string) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// integrity check
|
||||
var _ generic.Parser = parseSafeTensorsOCIConfig
|
||||
// parseSafeTensorsOCILayer parses a SafeTensors weight layer from an OCI model
|
||||
// artifact by reading only its JSON header (the layer is fetched up to a small
|
||||
// byte cap by the source layer; tensor data is never downloaded). It emits a
|
||||
// nameless package so safeTensorsMergeProcessor folds the result into the
|
||||
// config-derived named package as a Part. The point of this parser is to give
|
||||
// OCI scans the same content-derived fields the directory-scan path produces:
|
||||
// real tensor count, normalized quantization, __metadata__, and MetadataHash.
|
||||
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to read safetensors layer header: %w", err)
|
||||
}
|
||||
|
||||
md := pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: uint64(len(header.tensors)),
|
||||
Quantization: normalizeDType(header.dominantDType()),
|
||||
UserMetadata: header.metadata,
|
||||
MetadataHash: header.metadataHash(),
|
||||
}
|
||||
if p := header.parameterCount(); p > 0 {
|
||||
md.Parameters = formatParameterCount(p)
|
||||
}
|
||||
|
||||
// Emit nameless; safeTensorsMergeProcessor will absorb this into the
|
||||
// config-derived named package as a Part. The merge runs even when only
|
||||
// nameless packages exist, in which case the result is dropped.
|
||||
p := newSafeTensorsPackage(
|
||||
&md,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
reader.WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
)
|
||||
|
||||
return []pkg.Package{p}, nil, nil
|
||||
}
|
||||
|
||||
// integrity checks
|
||||
var (
|
||||
_ generic.Parser = parseSafeTensorsOCIConfig
|
||||
_ generic.Parser = parseSafeTensorsOCILayer
|
||||
)
|
||||
|
||||
@ -233,6 +233,27 @@ func TestSafeTensorsMergeProcessor(t *testing.T) {
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
require.Len(t, md.Parts, 1)
|
||||
assert.Empty(t, md.Parts[0].MetadataHash, "nameless part hash should be cleared")
|
||||
assert.Equal(t, 1, md.ShardCount)
|
||||
})
|
||||
|
||||
t.Run("sets ShardCount from absorbed parts", func(t *testing.T) {
|
||||
// Three nameless layer packages absorbed into one named config-derived package.
|
||||
parts := []pkg.Package{
|
||||
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "cccc"}},
|
||||
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}},
|
||||
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}},
|
||||
}
|
||||
out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{named}, parts...), nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, 3, md.ShardCount)
|
||||
require.Len(t, md.Parts, 3)
|
||||
// Hashes are cleared on absorbed parts, so sort order is deterministic ("" repeated).
|
||||
// The non-deterministic resolver order should not surface here either way.
|
||||
for _, p := range md.Parts {
|
||||
assert.Empty(t, p.MetadataHash)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("drops result when no named package", func(t *testing.T) {
|
||||
@ -249,6 +270,106 @@ func TestSafeTensorsMergeProcessor(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseSafeTensorsOCILayer(t *testing.T) {
|
||||
tensors := map[string]safeTensorsEntry{
|
||||
"layer.0.weight": {DType: "BF16", Shape: []int64{1024, 16}, DataOffsets: []int64{0, 32768}},
|
||||
"layer.1.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32768, 33280}},
|
||||
}
|
||||
userMeta := map[string]string{"format": "pt"}
|
||||
blob := buildSafeTensorsFile(t, userMeta, tensors)
|
||||
wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash()
|
||||
|
||||
t.Run("emits a nameless package with header-derived metadata", func(t *testing.T) {
|
||||
reader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(blob)))
|
||||
pkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, reader)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
|
||||
p := pkgs[0]
|
||||
assert.Empty(t, p.Name, "weight-layer parser must emit nameless; the merge processor names it")
|
||||
md := p.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, "safetensors", md.Format)
|
||||
assert.Equal(t, uint64(2), md.TensorCount)
|
||||
assert.Equal(t, "BF16", md.Quantization)
|
||||
assert.Equal(t, userMeta, md.UserMetadata)
|
||||
assert.Equal(t, wantHash, md.MetadataHash)
|
||||
})
|
||||
|
||||
t.Run("merges with config-derived named package and lifts ShardCount", func(t *testing.T) {
|
||||
// Synthesize what the OCI scan would produce: one config-derived named
|
||||
// package + one weight-layer derived nameless package. Run them through
|
||||
// the merge processor and assert the result looks like a complete model.
|
||||
configMd := pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
Architecture: "Qwen3ForCausalLM",
|
||||
Parameters: "2.68B",
|
||||
TotalSize: "5.00GB",
|
||||
Quantization: "Q4_K_M", // raw producer string
|
||||
}
|
||||
named := pkg.Package{Name: "qwen", Type: pkg.ModelPkg, Metadata: configMd}
|
||||
|
||||
reader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(blob)))
|
||||
layerPkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, reader)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, layerPkgs, 1)
|
||||
|
||||
out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{named}, layerPkgs...), nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, 1, md.ShardCount, "merge processor should set ShardCount from absorbed parts")
|
||||
// Producer-declared top-level fields are preserved.
|
||||
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture)
|
||||
assert.Equal(t, "Q4_K_M", md.Quantization)
|
||||
// The header-derived hash lives in Parts so callers can compare against a dir scan.
|
||||
require.Len(t, md.Parts, 1)
|
||||
// MetadataHash is cleared on absorbed parts by the existing merge processor.
|
||||
// What survives is the rest of the per-shard metadata (UserMetadata, TensorCount,
|
||||
// header-derived Quantization). Confirm those are intact.
|
||||
assert.Equal(t, userMeta, md.Parts[0].UserMetadata)
|
||||
assert.Equal(t, uint64(2), md.Parts[0].TensorCount)
|
||||
assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSafeTensorsCrossSourceHashParity(t *testing.T) {
|
||||
// Same content, two paths: a directory scan via parseSafeTensorsFile, and an
|
||||
// OCI weight-layer scan via parseSafeTensorsOCILayer. The MetadataHash of
|
||||
// the dir-scan package must equal the per-shard hash captured before the
|
||||
// merge processor absorbs it. This is the convergence point that lets a
|
||||
// caller correlate the two source types.
|
||||
tensors := map[string]safeTensorsEntry{
|
||||
"a.weight": {DType: "BF16", Shape: []int64{8, 8}, DataOffsets: []int64{0, 128}},
|
||||
"b.weight": {DType: "BF16", Shape: []int64{4, 4}, DataOffsets: []int64{128, 160}},
|
||||
}
|
||||
userMeta := map[string]string{"format": "pt", "producer": "test"}
|
||||
blob := buildSafeTensorsFile(t, userMeta, tensors)
|
||||
|
||||
// dir-scan path
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), blob, 0o644))
|
||||
dirReader := func() file.LocationReadCloser {
|
||||
f, err := os.Open(filepath.Join(dir, "model.safetensors"))
|
||||
require.NoError(t, err)
|
||||
return file.NewLocationReadCloser(file.NewLocation(filepath.Join(dir, "model.safetensors")), f)
|
||||
}()
|
||||
dirPkgs, _, err := parseSafeTensorsFile(context.Background(), nil, nil, dirReader)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dirPkgs, 1)
|
||||
dirHash := dirPkgs[0].Metadata.(pkg.SafeTensorsModelInfo).MetadataHash
|
||||
require.NotEmpty(t, dirHash)
|
||||
|
||||
// OCI weight-layer path
|
||||
ociReader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(blob)))
|
||||
ociPkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, ociReader)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ociPkgs, 1)
|
||||
ociHash := ociPkgs[0].Metadata.(pkg.SafeTensorsModelInfo).MetadataHash
|
||||
|
||||
assert.Equal(t, dirHash, ociHash, "same content via dir scan and OCI weight-layer scan must hash equal")
|
||||
}
|
||||
|
||||
func configReader(blob []byte) file.LocationReadCloser {
|
||||
return file.NewLocationReadCloser(file.NewLocation("/config.json"), io.NopCloser(bytes.NewReader(blob)))
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/anchore/syft/syft/artifact"
|
||||
"github.com/anchore/syft/syft/pkg"
|
||||
)
|
||||
@ -89,9 +91,16 @@ func safeTensorsMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship,
|
||||
}
|
||||
|
||||
if len(namedPkgs) == 1 && len(namelessParts) > 0 {
|
||||
// Sort by MetadataHash so OCI layer order (map iteration) doesn't leak
|
||||
// into the SBOM output.
|
||||
sort.Slice(namelessParts, func(i, j int) bool {
|
||||
return namelessParts[i].MetadataHash < namelessParts[j].MetadataHash
|
||||
})
|
||||
winner := &namedPkgs[0]
|
||||
if md, ok := winner.Metadata.(pkg.SafeTensorsModelInfo); ok {
|
||||
md.Parts = namelessParts
|
||||
// Trust per-shard headers over the producer-declared shard count.
|
||||
md.ShardCount = len(namelessParts)
|
||||
winner.Metadata = md
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,12 +88,12 @@ func validateAndFetchArtifact(ctx context.Context, client *registryClient, refer
|
||||
// model artifact and stores them on disk so the ContainerImageModel resolver
|
||||
// can serve them by media type:
|
||||
//
|
||||
// - For GGUF: the first maxHeaderBytes of each weight layer (existing behavior).
|
||||
// - For SafeTensors: the model-config blob (already in memory as RawConfig)
|
||||
// plus each companion layer in full. We deliberately skip the multi-GB
|
||||
// safetensors weight layers — the config blob carries aggregate metadata
|
||||
// (format, quantization, parameter count, tensor count, total size) that
|
||||
// the cataloger needs, and individual shard headers are not yet used.
|
||||
// - For GGUF: the first maxHeaderBytes of each weight layer.
|
||||
// - For SafeTensors: the model-config blob (already in memory as RawConfig),
|
||||
// each companion layer in full, and the first maxHeaderBytes of each
|
||||
// weight layer. The weight-layer prefix is enough to read the JSON header
|
||||
// (tensor map + __metadata__), which is what the cataloger hashes for
|
||||
// cross-source identity.
|
||||
func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, artifact *modelArtifact) (string, *fileresolver.ContainerImageModel, error) {
|
||||
tempDir, err := os.MkdirTemp("", "syft-oci-model")
|
||||
if err != nil {
|
||||
@ -108,7 +108,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
|
||||
|
||||
layerFiles := make(map[string]fileresolver.LayerInfo)
|
||||
|
||||
// GGUF weight-layer headers (unchanged).
|
||||
// GGUF weight-layer headers.
|
||||
for _, layer := range artifact.GGUFLayers {
|
||||
li, err := fetchSingleGGUFHeader(ctx, client, artifact.Reference, layer, tempDir)
|
||||
if err != nil {
|
||||
@ -143,6 +143,21 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
|
||||
}
|
||||
}
|
||||
|
||||
// SafeTensors weight-layer headers. We only pull the leading prefix (same
|
||||
// budget as a GGUF header) because the JSON header lives at the very start
|
||||
// of the file — multi-GB tensor data that follows is intentionally not
|
||||
// downloaded.
|
||||
if artifact.Format == modelFormatSafeTensors {
|
||||
for _, layer := range artifact.SafeTensorsLayers {
|
||||
li, err := fetchSafeTensorsLayerHeader(ctx, client, artifact.Reference, layer, tempDir)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", nil, err
|
||||
}
|
||||
layerFiles[layer.Digest.String()] = li
|
||||
}
|
||||
}
|
||||
|
||||
resolver := fileresolver.NewContainerImageModel(tempDir, layerFiles)
|
||||
|
||||
return tempDir, resolver, nil
|
||||
@ -202,6 +217,26 @@ func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fetchSafeTensorsLayerHeader fetches the leading bytes of a SafeTensors weight
|
||||
// layer (enough to cover the JSON header) and writes them to a temp file.
|
||||
func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
|
||||
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes)
|
||||
if err != nil {
|
||||
return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch safetensors layer header: %w", err)
|
||||
}
|
||||
|
||||
safeDigest := strings.ReplaceAll(layer.Digest.String(), ":", "-")
|
||||
tempPath := filepath.Join(tempDir, safeDigest+".safetensors")
|
||||
if err := os.WriteFile(tempPath, headerData, 0600); err != nil {
|
||||
return fileresolver.LayerInfo{}, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
|
||||
return fileresolver.LayerInfo{
|
||||
TempPath: tempPath,
|
||||
MediaType: string(layer.MediaType),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildMetadata constructs OCIModelMetadata from a modelArtifact.
|
||||
func buildMetadata(artifact *modelArtifact) source.OCIModelMetadata {
|
||||
// layers
|
||||
|
||||
@ -133,9 +133,8 @@ type modelArtifact struct {
|
||||
GGUFLayers []v1.Descriptor
|
||||
|
||||
// SafeTensorsLayers are descriptors for layers carrying SafeTensors-format weights.
|
||||
// For safetensors we do NOT fetch these layers — the model-config blob already
|
||||
// contains the aggregate metadata we need — but we record them here for counting
|
||||
// and for future per-shard parsing.
|
||||
// We fetch the first maxHeaderBytes of each so the cataloger can read the JSON
|
||||
// header (tensor map + __metadata__) without pulling the multi-GB tensor data.
|
||||
SafeTensorsLayers []v1.Descriptor
|
||||
|
||||
// CompanionLayers are non-weight layers (README, config.json, license) that
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user