mirror of
https://github.com/anchore/syft.git
synced 2026-07-05 02:28:25 +02:00
pr: review
Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
parent
b216dad4a7
commit
16d0449cc8
166
syft/pkg/cataloger/ai/merge_test.go
Normal file
166
syft/pkg/cataloger/ai/merge_test.go
Normal file
@ -0,0 +1,166 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/anchore/syft/syft/file"
|
||||
"github.com/anchore/syft/syft/pkg"
|
||||
)
|
||||
|
||||
// stPkg builds a model package carrying the given metadata, with each path
|
||||
// recorded as a primary-evidence location.
|
||||
func stPkg(md pkg.SafeTensorsModelInfo, paths ...string) pkg.Package {
|
||||
locs := make([]file.Location, 0, len(paths))
|
||||
for _, p := range paths {
|
||||
locs = append(locs, file.NewLocation(p).WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation))
|
||||
}
|
||||
return pkg.Package{Type: pkg.ModelPkg, Metadata: md, Locations: file.NewLocationSet(locs...)}
|
||||
}
|
||||
|
||||
// shardMeta is a content-derived shard entry: it carries a MetadataHash, which is
|
||||
// what marks a group member as a shard (vs. a hash-less aggregate config blob).
|
||||
func shardMeta(hash string, tensorCount uint64) pkg.SafeTensorsModelInfo {
|
||||
return pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: tensorCount,
|
||||
Quantization: "BF16",
|
||||
Parameters: "1.00K",
|
||||
MetadataHash: hash,
|
||||
}
|
||||
}
|
||||
|
||||
// TestMergeSafeTensorsGroup exercises the rollup contract directly (the cataloger
|
||||
// tests cover it only as a side effect of the merge processor). It locks how a
|
||||
// group's per-member metadata folds into one package: tensor-count summing,
|
||||
// aggregate-over-shard field precedence, UserMetadata dedup + sorting, Parts
|
||||
// rollup, ShardCount derivation, and the content MetadataHash rollup.
|
||||
func TestMergeSafeTensorsGroup(t *testing.T) {
|
||||
t.Run("single shard: hash passes through, ShardCount 1, no Parts", func(t *testing.T) {
|
||||
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(shardMeta("aaaa", 5), "/m/a.safetensors")})
|
||||
|
||||
md := out.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, pkg.ModelPkg, out.Type)
|
||||
assert.Equal(t, 1, md.ShardCount)
|
||||
assert.Equal(t, uint64(5), md.TensorCount)
|
||||
assert.Equal(t, "aaaa", md.MetadataHash, "a single shard's hash passes through unchanged")
|
||||
assert.Nil(t, md.Parts, "single-shard models do not populate Parts")
|
||||
})
|
||||
|
||||
t.Run("multi-shard: tensors summed, Parts sorted by hash, rollup is order-independent", func(t *testing.T) {
|
||||
in := []pkg.Package{
|
||||
stPkg(shardMeta("cccc", 3), "/m/c.safetensors"),
|
||||
stPkg(shardMeta("aaaa", 3), "/m/a.safetensors"),
|
||||
stPkg(shardMeta("bbbb", 3), "/m/b.safetensors"),
|
||||
}
|
||||
out := mergeSafeTensorsGroup(in)
|
||||
|
||||
md := out.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, 3, md.ShardCount)
|
||||
assert.Equal(t, uint64(9), md.TensorCount, "tensor counts are summed across shards")
|
||||
require.Len(t, md.Parts, 3)
|
||||
assert.Equal(t,
|
||||
[]string{"aaaa", "bbbb", "cccc"},
|
||||
[]string{md.Parts[0].MetadataHash, md.Parts[1].MetadataHash, md.Parts[2].MetadataHash},
|
||||
"Parts are sorted by metadata hash",
|
||||
)
|
||||
assert.Equal(t, rollupHash([]string{"aaaa", "bbbb", "cccc"}), md.MetadataHash)
|
||||
|
||||
// the rollup hash must not depend on the order members arrive in
|
||||
shuffled := []pkg.Package{
|
||||
stPkg(shardMeta("bbbb", 3), "/m/b.safetensors"),
|
||||
stPkg(shardMeta("cccc", 3), "/m/c.safetensors"),
|
||||
stPkg(shardMeta("aaaa", 3), "/m/a.safetensors"),
|
||||
}
|
||||
out2 := mergeSafeTensorsGroup(shuffled)
|
||||
assert.Equal(t, md.MetadataHash, out2.Metadata.(pkg.SafeTensorsModelInfo).MetadataHash)
|
||||
})
|
||||
|
||||
t.Run("aggregate fields win over shard-derived fields", func(t *testing.T) {
|
||||
// an aggregate (OCI config blob) carries no MetadataHash but declares the
|
||||
// authoritative totals.
|
||||
aggregate := pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: 999,
|
||||
TotalSize: "5.00GB",
|
||||
Parameters: "2.68B",
|
||||
Quantization: "Q4_K_M",
|
||||
}
|
||||
in := []pkg.Package{
|
||||
stPkg(aggregate, "/"),
|
||||
stPkg(shardMeta("aaaa", 3), "/"),
|
||||
stPkg(shardMeta("bbbb", 3), "/"),
|
||||
}
|
||||
out := mergeSafeTensorsGroup(in)
|
||||
|
||||
md := out.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, uint64(999), md.TensorCount, "aggregate TensorCount is authoritative; shard counts are not summed in")
|
||||
assert.Equal(t, "5.00GB", md.TotalSize)
|
||||
assert.Equal(t, "2.68B", md.Parameters)
|
||||
assert.Equal(t, "Q4_K_M", md.Quantization, "aggregate quantization wins over the shard dtype")
|
||||
assert.Equal(t, 2, md.ShardCount, "ShardCount comes from the number of shards, not the aggregate")
|
||||
assert.Equal(t, rollupHash([]string{"aaaa", "bbbb"}), md.MetadataHash, "the content hash still rolls up the shard hashes")
|
||||
})
|
||||
|
||||
t.Run("aggregate-only group: ShardCount 1, empty hash, no Parts", func(t *testing.T) {
|
||||
aggregate := pkg.SafeTensorsModelInfo{Format: "safetensors", TensorCount: 42, TotalSize: "1.00GB"}
|
||||
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(aggregate, "/")})
|
||||
|
||||
md := out.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, uint64(42), md.TensorCount)
|
||||
assert.Equal(t, 1, md.ShardCount, "a group with no shards still reports a single shard")
|
||||
assert.Equal(t, "", md.MetadataHash, "there are no shard hashes to roll up")
|
||||
assert.Nil(t, md.Parts)
|
||||
})
|
||||
|
||||
t.Run("UserMetadata: keys merged and sorted, first value wins on conflict", func(t *testing.T) {
|
||||
// keys are intentionally unsorted within each shard so the assertion proves
|
||||
// the merge re-sorts globally; "format" appears in both shards so dedup
|
||||
// precedence (first wins) is exercised too.
|
||||
s1 := shardMeta("aaaa", 1)
|
||||
s1.UserMetadata = pkg.KeyValues{{Key: "format", Value: "pt"}, {Key: "author", Value: "alice"}}
|
||||
s2 := shardMeta("bbbb", 1)
|
||||
s2.UserMetadata = pkg.KeyValues{{Key: "format", Value: "gguf"}, {Key: "license", Value: "mit"}}
|
||||
|
||||
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(s1, "/m/a.safetensors"), stPkg(s2, "/m/b.safetensors")})
|
||||
|
||||
md := out.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, pkg.KeyValues{
|
||||
{Key: "author", Value: "alice"},
|
||||
{Key: "format", Value: "pt"}, // first shard's value wins over s2's "gguf"
|
||||
{Key: "license", Value: "mit"},
|
||||
}, md.UserMetadata)
|
||||
})
|
||||
|
||||
t.Run("members without safetensors metadata are ignored in the rollup", func(t *testing.T) {
|
||||
notST := pkg.Package{
|
||||
Type: pkg.ModelPkg,
|
||||
Metadata: pkg.GGUFFileHeader{},
|
||||
Locations: file.NewLocationSet(file.NewLocation("/m/x.gguf")),
|
||||
}
|
||||
out := mergeSafeTensorsGroup([]pkg.Package{stPkg(shardMeta("aaaa", 2), "/m/a.safetensors"), notST})
|
||||
|
||||
md := out.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, uint64(2), md.TensorCount, "only the safetensors shard contributes")
|
||||
assert.Equal(t, 1, md.ShardCount)
|
||||
assert.Equal(t, "aaaa", md.MetadataHash)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRollupHash locks the cross-source content-fingerprint rollup: empty input
|
||||
// yields no hash, a lone shard's hash passes through unchanged (so a single-shard
|
||||
// model fingerprints identically across directory and OCI sources), and multiple
|
||||
// shards fold into one order-independent digest.
|
||||
func TestRollupHash(t *testing.T) {
|
||||
assert.Equal(t, "", rollupHash(nil), "no hashes → empty")
|
||||
assert.Equal(t, "solo", rollupHash([]string{"solo"}), "a single hash passes through unchanged")
|
||||
|
||||
ab := rollupHash([]string{"a", "b"})
|
||||
ba := rollupHash([]string{"b", "a"})
|
||||
assert.Equal(t, ab, ba, "the rollup is independent of input order")
|
||||
assert.Len(t, ab, 16, "a multi-hash rollup is a 16-char xxhash")
|
||||
assert.NotEqual(t, "a", ab)
|
||||
assert.NotEqual(t, "b", ab)
|
||||
}
|
||||
@ -3,9 +3,6 @@ package ai
|
||||
import "path"
|
||||
|
||||
// pickSafeTensorsName implements the documented naming precedence chain:
|
||||
// - config.json _name_or_path (path.Base, so "org/Model" → "Model";
|
||||
// applies to both dir-scan and OCI groups)
|
||||
// - fallback name — the group's source-specific positional identifier
|
||||
func pickSafeTensorsName(nameOrPath, fallbackName string) string {
|
||||
if nameOrPath != "" {
|
||||
return path.Base(nameOrPath)
|
||||
@ -15,8 +12,8 @@ func pickSafeTensorsName(nameOrPath, fallbackName string) string {
|
||||
|
||||
// safeTensorsDirName returns the directory-scan naming fallback: the base name
|
||||
// of the group's parent directory (the group key is already that directory).
|
||||
func safeTensorsDirName(groupKey string) string {
|
||||
base := path.Base(groupKey)
|
||||
func safeTensorsDirName(directory string) string {
|
||||
base := path.Base(directory)
|
||||
switch base {
|
||||
case "/", ".", "":
|
||||
return ""
|
||||
|
||||
@ -22,10 +22,7 @@ func newGGUFPackage(metadata *pkg.GGUFFileHeader, modelName, version, license st
|
||||
}
|
||||
|
||||
// newSafeTensorsPackage creates a SafeTensors package with the given metadata
|
||||
// and locations. Name and Licenses are intentionally not set here — the
|
||||
// safetensors cataloger emits nameless packages from every parser, and the
|
||||
// merge processor is the single owner of naming, license resolution, and
|
||||
// supporting-evidence attachment.
|
||||
// and locations. Name and Licenses are intentionally not set here and done at the processor level
|
||||
func newSafeTensorsPackage(metadata *pkg.SafeTensorsModelInfo, locations ...file.Location) pkg.Package {
|
||||
p := pkg.Package{
|
||||
Locations: file.NewLocationSet(locations...),
|
||||
|
||||
@ -47,12 +47,7 @@ type dockerAIModelConfig struct {
|
||||
} `json:"config"`
|
||||
}
|
||||
|
||||
// parseSafeTensorsOCIConfig decodes the Docker AI model-config blob and emits
|
||||
// a nameless package whose metadata mirrors the producer-declared aggregate
|
||||
// fields (Format, Quantization, Parameters, Size, TensorCount). For any
|
||||
// format other than "safetensors" it emits nothing so the GGUF cataloger can
|
||||
// claim the artifact. Naming, license, and HF-companion enrichment all run
|
||||
// once per group in safeTensorsMergeProcessor.
|
||||
// parseSafeTensorsOCIConfig decodes the Docker AI model-config blob
|
||||
func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
@ -88,10 +83,7 @@ func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.En
|
||||
}
|
||||
|
||||
// parseSafeTensorsOCILayer decodes the JSON header of a SafeTensors weight
|
||||
// layer fetched from an OCI model artifact (the source layer caps each layer
|
||||
// at a small prefix; tensor data is never downloaded). It emits a nameless
|
||||
// package; safeTensorsMergeProcessor folds it into the artifact's group and
|
||||
// rolls per-shard fields up into the final merged package.
|
||||
// layer fetched from an OCI model artifact
|
||||
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
|
||||
@ -15,12 +15,6 @@ import (
|
||||
// assembly. SafeTensors packages reach it nameless from the parsers; it groups
|
||||
// them per model, merges the per-shard metadata, resolves a name + licenses, and
|
||||
// drops any model it cannot name.
|
||||
//
|
||||
// There are exactly two sources, each handled by its own path:
|
||||
// - an OCI model artifact, where the source presents every layer at the
|
||||
// virtual path "/" and the whole scan is a single model (mergeOCIModel)
|
||||
// - a filesystem scan, where models are grouped by the directory their files
|
||||
// live in (mergeDirModels)
|
||||
func safeTensorsMergeProcessor(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
if err != nil || len(pkgs) == 0 {
|
||||
return pkgs, rels, err
|
||||
@ -55,12 +49,6 @@ func partitionSafeTensorsPackages(pkgs []pkg.Package) (safeTensors, other []pkg.
|
||||
// That source (the ContainerImageModel resolver) presents every layer at the
|
||||
// virtual path "/", whereas a filesystem scan always carries a real file path. A
|
||||
// single scan is one source, so the first package is representative of the rest.
|
||||
//
|
||||
// This deliberately keys off the path signal rather than type-asserting the
|
||||
// resolver to file.OCIMediaTypeResolver: the test harness wraps resolvers in an
|
||||
// ObservingResolver that implements that interface unconditionally, so an
|
||||
// interface check would misclassify directory scans as OCI. The "/" path is the
|
||||
// genuine, testable signal the OCI model source produces.
|
||||
func fromOCIArtifact(pkgs []pkg.Package) bool {
|
||||
loc := primaryEvidenceLocation(pkgs[0])
|
||||
return loc != nil && loc.RealPath == "/"
|
||||
@ -83,8 +71,7 @@ func mergeOCIModel(ctx context.Context, resolver file.Resolver, pkgs []pkg.Packa
|
||||
}
|
||||
|
||||
// mergeDirModels groups filesystem-scanned files by their parent directory and
|
||||
// emits one model per directory, named from a sibling config.json/README or the
|
||||
// directory itself.
|
||||
// emits one model per directory
|
||||
func mergeDirModels(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package) []pkg.Package {
|
||||
groups := groupByParentDir(pkgs)
|
||||
|
||||
|
||||
@ -2,24 +2,16 @@ package pkg
|
||||
|
||||
// SafeTensorsModelInfo holds the model details extracted from SafeTensors content.
|
||||
// SafeTensors is a simple, safe serialization format for storing tensors, used
|
||||
// as the default weight format for Hugging Face transformer models. Syft may
|
||||
// populate this struct from these sources:
|
||||
// - a single .safetensors file (header-only parse)
|
||||
// - the per-shard headers of a multi-shard model, merged into one package
|
||||
// - a Docker AI OCI model artifact: the config blob
|
||||
// (vnd.docker.ai.model.config.v0.1+json) plus each weight layer's header
|
||||
//
|
||||
// Model name, license, and version live on the enclosing syft Package rather
|
||||
// than in this struct.
|
||||
// as the default weight format for Hugging Face transformer models.
|
||||
// Model name, license, and version live on the syft Package
|
||||
type SafeTensorsModelInfo struct {
|
||||
// Format is the source format label (always "safetensors" for this metadata type).
|
||||
// Present because the Docker AI model config blob carries an explicit format field
|
||||
// that can also be "gguf", and recording it here makes the origin explicit.
|
||||
Format string `json:"format,omitempty" cyclonedx:"format"`
|
||||
|
||||
// Architecture is the model architecture (e.g., "LlamaForCausalLM",
|
||||
// "Qwen3MoeForConditionalGeneration"). It is not present in the SafeTensors
|
||||
// header itself; it is enriched from the companion Hugging Face config.json
|
||||
// header itself; it is enriched from the companion config.json
|
||||
// "architectures" array when one is found alongside the model.
|
||||
Architecture string `json:"architecture,omitempty" cyclonedx:"architecture"`
|
||||
|
||||
@ -42,15 +34,11 @@ type SafeTensorsModelInfo struct {
|
||||
ShardCount int `json:"shardCount,omitempty" cyclonedx:"shardCount"`
|
||||
|
||||
// UserMetadata is the optional "__metadata__" map from a .safetensors file header
|
||||
// (string-to-string key/values set by the producer). Stored as a sorted KeyValues
|
||||
// slice rather than a Go map so SBOM output is stable across runs.
|
||||
// (string-to-string key/values set by the producer).
|
||||
UserMetadata KeyValues `json:"userMetadata,omitempty" cyclonedx:"userMetadata"`
|
||||
|
||||
// MetadataHash is an xxhash over the on-disk SafeTensors header (sorted tensor
|
||||
// entries + __metadata__). It is derived ONLY from the safetensors file bytes —
|
||||
// never from OCI manifest, layer descriptor, or config-blob fields — so the same
|
||||
// model content scanned via a directory source and via an OCI image produces the
|
||||
// same value. Treat this as the cross-source content fingerprint.
|
||||
// entries + __metadata__). It is derived ONLY from the safetensors file bytes.
|
||||
MetadataHash string `json:"metadataHash,omitempty" cyclonedx:"metadataHash"`
|
||||
|
||||
// Parts contains metadata from additional SafeTensors shards or OCI layers that
|
||||
|
||||
@ -86,14 +86,7 @@ func validateAndFetchArtifact(ctx context.Context, client *registryClient, refer
|
||||
|
||||
// fetchAndStoreModelHeaders fetches the blobs needed to catalog a Docker AI
|
||||
// model artifact and stores them on disk so the ContainerImageModel resolver
|
||||
// can serve them by media type:
|
||||
//
|
||||
// - For GGUF: the first maxHeaderBytes of each weight layer.
|
||||
// - For SafeTensors: the model-config blob (already in memory as RawConfig),
|
||||
// each companion layer in full, and the first maxHeaderBytes of each
|
||||
// weight layer. The weight-layer prefix is enough to read the JSON header
|
||||
// (tensor map + __metadata__), which is what the cataloger hashes for
|
||||
// cross-source identity.
|
||||
// can serve them by media type
|
||||
func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, artifact *modelArtifact) (string, *fileresolver.ContainerImageModel, error) {
|
||||
tempDir, err := os.MkdirTemp("", "syft-oci-model")
|
||||
if err != nil {
|
||||
@ -119,8 +112,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
|
||||
}
|
||||
|
||||
// For SafeTensors artifacts, expose the model-config blob to the resolver
|
||||
// so parseSafeTensorsOCIConfig can match it by media type. RawConfig was
|
||||
// already fetched as part of the manifest walk.
|
||||
// so parseSafeTensorsOCIConfig can match it by media type.
|
||||
if artifact.Format == modelFormatSafeTensors && len(artifact.RawConfig) > 0 {
|
||||
li, err := storeConfigBlobAsLayer(artifact, tempDir)
|
||||
if err != nil {
|
||||
@ -144,9 +136,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
|
||||
}
|
||||
|
||||
// SafeTensors weight-layer headers. We only pull the leading prefix (same
|
||||
// budget as a GGUF header) because the JSON header lives at the very start
|
||||
// of the file — multi-GB tensor data that follows is intentionally not
|
||||
// downloaded.
|
||||
// budget as a GGUF header)
|
||||
if artifact.Format == modelFormatSafeTensors {
|
||||
for _, layer := range artifact.SafeTensorsLayers {
|
||||
li, err := fetchSafeTensorsLayerHeader(ctx, client, artifact.Reference, layer, tempDir)
|
||||
@ -168,7 +158,7 @@ func fetchAndStoreModelHeaders(ctx context.Context, client *registryClient, arti
|
||||
}
|
||||
|
||||
// storeConfigBlobAsLayer writes the already-fetched raw config bytes to a temp
|
||||
// file so the resolver can serve them via media type.
|
||||
// file so the resolver can serve them via media type
|
||||
func storeConfigBlobAsLayer(artifact *modelArtifact, tempDir string) (fileresolver.LayerInfo, error) {
|
||||
digest := artifact.Manifest.Config.Digest.String()
|
||||
safeDigest := strings.ReplaceAll(digest, ":", "-")
|
||||
@ -182,9 +172,7 @@ func storeConfigBlobAsLayer(artifact *modelArtifact, tempDir string) (fileresolv
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fetchCompanionLayer downloads a companion (non-weight) layer to a temp file.
|
||||
// Unlike weight layers we fetch up to maxCompanionBytes, which comfortably
|
||||
// covers READMEs, HF config.json, tokenizer.json, and LICENSE text.
|
||||
// fetchCompanionLayer downloads a companion (non-weight) layer to a temp file
|
||||
func fetchCompanionLayer(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
|
||||
data, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxCompanionBytes)
|
||||
if err != nil {
|
||||
@ -201,9 +189,9 @@ func fetchCompanionLayer(ctx context.Context, client *registryClient, ref name.R
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fetchSingleGGUFHeader fetches a single GGUF layer header and writes it to a temp file.
|
||||
// fetchSingleGGUFHeader fetches a single GGUF layer header and writes it to a temp file
|
||||
func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
|
||||
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes)
|
||||
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxWeightHeaderBytes)
|
||||
if err != nil {
|
||||
return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch GGUF layer header: %w", err)
|
||||
}
|
||||
@ -222,9 +210,9 @@ func fetchSingleGGUFHeader(ctx context.Context, client *registryClient, ref name
|
||||
}
|
||||
|
||||
// fetchSafeTensorsLayerHeader fetches the leading bytes of a SafeTensors weight
|
||||
// layer (enough to cover the JSON header) and writes them to a temp file.
|
||||
// layer (enough to cover the JSON header) and writes them to a temp file
|
||||
func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, ref name.Reference, layer v1.Descriptor, tempDir string) (fileresolver.LayerInfo, error) {
|
||||
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxHeaderBytes)
|
||||
headerData, err := client.fetchBlobRange(ctx, ref, layer.Digest, maxWeightHeaderBytes)
|
||||
if err != nil {
|
||||
return fileresolver.LayerInfo{}, fmt.Errorf("failed to fetch safetensors layer header: %w", err)
|
||||
}
|
||||
@ -241,7 +229,7 @@ func fetchSafeTensorsLayerHeader(ctx context.Context, client *registryClient, re
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildMetadata constructs OCIModelMetadata from a modelArtifact.
|
||||
// buildMetadata constructs OCIModelMetadata from a modelArtifact
|
||||
func buildMetadata(artifact *modelArtifact) source.OCIModelMetadata {
|
||||
// layers
|
||||
layers := make([]source.LayerMetadata, len(artifact.Manifest.Layers))
|
||||
|
||||
@ -37,10 +37,13 @@ const (
|
||||
modelFormatGGUF = "gguf"
|
||||
modelFormatSafeTensors = "safetensors"
|
||||
|
||||
// Maximum bytes to read/return for weight-layer headers (GGUF + safetensors).
|
||||
maxHeaderBytes = 8 * 1024 * 1024 // 8 MB
|
||||
// Maximum bytes to fetch for a companion metadata layer (README, config.json, license).
|
||||
// These blobs are small by convention; cap well below a safetensors header.
|
||||
// maxWeightHeaderBytes is the leading slice we range-GET from a (multi-GB)
|
||||
// weight layer — enough to cover the GGUF/safetensors header.
|
||||
maxWeightHeaderBytes = 8 * 1024 * 1024 // 8 MB
|
||||
|
||||
// maxCompanionBytes caps a whole companion blob (README, config.json,
|
||||
// license); these are small by convention. Matches the 4 MB read cap in
|
||||
// classifyOCIModelFileLayer.
|
||||
maxCompanionBytes = 4 * 1024 * 1024 // 4 MB
|
||||
)
|
||||
|
||||
@ -129,17 +132,13 @@ type modelArtifact struct {
|
||||
Format string
|
||||
|
||||
// GGUFLayers are descriptors for layers carrying GGUF-format weights.
|
||||
// We fetch the first few MB of each to read the header.
|
||||
// We fetch the first few MB of each to read the header data
|
||||
GGUFLayers []v1.Descriptor
|
||||
|
||||
// SafeTensorsLayers are descriptors for layers carrying SafeTensors-format weights.
|
||||
// We fetch the first maxHeaderBytes of each so the cataloger can read the JSON
|
||||
// header (tensor map + __metadata__) without pulling the multi-GB tensor data.
|
||||
SafeTensorsLayers []v1.Descriptor
|
||||
|
||||
// CompanionLayers are non-weight layers (README, config.json, license) that
|
||||
// we do fetch (in full, given their small size) so companion-file parsing
|
||||
// in the safetensors cataloger can find them via media type.
|
||||
// CompanionLayers are non-weight layers (README, config.json, license)
|
||||
CompanionLayers []v1.Descriptor
|
||||
}
|
||||
|
||||
@ -199,10 +198,7 @@ func (c *registryClient) fetchModelArtifact(ctx context.Context, refStr string)
|
||||
}
|
||||
|
||||
// detectModelFormat returns a single format string when either GGUF or
|
||||
// SafeTensors weight layers are present. When both appear (not expected in
|
||||
// practice for Docker Model Runner artifacts), GGUF wins because the GGUF
|
||||
// cataloger is the more established path. Empty result means the manifest has
|
||||
// no recognized weight layers.
|
||||
// SafeTensors weight layers are present.
|
||||
func detectModelFormat(ggufCount, safetensorsCount int) string {
|
||||
switch {
|
||||
case ggufCount > 0:
|
||||
@ -271,14 +267,6 @@ func (c *registryClient) fetchBlobRange(ctx context.Context, ref name.Reference,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get layer reader: %w", err)
|
||||
}
|
||||
// this defer is what causes the download to stop
|
||||
// 1. io.ReadFull(reader, data) reads exactly 8MB into the buffer
|
||||
// 2. The function returns with data[:n]
|
||||
// 3. defer reader.Close() executes, closing the HTTP response body
|
||||
// 4. Closing the response body closes the underlying TCP connection
|
||||
// 5. The server receives TCP FIN/RST and stops sending
|
||||
// note: some data is already in flight when we close so we will see > 8mb over the wire
|
||||
// the full image will not download given we terminate the reader early here
|
||||
defer reader.Close()
|
||||
|
||||
// Note: this is not some arbitrary number picked out of the blue.
|
||||
@ -286,6 +274,7 @@ func (c *registryClient) fetchBlobRange(ctx context.Context, ref name.Reference,
|
||||
// https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#file-structure
|
||||
data := make([]byte, maxBytes)
|
||||
n, err := io.ReadFull(reader, data)
|
||||
|
||||
// ErrUnexpectedEOF means the layer is smaller than maxBytes; EOF means it is
|
||||
// empty. Both mean we read everything there was, not a failure.
|
||||
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user