mirror of
https://github.com/anchore/syft.git
synced 2026-07-05 02:28:25 +02:00
fix: non deterministic name on iteration
Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
parent
52653e24fc
commit
1f035bc369
@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
@ -213,27 +214,34 @@ func parseFrontmatter(buf []byte) *readmeFrontmatter {
|
||||
if end < 0 {
|
||||
return nil
|
||||
}
|
||||
var fm readmeFrontmatter
|
||||
if err := yaml.Unmarshal(rest[:end], &fm); err != nil {
|
||||
|
||||
// base_model may be either a scalar ("org/model") or a sequence; decode it
|
||||
// as a yaml.Node so a scalar value does not fail the whole block.
|
||||
var raw struct {
|
||||
License string `yaml:"license"`
|
||||
BaseModel yaml.Node `yaml:"base_model"`
|
||||
}
|
||||
if err := yaml.Unmarshal(rest[:end], &raw); err != nil {
|
||||
log.Debugf("failed to parse README frontmatter: %v", err)
|
||||
return nil
|
||||
}
|
||||
// base_model may also appear as a scalar; yaml.Unmarshal will fail silently in that case.
|
||||
if fm.License == "" && len(fm.BaseModel) == 0 {
|
||||
var alt struct {
|
||||
License string `yaml:"license"`
|
||||
BaseModel string `yaml:"base_model"`
|
||||
}
|
||||
if err := yaml.Unmarshal(rest[:end], &alt); err == nil {
|
||||
fm.License = alt.License
|
||||
if alt.BaseModel != "" {
|
||||
fm.BaseModel = []string{alt.BaseModel}
|
||||
}
|
||||
|
||||
fm := readmeFrontmatter{License: raw.License}
|
||||
switch raw.BaseModel.Kind {
|
||||
case yaml.ScalarNode:
|
||||
if raw.BaseModel.Value != "" {
|
||||
fm.BaseModel = []string{raw.BaseModel.Value}
|
||||
}
|
||||
case yaml.SequenceNode:
|
||||
_ = raw.BaseModel.Decode(&fm.BaseModel)
|
||||
}
|
||||
return &fm
|
||||
}
|
||||
|
||||
// defaultModelName is the fallback package name when no model name can be
|
||||
// derived from sibling files, the file path, or OCI companion layers.
|
||||
const defaultModelName = "safetensors-model"
|
||||
|
||||
// modelNameFromPath turns "/models/foo/model.safetensors" into "foo".
|
||||
// For a bare filename "weights.safetensors" we return "weights".
|
||||
func modelNameFromPath(p string) string {
|
||||
@ -252,7 +260,7 @@ func modelNameFromIndexPath(p string) string {
|
||||
if dir != "" && dir != "." && dir != string(filepath.Separator) {
|
||||
return dir
|
||||
}
|
||||
return "safetensors-model"
|
||||
return defaultModelName
|
||||
}
|
||||
|
||||
// formatParameterCount prints a count like 6_700_000_000 as "6.7B" using B/M/K
|
||||
@ -274,8 +282,8 @@ func formatParameterCount(n uint64) string {
|
||||
// "71.90GB". Non-numeric inputs are passed through unchanged so we never lose
|
||||
// producer-declared strings such as "71.90GB".
|
||||
func formatByteSize(s string) string {
|
||||
var n uint64
|
||||
if _, err := fmt.Sscanf(s, "%d", &n); err != nil || n == 0 {
|
||||
n, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil || n == 0 {
|
||||
return s
|
||||
}
|
||||
const (
|
||||
|
||||
@ -72,6 +72,9 @@ func parseSafeTensorsOCIConfig(_ context.Context, resolver file.Resolver, _ *gen
|
||||
}
|
||||
|
||||
name, license := enrichFromDockerAILayers(resolver, &md)
|
||||
if name == "" {
|
||||
name = defaultModelName
|
||||
}
|
||||
|
||||
p := newSafeTensorsPackage(
|
||||
&md,
|
||||
@ -99,8 +102,18 @@ func enrichFromDockerAILayers(resolver file.Resolver, md *pkg.SafeTensorsModelIn
|
||||
if err != nil {
|
||||
log.Debugf("failed to list docker AI model-file layers: %v", err)
|
||||
}
|
||||
|
||||
// Collect name candidates separately so precedence does not depend on the
|
||||
// order the resolver returns layers in. config.json's _name_or_path wins over
|
||||
// a README base_model, matching enrichFromSiblings.
|
||||
var configName, readmeName string
|
||||
for _, loc := range modelFileLocations {
|
||||
readAndClassifyDockerAILayer(resolver, loc, md, &name, &license)
|
||||
readAndClassifyDockerAILayer(resolver, loc, md, &configName, &readmeName, &license)
|
||||
}
|
||||
|
||||
name = configName
|
||||
if name == "" {
|
||||
name = readmeName
|
||||
}
|
||||
|
||||
if license == "" {
|
||||
@ -113,7 +126,7 @@ func enrichFromDockerAILayers(resolver file.Resolver, md *pkg.SafeTensorsModelIn
|
||||
// readAndClassifyDockerAILayer fetches a single Docker AI model-file layer and
|
||||
// passes its contents to classifyAndMerge. Split out from the calling loop so
|
||||
// the resolver handle is closed via defer on every iteration.
|
||||
func readAndClassifyDockerAILayer(resolver file.Resolver, loc file.Location, md *pkg.SafeTensorsModelInfo, name, license *string) {
|
||||
func readAndClassifyDockerAILayer(resolver file.Resolver, loc file.Location, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) {
|
||||
rc, err := resolver.FileContentsByLocation(loc)
|
||||
if err != nil {
|
||||
return
|
||||
@ -124,13 +137,13 @@ func readAndClassifyDockerAILayer(resolver file.Resolver, loc file.Location, md
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
classifyAndMerge(buf, md, name, license)
|
||||
classifyAndMerge(buf, md, configName, readmeName, license)
|
||||
}
|
||||
|
||||
// classifyAndMerge sniffs a vnd.docker.ai.model.file blob (which can be README.md,
|
||||
// config.json, generation_config.json, tokenizer.json, etc.) and folds useful
|
||||
// fields into the metadata struct and out-parameters.
|
||||
func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, name, license *string) {
|
||||
func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, configName, readmeName, license *string) {
|
||||
trimmed := trimLeadingWhitespace(buf)
|
||||
switch {
|
||||
case hasPrefix(trimmed, "---"):
|
||||
@ -138,8 +151,8 @@ func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, name, license *s
|
||||
if *license == "" {
|
||||
*license = fm.License
|
||||
}
|
||||
if *name == "" && len(fm.BaseModel) > 0 {
|
||||
*name = lastPathSegment(fm.BaseModel[0])
|
||||
if *readmeName == "" && len(fm.BaseModel) > 0 {
|
||||
*readmeName = lastPathSegment(fm.BaseModel[0])
|
||||
}
|
||||
}
|
||||
case hasPrefix(trimmed, "{"):
|
||||
@ -156,8 +169,8 @@ func classifyAndMerge(buf []byte, md *pkg.SafeTensorsModelInfo, name, license *s
|
||||
if md.TransformersVersion == "" {
|
||||
md.TransformersVersion = cfg.TransformersVersion
|
||||
}
|
||||
if *name == "" && cfg.NameOrPath != "" {
|
||||
*name = lastPathSegment(cfg.NameOrPath)
|
||||
if *configName == "" && cfg.NameOrPath != "" {
|
||||
*configName = lastPathSegment(cfg.NameOrPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
412
syft/pkg/cataloger/ai/parse_safetensors_test.go
Normal file
412
syft/pkg/cataloger/ai/parse_safetensors_test.go
Normal file
@ -0,0 +1,412 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/anchore/syft/syft/file"
|
||||
"github.com/anchore/syft/syft/pkg"
|
||||
"github.com/anchore/syft/syft/pkg/cataloger/internal/pkgtest"
|
||||
)
|
||||
|
||||
// buildSafeTensorsFile builds the on-disk bytes of a .safetensors file: an
|
||||
// 8-byte little-endian header length followed by the JSON header. Tensor data
|
||||
// is omitted because the parser only reads the header.
|
||||
func buildSafeTensorsFile(t *testing.T, metadata map[string]string, tensors map[string]safeTensorsEntry) []byte {
|
||||
t.Helper()
|
||||
raw := map[string]any{}
|
||||
if metadata != nil {
|
||||
raw["__metadata__"] = metadata
|
||||
}
|
||||
for name, entry := range tensors {
|
||||
raw[name] = entry
|
||||
}
|
||||
body, err := json.Marshal(raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
out := make([]byte, 8+len(body))
|
||||
binary.LittleEndian.PutUint64(out[:8], uint64(len(body)))
|
||||
copy(out[8:], body)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestSafeTensorsCataloger_singleFile(t *testing.T) {
|
||||
userMeta := map[string]string{"format": "pt"}
|
||||
tensors := map[string]safeTensorsEntry{
|
||||
"model.embed.weight": {DType: "BF16", Shape: []int64{1000, 16}, DataOffsets: []int64{0, 32000}},
|
||||
"model.layer.weight": {DType: "BF16", Shape: []int64{16, 16}, DataOffsets: []int64{32000, 32512}},
|
||||
}
|
||||
// the dedicated hash test below locks the algorithm; here we only assert the
|
||||
// cataloger wires the header hash through to the package metadata.
|
||||
wantHash := (&safeTensorsHeader{metadata: userMeta, tensors: tensors}).metadataHash()
|
||||
|
||||
dir := t.TempDir()
|
||||
modelDir := filepath.Join(dir, "models")
|
||||
require.NoError(t, os.MkdirAll(modelDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "model.safetensors"), buildSafeTensorsFile(t, userMeta, tensors), 0o644))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "config.json"),
|
||||
[]byte(`{"architectures":["LlamaForCausalLM"],"torch_dtype":"bfloat16","transformers_version":"4.40.0","_name_or_path":"meta-llama/Llama-3-8B"}`), 0o644))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "README.md"),
|
||||
[]byte("---\nlicense: Apache-2.0\nbase_model:\n - meta-llama/Llama-3\n---\n# Llama 3\n"), 0o644))
|
||||
|
||||
expected := []pkg.Package{
|
||||
{
|
||||
Name: "Llama-3-8B",
|
||||
Type: pkg.ModelPkg,
|
||||
Licenses: pkg.NewLicenseSet(
|
||||
pkg.NewLicenseFromFields("Apache-2.0", "", nil),
|
||||
),
|
||||
Metadata: pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
Architecture: "LlamaForCausalLM",
|
||||
Quantization: "BF16",
|
||||
Parameters: "16.26K",
|
||||
TensorCount: 2,
|
||||
TorchDtype: "bfloat16",
|
||||
TransformersVersion: "4.40.0",
|
||||
ShardCount: 1,
|
||||
UserMetadata: userMeta,
|
||||
MetadataHash: wantHash,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pkgtest.NewCatalogTester().
|
||||
FromDirectory(t, dir).
|
||||
Expects(expected, nil).
|
||||
IgnoreLocationLayer().
|
||||
IgnorePackageFields("FoundBy", "Locations").
|
||||
TestCataloger(t, NewSafeTensorsCataloger())
|
||||
}
|
||||
|
||||
func TestSafeTensorsCataloger_shardedIndex(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
modelDir := filepath.Join(dir, "my-model")
|
||||
require.NoError(t, os.MkdirAll(modelDir, 0o755))
|
||||
index := `{
|
||||
"metadata": {"total_size": 16000000000},
|
||||
"weight_map": {
|
||||
"layer.0.weight": "model-00001-of-00002.safetensors",
|
||||
"layer.1.weight": "model-00001-of-00002.safetensors",
|
||||
"layer.2.weight": "model-00002-of-00002.safetensors"
|
||||
}
|
||||
}`
|
||||
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "model.safetensors.index.json"), []byte(index), 0o644))
|
||||
|
||||
expected := []pkg.Package{
|
||||
{
|
||||
Name: "my-model",
|
||||
Type: pkg.ModelPkg,
|
||||
Licenses: pkg.NewLicenseSet(),
|
||||
Metadata: pkg.SafeTensorsModelInfo{
|
||||
Format: "safetensors",
|
||||
TensorCount: 3,
|
||||
ShardCount: 2,
|
||||
TotalSize: "14.90GB",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pkgtest.NewCatalogTester().
|
||||
FromDirectory(t, dir).
|
||||
Expects(expected, nil).
|
||||
IgnoreLocationLayer().
|
||||
IgnorePackageFields("FoundBy", "Locations").
|
||||
TestCataloger(t, NewSafeTensorsCataloger())
|
||||
}
|
||||
|
||||
func TestParseSafeTensorsOCIConfig(t *testing.T) {
|
||||
configBlob := []byte(`{"config":{"format":"safetensors","quantization":"Q4_K_M","parameters":"8B","size":"16.00GB","safetensors":{"tensor_count":291}}}`)
|
||||
|
||||
t.Run("enriches from companion layers", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
readmePath := filepath.Join(dir, "README.md")
|
||||
require.NoError(t, os.WriteFile(readmePath,
|
||||
[]byte("---\nlicense: mit\nbase_model:\n - org/My-Model\n---\n# card\n"), 0o644))
|
||||
hfConfigPath := filepath.Join(dir, "config.json")
|
||||
require.NoError(t, os.WriteFile(hfConfigPath,
|
||||
[]byte(`{"architectures":["Qwen3ForCausalLM"],"torch_dtype":"bfloat16"}`), 0o644))
|
||||
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: {file.NewLocation(readmePath), file.NewLocation(hfConfigPath)},
|
||||
})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
|
||||
p := pkgs[0]
|
||||
assert.Equal(t, "My-Model", p.Name)
|
||||
assert.Equal(t, pkg.ModelPkg, p.Type)
|
||||
assertHasLicense(t, p, "mit")
|
||||
|
||||
md := p.Metadata.(pkg.SafeTensorsModelInfo)
|
||||
assert.Equal(t, "safetensors", md.Format)
|
||||
assert.Equal(t, "Qwen3ForCausalLM", md.Architecture)
|
||||
assert.Equal(t, "bfloat16", md.TorchDtype)
|
||||
assert.Equal(t, "Q4_K_M", md.Quantization)
|
||||
assert.Equal(t, "8B", md.Parameters)
|
||||
assert.Equal(t, "16.00GB", md.TotalSize)
|
||||
assert.Equal(t, uint64(291), md.TensorCount)
|
||||
})
|
||||
|
||||
t.Run("falls back to license layer", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
readmePath := filepath.Join(dir, "README.md")
|
||||
require.NoError(t, os.WriteFile(readmePath,
|
||||
[]byte("---\nbase_model:\n - org/My-Model\n---\n"), 0o644))
|
||||
licensePath := filepath.Join(dir, "LICENSE")
|
||||
require.NoError(t, os.WriteFile(licensePath,
|
||||
[]byte(" Apache License\n Version 2.0, January 2004\n"), 0o644))
|
||||
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: {file.NewLocation(readmePath)},
|
||||
dockerAILicenseMediaType: {file.NewLocation(licensePath)},
|
||||
})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
assertHasLicense(t, pkgs[0], "Apache-2.0")
|
||||
})
|
||||
|
||||
t.Run("config _name_or_path wins over README base_model regardless of layer order", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
readmePath := filepath.Join(dir, "README.md")
|
||||
require.NoError(t, os.WriteFile(readmePath, []byte("---\nbase_model:\n - org/Readme-Name\n---\n"), 0o644))
|
||||
hfConfigPath := filepath.Join(dir, "config.json")
|
||||
require.NoError(t, os.WriteFile(hfConfigPath, []byte(`{"_name_or_path":"org/Config-Name"}`), 0o644))
|
||||
|
||||
// both layer orderings must yield the same (config-derived) name
|
||||
orderings := [][]file.Location{
|
||||
{file.NewLocation(readmePath), file.NewLocation(hfConfigPath)},
|
||||
{file.NewLocation(hfConfigPath), file.NewLocation(readmePath)},
|
||||
}
|
||||
for _, locs := range orderings {
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{
|
||||
dockerAIModelFileMediaType: locs,
|
||||
})
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
assert.Equal(t, "Config-Name", pkgs[0].Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to default name when none derivable", func(t *testing.T) {
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(configBlob))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pkgs, 1)
|
||||
assert.Equal(t, "safetensors-model", pkgs[0].Name, "model must still be emitted, not dropped")
|
||||
})
|
||||
|
||||
t.Run("ignores non-safetensors format", func(t *testing.T) {
|
||||
ggufBlob := []byte(`{"config":{"format":"gguf","quantization":"Q4_K_M"}}`)
|
||||
resolver := file.NewMockResolverForMediaTypes(map[string][]file.Location{})
|
||||
|
||||
pkgs, _, err := parseSafeTensorsOCIConfig(context.Background(), resolver, nil, configReader(ggufBlob))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, pkgs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSafeTensorsMergeProcessor(t *testing.T) {
|
||||
named := pkg.Package{Name: "model-a", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "aaaa"}}
|
||||
nameless := pkg.Package{Name: "", Type: pkg.ModelPkg, Metadata: pkg.SafeTensorsModelInfo{Format: "safetensors", MetadataHash: "bbbb"}}
|
||||
|
||||
t.Run("merges nameless into named parts", func(t *testing.T) {
|
||||
out, _, err := safeTensorsMergeProcessor([]pkg.Package{named, nameless}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
assert.Equal(t, "model-a", out[0].Name)
|
||||
md := out[0].Metadata.(pkg.SafeTensorsModelInfo)
|
||||
require.Len(t, md.Parts, 1)
|
||||
assert.Empty(t, md.Parts[0].MetadataHash, "nameless part hash should be cleared")
|
||||
})
|
||||
|
||||
t.Run("drops result when no named package", func(t *testing.T) {
|
||||
out, _, err := safeTensorsMergeProcessor([]pkg.Package{nameless}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, out)
|
||||
})
|
||||
|
||||
t.Run("passes through upstream error", func(t *testing.T) {
|
||||
sentinel := assert.AnError
|
||||
out, _, err := safeTensorsMergeProcessor([]pkg.Package{named}, nil, sentinel)
|
||||
assert.Equal(t, sentinel, err)
|
||||
assert.Len(t, out, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func configReader(blob []byte) file.LocationReadCloser {
|
||||
return file.NewLocationReadCloser(file.NewLocation("/config.json"), io.NopCloser(bytes.NewReader(blob)))
|
||||
}
|
||||
|
||||
func assertHasLicense(t *testing.T, p pkg.Package, value string) {
|
||||
t.Helper()
|
||||
for _, l := range p.Licenses.ToSlice() {
|
||||
if l.Value == value {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected license %q, got %+v", value, p.Licenses.ToSlice())
|
||||
}
|
||||
|
||||
func TestReadSafeTensorsHeader(t *testing.T) {
|
||||
t.Run("valid header", func(t *testing.T) {
|
||||
data := buildSafeTensorsFile(t, map[string]string{"format": "pt"}, map[string]safeTensorsEntry{
|
||||
"w": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
|
||||
})
|
||||
h, n, err := readSafeTensorsHeader(bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint64(len(data)-8), n)
|
||||
assert.Len(t, h.tensors, 1)
|
||||
assert.Equal(t, "pt", h.metadata["format"])
|
||||
})
|
||||
|
||||
t.Run("zero-length header", func(t *testing.T) {
|
||||
var buf [8]byte // length prefix of 0
|
||||
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated body", func(t *testing.T) {
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], 100) // claims 100 bytes but supplies none
|
||||
_, _, err := readSafeTensorsHeader(bytes.NewReader(buf[:]))
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSafeTensorsHeader_metadataHash(t *testing.T) {
|
||||
base := &safeTensorsHeader{
|
||||
metadata: map[string]string{"format": "pt"},
|
||||
tensors: map[string]safeTensorsEntry{
|
||||
"a.weight": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
|
||||
"b.weight": {DType: "F16", Shape: []int64{4}, DataOffsets: []int64{16, 24}},
|
||||
},
|
||||
}
|
||||
|
||||
// deterministic across calls and independent of map insertion order
|
||||
reordered := &safeTensorsHeader{
|
||||
metadata: map[string]string{"format": "pt"},
|
||||
tensors: map[string]safeTensorsEntry{
|
||||
"b.weight": {DType: "F16", Shape: []int64{4}, DataOffsets: []int64{16, 24}},
|
||||
"a.weight": {DType: "F32", Shape: []int64{2, 2}, DataOffsets: []int64{0, 16}},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, base.metadataHash(), reordered.metadataHash())
|
||||
assert.Len(t, base.metadataHash(), 16)
|
||||
|
||||
// changing a tensor changes the hash
|
||||
changed := &safeTensorsHeader{
|
||||
metadata: base.metadata,
|
||||
tensors: map[string]safeTensorsEntry{
|
||||
"a.weight": {DType: "F32", Shape: []int64{2, 3}, DataOffsets: []int64{0, 24}},
|
||||
"b.weight": {DType: "F16", Shape: []int64{4}, DataOffsets: []int64{24, 32}},
|
||||
},
|
||||
}
|
||||
assert.NotEqual(t, base.metadataHash(), changed.metadataHash())
|
||||
|
||||
// changing __metadata__ changes the hash
|
||||
differentMeta := &safeTensorsHeader{metadata: map[string]string{"format": "np"}, tensors: base.tensors}
|
||||
assert.NotEqual(t, base.metadataHash(), differentMeta.metadataHash())
|
||||
}
|
||||
|
||||
func TestSafeTensorsHeader_parameterCountAndDType(t *testing.T) {
|
||||
h := &safeTensorsHeader{tensors: map[string]safeTensorsEntry{
|
||||
"big": {DType: "BF16", Shape: []int64{1000, 16}},
|
||||
"small": {DType: "F32", Shape: []int64{16, 16}},
|
||||
"scalar": {DType: "F32", Shape: []int64{}}, // empty shape contributes 1
|
||||
}}
|
||||
assert.Equal(t, uint64(1000*16+16*16+1), h.parameterCount())
|
||||
assert.Equal(t, "BF16", h.dominantDType())
|
||||
}
|
||||
|
||||
func TestNormalizeDType(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"BF16": "BF16",
|
||||
"float16": "F16",
|
||||
"FP32": "F32",
|
||||
"int8": "I8",
|
||||
"U8": "U8",
|
||||
"bool": "BOOL",
|
||||
"weird": "WEIRD",
|
||||
}
|
||||
for in, want := range cases {
|
||||
assert.Equalf(t, want, normalizeDType(in), "normalizeDType(%q)", in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatParameterCount(t *testing.T) {
|
||||
cases := map[uint64]string{
|
||||
512: "512",
|
||||
16256: "16.26K",
|
||||
2_680_000_000: "2.68B",
|
||||
35_000_000: "35.00M",
|
||||
}
|
||||
for in, want := range cases {
|
||||
assert.Equalf(t, want, formatParameterCount(in), "formatParameterCount(%d)", in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatByteSize(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"16000000000": "14.90GB",
|
||||
"2048": "2.00KB",
|
||||
"500": "500B",
|
||||
"71.90GB": "71.90GB", // non-numeric passes through unchanged
|
||||
"": "",
|
||||
}
|
||||
for in, want := range cases {
|
||||
assert.Equalf(t, want, formatByteSize(in), "formatByteSize(%q)", in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFrontmatter(t *testing.T) {
|
||||
t.Run("list base_model", func(t *testing.T) {
|
||||
fm := parseFrontmatter([]byte("---\nlicense: mit\nbase_model:\n - org/Model\n---\nbody"))
|
||||
require.NotNil(t, fm)
|
||||
assert.Equal(t, "mit", fm.License)
|
||||
assert.Equal(t, []string{"org/Model"}, fm.BaseModel)
|
||||
})
|
||||
|
||||
t.Run("scalar base_model", func(t *testing.T) {
|
||||
fm := parseFrontmatter([]byte("---\nlicense: apache-2.0\nbase_model: org/Model\n---\n"))
|
||||
require.NotNil(t, fm)
|
||||
assert.Equal(t, "apache-2.0", fm.License)
|
||||
assert.Equal(t, []string{"org/Model"}, fm.BaseModel)
|
||||
})
|
||||
|
||||
t.Run("leading BOM", func(t *testing.T) {
|
||||
fm := parseFrontmatter([]byte("\xef\xbb\xbf---\nlicense: mit\n---\n"))
|
||||
require.NotNil(t, fm)
|
||||
assert.Equal(t, "mit", fm.License)
|
||||
})
|
||||
|
||||
t.Run("no frontmatter", func(t *testing.T) {
|
||||
assert.Nil(t, parseFrontmatter([]byte("# just a heading\n")))
|
||||
})
|
||||
|
||||
t.Run("unterminated frontmatter", func(t *testing.T) {
|
||||
assert.Nil(t, parseFrontmatter([]byte("---\nlicense: mit\n")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelNameFromPath(t *testing.T) {
|
||||
assert.Equal(t, "foo", modelNameFromPath("/models/foo/model.safetensors"))
|
||||
assert.Equal(t, "weights", modelNameFromPath("weights.safetensors"))
|
||||
assert.Equal(t, "my-model", modelNameFromIndexPath("/models/my-model/model.safetensors.index.json"))
|
||||
assert.Equal(t, "safetensors-model", modelNameFromIndexPath("model.safetensors.index.json"))
|
||||
}
|
||||
@ -0,0 +1,55 @@
|
||||
package ocimodelsource
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||
"github.com/google/go-containerregistry/pkg/v1/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/anchore/syft/syft/source"
|
||||
)
|
||||
|
||||
func TestDetectModelFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
gguf int
|
||||
safetensors int
|
||||
expected string
|
||||
}{
|
||||
{name: "gguf only", gguf: 2, safetensors: 0, expected: modelFormatGGUF},
|
||||
{name: "safetensors only", gguf: 0, safetensors: 3, expected: modelFormatSafeTensors},
|
||||
{name: "both prefers gguf", gguf: 1, safetensors: 1, expected: modelFormatGGUF},
|
||||
{name: "neither", gguf: 0, safetensors: 0, expected: ""},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.Equal(t, test.expected, detectModelFormat(test.gguf, test.safetensors))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSafeTensorsLayers(t *testing.T) {
|
||||
manifest := &v1.Manifest{Layers: []v1.Descriptor{
|
||||
{MediaType: types.MediaType(safetensorsLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "a"}},
|
||||
{MediaType: types.MediaType(ggufLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "b"}},
|
||||
{MediaType: types.MediaType(safetensorsLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "c"}},
|
||||
}}
|
||||
assert.Len(t, extractSafeTensorsLayers(manifest), 2)
|
||||
}
|
||||
|
||||
func TestExtractCompanionLayers(t *testing.T) {
|
||||
manifest := &v1.Manifest{Layers: []v1.Descriptor{
|
||||
{MediaType: types.MediaType(modelFileMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "readme"}},
|
||||
{MediaType: types.MediaType(licenseMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "license"}},
|
||||
{MediaType: types.MediaType(safetensorsLayerMediaType), Digest: v1.Hash{Algorithm: "sha256", Hex: "weights"}},
|
||||
{MediaType: types.DockerLayer, Digest: v1.Hash{Algorithm: "sha256", Hex: "other"}},
|
||||
}}
|
||||
// only the model.file and license layers should be selected (not weights or arbitrary layers)
|
||||
assert.Len(t, extractCompanionLayers(manifest), 2)
|
||||
}
|
||||
|
||||
func TestCalculateTotalSize(t *testing.T) {
|
||||
layers := []source.LayerMetadata{{Size: 100}, {Size: 250}, {Size: 0}}
|
||||
assert.Equal(t, int64(350), calculateTotalSize(layers))
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user