Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/video rest api example #268

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions nbs/stable_video_v1_REST_API_alpha.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "AvTo6cVeF3ip"
},
"outputs": [],
"source": [
"#@title Install requirements\n",
"import base64\n",
"from io import BytesIO\n",
"import json\n",
"import mimetypes\n",
"import os\n",
"from PIL import Image\n",
"import requests\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "pVBZ1o3fH1HX"
},
"outputs": [],
"source": [
"#@title Define helper functions\n",
"\n",
"def image_to_bytes(\n",
" img: Image,\n",
" format=\"PNG\"\n",
"):\n",
" im_file = BytesIO()\n",
" img.save(im_file, format=format)\n",
" img_bytes = im_file.getvalue()\n",
" return img_bytes\n",
"\n",
"def get_image_format(\n",
" image_path : str\n",
"):\n",
" image_mime_type = mimetypes.guess_type(image_path)[0]\n",
" if image_mime_type is None:\n",
" raise ValueError(f\"Unknown image mime type for {image_path}\")\n",
" image_format = image_mime_type.split(\"/\")[-1].upper()\n",
" return image_format\n",
"\n",
"def resize_and_crop(\n",
" image:Image,\n",
" width:int,\n",
" height:int\n",
" ):\n",
" # Resize the image so one side is the required size, then center crop the other side\n",
" # This is to avoid stretching the image\n",
" w, h = image.size\n",
" if w < h:\n",
" image = image.resize((width, int(h * width / w)))\n",
" else:\n",
" image = image.resize((int(w * height / h), height))\n",
" # center crop\n",
" left = (image.width - width) / 2\n",
" top = (image.height - height) / 2\n",
" right = (image.width + width) / 2\n",
" bottom = (image.height + height) / 2\n",
" image = image.crop((left, top, right, bottom))\n",
" return image\n",
"\n",
"def get_closest_valid_dims(\n",
" image : Image\n",
"):\n",
" # Finds the closest aspect ratio to the input image that are valid for SSC\n",
" # Valid dimensions are 1024x576, 768x768, 1024\n",
" w,h = image.size\n",
" aspect_ratio = w/h\n",
" portrait_aspect_ratio = 9/16\n",
" landscape_aspect_ratio = 16/9\n",
" portrait_aspect_ratio_midpoint = (portrait_aspect_ratio + 1)/2\n",
" landscape_aspect_ratio_midpoint = (landscape_aspect_ratio + 1)/2\n",
" if aspect_ratio < 1.0:\n",
" # portrait\n",
" width,height = (576,1024) if aspect_ratio < portrait_aspect_ratio_midpoint else (768,768)\n",
" else:\n",
" # landscape\n",
" width,height = (1024,576) if aspect_ratio > landscape_aspect_ratio_midpoint else (768,768)\n",
"\n",
" return width, height\n",
"\n",
"\n",
"def image_to_valid_bytes(\n",
" image_path : str\n",
" ):\n",
" image = Image.open(image_path)\n",
" width, height = get_closest_valid_dims(image)\n",
" format = get_image_format(image_path)\n",
" print(f\"Resizing image to {width}x{height}\")\n",
" image = resize_and_crop(image, width, height)\n",
" image_bytes = image_to_bytes(image, format=format)\n",
" return image_bytes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "dtw-2LAC7NgM"
},
"outputs": [],
"source": [
"#@title Set up credentials\n",
"\n",
"import getpass\n",
"# @markdown To get your API key visit https://platform.stability.ai/account/keys\n",
"STABILITY_KEY = getpass.getpass('Enter your API Key')\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"cellView": "form",
"id": "0lDpGa2jAmAs"
},
"outputs": [],
"source": [
"#@title Define input\n",
"\n",
"#@markdown - Drag and drop image to file folder on left\n",
"#@markdown - Right click it and choose Copy path\n",
"#@markdown - Paste that path into init_image field below\n",
"#@markdown <br><br>\n",
"\n",
"init_image = \"/content/img.jpg\" #@param {type:\"string\"}\n",
"seed = 0 #@param {type:\"integer\"}\n",
"cfg_scale = 2.5 #@param {type:\"number\"}\n",
"motion_bucket_id = 40 #@param {type:\"integer\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YNrNGa107s1M"
},
"outputs": [],
"source": [
"#@title Use REST API\n",
"\n",
"headers = {\n",
" \"Accept\": \"application/json\",\n",
" \"Authorization\": f\"Bearer {STABILITY_KEY}\"\n",
"}\n",
"host = f\"https://api.stability.ai/v2alpha/generation/image-to-video\"\n",
"\n",
"# Note:\n",
"# Valid input images must be 576x1024, 768x768, or 1024x576\n",
"# The method image_to_valid_bytes will resize to the nearest aspect ratio\n",
"init_image_bytes = image_to_valid_bytes(init_image)\n",
"\n",
"image_mime_type = mimetypes.guess_type(init_image)[0]\n",
"files = {\n",
" \"image\": (\"file\", init_image_bytes, image_mime_type),\n",
" }\n",
"params = {\n",
" \"seed\": seed,\n",
" \"cfg_scale\": cfg_scale,\n",
" \"motion_bucket_id\": motion_bucket_id\n",
" }\n",
"\n",
"for k,v in params.items():\n",
" if isinstance(v, bool):\n",
" v = str(v).lower()\n",
" files[k] = (None, str(v).encode('utf-8'))\n",
"\n",
"print(f\"Sending REST request to {host}...\")\n",
"\n",
"response = requests.post(\n",
" host,\n",
" headers=headers,\n",
" files=files,\n",
" )\n",
"\n",
"if not response.ok:\n",
" raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n",
"\n",
"#\n",
"# Process async response\n",
"#\n",
"response_dict = json.loads(response.text)\n",
"request_id = response_dict.get(\"id\", None)\n",
"assert request_id is not None, \"Expected id in response\"\n",
"\n",
"# Loop until video result or timeout\n",
"timeout = int(os.getenv(\"WORKER_TIMEOUT\", 500))\n",
"start = time.time()\n",
"status_code = 202\n",
"while status_code == 202:\n",
"\n",
" response = requests.get(\n",
" f\"{host}/result/{request_id}\",\n",
" headers=headers,\n",
" )\n",
"\n",
" if not response.ok:\n",
" raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n",
" status_code = response.status_code\n",
" time.sleep(2)\n",
" if time.time() - start > timeout:\n",
" raise Exception(f\"Timeout after {timeout} seconds\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "q6aXWOraCW0T"
},
"outputs": [],
"source": [
"#@title Decode response\n",
"json_data = response.json()\n",
"\n",
"video = base64.b64decode(json_data[\"video\"])\n",
"seed = json_data[\"seed\"]\n",
"finish_reason = json_data[\"finishReason\"]\n",
"\n",
"if finish_reason == 'CONTENT_FILTERED':\n",
" raise Warning(\"Video failed NSFW classifier\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PenTBs3oGZ2b"
},
"outputs": [],
"source": [
"#@title Save and display result\n",
"\n",
"filename = f\"video_{seed}.mp4\"\n",
"with open(filename, \"wb\") as f:\n",
" f.write(video)\n",
"print(f\"Saved video {filename}\")\n",
"\n",
"import IPython\n",
"mp4 = open(filename,'rb').read()\n",
"data_url = f\"data:video/mp4;base64,\" + base64.b64encode(mp4).decode()\n",
"IPython.display.display(IPython.display.HTML(f'<video controls loop><source src=\"{data_url}\" type=\"video/mp4\"></video>'))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "L-0LuAXdba8V"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}