fix(cpe-generation): set start and end date (#4600)

* fix(cpe-generation): set start and end date

Previously, the update job was silently failing because the NVD API
returns a 404 with no body if a start date is specified but not an end
date. Further, the API returns an error if more than 120 days are in
range of the start and end date.

Update the API client to:
1. Return a non-nil error on http 404
2. Chunk the date range into 120 day chunks
3. Pass start and end date to avoid errors.

Also add more tolerant timestamp parsing since the previous update job
would fail with timestamp format errors.

Signed-off-by: Will Murphy <willmurphyscode@users.noreply.github.com>

* refactor(cpe-generator): remove callbacks

Previously, this job had callbacks that were there to make sure that
incremental progress could be written to disk. However, incremental
progress was not being written to disk, and there were issues related to
the callbacks like double logging. Therefore, just remove the callbacks
and do simple imperative code to page through the API results.

Signed-off-by: Will Murphy <willmurphyscode@users.noreply.github.com>

---------

Signed-off-by: Will Murphy <willmurphyscode@users.noreply.github.com>
This commit is contained in:
Will Murphy 2026-02-05 09:54:24 -05:00 committed by GitHub
parent 6755377554
commit 138cb1be0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 421 additions and 104 deletions

View File

@ -10,6 +10,23 @@ import (
const cacheDir = ".cpe-cache" const cacheDir = ".cpe-cache"
// parseNVDTimestamp parses timestamps from NVD API which may or may not include timezone
func parseNVDTimestamp(ts string) (time.Time, error) {
// try RFC3339 first (with timezone)
if t, err := time.Parse(time.RFC3339, ts); err == nil {
return t, nil
}
// try without timezone (NVD sometimes returns this format)
if t, err := time.Parse("2006-01-02T15:04:05.000", ts); err == nil {
return t, nil
}
// try without milliseconds
if t, err := time.Parse("2006-01-02T15:04:05", ts); err == nil {
return t, nil
}
return time.Time{}, fmt.Errorf("unable to parse timestamp: %s", ts)
}
// IncrementMetadata tracks a single fetch increment for a monthly batch // IncrementMetadata tracks a single fetch increment for a monthly batch
type IncrementMetadata struct { type IncrementMetadata struct {
FetchedAt time.Time `json:"fetchedAt"` FetchedAt time.Time `json:"fetchedAt"`
@ -215,7 +232,7 @@ func groupProductsByMonth(products []NVDProduct) (map[string][]NVDProduct, error
productsByMonth := make(map[string][]NVDProduct) productsByMonth := make(map[string][]NVDProduct)
for _, product := range products { for _, product := range products {
lastMod, err := time.Parse(time.RFC3339, product.CPE.LastModified) lastMod, err := parseNVDTimestamp(product.CPE.LastModified)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse lastModified for %s: %w", product.CPE.CPENameID, err) return nil, fmt.Errorf("failed to parse lastModified for %s: %w", product.CPE.CPENameID, err)
} }
@ -342,8 +359,8 @@ func (m *CacheManager) LoadAllProducts() ([]NVDProduct, error) {
} }
// compare lastModified timestamps to keep the newer one // compare lastModified timestamps to keep the newer one
newMod, _ := time.Parse(time.RFC3339, p.CPE.LastModified) newMod, _ := parseNVDTimestamp(p.CPE.LastModified)
existingMod, _ := time.Parse(time.RFC3339, existing.CPE.LastModified) existingMod, _ := parseNVDTimestamp(existing.CPE.LastModified)
if newMod.After(existingMod) { if newMod.After(existingMod) {
productMap[p.CPE.CPENameID] = p productMap[p.CPE.CPENameID] = p

View File

@ -317,3 +317,59 @@ func TestCacheManager_SaveProducts(t *testing.T) {
assert.Equal(t, 1, metadata.MonthlyBatches["2024-12"].TotalProducts) assert.Equal(t, 1, metadata.MonthlyBatches["2024-12"].TotalProducts)
}) })
} }
func TestParseNVDTimestamp(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
expectedStr string // expected time formatted as 2006-01-02T15:04:05
}{
{
name: "RFC3339 with Z timezone",
input: "2024-11-15T10:00:00.000Z",
expectError: false,
expectedStr: "2024-11-15T10:00:00",
},
{
name: "RFC3339 with offset timezone",
input: "2024-11-15T10:00:00+00:00",
expectError: false,
expectedStr: "2024-11-15T10:00:00",
},
{
name: "without timezone (NVD format)",
input: "2026-01-27T15:53:34.823",
expectError: false,
expectedStr: "2026-01-27T15:53:34",
},
{
name: "without timezone or milliseconds",
input: "2024-11-15T10:00:00",
expectError: false,
expectedStr: "2024-11-15T10:00:00",
},
{
name: "invalid format",
input: "not-a-timestamp",
expectError: true,
},
{
name: "empty string",
input: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseNVDTimestamp(tt.input)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expectedStr, result.Format("2006-01-02T15:04:05"))
}
})
}
}

