pr: review

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-06-05 11:09:15 -04:00
parent b216dad4a7
commit 16d0449cc8
No known key found for this signature in database
8 changed files with 198 additions and 94 deletions

View File

@ -0,0 +1,166 @@
package ai
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/anchore/syft/syft/file"
"github.com/anchore/syft/syft/pkg"
)
// stPkg builds a model package carrying the given metadata, with each path
// recorded as a primary-evidence location.
func stPkg(md pkg.SafeTensorsModelInfo, paths ...string) pkg.Package {
locs := make([]file.Location, 0, len(paths))
for _, p := range paths {
locs = append(locs, file.NewLocation(p).WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation))
}
return pkg.Package{Type: pkg.ModelPkg, Metadata: md, Locations: file.NewLocationSet(locs...)}
}
// shardMeta is a content-derived shard entry: it carries a MetadataHash, which is
// what marks a group member as a shard (vs. a hash-less aggregate config blob).
func shardMeta(hash string, tensorCount uint64) pkg.SafeTensorsModelInfo {
return pkg.SafeTensorsModelInfo{
Format: "safetensors",
TensorCount: tensorCount,
Quantization: "BF16",
Parameters: "1.00K",
MetadataHash: hash,
}
}
// TestMergeSafeTensorsGroup exercises the rollup contract directly (the cataloger
// tests cover it only as a side effect of the merge processor). It locks how a
// group's per-member metadata folds into one package: tensor-count summing,
// aggregate-over-shard field precedence, UserMetadata dedup + sorting, Parts
// rollup, ShardCount derivation, and the content MetadataHash rollup.
func TestMergeSafeTensorsGroup(t *testing.T) {
t.Run("single shard: hash passes through, ShardCount 1, no Parts", func(t *testing.T) {
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(shardMeta("aaaa", 5), "/m/a.safetensors")})
md := out.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, pkg.ModelPkg, out.Type)
assert.Equal(t, 1, md.ShardCount)
assert.Equal(t, uint64(5), md.TensorCount)
assert.Equal(t, "aaaa", md.MetadataHash, "a single shard's hash passes through unchanged")
assert.Nil(t, md.Parts, "single-shard models do not populate Parts")
})
t.Run("multi-shard: tensors summed, Parts sorted by hash, rollup is order-independent", func(t *testing.T) {
in := []pkg.Package{
stPkg(shardMeta("cccc", 3), "/m/c.safetensors"),
stPkg(shardMeta("aaaa", 3), "/m/a.safetensors"),
stPkg(shardMeta("bbbb", 3), "/m/b.safetensors"),
}
out := mergeSafeTensorsGroup(in)
md := out.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, 3, md.ShardCount)
assert.Equal(t, uint64(9), md.TensorCount, "tensor counts are summed across shards")
require.Len(t, md.Parts, 3)
assert.Equal(t,
[]string{"aaaa", "bbbb", "cccc"},
[]string{md.Parts[0].MetadataHash, md.Parts[1].MetadataHash, md.Parts[2].MetadataHash},
"Parts are sorted by metadata hash",
)
assert.Equal(t, rollupHash([]string{"aaaa", "bbbb", "cccc"}), md.MetadataHash)
// the rollup hash must not depend on the order members arrive in
shuffled := []pkg.Package{
stPkg(shardMeta("bbbb", 3), "/m/b.safetensors"),
stPkg(shardMeta("cccc", 3), "/m/c.safetensors"),
stPkg(shardMeta("aaaa", 3), "/m/a.safetensors"),
}
out2 := mergeSafeTensorsGroup(shuffled)
assert.Equal(t, md.MetadataHash, out2.Metadata.(pkg.SafeTensorsModelInfo).MetadataHash)
})
t.Run("aggregate fields win over shard-derived fields", func(t *testing.T) {
// an aggregate (OCI config blob) carries no MetadataHash but declares the
// authoritative totals.
aggregate := pkg.SafeTensorsModelInfo{
Format: "safetensors",
TensorCount: 999,
TotalSize: "5.00GB",
Parameters: "2.68B",
Quantization: "Q4_K_M",
}
in := []pkg.Package{
stPkg(aggregate, "/"),
stPkg(shardMeta("aaaa", 3), "/"),
stPkg(shardMeta("bbbb", 3), "/"),
}
out := mergeSafeTensorsGroup(in)
md := out.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, uint64(999), md.TensorCount, "aggregate TensorCount is authoritative; shard counts are not summed in")
assert.Equal(t, "5.00GB", md.TotalSize)
assert.Equal(t, "2.68B", md.Parameters)
assert.Equal(t, "Q4_K_M", md.Quantization, "aggregate quantization wins over the shard dtype")
assert.Equal(t, 2, md.ShardCount, "ShardCount comes from the number of shards, not the aggregate")
assert.Equal(t, rollupHash([]string{"aaaa", "bbbb"}), md.MetadataHash, "the content hash still rolls up the shard hashes")
})
t.Run("aggregate-only group: ShardCount 1, empty hash, no Parts", func(t *testing.T) {
aggregate := pkg.SafeTensorsModelInfo{Format: "safetensors", TensorCount: 42, TotalSize: "1.00GB"}
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(aggregate, "/")})
md := out.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, uint64(42), md.TensorCount)
assert.Equal(t, 1, md.ShardCount, "a group with no shards still reports a single shard")
assert.Equal(t, "", md.MetadataHash, "there are no shard hashes to roll up")
assert.Nil(t, md.Parts)
})
t.Run("UserMetadata: keys merged and sorted, first value wins on conflict", func(t *testing.T) {
// keys are intentionally unsorted within each shard so the assertion proves
// the merge re-sorts globally; "format" appears in both shards so dedup
// precedence (first wins) is exercised too.
s1 := shardMeta("aaaa", 1)
s1.UserMetadata = pkg.KeyValues{{Key: "format", Value: "pt"}, {Key: "author", Value: "alice"}}
s2 := shardMeta("bbbb", 1)
s2.UserMetadata = pkg.KeyValues{{Key: "format", Value: "gguf"}, {Key: "license", Value: "mit"}}
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(s1, "/m/a.safetensors"), stPkg(s2, "/m/b.safetensors")})
md := out.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, pkg.KeyValues{
{Key: "author", Value: "alice"},
{Key: "format", Value: "pt"}, // first shard's value wins over s2's "gguf"
{Key: "license", Value: "mit"},
}, md.UserMetadata)
})
t.Run("members without safetensors metadata are ignored in the rollup", func(t *testing.T) {
notST := pkg.Package{
Type: pkg.ModelPkg,
Metadata: pkg.GGUFFileHeader{},
Locations: file.NewLocationSet(file.NewLocation("/m/x.gguf")),
}
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(shardMeta("aaaa", 2), "/m/a.safetensors"), notST})
md := out.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, uint64(2), md.TensorCount, "only the safetensors shard contributes")
assert.Equal(t, 1, md.ShardCount)
assert.Equal(t, "aaaa", md.MetadataHash)
})
}
// TestRollupHash locks the cross-source content-fingerprint rollup: empty input
// yields no hash, a lone shard's hash passes through unchanged (so a single-shard
// model fingerprints identically across directory and OCI sources), and multiple
// shards fold into one order-independent digest.
func TestRollupHash(t *testing.T) {
assert.Equal(t, "", rollupHash(nil), "no hashes → empty")
assert.Equal(t, "solo", rollupHash([]string{"solo"}), "a single hash passes through unchanged")
ab := rollupHash([]string{"a", "b"})
ba := rollupHash([]string{"b", "a"})
assert.Equal(t, ab, ba, "the rollup is independent of input order")
assert.Len(t, ab, 16, "a multi-hash rollup is a 16-char xxhash")
assert.NotEqual(t, "a", ab)
assert.NotEqual(t, "b", ab)
}

