diff --git a/file/file.go b/file/file.go index 5ac0dbb..c24dc10 100644 --- a/file/file.go +++ b/file/file.go @@ -18,6 +18,7 @@ import ( "github.com/asaskevich/govalidator" "github.com/pkg/errors" + sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" "gopkg.in/yaml.v3" ) @@ -34,6 +35,36 @@ func FileExists(filename string) bool { return !info.IsDir() } +// FileExistsIn checks if the file exists in the allowed paths +func FileExistsIn(file string, allowedPaths ...string) (string, error) { + fileAbsPath, err := filepath.Abs(file) + if err != nil { + return "", err + } + + uniqAllowedPaths := sliceutil.Dedupe(allowedPaths) + + for _, allowedPath := range uniqAllowedPaths { + allowedAbsPath, err := filepath.Abs(allowedPath) + if err != nil { + return "", err + } + // reject any path that for some reason was cleaned up and starts with . + if stringsutil.HasPrefixAny(allowedAbsPath, ".") { + return "", errors.New("invalid path") + } + + allowedDirPath := allowedAbsPath + if filepath.Ext(allowedAbsPath) != "" { + allowedDirPath = filepath.Dir(allowedAbsPath) + } + if strings.HasPrefix(fileAbsPath, allowedDirPath) && FileExists(fileAbsPath) { + return allowedDirPath, nil + } + } + return "", errors.New("no allowed path found") +} + // FolderExists checks if the folder exists func FolderExists(foldername string) bool { info, err := os.Stat(foldername) diff --git a/file/file_test.go b/file/file_test.go index 9f9d07c..53b9834 100644 --- a/file/file_test.go +++ b/file/file_test.go @@ -581,3 +581,54 @@ func TestOpenOrCreateFile(t *testing.T) { require.Error(t, err) }) } + +func TestFileExistsIn(t *testing.T) { + tempDir := t.TempDir() + anotherTempDir := t.TempDir() + tempFile := filepath.Join(tempDir, "file.txt") + err := os.WriteFile(tempFile, []byte("content"), 0644) + if err != nil { + t.Fatalf("failed to write to temporary file: %v", err) + } + defer os.RemoveAll(tempFile) + + tests := []struct { + name string + file string + allowedFiles []string + expectedPath string + expectedErr bool + }{ + { + name: "file exists in allowed directory", + file: tempFile, + allowedFiles: []string{filepath.Join(tempDir, "tempfile.txt")}, + expectedPath: tempDir, + expectedErr: false, + }, + { + name: "file does not exist in allowed directory", + file: tempFile, + allowedFiles: []string{anotherTempDir}, + expectedPath: "", + expectedErr: true, + }, + { + name: "path starting with .", + file: tempFile, + allowedFiles: []string{"."}, + expectedPath: "", + expectedErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + allowedPath, err := FileExistsIn(tc.file, tc.allowedFiles...) + gotErr := err != nil + require.Equal(t, tc.expectedErr, gotErr, "expected err but got %v", gotErr) + require.Equal(t, tc.expectedPath, allowedPath) + + }) + } +}