View File

@ -53,33 +53,26 @@ func updateCache(ctx context.Context, cacheManager *CacheManager, forceFullRefre
lastModStartDate, isFullRefresh := determineUpdateMode(metadata, forceFullRefresh) lastModStartDate, isFullRefresh := determineUpdateMode(metadata, forceFullRefresh)
// use resume index if available products, increment, err := fetchProducts(ctx, lastModStartDate)
resumeFromIndex := 0
if !isFullRefresh && metadata.LastStartIndex > 0 {
resumeFromIndex = metadata.LastStartIndex
fmt.Printf("Resuming from index %d...\n", resumeFromIndex)
}
allProducts, increment, err := fetchProducts(ctx, lastModStartDate, resumeFromIndex)
if err != nil { if err != nil {
// if we have partial products, save them before returning error // if we have partial products, save them before returning error
if len(allProducts) > 0 { if len(products) > 0 {
fmt.Printf("\nError occurred but saving %d products fetched so far...\n", len(allProducts)) fmt.Printf("\nError occurred but saving %d products fetched so far...\n", len(products))
if saveErr := saveAndReportResults(cacheManager, allProducts, isFullRefresh, metadata, increment); saveErr != nil { if saveErr := saveAndReportResults(cacheManager, products, isFullRefresh, metadata, increment); saveErr != nil {
fmt.Printf("WARNING: Failed to save partial progress: %v\n", saveErr) fmt.Printf("WARNING: Failed to save partial progress: %v\n", saveErr)
} else { } else {
fmt.Println("Partial progress saved successfully. Run again to resume from this point.") fmt.Println("Partial progress saved successfully.")
} }
} }
return err return err
} }
if len(allProducts) == 0 { if len(products) == 0 {
fmt.Println("No products fetched (already up to date)") fmt.Println("No products fetched (already up to date)")
return nil return nil
} }
return saveAndReportResults(cacheManager, allProducts, isFullRefresh, metadata, increment) return saveAndReportResults(cacheManager, products, isFullRefresh, metadata, increment)
} }
// determineUpdateMode decides whether to do a full refresh or incremental update // determineUpdateMode decides whether to do a full refresh or incremental update
@ -94,48 +87,25 @@ func determineUpdateMode(metadata *CacheMetadata, forceFullRefresh bool) (time.T
} }
// fetchProducts fetches products from the NVD API // fetchProducts fetches products from the NVD API
func fetchProducts(ctx context.Context, lastModStartDate time.Time, resumeFromIndex int) ([]NVDProduct, IncrementMetadata, error) { func fetchProducts(ctx context.Context, lastModStartDate time.Time) ([]NVDProduct, IncrementMetadata, error) {
apiClient := NewNVDAPIClient() apiClient := NewNVDAPIClient()
fmt.Println("Fetching CPE data from NVD Products API...") fmt.Println("Fetching CPE data from NVD Products API...")
var allProducts []NVDProduct products, err := apiClient.FetchProductsSince(ctx, lastModStartDate)
var totalResults int
var firstStartIndex, lastEndIndex int
onPageFetched := func(startIndex int, response NVDProductsResponse) error {
if totalResults == 0 {
totalResults = response.TotalResults
firstStartIndex = startIndex
}
lastEndIndex = startIndex + response.ResultsPerPage
allProducts = append(allProducts, response.Products...)
fmt.Printf("Fetched %d/%d products...\n", len(allProducts), totalResults)
return nil
}
if err := apiClient.FetchProductsSince(ctx, lastModStartDate, resumeFromIndex, onPageFetched); err != nil {
// return partial products with increment metadata so they can be saved
increment := IncrementMetadata{
FetchedAt: time.Now(),
LastModStartDate: lastModStartDate,
LastModEndDate: time.Now(),
Products: len(allProducts),
StartIndex: firstStartIndex,
EndIndex: lastEndIndex,
}
return allProducts, increment, fmt.Errorf("failed to fetch products from NVD API: %w", err)
}
increment := IncrementMetadata{ increment := IncrementMetadata{
FetchedAt: time.Now(), FetchedAt: time.Now(),
LastModStartDate: lastModStartDate, LastModStartDate: lastModStartDate,
LastModEndDate: time.Now(), LastModEndDate: time.Now(),
Products: len(allProducts), Products: len(products),
StartIndex: firstStartIndex,
EndIndex: lastEndIndex,
} }
return allProducts, increment, nil if err != nil {
// return partial products so they can be saved
return products, increment, fmt.Errorf("failed to fetch products from NVD API: %w", err)
}
return products, increment, nil
} }
// saveAndReportResults saves products and metadata, then reports success // saveAndReportResults saves products and metadata, then reports success