View File

@ -3,9 +3,6 @@ package ai
import "path" import "path"
// pickSafeTensorsName implements the documented naming precedence chain: // pickSafeTensorsName implements the documented naming precedence chain:
// - config.json _name_or_path (path.Base, so "org/Model" → "Model";
// applies to both dir-scan and OCI groups)
// - fallback name — the group's source-specific positional identifier
func pickSafeTensorsName(nameOrPath, fallbackName string) string { func pickSafeTensorsName(nameOrPath, fallbackName string) string {
if nameOrPath != "" { if nameOrPath != "" {
return path.Base(nameOrPath) return path.Base(nameOrPath)
@ -15,8 +12,8 @@ func pickSafeTensorsName(nameOrPath, fallbackName string) string {
// safeTensorsDirName returns the directory-scan naming fallback: the base name // safeTensorsDirName returns the directory-scan naming fallback: the base name
// of the group's parent directory (the group key is already that directory). // of the group's parent directory (the group key is already that directory).
func safeTensorsDirName(groupKey string) string { func safeTensorsDirName(directory string) string {
base := path.Base(groupKey) base := path.Base(directory)
switch base { switch base {
case "/", ".", "": case "/", ".", "":
return "" return ""

View File

@ -22,10 +22,7 @@ func newGGUFPackage(metadata *pkg.GGUFFileHeader, modelName, version, license st
} }
// newSafeTensorsPackage creates a SafeTensors package with the given metadata // newSafeTensorsPackage creates a SafeTensors package with the given metadata
// and locations. Name and Licenses are intentionally not set here — the // and locations. Name and Licenses are intentionally not set here and done at the processor level
// safetensors cataloger emits nameless packages from every parser, and the
// merge processor is the single owner of naming, license resolution, and
// supporting-evidence attachment.
func newSafeTensorsPackage(metadata *pkg.SafeTensorsModelInfo, locations ...file.Location) pkg.Package { func newSafeTensorsPackage(metadata *pkg.SafeTensorsModelInfo, locations ...file.Location) pkg.Package {
p := pkg.Package{ p := pkg.Package{
Locations: file.NewLocationSet(locations...), Locations: file.NewLocationSet(locations...),

View File

@ -47,12 +47,7 @@ type dockerAIModelConfig struct {
} `json:"config"` } `json:"config"`
} }
// parseSafeTensorsOCIConfig decodes the Docker AI model-config blob and emits // parseSafeTensorsOCIConfig decodes the Docker AI model-config blob
// a nameless package whose metadata mirrors the producer-declared aggregate
// fields (Format, Quantization, Parameters, Size, TensorCount). For any
// format other than "safetensors" it emits nothing so the GGUF cataloger can
// claim the artifact. Naming, license, and HF-companion enrichment all run
// once per group in safeTensorsMergeProcessor.
func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
defer internal.CloseAndLogError(reader, reader.Path()) defer internal.CloseAndLogError(reader, reader.Path())
@ -88,10 +83,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En
} }
// parseSafeTensorsOCILayer decodes the JSON header of a SafeTensors weight // parseSafeTensorsOCILayer decodes the JSON header of a SafeTensors weight
// layer fetched from an OCI model artifact (the source layer caps each layer // layer fetched from an OCI model artifact
// at a small prefix; tensor data is never downloaded). It emits a nameless
// package; safeTensorsMergeProcessor folds it into the artifact's group and
// rolls per-shard fields up into the final merged package.
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
defer internal.CloseAndLogError(reader, reader.Path()) defer internal.CloseAndLogError(reader, reader.Path())

View File

@ -15,12 +15,6 @@ import (
// assembly. SafeTensors packages reach it nameless from the parsers; it groups // assembly. SafeTensors packages reach it nameless from the parsers; it groups
// them per model, merges the per-shard metadata, resolves a name + licenses, and // them per model, merges the per-shard metadata, resolves a name + licenses, and
// drops any model it cannot name. // drops any model it cannot name.
//
// There are exactly two sources, each handled by its own path:
// - an OCI model artifact, where the source presents every layer at the
// virtual path "/" and the whole scan is a single model (mergeOCIModel)
// - a filesystem scan, where models are grouped by the directory their files
// live in (mergeDirModels)
func safeTensorsMergeProcessor(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) { func safeTensorsMergeProcessor(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) {
if err != nil || len(pkgs) == 0 { if err != nil || len(pkgs) == 0 {
return pkgs, rels, err return pkgs, rels, err
@ -55,12 +49,6 @@ func partitionSafeTensorsPackages(pkgs []pkg.Package) (safeTensors, other []pkg.
// That source (the ContainerImageModel resolver) presents every layer at the // That source (the ContainerImageModel resolver) presents every layer at the
// virtual path "/", whereas a filesystem scan always carries a real file path. A // virtual path "/", whereas a filesystem scan always carries a real file path. A
// single scan is one source, so the first package is representative of the rest. // single scan is one source, so the first package is representative of the rest.
//
// This deliberately keys off the path signal rather than type-asserting the
// resolver to file.OCIMediaTypeResolver: the test harness wraps resolvers in an
// ObservingResolver that implements that interface unconditionally, so an
// interface check would misclassify directory scans as OCI. The "/" path is the
// genuine, testable signal the OCI model source produces.
func fromOCIArtifact(pkgs []pkg.Package) bool { func fromOCIArtifact(pkgs []pkg.Package) bool {
loc := primaryEvidenceLocation(pkgs[0]) loc := primaryEvidenceLocation(pkgs[0])
return loc != nil && loc.RealPath == "/" return loc != nil && loc.RealPath == "/"
@ -83,8 +71,7 @@ func mergeOCIModel(ctx context.Context, resolver file.Resolver, pkgs []pkg.Packa
} }
// mergeDirModels groups filesystem-scanned files by their parent directory and // mergeDirModels groups filesystem-scanned files by their parent directory and
// emits one model per directory, named from a sibling config.json/README or the // emits one model per directory
// directory itself.
func mergeDirModels(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package) []pkg.Package { func mergeDirModels(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package) []pkg.Package {
groups := groupByParentDir(pkgs) groups := groupByParentDir(pkgs)

View File

@ -2,24 +2,16 @@ package pkg
// SafeTensorsModelInfo holds the model details extracted from SafeTensors content. // SafeTensorsModelInfo holds the model details extracted from SafeTensors content.
// SafeTensors is a simple, safe serialization format for storing tensors, used // SafeTensors is a simple, safe serialization format for storing tensors, used
// as the default weight format for Hugging Face transformer models. Syft may // as the default weight format for Hugging Face transformer models.
// populate this struct from these sources: // Model name, license, and version live on the syft Package
// - a single .safetensors file (header-only parse)
// - the per-shard headers of a multi-shard model, merged into one package
// - a Docker AI OCI model artifact: the config blob
// (vnd.docker.ai.model.config.v0.1+json) plus each weight layer's header
//
// Model name, license, and version live on the enclosing syft Package rather
// than in this struct.
type SafeTensorsModelInfo struct { type SafeTensorsModelInfo struct {
// Format is the source format label (always "safetensors" for this metadata type). // Format is the source format label (always "safetensors" for this metadata type).
// Present because the Docker AI model config blob carries an explicit format field // Present because the Docker AI model config blob carries an explicit format field
// that can also be "gguf", and recording it here makes the origin explicit.
Format string `json:"format,omitempty" cyclonedx:"format"` Format string `json:"format,omitempty" cyclonedx:"format"`
// Architecture is the model architecture (e.g., "LlamaForCausalLM", // Architecture is the model architecture (e.g., "LlamaForCausalLM",
// "Qwen3MoeForConditionalGeneration"). It is not present in the SafeTensors // "Qwen3MoeForConditionalGeneration"). It is not present in the SafeTensors
// header itself; it is enriched from the companion Hugging Face config.json // header itself; it is enriched from the companion config.json
// "architectures" array when one is found alongside the model. // "architectures" array when one is found alongside the model.
Architecture string `json:"architecture,omitempty" cyclonedx:"architecture"` Architecture string `json:"architecture,omitempty" cyclonedx:"architecture"`
@ -42,15 +34,11 @@ type SafeTensorsModelInfo struct {
ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"` ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"`
// UserMetadata is the optional "__metadata__" map from a .safetensors file header // UserMetadata is the optional "__metadata__" map from a .safetensors file header
// (string-to-string key/values set by the producer). Stored as a sorted KeyValues // (string-to-string key/values set by the producer).
// slice rather than a Go map so SBOM output is stable across runs.
UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"` UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"`
// MetadataHash is an xxhash over the on-disk SafeTensors header (sorted tensor // MetadataHash is an xxhash over the on-disk SafeTensors header (sorted tensor
// entries + __metadata__). It is derived ONLY from the safetensors file bytes — // entries + __metadata__). It is derived ONLY from the safetensors file bytes.
// never from OCI manifest, layer descriptor, or config-blob fields — so the same
// model content scanned via a directory source and via an OCI image produces the
// same value. Treat this as the cross-source content fingerprint.
MetadataHash string `json:"metadataHash,omitempty" cyclonedx:"metadataHash"` MetadataHash string `json:"metadataHash,omitempty" cyclonedx:"metadataHash"`
// Parts contains metadata from additional SafeTensors shards or OCI layers that // Parts contains metadata from additional SafeTensors shards or OCI layers that

View File

@ -86,14 +86,7 @@ func validateAndFetchArtifact(ctx context.Context, client *registryClient, refer
// fetchAndStoreModelHeaders fetches the blobs needed to catalog a Docker AI // fetchAndStoreModelHeaders fetches the blobs needed to catalog a Docker AI
// model artifact and stores them on disk so the ContainerImageModel resolver // model artifact and stores them on disk so the ContainerImageModel resolver
// can serve them by media type: // can serve them by media type
//
// - For GGUF: the first maxHeaderBytes of each weight layer.
// - For SafeTensors: the model-config blob (already in memory as RawConfig),
// each companion layer in full, and the first maxHeaderBytes of each
// weight layer. The weight-layer prefix is enough to read the JSON header
// (tensor map + __metadata__), which is what the cataloger hashes for
// cross-source identity.
func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, artifact *modelArtifact) (string, *fileresolver.ContainerImageModel, error) { func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, artifact *modelArtifact) (string, *fileresolver.ContainerImageModel, error) {
tempDir, err := os.MkdirTemp("", "syft-oci-model") tempDir, err := os.MkdirTemp("", "syft-oci-model")
if err != nil { if err != nil {
@ -119,8 +112,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
} }
// For SafeTensors artifacts, expose the model-config blob to the resolver // For SafeTensors artifacts, expose the model-config blob to the resolver
// so parseSafeTensorsOCIConfig can match it by media type. RawConfig was // so parseSafeTensorsOCIConfig can match it by media type.
// already fetched as part of the manifest walk.
if artifact.Format == modelFormatSafeTensors && len(artifact.RawConfig) > 0 { if artifact.Format == modelFormatSafeTensors && len(artifact.RawConfig) > 0 {
li, err := storeConfigBlobAsLayer(artifact, tempDir) li, err := storeConfigBlobAsLayer(artifact, tempDir)
if err != nil { if err != nil {
@ -144,9 +136,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
} }
// SafeTensors weight-layer headers. We only pull the leading prefix (same // SafeTensors weight-layer headers. We only pull the leading prefix (same
// budget as a GGUF header) because the JSON header lives at the very start // budget as a GGUF header)
// of the file — multi-GB tensor data that follows is intentionally not
// downloaded.
if artifact.Format == modelFormatSafeTensors { if artifact.Format == modelFormatSafeTensors {
for _, layer := range artifact.SafeTensorsLayers { for _, layer := range artifact.SafeTensorsLayers {
li, err := fetchSafeTensorsLayerHeader(ctx, client, artifact.Reference, layer, tempDir) li, err := fetchSafeTensorsLayerHeader(ctx, client, artifact.Reference, layer, tempDir)
@ -168,7 +158,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
} }
// storeConfigBlobAsLayer writes the already-fetched raw config bytes to a temp // storeConfigBlobAsLayer writes the already-fetched raw config bytes to a temp
// file so the resolver can serve them via media type. // file so the resolver can serve them via media type
func storeConfigBlobAsLayer(artifact *modelArtifact, tempDir string) (fileresolver.LayerInfo, error) { func storeConfigBlobAsLayer(artifact *modelArtifact, tempDir string) (fileresolver.LayerInfo, error) {
digest := artifact.Manifest.Config.Digest.String() digest := artifact.Manifest.Config.Digest.String()
safeDigest := strings.ReplaceAll(digest, ":", "-") safeDigest := strings.ReplaceAll(digest, ":", "-")
@ -182,9 +172,7 @@ func storeConfigBlobAsLayer(artifact *modelArtifact, tempDir string) (fileresolv
}, nil }, nil
} }
// fetchCompanionLayer downloads a companion (non-weight) layer to a temp file. // fetchCompanionLayer downloads a companion (non-weight) layer to a temp file
// Unlike weight layers we fetch up to maxCompanionBytes, which comfortably
// covers READMEs, HF config.json, tokenizer.json, and LICENSE text.
func fetchCompanionLayer(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) { func fetchCompanionLayer(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
data, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxCompanionBytes) data, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxCompanionBytes)
if err != nil { if err != nil {
@ -201,9 +189,9 @@ func fetchCompanionLayer(ctx context.Context, client *registryClient, ref name.R
}, nil }, nil
} }
// fetchSingleGGUFHeader fetches a single GGUF layer header and writes it to a temp file. // fetchSingleGGUFHeader fetches a single GGUF layer header and writes it to a temp file
func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) { func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes) headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxWeightHeaderBytes)
if err != nil { if err != nil {
return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch GGUF layer header: %w", err) return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch GGUF layer header: %w", err)
} }
@ -222,9 +210,9 @@ func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name
} }
// fetchSafeTensorsLayerHeader fetches the leading bytes of a SafeTensors weight // fetchSafeTensorsLayerHeader fetches the leading bytes of a SafeTensors weight
// layer (enough to cover the JSON header) and writes them to a temp file. // layer (enough to cover the JSON header) and writes them to a temp file
func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) { func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes) headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxWeightHeaderBytes)
if err != nil { if err != nil {
return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch safetensors layer header: %w", err) return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch safetensors layer header: %w", err)
} }
@ -241,7 +229,7 @@ func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, re
}, nil }, nil
} }
// buildMetadata constructs OCIModelMetadata from a modelArtifact. // buildMetadata constructs OCIModelMetadata from a modelArtifact
func buildMetadata(artifact *modelArtifact) source.OCIModelMetadata { func buildMetadata(artifact *modelArtifact) source.OCIModelMetadata {
// layers // layers
layers := make([]source.LayerMetadata, len(artifact.Manifest.Layers)) layers := make([]source.LayerMetadata, len(artifact.Manifest.Layers))

View File

@ -37,10 +37,13 @@ const (
modelFormatGGUF = "gguf" modelFormatGGUF = "gguf"
modelFormatSafeTensors = "safetensors" modelFormatSafeTensors = "safetensors"
// Maximum bytes to read/return for weight-layer headers (GGUF + safetensors). // maxWeightHeaderBytes is the leading slice we range-GET from a (multi-GB)
maxHeaderBytes = 8 * 1024 * 1024 // 8 MB // weight layer — enough to cover the GGUF/safetensors header.
// Maximum bytes to fetch for a companion metadata layer (README, config.json, license). maxWeightHeaderBytes = 8 * 1024 * 1024 // 8 MB
// These blobs are small by convention; cap well below a safetensors header.
// maxCompanionBytes caps a whole companion blob (README, config.json,
// license); these are small by convention. Matches the 4 MB read cap in
// classifyOCIModelFileLayer.
maxCompanionBytes = 4 * 1024 * 1024 // 4 MB maxCompanionBytes = 4 * 1024 * 1024 // 4 MB
) )
@ -129,17 +132,13 @@ type modelArtifact struct {
Format string Format string
// GGUFLayers are descriptors for layers carrying GGUF-format weights. // GGUFLayers are descriptors for layers carrying GGUF-format weights.
// We fetch the first few MB of each to read the header. // We fetch the first few MB of each to read the header data
GGUFLayers []v1.Descriptor GGUFLayers []v1.Descriptor
// SafeTensorsLayers are descriptors for layers carrying SafeTensors-format weights. // SafeTensorsLayers are descriptors for layers carrying SafeTensors-format weights.
// We fetch the first maxHeaderBytes of each so the cataloger can read the JSON
// header (tensor map + __metadata__) without pulling the multi-GB tensor data.
SafeTensorsLayers []v1.Descriptor SafeTensorsLayers []v1.Descriptor
// CompanionLayers are non-weight layers (README, config.json, license) that // CompanionLayers are non-weight layers (README, config.json, license)
// we do fetch (in full, given their small size) so companion-file parsing
// in the safetensors cataloger can find them via media type.
CompanionLayers []v1.Descriptor CompanionLayers []v1.Descriptor
} }
@ -199,10 +198,7 @@ func (c *registryClient) fetchModelArtifact(ctx context.Context, refStr string)
} }
// detectModelFormat returns a single format string when either GGUF or // detectModelFormat returns a single format string when either GGUF or
// SafeTensors weight layers are present. When both appear (not expected in // SafeTensors weight layers are present.
// practice for Docker Model Runner artifacts), GGUF wins because the GGUF
// cataloger is the more established path. Empty result means the manifest has
// no recognized weight layers.
func detectModelFormat(ggufCount, safetensorsCount int) string { func detectModelFormat(ggufCount, safetensorsCount int) string {
switch { switch {
case ggufCount > 0: case ggufCount > 0:
@ -271,14 +267,6 @@ func (c *registryClient) fetchBlobRange(ctx context.Context, ref name.Reference,
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get layer reader: %w", err) return nil, fmt.Errorf("failed to get layer reader: %w", err)
} }
// this defer is what causes the download to stop
// 1. io.ReadFull(reader, data) reads exactly 8MB into the buffer
// 2. The function returns with data[:n]
// 3. defer reader.Close() executes, closing the HTTP response body
// 4. Closing the response body closes the underlying TCP connection
// 5. The server receives TCP FIN/RST and stops sending
// note: some data is already in flight when we close so we will see > 8mb over the wire
// the full image will not download given we terminate the reader early here
defer reader.Close() defer reader.Close()
// Note: this is not some arbitrary number picked out of the blue. // Note: this is not some arbitrary number picked out of the blue.
@ -286,6 +274,7 @@ func (c *registryClient) fetchBlobRange(ctx context.Context, ref name.Reference,
// https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#file-structure // https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#file-structure
data := make([]byte, maxBytes) data := make([]byte, maxBytes)
n, err := io.ReadFull(reader, data) n, err := io.ReadFull(reader, data)
// ErrUnexpectedEOF means the layer is smaller than maxBytes; EOF means it is // ErrUnexpectedEOF means the layer is smaller than maxBytes; EOF means it is
// empty. Both mean we read everything there was, not a failure. // empty. Both mean we read everything there was, not a failure.
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {