Skip to content

Commit

Permalink
Add basic endpoint to unblock testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ua741 committed Jul 29, 2024
1 parent 950b2bb commit 2cc8714
Show file tree
Hide file tree
Showing 11 changed files with 446 additions and 2 deletions.
6 changes: 6 additions & 0 deletions server/cmd/museum/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,12 @@ func main() {
privateAPI.GET("/files/download/v2/:fileID", fileHandler.Get)
privateAPI.GET("/files/preview/:fileID", fileHandler.GetThumbnail)
privateAPI.GET("/files/preview/v2/:fileID", fileHandler.GetThumbnail)

privateAPI.GET("/files/file-data/playlist/:fileID", fileHandler.GetVideoPlaylist)
privateAPI.POST("/files/file-data/playlist", fileHandler.ReportVideoPlayList)
privateAPI.GET("/files/file-data/preview/upload-url/:fileID", fileHandler.GetVideoUploadURL)
privateAPI.GET("/files/file-data/preview/:fileID", fileHandler.GetVideoUploadURL)

privateAPI.POST("/files", fileHandler.CreateOrUpdate)
privateAPI.POST("/files/copy", fileHandler.CopyFiles)
privateAPI.PUT("/files/update", fileHandler.Update)
Expand Down
45 changes: 45 additions & 0 deletions server/ente/fileobjects/type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package fileobjects

import (
"database/sql/driver"
"errors"
"fmt"
)

type Type string

const (
OriginalFile Type = "file"
OriginalThumbnail Type = "thumb"
PreviewImage Type = "previewImage"
PreviewVideo Type = "previewVideo"
Derived Type = "derived"
)

func (ft Type) IsValid() bool {
switch ft {
case OriginalFile, OriginalThumbnail, PreviewImage, PreviewVideo, Derived:
return true
}
return false
}

func (ft *Type) Scan(value interface{}) error {
strValue, ok := value.(string)
if !ok {
return errors.New("type should be a string")
}

*ft = Type(strValue)
if !ft.IsValid() {
return fmt.Errorf("invalid FileType value: %s", strValue)
}
return nil
}

func (ft Type) Value() (driver.Value, error) {
if !ft.IsValid() {
return nil, fmt.Errorf("invalid FileType value: %s", ft)
}
return string(ft), nil
}
3 changes: 1 addition & 2 deletions server/migrations/89_derived_data_table.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ CREATE TABLE file_data (
-- following field contains list of buckets from where we need to delete the data as the given data_type will not longer be persisted in that dc
delete_from_buckets s3region[] NOT NULL DEFAULT '{}',
pending_sync BOOLEAN NOT NULL DEFAULT false,
is_deleted BOOLEAN NOT NULL DEFAULT false,
last_sync_time BIGINT NOT NULL DEFAULT 0,
created_at BIGINT NOT NULL DEFAULT now_utc_micro_seconds(),
updated_at BIGINT NOT NULL DEFAULT now_utc_micro_seconds(),
PRIMARY KEY (file_id, data_type)
);

-- Add primary key
ALTER TABLE file_data ADD PRIMARY KEY (file_id, data_type);

-- Add index for user_id and data_type
CREATE INDEX idx_file_data_user_id_data_type ON file_data (user_id, data_type);
Expand Down
46 changes: 46 additions & 0 deletions server/pkg/api/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,52 @@ func (h *FileHandler) GetUploadURLs(c *gin.Context) {
})
}

func (h *FileHandler) GetVideoUploadURL(c *gin.Context) {
enteApp := auth.GetApp(c)
userID, fileID := getUserAndFileIDs(c)
urls, err := h.Controller.GetVideoUploadUrl(c, userID, fileID, enteApp)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, urls)
}

func (h *FileHandler) GetVideoPreviewUrl(c *gin.Context) {
userID, fileID := getUserAndFileIDs(c)
url, err := h.Controller.GetPreviewUrl(c, userID, fileID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.Redirect(http.StatusTemporaryRedirect, url)
}

func (h *FileHandler) ReportVideoPlayList(c *gin.Context) {
var request ente.InsertOrUpdateEmbeddingRequest
if err := c.ShouldBindJSON(&request); err != nil {
handler.Error(c,
stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err)))
return
}
err := h.Controller.ReportVideoPreview(c, request)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.Status(http.StatusOK)
}

func (h *FileHandler) GetVideoPlaylist(c *gin.Context) {
fileID, _ := strconv.ParseInt(c.Param("fileID"), 10, 64)
response, err := h.Controller.GetPlaylist(c, fileID)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, response)
}

