diff --git a/internal/storage/cp/cp.go b/internal/storage/cp/cp.go index e55120a1c..d1912b2c9 100644 --- a/internal/storage/cp/cp.go +++ b/internal/storage/cp/cp.go @@ -52,7 +52,7 @@ func Run(ctx context.Context, src, dst string, recursive bool, maxJobs uint, fsy if recursive { return UploadStorageObjectAll(ctx, api, dstParsed.Path, localPath, maxJobs, fsys, opts...) } - return api.UploadObject(ctx, dstParsed.Path, src, fsys, opts...) + return api.UploadObject(ctx, dstParsed.Path, src, afero.NewIOFS(fsys), opts...) } else if strings.EqualFold(srcParsed.Scheme, client.STORAGE_SCHEME) && strings.EqualFold(dstParsed.Scheme, client.STORAGE_SCHEME) { return errors.New("Copying between buckets is not supported") } @@ -149,7 +149,7 @@ func UploadStorageObjectAll(ctx context.Context, api storage.StorageAPI, remoteP } fmt.Fprintln(os.Stderr, "Uploading:", filePath, "=>", dstPath) job := func() error { - err := api.UploadObject(ctx, dstPath, filePath, fsys, opts...) + err := api.UploadObject(ctx, dstPath, filePath, afero.NewIOFS(fsys), opts...) if err != nil && strings.Contains(err.Error(), `"error":"Bucket not found"`) { // Retry after creating bucket if bucket, prefix := client.SplitBucketPrefix(dstPath); len(prefix) > 0 { @@ -162,7 +162,7 @@ func UploadStorageObjectAll(ctx context.Context, api storage.StorageAPI, remoteP if _, err := api.CreateBucket(ctx, body); err != nil { return err } - err = api.UploadObject(ctx, dstPath, filePath, fsys, opts...) + err = api.UploadObject(ctx, dstPath, filePath, afero.NewIOFS(fsys), opts...) } } return err diff --git a/pkg/storage/batch.go b/pkg/storage/batch.go index 18a7d933e..63638e9a9 100644 --- a/pkg/storage/batch.go +++ b/pkg/storage/batch.go @@ -91,17 +91,9 @@ func (s *StorageAPI) UpsertObjects(ctx context.Context, bucketConfig config.Buck } fmt.Fprintln(os.Stderr, "Uploading:", filePath, "=>", dstPath) job := func() error { - f, err := fsys.Open(filePath) - if err != nil { - return errors.Errorf("failed to open file: %w", err) - } - defer f.Close() - fo, err := ParseFileOptions(f, filePath) - if err != nil { - return err - } - fo.Overwrite = true - return s.UploadObjectStream(ctx, dstPath, f, *fo) + return s.UploadObject(ctx, dstPath, filePath, fsys, func(fo *FileOptions) { + fo.Overwrite = true + }) } return jq.Put(job) } diff --git a/pkg/storage/objects.go b/pkg/storage/objects.go index 170e6862b..a60fdac55 100644 --- a/pkg/storage/objects.go +++ b/pkg/storage/objects.go @@ -86,18 +86,11 @@ func ParseFileOptions(f fs.File, localPath string, opts ...func(*FileOptions)) ( } else if _, err = s.Seek(0, io.SeekStart); err != nil { return nil, errors.Errorf("failed to seek file: %w", err) } - // For text/plain content types, we try to determine a more specific type - // based on the file extension, as the initial detection might be too generic - if strings.Contains(fo.ContentType, "text/plain") { - if extensionType := mime.TypeByExtension(filepath.Ext(localPath)); extensionType != "" { - fo.ContentType = extensionType - } - } } return fo, nil } -func (s *StorageAPI) UploadObject(ctx context.Context, remotePath, localPath string, fsys afero.Fs, opts ...func(*FileOptions)) error { +func (s *StorageAPI) UploadObject(ctx context.Context, remotePath, localPath string, fsys fs.FS, opts ...func(*FileOptions)) error { f, err := fsys.Open(localPath) if err != nil { return errors.Errorf("failed to open file: %w", err) @@ -107,6 +100,13 @@ func (s *StorageAPI) UploadObject(ctx context.Context, remotePath, localPath str if err != nil { return err } + // For text/plain content types, we try to determine a more specific type + // based on the file extension, as the initial detection might be too generic + if strings.Contains(fo.ContentType, "text/plain") { + if extensionType := mime.TypeByExtension(filepath.Ext(localPath)); extensionType != "" { + fo.ContentType = extensionType + } + } return s.UploadObjectStream(ctx, remotePath, f, *fo) } diff --git a/pkg/storage/objects_test.go b/pkg/storage/objects_test.go index f7d178996..e5a3bd752 100644 --- a/pkg/storage/objects_test.go +++ b/pkg/storage/objects_test.go @@ -1,14 +1,22 @@ package storage import ( + "context" "mime" + "net/http" "testing" + fs "testing/fstest" - "github.com/spf13/afero" + "github.com/h2non/gock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/supabase/cli/internal/testing/apitest" + "github.com/supabase/cli/pkg/fetcher" ) +var mockApi = StorageAPI{Fetcher: fetcher.NewFetcher( + "http://127.0.0.1", +)} + func TestParseFileOptionsContentTypeDetection(t *testing.T) { tests := []struct { name string @@ -56,7 +64,7 @@ func TestParseFileOptionsContentTypeDetection(t *testing.T) { { name: "respects custom content type", content: []byte("const hello = () => console.log('Hello, World!');"), - filename: "script.js", + filename: "custom.js", wantMimeType: "application/custom", wantCacheCtrl: "max-age=3600", opts: []func(*FileOptions){func(fo *FileOptions) { fo.ContentType = "application/custom" }}, @@ -66,20 +74,20 @@ func TestParseFileOptionsContentTypeDetection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create a temporary file with test content - fs := afero.NewMemMapFs() - require.NoError(t, afero.WriteFile(fs, tt.filename, tt.content, 0644)) - - f, err := fs.Open(tt.filename) - require.NoError(t, err) - defer f.Close() - + fsys := fs.MapFS{tt.filename: &fs.MapFile{Data: tt.content}} + // Setup mock api + defer gock.OffAll() + gock.New("http://127.0.0.1"). + Post("/storage/v1/object/"+tt.filename). + MatchHeader("Content-Type", tt.wantMimeType). + MatchHeader("Cache-Control", tt.wantCacheCtrl). + Reply(http.StatusOK) // Parse options - fo, err := ParseFileOptions(f, tt.filename, tt.opts...) - require.NoError(t, err) - + // err := mockApi.UploadObject(context.Background(), tt.filename, tt.filename, fsys, tt.opts...) + err := mockApi.UploadObject(context.Background(), tt.filename, tt.filename, fsys) // Assert results - assert.Equal(t, tt.wantMimeType, fo.ContentType) - assert.Equal(t, tt.wantCacheCtrl, fo.CacheControl) + assert.NoError(t, err) + assert.Empty(t, apitest.ListUnmatchedRequests()) }) } }