From 16d0449cc8fb330fe8fc32e640389df5d153cd56 Mon Sep 17 00:00:00 2001 From: Christopher Phillips <32073428+spiffcs@users.noreply.github.com> Date: Fri, 5 Jun 2026 11:09:15 -0400 Subject: [PATCH] pr: review Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com> --- syft/pkg/cataloger/ai/merge_test.go | 166 ++++++++++++++++++ syft/pkg/cataloger/ai/naming.go | 7 +- syft/pkg/cataloger/ai/package.go | 5 +- .../pkg/cataloger/ai/parse_safetensors_oci.go | 12 +- syft/pkg/cataloger/ai/processor.go | 15 +- syft/pkg/safetensors.go | 22 +-- .../source/ocimodelsource/oci_model_source.go | 32 ++-- syft/source/ocimodelsource/registry_client.go | 33 ++-- 8 files changed, 198 insertions(+), 94 deletions(-) create mode 100644 syft/pkg/cataloger/ai/merge_test.go diff --git a/syft/pkg/cataloger/ai/merge_test.go b/syft/pkg/cataloger/ai/merge_test.go new file mode 100644 index 000000000..06631db53 --- /dev/null +++ b/syft/pkg/cataloger/ai/merge_test.go @@ -0,0 +1,166 @@ +package ai + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/anchore/syft/syft/file" + "github.com/anchore/syft/syft/pkg" +) + +// stPkg builds a model package carrying the given metadata, with each path +// recorded as a primary-evidence location. +func stPkg(md pkg.SafeTensorsModelInfo, paths ...string) pkg.Package { + locs := make([]file.Location, 0, len(paths)) + for _, p := range paths { + locs = append(locs, file.NewLocation(p).WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation)) + } + return pkg.Package{Type: pkg.ModelPkg, Metadata: md, Locations: file.NewLocationSet(locs...)} +} + +// shardMeta is a content-derived shard entry: it carries a MetadataHash, which is +// what marks a group member as a shard (vs. a hash-less aggregate config blob). +func shardMeta(hash string, tensorCount uint64) pkg.SafeTensorsModelInfo { + return pkg.SafeTensorsModelInfo{ + Format: "safetensors", + TensorCount: tensorCount, + Quantization: "BF16", + Parameters: "1.00K", + MetadataHash: hash, + } +} + +// TestMergeSafeTensorsGroup exercises the rollup contract directly (the cataloger +// tests cover it only as a side effect of the merge processor). It locks how a +// group's per-member metadata folds into one package: tensor-count summing, +// aggregate-over-shard field precedence, UserMetadata dedup + sorting, Parts +// rollup, ShardCount derivation, and the content MetadataHash rollup. +func TestMergeSafeTensorsGroup(t *testing.T) { + t.Run("single shard: hash passes through, ShardCount 1, no Parts", func(t *testing.T) { + out := mergeSafeTensorsGroup([]pkg.Package{stPkg(shardMeta("aaaa", 5), "/m/a.safetensors")}) + + md := out.Metadata.(pkg.SafeTensorsModelInfo) + assert.Equal(t, pkg.ModelPkg, out.Type) + assert.Equal(t, 1, md.ShardCount) + assert.Equal(t, uint64(5), md.TensorCount) + assert.Equal(t, "aaaa", md.MetadataHash, "a single shard's hash passes through unchanged") + assert.Nil(t, md.Parts, "single-shard models do not populate Parts") + }) + + t.Run("multi-shard: tensors summed, Parts sorted by hash, rollup is order-independent", func(t *testing.T) { + in := []pkg.Package{ + stPkg(shardMeta("cccc", 3), "/m/c.safetensors"), + stPkg(shardMeta("aaaa", 3), "/m/a.safetensors"), + stPkg(shardMeta("bbbb", 3), "/m/b.safetensors"), + } + out := mergeSafeTensorsGroup(in) + + md := out.Metadata.(pkg.SafeTensorsModelInfo) + assert.Equal(t, 3, md.ShardCount) + assert.Equal(t, uint64(9), md.TensorCount, "tensor counts are summed across shards") + require.Len(t, md.Parts, 3) + assert.Equal(t, + []string{"aaaa", "bbbb", "cccc"}, + []string{md.Parts[0].MetadataHash, md.Parts[1].MetadataHash, md.Parts[2].MetadataHash}, + "Parts are sorted by metadata hash", + ) + assert.Equal(t, rollupHash([]string{"aaaa", "bbbb", "cccc"}), md.MetadataHash) + + // the rollup hash must not depend on the order members arrive in + shuffled := []pkg.Package{ + stPkg(shardMeta("bbbb", 3), "/m/b.safetensors"), + stPkg(shardMeta("cccc", 3), "/m/c.safetensors"), + stPkg(shardMeta("aaaa", 3), "/m/a.safetensors"), + } + out2 := mergeSafeTensorsGroup(shuffled) + assert.Equal(t, md.MetadataHash, out2.Metadata.(pkg.SafeTensorsModelInfo).MetadataHash) + }) + + t.Run("aggregate fields win over shard-derived fields", func(t *testing.T) { + // an aggregate (OCI config blob) carries no MetadataHash but declares the + // authoritative totals. + aggregate := pkg.SafeTensorsModelInfo{ + Format: "safetensors", + TensorCount: 999, + TotalSize: "5.00GB", + Parameters: "2.68B", + Quantization: "Q4_K_M", + } + in := []pkg.Package{ + stPkg(aggregate, "/"), + stPkg(shardMeta("aaaa", 3), "/"), + stPkg(shardMeta("bbbb", 3), "/"), + } + out := mergeSafeTensorsGroup(in) + + md := out.Metadata.(pkg.SafeTensorsModelInfo) + assert.Equal(t, uint64(999), md.TensorCount, "aggregate TensorCount is authoritative; shard counts are not summed in") + assert.Equal(t, "5.00GB", md.TotalSize) + assert.Equal(t, "2.68B", md.Parameters) + assert.Equal(t, "Q4_K_M", md.Quantization, "aggregate quantization wins over the shard dtype") + assert.Equal(t, 2, md.ShardCount, "ShardCount comes from the number of shards, not the aggregate") + assert.Equal(t, rollupHash([]string{"aaaa", "bbbb"}), md.MetadataHash, "the content hash still rolls up the shard hashes") + }) + + t.Run("aggregate-only group: ShardCount 1, empty hash, no Parts", func(t *testing.T) { + aggregate := pkg.SafeTensorsModelInfo{Format: "safetensors", TensorCount: 42, TotalSize: "1.00GB"} + out := mergeSafeTensorsGroup([]pkg.Package{stPkg(aggregate, "/")}) + + md := out.Metadata.(pkg.SafeTensorsModelInfo) + assert.Equal(t, uint64(42), md.TensorCount) + assert.Equal(t, 1, md.ShardCount, "a group with no shards still reports a single shard") + assert.Equal(t, "", md.MetadataHash, "there are no shard hashes to roll up") + assert.Nil(t, md.Parts) + }) + + t.Run("UserMetadata: keys merged and sorted, first value wins on conflict", func(t *testing.T) { + // keys are intentionally unsorted within each shard so the assertion proves + // the merge re-sorts globally; "format" appears in both shards so dedup + // precedence (first wins) is exercised too. + s1 := shardMeta("aaaa", 1) + s1.UserMetadata = pkg.KeyValues{{Key: "format", Value: "pt"}, {Key: "author", Value: "alice"}} + s2 := shardMeta("bbbb", 1) + s2.UserMetadata = pkg.KeyValues{{Key: "format", Value: "gguf"}, {Key: "license", Value: "mit"}} + + out := mergeSafeTensorsGroup([]pkg.Package{stPkg(s1, "/m/a.safetensors"), stPkg(s2, "/m/b.safetensors")}) + + md := out.Metadata.(pkg.SafeTensorsModelInfo) + assert.Equal(t, pkg.KeyValues{ + {Key: "author", Value: "alice"}, + {Key: "format", Value: "pt"}, // first shard's value wins over s2's "gguf" + {Key: "license", Value: "mit"}, + }, md.UserMetadata) + }) + + t.Run("members without safetensors metadata are ignored in the rollup", func(t *testing.T) { + notST := pkg.Package{ + Type: pkg.ModelPkg, + Metadata: pkg.GGUFFileHeader{}, + Locations: file.NewLocationSet(file.NewLocation("/m/x.gguf")), + } + out := mergeSafeTensorsGroup([]pkg.Package{stPkg(shardMeta("aaaa", 2), "/m/a.safetensors"), notST}) + + md := out.Metadata.(pkg.SafeTensorsModelInfo) + assert.Equal(t, uint64(2), md.TensorCount, "only the safetensors shard contributes") + assert.Equal(t, 1, md.ShardCount) + assert.Equal(t, "aaaa", md.MetadataHash) + }) +} + +// TestRollupHash locks the cross-source content-fingerprint rollup: empty input +// yields no hash, a lone shard's hash passes through unchanged (so a single-shard +// model fingerprints identically across directory and OCI sources), and multiple +// shards fold into one order-independent digest. +func TestRollupHash(t *testing.T) { + assert.Equal(t, "", rollupHash(nil), "no hashes → empty") + assert.Equal(t, "solo", rollupHash([]string{"solo"}), "a single hash passes through unchanged") + + ab := rollupHash([]string{"a", "b"}) + ba := rollupHash([]string{"b", "a"}) + assert.Equal(t, ab, ba, "the rollup is independent of input order") + assert.Len(t, ab, 16, "a multi-hash rollup is a 16-char xxhash") + assert.NotEqual(t, "a", ab) + assert.NotEqual(t, "b", ab) +} diff --git a/syft/pkg/cataloger/ai/naming.go b/syft/pkg/cataloger/ai/naming.go index b9e766003..ed2571dcb 100644 --- a/syft/pkg/cataloger/ai/naming.go +++ b/syft/pkg/cataloger/ai/naming.go @@ -3,9 +3,6 @@ package ai import "path" // pickSafeTensorsName implements the documented naming precedence chain: -// - config.json _name_or_path (path.Base, so "org/Model" → "Model"; -// applies to both dir-scan and OCI groups) -// - fallback name — the group's source-specific positional identifier func pickSafeTensorsName(nameOrPath, fallbackName string) string { if nameOrPath != "" { return path.Base(nameOrPath) @@ -15,8 +12,8 @@ func pickSafeTensorsName(nameOrPath, fallbackName string) string { // safeTensorsDirName returns the directory-scan naming fallback: the base name // of the group's parent directory (the group key is already that directory). -func safeTensorsDirName(groupKey string) string { - base := path.Base(groupKey) +func safeTensorsDirName(directory string) string { + base := path.Base(directory) switch base { case "/", ".", "": return "" diff --git a/syft/pkg/cataloger/ai/package.go b/syft/pkg/cataloger/ai/package.go index c82ab494c..3abb42343 100644 --- a/syft/pkg/cataloger/ai/package.go +++ b/syft/pkg/cataloger/ai/package.go @@ -22,10 +22,7 @@ func newGGUFPackage(metadata *pkg.GGUFFileHeader, modelName, version, license st } // newSafeTensorsPackage creates a SafeTensors package with the given metadata -// and locations. Name and Licenses are intentionally not set here — the -// safetensors cataloger emits nameless packages from every parser, and the -// merge processor is the single owner of naming, license resolution, and -// supporting-evidence attachment. +// and locations. Name and Licenses are intentionally not set here and done at the processor level func newSafeTensorsPackage(metadata *pkg.SafeTensorsModelInfo, locations ...file.Location) pkg.Package { p := pkg.Package{ Locations: file.NewLocationSet(locations...), diff --git a/syft/pkg/cataloger/ai/parse_safetensors_oci.go b/syft/pkg/cataloger/ai/parse_safetensors_oci.go index f3fce0403..a20988f79 100644 --- a/syft/pkg/cataloger/ai/parse_safetensors_oci.go +++ b/syft/pkg/cataloger/ai/parse_safetensors_oci.go @@ -47,12 +47,7 @@ type dockerAIModelConfig struct { } `json:"config"` } -// parseSafeTensorsOCIConfig decodes the Docker AI model-config blob and emits -// a nameless package whose metadata mirrors the producer-declared aggregate -// fields (Format, Quantization, Parameters, Size, TensorCount). For any -// format other than "safetensors" it emits nothing so the GGUF cataloger can -// claim the artifact. Naming, license, and HF-companion enrichment all run -// once per group in safeTensorsMergeProcessor. +// parseSafeTensorsOCIConfig decodes the Docker AI model-config blob func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { defer internal.CloseAndLogError(reader, reader.Path()) @@ -88,10 +83,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En } // parseSafeTensorsOCILayer decodes the JSON header of a SafeTensors weight -// layer fetched from an OCI model artifact (the source layer caps each layer -// at a small prefix; tensor data is never downloaded). It emits a nameless -// package; safeTensorsMergeProcessor folds it into the artifact's group and -// rolls per-shard fields up into the final merged package. +// layer fetched from an OCI model artifact func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) { defer internal.CloseAndLogError(reader, reader.Path()) diff --git a/syft/pkg/cataloger/ai/processor.go b/syft/pkg/cataloger/ai/processor.go index c685aeb6e..2a0ea0fee 100644 --- a/syft/pkg/cataloger/ai/processor.go +++ b/syft/pkg/cataloger/ai/processor.go @@ -15,12 +15,6 @@ import ( // assembly. SafeTensors packages reach it nameless from the parsers; it groups // them per model, merges the per-shard metadata, resolves a name + licenses, and // drops any model it cannot name. -// -// There are exactly two sources, each handled by its own path: -// - an OCI model artifact, where the source presents every layer at the -// virtual path "/" and the whole scan is a single model (mergeOCIModel) -// - a filesystem scan, where models are grouped by the directory their files -// live in (mergeDirModels) func safeTensorsMergeProcessor(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) { if err != nil || len(pkgs) == 0 { return pkgs, rels, err @@ -55,12 +49,6 @@ func partitionSafeTensorsPackages(pkgs []pkg.Package) (safeTensors, other []pkg. // That source (the ContainerImageModel resolver) presents every layer at the // virtual path "/", whereas a filesystem scan always carries a real file path. A // single scan is one source, so the first package is representative of the rest. -// -// This deliberately keys off the path signal rather than type-asserting the -// resolver to file.OCIMediaTypeResolver: the test harness wraps resolvers in an -// ObservingResolver that implements that interface unconditionally, so an -// interface check would misclassify directory scans as OCI. The "/" path is the -// genuine, testable signal the OCI model source produces. func fromOCIArtifact(pkgs []pkg.Package) bool { loc := primaryEvidenceLocation(pkgs[0]) return loc != nil && loc.RealPath == "/" @@ -83,8 +71,7 @@ func mergeOCIModel(ctx context.Context, resolver file.Resolver, pkgs []pkg.Packa } // mergeDirModels groups filesystem-scanned files by their parent directory and -// emits one model per directory, named from a sibling config.json/README or the -// directory itself. +// emits one model per directory func mergeDirModels(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package) []pkg.Package { groups := groupByParentDir(pkgs) diff --git a/syft/pkg/safetensors.go b/syft/pkg/safetensors.go index d0ca8c2dc..0e769472b 100644 --- a/syft/pkg/safetensors.go +++ b/syft/pkg/safetensors.go @@ -2,24 +2,16 @@ package pkg // SafeTensorsModelInfo holds the model details extracted from SafeTensors content. // SafeTensors is a simple, safe serialization format for storing tensors, used -// as the default weight format for Hugging Face transformer models. Syft may -// populate this struct from these sources: -// - a single .safetensors file (header-only parse) -// - the per-shard headers of a multi-shard model, merged into one package -// - a Docker AI OCI model artifact: the config blob -// (vnd.docker.ai.model.config.v0.1+json) plus each weight layer's header -// -// Model name, license, and version live on the enclosing syft Package rather -// than in this struct. +// as the default weight format for Hugging Face transformer models. +// Model name, license, and version live on the syft Package type SafeTensorsModelInfo struct { // Format is the source format label (always "safetensors" for this metadata type). // Present because the Docker AI model config blob carries an explicit format field - // that can also be "gguf", and recording it here makes the origin explicit. Format string `json:"format,omitempty" cyclonedx:"format"` // Architecture is the model architecture (e.g., "LlamaForCausalLM", // "Qwen3MoeForConditionalGeneration"). It is not present in the SafeTensors - // header itself; it is enriched from the companion Hugging Face config.json + // header itself; it is enriched from the companion config.json // "architectures" array when one is found alongside the model. Architecture string `json:"architecture,omitempty" cyclonedx:"architecture"` @@ -42,15 +34,11 @@ type SafeTensorsModelInfo struct { ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"` // UserMetadata is the optional "__metadata__" map from a .safetensors file header - // (string-to-string key/values set by the producer). Stored as a sorted KeyValues - // slice rather than a Go map so SBOM output is stable across runs. + // (string-to-string key/values set by the producer). UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"` // MetadataHash is an xxhash over the on-disk SafeTensors header (sorted tensor - // entries + __metadata__). It is derived ONLY from the safetensors file bytes — - // never from OCI manifest, layer descriptor, or config-blob fields — so the same - // model content scanned via a directory source and via an OCI image produces the - // same value. Treat this as the cross-source content fingerprint. + // entries + __metadata__). It is derived ONLY from the safetensors file bytes. MetadataHash string `json:"metadataHash,omitempty" cyclonedx:"metadataHash"` // Parts contains metadata from additional SafeTensors shards or OCI layers that diff --git a/syft/source/ocimodelsource/oci_model_source.go b/syft/source/ocimodelsource/oci_model_source.go index 0fc697396..c93e9c3fa 100644 --- a/syft/source/ocimodelsource/oci_model_source.go +++ b/syft/source/ocimodelsource/oci_model_source.go @@ -86,14 +86,7 @@ func validateAndFetchArtifact(ctx context.Context, client *registryClient, refer // fetchAndStoreModelHeaders fetches the blobs needed to catalog a Docker AI // model artifact and stores them on disk so the ContainerImageModel resolver -// can serve them by media type: -// -// - For GGUF: the first maxHeaderBytes of each weight layer. -// - For SafeTensors: the model-config blob (already in memory as RawConfig), -// each companion layer in full, and the first maxHeaderBytes of each -// weight layer. The weight-layer prefix is enough to read the JSON header -// (tensor map + __metadata__), which is what the cataloger hashes for -// cross-source identity. +// can serve them by media type func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, artifact *modelArtifact) (string, *fileresolver.ContainerImageModel, error) { tempDir, err := os.MkdirTemp("", "syft-oci-model") if err != nil { @@ -119,8 +112,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti } // For SafeTensors artifacts, expose the model-config blob to the resolver - // so parseSafeTensorsOCIConfig can match it by media type. RawConfig was - // already fetched as part of the manifest walk. + // so parseSafeTensorsOCIConfig can match it by media type. if artifact.Format == modelFormatSafeTensors && len(artifact.RawConfig) > 0 { li, err := storeConfigBlobAsLayer(artifact, tempDir) if err != nil { @@ -144,9 +136,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti } // SafeTensors weight-layer headers. We only pull the leading prefix (same - // budget as a GGUF header) because the JSON header lives at the very start - // of the file — multi-GB tensor data that follows is intentionally not - // downloaded. + // budget as a GGUF header) if artifact.Format == modelFormatSafeTensors { for _, layer := range artifact.SafeTensorsLayers { li, err := fetchSafeTensorsLayerHeader(ctx, client, artifact.Reference, layer, tempDir) @@ -168,7 +158,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti } // storeConfigBlobAsLayer writes the already-fetched raw config bytes to a temp -// file so the resolver can serve them via media type. +// file so the resolver can serve them via media type func storeConfigBlobAsLayer(artifact *modelArtifact, tempDir string) (fileresolver.LayerInfo, error) { digest := artifact.Manifest.Config.Digest.String() safeDigest := strings.ReplaceAll(digest, ":", "-") @@ -182,9 +172,7 @@ func storeConfigBlobAsLayer(artifact *modelArtifact, tempDir string) (fileresolv }, nil } -// fetchCompanionLayer downloads a companion (non-weight) layer to a temp file. -// Unlike weight layers we fetch up to maxCompanionBytes, which comfortably -// covers READMEs, HF config.json, tokenizer.json, and LICENSE text. +// fetchCompanionLayer downloads a companion (non-weight) layer to a temp file func fetchCompanionLayer(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) { data, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxCompanionBytes) if err != nil { @@ -201,9 +189,9 @@ func fetchCompanionLayer(ctx context.Context, client *registryClient, ref name.R }, nil } -// fetchSingleGGUFHeader fetches a single GGUF layer header and writes it to a temp file. +// fetchSingleGGUFHeader fetches a single GGUF layer header and writes it to a temp file func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) { - headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes) + headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxWeightHeaderBytes) if err != nil { return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch GGUF layer header: %w", err) } @@ -222,9 +210,9 @@ func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name } // fetchSafeTensorsLayerHeader fetches the leading bytes of a SafeTensors weight -// layer (enough to cover the JSON header) and writes them to a temp file. +// layer (enough to cover the JSON header) and writes them to a temp file func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) { - headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes) + headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxWeightHeaderBytes) if err != nil { return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch safetensors layer header: %w", err) } @@ -241,7 +229,7 @@ func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, re }, nil } -// buildMetadata constructs OCIModelMetadata from a modelArtifact. +// buildMetadata constructs OCIModelMetadata from a modelArtifact func buildMetadata(artifact *modelArtifact) source.OCIModelMetadata { // layers layers := make([]source.LayerMetadata, len(artifact.Manifest.Layers)) diff --git a/syft/source/ocimodelsource/registry_client.go b/syft/source/ocimodelsource/registry_client.go index 43adfd5e2..1f011f5ab 100644 --- a/syft/source/ocimodelsource/registry_client.go +++ b/syft/source/ocimodelsource/registry_client.go @@ -37,10 +37,13 @@ const ( modelFormatGGUF = "gguf" modelFormatSafeTensors = "safetensors" - // Maximum bytes to read/return for weight-layer headers (GGUF + safetensors). - maxHeaderBytes = 8 * 1024 * 1024 // 8 MB - // Maximum bytes to fetch for a companion metadata layer (README, config.json, license). - // These blobs are small by convention; cap well below a safetensors header. + // maxWeightHeaderBytes is the leading slice we range-GET from a (multi-GB) + // weight layer — enough to cover the GGUF/safetensors header. + maxWeightHeaderBytes = 8 * 1024 * 1024 // 8 MB + + // maxCompanionBytes caps a whole companion blob (README, config.json, + // license); these are small by convention. Matches the 4 MB read cap in + // classifyOCIModelFileLayer. maxCompanionBytes = 4 * 1024 * 1024 // 4 MB ) @@ -129,17 +132,13 @@ type modelArtifact struct { Format string // GGUFLayers are descriptors for layers carrying GGUF-format weights. - // We fetch the first few MB of each to read the header. + // We fetch the first few MB of each to read the header data GGUFLayers []v1.Descriptor // SafeTensorsLayers are descriptors for layers carrying SafeTensors-format weights. - // We fetch the first maxHeaderBytes of each so the cataloger can read the JSON - // header (tensor map + __metadata__) without pulling the multi-GB tensor data. SafeTensorsLayers []v1.Descriptor - // CompanionLayers are non-weight layers (README, config.json, license) that - // we do fetch (in full, given their small size) so companion-file parsing - // in the safetensors cataloger can find them via media type. + // CompanionLayers are non-weight layers (README, config.json, license) CompanionLayers []v1.Descriptor } @@ -199,10 +198,7 @@ func (c *registryClient) fetchModelArtifact(ctx context.Context, refStr string) } // detectModelFormat returns a single format string when either GGUF or -// SafeTensors weight layers are present. When both appear (not expected in -// practice for Docker Model Runner artifacts), GGUF wins because the GGUF -// cataloger is the more established path. Empty result means the manifest has -// no recognized weight layers. +// SafeTensors weight layers are present. func detectModelFormat(ggufCount, safetensorsCount int) string { switch { case ggufCount > 0: @@ -271,14 +267,6 @@ func (c *registryClient) fetchBlobRange(ctx context.Context, ref name.Reference, if err != nil { return nil, fmt.Errorf("failed to get layer reader: %w", err) } - // this defer is what causes the download to stop - // 1. io.ReadFull(reader, data) reads exactly 8MB into the buffer - // 2. The function returns with data[:n] - // 3. defer reader.Close() executes, closing the HTTP response body - // 4. Closing the response body closes the underlying TCP connection - // 5. The server receives TCP FIN/RST and stops sending - // note: some data is already in flight when we close so we will see > 8mb over the wire - // the full image will not download given we terminate the reader early here defer reader.Close() // Note: this is not some arbitrary number picked out of the blue. @@ -286,6 +274,7 @@ func (c *registryClient) fetchBlobRange(ctx context.Context, ref name.Reference, // https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#file-structure data := make([]byte, maxBytes) n, err := io.ReadFull(reader, data) + // ErrUnexpectedEOF means the layer is smaller than maxBytes; EOF means it is // empty. Both mean we read everything there was, not a failure. if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {