diff --git a/cmd/syft/cli/commands/update.go b/cmd/syft/cli/commands/update.go index 3ef1812e1..88f159700 100644 --- a/cmd/syft/cli/commands/update.go +++ b/cmd/syft/cli/commands/update.go @@ -38,7 +38,7 @@ func applicationUpdateCheck(id clio.Identification, check *options.UpdateCheck) func checkForApplicationUpdate(id clio.Identification) { log.Debugf("checking if a new version of %s is available", id.Name) - isAvailable, newVersion, err := isUpdateAvailable(id.Version) + isAvailable, newVersion, err := isUpdateAvailable(id) if err != nil { // this should never stop the application log.Errorf(err.Error()) @@ -59,18 +59,18 @@ func checkForApplicationUpdate(id clio.Identification) { } // isUpdateAvailable indicates if there is a newer application version available, and if so, what the new version is. -func isUpdateAvailable(version string) (bool, string, error) { - if !isProductionBuild(version) { +func isUpdateAvailable(id clio.Identification) (bool, string, error) { + if !isProductionBuild(id.Version) { // don't allow for non-production builds to check for a version. return false, "", nil } - currentVersion, err := hashiVersion.NewVersion(version) + currentVersion, err := hashiVersion.NewVersion(id.Version) if err != nil { return false, "", fmt.Errorf("failed to parse current application version: %w", err) } - latestVersion, err := fetchLatestApplicationVersion() + latestVersion, err := fetchLatestApplicationVersion(id) if err != nil { return false, "", err } @@ -89,11 +89,12 @@ func isProductionBuild(version string) bool { return true } -func fetchLatestApplicationVersion() (*hashiVersion.Version, error) { +func fetchLatestApplicationVersion(id clio.Identification) (*hashiVersion.Version, error) { req, err := http.NewRequest(http.MethodGet, latestAppVersionURL.host+latestAppVersionURL.path, nil) if err != nil { return nil, fmt.Errorf("failed to create request for latest version: %w", err) } + req.Header.Add("User-Agent", fmt.Sprintf("%v %v", id.Name, id.Version)) client := http.Client{} resp, err := client.Do(req) diff --git a/cmd/syft/cli/commands/update_test.go b/cmd/syft/cli/commands/update_test.go index 96cd7de34..bc04a7465 100644 --- a/cmd/syft/cli/commands/update_test.go +++ b/cmd/syft/cli/commands/update_test.go @@ -5,6 +5,7 @@ import ( "net/http/httptest" "testing" + "github.com/anchore/clio" hashiVersion "github.com/anchore/go-version" "github.com/anchore/syft/cmd/syft/internal" ) @@ -106,7 +107,7 @@ func TestIsUpdateAvailable(t *testing.T) { t.Run(test.name, func(t *testing.T) { // setup mocks // local... - version := test.buildVersion + id := clio.Identification{Name: "Syft", Version: test.buildVersion} // remote... handler := http.NewServeMux() handler.HandleFunc(latestAppVersionURL.path, func(w http.ResponseWriter, r *http.Request) { @@ -117,7 +118,7 @@ func TestIsUpdateAvailable(t *testing.T) { latestAppVersionURL.host = mockSrv.URL defer mockSrv.Close() - isAvailable, newVersion, err := isUpdateAvailable(version) + isAvailable, newVersion, err := isUpdateAvailable(id) if err != nil && !test.err { t.Fatalf("got error but expected none: %+v", err) } else if err == nil && test.err { @@ -138,52 +139,67 @@ func TestIsUpdateAvailable(t *testing.T) { func TestFetchLatestApplicationVersion(t *testing.T) { tests := []struct { - name string - response string - code int - err bool - expected *hashiVersion.Version + name string + response string + code int + err bool + id clio.Identification + expected *hashiVersion.Version + expectedHeaders map[string]string }{ { - name: "gocase", - response: "1.0.0", - code: 200, - expected: hashiVersion.Must(hashiVersion.NewVersion("1.0.0")), + name: "gocase", + response: "1.0.0", + code: 200, + id: clio.Identification{Name: "Syft", Version: "0.0.0"}, + expected: hashiVersion.Must(hashiVersion.NewVersion("1.0.0")), + expectedHeaders: map[string]string{"User-Agent": "Syft 0.0.0"}, + err: false, }, { - name: "garbage", - response: "garbage", - code: 200, - expected: nil, - err: true, + name: "garbage", + response: "garbage", + code: 200, + id: clio.Identification{Name: "Syft", Version: "0.0.0"}, + expected: nil, + expectedHeaders: nil, + err: true, }, { - name: "http 500", - response: "1.0.0", - code: 500, - expected: nil, - err: true, + name: "http 500", + response: "1.0.0", + code: 500, + id: clio.Identification{Name: "Syft", Version: "0.0.0"}, + expected: nil, + expectedHeaders: nil, + err: true, }, { - name: "http 404", - response: "1.0.0", - code: 404, - expected: nil, - err: true, + name: "http 404", + response: "1.0.0", + code: 404, + id: clio.Identification{Name: "Syft", Version: "0.0.0"}, + expected: nil, + expectedHeaders: nil, + err: true, }, { - name: "empty", - response: "", - code: 200, - expected: nil, - err: true, + name: "empty", + response: "", + code: 200, + id: clio.Identification{Name: "Syft", Version: "0.0.0"}, + expected: nil, + expectedHeaders: nil, + err: true, }, { - name: "too long", - response: "this is really long this is really long this is really long this is really long this is really long this is really long this is really long this is really long ", - code: 200, - expected: nil, - err: true, + name: "too long", + response: "this is really long this is really long this is really long this is really long this is really long this is really long this is really long this is really long ", + code: 200, + id: clio.Identification{Name: "Syft", Version: "0.0.0"}, + expected: nil, + expectedHeaders: nil, + err: true, }, } @@ -192,6 +208,15 @@ func TestFetchLatestApplicationVersion(t *testing.T) { // setup mock handler := http.NewServeMux() handler.HandleFunc(latestAppVersionURL.path, func(w http.ResponseWriter, r *http.Request) { + if test.expectedHeaders != nil { + for headerName, headerValue := range test.expectedHeaders { + actualHeader := r.Header.Get(headerName) + if actualHeader != headerValue { + t.Fatalf("expected header %v=%v but got %v", headerName, headerValue, actualHeader) + } + } + } + w.WriteHeader(test.code) _, _ = w.Write([]byte(test.response)) }) @@ -199,7 +224,7 @@ func TestFetchLatestApplicationVersion(t *testing.T) { latestAppVersionURL.host = mockSrv.URL defer mockSrv.Close() - actual, err := fetchLatestApplicationVersion() + actual, err := fetchLatestApplicationVersion(test.id) if err != nil && !test.err { t.Fatalf("got error but expected none: %+v", err) } else if err == nil && test.err {