Skip to content

Commit

Permalink
[server] Add timeout while fetching embedding (#1676)
Browse files Browse the repository at this point in the history
## Description

## Tests
  • Loading branch information
vishnukvmd authored May 10, 2024
2 parents 5caa9c5 + 3a70dcd commit 32e8a44
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions server/pkg/controller/embedding/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package embedding

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/ente-io/museum/pkg/utils/array"
"strconv"
"sync"
gTime "time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
Expand Down Expand Up @@ -306,7 +308,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
defer wg.Done()
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore

obj, err := c.getEmbeddingObject(objectKey, downloader)
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
if err != nil {
errs = append(errs, err)
log.Error("error fetching embedding object: "+objectKey, err)
Expand Down Expand Up @@ -343,7 +345,9 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
defer wg.Done()
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
obj, err := c.getEmbeddingObject(objectKey, downloader)
ctx, cancel := context.WithTimeout(context.Background(), gTime.Second*10) // 10 seconds timeout
defer cancel()
obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
if err != nil {
log.Error("error fetching embedding object: "+objectKey, err)
embeddingObjects[i] = embeddingObjectResult{
Expand All @@ -363,21 +367,21 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
return embeddingObjects, nil
}

func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
return c.getEmbeddingObjectWithRetries(objectKey, downloader, 3)
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
}

func (c *Controller) getEmbeddingObjectWithRetries(objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
var obj ente.EmbeddingObject
buff := &aws.WriteAtBuffer{}
_, err := downloader.Download(buff, &s3.GetObjectInput{
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
Bucket: c.S3Config.GetHotBucket(),
Key: &objectKey,
})
if err != nil {
log.Error(err)
if retryCount > 0 {
return c.getEmbeddingObjectWithRetries(objectKey, downloader, retryCount-1)
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
}
return obj, stacktrace.Propagate(err, "")
}
Expand Down

0 comments on commit 32e8a44

Please sign in to comment.