fix: protect against traversal in file source

Signed-off-by: Christopher Phillips <32073428+spiffcs@users.noreply.github.com>
This commit is contained in:
Christopher Phillips 2025-10-13 12:11:59 -04:00
parent 69baca8804
commit 3f14eb7eaf
No known key found for this signature in database
4 changed files with 96 additions and 5 deletions

View File

@ -152,7 +152,7 @@ func ContentsFromZip(ctx context.Context, archivePath string, paths ...string) (
// UnzipToDir extracts a zip archive to a target directory.
func UnzipToDir(ctx context.Context, archivePath, targetDir string) error {
visitor := func(_ context.Context, file archives.FileInfo) error {
joinedPath, err := safeJoin(targetDir, file.NameInArchive)
joinedPath, err := SafeJoin(targetDir, file.NameInArchive)
if err != nil {
return err
}
@ -163,8 +163,8 @@ func UnzipToDir(ctx context.Context, archivePath, targetDir string) error {
return TraverseFilesInZip(ctx, archivePath, visitor)
}
// safeJoin ensures that any destinations do not resolve to a path above the prefix path.
func safeJoin(prefix string, dest ...string) (string, error) {
// SafeJoin ensures that any destinations do not resolve to a path above the prefix path.
func SafeJoin(prefix string, dest ...string) (string, error) {
joinResult := filepath.Join(append([]string{prefix}, dest...)...)
cleanJoinResult := filepath.Clean(joinResult)
if !strings.HasPrefix(cleanJoinResult, filepath.Clean(prefix)) {

View File

@ -308,7 +308,7 @@ func TestSafeJoin(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%+v:%+v", test.prefix, test.args), func(t *testing.T) {
actual, err := safeJoin(test.prefix, test.args...)
actual, err := SafeJoin(test.prefix, test.args...)
test.errAssertion(t, err)
assert.Equal(t, test.expected, actual)
})

View File

@ -264,7 +264,11 @@ func unarchiveToTmp(path string, unarchiver archives.Extractor) (string, func()
}
visitor := func(_ context.Context, file archives.FileInfo) error {
destPath := filepath.Join(tempDir, file.NameInArchive)
destPath, err := intFile.SafeJoin(tempDir, file.NameInArchive)
if err != nil {
return fmt.Errorf("unsafe path in archive (potential path traversal): %w", err)
}
if file.IsDir() {
return os.MkdirAll(destPath, file.Mode())
}

View File

@ -1,6 +1,7 @@
package filesource
import (
"archive/tar"
"io"
"os"
"os/exec"
@ -318,3 +319,89 @@ func Test_FileSource_ID(t *testing.T) {
})
}
}
func TestUnarchiveToTmp_PathTraversalProtection(t *testing.T) {
// This test verifies that malicious archives with path traversal attempts
// (e.g., ../../../etc/passwd) are properly blocked by SafeJoin
testutil.Chdir(t, "..") // run with source/test-fixtures
// Create a malicious tar archive with path traversal attempts
tempDir := t.TempDir()
maliciousArchive := filepath.Join(tempDir, "malicious.tar")
// Create a temporary directory with a file that we'll add to the archive
sourceDir := filepath.Join(tempDir, "source")
require.NoError(t, os.MkdirAll(sourceDir, 0755))
testFile := filepath.Join(sourceDir, "test.txt")
require.NoError(t, os.WriteFile(testFile, []byte("malicious content"), 0644))
// Create a malicious tar manually using Go's archive/tar
// This allows us to inject path traversal entries
archiveFile, err := os.Create(maliciousArchive)
require.NoError(t, err)
defer archiveFile.Close()
tw := tar.NewWriter(archiveFile)
defer tw.Close()
// Add a file with path traversal in its name
content := []byte("malicious content")
header := &tar.Header{
Name: "../../../tmp/malicious.txt",
Mode: 0644,
Size: int64(len(content)),
}
require.NoError(t, tw.WriteHeader(header))
_, err = tw.Write(content)
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, archiveFile.Close())
// Attempt to create a source from the malicious archive
// This should fail due to path traversal protection
cfg := Config{
Path: maliciousArchive,
SkipExtractArchive: false,
}
src, err := New(cfg)
// We expect an error containing "path traversal" or "unsafe path"
if err == nil {
if src != nil {
src.Close()
}
t.Fatal("expected error when extracting archive with path traversal, but got none")
}
// Verify the error message indicates path traversal was detected
assert.Contains(t, err.Error(), "path traversal",
"error should mention path traversal, got: %v", err)
}
func TestUnarchiveToTmp_LegitimateArchive(t *testing.T) {
// This test verifies that legitimate archives without path traversal work correctly
testutil.Chdir(t, "..") // run with source/test-fixtures
archivePath := setupArchiveTest(t, "test-fixtures/path-detected", false)
cfg := Config{
Path: archivePath,
SkipExtractArchive: false,
}
src, err := New(cfg)
require.NoError(t, err, "legitimate archive should extract without error")
require.NotNil(t, src)
t.Cleanup(func() {
require.NoError(t, src.Close())
})
// Verify we can access the resolver
res, err := src.FileResolver(source.SquashedScope)
require.NoError(t, err)
require.NotNil(t, res)
}