test: fixture test with real safetensor data

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2026-05-28 12:17:39 -04:00
parent 324fecf4a4
commit 1a1f2af92b
No known key found for this signature in database
4 changed files with 209 additions and 0 deletions

View File

@ -333,6 +333,38 @@ func TestParseSafeTensorsOCILayer(t *testing.T) {
})
}
// TestParseSafeTensorsOCILayer_realFixture grounds the OCI layer parser
// against a real `[prefix + JSON header]` captured from a public Docker AI
// model artifact (docker.io/ai/nomic-embed-text-v2-moe-safetensors:475M).
// The fixture and the tool that produced it live in
// testdata/safetensors/; see the README there to refresh.
//
// Locking in the field values guards against changes to the header parser
// silently breaking on real-world content shape.
func TestParseSafeTensorsOCILayer_realFixture(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", "safetensors", "nomic-embed-475M.header.safetensors"))
require.NoError(t, err)
require.Greater(t, len(data), 8, "fixture must include the 8-byte length prefix")
reader := file.NewLocationReadCloser(file.NewLocation("/"), io.NopCloser(bytes.NewReader(data)))
pkgs, _, err := parseSafeTensorsOCILayer(context.Background(), nil, nil, reader)
require.NoError(t, err)
require.Len(t, pkgs, 1)
assert.Empty(t, pkgs[0].Name, "weight-layer packages are nameless before the merge processor runs")
md := pkgs[0].Metadata.(pkg.SafeTensorsModelInfo)
assert.Equal(t, "safetensors", md.Format)
assert.Equal(t, uint64(148), md.TensorCount, "nomic-embed-v2-moe 475M ships 148 tensor entries in this shard")
assert.Equal(t, "F32", md.Quantization, "every tensor in the captured shard is F32")
assert.Equal(t, "475.29M", md.Parameters)
assert.Equal(t, map[string]string{"format": "pt"}, md.UserMetadata)
// MetadataHash is locked to the exact value the parser produces for this
// captured input. The fixture is immutable on disk; if this value changes
// either the hash algorithm or the canonicalization changed, both of which
// callers may rely on for cross-source identity.
assert.Equal(t, "051a14e686673dea", md.MetadataHash)
}
func TestSafeTensorsCrossSourceHashParity(t *testing.T) {
// Same content, two paths: a directory scan via parseSafeTensorsFile, and an
// OCI weight-layer scan via parseSafeTensorsOCILayer. The MetadataHash of

View File

@ -0,0 +1,28 @@
# SafeTensors header fixtures
These fixtures are `[8-byte length prefix + JSON header]` captures from
public Docker AI model artifacts on the registry.
`extract_header.go` does a range-GET of the first several MB of the layer,
slices off just `[prefix + JSON header]`, and writes
that to disk.
## Refreshing a fixture
```sh
# from the package root
go run ./testdata/safetensors/extract_header.go \
docker.io/ai/nomic-embed-text-v2-moe-safetensors:475M \
./testdata/safetensors/nomic-embed-475M.header.safetensors
```
The tool prints the layer digest it selected and the number of top-level keys
in the captured header. If you see `header length N does not fit in M fetched
bytes`, raise `fetchBytes` in `extract_header.go` and rerun.
## Notes
- Pick one shard, not the full sharded set. The fixture is meant to exercise
the per-shard parser; merging across shards has its own tests.
- Don't commit anything larger than ~1 MB. If a model has an unusually large
header, capture a smaller model instead.

View File

@ -0,0 +1,149 @@
// extract_header is a manual fixture tool that captures the real on-disk
// safetensors header from a Docker AI OCI model artifact (a vnd.docker.ai.safetensors
// layer) and writes just [8-byte length prefix + JSON header] to a destination
// file. Tensor data following the header is never downloaded, so the resulting
// fixture is a few KB to a few MB even for multi-GB models.
//
// This file lives under testdata/ so the Go build system ignores it. Run it
// manually when refreshing fixtures:
//
// go run ./testdata/safetensors/extract_header.go \
// docker.io/ai/nomic-embed-text-v2-moe-safetensors:475M \
// ./testdata/safetensors/nomic-embed-475M.header.safetensors
package main
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/google/go-containerregistry/pkg/v1/remote"
)
const (
safetensorsLayerMediaType = "application/vnd.docker.ai.safetensors"
// 8 MB matches maxHeaderBytes in the OCI model source. Real model headers
// are well under 1 MB; the extra slack covers outliers.
fetchBytes = 8 * 1024 * 1024
)
func main() {
if len(os.Args) != 3 {
fmt.Fprintf(os.Stderr, "usage: %s <registry-ref> <output-path>\n", os.Args[0])
os.Exit(2)
}
if err := run(os.Args[1], os.Args[2]); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func run(refStr, outPath string) error {
ctx := context.Background()
ref, err := name.ParseReference(refStr)
if err != nil {
return fmt.Errorf("parse reference: %w", err)
}
opts := []remote.Option{
remote.WithAuthFromKeychain(authn.DefaultKeychain),
remote.WithContext(ctx),
}
desc, err := remote.Get(ref, opts...)
if err != nil {
return fmt.Errorf("fetch descriptor: %w", err)
}
manifest := &v1.Manifest{}
if err := json.Unmarshal(desc.Manifest, manifest); err != nil {
return fmt.Errorf("decode manifest: %w", err)
}
weightLayer := pickWeightLayer(manifest)
if weightLayer == nil {
return fmt.Errorf("no %q layer found in %s", safetensorsLayerMediaType, ref)
}
fmt.Fprintf(os.Stderr, "selected layer %s (%d bytes on-disk)\n", weightLayer.Digest, weightLayer.Size)
prefix, err := fetchPrefix(ctx, ref, weightLayer.Digest, opts)
if err != nil {
return fmt.Errorf("fetch layer prefix: %w", err)
}
header, err := sliceHeader(prefix)
if err != nil {
return fmt.Errorf("extract header: %w", err)
}
if err := os.WriteFile(outPath, header, 0o644); err != nil {
return fmt.Errorf("write fixture: %w", err)
}
fmt.Fprintf(os.Stderr, "wrote %d bytes to %s\n", len(header), outPath)
return nil
}
// pickWeightLayer returns the first vnd.docker.ai.safetensors layer in the
// manifest, or nil if none exists. For sharded models we deliberately only
// capture one shard: the fixture is meant to exercise the parser, not the
// merge step.
func pickWeightLayer(manifest *v1.Manifest) *v1.Descriptor {
for i := range manifest.Layers {
if string(manifest.Layers[i].MediaType) == safetensorsLayerMediaType {
return &manifest.Layers[i]
}
}
return nil
}
// fetchPrefix range-reads the first fetchBytes of a layer. Closing the reader
// terminates the underlying HTTP body, so we never download the tensor data
// that follows the header.
func fetchPrefix(_ context.Context, ref name.Reference, digest v1.Hash, opts []remote.Option) ([]byte, error) {
layer, err := remote.Layer(ref.Context().Digest(digest.String()), opts...)
if err != nil {
return nil, err
}
reader, err := layer.Compressed()
if err != nil {
return nil, err
}
defer reader.Close()
buf := make([]byte, fetchBytes)
n, err := io.ReadFull(reader, buf)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, err
}
return buf[:n], nil
}
// sliceHeader reads the 8-byte little-endian length prefix and returns just
// [prefix + JSON header]. It also probes the JSON to make sure the captured
// fixture is well-formed, so we never commit a half-truncated header.
func sliceHeader(buf []byte) ([]byte, error) {
if len(buf) < 8 {
return nil, fmt.Errorf("short read: only %d bytes", len(buf))
}
headerLen := binary.LittleEndian.Uint64(buf[:8])
if headerLen == 0 {
return nil, fmt.Errorf("header length is zero")
}
if headerLen > uint64(len(buf)-8) {
return nil, fmt.Errorf("header length %d does not fit in %d fetched bytes; increase fetchBytes", headerLen, len(buf))
}
out := buf[:8+int(headerLen)]
var probe map[string]json.RawMessage
if err := json.Unmarshal(out[8:], &probe); err != nil {
return nil, fmt.Errorf("captured JSON does not parse: %w", err)
}
fmt.Fprintf(os.Stderr, "header parses cleanly: %d top-level keys\n", len(probe))
return out, nil
}