Perf: skip license scanner injection (#3796)

* (perf): allow library users to skip default scanner injection

Signed-off-by: Adam McClenaghan <adam@mcclenaghan.co.uk>

* (perf): remove prints

Signed-off-by: Adam McClenaghan <adam@mcclenaghan.co.uk>

* perf: move to cataloging licenses.go

Signed-off-by: adammcclenaghan <adam.mcclenaghan@upwind.io>

* perf: Simplify to expose a SetContextLicenseScanner func

Signed-off-by: adammcclenaghan <adam.mcclenaghan@upwind.io>

---------

Signed-off-by: Adam McClenaghan <adam@mcclenaghan.co.uk>
Signed-off-by: adammcclenaghan <adam.mcclenaghan@upwind.io>
This commit is contained in:
Adam McClenaghan 2025-04-23 21:01:10 +01:00 committed by GitHub
parent 273d414b6b
commit f6d4a7d27a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 6 deletions

View File

@ -6,12 +6,19 @@ import (
type licenseScannerKey struct{} type licenseScannerKey struct{}
var ctxKey = licenseScannerKey{}
func SetContextLicenseScanner(ctx context.Context, s Scanner) context.Context { func SetContextLicenseScanner(ctx context.Context, s Scanner) context.Context {
return context.WithValue(ctx, licenseScannerKey{}, s) return context.WithValue(ctx, ctxKey, s)
}
func IsContextLicenseScannerSet(ctx context.Context) bool {
_, ok := ctx.Value(ctxKey).(Scanner)
return ok
} }
func ContextLicenseScanner(ctx context.Context) (Scanner, error) { func ContextLicenseScanner(ctx context.Context) (Scanner, error) {
if s, ok := ctx.Value(licenseScannerKey{}).(Scanner); ok { if s, ok := ctx.Value(ctxKey).(Scanner); ok {
return s, nil return s, nil
} }
return NewDefaultScanner() return NewDefaultScanner()

View File

@ -0,0 +1,47 @@
package licenses
import (
"context"
"github.com/stretchr/testify/require"
"testing"
)
func TestSetContextLicenseScanner(t *testing.T) {
scanner := testScanner(true)
ctx := context.Background()
ctx = SetContextLicenseScanner(ctx, scanner)
val := ctx.Value(ctxKey)
require.NotNil(t, val)
s, ok := val.(Scanner)
require.True(t, ok)
require.Equal(t, scanner, s)
}
func TestIsContextLicenseScannerSet(t *testing.T) {
scanner := testScanner(true)
ctx := context.Background()
require.False(t, IsContextLicenseScannerSet(ctx))
ctx = SetContextLicenseScanner(ctx, scanner)
require.True(t, IsContextLicenseScannerSet(ctx))
}
func TestContextLicenseScanner(t *testing.T) {
t.Run("with scanner", func(t *testing.T) {
scanner := testScanner(true)
ctx := SetContextLicenseScanner(context.Background(), scanner)
s, err := ContextLicenseScanner(ctx)
if err != nil || s != scanner {
t.Fatal("expected scanner from context")
}
})
t.Run("without scanner", func(t *testing.T) {
ctx := context.Background()
s, err := ContextLicenseScanner(ctx)
if err != nil || s == nil {
t.Fatal("expected default scanner")
}
})
}

View File

@ -96,14 +96,21 @@ func setupContext(ctx context.Context, cfg *CreateSBOMConfig) (context.Context,
ctx = setContextExecutors(ctx, cfg) ctx = setContextExecutors(ctx, cfg)
// configure license scanner // configure license scanner
return setContextLicenseScanner(ctx, cfg) // skip injecting a license scanner if one already set on context
if licenses.IsContextLicenseScannerSet(ctx) {
return ctx, nil
}
return SetContextLicenseScanner(ctx, cfg.Licenses)
} }
func setContextLicenseScanner(ctx context.Context, cfg *CreateSBOMConfig) (context.Context, error) { // SetContextLicenseScanner creates and sets a license scanner
// on the provided context using the provided license config.
func SetContextLicenseScanner(ctx context.Context, cfg cataloging.LicenseConfig) (context.Context, error) {
// inject a single license scanner and content config for all package cataloging tasks into context // inject a single license scanner and content config for all package cataloging tasks into context
licenseScanner, err := licenses.NewDefaultScanner( licenseScanner, err := licenses.NewDefaultScanner(
licenses.WithIncludeLicenseContent(cfg.Licenses.IncludeUnkownLicenseContent), licenses.WithIncludeLicenseContent(cfg.IncludeUnkownLicenseContent),
licenses.WithCoverage(cfg.Licenses.Coverage), licenses.WithCoverage(cfg.Coverage),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("could not build licenseScanner for cataloging: %w", err) return nil, fmt.Errorf("could not build licenseScanner for cataloging: %w", err)