View File

@ -24,6 +24,9 @@ const (
// retry configuration for rate limiting // retry configuration for rate limiting
maxRetries = 5 maxRetries = 5
baseRetryDelay = 30 * time.Second // NVD uses 30-second rolling windows baseRetryDelay = 30 * time.Second // NVD uses 30-second rolling windows
// NVD API has a maximum date range of 120 days for queries with date filters
maxDateRangeDays = 120
) )
// NVDAPIClient handles communication with the NVD Products API // NVDAPIClient handles communication with the NVD Products API
@ -109,45 +112,98 @@ func NewNVDAPIClient() *NVDAPIClient {
} }
} }
// PageCallback is called after each page is successfully fetched // FetchProductsSince fetches all products modified since the given date.
// it receives the startIndex and the response for that page // If lastModStartDate is zero, fetches all products (no date filter).
type PageCallback func(startIndex int, response NVDProductsResponse) error // If lastModStartDate is set, fetches in 120-day chunks (NVD API limit) from that date to now.
// Returns partial results on error so progress can be saved.
func (c *NVDAPIClient) FetchProductsSince(ctx context.Context, lastModStartDate time.Time) ([]NVDProduct, error) {
// if no date filter, fetch all products in a single pass
if lastModStartDate.IsZero() {
return c.fetchDateRange(ctx, time.Time{}, time.Time{})
}
// FetchProductsSince fetches all products modified since the given date // fetch in 120-day chunks from lastModStartDate to now
// if lastModStartDate is zero, fetches all products chunks := buildDateChunks(lastModStartDate, time.Now().UTC())
// calls onPageFetched callback after each successful page fetch for incremental saving if len(chunks) > 1 {
// if resumeFromIndex > 0, starts fetching from that index fmt.Printf("Date range spans %d chunks of up to %d days each\n", len(chunks), maxDateRangeDays)
func (c *NVDAPIClient) FetchProductsSince(ctx context.Context, lastModStartDate time.Time, resumeFromIndex int, onPageFetched PageCallback) error { }
startIndex := resumeFromIndex
var allProducts []NVDProduct
for i, chunk := range chunks {
if len(chunks) > 1 {
fmt.Printf("Fetching chunk %d/%d: %s to %s\n", i+1, len(chunks),
chunk.start.Format("2006-01-02"), chunk.end.Format("2006-01-02"))
}
products, err := c.fetchDateRange(ctx, chunk.start, chunk.end)
if err != nil {
// return partial results so caller can save progress
return allProducts, err
}
allProducts = append(allProducts, products...)
if len(chunks) > 1 {
fmt.Printf("Chunk %d complete: %d products (total so far: %d)\n", i+1, len(products), len(allProducts))
}
}
fmt.Printf("Fetched %d products total\n", len(allProducts))
return allProducts, nil
}
// dateChunk represents a date range for fetching
type dateChunk struct {
start time.Time
end time.Time
}
// buildDateChunks splits a date range into chunks of maxDateRangeDays
func buildDateChunks(start, end time.Time) []dateChunk {
var chunks []dateChunk
chunkStart := start
for chunkStart.Before(end) {
chunkEnd := chunkStart.AddDate(0, 0, maxDateRangeDays)
if chunkEnd.After(end) {
chunkEnd = end
}
chunks = append(chunks, dateChunk{start: chunkStart, end: chunkEnd})
chunkStart = chunkEnd
}
return chunks
}
// fetchDateRange fetches all products within a single date range (must be <= 120 days).
// If start and end are both zero, fetches all products without date filtering.
func (c *NVDAPIClient) fetchDateRange(ctx context.Context, start, end time.Time) ([]NVDProduct, error) {
var products []NVDProduct
startIndex := 0
for { for {
resp, err := c.fetchPage(ctx, startIndex, lastModStartDate) resp, err := c.fetchPage(ctx, startIndex, start, end)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch page at index %d: %w", startIndex, err) return products, fmt.Errorf("failed to fetch page at index %d: %w", startIndex, err)
} }
// call callback to save progress immediately products = append(products, resp.Products...)
if onPageFetched != nil { fmt.Printf(" Fetched %d/%d products...\n", len(products), resp.TotalResults)
if err := onPageFetched(startIndex, resp); err != nil {
return fmt.Errorf("callback failed at index %d: %w", startIndex, err)
}
}
// check if we've fetched all results // check if we've fetched all results
if startIndex+resp.ResultsPerPage >= resp.TotalResults { if startIndex+resp.ResultsPerPage >= resp.TotalResults {
fmt.Printf("Fetched %d/%d products (complete)\n", resp.TotalResults, resp.TotalResults)
break break
} }
startIndex += resp.ResultsPerPage startIndex += resp.ResultsPerPage
fmt.Printf("Fetched %d/%d products...\n", startIndex, resp.TotalResults)
} }
return nil return products, nil
} }
// fetchPage fetches a single page of results from the NVD API with retry logic for rate limiting // fetchPage fetches a single page of results from the NVD API with retry logic for rate limiting
func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, lastModStartDate time.Time) (NVDProductsResponse, error) { // if both start and end are zero, fetches without date filtering
// if start and end are set, they must form a range <= 120 days (enforced by caller)
func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, start, end time.Time) (NVDProductsResponse, error) {
var lastErr error var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ { for attempt := 0; attempt < maxRetries; attempt++ {
@ -160,10 +216,12 @@ func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, lastModSta
url := fmt.Sprintf("%s?resultsPerPage=%d&startIndex=%d", nvdProductsAPIURL, resultsPerPage, startIndex) url := fmt.Sprintf("%s?resultsPerPage=%d&startIndex=%d", nvdProductsAPIURL, resultsPerPage, startIndex)
// add date range if specified (incremental update) // add date range if specified (incremental update)
if !lastModStartDate.IsZero() { // NVD API requires both lastModStartDate and lastModEndDate when either is present
// NVD API requires RFC3339 format: 2024-01-01T00:00:00.000 if !start.IsZero() && !end.IsZero() {
lastModStartStr := lastModStartDate.Format("2006-01-02T15:04:05.000") // NVD API requires this format: 2024-01-01T00:00:00.000
url += fmt.Sprintf("&lastModStartDate=%s", lastModStartStr) startStr := start.Format("2006-01-02T15:04:05.000")
endStr := end.Format("2006-01-02T15:04:05.000")
url += fmt.Sprintf("&lastModStartDate=%s&lastModEndDate=%s", startStr, endStr)
} }
// create request // create request
@ -191,14 +249,12 @@ func (c *NVDAPIClient) fetchPage(ctx context.Context, startIndex int, lastModSta
continue // retry continue // retry
} }
// handle HTTP status codes // check for error status codes
statusResponse, handled, err := c.handleHTTPStatus(httpResp, startIndex) if err := checkHTTPStatus(httpResp); err != nil {
if handled { return NVDProductsResponse{}, err
// either error or special case (404 with empty results)
return statusResponse, err
} }
// success - parse response // parse response
var response NVDProductsResponse var response NVDProductsResponse
if err := json.NewDecoder(httpResp.Body).Decode(&response); err != nil { if err := json.NewDecoder(httpResp.Body).Decode(&response); err != nil {
httpResp.Body.Close() httpResp.Body.Close()
@ -235,31 +291,16 @@ func (c *NVDAPIClient) handleRateLimit(ctx context.Context, httpResp *http.Respo
} }
} }
// handleHTTPStatus handles non-429 HTTP status codes // checkHTTPStatus returns an error for non-200 status codes.
// returns (response, handled, error) where: // NVD API returns 200 with TotalResults=0 when there are no results,
// - handled=true means the status was processed (either success case like 404 or error) // so any non-200 status (including 404) indicates an actual error.
// - handled=false means continue to normal response parsing func checkHTTPStatus(httpResp *http.Response) error {
func (c *NVDAPIClient) handleHTTPStatus(httpResp *http.Response, startIndex int) (NVDProductsResponse, bool, error) { if httpResp.StatusCode == http.StatusOK {
// handle 404 as "no results found" (common when querying recent dates with no updates) return nil
if httpResp.StatusCode == http.StatusNotFound {
httpResp.Body.Close()
return NVDProductsResponse{
ResultsPerPage: 0,
StartIndex: startIndex,
TotalResults: 0,
Products: []NVDProduct{},
}, true, nil
} }
body, _ := io.ReadAll(httpResp.Body)
// check for other non-200 status codes httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK { return fmt.Errorf("NVD API error (status %d): %s", httpResp.StatusCode, string(body))
body, _ := io.ReadAll(httpResp.Body)
httpResp.Body.Close()
return NVDProductsResponse{}, true, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body))
}
// status OK - let caller parse response
return NVDProductsResponse{}, false, nil
} }
// parseRetryAfter parses the Retry-After header from HTTP 429 responses // parseRetryAfter parses the Retry-After header from HTTP 429 responses

