From f6d4a7d27a094861a3a9c112d3818d27dab4cd0f Mon Sep 17 00:00:00 2001 From: Adam McClenaghan Date: Wed, 23 Apr 2025 21:01:10 +0100 Subject: [PATCH] Perf: skip license scanner injection (#3796) * (perf): allow library users to skip default scanner injection Signed-off-by: Adam McClenaghan * (perf): remove prints Signed-off-by: Adam McClenaghan * perf: move to cataloging licenses.go Signed-off-by: adammcclenaghan * perf: Simplify to expose a SetContextLicenseScanner func Signed-off-by: adammcclenaghan --------- Signed-off-by: Adam McClenaghan Signed-off-by: adammcclenaghan --- internal/licenses/context.go | 11 ++++++-- internal/licenses/context_test.go | 47 +++++++++++++++++++++++++++++++ syft/create_sbom.go | 15 +++++++--- 3 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 internal/licenses/context_test.go diff --git a/internal/licenses/context.go b/internal/licenses/context.go index 9d735a8a3..91301cc20 100644 --- a/internal/licenses/context.go +++ b/internal/licenses/context.go @@ -6,12 +6,19 @@ import ( type licenseScannerKey struct{} +var ctxKey = licenseScannerKey{} + func SetContextLicenseScanner(ctx context.Context, s Scanner) context.Context { - return context.WithValue(ctx, licenseScannerKey{}, s) + return context.WithValue(ctx, ctxKey, s) +} + +func IsContextLicenseScannerSet(ctx context.Context) bool { + _, ok := ctx.Value(ctxKey).(Scanner) + return ok } func ContextLicenseScanner(ctx context.Context) (Scanner, error) { - if s, ok := ctx.Value(licenseScannerKey{}).(Scanner); ok { + if s, ok := ctx.Value(ctxKey).(Scanner); ok { return s, nil } return NewDefaultScanner() diff --git a/internal/licenses/context_test.go b/internal/licenses/context_test.go new file mode 100644 index 000000000..822be2fa4 --- /dev/null +++ b/internal/licenses/context_test.go @@ -0,0 +1,47 @@ +package licenses + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" +) + +func TestSetContextLicenseScanner(t *testing.T) { + scanner := testScanner(true) + ctx := context.Background() + ctx = SetContextLicenseScanner(ctx, scanner) + + val := ctx.Value(ctxKey) + require.NotNil(t, val) + s, ok := val.(Scanner) + require.True(t, ok) + require.Equal(t, scanner, s) +} + +func TestIsContextLicenseScannerSet(t *testing.T) { + scanner := testScanner(true) + ctx := context.Background() + require.False(t, IsContextLicenseScannerSet(ctx)) + + ctx = SetContextLicenseScanner(ctx, scanner) + require.True(t, IsContextLicenseScannerSet(ctx)) +} + +func TestContextLicenseScanner(t *testing.T) { + t.Run("with scanner", func(t *testing.T) { + scanner := testScanner(true) + ctx := SetContextLicenseScanner(context.Background(), scanner) + s, err := ContextLicenseScanner(ctx) + if err != nil || s != scanner { + t.Fatal("expected scanner from context") + } + }) + + t.Run("without scanner", func(t *testing.T) { + ctx := context.Background() + s, err := ContextLicenseScanner(ctx) + if err != nil || s == nil { + t.Fatal("expected default scanner") + } + }) +} diff --git a/syft/create_sbom.go b/syft/create_sbom.go index aec9fdcae..68605d564 100644 --- a/syft/create_sbom.go +++ b/syft/create_sbom.go @@ -96,14 +96,21 @@ func setupContext(ctx context.Context, cfg *CreateSBOMConfig) (context.Context, ctx = setContextExecutors(ctx, cfg) // configure license scanner - return setContextLicenseScanner(ctx, cfg) + // skip injecting a license scanner if one already set on context + if licenses.IsContextLicenseScannerSet(ctx) { + return ctx, nil + } + + return SetContextLicenseScanner(ctx, cfg.Licenses) } -func setContextLicenseScanner(ctx context.Context, cfg *CreateSBOMConfig) (context.Context, error) { +// SetContextLicenseScanner creates and sets a license scanner +// on the provided context using the provided license config. +func SetContextLicenseScanner(ctx context.Context, cfg cataloging.LicenseConfig) (context.Context, error) { // inject a single license scanner and content config for all package cataloging tasks into context licenseScanner, err := licenses.NewDefaultScanner( - licenses.WithIncludeLicenseContent(cfg.Licenses.IncludeUnkownLicenseContent), - licenses.WithCoverage(cfg.Licenses.Coverage), + licenses.WithIncludeLicenseContent(cfg.IncludeUnkownLicenseContent), + licenses.WithCoverage(cfg.Coverage), ) if err != nil { return nil, fmt.Errorf("could not build licenseScanner for cataloging: %w", err)