diff --git a/syft/pkg/cataloger/ai/parse_safetensors_model.go b/syft/pkg/cataloger/ai/parse_safetensors_model.go index 1a7b81fe0..a1e70d655 100644 --- a/syft/pkg/cataloger/ai/parse_safetensors_model.go +++ b/syft/pkg/cataloger/ai/parse_safetensors_model.go @@ -8,6 +8,7 @@ import ( "io" "path" "path/filepath" + "strconv" "strings" "gopkg.in/yaml.v3" @@ -213,27 +214,34 @@ func parseFrontmatter(buf []byte) *readmeFrontmatter { if end < 0 { return nil } - var fm readmeFrontmatter - if err := yaml.Unmarshal(rest[:end], &fm); err != nil { + + // base_model may be either a scalar ("org/model") or a sequence; decode it + // as a yaml.Node so a scalar value does not fail the whole block. + var raw struct { + License string `yaml:"license"` + BaseModel yaml.Node `yaml:"base_model"` + } + if err := yaml.Unmarshal(rest[:end], &raw); err != nil { log.Debugf("failed to parse README frontmatter: %v", err) return nil } - // base_model may also appear as a scalar; yaml.Unmarshal will fail silently in that case. - if fm.License == "" && len(fm.BaseModel) == 0 { - var alt struct { - License string `yaml:"license"` - BaseModel string `yaml:"base_model"` - } - if err := yaml.Unmarshal(rest[:end], &alt); err == nil { - fm.License = alt.License - if alt.BaseModel != "" { - fm.BaseModel = []string{alt.BaseModel} - } + + fm := readmeFrontmatter{License: raw.License} + switch raw.BaseModel.Kind { + case yaml.ScalarNode: + if raw.BaseModel.Value != "" { + fm.BaseModel = []string{raw.BaseModel.Value} } + case yaml.SequenceNode: + _ = raw.BaseModel.Decode(&fm.BaseModel) } return &fm } +// defaultModelName is the fallback package name when no model name can be +// derived from sibling files, the file path, or OCI companion layers. +const defaultModelName = "safetensors-model" + // modelNameFromPath turns "/models/foo/model.safetensors" into "foo". // For a bare filename "weights.safetensors" we return "weights". func modelNameFromPath(p string) string { @@ -252,7 +260,7 @@ func modelNameFromIndexPath(p string) string { if dir != "" && dir != "." && dir != string(filepath.Separator) { return dir } - return "safetensors-model" + return defaultModelName } // formatParameterCount prints a count like 6_700_000_000 as "6.7B" using B/M/K @@ -274,8 +282,8 @@ func formatParameterCount(n uint64) string { // "71.90GB". Non-numeric inputs are passed through unchanged so we never lose // producer-declared strings such as "71.90GB". func formatByteSize(s string) string { - var n uint64 - if _, err := fmt.Sscanf(s, "%d", &n); err != nil || n == 0 { + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || n == 0 { return s } const ( diff --git a/syft/pkg/cataloger/ai/parse_safetensors_oci.go b/syft/pkg/cataloger/ai/parse_safetensors_oci.go index 5da4ce92d..499daf944 100644 --- a/syft/pkg/cataloger/ai/parse_safetensors_oci.go +++ b/syft/pkg/cataloger/ai/parse_safetensors_oci.go @@ -72,6 +72,9 @@ func parseSafeTensorsOCIConfig(_ context.Context, resolver file.Resolver, _ *gen } name, license := enrichFromDockerAILayers(resolver, &md) + if name == "" { + name = defaultModelName + } p := newSafeTensorsPackage( &md, @@ -99,8 +102,18 @@ func enrichFromDockerAILayers(resolver file.Resolver, md *pkg.SafeTensorsModelIn if err != nil { log.Debugf("failed to list docker AI model-file layers: %v", err) } + + // Collect name candidates separately so precedence does not depend on the + // order the resolver returns layers in. config.json's _name_or_path wins over + // a README base_model, matching enrichFromSiblings. + var configName, readmeName string for _, loc := range modelFileLocations { - readAndClassifyDockerAILayer(resolver, loc, md, &name, &license) + readAndClassifyDockerAILayer(resolver, loc, md, &configName, &readmeName, &license) + } + + name = configName + if name == "" { + name = readmeName } if license == "" { @@ -113,7 +126,7 @@ func enrichFromDockerAILayers(resolver file.Resolver, md *pkg.SafeTensorsModelIn // readAndClassifyDockerAILayer fetches a single Docker AI model-file layer and // passes its contents to classifyAndMerge. Split out from the calling loop so // the resolver handle is closed via defer on every iteration. -func readAndClassifyDockerAILayer(resolver file.Resolver, loc file.Location, md *pkg.SafeTensorsModelInfo, name, license *string) { +func readAndClassifyDockerAILayer(resolver file.Resolver, loc file.Location, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) { rc, err := resolver.FileContentsByLocation(loc) if err != nil { return @@ -124,13 +137,13 @@ func readAndClassifyDockerAILayer(resolver file.Resolver, loc file.Location, md if err != nil { return } - classifyAndMerge(buf, md, name, license) + classifyAndMerge(buf, md, configName, readmeName, license) } // classifyAndMerge sniffs a vnd.docker.ai.model.file blob (which can be README.md, // config.json, generation_config.json, tokenizer.json, etc.) and folds useful // fields into the metadata struct and out-parameters. -func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, name, license *string) { +func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) { trimmed := trimLeadingWhitespace(buf) switch { case hasPrefix(trimmed, "---"): @@ -138,8 +151,8 @@ func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, name, license *s if *license == "" { *license = fm.License } - if *name == "" && len(fm.BaseModel) > 0 { - *name = lastPathSegment(fm.BaseModel[0]) + if *readmeName == "" && len(fm.BaseModel) > 0 { + *readmeName = lastPathSegment(fm.BaseModel[0]) } } case hasPrefix(trimmed, "{"): @@ -156,8 +169,8 @@ func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, name, license *s if md.TransformersVersion == "" { md.TransformersVersion = cfg.TransformersVersion } - if *name == "" && cfg.NameOrPath != "" { - *name = lastPathSegment(cfg.NameOrPath) + if *configName == "" && cfg.NameOrPath != "" { + *configName = lastPathSegment(cfg.NameOrPath) } } } diff --git a/syft/pkg/cataloger/ai/parse_safetensors_test.go b/syft/pkg/cataloger/ai/parse_safetensors_test.go new file mode 100644 index 000000000..783edf5a0 --- /dev/null +++ b/syft/pkg/cataloger/ai/parse_safetensors_test.go @@ -0,0 +1,412 @@ +package ai + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/anchore/syft/syft/file" + "github.com/anchore/syft/syft/pkg" + "github.com/anchore/syft/syft/pkg/cataloger/internal/pkgtest" +) + +// buildSafeTensorsFile builds the on-disk bytes of a .safetensors file: an +// 8-byte little-endian header length followed by the JSON header. Tensor data +// is omitted because the parser only reads the header. +func buildSafeTensorsFile(t *testing.T, metadata map[string]string, tensors map[string]safeTensorsEntry) []byte { + t.Helper() + raw := map[string]any{} + if metadata != nil { + raw["__metadata__"] = metadata + } + for name, entry := range tensors { + raw[name] = entry + } + body, err := json.Marshal(raw) + require.NoError(t, err) + + out := make([]byte, 8+len(body)) + binary.LittleEndian.PutUint64(out[:8], uint64(len(body))) + copy(out[8:], body) + return out +} + +func TestSafeTensorsCataloger_singleFile(t *testing.T) { + userMeta := map[string]string{"format": "pt"} + tensors := map[string]safeTensorsEntry{ + "model.embed.weight": {DType: "BF16", Shape: []int64{1000, 16}, DataOffsets: []int64{0, 32000}}, + "model.layer.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32000, 32512}}, + } + // the dedicated hash test below locks the algorithm; here we only assert the + // cataloger wires the header hash through to the package metadata. + wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash() + + dir := t.TempDir() + modelDir := filepath.Join(dir, "models") + require.NoError(t, os.MkdirAll(modelDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(modelDir, "model.safetensors"), buildSafeTensorsFile(t, userMeta, tensors), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(modelDir, "config.json"), + []byte(`{"architectures":["LlamaForCausalLM"],"torch_dtype":"bfloat16","transformers_version":"4.40.0","_name_or_path":"meta-llama/Llama-3-8B"}`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(modelDir, "README.md"), + []byte("---\nlicense: Apache-2.0\nbase_model:\n - meta-llama/Llama-3\n---\n# Llama 3\n"), 0o644)) + + expected := []pkg.Package{ + { + Name: "Llama-3-8B", + Type: pkg.ModelPkg, + Licenses: pkg.NewLicenseSet( + pkg.NewLicenseFromFields("Apache-2.0", "", nil), + ), + Metadata: pkg.SafeTensorsModelInfo{ + Format: "safetensors", + Architecture: "LlamaForCausalLM", + Quantization: "BF16", + Parameters: "16.26K", + TensorCount: 2, + TorchDtype: "bfloat16", + TransformersVersion: "4.40.0", + ShardCount: 1, + UserMetadata: userMeta, + MetadataHash: wantHash, + }, + }, + } + + pkgtest.NewCatalogTester(). + FromDirectory(t, dir). + Expects(expected, nil). + IgnoreLocationLayer(). + IgnorePackageFields("FoundBy", "Locations"). + TestCataloger(t, NewSafeTensorsCataloger()) +} + +func TestSafeTensorsCataloger_shardedIndex(t *testing.T) { + dir := t.TempDir() + modelDir := filepath.Join(dir, "my-model") + require.NoError(t, os.MkdirAll(modelDir, 0o755)) + index := `{ + "metadata": {"total_size": 16000000000}, + "weight_map": { + "layer.0.weight": "model-00001-of-00002.safetensors", + "layer.1.weight": "model-00001-of-00002.safetensors", + "layer.2.weight": "model-00002-of-00002.safetensors" + } + }` + require.NoError(t, os.WriteFile(filepath.Join(modelDir, "model.safetensors.index.json"), []byte(index), 0o644)) + + expected := []pkg.Package{ + { + Name: "my-model", + Type: pkg.ModelPkg, + Licenses: pkg.NewLicenseSet(), + Metadata: pkg.SafeTensorsModelInfo{ + Format: "safetensors", + TensorCount: 3, + ShardCount: 2, + TotalSize: "14.90GB", + }, + }, + } + + pkgtest.NewCatalogTester(). + FromDirectory(t, dir). + Expects(expected, nil). + IgnoreLocationLayer(). + IgnorePackageFields("FoundBy", "Locations"). + TestCataloger(t, NewSafeTensorsCataloger()) +} + +func TestParseSafeTensorsOCIConfig(t *testing.T) { + configBlob := []byte(`{"config":{"format":"safetensors","quantization":"Q4_K_M","parameters":"8B","size":"16.00GB","safetensors":{"tensor_count":291}}}`) + + t.Run("enriches from companion layers", func(t *testing.T) { + dir := t.TempDir() + readmePath := filepath.Join(dir, "README.md") + require.NoError(t, os.WriteFile(readmePath, + []byte("---\nlicense: mit\nbase_model:\n - org/My-Model\n---\n# card\n"), 0o644)) + hfConfigPath := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(hfConfigPath, + []byte(`{"architectures":["Qwen3ForCausalLM"],"torch_dtype":"bfloat16"}`), 0o644)) + + resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{ + dockerAIModelFileMediaType: {file.NewLocation(readmePath), file.NewLocation(hfConfigPath)}, + }) + + pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob)) + require.NoError(t, err) + require.Len(t, pkgs, 1) + + p := pkgs[0] + assert.Equal(t, "My-Model", p.Name) + assert.Equal(t, pkg.ModelPkg, p.Type) + assertHasLicense(t, p, "mit") + + md := p.Metadata.(pkg.SafeTensorsModelInfo) + assert.Equal(t, "safetensors", md.Format) + assert.Equal(t, "Qwen3ForCausalLM", md.Architecture) + assert.Equal(t, "bfloat16", md.TorchDtype) + assert.Equal(t, "Q4_K_M", md.Quantization) + assert.Equal(t, "8B", md.Parameters) + assert.Equal(t, "16.00GB", md.TotalSize) + assert.Equal(t, uint64(291), md.TensorCount) + }) + + t.Run("falls back to license layer", func(t *testing.T) { + dir := t.TempDir() + readmePath := filepath.Join(dir, "README.md") + require.NoError(t, os.WriteFile(readmePath, + []byte("---\nbase_model:\n - org/My-Model\n---\n"), 0o644)) + licensePath := filepath.Join(dir, "LICENSE") + require.NoError(t, os.WriteFile(licensePath, + []byte(" Apache License\n Version 2.0, January 2004\n"), 0o644)) + + resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{ + dockerAIModelFileMediaType: {file.NewLocation(readmePath)}, + dockerAILicenseMediaType: {file.NewLocation(licensePath)}, + }) + + pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob)) + require.NoError(t, err) + require.Len(t, pkgs, 1) + assertHasLicense(t, pkgs[0], "Apache-2.0") + }) + + t.Run("config _name_or_path wins over README base_model regardless of layer order", func(t *testing.T) { + dir := t.TempDir() + readmePath := filepath.Join(dir, "README.md") + require.NoError(t, os.WriteFile(readmePath, []byte("---\nbase_model:\n - org/Readme-Name\n---\n"), 0o644)) + hfConfigPath := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(hfConfigPath, []byte(`{"_name_or_path":"org/Config-Name"}`), 0o644)) + + // both layer orderings must yield the same (config-derived) name + orderings := [][]file.Location{ + {file.NewLocation(readmePath), file.NewLocation(hfConfigPath)}, + {file.NewLocation(hfConfigPath), file.NewLocation(readmePath)}, + } + for _, locs := range orderings { + resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{ + dockerAIModelFileMediaType: locs, + }) + pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob)) + require.NoError(t, err) + require.Len(t, pkgs, 1) + assert.Equal(t, "Config-Name", pkgs[0].Name) + } + }) + + t.Run("falls back to default name when none derivable", func(t *testing.T) { + resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{}) + + pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob)) + require.NoError(t, err) + require.Len(t, pkgs, 1) + assert.Equal(t, "safetensors-model", pkgs[0].Name, "model must still be emitted, not dropped") + }) + + t.Run("ignores non-safetensors format", func(t *testing.T) { + ggufBlob := []byte(`{"config":{"format":"gguf","quantization":"Q4_K_M"}}`) + resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{}) + + pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(ggufBlob)) + require.NoError(t, err) + assert.Empty(t, pkgs) + }) +} + +func TestSafeTensorsMergeProcessor(t *testing.T) { + named := pkg.Package{Name: "model-a", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}} + nameless := pkg.Package{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}} + + t.Run("merges nameless into named parts", func(t *testing.T) { + out, _, err := safeTensorsMergeProcessor([]pkg.Package{named, nameless}, nil, nil) + require.NoError(t, err) + require.Len(t, out, 1) + assert.Equal(t, "model-a", out[0].Name) + md := out[0].Metadata.(pkg.SafeTensorsModelInfo) + require.Len(t, md.Parts, 1) + assert.Empty(t, md.Parts[0].MetadataHash, "nameless part hash should be cleared") + }) + + t.Run("drops result when no named package", func(t *testing.T) { + out, _, err := safeTensorsMergeProcessor([]pkg.Package{nameless}, nil, nil) + require.NoError(t, err) + assert.Empty(t, out) + }) + + t.Run("passes through upstream error", func(t *testing.T) { + sentinel := assert.AnError + out, _, err := safeTensorsMergeProcessor([]pkg.Package{named}, nil, sentinel) + assert.Equal(t, sentinel, err) + assert.Len(t, out, 1) + }) +} + +func configReader(blob []byte) file.LocationReadCloser { + return file.NewLocationReadCloser(file.NewLocation("/config.json"), io.NopCloser(bytes.NewReader(blob))) +} + +func assertHasLicense(t *testing.T, p pkg.Package, value string) { + t.Helper() + for _, l := range p.Licenses.ToSlice() { + if l.Value == value { + return + } + } + t.Errorf("expected license %q, got %+v", value, p.Licenses.ToSlice()) +} + +func TestReadSafeTensorsHeader(t *testing.T) { + t.Run("valid header", func(t *testing.T) { + data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{ + "w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}}, + }) + h, n, err := readSafeTensorsHeader(bytes.NewReader(data)) + require.NoError(t, err) + assert.Equal(t, uint64(len(data)-8), n) + assert.Len(t, h.tensors, 1) + assert.Equal(t, "pt", h.metadata["format"]) + }) + + t.Run("zero-length header", func(t *testing.T) { + var buf [8]byte // length prefix of 0 + _, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) + require.Error(t, err) + }) + + t.Run("truncated body", func(t *testing.T) { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none + _, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:])) + require.Error(t, err) + }) +} + +func TestSafeTensorsHeader_metadataHash(t *testing.T) { + base := &safeTensorsHeader{ + metadata: map[string]string{"format": "pt"}, + tensors: map[string]safeTensorsEntry{ + "a.weight": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}}, + "b.weight": {DType: "F16", Shape: []int64{4}, DataOffsets: []int64{16, 24}}, + }, + } + + // deterministic across calls and independent of map insertion order + reordered := &safeTensorsHeader{ + metadata: map[string]string{"format": "pt"}, + tensors: map[string]safeTensorsEntry{ + "b.weight": {DType: "F16", Shape: []int64{4}, DataOffsets: []int64{16, 24}}, + "a.weight": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}}, + }, + } + assert.Equal(t, base.metadataHash(), reordered.metadataHash()) + assert.Len(t, base.metadataHash(), 16) + + // changing a tensor changes the hash + changed := &safeTensorsHeader{ + metadata: base.metadata, + tensors: map[string]safeTensorsEntry{ + "a.weight": {DType: "F32", Shape: []int64{2, 3}, DataOffsets: []int64{0, 24}}, + "b.weight": {DType: "F16", Shape: []int64{4}, DataOffsets: []int64{24, 32}}, + }, + } + assert.NotEqual(t, base.metadataHash(), changed.metadataHash()) + + // changing __metadata__ changes the hash + differentMeta := &safeTensorsHeader{metadata: map[string]string{"format": "np"}, tensors: base.tensors} + assert.NotEqual(t, base.metadataHash(), differentMeta.metadataHash()) +} + +func TestSafeTensorsHeader_parameterCountAndDType(t *testing.T) { + h := &safeTensorsHeader{tensors: map[string]safeTensorsEntry{ + "big": {DType: "BF16", Shape: []int64{1000, 16}}, + "small": {DType: "F32", Shape: []int64{16, 16}}, + "scalar": {DType: "F32", Shape: []int64{}}, // empty shape contributes 1 + }} + assert.Equal(t, uint64(1000*16+16*16+1), h.parameterCount()) + assert.Equal(t, "BF16", h.dominantDType()) +} + +func TestNormalizeDType(t *testing.T) { + cases := map[string]string{ + "BF16": "BF16", + "float16": "F16", + "FP32": "F32", + "int8": "I8", + "U8": "U8", + "bool": "BOOL", + "weird": "WEIRD", + } + for in, want := range cases { + assert.Equalf(t, want, normalizeDType(in), "normalizeDType(%q)", in) + } +} + +func TestFormatParameterCount(t *testing.T) { + cases := map[uint64]string{ + 512: "512", + 16256: "16.26K", + 2_680_000_000: "2.68B", + 35_000_000: "35.00M", + } + for in, want := range cases { + assert.Equalf(t, want, formatParameterCount(in), "formatParameterCount(%d)", in) + } +} + +func TestFormatByteSize(t *testing.T) { + cases := map[string]string{ + "16000000000": "14.90GB", + "2048": "2.00KB", + "500": "500B", + "71.90GB": "71.90GB", // non-numeric passes through unchanged + "": "", + } + for in, want := range cases { + assert.Equalf(t, want, formatByteSize(in), "formatByteSize(%q)", in) + } +} + +func TestParseFrontmatter(t *testing.T) { + t.Run("list base_model", func(t *testing.T) { + fm := parseFrontmatter([]byte("---\nlicense: mit\nbase_model:\n - org/Model\n---\nbody")) + require.NotNil(t, fm) + assert.Equal(t, "mit", fm.License) + assert.Equal(t, []string{"org/Model"}, fm.BaseModel) + }) + + t.Run("scalar base_model", func(t *testing.T) { + fm := parseFrontmatter([]byte("---\nlicense: apache-2.0\nbase_model: org/Model\n---\n")) + require.NotNil(t, fm) + assert.Equal(t, "apache-2.0", fm.License) + assert.Equal(t, []string{"org/Model"}, fm.BaseModel) + }) + + t.Run("leading BOM", func(t *testing.T) { + fm := parseFrontmatter([]byte("\xef\xbb\xbf---\nlicense: mit\n---\n")) + require.NotNil(t, fm) + assert.Equal(t, "mit", fm.License) + }) + + t.Run("no frontmatter", func(t *testing.T) { + assert.Nil(t, parseFrontmatter([]byte("# just a heading\n"))) + }) + + t.Run("unterminated frontmatter", func(t *testing.T) { + assert.Nil(t, parseFrontmatter([]byte("---\nlicense: mit\n"))) + }) +} + +func TestModelNameFromPath(t *testing.T) { + assert.Equal(t, "foo", modelNameFromPath("/models/foo/model.safetensors")) + assert.Equal(t, "weights", modelNameFromPath("weights.safetensors")) + assert.Equal(t, "my-model", modelNameFromIndexPath("/models/my-model/model.safetensors.index.json")) + assert.Equal(t, "safetensors-model", modelNameFromIndexPath("model.safetensors.index.json")) +} diff --git a/syft/source/ocimodelsource/registry_client_safetensors_test.go b/syft/source/ocimodelsource/registry_client_safetensors_test.go new file mode 100644 index 000000000..2ca91b4a2 --- /dev/null +++ b/syft/source/ocimodelsource/registry_client_safetensors_test.go @@ -0,0 +1,55 @@ +package ocimodelsource + +import ( + "testing" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/types" + "github.com/stretchr/testify/assert" + + "github.com/anchore/syft/syft/source" +) + +func TestDetectModelFormat(t *testing.T) { + tests := []struct { + name string + gguf int + safetensors int + expected string + }{ + {name: "gguf only", gguf: 2, safetensors: 0, expected: modelFormatGGUF}, + {name: "safetensors only", gguf: 0, safetensors: 3, expected: modelFormatSafeTensors}, + {name: "both prefers gguf", gguf: 1, safetensors: 1, expected: modelFormatGGUF}, + {name: "neither", gguf: 0, safetensors: 0, expected: ""}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, detectModelFormat(test.gguf, test.safetensors)) + }) + } +} + +func TestExtractSafeTensorsLayers(t *testing.T) { + manifest := &v1.Manifest{Layers: []v1.Descriptor{ + {MediaType: types.MediaType(safetensorsLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "a"}}, + {MediaType: types.MediaType(ggufLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "b"}}, + {MediaType: types.MediaType(safetensorsLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "c"}}, + }} + assert.Len(t, extractSafeTensorsLayers(manifest), 2) +} + +func TestExtractCompanionLayers(t *testing.T) { + manifest := &v1.Manifest{Layers: []v1.Descriptor{ + {MediaType: types.MediaType(modelFileMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "readme"}}, + {MediaType: types.MediaType(licenseMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "license"}}, + {MediaType: types.MediaType(safetensorsLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "weights"}}, + {MediaType: types.DockerLayer, Digest: v1.Hash{Algorithm: "sha256", Hex: "other"}}, + }} + // only the model.file and license layers should be selected (not weights or arbitrary layers) + assert.Len(t, extractCompanionLayers(manifest), 2) +} + +func TestCalculateTotalSize(t *testing.T) { + layers := []source.LayerMetadata{{Size: 100}, {Size: 250}, {Size: 0}} + assert.Equal(t, int64(350), calculateTotalSize(layers)) +}