View File

@ -0,0 +1,233 @@
package main
import (
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildDateChunks(t *testing.T) {
tests := []struct {
name string
start time.Time
end time.Time
expectedChunks int
validateChunks func(t *testing.T, chunks []dateChunk)
}{
{
name: "single chunk when range is under 120 days",
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), // 31 days
expectedChunks: 1,
validateChunks: func(t *testing.T, chunks []dateChunk) {
assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start)
assert.Equal(t, time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), chunks[0].end)
},
},
{
name: "single chunk when range is exactly 120 days",
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), // 120 days
expectedChunks: 1,
validateChunks: func(t *testing.T, chunks []dateChunk) {
assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start)
assert.Equal(t, time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), chunks[0].end)
},
},
{
name: "two chunks when range is 121 days",
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 5, 2, 0, 0, 0, 0, time.UTC), // 121 days
expectedChunks: 2,
validateChunks: func(t *testing.T, chunks []dateChunk) {
// first chunk: Jan 1 to May 1 (120 days)
assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start)
assert.Equal(t, time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), chunks[0].end)
// second chunk: May 1 to May 2 (1 day)
assert.Equal(t, time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC), chunks[1].start)
assert.Equal(t, time.Date(2025, 5, 2, 0, 0, 0, 0, time.UTC), chunks[1].end)
},
},
{
name: "multiple chunks for a full year",
start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), // 366 days (leap year)
expectedChunks: 4,
validateChunks: func(t *testing.T, chunks []dateChunk) {
// verify chunks are contiguous (each chunk starts where previous ended)
for i := 1; i < len(chunks); i++ {
assert.Equal(t, chunks[i-1].end, chunks[i].start, "chunks should be contiguous")
}
// verify first and last
assert.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), chunks[0].start)
assert.Equal(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), chunks[len(chunks)-1].end)
},
},
{
name: "empty result when start equals end",
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
expectedChunks: 0,
},
{
name: "empty result when start is after end",
start: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
expectedChunks: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chunks := buildDateChunks(tt.start, tt.end)
assert.Len(t, chunks, tt.expectedChunks)
if tt.validateChunks != nil && len(chunks) > 0 {
tt.validateChunks(t, chunks)
}
})
}
}
func TestBuildDateChunks_ChunkSizeNeverExceeds120Days(t *testing.T) {
// test with various date ranges to ensure no chunk exceeds 120 days
testCases := []struct {
start time.Time
end time.Time
}{
{time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)},
{time.Date(2023, 6, 15, 0, 0, 0, 0, time.UTC), time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)},
{time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC)},
}
for _, tc := range testCases {
chunks := buildDateChunks(tc.start, tc.end)
for i, chunk := range chunks {
days := chunk.end.Sub(chunk.start).Hours() / 24
assert.LessOrEqual(t, days, float64(maxDateRangeDays),
"chunk %d exceeds max days: start=%s, end=%s, days=%.0f",
i, chunk.start.Format("2006-01-02"), chunk.end.Format("2006-01-02"), days)
}
}
}
// mockResponseBody creates an io.ReadCloser from a string for testing
type mockReadCloser struct {
io.Reader
closed bool
}
func (m *mockReadCloser) Close() error {
m.closed = true
return nil
}
func newMockResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: &mockReadCloser{Reader: strings.NewReader(body)},
}
}
func TestCheckHTTPStatus(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
expectError bool
errorContains string
}{
{
name: "200 OK returns nil",
statusCode: http.StatusOK,
body: `{"totalResults": 0}`,
expectError: false,
},
{
name: "404 returns error",
statusCode: http.StatusNotFound,
body: "",
expectError: true,
errorContains: "status 404",
},
{
name: "400 Bad Request returns error with body",
statusCode: http.StatusBadRequest,
body: "Both lastModStartDate and lastModEndDate are required",
expectError: true,
errorContains: "lastModStartDate and lastModEndDate are required",
},
{
name: "500 Internal Server Error returns error",
statusCode: http.StatusInternalServerError,
body: "Internal server error",
expectError: true,
errorContains: "status 500",
},
{
name: "503 Service Unavailable returns error",
statusCode: http.StatusServiceUnavailable,
body: "Service temporarily unavailable",
expectError: true,
errorContains: "status 503",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := newMockResponse(tt.statusCode, tt.body)
err := checkHTTPStatus(resp)
if tt.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errorContains)
// verify body was closed on error
assert.True(t, resp.Body.(*mockReadCloser).closed, "response body should be closed on error")
} else {
require.NoError(t, err)
// verify body was NOT closed on success (caller needs to read it)
assert.False(t, resp.Body.(*mockReadCloser).closed, "response body should not be closed on success")
}
})
}
}
func TestParseRetryAfter(t *testing.T) {
tests := []struct {
name string
header string
expected time.Duration
}{
{
name: "empty header returns 0",
header: "",
expected: 0,
},
{
name: "numeric seconds",
header: "30",
expected: 30 * time.Second,
},
{
name: "numeric seconds - larger value",
header: "120",
expected: 120 * time.Second,
},
{
name: "invalid value returns 0",
header: "not-a-number",
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseRetryAfter(tt.header)
assert.Equal(t, tt.expected, result)
})
}
}