diff --git a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager.go b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager.go index 997dc93cd..7b4a8e6e2 100644 --- a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager.go +++ b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager.go @@ -10,6 +10,23 @@ import ( const cacheDir = ".cpe-cache" +// parseNVDTimestamp parses timestamps from NVD API which may or may not include timezone +func parseNVDTimestamp(ts string) (time.Time, error) { + // try RFC3339 first (with timezone) + if t, err := time.Parse(time.RFC3339, ts); err == nil { + return t, nil + } + // try without timezone (NVD sometimes returns this format) + if t, err := time.Parse("2006-01-02T15:04:05.000", ts); err == nil { + return t, nil + } + // try without milliseconds + if t, err := time.Parse("2006-01-02T15:04:05", ts); err == nil { + return t, nil + } + return time.Time{}, fmt.Errorf("unable to parse timestamp: %s", ts) +} + // IncrementMetadata tracks a single fetch increment for a monthly batch type IncrementMetadata struct { FetchedAt time.Time `json:"fetchedAt"` @@ -215,7 +232,7 @@ func groupProductsByMonth(products []NVDProduct) (map[string][]NVDProduct, error productsByMonth := make(map[string][]NVDProduct) for _, product := range products { - lastMod, err := time.Parse(time.RFC3339, product.CPE.LastModified) + lastMod, err := parseNVDTimestamp(product.CPE.LastModified) if err != nil { return nil, fmt.Errorf("failed to parse lastModified for %s: %w", product.CPE.CPENameID, err) } @@ -342,8 +359,8 @@ func (m *CacheManager) LoadAllProducts() ([]NVDProduct, error) { } // compare lastModified timestamps to keep the newer one - newMod, _ := time.Parse(time.RFC3339, p.CPE.LastModified) - existingMod, _ := time.Parse(time.RFC3339, existing.CPE.LastModified) + newMod, _ := parseNVDTimestamp(p.CPE.LastModified) + existingMod, _ := parseNVDTimestamp(existing.CPE.LastModified) if newMod.After(existingMod) { productMap[p.CPE.CPENameID] = p diff --git a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager_test.go b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager_test.go index 2ed185849..7e108f103 100644 --- a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager_test.go +++ b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/cache_manager_test.go @@ -317,3 +317,59 @@ func TestCacheManager_SaveProducts(t *testing.T) { assert.Equal(t, 1, metadata.MonthlyBatches["2024-12"].TotalProducts) }) } + +func TestParseNVDTimestamp(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedStr string // expected time formatted as 2006-01-02T15:04:05 + }{ + { + name: "RFC3339 with Z timezone", + input: "2024-11-15T10:00:00.000Z", + expectError: false, + expectedStr: "2024-11-15T10:00:00", + }, + { + name: "RFC3339 with offset timezone", + input: "2024-11-15T10:00:00+00:00", + expectError: false, + expectedStr: "2024-11-15T10:00:00", + }, + { + name: "without timezone (NVD format)", + input: "2026-01-27T15:53:34.823", + expectError: false, + expectedStr: "2026-01-27T15:53:34", + }, + { + name: "without timezone or milliseconds", + input: "2024-11-15T10:00:00", + expectError: false, + expectedStr: "2024-11-15T10:00:00", + }, + { + name: "invalid format", + input: "not-a-timestamp", + expectError: true, + }, + { + name: "empty string", + input: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseNVDTimestamp(tt.input) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedStr, result.Format("2006-01-02T15:04:05")) + } + }) + } +} diff --git a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/main.go b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/main.go index be6a9ffe0..b8c09475f 100644 --- a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/main.go +++ b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/main.go @@ -53,33 +53,26 @@ func updateCache(ctx context.Context, cacheManager *CacheManager, forceFullRefre lastModStartDate, isFullRefresh := determineUpdateMode(metadata, forceFullRefresh) - // use resume index if available - resumeFromIndex := 0 - if !isFullRefresh && metadata.LastStartIndex > 0 { - resumeFromIndex = metadata.LastStartIndex - fmt.Printf("Resuming from index %d...\n", resumeFromIndex) - } - - allProducts, increment, err := fetchProducts(ctx, lastModStartDate, resumeFromIndex) + products, increment, err := fetchProducts(ctx, lastModStartDate) if err != nil { // if we have partial products, save them before returning error - if len(allProducts) > 0 { - fmt.Printf("\nError occurred but saving %d products fetched so far...\n", len(allProducts)) - if saveErr := saveAndReportResults(cacheManager, allProducts, isFullRefresh, metadata, increment); saveErr != nil { + if len(products) > 0 { + fmt.Printf("\nError occurred but saving %d products fetched so far...\n", len(products)) + if saveErr := saveAndReportResults(cacheManager, products, isFullRefresh, metadata, increment); saveErr != nil { fmt.Printf("WARNING: Failed to save partial progress: %v\n", saveErr) } else { - fmt.Println("Partial progress saved successfully. Run again to resume from this point.") + fmt.Println("Partial progress saved successfully.") } } return err } - if len(allProducts) == 0 { + if len(products) == 0 { fmt.Println("No products fetched (already up to date)") return nil } - return saveAndReportResults(cacheManager, allProducts, isFullRefresh, metadata, increment) + return saveAndReportResults(cacheManager, products, isFullRefresh, metadata, increment) } // determineUpdateMode decides whether to do a full refresh or incremental update @@ -94,48 +87,25 @@ func determineUpdateMode(metadata *CacheMetadata, forceFullRefresh bool) (time.T } // fetchProducts fetches products from the NVD API -func fetchProducts(ctx context.Context, lastModStartDate time.Time, resumeFromIndex int) ([]NVDProduct, IncrementMetadata, error) { +func fetchProducts(ctx context.Context, lastModStartDate time.Time) ([]NVDProduct, IncrementMetadata, error) { apiClient := NewNVDAPIClient() fmt.Println("Fetching CPE data from NVD Products API...") - var allProducts []NVDProduct - var totalResults int - var firstStartIndex, lastEndIndex int - - onPageFetched := func(startIndex int, response NVDProductsResponse) error { - if totalResults == 0 { - totalResults = response.TotalResults - firstStartIndex = startIndex - } - lastEndIndex = startIndex + response.ResultsPerPage - allProducts = append(allProducts, response.Products...) - fmt.Printf("Fetched %d/%d products...\n", len(allProducts), totalResults) - return nil - } - - if err := apiClient.FetchProductsSince(ctx, lastModStartDate, resumeFromIndex, onPageFetched); err != nil { - // return partial products with increment metadata so they can be saved - increment := IncrementMetadata{ - FetchedAt: time.Now(), - LastModStartDate: lastModStartDate, - LastModEndDate: time.Now(), - Products: len(allProducts), - StartIndex: firstStartIndex, - EndIndex: lastEndIndex, - } - return allProducts, increment, fmt.Errorf("failed to fetch products from NVD API: %w", err) - } + products, err := apiClient.FetchProductsSince(ctx, lastModStartDate) increment := IncrementMetadata{ FetchedAt: time.Now(), LastModStartDate: lastModStartDate, LastModEndDate: time.Now(), - Products: len(allProducts), - StartIndex: firstStartIndex, - EndIndex: lastEndIndex, + Products: len(products), } - return allProducts, increment, nil + if err != nil { + // return partial products so they can be saved + return products, increment, fmt.Errorf("failed to fetch products from NVD API: %w", err) + } + + return products, increment, nil } // saveAndReportResults saves products and metadata, then reports success diff --git a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/nvd_api_client.go b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/nvd_api_client.go index ec22a889c..904639f6b 100644 --- a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/nvd_api_client.go +++ b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/nvd_api_client.go @@ -24,6 +24,9 @@ const ( // retry configuration for rate limiting maxRetries = 5 baseRetryDelay = 30 * time.Second // NVD uses 30-second rolling windows + + // NVD API has a maximum date range of 120 days for queries with date filters + maxDateRangeDays = 120 ) // NVDAPIClient handles communication with the NVD Products API @@ -109,45 +112,98 @@ func NewNVDAPIClient() *NVDAPIClient { } } -// PageCallback is called after each page is successfully fetched -// it receives the startIndex and the response for that page -type PageCallback func(startIndex int, response NVDProductsResponse) error +// FetchProductsSince fetches all products modified since the given date. +// If lastModStartDate is zero, fetches all products (no date filter). +// If lastModStartDate is set, fetches in 120-day chunks (NVD API limit) from that date to now. +// Returns partial results on error so progress can be saved. +func (c *NVDAPIClient) FetchProductsSince(ctx context.Context, lastModStartDate time.Time) ([]NVDProduct, error) { + // if no date filter, fetch all products in a single pass + if lastModStartDate.IsZero() { + return c.fetchDateRange(ctx, time.Time{}, time.Time{}) + } -// FetchProductsSince fetches all products modified since the given date -// if lastModStartDate is zero, fetches all products -// calls onPageFetched callback after each successful page fetch for incremental saving -// if resumeFromIndex > 0, starts fetching from that index -func (c *NVDAPIClient) FetchProductsSince(ctx context.Context, lastModStartDate time.Time, resumeFromIndex int, onPageFetched PageCallback) error { - startIndex := resumeFromIndex + // fetch in 120-day chunks from lastModStartDate to now + chunks := buildDateChunks(lastModStartDate, time.Now().UTC()) + if len(chunks) > 1 { + fmt.Printf("Date range spans %d chunks of up to %d days each\n", len(chunks), maxDateRangeDays) + } + + var allProducts []NVDProduct + for i, chunk := range chunks { + if len(chunks) > 1 { + fmt.Printf("Fetching chunk %d/%d: %s to %s\n", i+1, len(chunks), + chunk.start.Format("2006-01-02"), chunk.end.Format("2006-01-02")) + } + + products, err := c.fetchDateRange(ctx, chunk.start, chunk.end) + if err != nil { + // return partial results so caller can save progress + return allProducts, err + } + + allProducts = append(allProducts, products...) + if len(chunks) > 1 { + fmt.Printf("Chunk %d complete: %d products (total so far: %d)\n", i+1, len(products), len(allProducts)) + } + } + + fmt.Printf("Fetched %d products total\n", len(allProducts)) + return allProducts, nil +} + +// dateChunk represents a date range for fetching +type dateChunk struct { + start time.Time + end time.Time +} + +// buildDateChunks splits a date range into chunks of maxDateRangeDays +func buildDateChunks(start, end time.Time) []dateChunk { + var chunks []dateChunk + chunkStart := start + + for chunkStart.Before(end) { + chunkEnd := chunkStart.AddDate(0, 0, maxDateRangeDays) + if chunkEnd.After(end) { + chunkEnd = end + } + chunks = append(chunks, dateChunk{start: chunkStart, end: chunkEnd}) + chunkStart = chunkEnd + } + + return chunks +} + +// fetchDateRange fetches all products within a single date range (must be <= 120 days). +// If start and end are both zero, fetches all products without date filtering. +func (c *NVDAPIClient) fetchDateRange(ctx context.Context, start, end time.Time) ([]NVDProduct, error) { + var products []NVDProduct + startIndex := 0 for { - resp, err := c.fetchPage(ctx, startIndex, lastModStartDate) + resp, err := c.fetchPage(ctx, startIndex, start, end) if err != nil { - return fmt.Errorf("failed to fetch page at index %d: %w", startIndex, err) + return products, fmt.Errorf("failed to fetch page at index %d: %w", startIndex, err) } - // call callback to save progress immediately - if onPageFetched != nil { - if err := onPageFetched(startIndex, resp); err != nil { - return fmt.Errorf("callback failed at index %d: %w", startIndex, err) - } - } + products = append(products, resp.Products...) + fmt.Printf(" Fetched %d/%d products...\n", len(products), resp.TotalResults) // check if we've fetched all results if startIndex+resp.ResultsPerPage >= resp.TotalResults { - fmt.Printf("Fetched %d/%d products (complete)\n", resp.TotalResults, resp.TotalResults) break } startIndex += resp.ResultsPerPage - fmt.Printf("Fetched %d/%d products...\n", startIndex, resp.TotalResults) } - return nil + return products, nil } // fetchPage fetches a single page of results from the NVD API with retry logic for rate limiting -func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, lastModStartDate time.Time) (NVDProductsResponse, error) { +// if both start and end are zero, fetches without date filtering +// if start and end are set, they must form a range <= 120 days (enforced by caller) +func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, start, end time.Time) (NVDProductsResponse, error) { var lastErr error for attempt := 0; attempt < maxRetries; attempt++ { @@ -160,10 +216,12 @@ func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, lastModSta url := fmt.Sprintf("%s?resultsPerPage=%d&startIndex=%d", nvdProductsAPIURL, resultsPerPage, startIndex) // add date range if specified (incremental update) - if !lastModStartDate.IsZero() { - // NVD API requires RFC3339 format: 2024-01-01T00:00:00.000 - lastModStartStr := lastModStartDate.Format("2006-01-02T15:04:05.000") - url += fmt.Sprintf("&lastModStartDate=%s", lastModStartStr) + // NVD API requires both lastModStartDate and lastModEndDate when either is present + if !start.IsZero() && !end.IsZero() { + // NVD API requires this format: 2024-01-01T00:00:00.000 + startStr := start.Format("2006-01-02T15:04:05.000") + endStr := end.Format("2006-01-02T15:04:05.000") + url += fmt.Sprintf("&lastModStartDate=%s&lastModEndDate=%s", startStr, endStr) } // create request @@ -191,14 +249,12 @@ func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, lastModSta continue // retry } - // handle HTTP status codes - statusResponse, handled, err := c.handleHTTPStatus(httpResp, startIndex) - if handled { - // either error or special case (404 with empty results) - return statusResponse, err + // check for error status codes + if err := checkHTTPStatus(httpResp); err != nil { + return NVDProductsResponse{}, err } - // success - parse response + // parse response var response NVDProductsResponse if err := json.NewDecoder(httpResp.Body).Decode(&response); err != nil { httpResp.Body.Close() @@ -235,31 +291,16 @@ func (c *NVDAPIClient) handleRateLimit(ctx context.Context, httpResp *http.Respo } } -// handleHTTPStatus handles non-429 HTTP status codes -// returns (response, handled, error) where: -// - handled=true means the status was processed (either success case like 404 or error) -// - handled=false means continue to normal response parsing -func (c *NVDAPIClient) handleHTTPStatus(httpResp *http.Response, startIndex int) (NVDProductsResponse, bool, error) { - // handle 404 as "no results found" (common when querying recent dates with no updates) - if httpResp.StatusCode == http.StatusNotFound { - httpResp.Body.Close() - return NVDProductsResponse{ - ResultsPerPage: 0, - StartIndex: startIndex, - TotalResults: 0, - Products: []NVDProduct{}, - }, true, nil +// checkHTTPStatus returns an error for non-200 status codes. +// NVD API returns 200 with TotalResults=0 when there are no results, +// so any non-200 status (including 404) indicates an actual error. +func checkHTTPStatus(httpResp *http.Response) error { + if httpResp.StatusCode == http.StatusOK { + return nil } - - // check for other non-200 status codes - if httpResp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(httpResp.Body) - httpResp.Body.Close() - return NVDProductsResponse{}, true, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) - } - - // status OK - let caller parse response - return NVDProductsResponse{}, false, nil + body, _ := io.ReadAll(httpResp.Body) + httpResp.Body.Close() + return fmt.Errorf("NVD API error (status %d): %s", httpResp.StatusCode, string(body)) } // parseRetryAfter parses the Retry-After header from HTTP 429 responses diff --git a/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/nvd_api_client_test.go b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/nvd_api_client_test.go new file mode 100644 index 000000000..a304e8c8d --- /dev/null +++ b/syft/pkg/cataloger/internal/cpegenerate/dictionary/index-generator/nvd_api_client_test.go @@ -0,0 +1,233 @@ +package main + +import ( + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildDateChunks(t *testing.T) { + tests := []struct { + name string + start time.Time + end time.Time + expectedChunks int + validateChunks func(t *testing.T, chunks []dateChunk) + }{ + { + name: "single chunk when range is under 120 days", + start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + end: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), // 31 days + expectedChunks: 1, + validateChunks: func(t *testing.T, chunks []dateChunk) { + assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start) + assert.Equal(t, time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), chunks[0].end) + }, + }, + { + name: "single chunk when range is exactly 120 days", + start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + end: time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), // 120 days + expectedChunks: 1, + validateChunks: func(t *testing.T, chunks []dateChunk) { + assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start) + assert.Equal(t, time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), chunks[0].end) + }, + }, + { + name: "two chunks when range is 121 days", + start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + end: time.Date(2025, 5, 2, 0, 0, 0, 0, time.UTC), // 121 days + expectedChunks: 2, + validateChunks: func(t *testing.T, chunks []dateChunk) { + // first chunk: Jan 1 to May 1 (120 days) + assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start) + assert.Equal(t, time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), chunks[0].end) + // second chunk: May 1 to May 2 (1 day) + assert.Equal(t, time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), chunks[1].start) + assert.Equal(t, time.Date(2025, 5, 2, 0, 0, 0, 0, time.UTC), chunks[1].end) + }, + }, + { + name: "multiple chunks for a full year", + start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + end: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), // 366 days (leap year) + expectedChunks: 4, + validateChunks: func(t *testing.T, chunks []dateChunk) { + // verify chunks are contiguous (each chunk starts where previous ended) + for i := 1; i < len(chunks); i++ { + assert.Equal(t, chunks[i-1].end, chunks[i].start, "chunks should be contiguous") + } + // verify first and last + assert.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start) + assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[len(chunks)-1].end) + }, + }, + { + name: "empty result when start equals end", + start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + end: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + expectedChunks: 0, + }, + { + name: "empty result when start is after end", + start: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), + end: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + expectedChunks: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chunks := buildDateChunks(tt.start, tt.end) + assert.Len(t, chunks, tt.expectedChunks) + if tt.validateChunks != nil && len(chunks) > 0 { + tt.validateChunks(t, chunks) + } + }) + } +} + +func TestBuildDateChunks_ChunkSizeNeverExceeds120Days(t *testing.T) { + // test with various date ranges to ensure no chunk exceeds 120 days + testCases := []struct { + start time.Time + end time.Time + }{ + {time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2023, 6, 15, 0, 0, 0, 0, time.UTC), time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC)}, + } + + for _, tc := range testCases { + chunks := buildDateChunks(tc.start, tc.end) + for i, chunk := range chunks { + days := chunk.end.Sub(chunk.start).Hours() / 24 + assert.LessOrEqual(t, days, float64(maxDateRangeDays), + "chunk %d exceeds max days: start=%s, end=%s, days=%.0f", + i, chunk.start.Format("2006-01-02"), chunk.end.Format("2006-01-02"), days) + } + } +} + +// mockResponseBody creates an io.ReadCloser from a string for testing +type mockReadCloser struct { + io.Reader + closed bool +} + +func (m *mockReadCloser) Close() error { + m.closed = true + return nil +} + +func newMockResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: &mockReadCloser{Reader: strings.NewReader(body)}, + } +} + +func TestCheckHTTPStatus(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + expectError bool + errorContains string + }{ + { + name: "200 OK returns nil", + statusCode: http.StatusOK, + body: `{"totalResults": 0}`, + expectError: false, + }, + { + name: "404 returns error", + statusCode: http.StatusNotFound, + body: "", + expectError: true, + errorContains: "status 404", + }, + { + name: "400 Bad Request returns error with body", + statusCode: http.StatusBadRequest, + body: "Both lastModStartDate and lastModEndDate are required", + expectError: true, + errorContains: "lastModStartDate and lastModEndDate are required", + }, + { + name: "500 Internal Server Error returns error", + statusCode: http.StatusInternalServerError, + body: "Internal server error", + expectError: true, + errorContains: "status 500", + }, + { + name: "503 Service Unavailable returns error", + statusCode: http.StatusServiceUnavailable, + body: "Service temporarily unavailable", + expectError: true, + errorContains: "status 503", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := newMockResponse(tt.statusCode, tt.body) + err := checkHTTPStatus(resp) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + // verify body was closed on error + assert.True(t, resp.Body.(*mockReadCloser).closed, "response body should be closed on error") + } else { + require.NoError(t, err) + // verify body was NOT closed on success (caller needs to read it) + assert.False(t, resp.Body.(*mockReadCloser).closed, "response body should not be closed on success") + } + }) + } +} + +func TestParseRetryAfter(t *testing.T) { + tests := []struct { + name string + header string + expected time.Duration + }{ + { + name: "empty header returns 0", + header: "", + expected: 0, + }, + { + name: "numeric seconds", + header: "30", + expected: 30 * time.Second, + }, + { + name: "numeric seconds - larger value", + header: "120", + expected: 120 * time.Second, + }, + { + name: "invalid value returns 0", + header: "not-a-number", + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseRetryAfter(tt.header) + assert.Equal(t, tt.expected, result) + }) + } +}