Skip to content

Commit

Permalink
Refactoring result construction for embeddings #48
Browse files Browse the repository at this point in the history
  • Loading branch information
MrCsabaToth committed Oct 18, 2024
1 parent 756669a commit 7cb84eb
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions functions/fn_impl/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,7 @@ def embed(req: https_fn.Request) -> https_fn.Response:
region = 'us-central1'
vertexai.init(project=project_id, location=region)

multi_modal_embedding_model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")

image = VMImage.load_from_file(image_path) if image_path else None
video = VMVideo.load_from_file(video_path) if video_path else None
dimension = 1408

embeddings = multi_modal_embedding_model.get_embeddings(
image=image,
video=video,
video_segment_config=VideoSegmentConfig(),
contextual_text=text,
dimension=dimension,
) if image or video else []

embeddings = []
if text:
# Multi-lingual text embedding
# The task type for embedding. Check the available tasks in the model's documentation.
Expand All @@ -99,17 +86,42 @@ def embed(req: https_fn.Request) -> https_fn.Response:
kwargs = dict(output_dimensionality=768)
try:
text_embeddings = multi_lingual_embedding_model.get_embeddings(inputs, **kwargs)
embeddings.extend(text_embeddings)
embeddings.append([embedding.values for embedding in text_embeddings])
except Exception as e:
client = google.cloud.logging.Client()
client.setup_logging()
logging.exception(e)
return embeddings, 500

embedding_values = [embedding.values for embedding in embeddings]
if image_path or video_path:
multi_modal_embedding_model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")

image = VMImage.load_from_file(image_path) if image_path else None
video = VMVideo.load_from_file(video_path) if video_path else None
dimension = 1408

try:
multi_modal_embeddings = multi_modal_embedding_model.get_embeddings(
image=image,
video=video,
video_segment_config=VideoSegmentConfig(),
contextual_text=text,
dimension=dimension,
)

if multi_modal_embeddings:
embeddings.append(multi_modal_embeddings.image_embedding if image else [])
if video and multi_modal_embeddings.video_embeddings:
for video_embedding in multi_modal_embeddings.video_embeddings:
embeddings.append(video_embedding.embedding)
except Exception as e:
client = google.cloud.logging.Client()
client.setup_logging()
logging.exception(e)
return embeddings, 500

return https_fn.Response(
json.dumps(dict(data=embedding_values)),
json.dumps(dict(data=embeddings)),
status=200,
content_type='application/json',
)

0 comments on commit 7cb84eb

Please sign in to comment.