mirror of
https://github.com/anchore/syft.git
synced 2026-07-05 02:28:25 +02:00
fix: move non safetensor layer fetch to post
Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
parent
69b7c5e3d0
commit
a75c3086f6
@ -40,5 +40,5 @@ func NewSafeTensorsCataloger() pkg.Cataloger {
|
||||
WithParserByGlobs(parseSafeTensorsIndex, "**/*.safetensors.index.json").
|
||||
WithParserByMediaType(parseSafeTensorsOCIConfig, dockerAIModelConfigMediaTypes...).
|
||||
WithParserByMediaType(parseSafeTensorsOCILayer, dockerAISafeTensorsMediaType).
|
||||
WithProcessors(safeTensorsMergeProcessor)
|
||||
WithResolvingProcessors(safeTensorsMergeProcessor)
|
||||
}
|
||||
|
||||
@ -21,13 +21,15 @@ func newGGUFPackage(metadata *pkg.GGUFFileHeader, modelName, version, license st
|
||||
return p
|
||||
}
|
||||
|
||||
func newSafeTensorsPackage(metadata *pkg.SafeTensorsModelInfo, modelName, version, license string, locations ...file.Location) pkg.Package {
|
||||
// newSafeTensorsPackage creates a SafeTensors package with the given metadata
|
||||
// and locations. Name and Licenses are intentionally not set here — the
|
||||
// safetensors cataloger emits nameless packages from every parser, and the
|
||||
// merge processor is the single owner of naming, license resolution, and
|
||||
// supporting-evidence attachment.
|
||||
func newSafeTensorsPackage(metadata *pkg.SafeTensorsModelInfo, locations ...file.Location) pkg.Package {
|
||||
p := pkg.Package{
|
||||
Name: modelName,
|
||||
Version: version,
|
||||
Locations: file.NewLocationSet(locations...),
|
||||
Type: pkg.ModelPkg,
|
||||
Licenses: pkg.NewLicenseSet(pkg.NewLicensesFromValues(license)...),
|
||||
Metadata: *metadata,
|
||||
// PURL is intentionally not set: package-url has not yet finalized ML model support.
|
||||
}
|
||||
|
||||
@ -1,20 +1,13 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/anchore/syft/internal"
|
||||
"github.com/anchore/syft/internal/log"
|
||||
"github.com/anchore/syft/internal/unknown"
|
||||
"github.com/anchore/syft/syft/artifact"
|
||||
"github.com/anchore/syft/syft/file"
|
||||
@ -22,10 +15,11 @@ import (
|
||||
"github.com/anchore/syft/syft/pkg/cataloger/generic"
|
||||
)
|
||||
|
||||
// parseSafeTensorsFile parses a single .safetensors file by reading only its
|
||||
// JSON header, then enriches the resulting package with metadata from sibling
|
||||
// config.json and README.md files when the resolver can find them.
|
||||
func parseSafeTensorsFile(_ context.Context, resolver file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
// parseSafeTensorsFile decodes the JSON header of a single .safetensors file
|
||||
// and emits a nameless package whose metadata is derived purely from the
|
||||
// header bytes. Naming, license resolution, sibling enrichment, and cross-
|
||||
// shard rollup are all the responsibility of safeTensorsMergeProcessor.
|
||||
func parseSafeTensorsFile(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
header, _, err := readSafeTensorsHeader(&io.LimitedReader{R: reader, N: maxSafeTensorsHeaderSize + 8})
|
||||
@ -45,27 +39,19 @@ func parseSafeTensorsFile(_ context.Context, resolver file.Resolver, _ *generic.
|
||||
md.Parameters = formatParameterCount(p)
|
||||
}
|
||||
|
||||
name, license := enrichFromSiblings(resolver, reader.Path(), &md)
|
||||
if name == "" {
|
||||
name = modelNameFromPath(reader.Path())
|
||||
}
|
||||
|
||||
p := newSafeTensorsPackage(
|
||||
&md,
|
||||
name,
|
||||
"",
|
||||
license,
|
||||
reader.WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
)
|
||||
|
||||
return []pkg.Package{p}, nil, unknown.IfEmptyf([]pkg.Package{p}, "unable to parse safetensors file")
|
||||
}
|
||||
|
||||
// parseSafeTensorsIndex parses a model.safetensors.index.json file for a sharded
|
||||
// model. The index lists every tensor and the shard file it lives in; from this
|
||||
// we derive tensor count, unique shard count, and (when present) the producer-
|
||||
// declared total_size.
|
||||
func parseSafeTensorsIndex(_ context.Context, resolver file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
// parseSafeTensorsIndex decodes a model.safetensors.index.json file for a
|
||||
// sharded model and emits a nameless package recording tensor count, unique
|
||||
// shard count, and (when present) the producer-declared total_size. Like
|
||||
// parseSafeTensorsFile, naming and sibling enrichment happen in the merge
|
||||
// processor.
|
||||
func parseSafeTensorsIndex(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
var doc struct {
|
||||
@ -92,179 +78,16 @@ func parseSafeTensorsIndex(_ context.Context, resolver file.Resolver, _ *generic
|
||||
md.TotalSize = formatByteSize(doc.Metadata.TotalSize.String())
|
||||
}
|
||||
|
||||
name, license := enrichFromSiblings(resolver, reader.Path(), &md)
|
||||
if name == "" {
|
||||
name = modelNameFromIndexPath(reader.Path())
|
||||
}
|
||||
|
||||
p := newSafeTensorsPackage(
|
||||
&md,
|
||||
name,
|
||||
"",
|
||||
license,
|
||||
reader.WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
)
|
||||
|
||||
return []pkg.Package{p}, nil, unknown.IfEmptyf([]pkg.Package{p}, "unable to parse safetensors index")
|
||||
}
|
||||
|
||||
// enrichFromSiblings looks for a sibling config.json and README.md next to the
|
||||
// safetensors artifact and folds their values into the metadata struct. It
|
||||
// returns a name and license derived from those sources, with the caller free
|
||||
// to fall back to a filename-derived default.
|
||||
func enrichFromSiblings(resolver file.Resolver, sourcePath string, md *pkg.SafeTensorsModelInfo) (name, license string) {
|
||||
if resolver == nil {
|
||||
return "", ""
|
||||
}
|
||||
dir := path.Dir(sourcePath)
|
||||
|
||||
if cfg := readSiblingJSON(resolver, path.Join(dir, "config.json")); cfg != nil {
|
||||
if md.Architecture == "" && len(cfg.Architectures) > 0 {
|
||||
md.Architecture = cfg.Architectures[0]
|
||||
}
|
||||
if md.TorchDtype == "" {
|
||||
md.TorchDtype = cfg.TorchDtype
|
||||
}
|
||||
if md.TransformersVersion == "" {
|
||||
md.TransformersVersion = cfg.TransformersVersion
|
||||
}
|
||||
if cfg.NameOrPath != "" {
|
||||
name = path.Base(cfg.NameOrPath)
|
||||
}
|
||||
}
|
||||
|
||||
if fm := readReadmeFrontmatter(resolver, path.Join(dir, "README.md")); fm != nil {
|
||||
if license == "" {
|
||||
license = fm.License
|
||||
}
|
||||
if name == "" && len(fm.BaseModel) > 0 {
|
||||
name = path.Base(fm.BaseModel[0])
|
||||
}
|
||||
}
|
||||
|
||||
return name, license
|
||||
}
|
||||
|
||||
// hfConfig is a minimal projection of Hugging Face config.json fields we care about.
|
||||
type hfConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
TransformersVersion string `json:"transformers_version"`
|
||||
NameOrPath string `json:"_name_or_path"`
|
||||
}
|
||||
|
||||
func readSiblingJSON(resolver file.Resolver, p string) *hfConfig {
|
||||
locations, err := resolver.FilesByPath(p)
|
||||
if err != nil || len(locations) == 0 {
|
||||
return nil
|
||||
}
|
||||
rc, err := resolver.FileContentsByLocation(locations[0])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer internal.CloseAndLogError(rc, p)
|
||||
|
||||
var cfg hfConfig
|
||||
if err := json.NewDecoder(rc).Decode(&cfg); err != nil {
|
||||
log.Debugf("failed to decode %s: %v", p, err)
|
||||
return nil
|
||||
}
|
||||
return &cfg
|
||||
}
|
||||
|
||||
// readmeFrontmatter holds the subset of YAML frontmatter fields we extract.
|
||||
type readmeFrontmatter struct {
|
||||
License string `yaml:"license"`
|
||||
BaseModel []string `yaml:"base_model"`
|
||||
}
|
||||
|
||||
// readReadmeFrontmatter extracts the leading YAML frontmatter block from a README.
|
||||
// The block is delimited by "---" lines at the start of the file.
|
||||
func readReadmeFrontmatter(resolver file.Resolver, p string) *readmeFrontmatter {
|
||||
locations, err := resolver.FilesByPath(p)
|
||||
if err != nil || len(locations) == 0 {
|
||||
return nil
|
||||
}
|
||||
rc, err := resolver.FileContentsByLocation(locations[0])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer internal.CloseAndLogError(rc, p)
|
||||
|
||||
buf, err := io.ReadAll(io.LimitReader(rc, 1024*1024))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseFrontmatter(buf)
|
||||
}
|
||||
|
||||
// parseFrontmatter pulls the YAML block between the first and second "---" lines
|
||||
// of a file (if present) and decodes known fields from it.
|
||||
func parseFrontmatter(buf []byte) *readmeFrontmatter {
|
||||
trimmed := bytes.TrimLeft(buf, "\xef\xbb\xbf \t\r\n")
|
||||
if !bytes.HasPrefix(trimmed, []byte("---")) {
|
||||
return nil
|
||||
}
|
||||
rest := trimmed[3:]
|
||||
// trim the newline directly following the opening delimiter
|
||||
if i := bytes.IndexByte(rest, '\n'); i >= 0 {
|
||||
rest = rest[i+1:]
|
||||
}
|
||||
end := bytes.Index(rest, []byte("\n---"))
|
||||
if end < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// base_model may be either a scalar ("org/model") or a sequence; decode it
|
||||
// as a yaml.Node so a scalar value does not fail the whole block.
|
||||
var raw struct {
|
||||
License string `yaml:"license"`
|
||||
BaseModel yaml.Node `yaml:"base_model"`
|
||||
}
|
||||
if err := yaml.Unmarshal(rest[:end], &raw); err != nil {
|
||||
log.Debugf("failed to parse README frontmatter: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
fm := readmeFrontmatter{License: raw.License}
|
||||
switch raw.BaseModel.Kind {
|
||||
case yaml.ScalarNode:
|
||||
if raw.BaseModel.Value != "" {
|
||||
fm.BaseModel = []string{raw.BaseModel.Value}
|
||||
}
|
||||
case yaml.SequenceNode:
|
||||
_ = raw.BaseModel.Decode(&fm.BaseModel)
|
||||
}
|
||||
return &fm
|
||||
}
|
||||
|
||||
// defaultModelName is the fallback package name when no model name can be
|
||||
// derived from sibling files, the file path, or OCI companion layers.
|
||||
const defaultModelName = "safetensors-model"
|
||||
|
||||
// modelNameFromPath turns "/models/foo/model.safetensors" into "foo".
|
||||
// For a bare filename "weights.safetensors" we return "weights".
|
||||
func modelNameFromPath(p string) string {
|
||||
base := strings.TrimSuffix(filepath.Base(p), ".safetensors")
|
||||
dir := filepath.Base(filepath.Dir(p))
|
||||
if dir != "" && dir != "." && dir != string(filepath.Separator) {
|
||||
return dir
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// modelNameFromIndexPath derives a model name from the index filename's parent
|
||||
// directory, defaulting to "safetensors-model" if no useful directory name exists.
|
||||
func modelNameFromIndexPath(p string) string {
|
||||
dir := filepath.Base(filepath.Dir(p))
|
||||
if dir != "" && dir != "." && dir != string(filepath.Separator) {
|
||||
return dir
|
||||
}
|
||||
return defaultModelName
|
||||
}
|
||||
|
||||
// formatParameterCount prints a count like 6_700_000_000 as "6.7B" using B/M/K
|
||||
// thresholds matching the notation used by Hugging Face and Docker AI labels.
|
||||
// formatParameterCount prints a count like 6_700_000_000 as "6.70B" using
|
||||
// B/M/K thresholds matching the notation used by Hugging Face and Docker AI
|
||||
// labels.
|
||||
func formatParameterCount(n uint64) string {
|
||||
switch {
|
||||
case n >= 1_000_000_000:
|
||||
@ -278,9 +101,9 @@ func formatParameterCount(n uint64) string {
|
||||
}
|
||||
}
|
||||
|
||||
// formatByteSize turns a numeric string (bytes) into a human-friendly size like
|
||||
// "71.90GB". Non-numeric inputs are passed through unchanged so we never lose
|
||||
// producer-declared strings such as "71.90GB".
|
||||
// formatByteSize turns a numeric string (bytes) into a human-friendly size
|
||||
// like "71.90GB". Non-numeric inputs are passed through unchanged so producer-
|
||||
// declared strings (e.g. "71.90GB" from a Docker AI config blob) survive.
|
||||
func formatByteSize(s string) string {
|
||||
n, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil || n == 0 {
|
||||
|
||||
@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/anchore/syft/internal"
|
||||
"github.com/anchore/syft/internal/log"
|
||||
"github.com/anchore/syft/internal/unknown"
|
||||
"github.com/anchore/syft/syft/artifact"
|
||||
"github.com/anchore/syft/syft/file"
|
||||
@ -48,13 +47,13 @@ type dockerAIModelConfig struct {
|
||||
} `json:"config"`
|
||||
}
|
||||
|
||||
// parseSafeTensorsOCIConfig parses a Docker AI model-config blob. When the blob
|
||||
// advertises format=="safetensors" it emits a single named package whose
|
||||
// metadata is enriched by scanning sibling OCI layers (README.md for license +
|
||||
// base_model name, config.json for architecture, LICENSE text for a license
|
||||
// fallback). For any other format it emits nothing so the GGUF cataloger can
|
||||
// claim the image.
|
||||
func parseSafeTensorsOCIConfig(_ context.Context, resolver file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
// parseSafeTensorsOCIConfig decodes the Docker AI model-config blob and emits
|
||||
// a nameless package whose metadata mirrors the producer-declared aggregate
|
||||
// fields (Format, Quantization, Parameters, Size, TensorCount). For any
|
||||
// format other than "safetensors" it emits nothing so the GGUF cataloger can
|
||||
// claim the artifact. Naming, license, and HF-companion enrichment all run
|
||||
// once per group in safeTensorsMergeProcessor.
|
||||
func parseSafeTensorsOCIConfig(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(reader, 1024*1024))
|
||||
@ -81,174 +80,18 @@ func parseSafeTensorsOCIConfig(_ context.Context, resolver file.Resolver, _ *gen
|
||||
md.TensorCount = uint64(n)
|
||||
}
|
||||
|
||||
name, license := enrichFromDockerAILayers(resolver, &md)
|
||||
if name == "" {
|
||||
name = defaultModelName
|
||||
}
|
||||
|
||||
p := newSafeTensorsPackage(
|
||||
&md,
|
||||
name,
|
||||
"",
|
||||
license,
|
||||
reader.WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
)
|
||||
|
||||
return []pkg.Package{p}, nil, unknown.IfEmptyf([]pkg.Package{p}, "unable to parse docker AI safetensors config")
|
||||
}
|
||||
|
||||
// enrichFromDockerAILayers walks sibling Docker AI layers via the OCI resolver
|
||||
// and mines them for a model name, architecture, and license. README.md carries
|
||||
// YAML frontmatter with license + base_model; HF config.json carries
|
||||
// architectures/torch_dtype/transformers_version; the vnd.docker.ai.license
|
||||
// blob is plain license text.
|
||||
func enrichFromDockerAILayers(resolver file.Resolver, md *pkg.SafeTensorsModelInfo) (name, license string) {
|
||||
ociResolver, ok := resolver.(file.OCIMediaTypeResolver)
|
||||
if !ok {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
modelFileLocations, err := ociResolver.FilesByMediaType(dockerAIModelFileMediaType)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list docker AI model-file layers: %v", err)
|
||||
}
|
||||
|
||||
// Collect name candidates separately so precedence does not depend on the
|
||||
// order the resolver returns layers in. config.json's _name_or_path wins over
|
||||
// a README base_model, matching enrichFromSiblings.
|
||||
var configName, readmeName string
|
||||
for _, loc := range modelFileLocations {
|
||||
readAndClassifyDockerAILayer(resolver, loc, md, &configName, &readmeName, &license)
|
||||
}
|
||||
|
||||
name = configName
|
||||
if name == "" {
|
||||
name = readmeName
|
||||
}
|
||||
|
||||
if license == "" {
|
||||
license = readDockerAILicense(resolver, ociResolver)
|
||||
}
|
||||
|
||||
return name, license
|
||||
}
|
||||
|
||||
// readAndClassifyDockerAILayer fetches a single Docker AI model-file layer and
|
||||
// passes its contents to classifyAndMerge. Split out from the calling loop so
|
||||
// the resolver handle is closed via defer on every iteration.
|
||||
func readAndClassifyDockerAILayer(resolver file.Resolver, loc file.Location, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) {
|
||||
rc, err := resolver.FileContentsByLocation(loc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogError(rc, loc.RealPath)
|
||||
|
||||
buf, err := io.ReadAll(io.LimitReader(rc, 4*1024*1024))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
classifyAndMerge(buf, md, configName, readmeName, license)
|
||||
}
|
||||
|
||||
// classifyAndMerge sniffs a vnd.docker.ai.model.file blob (which can be README.md,
|
||||
// config.json, generation_config.json, tokenizer.json, etc.) and folds useful
|
||||
// fields into the metadata struct and out-parameters.
|
||||
func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) {
|
||||
trimmed := trimLeadingWhitespace(buf)
|
||||
switch {
|
||||
case hasPrefix(trimmed, "---"):
|
||||
if fm := parseFrontmatter(buf); fm != nil {
|
||||
if *license == "" {
|
||||
*license = fm.License
|
||||
}
|
||||
if *readmeName == "" && len(fm.BaseModel) > 0 {
|
||||
*readmeName = lastPathSegment(fm.BaseModel[0])
|
||||
}
|
||||
}
|
||||
case hasPrefix(trimmed, "{"):
|
||||
var cfg hfConfig
|
||||
if err := json.Unmarshal(buf, &cfg); err != nil {
|
||||
return
|
||||
}
|
||||
if md.Architecture == "" && len(cfg.Architectures) > 0 {
|
||||
md.Architecture = cfg.Architectures[0]
|
||||
}
|
||||
if md.TorchDtype == "" {
|
||||
md.TorchDtype = cfg.TorchDtype
|
||||
}
|
||||
if md.TransformersVersion == "" {
|
||||
md.TransformersVersion = cfg.TransformersVersion
|
||||
}
|
||||
if *configName == "" && cfg.NameOrPath != "" {
|
||||
*configName = lastPathSegment(cfg.NameOrPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readDockerAILicense extracts a short license identifier from the first line
|
||||
// of a vnd.docker.ai.license layer. Docker packages the full license text, so
|
||||
// we only peek at a prefix looking for well-known titles like "Apache License".
|
||||
func readDockerAILicense(resolver file.Resolver, ociResolver file.OCIMediaTypeResolver) string {
|
||||
locations, err := ociResolver.FilesByMediaType(dockerAILicenseMediaType)
|
||||
if err != nil || len(locations) == 0 {
|
||||
return ""
|
||||
}
|
||||
rc, err := resolver.FileContentsByLocation(locations[0])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer internal.CloseAndLogError(rc, locations[0].RealPath)
|
||||
|
||||
buf, err := io.ReadAll(io.LimitReader(rc, 2048))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
text := strings.ToLower(string(buf))
|
||||
switch {
|
||||
case strings.Contains(text, "apache license") && strings.Contains(text, "version 2.0"):
|
||||
return "Apache-2.0"
|
||||
case strings.Contains(text, "mit license"):
|
||||
return "MIT"
|
||||
case strings.Contains(text, "bsd 3-clause"):
|
||||
return "BSD-3-Clause"
|
||||
case strings.Contains(text, "bsd 2-clause"):
|
||||
return "BSD-2-Clause"
|
||||
case strings.Contains(text, "gnu general public license") && strings.Contains(text, "version 3"):
|
||||
return "GPL-3.0"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func hasPrefix(b []byte, s string) bool {
|
||||
return len(b) >= len(s) && string(b[:len(s)]) == s
|
||||
}
|
||||
|
||||
func trimLeadingWhitespace(b []byte) []byte {
|
||||
i := 0
|
||||
for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\r' || b[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
// strip a leading UTF-8 BOM if present
|
||||
if len(b)-i >= 3 && b[i] == 0xEF && b[i+1] == 0xBB && b[i+2] == 0xBF {
|
||||
i += 3
|
||||
}
|
||||
return b[i:]
|
||||
}
|
||||
|
||||
func lastPathSegment(s string) string {
|
||||
if i := strings.LastIndexAny(s, "/\\"); i >= 0 {
|
||||
return s[i+1:]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// parseSafeTensorsOCILayer parses a SafeTensors weight layer from an OCI model
|
||||
// artifact by reading only its JSON header (the layer is fetched up to a small
|
||||
// byte cap by the source layer; tensor data is never downloaded). It emits a
|
||||
// nameless package so safeTensorsMergeProcessor folds the result into the
|
||||
// config-derived named package as a Part. The point of this parser is to give
|
||||
// OCI scans the same content-derived fields the directory-scan path produces:
|
||||
// real tensor count, normalized quantization, __metadata__, and MetadataHash.
|
||||
// parseSafeTensorsOCILayer decodes the JSON header of a SafeTensors weight
|
||||
// layer fetched from an OCI model artifact (the source layer caps each layer
|
||||
// at a small prefix; tensor data is never downloaded). It emits a nameless
|
||||
// package; safeTensorsMergeProcessor folds it into the artifact's group and
|
||||
// rolls per-shard fields up into the final merged package.
|
||||
func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Environment, reader file.LocationReadCloser) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
defer internal.CloseAndLogError(reader, reader.Path())
|
||||
|
||||
@ -268,17 +111,10 @@ func parseSafeTensorsOCILayer(_ context.Context, _ file.Resolver, _ *generic.Env
|
||||
md.Parameters = formatParameterCount(p)
|
||||
}
|
||||
|
||||
// Emit nameless; safeTensorsMergeProcessor will absorb this into the
|
||||
// config-derived named package as a Part. The merge runs even when only
|
||||
// nameless packages exist, in which case the result is dropped.
|
||||
p := newSafeTensorsPackage(
|
||||
&md,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
reader.WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
)
|
||||
|
||||
return []pkg.Package{p}, nil, nil
|
||||
}
|
||||
|
||||
|
||||
@ -124,167 +124,230 @@ func TestSafeTensorsCataloger_shardedIndex(t *testing.T) {
|
||||
TestCataloger(t, NewSafeTensorsCataloger())
|
||||
}
|
||||
|
||||
// TestParseSafeTensorsOCIConfig covers the parser in isolation: it should emit
|
||||
// a nameless package mirroring the config blob's producer-declared fields, and
|
||||
// emit nothing for non-safetensors formats so the GGUF cataloger can claim the
|
||||
// artifact. Naming and license resolution happen in the merge processor and are
|
||||
// tested separately under TestSafeTensorsMergeProcessor.
|
||||
func TestParseSafeTensorsOCIConfig(t *testing.T) {
|
||||
configBlob := []byte(`{"config":{"format":"safetensors","quantization":"Q4_K_M","parameters":"8B","size":"16.00GB","safetensors":{"tensor_count":291}}}`)
|
||||
t.Run("emits a nameless package with config-blob fields", func(t *testing.T) {
|
||||
blob := []byte(`{"config":{"format":"safetensors","quantization":"Q4_K_M","parameters":"8B","size":"16.00GB","safetensors":{"tensor_count":291}}}`)
|
||||
|
||||
t.Run("enriches from companion layers", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
readmePath := filepath.Join(dir, "README.md")
|
||||
require.NoError(t, os.WriteFile(readmePath,
|
||||
[]byte("---\nlicense: mit\nbase_model:\n - org/My-Model\n---\n# card\n"), 0o644))
|
||||
hfConfigPath := filepath.Join(dir, "config.json")
|
||||
require.NoError(t, os.WriteFile(hfConfigPath,
|
||||
[]byte(`{"architectures":["Qwen3ForCausalLM"],"torch_dtype":"bfloat16"}`), 0o644))
|
||||
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: {file.NewLocation(readmePath), file.NewLocation(hfConfigPath)},
|
||||
})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), nil, nil, configReader(blob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
|
||||
p := pkgs[0]
|
||||
assert.Equal(t, "My-Model", p.Name)
|
||||
assert.Equal(t, pkg.ModelPkg, p.Type)
|
||||
assertHasLicense(t, p, "mit")
|
||||
|
||||
assert.Empty(t, p.Name, "config-blob parser must emit nameless; the merge processor names it")
|
||||
assert.Empty(t, p.Licenses.ToSlice(), "license resolution belongs to the merge processor")
|
||||
md := p.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, "safetensors", md.Format)
|
||||
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture)
|
||||
assert.Equal(t, "bfloat16", md.TorchDtype)
|
||||
assert.Equal(t, "Q4_K_M", md.Quantization)
|
||||
assert.Equal(t, "8B", md.Parameters)
|
||||
assert.Equal(t, "16.00GB", md.TotalSize)
|
||||
assert.Equal(t, uint64(291), md.TensorCount)
|
||||
})
|
||||
|
||||
t.Run("falls back to license layer", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
readmePath := filepath.Join(dir, "README.md")
|
||||
require.NoError(t, os.WriteFile(readmePath,
|
||||
[]byte("---\nbase_model:\n - org/My-Model\n---\n"), 0o644))
|
||||
licensePath := filepath.Join(dir, "LICENSE")
|
||||
require.NoError(t, os.WriteFile(licensePath,
|
||||
[]byte(" Apache License\n Version 2.0, January 2004\n"), 0o644))
|
||||
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: {file.NewLocation(readmePath)},
|
||||
dockerAILicenseMediaType: {file.NewLocation(licensePath)},
|
||||
})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
assertHasLicense(t, pkgs[0], "Apache-2.0")
|
||||
})
|
||||
|
||||
t.Run("config _name_or_path wins over README base_model regardless of layer order", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
readmePath := filepath.Join(dir, "README.md")
|
||||
require.NoError(t, os.WriteFile(readmePath, []byte("---\nbase_model:\n - org/Readme-Name\n---\n"), 0o644))
|
||||
hfConfigPath := filepath.Join(dir, "config.json")
|
||||
require.NoError(t, os.WriteFile(hfConfigPath, []byte(`{"_name_or_path":"org/Config-Name"}`), 0o644))
|
||||
|
||||
// both layer orderings must yield the same (config-derived) name
|
||||
orderings := [][]file.Location{
|
||||
{file.NewLocation(readmePath), file.NewLocation(hfConfigPath)},
|
||||
{file.NewLocation(hfConfigPath), file.NewLocation(readmePath)},
|
||||
}
|
||||
for _, locs := range orderings {
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: locs,
|
||||
})
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
assert.Equal(t, "Config-Name", pkgs[0].Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to default name when none derivable", func(t *testing.T) {
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
assert.Equal(t, "safetensors-model", pkgs[0].Name, "model must still be emitted, not dropped")
|
||||
assert.Empty(t, md.MetadataHash, "config blobs have no header content to hash")
|
||||
})
|
||||
|
||||
t.Run("ignores non-safetensors format", func(t *testing.T) {
|
||||
ggufBlob := []byte(`{"config":{"format":"gguf","quantization":"Q4_K_M"}}`)
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(ggufBlob))
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), nil, nil, configReader(ggufBlob))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, pkgs)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSafeTensorsMergeProcessor exercises the merge processor directly with
|
||||
// synthetic input. The full-cataloger integration tests cover the realistic
|
||||
// happy paths; this focuses on grouping, the naming precedence chain, the
|
||||
// drop-when-unnameable rule, and cross-shard rollup.
|
||||
func TestSafeTensorsMergeProcessor(t *testing.T) {
|
||||
named := pkg.Package{Name: "model-a", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}}
|
||||
nameless := pkg.Package{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}}
|
||||
|
||||
t.Run("preserves the part's MetadataHash when the named package already has one", func(t *testing.T) {
|
||||
out, _, err := safeTensorsMergeProcessor([]pkg.Package{named, nameless}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
assert.Equal(t, "model-a", out[0].Name)
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
require.Len(t, md.Parts, 1)
|
||||
assert.Equal(t, "bbbb", md.Parts[0].MetadataHash, "part hash must survive: it is the cross-source fingerprint")
|
||||
assert.Equal(t, "aaaa", md.MetadataHash, "named package's own hash is not overwritten")
|
||||
assert.Equal(t, 1, md.ShardCount)
|
||||
})
|
||||
|
||||
t.Run("lifts the single part's MetadataHash to top-level when named has none", func(t *testing.T) {
|
||||
// This is the OCI single-shard shape: the config-blob parser produces a
|
||||
// named package with no hash; the weight-layer parser produces a nameless
|
||||
// part with the real header hash. Top-level should land in the same field
|
||||
// a dir-scan single-file would populate, so callers can correlate them.
|
||||
namedNoHash := pkg.Package{Name: "model-b", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors"}}
|
||||
part := pkg.Package{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "deadbeef"}}
|
||||
out, _, err := safeTensorsMergeProcessor([]pkg.Package{namedNoHash, part}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, "deadbeef", md.MetadataHash, "single-shard lift makes OCI top-level match dir-scan top-level")
|
||||
require.Len(t, md.Parts, 1)
|
||||
assert.Equal(t, "deadbeef", md.Parts[0].MetadataHash, "part also retains its hash")
|
||||
})
|
||||
|
||||
t.Run("multi-shard preserves per-part hashes and sorts deterministically", func(t *testing.T) {
|
||||
// Three nameless layer packages absorbed into one named config-derived package.
|
||||
// Top-level MetadataHash stays empty (no canonical single hash for a sharded
|
||||
// model — callers must combine the per-shard hashes themselves).
|
||||
namedNoHash := pkg.Package{Name: "model-c", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors"}}
|
||||
parts := []pkg.Package{
|
||||
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "cccc"}},
|
||||
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}},
|
||||
{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}},
|
||||
dirPkg := func(realPath string, md pkg.SafeTensorsModelInfo) pkg.Package {
|
||||
return pkg.Package{
|
||||
Type: pkg.ModelPkg,
|
||||
Metadata: md,
|
||||
Locations: file.NewLocationSet(
|
||||
file.NewLocation(realPath).
|
||||
WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
),
|
||||
}
|
||||
out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{namedNoHash}, parts...), nil, nil)
|
||||
}
|
||||
ociPkg := func(md pkg.SafeTensorsModelInfo) pkg.Package {
|
||||
return pkg.Package{
|
||||
Type: pkg.ModelPkg,
|
||||
Metadata: md,
|
||||
Locations: file.NewLocationSet(
|
||||
file.NewLocation("/").
|
||||
WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("dir scan: parent-dir fallback names a bare safetensors with no siblings", func(t *testing.T) {
|
||||
// case #1: model.safetensors in /models/tiny-llama/ with no config.json
|
||||
// or README. The processor cannot derive a producer name and Architecture
|
||||
// is empty, so it lands on the parent-dir rung.
|
||||
p := dirPkg("/models/tiny-llama/weights.safetensors", pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: 4,
|
||||
Quantization: "BF16",
|
||||
MetadataHash: "abc",
|
||||
})
|
||||
resolver := file.NewMockResolverForPaths() // no config.json / README available
|
||||
out, _, err := safeTensorsMergeProcessor(context.Background(), resolver, []pkg.Package{p}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, 3, md.ShardCount)
|
||||
assert.Empty(t, md.MetadataHash, "multi-shard leaves top-level hash unset")
|
||||
require.Len(t, md.Parts, 3)
|
||||
// Parts sorted by MetadataHash for deterministic SBOM output regardless of resolver order.
|
||||
assert.Equal(t, []string{"aaaa", "bbbb", "cccc"}, []string{md.Parts[0].MetadataHash, md.Parts[1].MetadataHash, md.Parts[2].MetadataHash})
|
||||
assert.Equal(t, "tiny-llama", out[0].Name)
|
||||
})
|
||||
|
||||
t.Run("drops result when no named package", func(t *testing.T) {
|
||||
out, _, err := safeTensorsMergeProcessor([]pkg.Package{nameless}, nil, nil)
|
||||
t.Run("dir scan: parent-dir fallback rescues a metadata-only header", func(t *testing.T) {
|
||||
// case #3: header carries only __metadata__, no tensors. Parameters and
|
||||
// Architecture are both empty, so Arch-Parameters can't fire either —
|
||||
// the parent-dir fallback is the only thing that names the package.
|
||||
p := dirPkg("/scan/edge/headeronly/model.safetensors", pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
MetadataHash: "xyz",
|
||||
UserMetadata: pkg.KeyValues{{Key: "producer", Value: "stgen"}},
|
||||
})
|
||||
resolver := file.NewMockResolverForPaths()
|
||||
out, _, err := safeTensorsMergeProcessor(context.Background(), resolver, []pkg.Package{p}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, out)
|
||||
require.Len(t, out, 1)
|
||||
assert.Equal(t, "headeronly", out[0].Name)
|
||||
})
|
||||
|
||||
t.Run("dir scan: Architecture-Parameters synthetic wins over parent-dir", func(t *testing.T) {
|
||||
// Architecture and Parameters are both populated → synthetic wins over
|
||||
// the parent-dir fallback. _name_or_path is not available (no sibling
|
||||
// config.json mock).
|
||||
p := dirPkg("/models/tiny/weights.safetensors", pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
Architecture: "LlamaForCausalLM",
|
||||
Parameters: "2.68B",
|
||||
TensorCount: 4,
|
||||
MetadataHash: "abc",
|
||||
})
|
||||
resolver := file.NewMockResolverForPaths()
|
||||
out, _, err := safeTensorsMergeProcessor(context.Background(), resolver, []pkg.Package{p}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
assert.Equal(t, "LlamaForCausalLM-2.68B", out[0].Name)
|
||||
})
|
||||
|
||||
t.Run("OCI: dropped when no name source is available", func(t *testing.T) {
|
||||
// The vllm-style shape: config-blob package + a weight-layer package,
|
||||
// both at virtual path "/", no model.file companions on the resolver.
|
||||
// With nothing to derive a name from, the group is dropped (no opaque
|
||||
// fallback / no parent-dir option for OCI).
|
||||
configMd := pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: 5,
|
||||
TotalSize: "1GB",
|
||||
}
|
||||
shardMd := pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: 5,
|
||||
Quantization: "BF16",
|
||||
MetadataHash: "deadbeef",
|
||||
}
|
||||
resolver := file.NewMockResolverForMediaTypes(nil)
|
||||
out, _, err := safeTensorsMergeProcessor(
|
||||
context.Background(), resolver,
|
||||
[]pkg.Package{ociPkg(configMd), ociPkg(shardMd)}, nil, nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, out, "OCI group with no naming source must be dropped")
|
||||
})
|
||||
|
||||
t.Run("OCI: merges config + shard and names from companion config.json", func(t *testing.T) {
|
||||
// Write a single model.file companion blob containing HF config.json so
|
||||
// the processor can derive _name_or_path and Architecture from it.
|
||||
dir := t.TempDir()
|
||||
hfConfigPath := filepath.Join(dir, "config.json")
|
||||
require.NoError(t, os.WriteFile(hfConfigPath,
|
||||
[]byte(`{"architectures":["Qwen3ForCausalLM"],"torch_dtype":"bfloat16","_name_or_path":"org/qwen-tiny"}`), 0o644))
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: {file.NewLocation(hfConfigPath)},
|
||||
})
|
||||
|
||||
configMd := pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
Quantization: "Q4_K_M", // raw producer-declared value
|
||||
Parameters: "8B",
|
||||
TotalSize: "16.00GB",
|
||||
TensorCount: 291,
|
||||
}
|
||||
shardMd := pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: 100, // per-shard count — must NOT be summed onto the aggregate's 291
|
||||
Quantization: "BF16",
|
||||
MetadataHash: "deadbeef",
|
||||
UserMetadata: pkg.KeyValues{{Key: "format", Value: "pt"}},
|
||||
}
|
||||
out, _, err := safeTensorsMergeProcessor(
|
||||
context.Background(), resolver,
|
||||
[]pkg.Package{ociPkg(configMd), ociPkg(shardMd)}, nil, nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
|
||||
got := out[0]
|
||||
assert.Equal(t, "qwen-tiny", got.Name, "name comes from path.Base(_name_or_path)")
|
||||
md := got.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, uint64(291), md.TensorCount, "aggregate TensorCount must win — never double-count by summing the shard")
|
||||
assert.Equal(t, "16.00GB", md.TotalSize)
|
||||
assert.Equal(t, "8B", md.Parameters)
|
||||
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture, "Architecture enriched from companion config.json")
|
||||
assert.Equal(t, "bfloat16", md.TorchDtype)
|
||||
assert.Equal(t, "Q4_K_M", md.Quantization, "aggregate Quantization wins over shard's normalized dtype when both present")
|
||||
assert.Equal(t, "deadbeef", md.MetadataHash, "single-shard rollup is the lone shard's hash")
|
||||
assert.Equal(t, pkg.KeyValues{{Key: "format", Value: "pt"}}, md.UserMetadata)
|
||||
assert.Nil(t, md.Parts, "single-shard groups skip Parts; the outer view already exposes everything")
|
||||
})
|
||||
|
||||
t.Run("OCI: multi-shard rollup hashes are stable and sorted", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
hfConfigPath := filepath.Join(dir, "config.json")
|
||||
require.NoError(t, os.WriteFile(hfConfigPath,
|
||||
[]byte(`{"architectures":["X"],"_name_or_path":"org/multi"}`), 0o644))
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: {file.NewLocation(hfConfigPath)},
|
||||
})
|
||||
|
||||
configMd := pkg.SafeTensorsModelInfo{Format: "safetensors", TensorCount: 9, TotalSize: "3GB"}
|
||||
shard := func(hash string, cnt uint64) pkg.SafeTensorsModelInfo {
|
||||
return pkg.SafeTensorsModelInfo{Format: "safetensors", TensorCount: cnt, Quantization: "BF16", MetadataHash: hash}
|
||||
}
|
||||
in := []pkg.Package{
|
||||
ociPkg(configMd),
|
||||
ociPkg(shard("cccc", 3)),
|
||||
ociPkg(shard("aaaa", 3)),
|
||||
ociPkg(shard("bbbb", 3)),
|
||||
}
|
||||
out1, _, err := safeTensorsMergeProcessor(context.Background(), resolver, in, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out1, 1)
|
||||
md1 := out1[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
require.Len(t, md1.Parts, 3)
|
||||
// Parts deterministically sorted by MetadataHash.
|
||||
assert.Equal(t,
|
||||
[]string{"aaaa", "bbbb", "cccc"},
|
||||
[]string{md1.Parts[0].MetadataHash, md1.Parts[1].MetadataHash, md1.Parts[2].MetadataHash},
|
||||
)
|
||||
// Rollup hash is stable across input ordering.
|
||||
shuffled := []pkg.Package{ociPkg(shard("bbbb", 3)), ociPkg(configMd), ociPkg(shard("aaaa", 3)), ociPkg(shard("cccc", 3))}
|
||||
out2, _, err := safeTensorsMergeProcessor(context.Background(), resolver, shuffled, nil, nil)
|
||||
require.NoError(t, err)
|
||||
md2 := out2[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, md1.MetadataHash, md2.MetadataHash, "rollup hash must not depend on input order")
|
||||
})
|
||||
|
||||
t.Run("passes through upstream error", func(t *testing.T) {
|
||||
sentinel := assert.AnError
|
||||
out, _, err := safeTensorsMergeProcessor([]pkg.Package{named}, nil, sentinel)
|
||||
p := dirPkg("/models/x/y.safetensors", pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "h"})
|
||||
out, _, err := safeTensorsMergeProcessor(context.Background(), nil, []pkg.Package{p}, nil, sentinel)
|
||||
assert.Equal(t, sentinel, err)
|
||||
assert.Len(t, out, 1)
|
||||
assert.Equal(t, []pkg.Package{p}, out)
|
||||
})
|
||||
}
|
||||
|
||||
@ -314,42 +377,65 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
|
||||
assert.Equal(t, wantHash, md.MetadataHash)
|
||||
})
|
||||
|
||||
t.Run("merges with config-derived named package and lifts ShardCount", func(t *testing.T) {
|
||||
// Synthesize what the OCI scan would produce: one config-derived named
|
||||
// package + one weight-layer derived nameless package. Run them through
|
||||
// the merge processor and assert the result looks like a complete model.
|
||||
configMd := pkg.SafeTensorsModelInfo{
|
||||
t.Run("merged via processor: aggregate fields preserved, hash lifted from single shard", func(t *testing.T) {
|
||||
// Synthesize the OCI single-shard shape: a config-blob-derived nameless
|
||||
// package + the weight-layer parser's nameless package (both at virtual
|
||||
// path "/"). With a companion HF config.json on the resolver to provide
|
||||
// _name_or_path, the merge processor produces a single named model.
|
||||
dir := t.TempDir()
|
||||
hfConfigPath := filepath.Join(dir, "config.json")
|
||||
require.NoError(t, os.WriteFile(hfConfigPath,
|
||||
[]byte(`{"architectures":["Qwen3ForCausalLM"],"_name_or_path":"org/qwen-test"}`), 0o644))
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: {file.NewLocation(hfConfigPath)},
|
||||
})
|
||||
|
||||
configPkg := pkg.Package{
|
||||
Type: pkg.ModelPkg,
|
||||
Metadata: pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
Architecture: "Qwen3ForCausalLM",
|
||||
Parameters: "2.68B",
|
||||
TotalSize: "5.00GB",
|
||||
Quantization: "Q4_K_M", // raw producer string
|
||||
TensorCount: 9999,
|
||||
},
|
||||
Locations: file.NewLocationSet(
|
||||
file.NewLocation("/").WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
),
|
||||
}
|
||||
named := pkg.Package{Name: "qwen", Type: pkg.ModelPkg, Metadata: configMd}
|
||||
|
||||
reader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(blob)))
|
||||
reader := file.NewLocationReadCloser(
|
||||
file.NewLocation("/").WithAnnotation(pkg.EvidenceAnnotationKey, pkg.PrimaryEvidenceAnnotation),
|
||||
io.NopCloser(bytes.NewReader(blob)),
|
||||
)
|
||||
layerPkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, reader)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, layerPkgs, 1)
|
||||
|
||||
out, _, err := safeTensorsMergeProcessor(append([]pkg.Package{named}, layerPkgs...), nil, nil)
|
||||
out, _, err := safeTensorsMergeProcessor(
|
||||
context.Background(), resolver,
|
||||
append([]pkg.Package{configPkg}, layerPkgs...), nil, nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, 1, md.ShardCount, "merge processor should set ShardCount from absorbed parts")
|
||||
// Producer-declared top-level fields are preserved.
|
||||
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture)
|
||||
got := out[0]
|
||||
assert.Equal(t, "qwen-test", got.Name, "name comes from the companion config.json _name_or_path")
|
||||
md := got.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
// Aggregate-declared fields win for totals; per-shard count must NOT be
|
||||
// summed into the aggregate.
|
||||
assert.Equal(t, uint64(9999), md.TensorCount)
|
||||
assert.Equal(t, "5.00GB", md.TotalSize)
|
||||
assert.Equal(t, "2.68B", md.Parameters)
|
||||
// Aggregate Quantization wins when set; shard's normalized dtype is the
|
||||
// fallback (not exercised here because the config had Q4_K_M).
|
||||
assert.Equal(t, "Q4_K_M", md.Quantization)
|
||||
// Single-shard: the header-derived MetadataHash is lifted to top-level so
|
||||
// it matches the field a dir-scan would populate.
|
||||
assert.Equal(t, wantHash, md.MetadataHash, "single-shard OCI scan must expose the hash at the same field as a dir scan")
|
||||
// The full per-shard breakdown is also preserved under Parts.
|
||||
require.Len(t, md.Parts, 1)
|
||||
assert.Equal(t, wantHash, md.Parts[0].MetadataHash)
|
||||
assert.Equal(t, wantUserMetadata, md.Parts[0].UserMetadata)
|
||||
assert.Equal(t, uint64(2), md.Parts[0].TensorCount)
|
||||
assert.Equal(t, "BF16", md.Parts[0].Quantization, "part keeps the normalized header dtype")
|
||||
// Architecture comes from companion HF config.json enrichment.
|
||||
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture)
|
||||
// Single-shard groups skip Parts; the rollup hash is the lone shard's hash.
|
||||
assert.Nil(t, md.Parts)
|
||||
assert.Equal(t, wantHash, md.MetadataHash)
|
||||
assert.Equal(t, wantUserMetadata, md.UserMetadata)
|
||||
assert.Equal(t, 1, md.ShardCount)
|
||||
})
|
||||
}
|
||||
|
||||
@ -577,13 +663,6 @@ func TestParseFrontmatter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelNameFromPath(t *testing.T) {
|
||||
assert.Equal(t, "foo", modelNameFromPath("/models/foo/model.safetensors"))
|
||||
assert.Equal(t, "weights", modelNameFromPath("weights.safetensors"))
|
||||
assert.Equal(t, "my-model", modelNameFromIndexPath("/models/my-model/model.safetensors.index.json"))
|
||||
assert.Equal(t, "safetensors-model", modelNameFromIndexPath("model.safetensors.index.json"))
|
||||
}
|
||||
|
||||
func TestDockerAIModelConfigMediaTypes(t *testing.T) {
|
||||
// supported mirrors how the resolver matches: filepath.Match each registered
|
||||
// media type against a layer's media type.
|
||||
|
||||
@ -1,12 +1,32 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/anchore/syft/internal"
|
||||
"github.com/anchore/syft/internal/log"
|
||||
"github.com/anchore/syft/syft/artifact"
|
||||
"github.com/anchore/syft/syft/file"
|
||||
"github.com/anchore/syft/syft/pkg"
|
||||
"github.com/anchore/syft/syft/pkg/cataloger/internal/licenses"
|
||||
)
|
||||
|
||||
// ociGroupKey is the sentinel grouping key for every safetensors package that
|
||||
// originated from an OCI model artifact. The ContainerImageModel resolver gives
|
||||
// each layer the virtual RealPath "/" regardless of layer media type, so all
|
||||
// safetensors packages from a single OCI scan collapse into one group.
|
||||
const ociGroupKey = "@oci@"
|
||||
|
||||
// ggufMergeProcessor consolidates multiple GGUF packages into a single package
|
||||
// representing the AI model. When scanning OCI images with multiple layers,
|
||||
// each layer may produce a separate package. This processor finds the package
|
||||
@ -16,7 +36,6 @@ func ggufMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship, err er
|
||||
if err != nil {
|
||||
return pkgs, rels, err
|
||||
}
|
||||
|
||||
if len(pkgs) == 0 {
|
||||
return pkgs, rels, err
|
||||
}
|
||||
@ -55,69 +74,563 @@ func ggufMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship, err er
|
||||
}
|
||||
}
|
||||
|
||||
// Largest number of key value
|
||||
|
||||
return namedPkgs, rels, err
|
||||
}
|
||||
|
||||
// safeTensorsMergeProcessor mirrors ggufMergeProcessor for SafeTensors packages.
|
||||
// When scanning an OCI AI artifact, the model-config blob produces one named
|
||||
// package and each safetensors weight layer produces a nameless package. The
|
||||
// nameless packages are absorbed into the named one's Parts slice.
|
||||
// safeTensorsMergeProcessor is the single owner of naming, license resolution,
|
||||
// HF config.json mining, cross-shard rollup, and supporting-evidence attachment
|
||||
// for safetensors packages. The parsers it processes are intentionally minimal
|
||||
// — they only decode the safetensors-specific format and emit nameless packages
|
||||
// with content-derived metadata. This function:
|
||||
//
|
||||
// MetadataHash is intentionally preserved on absorbed parts: it is derived
|
||||
// purely from the on-disk safetensors header (see SafeTensorsModelInfo doc),
|
||||
// so it acts as the cross-source content fingerprint. For a single-shard
|
||||
// model we also copy it up to the named package's top-level MetadataHash so
|
||||
// that an OCI scan and a directory scan of the same single .safetensors file
|
||||
// expose the hash at the same field — `md.MetadataHash` — without callers
|
||||
// having to inspect Parts.
|
||||
func safeTensorsMergeProcessor(pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
if err != nil {
|
||||
return pkgs, rels, err
|
||||
}
|
||||
if len(pkgs) == 0 {
|
||||
// 1. groups all nameless packages by parent directory (or a single sentinel
|
||||
// for OCI artifacts, since the ContainerImageModel resolver puts every
|
||||
// layer at virtual path "/");
|
||||
// 2. merges the per-shard metadata (tensor count, dominant dtype, total size,
|
||||
// UserMetadata, rollup MetadataHash) into one package per group;
|
||||
// 3. enriches the merged package by consulting the resolver ONCE per group —
|
||||
// sibling config.json + README.md for dir scans, the model-file companion
|
||||
// layers + license layer for OCI — and attaches those locations as
|
||||
// supporting evidence;
|
||||
// 4. picks a name via the precedence chain
|
||||
// config.json _name_or_path → Architecture-Parameters → parent-dir
|
||||
// and drops the group when none of those produced a name (no opaque
|
||||
// fallback / no MetadataHash-as-name).
|
||||
func safeTensorsMergeProcessor(ctx context.Context, resolver file.Resolver, pkgs []pkg.Package, rels []artifact.Relationship, err error) ([]pkg.Package, []artifact.Relationship, error) {
|
||||
if err != nil || len(pkgs) == 0 {
|
||||
return pkgs, rels, err
|
||||
}
|
||||
|
||||
var namedPkgs []pkg.Package
|
||||
var namelessParts []pkg.SafeTensorsModelInfo
|
||||
// Defensively split off non-safetensors packages — the cataloger only emits
|
||||
// SafeTensorsModelInfo today, but this keeps the processor robust if other
|
||||
// types ever flow through.
|
||||
var stPkgs, other []pkg.Package
|
||||
for _, p := range pkgs {
|
||||
if p.Name != "" {
|
||||
namedPkgs = append(namedPkgs, p)
|
||||
if _, ok := p.Metadata.(pkg.SafeTensorsModelInfo); ok {
|
||||
stPkgs = append(stPkgs, p)
|
||||
continue
|
||||
}
|
||||
if md, ok := p.Metadata.(pkg.SafeTensorsModelInfo); ok {
|
||||
namelessParts = append(namelessParts, md)
|
||||
other = append(other, p)
|
||||
}
|
||||
if len(stPkgs) == 0 {
|
||||
return pkgs, rels, err
|
||||
}
|
||||
|
||||
if len(namedPkgs) == 0 {
|
||||
return nil, rels, err
|
||||
groups := groupSafeTensorsPackages(stPkgs)
|
||||
|
||||
// Deterministic iteration order so the SBOM doesn't depend on map order.
|
||||
keys := make([]string, 0, len(groups))
|
||||
for k := range groups {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
out := other
|
||||
for _, key := range keys {
|
||||
merged := mergeSafeTensorsGroup(groups[key])
|
||||
nameOrPath := enrichSafeTensorsGroup(ctx, resolver, key, &merged)
|
||||
name := pickSafeTensorsName(merged, key, nameOrPath)
|
||||
if name == "" {
|
||||
continue // drop unnameable groups, per design (no opaque fallback)
|
||||
}
|
||||
merged.Name = name
|
||||
merged.SetID()
|
||||
out = append(out, merged)
|
||||
}
|
||||
return out, rels, nil
|
||||
}
|
||||
|
||||
if len(namedPkgs) == 1 && len(namelessParts) > 0 {
|
||||
// Sort by MetadataHash so OCI layer order (map iteration) doesn't leak
|
||||
// into the SBOM output.
|
||||
sort.Slice(namelessParts, func(i, j int) bool {
|
||||
return namelessParts[i].MetadataHash < namelessParts[j].MetadataHash
|
||||
// groupSafeTensorsPackages buckets packages by the parent directory of their
|
||||
// primary-evidence location, or the OCI sentinel when the location lives at
|
||||
// the ContainerImageModel resolver's virtual "/" path.
|
||||
func groupSafeTensorsPackages(pkgs []pkg.Package) map[string][]pkg.Package {
|
||||
out := make(map[string][]pkg.Package)
|
||||
for _, p := range pkgs {
|
||||
key := safeTensorsGroupKey(p)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = append(out[key], p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func safeTensorsGroupKey(p pkg.Package) string {
|
||||
loc := primaryEvidenceLocation(p)
|
||||
if loc == nil {
|
||||
return ""
|
||||
}
|
||||
if loc.RealPath == "/" {
|
||||
return ociGroupKey
|
||||
}
|
||||
return path.Dir(loc.RealPath)
|
||||
}
|
||||
|
||||
func primaryEvidenceLocation(p pkg.Package) *file.Location {
|
||||
locs := p.Locations.ToSlice()
|
||||
for i, l := range locs {
|
||||
if l.Annotations != nil && l.Annotations[pkg.EvidenceAnnotationKey] == pkg.PrimaryEvidenceAnnotation {
|
||||
return &locs[i]
|
||||
}
|
||||
}
|
||||
if len(locs) > 0 {
|
||||
return &locs[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeSafeTensorsGroup folds a group's per-member metadata into a single
|
||||
// package. Members are classified into two buckets to avoid double-counting:
|
||||
//
|
||||
// - "aggregate" members have producer-declared totals (TensorCount, TotalSize,
|
||||
// ShardCount, Parameters) but no MetadataHash — these are the Docker AI
|
||||
// config blob and the sharded-index file.
|
||||
// - "shard" members have a content-derived MetadataHash and per-shard counts —
|
||||
// these are the individual .safetensors header parsers, both dir-scan and
|
||||
// OCI weight-layer.
|
||||
//
|
||||
// Aggregate values are the source of truth for the merged totals when present;
|
||||
// shards contribute Quantization, UserMetadata, the rollup MetadataHash, and
|
||||
// (for multi-shard models) the Parts breakdown.
|
||||
func mergeSafeTensorsGroup(members []pkg.Package) pkg.Package {
|
||||
locSet := unionLocations(members)
|
||||
aggregates, shards := bucketSafeTensorsMembers(members)
|
||||
|
||||
merged := pkg.SafeTensorsModelInfo{Format: "safetensors"}
|
||||
mergeAggregatesInto(&merged, aggregates)
|
||||
shardTensorTotal, hashes := mergeShardsInto(&merged, shards)
|
||||
|
||||
if merged.TensorCount == 0 {
|
||||
merged.TensorCount = shardTensorTotal
|
||||
}
|
||||
if merged.ShardCount == 0 {
|
||||
if len(shards) > 0 {
|
||||
merged.ShardCount = len(shards)
|
||||
} else {
|
||||
merged.ShardCount = 1
|
||||
}
|
||||
}
|
||||
merged.MetadataHash = rollupHash(hashes)
|
||||
|
||||
// Parts only carry value for multi-shard models; for a single shard the
|
||||
// outer view already exposes every per-shard field.
|
||||
if len(shards) > 1 {
|
||||
parts := append([]pkg.SafeTensorsModelInfo(nil), shards...)
|
||||
sort.Slice(parts, func(i, j int) bool {
|
||||
return parts[i].MetadataHash < parts[j].MetadataHash
|
||||
})
|
||||
winner := &namedPkgs[0]
|
||||
if md, ok := winner.Metadata.(pkg.SafeTensorsModelInfo); ok {
|
||||
md.Parts = namelessParts
|
||||
// Trust per-shard headers over the producer-declared shard count.
|
||||
md.ShardCount = len(namelessParts)
|
||||
// Single-shard: lift the part's content fingerprint to the top
|
||||
// level so the field placement matches a dir-scan single file.
|
||||
// Only lift when the named package has no hash of its own (the
|
||||
// OCI config-blob parser never sets one; dir-scan paths never
|
||||
// produce nameless parts, so they don't reach this branch).
|
||||
if len(namelessParts) == 1 && md.MetadataHash == "" {
|
||||
md.MetadataHash = namelessParts[0].MetadataHash
|
||||
merged.Parts = parts
|
||||
}
|
||||
winner.Metadata = md
|
||||
|
||||
return pkg.Package{
|
||||
Locations: locSet,
|
||||
Type: pkg.ModelPkg,
|
||||
Metadata: merged,
|
||||
}
|
||||
}
|
||||
|
||||
return namedPkgs, rels, err
|
||||
// mergeAggregatesInto folds aggregate-declared totals (config blob or sharded
|
||||
// index) into merged. First non-empty wins, so the order aggregates are passed
|
||||
// in determines tie-breaking — in practice there is one config blob and one
|
||||
// index per group, never two of the same kind.
|
||||
func mergeAggregatesInto(merged *pkg.SafeTensorsModelInfo, aggregates []pkg.SafeTensorsModelInfo) {
|
||||
for _, a := range aggregates {
|
||||
if merged.TensorCount == 0 {
|
||||
merged.TensorCount = a.TensorCount
|
||||
}
|
||||
if merged.ShardCount == 0 {
|
||||
merged.ShardCount = a.ShardCount
|
||||
}
|
||||
firstNonEmpty(&merged.Parameters, a.Parameters)
|
||||
firstNonEmpty(&merged.TotalSize, a.TotalSize)
|
||||
firstNonEmpty(&merged.Architecture, a.Architecture)
|
||||
firstNonEmpty(&merged.Quantization, a.Quantization)
|
||||
firstNonEmpty(&merged.TorchDtype, a.TorchDtype)
|
||||
firstNonEmpty(&merged.TransformersVersion, a.TransformersVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// mergeShardsInto folds the per-shard header metadata into merged, returning
|
||||
// the summed shard TensorCount and the list of non-empty per-shard hashes for
|
||||
// the rollup. Architecture / TorchDtype / TransformersVersion are accepted as
|
||||
// fallbacks if a shard ever carries them (the current parsers don't, but the
|
||||
// resolver-backed enrichment runs afterwards and won't overwrite anything
|
||||
// already set, so it's safe to populate them earlier).
|
||||
func mergeShardsInto(merged *pkg.SafeTensorsModelInfo, shards []pkg.SafeTensorsModelInfo) (shardTensorTotal uint64, hashes []string) {
|
||||
seenKV := map[string]bool{}
|
||||
for _, s := range shards {
|
||||
shardTensorTotal += s.TensorCount
|
||||
firstNonEmpty(&merged.Quantization, s.Quantization)
|
||||
firstNonEmpty(&merged.Parameters, s.Parameters)
|
||||
firstNonEmpty(&merged.Architecture, s.Architecture)
|
||||
firstNonEmpty(&merged.TorchDtype, s.TorchDtype)
|
||||
firstNonEmpty(&merged.TransformersVersion, s.TransformersVersion)
|
||||
for _, kv := range s.UserMetadata {
|
||||
if seenKV[kv.Key] {
|
||||
continue
|
||||
}
|
||||
seenKV[kv.Key] = true
|
||||
merged.UserMetadata = append(merged.UserMetadata, kv)
|
||||
}
|
||||
if s.MetadataHash != "" {
|
||||
hashes = append(hashes, s.MetadataHash)
|
||||
}
|
||||
}
|
||||
return shardTensorTotal, hashes
|
||||
}
|
||||
|
||||
func firstNonEmpty(dst *string, v string) {
|
||||
if *dst == "" {
|
||||
*dst = v
|
||||
}
|
||||
}
|
||||
|
||||
// unionLocations gathers every location from every member into a single set.
|
||||
func unionLocations(members []pkg.Package) file.LocationSet {
|
||||
out := file.NewLocationSet()
|
||||
for _, m := range members {
|
||||
for _, l := range m.Locations.ToSlice() {
|
||||
out.Add(l)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// bucketSafeTensorsMembers splits group members into aggregate-flavored entries
|
||||
// (no MetadataHash — Docker AI config blob or sharded index) and shard-flavored
|
||||
// entries (carry a content-derived MetadataHash from a header parser).
|
||||
func bucketSafeTensorsMembers(members []pkg.Package) (aggregates, shards []pkg.SafeTensorsModelInfo) {
|
||||
for _, m := range members {
|
||||
md, ok := m.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if md.MetadataHash != "" {
|
||||
shards = append(shards, md)
|
||||
continue
|
||||
}
|
||||
aggregates = append(aggregates, md)
|
||||
}
|
||||
return aggregates, shards
|
||||
}
|
||||
|
||||
// rollupHash returns a stable hash across the sorted set of per-member
|
||||
// content-derived hashes. For a single member it returns that hash unchanged,
|
||||
// so a single-file dir scan and an OCI scan with one safetensors layer surface
|
||||
// the same value. For multi-shard models the rollup is the xxhash of the
|
||||
// sorted hashes joined with "|".
|
||||
func rollupHash(hashes []string) string {
|
||||
if len(hashes) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(hashes) == 1 {
|
||||
return hashes[0]
|
||||
}
|
||||
sorted := append([]string(nil), hashes...)
|
||||
sort.Strings(sorted)
|
||||
return fmt.Sprintf("%016x", xxhash.Sum64String(strings.Join(sorted, "|")))
|
||||
}
|
||||
|
||||
// enrichSafeTensorsGroup reads the resolver once for the group to populate the
|
||||
// merged metadata's Architecture / TorchDtype / TransformersVersion, set the
|
||||
// licenses on the merged package, and attach the location of every consulted
|
||||
// supporting file as SupportingEvidence. Returns the raw _name_or_path so the
|
||||
// caller can apply path.Base in its naming step.
|
||||
func enrichSafeTensorsGroup(ctx context.Context, resolver file.Resolver, groupKey string, merged *pkg.Package) (nameOrPath string) {
|
||||
md := merged.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
|
||||
var (
|
||||
lics []pkg.License
|
||||
supporting []file.Location
|
||||
)
|
||||
if groupKey == ociGroupKey {
|
||||
nameOrPath, lics, supporting = enrichSafeTensorsOCI(ctx, resolver, &md)
|
||||
} else {
|
||||
nameOrPath, lics, supporting = enrichSafeTensorsDir(ctx, resolver, groupKey, &md)
|
||||
}
|
||||
|
||||
merged.Metadata = md
|
||||
if len(lics) > 0 {
|
||||
merged.Licenses = pkg.NewLicenseSet(lics...)
|
||||
}
|
||||
for _, loc := range supporting {
|
||||
merged.Locations.Add(loc.WithAnnotation(pkg.EvidenceAnnotationKey, pkg.SupportingEvidenceAnnotation))
|
||||
}
|
||||
return nameOrPath
|
||||
}
|
||||
|
||||
// enrichSafeTensorsDir handles the directory-scan case: look for sibling
|
||||
// config.json and README.md next to the model files.
|
||||
func enrichSafeTensorsDir(ctx context.Context, resolver file.Resolver, dir string, md *pkg.SafeTensorsModelInfo) (nameOrPath string, lics []pkg.License, supporting []file.Location) {
|
||||
if loc, cfg := readDirHFConfig(resolver, path.Join(dir, "config.json")); cfg != nil {
|
||||
applyHFConfig(md, cfg)
|
||||
nameOrPath = cfg.NameOrPath
|
||||
supporting = append(supporting, *loc)
|
||||
}
|
||||
|
||||
if loc, fm := readDirReadmeFrontmatter(resolver, path.Join(dir, "README.md")); fm != nil {
|
||||
if fm.License != "" {
|
||||
lics = pkg.NewLicensesFromValuesWithContext(ctx, fm.License)
|
||||
}
|
||||
if nameOrPath == "" && len(fm.BaseModel) > 0 {
|
||||
nameOrPath = fm.BaseModel[0]
|
||||
}
|
||||
supporting = append(supporting, *loc)
|
||||
}
|
||||
return nameOrPath, lics, supporting
|
||||
}
|
||||
|
||||
// enrichSafeTensorsOCI handles the OCI-artifact case: walk the
|
||||
// vnd.docker.ai.model.file layers (READMEs and HF config.json all ride that
|
||||
// media type — we sniff content to tell them apart), then fall back to the
|
||||
// vnd.docker.ai.license layer through the shared license scanner.
|
||||
func enrichSafeTensorsOCI(ctx context.Context, resolver file.Resolver, md *pkg.SafeTensorsModelInfo) (nameOrPath string, lics []pkg.License, supporting []file.Location) {
|
||||
ociResolver, ok := resolver.(file.OCIMediaTypeResolver)
|
||||
if !ok {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
modelFileLocs, err := ociResolver.FilesByMediaType(dockerAIModelFileMediaType)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list docker AI model-file layers: %v", err)
|
||||
}
|
||||
|
||||
// Collect config / readme candidates separately so the layer-iteration order
|
||||
// returned by the resolver doesn't decide the precedence.
|
||||
var configName, readmeName, readmeLicense string
|
||||
for _, loc := range modelFileLocs {
|
||||
if classifyOCIModelFileLayer(resolver, loc, md, &configName, &readmeName, &readmeLicense) {
|
||||
supporting = append(supporting, loc)
|
||||
}
|
||||
}
|
||||
|
||||
// Precedence: config.json _name_or_path > README base_model.
|
||||
if configName != "" {
|
||||
nameOrPath = configName
|
||||
} else {
|
||||
nameOrPath = readmeName
|
||||
}
|
||||
|
||||
// README license takes precedence; fall back to the license layer via the
|
||||
// shared scanner (which understands SPDX text far better than a hand-rolled
|
||||
// substring match).
|
||||
switch {
|
||||
case readmeLicense != "":
|
||||
lics = pkg.NewLicensesFromValuesWithContext(ctx, readmeLicense)
|
||||
default:
|
||||
licLocs, lErr := ociResolver.FilesByMediaType(dockerAILicenseMediaType)
|
||||
if lErr != nil {
|
||||
log.Debugf("failed to list docker AI license layers: %v", lErr)
|
||||
}
|
||||
if len(licLocs) > 0 {
|
||||
lics = licenses.FindAtLocations(ctx, resolver, licLocs...)
|
||||
supporting = append(supporting, licLocs...)
|
||||
}
|
||||
}
|
||||
return nameOrPath, lics, supporting
|
||||
}
|
||||
|
||||
// classifyOCIModelFileLayer reads up to 4 MiB of a model.file layer and
|
||||
// classifies it as README frontmatter or HF config.json based on its leading
|
||||
// bytes. Side-effects: applies HF config fields onto md, accumulates name and
|
||||
// license candidates via the out-params. Returns true when the layer was
|
||||
// successfully classified (and should be recorded as supporting evidence).
|
||||
func classifyOCIModelFileLayer(resolver file.Resolver, loc file.Location, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) bool {
|
||||
rc, err := resolver.FileContentsByLocation(loc)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer internal.CloseAndLogError(rc, loc.RealPath)
|
||||
|
||||
buf, err := io.ReadAll(io.LimitReader(rc, 4*1024*1024))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
trimmed := trimLeadingWhitespace(buf)
|
||||
switch {
|
||||
case hasPrefix(trimmed, "---"):
|
||||
fm := parseFrontmatter(buf)
|
||||
if fm == nil {
|
||||
return false
|
||||
}
|
||||
if *license == "" {
|
||||
*license = fm.License
|
||||
}
|
||||
if *readmeName == "" && len(fm.BaseModel) > 0 {
|
||||
*readmeName = fm.BaseModel[0]
|
||||
}
|
||||
return true
|
||||
case hasPrefix(trimmed, "{"):
|
||||
var cfg hfConfig
|
||||
if err := json.Unmarshal(buf, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
applyHFConfig(md, &cfg)
|
||||
if *configName == "" && cfg.NameOrPath != "" {
|
||||
*configName = cfg.NameOrPath
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// applyHFConfig folds the subset of HF config.json fields we surface in our
|
||||
// metadata onto md. Fields already populated on md are left alone — earlier
|
||||
// content-derived values (Quantization, TensorCount, etc., from header bytes)
|
||||
// always win over producer-declared ones in case of conflict.
|
||||
func applyHFConfig(md *pkg.SafeTensorsModelInfo, cfg *hfConfig) {
|
||||
if md.Architecture == "" && len(cfg.Architectures) > 0 {
|
||||
md.Architecture = cfg.Architectures[0]
|
||||
}
|
||||
if md.TorchDtype == "" {
|
||||
md.TorchDtype = cfg.TorchDtype
|
||||
}
|
||||
if md.TransformersVersion == "" {
|
||||
md.TransformersVersion = cfg.TransformersVersion
|
||||
}
|
||||
}
|
||||
|
||||
// pickSafeTensorsName implements the documented naming precedence chain:
|
||||
//
|
||||
// 1. config.json _name_or_path (path.Base, so "org/Model" → "Model")
|
||||
// 2. OCI manifest title (deferred to a follow-up; reserved here)
|
||||
// 3. Architecture-Parameters synthetic (only when both are populated)
|
||||
// 4. parent directory of the group (dir-scan only — OCI has no useful path)
|
||||
//
|
||||
// Returns "" to signal the merge processor should drop the group rather than
|
||||
// invent a name.
|
||||
func pickSafeTensorsName(merged pkg.Package, groupKey, nameOrPath string) string {
|
||||
md, _ := merged.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
|
||||
if nameOrPath != "" {
|
||||
return path.Base(nameOrPath)
|
||||
}
|
||||
// 2. OCI manifest title — follow-up.
|
||||
|
||||
if md.Architecture != "" && md.Parameters != "" {
|
||||
return md.Architecture + "-" + md.Parameters
|
||||
}
|
||||
|
||||
if groupKey != ociGroupKey {
|
||||
base := path.Base(groupKey)
|
||||
if base != "" && base != "." && base != "/" {
|
||||
return base
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --- Relocated enrichment helpers ----------------------------------------
|
||||
//
|
||||
// These types and functions used to live in the parser files; they moved here
|
||||
// when the parsers shrank to "just decode the safetensors-specific format" and
|
||||
// every resolver-backed read centralized in the merge processor.
|
||||
|
||||
// hfConfig is a minimal projection of Hugging Face config.json fields.
|
||||
type hfConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
TransformersVersion string `json:"transformers_version"`
|
||||
NameOrPath string `json:"_name_or_path"`
|
||||
}
|
||||
|
||||
// readmeFrontmatter holds the subset of YAML frontmatter fields we extract.
|
||||
type readmeFrontmatter struct {
|
||||
License string `yaml:"license"`
|
||||
BaseModel []string `yaml:"base_model"`
|
||||
}
|
||||
|
||||
func readDirHFConfig(resolver file.Resolver, p string) (*file.Location, *hfConfig) {
|
||||
locations, err := resolver.FilesByPath(p)
|
||||
if err != nil || len(locations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rc, err := resolver.FileContentsByLocation(locations[0])
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer internal.CloseAndLogError(rc, p)
|
||||
|
||||
var cfg hfConfig
|
||||
if err := json.NewDecoder(rc).Decode(&cfg); err != nil {
|
||||
log.Debugf("failed to decode %s: %v", p, err)
|
||||
return nil, nil
|
||||
}
|
||||
return &locations[0], &cfg
|
||||
}
|
||||
|
||||
func readDirReadmeFrontmatter(resolver file.Resolver, p string) (*file.Location, *readmeFrontmatter) {
|
||||
locations, err := resolver.FilesByPath(p)
|
||||
if err != nil || len(locations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rc, err := resolver.FileContentsByLocation(locations[0])
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer internal.CloseAndLogError(rc, p)
|
||||
|
||||
buf, err := io.ReadAll(io.LimitReader(rc, 1024*1024))
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
fm := parseFrontmatter(buf)
|
||||
if fm == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &locations[0], fm
|
||||
}
|
||||
|
||||
// parseFrontmatter pulls the YAML block between the first and second "---"
|
||||
// lines of a file (if present) and decodes the fields we care about. base_model
|
||||
// is decoded via yaml.Node so a scalar value ("org/model") doesn't fail the
|
||||
// whole block.
|
||||
func parseFrontmatter(buf []byte) *readmeFrontmatter {
|
||||
trimmed := bytes.TrimLeft(buf, "\xef\xbb\xbf \t\r\n")
|
||||
if !bytes.HasPrefix(trimmed, []byte("---")) {
|
||||
return nil
|
||||
}
|
||||
rest := trimmed[3:]
|
||||
if i := bytes.IndexByte(rest, '\n'); i >= 0 {
|
||||
rest = rest[i+1:]
|
||||
}
|
||||
end := bytes.Index(rest, []byte("\n---"))
|
||||
if end < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw struct {
|
||||
License string `yaml:"license"`
|
||||
BaseModel yaml.Node `yaml:"base_model"`
|
||||
}
|
||||
if err := yaml.Unmarshal(rest[:end], &raw); err != nil {
|
||||
log.Debugf("failed to parse README frontmatter: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
fm := readmeFrontmatter{License: raw.License}
|
||||
switch raw.BaseModel.Kind {
|
||||
case yaml.ScalarNode:
|
||||
if raw.BaseModel.Value != "" {
|
||||
fm.BaseModel = []string{raw.BaseModel.Value}
|
||||
}
|
||||
case yaml.SequenceNode:
|
||||
_ = raw.BaseModel.Decode(&fm.BaseModel)
|
||||
}
|
||||
return &fm
|
||||
}
|
||||
|
||||
func hasPrefix(b []byte, s string) bool {
|
||||
return len(b) >= len(s) && string(b[:len(s)]) == s
|
||||
}
|
||||
|
||||
func trimLeadingWhitespace(b []byte) []byte {
|
||||
i := 0
|
||||
for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\r' || b[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
if len(b)-i >= 3 && b[i] == 0xEF && b[i+1] == 0xBB && b[i+2] == 0xBF {
|
||||
i += 3
|
||||
}
|
||||
return b[i:]
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user