From 2cc87140edfa5877021b85a39d8954532fe8cdbd Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Mon, 29 Jul 2024 15:37:58 +0530 Subject: [PATCH] Add basic endpoint to unblock testing --- server/cmd/museum/main.go | 6 + server/ente/fileobjects/type.go | 45 +++++ .../migrations/89_derived_data_table.up.sql | 3 +- server/pkg/api/file.go | 46 +++++ server/pkg/api/file_preview.go | 1 + server/pkg/controller/file.go | 2 + server/pkg/controller/file_preview.go | 156 ++++++++++++++++ server/pkg/controller/filedata/controller.go | 4 + server/pkg/controller/filedata/file_object.go | 5 + server/pkg/controller/preview/controller.go | 8 + server/pkg/repo/filedata/repository.go | 172 ++++++++++++++++++ 11 files changed, 446 insertions(+), 2 deletions(-) create mode 100644 server/ente/fileobjects/type.go create mode 100644 server/pkg/api/file_preview.go create mode 100644 server/pkg/controller/file_preview.go create mode 100644 server/pkg/controller/filedata/controller.go create mode 100644 server/pkg/controller/filedata/file_object.go create mode 100644 server/pkg/controller/preview/controller.go create mode 100644 server/pkg/repo/filedata/repository.go diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 9258fa9b77..3fb5359289 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -408,6 +408,12 @@ func main() { privateAPI.GET("/files/download/v2/:fileID", fileHandler.Get) privateAPI.GET("/files/preview/:fileID", fileHandler.GetThumbnail) privateAPI.GET("/files/preview/v2/:fileID", fileHandler.GetThumbnail) + + privateAPI.GET("/files/file-data/playlist/:fileID", fileHandler.GetVideoPlaylist) + privateAPI.POST("/files/file-data/playlist", fileHandler.ReportVideoPlayList) + privateAPI.GET("/files/file-data/preview/upload-url/:fileID", fileHandler.GetVideoUploadURL) + privateAPI.GET("/files/file-data/preview/:fileID", fileHandler.GetVideoUploadURL) + privateAPI.POST("/files", fileHandler.CreateOrUpdate) privateAPI.POST("/files/copy", fileHandler.CopyFiles) privateAPI.PUT("/files/update", fileHandler.Update) diff --git a/server/ente/fileobjects/type.go b/server/ente/fileobjects/type.go new file mode 100644 index 0000000000..8c5ce65919 --- /dev/null +++ b/server/ente/fileobjects/type.go @@ -0,0 +1,45 @@ +package fileobjects + +import ( + "database/sql/driver" + "errors" + "fmt" +) + +type Type string + +const ( + OriginalFile Type = "file" + OriginalThumbnail Type = "thumb" + PreviewImage Type = "previewImage" + PreviewVideo Type = "previewVideo" + Derived Type = "derived" +) + +func (ft Type) IsValid() bool { + switch ft { + case OriginalFile, OriginalThumbnail, PreviewImage, PreviewVideo, Derived: + return true + } + return false +} + +func (ft *Type) Scan(value interface{}) error { + strValue, ok := value.(string) + if !ok { + return errors.New("type should be a string") + } + + *ft = Type(strValue) + if !ft.IsValid() { + return fmt.Errorf("invalid FileType value: %s", strValue) + } + return nil +} + +func (ft Type) Value() (driver.Value, error) { + if !ft.IsValid() { + return nil, fmt.Errorf("invalid FileType value: %s", ft) + } + return string(ft), nil +} diff --git a/server/migrations/89_derived_data_table.up.sql b/server/migrations/89_derived_data_table.up.sql index 4ba869c71f..548af5f12a 100644 --- a/server/migrations/89_derived_data_table.up.sql +++ b/server/migrations/89_derived_data_table.up.sql @@ -12,14 +12,13 @@ CREATE TABLE file_data ( -- following field contains list of buckets from where we need to delete the data as the given data_type will not longer be persisted in that dc delete_from_buckets s3region[] NOT NULL DEFAULT '{}', pending_sync BOOLEAN NOT NULL DEFAULT false, + is_deleted BOOLEAN NOT NULL DEFAULT false, last_sync_time BIGINT NOT NULL DEFAULT 0, created_at BIGINT NOT NULL DEFAULT now_utc_micro_seconds(), updated_at BIGINT NOT NULL DEFAULT now_utc_micro_seconds(), PRIMARY KEY (file_id, data_type) ); --- Add primary key -ALTER TABLE file_data ADD PRIMARY KEY (file_id, data_type); -- Add index for user_id and data_type CREATE INDEX idx_file_data_user_id_data_type ON file_data (user_id, data_type); diff --git a/server/pkg/api/file.go b/server/pkg/api/file.go index 064bc3be08..94760049af 100644 --- a/server/pkg/api/file.go +++ b/server/pkg/api/file.go @@ -120,6 +120,52 @@ func (h *FileHandler) GetUploadURLs(c *gin.Context) { }) } +func (h *FileHandler) GetVideoUploadURL(c *gin.Context) { + enteApp := auth.GetApp(c) + userID, fileID := getUserAndFileIDs(c) + urls, err := h.Controller.GetVideoUploadUrl(c, userID, fileID, enteApp) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, urls) +} + +func (h *FileHandler) GetVideoPreviewUrl(c *gin.Context) { + userID, fileID := getUserAndFileIDs(c) + url, err := h.Controller.GetPreviewUrl(c, userID, fileID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.Redirect(http.StatusTemporaryRedirect, url) +} + +func (h *FileHandler) ReportVideoPlayList(c *gin.Context) { + var request ente.InsertOrUpdateEmbeddingRequest + if err := c.ShouldBindJSON(&request); err != nil { + handler.Error(c, + stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) + return + } + err := h.Controller.ReportVideoPreview(c, request) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.Status(http.StatusOK) +} + +func (h *FileHandler) GetVideoPlaylist(c *gin.Context) { + fileID, _ := strconv.ParseInt(c.Param("fileID"), 10, 64) + response, err := h.Controller.GetPlaylist(c, fileID) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, response) +} + // GetMultipartUploadURLs returns an array of PartUpload PresignedURLs func (h *FileHandler) GetMultipartUploadURLs(c *gin.Context) { enteApp := auth.GetApp(c) diff --git a/server/pkg/api/file_preview.go b/server/pkg/api/file_preview.go new file mode 100644 index 0000000000..778f64ec17 --- /dev/null +++ b/server/pkg/api/file_preview.go @@ -0,0 +1 @@ +package api diff --git a/server/pkg/controller/file.go b/server/pkg/controller/file.go index 5bb4f47bf9..3d51b6f2db 100644 --- a/server/pkg/controller/file.go +++ b/server/pkg/controller/file.go @@ -284,6 +284,8 @@ func (c *FileController) GetUploadURLs(ctx context.Context, userID int64, count return urls, nil } + + // GetFileURL verifies permissions and returns a presigned url to the requested file func (c *FileController) GetFileURL(ctx *gin.Context, userID int64, fileID int64) (string, error) { err := c.verifyFileAccess(userID, fileID) diff --git a/server/pkg/controller/file_preview.go b/server/pkg/controller/file_preview.go new file mode 100644 index 0000000000..1aa3bfb1dc --- /dev/null +++ b/server/pkg/controller/file_preview.go @@ -0,0 +1,156 @@ +package controller + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/ente-io/museum/ente" + "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/network" + "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "strconv" + "strings" +) + +const ( + _model = "hls_video" +) + +// GetUploadURLs returns a bunch of presigned URLs for uploading files +func (c *FileController) GetVideoUploadUrl(ctx context.Context, userID int64, fileID int64, app ente.App) (*ente.UploadURL, error) { + err := c.UsageCtrl.CanUploadFile(ctx, userID, nil, app) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + s3Client := c.S3Config.GetDerivedStorageS3Client() + dc := c.S3Config.GetDerivedStorageDataCenter() + bucket := c.S3Config.GetDerivedStorageBucket() + objectKey := strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/" + _model + url, err := c.getObjectURL(s3Client, dc, bucket, objectKey) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + log.Infof("Got upload URL for %s", objectKey) + return &url, nil +} + +func (c *FileController) GetPreviewUrl(ctx context.Context, userID int64, fileID int64) (string, error) { + err := c.verifyFileAccess(userID, fileID) + if err != nil { + return "", err + } + objectKey := strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/hls_video" + s3Client := c.S3Config.GetDerivedStorageS3Client() + r, _ := s3Client.GetObjectRequest(&s3.GetObjectInput{ + Bucket: c.S3Config.GetDerivedStorageBucket(), + Key: &objectKey, + }) + return r.Presign(PreSignedRequestValidityDuration) +} + +func (c *FileController) GetPlaylist(ctx *gin.Context, fileID int64) (ente.EmbeddingObject, error) { + objectKey := strconv.FormatInt(auth.GetUserID(ctx.Request.Header), 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/hls_video_playlist.m3u8" + // check if object exists + err := c.checkObjectExists(ctx, objectKey, c.S3Config.GetDerivedStorageDataCenter()) + if err != nil { + return ente.EmbeddingObject{}, stacktrace.Propagate(ente.NewBadRequestWithMessage("Video playlist does not exist"), fmt.Sprintf("objectKey: %s", objectKey)) + } + return c.downloadObject(ctx, objectKey, c.S3Config.GetDerivedStorageDataCenter()) +} + +func (c *FileController) ReportVideoPreview(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) error { + userID := auth.GetUserID(ctx.Request.Header) + if strings.Compare(req.Model, "hls_video") != 0 { + return stacktrace.Propagate(ente.NewBadRequestWithMessage("Model should be hls_video"), "Invalid fileID") + } + count, err := c.CollectionRepo.GetCollectionCount(req.FileID) + if err != nil { + return stacktrace.Propagate(err, "") + } + if count < 1 { + return stacktrace.Propagate(ente.ErrNotFound, "") + } + version := 1 + if req.Version != nil { + version = *req.Version + } + objectKey := strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(req.FileID, 10) + "/hls_video" + playlistKey := objectKey + "_playlist.m3u8" + + // verify that objectKey exists + err = c.checkObjectExists(ctx, objectKey, c.S3Config.GetDerivedStorageDataCenter()) + if err != nil { + return stacktrace.Propagate(ente.NewBadRequestWithMessage("Video object does not exist, upload that before playlist reporting"), fmt.Sprintf("objectKey: %s", objectKey)) + } + + obj := ente.EmbeddingObject{ + Version: version, + EncryptedEmbedding: req.EncryptedEmbedding, + DecryptionHeader: req.DecryptionHeader, + Client: network.GetClientInfo(ctx), + } + _, uploadErr := c.uploadObject(obj, playlistKey, c.S3Config.GetDerivedStorageDataCenter()) + if uploadErr != nil { + log.Error(uploadErr) + return stacktrace.Propagate(uploadErr, "") + } + return nil +} + +func (c *FileController) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) { + embeddingObj, _ := json.Marshal(obj) + s3Client := c.S3Config.GetS3Client(dc) + s3Bucket := c.S3Config.GetBucket(dc) + uploader := s3manager.NewUploaderWithClient(&s3Client) + up := s3manager.UploadInput{ + Bucket: s3Bucket, + Key: &key, + Body: bytes.NewReader(embeddingObj), + } + result, err := uploader.Upload(&up) + if err != nil { + log.Error(err) + return -1, stacktrace.Propagate(err, "") + } + + log.Infof("Uploaded to bucket %s", result.Location) + return len(embeddingObj), nil +} + +func (c *FileController) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { + var obj ente.EmbeddingObject + buff := &aws.WriteAtBuffer{} + bucket := c.S3Config.GetBucket(dc) + s3Client := c.S3Config.GetS3Client(dc) + downloader := s3manager.NewDownloaderWithClient(&s3Client) + _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ + Bucket: bucket, + Key: &objectKey, + }) + if err != nil { + return obj, err + } + err = json.Unmarshal(buff.Bytes(), &obj) + if err != nil { + return obj, stacktrace.Propagate(err, "unmarshal failed") + } + return obj, nil +} + +func (c *FileController) checkObjectExists(ctx context.Context, objectKey string, dc string) error { + s3Client := c.S3Config.GetS3Client(dc) + _, err := s3Client.HeadObject(&s3.HeadObjectInput{ + Bucket: c.S3Config.GetBucket(dc), + Key: &objectKey, + }) + if err != nil { + return err + } + return nil +} diff --git a/server/pkg/controller/filedata/controller.go b/server/pkg/controller/filedata/controller.go new file mode 100644 index 0000000000..cb86174825 --- /dev/null +++ b/server/pkg/controller/filedata/controller.go @@ -0,0 +1,4 @@ +package filedata + +type Controller struct { +} diff --git a/server/pkg/controller/filedata/file_object.go b/server/pkg/controller/filedata/file_object.go new file mode 100644 index 0000000000..aa627fb82c --- /dev/null +++ b/server/pkg/controller/filedata/file_object.go @@ -0,0 +1,5 @@ +package filedata + +func (c *Controller) f() { + +} diff --git a/server/pkg/controller/preview/controller.go b/server/pkg/controller/preview/controller.go new file mode 100644 index 0000000000..e5ecaa71f0 --- /dev/null +++ b/server/pkg/controller/preview/controller.go @@ -0,0 +1,8 @@ +package preview + +type Controller struct { +} + +func NewController() *Controller { + return &Controller{} +} diff --git a/server/pkg/repo/filedata/repository.go b/server/pkg/repo/filedata/repository.go new file mode 100644 index 0000000000..b06ddcfded --- /dev/null +++ b/server/pkg/repo/filedata/repository.go @@ -0,0 +1,172 @@ +package filedata + +import ( + "context" + "database/sql" + "github.com/ente-io/stacktrace" + "github.com/lib/pq" + "github.com/pkg/errors" +) + +// FileData represents the structure of the file_data table. +type FileData struct { + FileID int64 + UserID int64 + DataType string + Size int64 + LatestBucket string + ReplicatedBuckets []string + DeleteFromBuckets []string + PendingSync bool + IsDeleted bool + LastSyncTime int64 + CreatedAt int64 + UpdatedAt int64 +} + +// Repository defines the methods for inserting, updating, and retrieving file data. +type Repository struct { + DB *sql.DB +} + +// Insert inserts a new file_data record +func (r *Repository) Insert(ctx context.Context, data FileData) error { + query := ` + INSERT INTO file_data + (file_id, user_id, data_type, size, latest_bucket, replicated_buckets) + VALUES + ($1, $2, $3, $4, $5, $6) + ON CONFLICT (file_id, data_type) + DO UPDATE SET + size = $4, + latest_bucket = $5, + replicated_buckets = $6 ` + _, err := r.DB.ExecContext(ctx, query, + data.FileID, data.UserID, data.DataType, data.Size, data.LatestBucket, pq.Array(data.ReplicatedBuckets)) + if err != nil { + return stacktrace.Propagate(err, "failed to insert file data") + } + return nil +} + +// UpdateReplicatedBuckets updates the replicated_buckets for a given file and data type. +func (r *Repository) UpdateReplicatedBuckets(ctx context.Context, fileID int64, dataType string, newBuckets []string, previousUpdatedAt int64) error { + query := ` + UPDATE file_data + SET replicated_buckets = $1, updated_at = now_utc_micro_seconds() + WHERE file_id = $2 AND data_type = $3 AND updated_at = $4` + res, err := r.DB.ExecContext(ctx, query, pq.Array(newBuckets), fileID, dataType, previousUpdatedAt) + if err != nil { + return errors.Wrap(err, "failed to update replicated buckets") + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return errors.Wrap(err, "failed to check rows affected") + } + if rowsAffected == 0 { + return errors.New("no rows were updated, possible concurrent modification") + } + return nil +} + +// UpdateDeleteFromBuckets updates the delete_from_buckets for a given file and data type. +func (r *Repository) UpdateDeleteFromBuckets(ctx context.Context, fileID int64, dataType string, newBuckets []string, previousUpdatedAt int64) error { + query := ` + UPDATE file_data + SET delete_from_buckets = $1, updated_at = now_utc_micro_seconds() + WHERE file_id = $2 AND data_type = $3 AND updated_at = $4` + res, err := r.DB.ExecContext(ctx, query, pq.Array(newBuckets), fileID, dataType, previousUpdatedAt) + if err != nil { + return errors.Wrap(err, "failed to update delete from buckets") + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return errors.Wrap(err, "failed to check rows affected") + } + if rowsAffected == 0 { + return errors.New("no rows were updated, possible concurrent modification") + } + return nil +} + +// DeleteFileData deletes a file_data record by file_id and data_type if both replicated_buckets and delete_from_buckets are empty. +func (r *Repository) DeleteFileData(ctx context.Context, fileID int64, dataType string, previousUpdatedAt int64) error { + // First, check if both replicated_buckets and delete_from_buckets are empty. + var replicatedBuckets, deleteFromBuckets []string + query := `SELECT replicated_buckets, delete_from_buckets FROM file_data WHERE file_id = $1 AND data_type = $2` + err := r.DB.QueryRowContext(ctx, query, fileID, dataType).Scan(pq.Array(&replicatedBuckets), pq.Array(&deleteFromBuckets)) + if err != nil { + if err == sql.ErrNoRows { + return errors.New("no file data found for the given file_id and data_type") + } + return errors.Wrap(err, "failed to check buckets before deleting file data") + } + + if len(replicatedBuckets) > 0 || len(deleteFromBuckets) > 0 { + return errors.New("cannot delete file data with non-empty replicated_buckets or delete_from_buckets") + } + + // Proceed with deletion if both arrays are empty and updated_at matches. + deleteQuery := `DELETE FROM file_data WHERE file_id = $1 AND data_type = $2 AND updated_at = $3` + res, err := r.DB.ExecContext(ctx, deleteQuery, fileID, dataType, previousUpdatedAt) + if err != nil { + return errors.Wrap(err, "failed to delete file data") + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return errors.Wrap(err, "failed to check rows affected") + } + if rowsAffected == 0 { + return errors.New("no rows were deleted, possible concurrent modification") + } + return nil +} + +// GetFileData retrieves a single file_data record by file_id and data_type. +func (r *Repository) GetFileData(ctx context.Context, fileID int64, dataType string) (FileData, error) { + var data FileData + query := `SELECT file_id, user_id, data_type, size, latest_bucket, replicated_buckets, delete_from_buckets, pending_sync, is_deleted, last_sync_time, created_at, updated_at + FROM file_data + WHERE file_id = $1 AND data_type = $2` + err := r.DB.QueryRowContext(ctx, query, fileID, dataType).Scan( + &data.FileID, &data.UserID, &data.DataType, &data.Size, &data.LatestBucket, pq.Array(&data.ReplicatedBuckets), pq.Array(&data.DeleteFromBuckets), &data.PendingSync, &data.IsDeleted, &data.LastSyncTime, &data.CreatedAt, &data.UpdatedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return FileData{}, errors.Wrap(err, "no file data found") + } + return FileData{}, errors.Wrap(err, "failed to retrieve file data") + } + return data, nil +} + +// ListFileData retrieves all file_data records for a given user_id. +func (r *Repository) ListFileData(ctx context.Context, userID int64) ([]FileData, error) { + query := `SELECT file_id, user_id, data_type, size, latest_bucket, replicated_buckets, delete_from_buckets, pending_sync, is_deleted, last_sync_time, created_at, updated_at + FROM file_data + WHERE user_id = $1` + rows, err := r.DB.QueryContext(ctx, query, userID) + if err != nil { + return nil, errors.Wrap(err, "failed to list file data") + } + defer rows.Close() + + var fileDataList []FileData + for rows.Next() { + var data FileData + err := rows.Scan( + &data.FileID, &data.UserID, &data.DataType, &data.Size, &data.LatestBucket, pq.Array(&data.ReplicatedBuckets), pq.Array(&data.DeleteFromBuckets), &data.PendingSync, &data.IsDeleted, &data.LastSyncTime, &data.CreatedAt, &data.UpdatedAt, + ) + if err != nil { + return nil, errors.Wrap(err, "failed to scan file data row") + } + fileDataList = append(fileDataList, data) + } + if err = rows.Err(); err != nil { + return nil, errors.Wrap(err, "error iterating file data rows") + } + return fileDataList, nil +}