Skip to content

Commit

Permalink
Update image task related notebooks
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 670080374
  • Loading branch information
vertex-mg-bot authored and copybara-github committed Sep 2, 2024
1 parent 390eea8 commit 9df704f
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 524 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
{
"cell_type": "code",
"execution_count": null,
"language": "python",
"metadata": {
"cellView": "form",
"id": "QvQjsmIJ6Y3f"
Expand All @@ -111,7 +112,7 @@
"import re\n",
"import uuid\n",
"from datetime import datetime\n",
"from typing import List, Sequence, Tuple\n",
"from typing import Sequence, Tuple\n",
"\n",
"import gradio as gr\n",
"import matplotlib as mpl\n",
Expand Down Expand Up @@ -220,109 +221,6 @@
"}\n",
"\n",
"\n",
"def resize_image(image: Image.Image, new_width: int = 1000) -> Image.Image:\n",
" width, height = image.size\n",
" print(f\"original input image size: {width}, {height}\")\n",
" new_height = int(height * new_width / width)\n",
" new_img = image.resize((new_width, new_height))\n",
" print(f\"resized input image size: {new_width}, {new_height}\")\n",
" return new_img\n",
"\n",
"\n",
"def vqa_predict(\n",
" endpoint: aiplatform.Endpoint,\n",
" image: Image.Image,\n",
" prompts: List[str],\n",
" new_width: int = 1000,\n",
") -> List[str]:\n",
" \"\"\"Predicts the answer to a question about an image using an Endpoint.\"\"\"\n",
" # Resize and convert image to base64 string.\n",
" resized_image = resize_image(image, new_width)\n",
" resized_image_base64 = common_util.image_to_base64(resized_image)\n",
"\n",
" # Format question prompt\n",
" question_prompt_format = \"answer en {}\\n\"\n",
"\n",
" instances = []\n",
" for question_prompt in prompts:\n",
" if question_prompt:\n",
" instances.append(\n",
" {\n",
" \"prompt\": question_prompt_format.format(question_prompt),\n",
" \"image\": resized_image_base64,\n",
" }\n",
" )\n",
"\n",
" response = endpoint.predict(instances=instances)\n",
" return [pred.get(\"response\") for pred in response.predictions]\n",
"\n",
"\n",
"def caption_predict(\n",
" endpoint: aiplatform.Endpoint,\n",
" image: Image.Image = None,\n",
" language_code: str = \"en\",\n",
" new_width: int = 1000,\n",
") -> str:\n",
" \"\"\"Predicts a caption for a given image using an Endpoint.\"\"\"\n",
" # Resize and convert image to base64 string.\n",
" resized_image = resize_image(image, new_width)\n",
" resized_image_base64 = common_util.image_to_base64(resized_image)\n",
"\n",
" # Format caption prompt\n",
" caption_prompt = f\"caption {language_code}\\n\"\n",
"\n",
" instances = [\n",
" {\n",
" \"prompt\": caption_prompt,\n",
" \"image\": resized_image_base64,\n",
" },\n",
" ]\n",
" response = endpoint.predict(instances=instances)\n",
" return response.predictions[0].get(\"response\")\n",
"\n",
"\n",
"def ocr_predict(\n",
" endpoint: aiplatform.Endpoint,\n",
" image: Image.Image = None,\n",
" new_width: int = 1000,\n",
") -> str:\n",
" \"\"\"Extracts text from a given image using an Endpoint.\"\"\"\n",
" # Resize and convert image to base64 string.\n",
" resized_image = resize_image(image, new_width)\n",
" resized_image_base64 = common_util.image_to_base64(resized_image)\n",
"\n",
" instances = [\n",
" {\n",
" \"prompt\": \"ocr\",\n",
" \"image\": resized_image_base64,\n",
" },\n",
" ]\n",
" response = endpoint.predict(instances=instances)\n",
" return response.predictions[0].get(\"response\")\n",
"\n",
"\n",
"def detect_predict(\n",
" endpoint: aiplatform.Endpoint,\n",
" image: Image.Image,\n",
" prompt: str,\n",
" new_width: int = 1000,\n",
"):\n",
" \"\"\"Predicts the answer to a question about an image using an Endpoint.\"\"\"\n",
" # Resize and convert image to base64 string.\n",
" resized_image = resize_image(image, new_width)\n",
" resized_image_base64 = common_util.image_to_base64(resized_image)\n",
"\n",
" instances = [\n",
" {\n",
" \"prompt\": f\"detect {prompt}\",\n",
" \"image\": resized_image_base64,\n",
" }\n",
" ]\n",
"\n",
" response = endpoint.predict(instances=instances)\n",
" return response.predictions[0].get(\"response\")\n",
"\n",
"\n",
"def parse_detections(txt):\n",
" \"\"\"Parses bounding boxes from a detection string.\"\"\"\n",
" bboxes = []\n",
Expand Down Expand Up @@ -547,6 +445,7 @@
"\n",
"# @markdown This section uses the deployed PaliGemma model to answer questions about a given image.\n",
"\n",
"# @markdown Once deployment succeeds, you can send requests to the endpoint with images and questions.\n",
"# @markdown ![](https://images.pexels.com/photos/4012966/pexels-photo-4012966.jpeg?w=1260&h=750)\n",
"image_url = \"https://images.pexels.com/photos/4012966/pexels-photo-4012966.jpeg\" # @param {type:\"string\"}\n",
"\n",
Expand All @@ -568,12 +467,11 @@
" question_prompt_4,\n",
" question_prompt_5,\n",
"]\n",
"questions_list = [question for question in questions_list if question]\n",
"\n",
"questions = [question for question in questions_list if question]\n",
"\n",
"answers = vqa_predict(endpoints[\"endpoint\"], image, questions_list)\n",
"answers = common_util.vqa_predict(endpoints[\"endpoint\"], questions, image)\n",
"\n",
"for question, answer in zip(questions_list, answers):\n",
"for question, answer in zip(questions, answers):\n",
" print(f\"Question: {question}\")\n",
" print(f\"Answer: {answer}\")\n",
"# @markdown Click \"Show Code\" to see more details."
Expand All @@ -589,20 +487,24 @@
"outputs": [],
"source": [
"# @title Image Captioning\n",
"\n",
"# @markdown This section uses the deployed PaliGemma model to caption and describe an image in a chosen language.\n",
"\n",
"# @markdown ![](https://images.pexels.com/photos/20427316/pexels-photo-20427316/free-photo-of-a-moped-parked-in-front-of-a-blue-door.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2)\n",
"caption_prompt = True\n",
"\n",
"image_url = \"https://images.pexels.com/photos/20427316/pexels-photo-20427316/free-photo-of-a-moped-parked-in-front-of-a-blue-door.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2\" # @param {type:\"string\"}\n",
"# @markdown <img src=\"https://storage.googleapis.com/longcap100/91.jpeg\" width=\"400\" >\n",
"\n",
"image_url = \"https://storage.googleapis.com/longcap100/91.jpeg\" # @param {type:\"string\"}\n",
"language_code = \"en\" # @param {type: \"string\"}\n",
"\n",
"image = common_util.download_image(image_url)\n",
"display(image)\n",
"\n",
"# Make a prediction.\n",
"image_base64 = common_util.image_to_base64(image)\n",
"language_code = \"en\" # @param {type: \"string\"}\n",
"caption = caption_predict(endpoints[\"endpoint\"], image, language_code)\n",
"\n",
"caption = common_util.caption_predict(\n",
" endpoints[\"endpoint\"], language_code, image, caption_prompt\n",
")\n",
"\n",
"print(\"Caption: \", caption)\n",
"# @markdown Click \"Show Code\" to see more details."
Expand All @@ -619,13 +521,14 @@
"source": [
"# @title OCR\n",
"# @markdown This section uses the deployed PaliGemma model to extract text from an image, starting from the top left.\n",
"ocr_prompt = \"ocr\"\n",
"\n",
"# @markdown ![](https://images.pexels.com/photos/8919535/pexels-photo-8919535.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2)\n",
"image_url = \"https://images.pexels.com/photos/8919535/pexels-photo-8919535.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2\" # @param {type:\"string\"}\n",
"\n",
"image = common_util.download_image(image_url)\n",
"display(image)\n",
"text_found = ocr_predict(endpoints[\"endpoint\"], image)\n",
"text_found = common_util.ocr_predict(endpoints[\"endpoint\"], ocr_prompt, image)\n",
"\n",
"print(f\"Text found: {text_found}\")\n",
"# @markdown Click \"Show Code\" to see more details."
Expand All @@ -644,22 +547,24 @@
"# @markdown This section uses the deployed PaliGemma model to output bounding boxes for specified object image in a given image.\n",
"# @markdown The text output will be parsed into bounding boxes and overlaid on the original image.\n",
"\n",
"# @markdown ![](https://images.pexels.com/photos/1006293/pexels-photo-1006293.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2)\n",
"# @markdown Specify what object to detect. To specify multiple objects, enter them as a semicolon separated list as shown below.\n",
"objects = \"plant ; pineapple ; glasses\" # @param {type:\"string\"}\n",
"detect_promt = f\"detect {objects}\"\n",
"\n",
"# @markdown ![](https://images.pexels.com/photos/1006293/pexels-photo-1006293.jpeg?auto=compress&cs=tinysrgb&w=630&h=375&dpr=2)\n",
"image_url = \"https://images.pexels.com/photos/1006293/pexels-photo-1006293.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2\" # @param {type:\"string\"}\n",
"\n",
"# @markdown Specify what object to detect. To specify multiple objects, enter them as a semicolon separated list as shown below.\n",
"\n",
"objects = \"plant ; pineapple ; glasses\" # @param {type:\"string\"}\n",
"image = common_util.download_image(image_url)\n",
"display(image)\n",
"\n",
"# Make a prediction.\n",
"detection_response = detect_predict(endpoints[\"endpoint\"], image, objects)\n",
"detection_response = common_util.detect_predict(\n",
" endpoints[\"endpoint\"], detect_promt, image\n",
")\n",
"\n",
"print(\"Output: \", detection_response)\n",
"bboxes = parse_detections(detection_response)\n",
"plot_bounding_boxes(image, bboxes)\n",
"print(\"Output: \", detection_response)\n",
"# @markdown Click \"Show Code\" to see more details."
]
},
Expand Down Expand Up @@ -806,13 +711,13 @@
" raise gr.Error(\"You must upload an image!\")\n",
" endpoint = get_endpoint(endpoint_name)\n",
" if interface_name == Task.VQA.value:\n",
" return vqa_predict(endpoint, image, [prompt])[0], None\n",
" return common_util.vqa_predict(endpoint, [prompt], image)[0], None\n",
" elif interface_name == Task.CAPTION.value:\n",
" return caption_predict(endpoint, image, language_code), None\n",
" return common_util.caption_predict(endpoint, language_code, image, True), None\n",
" elif interface_name == Task.OCR.value:\n",
" return ocr_predict(endpoint, image), None\n",
" return common_util.ocr_predict(endpoint, ocr_prompt, image), None\n",
" elif interface_name == Task.DETECT.value:\n",
" text_output = detect_predict(endpoint, image, prompt)\n",
" text_output = common_util.detect_predict(endpoint, f\"detect {prompt}\", image)\n",
" bboxes = parse_detections(text_output)\n",
" return text_output, plot_bounding_boxes(image, bboxes)\n",
" else:\n",
Expand Down
Loading

0 comments on commit 9df704f

Please sign in to comment.