// GetMultipartUploadURLs returns an array of PartUpload PresignedURLs
func (h *FileHandler) GetMultipartUploadURLs(c *gin.Context) {
enteApp := auth.GetApp(c)
Expand Down
1 change: 1 addition & 0 deletions server/pkg/api/file_preview.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package api
2 changes: 2 additions & 0 deletions server/pkg/controller/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ func (c *FileController) GetUploadURLs(ctx context.Context, userID int64, count
return urls, nil
}



// GetFileURL verifies permissions and returns a presigned url to the requested file
func (c *FileController) GetFileURL(ctx *gin.Context, userID int64, fileID int64) (string, error) {
err := c.verifyFileAccess(userID, fileID)
Expand Down
156 changes: 156 additions & 0 deletions server/pkg/controller/file_preview.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package controller

import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/museum/pkg/utils/network"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"strconv"
"strings"
)

const (
_model = "hls_video"
)

// GetUploadURLs returns a bunch of presigned URLs for uploading files
func (c *FileController) GetVideoUploadUrl(ctx context.Context, userID int64, fileID int64, app ente.App) (*ente.UploadURL, error) {
err := c.UsageCtrl.CanUploadFile(ctx, userID, nil, app)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
s3Client := c.S3Config.GetDerivedStorageS3Client()
dc := c.S3Config.GetDerivedStorageDataCenter()
bucket := c.S3Config.GetDerivedStorageBucket()
objectKey := strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/" + _model
url, err := c.getObjectURL(s3Client, dc, bucket, objectKey)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
log.Infof("Got upload URL for %s", objectKey)
return &url, nil
}

func (c *FileController) GetPreviewUrl(ctx context.Context, userID int64, fileID int64) (string, error) {
err := c.verifyFileAccess(userID, fileID)
if err != nil {
return "", err
}
objectKey := strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/hls_video"
s3Client := c.S3Config.GetDerivedStorageS3Client()
r, _ := s3Client.GetObjectRequest(&s3.GetObjectInput{
Bucket: c.S3Config.GetDerivedStorageBucket(),
Key: &objectKey,
})
return r.Presign(PreSignedRequestValidityDuration)
}

func (c *FileController) GetPlaylist(ctx *gin.Context, fileID int64) (ente.EmbeddingObject, error) {
objectKey := strconv.FormatInt(auth.GetUserID(ctx.Request.Header), 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/hls_video_playlist.m3u8"
// check if object exists
err := c.checkObjectExists(ctx, objectKey, c.S3Config.GetDerivedStorageDataCenter())
if err != nil {
return ente.EmbeddingObject{}, stacktrace.Propagate(ente.NewBadRequestWithMessage("Video playlist does not exist"), fmt.Sprintf("objectKey: %s", objectKey))
}
return c.downloadObject(ctx, objectKey, c.S3Config.GetDerivedStorageDataCenter())
}

func (c *FileController) ReportVideoPreview(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) error {
userID := auth.GetUserID(ctx.Request.Header)
if strings.Compare(req.Model, "hls_video") != 0 {
return stacktrace.Propagate(ente.NewBadRequestWithMessage("Model should be hls_video"), "Invalid fileID")
}
count, err := c.CollectionRepo.GetCollectionCount(req.FileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
if count < 1 {
return stacktrace.Propagate(ente.ErrNotFound, "")
}
version := 1
if req.Version != nil {
version = *req.Version
}
objectKey := strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(req.FileID, 10) + "/hls_video"
playlistKey := objectKey + "_playlist.m3u8"

// verify that objectKey exists
err = c.checkObjectExists(ctx, objectKey, c.S3Config.GetDerivedStorageDataCenter())
if err != nil {
return stacktrace.Propagate(ente.NewBadRequestWithMessage("Video object does not exist, upload that before playlist reporting"), fmt.Sprintf("objectKey: %s", objectKey))
}

obj := ente.EmbeddingObject{
Version: version,
EncryptedEmbedding: req.EncryptedEmbedding,
DecryptionHeader: req.DecryptionHeader,
Client: network.GetClientInfo(ctx),
}
_, uploadErr := c.uploadObject(obj, playlistKey, c.S3Config.GetDerivedStorageDataCenter())
if uploadErr != nil {
log.Error(uploadErr)
return stacktrace.Propagate(uploadErr, "")
}
return nil
}

func (c *FileController) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
embeddingObj, _ := json.Marshal(obj)
s3Client := c.S3Config.GetS3Client(dc)
s3Bucket := c.S3Config.GetBucket(dc)
uploader := s3manager.NewUploaderWithClient(&s3Client)
up := s3manager.UploadInput{
Bucket: s3Bucket,
Key: &key,
Body: bytes.NewReader(embeddingObj),
}
result, err := uploader.Upload(&up)
if err != nil {
log.Error(err)
return -1, stacktrace.Propagate(err, "")
}

log.Infof("Uploaded to bucket %s", result.Location)
return len(embeddingObj), nil
}

func (c *FileController) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
var obj ente.EmbeddingObject
buff := &aws.WriteAtBuffer{}
bucket := c.S3Config.GetBucket(dc)
s3Client := c.S3Config.GetS3Client(dc)
downloader := s3manager.NewDownloaderWithClient(&s3Client)
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
Bucket: bucket,
Key: &objectKey,
})
if err != nil {
return obj, err
}
err = json.Unmarshal(buff.Bytes(), &obj)
if err != nil {
return obj, stacktrace.Propagate(err, "unmarshal failed")
}
return obj, nil
}

func (c *FileController) checkObjectExists(ctx context.Context, objectKey string, dc string) error {
s3Client := c.S3Config.GetS3Client(dc)
_, err := s3Client.HeadObject(&s3.HeadObjectInput{
Bucket: c.S3Config.GetBucket(dc),
Key: &objectKey,
})
if err != nil {
return err
}
return nil
}
4 changes: 4 additions & 0 deletions server/pkg/controller/filedata/controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package filedata

type Controller struct {
}
5 changes: 5 additions & 0 deletions server/pkg/controller/filedata/file_object.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package filedata

func (c *Controller) f() {

}
8 changes: 8 additions & 0 deletions server/pkg/controller/preview/controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package preview

type Controller struct {
}

func NewController() *Controller {
return &Controller{}
}
Loading

0 comments on commit 2cc8714

Please sign in to comment.