pr: refactor

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-06-05 03:16:02 -04:00
parent fe392a490b
commit 0dc4abc0e1
No known key found for this signature in database
10 changed files with 150 additions and 54 deletions

View File

@ -20,8 +20,10 @@ type LayerInfo struct {
} }
// ContainerImageModel is a file.Resolver implementation that provides access to // ContainerImageModel is a file.Resolver implementation that provides access to
// GGUF header data fetched from OCI model artifacts via range-GET requests. // model header and metadata data (GGUF and SafeTensors headers, the model config
// This does not fetch the entire model from the registry, only a sliver of it. // blob, and companion layers) fetched from OCI model artifacts via range-GET
// requests. This does not fetch the entire model from the registry, only a
// sliver of it.
type ContainerImageModel struct { type ContainerImageModel struct {
tempDir string // temp directory containing all layer files tempDir string // temp directory containing all layer files
layerFiles map[string]LayerInfo // digest -> layer info (temp path + media type) layerFiles map[string]LayerInfo // digest -> layer info (temp path + media type)

View File

@ -36,7 +36,21 @@ func resolveSafeTensorsOCIIdentity(ctx context.Context, resolver file.Resolver,
var configName, readmeName, readmeLicense string var configName, readmeName, readmeLicense string
var supporting []file.Location var supporting []file.Location
for _, loc := range modelFileLocs { for _, loc := range modelFileLocs {
if classifyOCIModelFileLayer(resolver, loc, md, &configName, &readmeName, &readmeLicense) { cfg, fm := classifyOCIModelFileLayer(resolver, loc)
switch {
case cfg != nil:
applyHFConfig(md, cfg)
if configName == "" {
configName = cfg.NameOrPath
}
supporting = append(supporting, loc)
case fm != nil:
if readmeLicense == "" {
readmeLicense = fm.License
}
if readmeName == "" && len(fm.BaseModel) > 0 {
readmeName = fm.BaseModel[0]
}
supporting = append(supporting, loc) supporting = append(supporting, loc)
} }
} }
@ -53,21 +67,19 @@ func resolveSafeTensorsOCIIdentity(ctx context.Context, resolver file.Resolver,
supporting: supporting, supporting: supporting,
} }
// License precedence: a README model-card license wins over dedicated // License precedence: a dedicated vnd.docker.ai.license layer is a
// license layers (mirrors the dir-scan path, where README frontmatter is the // producer-curated signal and outranks the free-text license field in a model
// license source). // card's README frontmatter.
licLocs, err := ociResolver.FilesByMediaType(dockerAILicenseMediaType)
if err != nil {
log.Debugf("failed to list docker AI license layers: %v", err)
}
switch { switch {
case len(licLocs) > 0:
id.licenses = identifyLicenseLayers(ctx, resolver, licLocs)
id.supporting = append(id.supporting, licLocs...)
case readmeLicense != "": case readmeLicense != "":
id.licenses = pkg.NewLicensesFromValuesWithContext(ctx, readmeLicense) id.licenses = pkg.NewLicensesFromValuesWithContext(ctx, readmeLicense)
default:
licLocs, err := ociResolver.FilesByMediaType(dockerAILicenseMediaType)
if err != nil {
log.Debugf("failed to list docker AI license layers: %v", err)
}
if len(licLocs) > 0 {
id.licenses = identifyLicenseLayers(ctx, resolver, licLocs)
id.supporting = append(id.supporting, licLocs...)
}
} }
return id return id
@ -100,45 +112,34 @@ func ociImageRefBasename(resolver file.Resolver) string {
return path.Base(parsed.Context().RepositoryStr()) return path.Base(parsed.Context().RepositoryStr())
} }
// classifyOCIModelFileLayer reads up to 4 MiB of a model.file layer and // classifyOCIModelFileLayer reads up to 4 MiB of a model.file layer and decodes
// classifies it as README frontmatter or HF config.json based on its leading bytes. // it as either an HF config.json or a README model card's YAML frontmatter,
func classifyOCIModelFileLayer(resolver file.Resolver, loc file.Location, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) bool { // based on its leading bytes. It returns whichever it recognized; both are nil
// when the layer is neither (or fails to decode). The caller owns precedence and
// metadata enrichment.
func classifyOCIModelFileLayer(resolver file.Resolver, loc file.Location) (*hfConfig, *readmeFrontmatter) {
rc, err := resolver.FileContentsByLocation(loc) rc, err := resolver.FileContentsByLocation(loc)
if err != nil { if err != nil {
return false return nil, nil
} }
defer internal.CloseAndLogError(rc, loc.RealPath) defer internal.CloseAndLogError(rc, loc.RealPath)
buf, err := io.ReadAll(io.LimitReader(rc, 4*1024*1024)) buf, err := io.ReadAll(io.LimitReader(rc, 4*1024*1024))
if err != nil { if err != nil {
return false return nil, nil
} }
trimmed := bytes.TrimLeft(buf, "\xef\xbb\xbf \t\r\n") trimmed := bytes.TrimLeft(buf, "\xef\xbb\xbf \t\r\n")
switch { switch {
case bytes.HasPrefix(trimmed, []byte("---")): case bytes.HasPrefix(trimmed, []byte("---")):
fm := parseFrontmatter(buf) return nil, parseFrontmatter(buf)
if fm == nil {
return false
}
if *license == "" {
*license = fm.License
}
if *readmeName == "" && len(fm.BaseModel) > 0 {
*readmeName = fm.BaseModel[0]
}
return true
case bytes.HasPrefix(trimmed, []byte("{")): case bytes.HasPrefix(trimmed, []byte("{")):
var cfg hfConfig var cfg hfConfig
if err := json.Unmarshal(buf, &cfg); err != nil { if err := json.Unmarshal(buf, &cfg); err != nil {
return false return nil, nil
} }
applyHFConfig(md, &cfg) return &cfg, nil
if *configName == "" && cfg.NameOrPath != "" {
*configName = cfg.NameOrPath
}
return true
} }
return false return nil, nil
} }
// identifyLicenseLayers turns Docker AI license-layer locations into // identifyLicenseLayers turns Docker AI license-layer locations into

View File

@ -59,9 +59,6 @@ func mergeAggregatesInto(merged *pkg.SafeTensorsModelInfo, aggregates []pkg.Safe
if merged.TensorCount == 0 { if merged.TensorCount == 0 {
merged.TensorCount = a.TensorCount merged.TensorCount = a.TensorCount
} }
if merged.ShardCount == 0 {
merged.ShardCount = a.ShardCount
}
firstNonEmpty(&merged.Parameters, a.Parameters) firstNonEmpty(&merged.Parameters, a.Parameters)
firstNonEmpty(&merged.TotalSize, a.TotalSize) firstNonEmpty(&merged.TotalSize, a.TotalSize)
firstNonEmpty(&merged.Quantization, a.Quantization) firstNonEmpty(&merged.Quantization, a.Quantization)
@ -71,7 +68,7 @@ func mergeAggregatesInto(merged *pkg.SafeTensorsModelInfo, aggregates []pkg.Safe
// mergeShardsInto folds the per-shard header metadata into merged, returning // mergeShardsInto folds the per-shard header metadata into merged, returning
// the summed shard TensorCount and the list of non-empty per-shard hashes for // the summed shard TensorCount and the list of non-empty per-shard hashes for
// the rollup. Shards carry only the content-derived fields (Quantization, // the rollup. Shards carry only the content-derived fields (Quantization,
// Parameters, UserMetadata); // Parameters, UserMetadata), so those are the only fields folded in here.
func mergeShardsInto(merged *pkg.SafeTensorsModelInfo, shards []pkg.SafeTensorsModelInfo) (shardTensorTotal uint64, hashes []string) { func mergeShardsInto(merged *pkg.SafeTensorsModelInfo, shards []pkg.SafeTensorsModelInfo) (shardTensorTotal uint64, hashes []string) {
seenKV := map[string]bool{} seenKV := map[string]bool{}
for _, s := range shards { for _, s := range shards {

View File

@ -148,10 +148,8 @@ func (h *safeTensorsHeader) metadataHash() string {
} }
// userMetadataKeyValues converts the safetensors __metadata__ map into a // userMetadataKeyValues converts the safetensors __metadata__ map into a
// KeyValues slice sorted by key. We do not use the convention of returning a // KeyValues slice sorted by key, so SBOM output is stable across runs. Returns
// nil slice for an empty input — instead, an empty input maps to an empty // nil for empty input (omitempty then drops the field).
// (length-0, non-nil) KeyValues — so downstream JSON serialization remains
// stable: `omitempty` drops the field either way.
func userMetadataKeyValues(m map[string]string) pkg.KeyValues { func userMetadataKeyValues(m map[string]string) pkg.KeyValues {
if len(m) == 0 { if len(m) == 0 {
return nil return nil

View File

@ -26,11 +26,12 @@ func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environ
return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err) return nil, nil, fmt.Errorf("failed to read safetensors header: %w", err)
} }
// ShardCount is intentionally not set here: the merge processor is the single
// owner of ShardCount and derives it from the number of shards in the group.
md := pkg.SafeTensorsModelInfo{ md := pkg.SafeTensorsModelInfo{
Format: "safetensors", Format: "safetensors",
TensorCount: uint64(len(header.tensors)), TensorCount: uint64(len(header.tensors)),
Quantization: normalizeDType(header.dominantDType()), Quantization: normalizeDType(header.dominantDType()),
ShardCount: 1,
UserMetadata: userMetadataKeyValues(header.metadata), UserMetadata: userMetadataKeyValues(header.metadata),
MetadataHash: header.metadataHash(), MetadataHash: header.metadataHash(),
} }

View File

@ -122,6 +122,27 @@ func TestSafeTensorsCataloger(t *testing.T) {
}, },
}, },
}, },
{
// rung 1 via README: with no config.json, the README model card's
// base_model names the model (path.Base applied), still beating the
// directory fallback ("readme-named").
name: "README base_model names the model when there is no config.json",
setup: func(t *testing.T) string {
dir := t.TempDir()
modelDir := filepath.Join(dir, "readme-named")
model(t, modelDir)
writeFile(t, filepath.Join(modelDir, "README.md"),
"---\nbase_model:\n - org/base-model-name\n---\n# Card\n")
return dir
},
expectedPackages: []pkg.Package{
{
Name: "base-model-name",
Type: pkg.ModelPkg,
Metadata: wantMetadata(""),
},
},
},
{ {
// rung 2: no config.json at all, so the model is named after its // rung 2: no config.json at all, so the model is named after its
// immediate parent directory. // immediate parent directory.
@ -265,6 +286,46 @@ func TestSafeTensorsCataloger(t *testing.T) {
} }
} }
// TestSafeTensorsCataloger_shardedDirectory covers the primary multi-shard shape:
// several `model-0000N-of-0000M.safetensors` files in one directory. The
// cataloger must group the shards into a single package, sum their tensor counts,
// record the shard count, and roll each shard up into Parts. (The OCI multi-shard
// path is covered separately in TestSafeTensorsMergeProcessor.)
func TestSafeTensorsCataloger_shardedDirectory(t *testing.T) {
userMeta := map[string]string{"format": "pt"}
// Two shards with distinct tensors → distinct per-shard metadata hashes, so
// the merge treats them as separate shards (3 tensors total across 2 shards).
shardA := map[string]safeTensorsEntry{
"layers.0.weight": {DType: "BF16", Shape: []int64{10, 10}, DataOffsets: []int64{0, 200}},
}
shardB := map[string]safeTensorsEntry{
"layers.1.weight": {DType: "BF16", Shape: []int64{10, 10}, DataOffsets: []int64{0, 200}},
"layers.2.weight": {DType: "BF16", Shape: []int64{10, 10}, DataOffsets: []int64{200, 400}},
}
dir := t.TempDir()
modelDir := filepath.Join(dir, "llama-sharded")
require.NoError(t, os.MkdirAll(modelDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "model-00001-of-00002.safetensors"),
buildSafeTensorsFile(t, userMeta, shardA), 0o644))
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "model-00002-of-00002.safetensors"),
buildSafeTensorsFile(t, userMeta, shardB), 0o644))
pkgtest.NewCatalogTester().
FromDirectory(t, dir).
ExpectsAssertion(func(t *testing.T, pkgs []pkg.Package, _ []artifact.Relationship) {
require.Len(t, pkgs, 1)
got := pkgs[0]
assert.Equal(t, "llama-sharded", got.Name, "a sharded model with no config.json is named by its directory")
md := got.Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, 2, md.ShardCount)
assert.Equal(t, uint64(3), md.TensorCount, "tensor counts are summed across shards")
assert.Len(t, md.Parts, 2, "each shard is rolled up into Parts")
assert.Equal(t, "BF16", md.Quantization)
}).
TestCataloger(t, NewSafeTensorsCataloger())
}
// TestParseSafeTensorsOCIConfig covers the parser in isolation: it should emit // TestParseSafeTensorsOCIConfig covers the parser in isolation: it should emit
// a nameless package mirroring the config blob's producer-declared fields, and // a nameless package mirroring the config blob's producer-declared fields, and
// emit nothing for non-safetensors formats so the GGUF cataloger can claim the // emit nothing for non-safetensors formats so the GGUF cataloger can claim the
@ -510,6 +571,34 @@ spdx-id: Apache-2.0
assertHasLicense(t, out[0], "Apache-2.0") assertHasLicense(t, out[0], "Apache-2.0")
}) })
t.Run("OCI: license layer wins over a README model-card license", func(t *testing.T) {
// When both a dedicated license layer and a README model-card license are
// present, the producer-curated license layer is authoritative. (If the
// README won, the resolved license would be MIT and this assertion fails.)
dir := t.TempDir()
licensePath := filepath.Join(dir, "LICENSE")
require.NoError(t, os.WriteFile(licensePath, []byte("---\nspdx-id: Apache-2.0\n---\n"), 0o644))
readmePath := filepath.Join(dir, "README.md")
require.NoError(t, os.WriteFile(readmePath,
[]byte("---\nlicense: MIT\nbase_model:\n - org/base\n---\n# Card\n"), 0o644))
configPath := filepath.Join(dir, "config.json")
require.NoError(t, os.WriteFile(configPath, []byte(`{"_name_or_path":"org/precedence-model"}`), 0o644))
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
dockerAIModelFileMediaType: {file.NewLocation(configPath), file.NewLocation(readmePath)},
dockerAILicenseMediaType: {file.NewLocation(licensePath)},
})
configMd := pkg.SafeTensorsModelInfo{Format: "safetensors", TensorCount: 1}
out, _, err := safeTensorsMergeProcessor(
context.Background(), resolver,
[]pkg.Package{ociPkg(configMd)}, nil, nil,
)
require.NoError(t, err)
require.Len(t, out, 1)
assert.Equal(t, "precedence-model", out[0].Name)
assertHasLicense(t, out[0], "Apache-2.0")
})
t.Run("passes through upstream error", func(t *testing.T) { t.Run("passes through upstream error", func(t *testing.T) {
sentinel := assert.AnError sentinel := assert.AnError
p := dirPkg("/models/x/y.safetensors", pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "h"}) p := dirPkg("/models/x/y.safetensors", pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "h"})

View File

@ -52,9 +52,15 @@ func partitionSafeTensorsPackages(pkgs []pkg.Package) (safeTensors, other []pkg.
} }
// fromOCIArtifact reports whether the packages came from an OCI model artifact. // fromOCIArtifact reports whether the packages came from an OCI model artifact.
// That source presents every layer at the virtual path "/", whereas a filesystem // That source (the ContainerImageModel resolver) presents every layer at the
// scan always carries a real file path. A single scan is one source, so the // virtual path "/", whereas a filesystem scan always carries a real file path. A
// first package is representative of the rest. // single scan is one source, so the first package is representative of the rest.
//
// This deliberately keys off the path signal rather than type-asserting the
// resolver to file.OCIMediaTypeResolver: the test harness wraps resolvers in an
// ObservingResolver that implements that interface unconditionally, so an
// interface check would misclassify directory scans as OCI. The "/" path is the
// genuine, testable signal the OCI model source produces.
func fromOCIArtifact(pkgs []pkg.Package) bool { func fromOCIArtifact(pkgs []pkg.Package) bool {
loc := primaryEvidenceLocation(pkgs[0]) loc := primaryEvidenceLocation(pkgs[0])
return loc != nil && loc.RealPath == "/" return loc != nil && loc.RealPath == "/"

View File

@ -34,7 +34,7 @@ type SafeTensorsModelInfo struct {
TensorCount uint64 `json:"tensorCount,omitempty" cyclonedx:"tensorCount"` TensorCount uint64 `json:"tensorCount,omitempty" cyclonedx:"tensorCount"`
// TotalSize is the total byte size of tensor data across all shards when known // TotalSize is the total byte size of tensor data across all shards when known
// (from the Docker AI model config "size" field or the sharded index "total_size"). // (from the Docker AI model config "size" field).
TotalSize string `json:"totalSize,omitempty" cyclonedx:"totalSize"` TotalSize string `json:"totalSize,omitempty" cyclonedx:"totalSize"`
// ShardCount is the number of .safetensors shards for a sharded model (1 for a // ShardCount is the number of .safetensors shards for a sharded model (1 for a

View File

@ -336,7 +336,8 @@ func (s *ociModelSource) Describe() source.Description {
} }
} }
// FileResolver returns a file resolver for accessing header of GGUF files. // FileResolver returns a file resolver for accessing model headers and companion
// metadata (GGUF/SafeTensors headers, the model config blob, and companion layers).
func (s *ociModelSource) FileResolver(_ source.Scope) (file.Resolver, error) { func (s *ociModelSource) FileResolver(_ source.Scope) (file.Resolver, error) {
return s.resolver, nil return s.resolver, nil
} }

View File

@ -286,8 +286,9 @@ func (c *registryClient) fetchBlobRange(ctx context.Context, ref name.Reference,
// https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#file-structure // https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#file-structure
data := make([]byte, maxBytes) data := make([]byte, maxBytes)
n, err := io.ReadFull(reader, data) n, err := io.ReadFull(reader, data)
if err != nil && err != io.ErrUnexpectedEOF { // ErrUnexpectedEOF means the layer is smaller than maxBytes; EOF means it is
// ErrUnexpectedEOF is okay - it means the file is smaller than maxBytes // empty. Both mean we read everything there was, not a failure.
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("failed to read layer data: %w", err) return nil, fmt.Errorf("failed to read layer data: %w", err)
} }