diff --git a/.github/workflows/unit_testing.yaml b/.github/workflows/unit_testing.yaml index b3004e69..9f94c1e1 100644 --- a/.github/workflows/unit_testing.yaml +++ b/.github/workflows/unit_testing.yaml @@ -21,13 +21,14 @@ jobs: cache: 'pip' cache-dependency-path: | **/setup.py + - name: Install dependencies run: | python -m pip install --upgrade pip if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Install package run: | - pip install .[dev] + pip install .[dev,anim] - name: Test with pytest run: | pytest -s --ignore=src/stability_sdk/interfaces diff --git a/.gitignore b/.gitignore index 26e36ec4..1f255a33 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ *.pyc *.egg-info -*.png +.vscode/ +build/ dist/ pyenv/ *venv/ .env generation-*.pb.json +Pipfile* \ No newline at end of file diff --git a/README.md b/README.md index 358f93d6..39787fe0 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,14 @@ It will generate and put PNGs in your current directory. To upscale: `python3 -m stability_sdk upscale -i "/path/to/image.png"` +## Animation UI + +Install with +`pip install stability-sdk[anim_ui]` + +Then run with +`python3 -m stability_sdk animate --gui` + ## SDK Usage Be sure to check out [Platform](https://platform.stability.ai) for comprehensive documentation on how to interact with our API. @@ -57,7 +65,7 @@ options: --width WIDTH, -W WIDTH [512] width of image --start_schedule START_SCHEDULE - [0.5] start schedule for init image (must be greater than 0, 1 is full strength + [0.5] start schedule for init image (must be greater than 0; 1 is full strength text prompt, no trace of image) --end_schedule END_SCHEDULE [0.01] end schedule for init image diff --git a/nbs/animation.ipynb b/nbs/animation.ipynb new file mode 100644 index 00000000..16f58781 --- /dev/null +++ b/nbs/animation.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "rWXrHouW3pq_" + }, + "source": [ + "# Animation SDK example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "jCZ-IphH3prD" + }, + "outputs": [], + "source": [ + "#@title Mount Google Drive\n", + "try:\n", + " from google.colab import drive\n", + " drive.mount('/content/gdrive')\n", + " outputs_path = \"/content/gdrive/MyDrive/AI/StableAnimation\"\n", + " !mkdir -p $outputs_path\n", + "except:\n", + " outputs_path = \".\"\n", + "print(f\"Animations will be saved to {outputs_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "cellView": "form", + "id": "zj56t6tc3prF" + }, + "outputs": [], + "source": [ + "%%capture\n", + "#@title Connect to the Stability API\n", + "\n", + "# install Stability Animation SDK for Python\n", + "%pip install stability-sdk[anim]\n", + "\n", + "import datetime\n", + "import json\n", + "import os\n", + "import panel as pn\n", + "import param\n", + "import shutil\n", + "import sys\n", + "\n", + "from base64 import b64encode\n", + "from IPython import display\n", + "from pathlib import Path\n", + "from PIL import Image\n", + "from tqdm import tqdm\n", + "from types import SimpleNamespace\n", + "\n", + "from stability_sdk.api import Context\n", + "from stability_sdk.animation import AnimationArgs, Animator\n", + "from stability_sdk.utils import create_video_from_frames\n", + "\n", + "\n", + "# Enter your API key from dreamstudio.ai\n", + "STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n", + "STABILITY_KEY = \"\" #@param {type:\"string\"}\n", + "\n", + "# Connect to Stability API\n", + "api_context = Context(STABILITY_HOST, STABILITY_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ldUAFmur3prH" + }, + "outputs": [], + "source": [ + "# @title Settings\n", + "\n", + "# @markdown Run this cell to reveal the settings UI. After entering values, move on to the next step.\n", + "\n", + "# @markdown To reset values to default, simply re-run this cell.\n", + "\n", + "# @markdown NB: Settings are grouped across several tabs.\n", + "\n", + "show_documentation = True # @param {type:'boolean'}\n", + "\n", + "# #@markdown ####**Resume:**\n", + "resume_timestring = \"\" #@param {type:\"string\"}\n", + "\n", + "#@markdown ####**Override Settings:**\n", + "override_settings_path = \"\" #@param {type:\"string\"}\n", + "\n", + "###################\n", + "\n", + "from stability_sdk.animation import (\n", + " AnimationArgs,\n", + " Animator,\n", + " AnimationSettings,\n", + " BasicSettings,\n", + " CoherenceSettings,\n", + " ColorSettings,\n", + " DepthSettings,\n", + " InpaintingSettings,\n", + " Rendering3dSettings,\n", + " CameraSettings,\n", + " VideoInputSettings,\n", + " VideoOutputSettings,\n", + ")\n", + "\n", + "args_generation = BasicSettings()\n", + "args_animation = AnimationSettings()\n", + "args_camera = CameraSettings()\n", + "args_coherence = CoherenceSettings()\n", + "args_color = ColorSettings()\n", + "args_depth = DepthSettings()\n", + "args_render_3d = Rendering3dSettings()\n", + "args_inpaint = InpaintingSettings()\n", + "args_vid_in = VideoInputSettings()\n", + "args_vid_out = VideoOutputSettings()\n", + "arg_objs = (\n", + " args_generation,\n", + " args_animation,\n", + " args_camera,\n", + " args_coherence,\n", + " args_color,\n", + " args_depth,\n", + " args_render_3d,\n", + " args_inpaint,\n", + " args_vid_in,\n", + " args_vid_out,\n", + ")\n", + "\n", + "def _show_docs(component):\n", + " cols = []\n", + " for k, v in component.param.objects().items():\n", + " if k == 'name':\n", + " continue\n", + " col = pn.Column(v, v.doc)\n", + " cols.append(col)\n", + " return pn.Column(*cols)\n", + "\n", + "def build(component):\n", + " if show_documentation:\n", + " component = _show_docs(component)\n", + " return pn.Row(component, width=1000)\n", + "\n", + "pn.extension()\n", + "\n", + "pn.Tabs(*[(a.name[:-5], build(a)) for a in arg_objs])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_SudvbZG3prI" + }, + "source": [ + "### Prompts" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "FT9slDSw3prJ" + }, + "outputs": [], + "source": [ + "animation_prompts = {\n", + " 0: \"a painting of a delicious cheeseburger\",\n", + " 24: \"a painting of the the answer to life the universe and everything\",\n", + "}\n", + "\n", + "negative_prompt = \"\"\n", + "negative_prompt_weight = -1.0\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Rpqv6t303prJ" + }, + "outputs": [], + "source": [ + "#@title Render the animation\n", + "\n", + "args_d = {}\n", + "[args_d.update(a.param.values()) for a in arg_objs]\n", + "args=AnimationArgs(**args_d)\n", + "\n", + "\n", + "# load override settings if provided\n", + "if override_settings_path:\n", + " if not os.path.exists(override_settings_path):\n", + " raise ValueError(f\"Override settings file not found: {override_settings_path}\")\n", + " with open(override_settings_path, 'r') as f:\n", + " overrides = json.load(f)\n", + " args = vars(args)\n", + " for k in args.keys():\n", + " if k in overrides:\n", + " args[k] = overrides[k]\n", + " args = SimpleNamespace(**args)\n", + " animation_prompts = overrides.get('animation_prompts', animation_prompts)\n", + " animation_prompts = {int(k): v for k, v in animation_prompts.items()}\n", + " negative_prompt = overrides.get('negative_prompt', negative_prompt)\n", + " negative_prompt_weight = overrides.get('negative_prompt_weight', negative_prompt_weight)\n", + "\n", + "# create folder for frames output\n", + "if resume_timestring:\n", + " out_dir = os.path.join(outputs_path, resume_timestring)\n", + " if not os.path.exists(out_dir):\n", + " raise Exception(\"Can't resume {resume_timestring} because path {out_dir} doesn't exist. Please make sure the timestring is correct.\")\n", + " timestring = resume_timestring\n", + "else:\n", + " timestring = datetime.datetime.now().strftime('%Y%m%d%H%M%S')\n", + " out_dir = os.path.join(outputs_path, timestring)\n", + " os.makedirs(out_dir, exist_ok=True)\n", + "print(f\"Saving animation frames to {out_dir}...\")\n", + "\n", + "animator = Animator(\n", + " api_context=api_context,\n", + " animation_prompts=animation_prompts,\n", + " args=args,\n", + " out_dir=out_dir, \n", + " negative_prompt=negative_prompt,\n", + " negative_prompt_weight=negative_prompt_weight,\n", + " resume=len(resume_timestring) != 0,\n", + ")\n", + "animator.save_settings(f\"{timestring}_settings.txt\")\n", + "\n", + "for frame in tqdm(animator.render(), initial=animator.start_frame_idx, total=args.max_frames):\n", + " display.clear_output(wait=True)\n", + " display.display(frame)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "aWhJnLNX3prL" + }, + "outputs": [], + "source": [ + "#@title Create video from frames\n", + "skip_video_for_run_all = False #@param {type: 'boolean'}\n", + "fps = 12 #@param {type:\"number\"}\n", + "\n", + "if skip_video_for_run_all == True:\n", + " print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n", + "else:\n", + " mp4_path = os.path.join(out_dir, f\"{timestring}.mp4\")\n", + " print(f\"Compiling animation frames to {mp4_path}...\")\n", + " create_video_from_frames(out_dir, mp4_path, fps)\n", + "\n", + " mp4 = open(mp4_path,'rb').read()\n", + " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + " display.display( display.HTML(f'') )" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "provenance": [] + }, + "kernelspec": { + "display_name": "client", + "language": "python", + "name": "client" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fb02550c4ef2b9a37ba5f7f381e893a74079cea154f791601856f87ae67cf67c" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/nbs/animation_gradio.ipynb b/nbs/animation_gradio.ipynb new file mode 100644 index 00000000..c38a0db6 --- /dev/null +++ b/nbs/animation_gradio.ipynb @@ -0,0 +1,110 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "tPZFjbUDTwYE" + }, + "source": [ + "# Stable Animation notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LUMF8i8BTwYH", + "outputId": "f61c635e-bc57-48ab-cc3d-5166286b158f" + }, + "outputs": [], + "source": [ + "#@title Mount Google Drive\n", + "import os\n", + "try:\n", + " from google.colab import drive\n", + " drive.mount('/content/gdrive')\n", + " outputs_path = \"/content/gdrive/MyDrive/AI/StableAnimation\"\n", + " os.makedirs(outputs_path, exist_ok=True)\n", + "except:\n", + " outputs_path = \".\"\n", + "print(f\"Animations will be saved to {outputs_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "tgZBrk8DTwYI" + }, + "outputs": [], + "source": [ + "#@title Install Animation SDK and connect to the Stability API\n", + "%pip install stability-sdk[anim_ui]\n", + "\n", + "from stability_sdk.api import Context\n", + "from stability_sdk.animation_ui import create_ui\n", + "\n", + "# Enter your API key from dreamstudio.ai\n", + "STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n", + "STABILITY_KEY = \"\" #@param {type:\"string\"}\n", + "\n", + "# Connect to Stability API\n", + "api_context = Context(STABILITY_HOST, STABILITY_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "cellView": "form", + "id": "QVCAr8xcTwYI" + }, + "outputs": [], + "source": [ + "#@title Animation UI\n", + "show_ui_in_notebook = True #@param {type:\"boolean\"}\n", + "\n", + "ui = create_ui(api_context, outputs_path)\n", + "\n", + "ui.queue(concurrency_count=2, max_size=2)\n", + "ui.launch(show_api=False, debug=True, inline=show_ui_in_notebook, height=768, share=True, show_error=True)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fb02550c4ef2b9a37ba5f7f381e893a74079cea154f791601856f87ae67cf67c" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/setup.py b/setup.py index 0854938a..7cd124e5 100644 --- a/setup.py +++ b/setup.py @@ -1,42 +1,56 @@ # fmt: off -from setuptools import setup, find_namespace_packages +from setuptools import ( + setup, + find_namespace_packages, +) with open('README.md','r') as f: README = f.read() setup( name='stability-sdk', - version='0.7.0', + version='0.8.0', author='Stability AI', author_email='support@stability.ai', maintainer='Stability AI', maintainer_email='support@stability.ai', url='https://beta.dreamstudio.ai/', download_url='https://github.com/Stability-AI/stability-sdk/', - description='Python SDK for interacting with stability.ai APIs', long_description=README, long_description_content_type="text/markdown", - install_requires=[ 'Pillow', 'grpcio==1.53.0', 'grpcio-tools==1.53.0', 'python-dotenv', + 'param', 'protobuf==4.21.12' ], extras_require={ 'dev': [ 'pytest', 'grpcio-testing' - ]}, + ], + 'anim': [ + 'keyframed', + 'numpy', + 'opencv-python-headless', + ], + 'anim_ui': [ + 'keyframed', + 'gradio', + 'numpy', + 'opencv-python-headless', + 'tqdm', + ] + }, packages=find_namespace_packages( where='src', include=['stability_sdk*'], ), package_dir = {"": "src"}, - classifiers=[ 'Intended Audience :: Developers', 'Intended Audience :: Education', diff --git a/src/stability_sdk/__init__.py b/src/stability_sdk/__init__.py index 29430654..f9a14d1c 100644 --- a/src/stability_sdk/__init__.py +++ b/src/stability_sdk/__init__.py @@ -2,8 +2,10 @@ # this is necessary because of how the auto-generated code constructs its imports # should be a way to move this upstream -thisPath = pathlib.Path(__file__).parent.resolve() -genPath = thisPath / "interfaces/gooseai/generation" -tensPath = thisPath / "interfaces/src/tensorizer/tensors" -#sys.path.append(str(genPath)) -sys.path.extend([str(genPath), str(tensPath)]) \ No newline at end of file +this_path = pathlib.Path(__file__).parent.resolve() +sys.path.extend([ + str(this_path / "interfaces/gooseai/dashboard"), + str(this_path / "interfaces/gooseai/generation"), + str(this_path / "interfaces/gooseai/project"), + str(this_path / "interfaces/src/tensorizer/tensors") +]) \ No newline at end of file diff --git a/src/stability_sdk/animation.py b/src/stability_sdk/animation.py new file mode 100644 index 00000000..5f183749 --- /dev/null +++ b/src/stability_sdk/animation.py @@ -0,0 +1,1131 @@ +import base64 +import bisect +import cv2 +import glob +import json +import logging +import math +import numpy as np +import os +import param +import random +import shutil +import subprocess + +from collections import OrderedDict, deque +from dataclasses import dataclass, fields +from keyframed.dsl import curve_from_cn_string +from PIL import Image, ImageOps +from types import SimpleNamespace +from typing import Callable, cast, Deque, Dict, Generator, List, Optional, Tuple, Union + +from stability_sdk.api import Context, generation +from stability_sdk.utils import ( + camera_pose_transform, + color_adjust_transform, + depth_calc_transform, + guidance_from_string, + image_mix, + image_to_png_bytes, + interpolate_mode_from_string, + resample_transform, + sampler_from_string, +) +import stability_sdk.matrix as matrix + +logger = logging.getLogger(__name__) +logger.setLevel(level=logging.INFO) + +DEFAULT_MODEL = 'stable-diffusion-v1-5' +TRANSLATION_SCALE = 1.0/200.0 # matches Disco and Deforum + +docstring_bordermode = ( + "Method that will be used to fill empty regions, e.g. after a rotation transform." + "\n\t* reflect - Mirror pixels across the image edge to fill empty regions." + "\n\t* replicate - Use closest pixel values (default)." + "\n\t* wrap - Treat image borders as if they were connected, i.e. use pixels from left edge to fill empty regions touching the right edge." + "\n\t* zero - Fill empty regions with black pixels." + "\n\t* prefill - Do simple inpainting over empty regions." +) + +class BasicSettings(param.Parameterized): + width = param.Integer(default=512, doc="Output image dimensions. Will be resized to a multiple of 64.") + height = param.Integer(default=512, doc="Output image dimensions. Will be resized to a multiple of 64.") + sampler = param.Selector( + default='K_dpmpp_2m', + objects=[ + "DDIM", "PLMS", "K_euler", "K_euler_ancestral", "K_heun", "K_dpm_2", + "K_dpm_2_ancestral", "K_lms", "K_dpmpp_2m", "K_dpmpp_2s_ancestral" + ] + ) + model = param.Selector( + default=DEFAULT_MODEL, + check_on_set=False, # allow old and new models without raising ValueError + objects=[ + "stable-diffusion-v1-5", "stable-diffusion-512-v2-1", "stable-diffusion-768-v2-1", + "stable-diffusion-depth-v2-0", "stable-diffusion-xl-beta-v2-2-2", + "custom" + ] + ) + custom_model = param.String(default="", doc="Identifier of custom model to use.") + seed = param.Integer(default=-1, doc="Provide a seed value for more deterministic behavior. Negative seed values will be replaced with a random seed (default).") + cfg_scale = param.Number(default=7, softbounds=(0,20), doc="Classifier-free guidance scale. Strength of prompt influence on denoising process. `cfg_scale=0` gives unconditioned sampling.") + clip_guidance = param.Selector(default='None', objects=["None", "Simple", "FastBlue", "FastGreen"], doc="CLIP-guidance preset.") + init_image = param.String(default='', doc="Path to image. Height and width dimensions will be inherited from image.") + init_sizing = param.Selector(default='stretch', objects=["cover", "stretch", "resize-canvas"]) + mask_path = param.String(default="", doc="Path to image or video mask") + mask_invert = param.Boolean(default=False, doc="White in mask marks areas to change by default.") + preset = param.Selector( + default='None', + objects=[ + 'None', '3d-model', 'analog-film', 'anime', 'cinematic', 'comic-book', 'digital-art', + 'enhance', 'fantasy-art', 'isometric', 'line-art', 'low-poly', 'modeling-compound', + 'neon-punk', 'origami', 'photographic', 'pixel-art', + ] + ) + +class AnimationSettings(param.Parameterized): + animation_mode = param.Selector(default='3D warp', objects=['2D', '3D warp', '3D render', 'Video Input']) + max_frames = param.Integer(default=72, doc="Force stop of animation job after this many frames are generated.") + border = param.Selector(default='replicate', objects=['reflect', 'replicate', 'wrap', 'zero', 'prefill'], doc=docstring_bordermode) + noise_add_curve = param.String(default="0:(0.02)") + noise_scale_curve = param.String(default="0:(0.99)") + strength_curve = param.String(default="0:(0.65)", doc="Image Strength (of init image relative to the prompt). 0 for ignore init image and attend only to prompt, 1 would return the init image unmodified") + steps_curve = param.String(default="0:(30)", doc="Diffusion steps") + steps_strength_adj = param.Boolean(default=False, doc="Adjusts number of diffusion steps based on current previous frame strength value.") + interpolate_prompts = param.Boolean(default=False, doc="Smoothly interpolate prompts between keyframes. Defaults to False") + locked_seed = param.Boolean(default=False) + +class CameraSettings(param.Parameterized): + """ + See disco/deforum keyframing syntax, originally developed by Chigozie Nri + General syntax: ":(), f2:(v2),f3:(v3)...." + Values between intermediate keyframes will be linearly interpolated by default to produce smooth transitions. + For abrupt transitions, specify values at adjacent keyframes. + """ + angle = param.String(default="0:(0)", doc="Camera rotation angle in degrees for 2D mode") + zoom = param.String(default="0:(1)", doc="Camera zoom factor for 2D mode (<1 zooms out, >1 zooms in)") + translation_x = param.String(default="0:(0)") + translation_y = param.String(default="0:(0)") + translation_z = param.String(default="0:(0)") + rotation_x = param.String(default="0:(0)", doc="Camera rotation around X-axis in degrees for 3D modes") + rotation_y = param.String(default="0:(0)", doc="Camera rotation around Y-axis in degrees for 3D modes") + rotation_z = param.String(default="0:(0)", doc="Camera rotation around Z-axis in degrees for 3D modes") + + +class CoherenceSettings(param.Parameterized): + diffusion_cadence_curve = param.String(default="0:(1)", doc="One greater than the number of frames between diffusion operations. A cadence of 1 performs diffusion on each frame. Values greater than one will generate frames using interpolation methods.") + cadence_interp = param.Selector(default='mix', objects=['film', 'mix', 'rife', 'vae-lerp', 'vae-slerp']) + cadence_spans = param.Boolean(default=False, doc="Experimental diffusion cadence mode for better outpainting") + + +class ColorSettings(param.Parameterized): + color_coherence = param.Selector(default='LAB', objects=['None', 'HSV', 'LAB', 'RGB'], doc="Color space that will be used for inter-frame color adjustments.") + brightness_curve = param.String(default="0:(1.0)") + contrast_curve = param.String(default="0:(1.0)") + hue_curve = param.String(default="0:(0.0)") + saturation_curve = param.String(default="0:(1.0)") + lightness_curve = param.String(default="0:(0.0)") + color_match_animate = param.Boolean(default=True, doc="Animate color match between key frames.") + + +class DepthSettings(param.Parameterized): + depth_model_weight = param.Number(default=0.3, softbounds=(0,1), doc="Blend factor between AdaBins and MiDaS depth models.") + near_plane = param.Number(default=200, doc="Distance to nearest plane of camera view volume.") + far_plane = param.Number(default=10000, doc="Distance to furthest plane of camera view volume.") + fov_curve = param.String(default="0:(25)", doc="FOV angle of camera volume in degrees.") + depth_blur_curve = param.String(default="0:(0.0)", doc="Blur strength of depth map.") + depth_warp_curve = param.String(default="0:(1.0)", doc="Depth warp strength.") + save_depth_maps = param.Boolean(default=False) + + +class Rendering3dSettings(param.Parameterized): + camera_type = param.Selector(default='perspective', objects=['perspective', 'orthographic']) + render_mode = param.Selector(default='mesh', objects=['mesh', 'pointcloud'], doc="Mode for image and mask rendering. 'pointcloud' is a bit faster, but 'mesh' is more stable") + mask_power = param.Number(default=0.3, softbounds=(0, 4), doc="Raises each mask (0, 1) value to this power. The higher the value the more changes will be applied to the nearest objects") + +class InpaintingSettings(param.Parameterized): + use_inpainting_model = param.Boolean(default=False, doc="If True, inpainting will be performed using dedicated inpainting model. If False, inpainting will be performed with the regular model that is selected") + inpaint_border = param.Boolean(default=False, doc="Use inpainting on top of border regions for 2D and 3D warp modes. Defaults to False") + mask_min_value = param.String(default="0:(0.25)", doc="Mask postprocessing for non-inpainting model. Mask floor values will be clipped by this value prior to inpainting") + mask_binarization_thr = param.Number(default=0.5, softbounds=(0,1), doc="Grayscale mask values lower than this value will be set to 0, values that are higher — to 1.") + save_inpaint_masks = param.Boolean(default=False) + +class VideoInputSettings(param.Parameterized): + video_init_path = param.String(default="", doc="Path to video input") + extract_nth_frame = param.Integer(default=1, bounds=(1,None), doc="Only use every Nth frame of the video") + video_mix_in_curve = param.String(default="0:(0.02)") + video_flow_warp = param.Boolean(default=True, doc="Whether or not to transfer the optical flow from the video to the generated animation as a warp effect.") + +class VideoOutputSettings(param.Parameterized): + fps = param.Integer(default=12, doc="Frame rate to use when generating video output.") + reverse = param.Boolean(default=False, doc="Whether to reverse the output video or not.") + +class AnimationArgs( + BasicSettings, + AnimationSettings, + CameraSettings, + CoherenceSettings, + ColorSettings, + DepthSettings, + Rendering3dSettings, + InpaintingSettings, + VideoInputSettings, + VideoOutputSettings +): + """ + Aggregates parameters from the multiple settings classes. + """ + +@dataclass +class FrameArgs: + """Expansion of key framed Args to per-frame values""" + angle: List[float] + zoom: List[float] + translation_x: List[float] + translation_y: List[float] + translation_z: List[float] + rotation_x: List[float] + rotation_y: List[float] + rotation_z: List[float] + brightness_curve: List[float] + contrast_curve: List[float] + hue_curve: List[float] + saturation_curve: List[float] + lightness_curve: List[float] + noise_add_curve: List[float] + noise_scale_curve: List[float] + steps_curve: List[float] + strength_curve: List[float] + diffusion_cadence_curve: List[float] + fov_curve: List[float] + depth_blur_curve: List[float] + depth_warp_curve: List[float] + video_mix_in_curve: List[float] + mask_min_value: List[float] + + +def args_to_dict(args): + """ + Converts arguments object to an OrderedDict + """ + if isinstance(args, param.Parameterized): + return OrderedDict(args.param.values()) + elif isinstance(args, SimpleNamespace): + return OrderedDict(vars(args)) + else: + raise NotImplementedError(f"Unsupported arguments object type: {type(args)}") + +def cv2_to_pil(cv2_img: np.ndarray) -> Image.Image: + """Convert a cv2 BGR ndarray to a PIL Image""" + return Image.fromarray(cv2_img[:, :, ::-1]) + +def interpolate_frames( + context: Context, + frames_path: str, + out_path: str, + interp_mode: generation.InterpolateMode, + interp_factor: int +) -> Generator[Image.Image, None, None]: + """Interpolates frames in a directory using the specified interpolation mode.""" + assert interp_factor > 1, "Interpolation factor must be greater than 1" + + # gather source frames + frame_files = glob.glob(os.path.join(frames_path, "frame_*.png")) + frame_files.sort() + + # perform frame interpolation + os.makedirs(out_path, exist_ok=True) + ratios = np.linspace(0, 1, interp_factor+1)[1:-1].tolist() + for i in range(len(frame_files) - 1): + shutil.copy(frame_files[i], os.path.join(out_path, f"frame_{i * interp_factor:05d}.png")) + frame1 = Image.open(frame_files[i]) + frame2 = Image.open(frame_files[i + 1]) + yield frame1 + tweens = context.interpolate([frame1, frame2], ratios, interp_mode) + for ti, tween in enumerate(tweens): + tween.save(os.path.join(out_path, f"frame_{i * interp_factor + ti + 1:05d}.png")) + yield tween + + # copy final frame + shutil.copy(frame_files[-1], os.path.join(out_path, f"frame_{(len(frame_files)-1) * interp_factor:05d}.png")) + +def mask_erode_blur(mask: Image.Image, mask_erode: int, mask_blur: int) -> Image.Image: + mask = np.array(mask) + if mask_erode > 0: + ks = mask_erode*2 + 1 + mask = cv2.erode(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ks, ks)), iterations=1) + if mask_blur > 0: + ks = mask_blur*2 + 1 + mask = cv2.GaussianBlur(mask, (ks, ks), 0) + return Image.fromarray(mask) + +def make_xform_2d( + w: float, h: float, + rotation_angle: float, # in radians + scale_factor: float, + translate_x: float, + translate_y: float, +) -> matrix.Matrix: + center = (w / 2, h / 2) + pre = matrix.translation(-center[0], -center[1], 0) + post = matrix.translation(center[0], center[1], 0) + rotate = matrix.rotation_euler(0, 0, rotation_angle) + scale = matrix.scale(scale_factor, scale_factor, 1) + rotate_scale = matrix.multiply(post, matrix.multiply(rotate, matrix.multiply(scale, pre))) + # match 3D camera translation, +X moves camera to right, +Y moves camera up + translate = matrix.translation(-translate_x, translate_y, 0) + return matrix.multiply(rotate_scale, translate) + +def model_supports_clip_guidance(model_name: str) -> bool: + return not model_name.startswith('stable-diffusion-xl') + +def model_requires_depth(model_name: str) -> bool: + return model_name == 'stable-diffusion-depth-v2-0' + +def sampler_supports_clip_guidance(sampler_name: str) -> bool: + supported_samplers = [ + generation.SAMPLER_K_EULER_ANCESTRAL, + generation.SAMPLER_K_DPM_2_ANCESTRAL, + generation.SAMPLER_K_DPMPP_2S_ANCESTRAL + ] + return sampler_from_string(sampler_name) in supported_samplers + +def to_3x3(m: matrix.Matrix) -> matrix.Matrix: + # convert 4x4 matrix with 2D rotation, scale, and translation to 3x3 matrix + return [[m[0][0], m[0][1], m[0][3]], + [m[1][0], m[1][1], m[1][3]], + [m[3][0], m[3][1], m[3][3]]] + + +class Animator: + def __init__( + self, + api_context: Context, + animation_prompts: dict, + args: Optional[AnimationArgs] = None, + out_dir: Optional[str] = None, + negative_prompt: str = '', + negative_prompt_weight: float = -1.0, + resume: bool = False + ): + self.api = api_context + self.animation_prompts = animation_prompts + self.args = args or AnimationArgs() + self.color_match_images: Optional[Dict[int, Image.Image]] = {} + self.diffusion_cadence_ofs: int = 0 + self.frame_args: FrameArgs + self.inpaint_mask: Optional[Image.Image] = None + self.key_frame_values: List[int] = [] + self.out_dir: Optional[str] = out_dir + self.mask: Optional[Image.Image] = None + self.mask_reader = None + self.cadence_on: bool = False + self.prior_frames: Deque[Image.Image] = deque([], 1) # forward warped prior frames. stores one image with cadence off, two images otherwise + self.prior_diffused: Deque[Image.Image] = deque([], 1) # results of diffusion. stores one image with cadence off, two images otherwise + self.prior_xforms: Deque[matrix.Matrix] = deque([], 1) # accumulated transforms since last diffusion. stores one with cadence off, two otherwise + self.negative_prompt: str = negative_prompt + self.negative_prompt_weight: float = negative_prompt_weight + self.start_frame_idx: int = 0 + self.video_prev_frame: Optional[Image.Image] = None + self.video_reader: Optional[cv2.VideoCapture] = None + + # configure Api to retry on classifier obfuscations + self.api._retry_obfuscation = True + + # create output directory + if self.out_dir is not None: + os.makedirs(self.out_dir, exist_ok=True) + elif self.args.save_depth_maps or self.args.save_inpaint_masks: + raise ValueError('out_dir must be specified when saving depth maps or inpaint masks') + + self.setup_animation(resume) + + def build_frame_xform(self, frame_idx) -> matrix.Matrix: + args, frame_args = self.args, self.frame_args + + if self.args.animation_mode == '2D': + angle = frame_args.angle[frame_idx] + scale = frame_args.zoom[frame_idx] + dx = frame_args.translation_x[frame_idx] + dy = frame_args.translation_y[frame_idx] + return make_xform_2d(args.width, args.height, math.radians(angle), scale, dx, dy) + + elif self.args.animation_mode in ('3D warp', '3D render'): + dx = frame_args.translation_x[frame_idx] + dy = frame_args.translation_y[frame_idx] + dz = frame_args.translation_z[frame_idx] + rx = frame_args.rotation_x[frame_idx] + ry = frame_args.rotation_y[frame_idx] + rz = frame_args.rotation_z[frame_idx] + + dx, dy, dz = -dx*TRANSLATION_SCALE, dy*TRANSLATION_SCALE, -dz*TRANSLATION_SCALE + rx, ry, rz = math.radians(rx), math.radians(ry), math.radians(rz) + + # create xform for the current frame + world_view = matrix.multiply(matrix.translation(dx, dy, dz), matrix.rotation_euler(rx, ry, rz)) + return world_view + + else: + return matrix.identity + + def emit_frame(self, frame_idx: int, out_frame: Image.Image) -> Image.Image: + if self.args.save_depth_maps: + depth_image = self.generate_depth_image(out_frame) + self.save_to_out_dir(frame_idx, depth_image, prefix='depth') + + if self.args.save_inpaint_masks and self.inpaint_mask is not None: + self.save_to_out_dir(frame_idx, self.inpaint_mask, prefix='mask') + + self.save_to_out_dir(frame_idx, out_frame) + return out_frame + + def generate_depth_image(self, image: Image.Image) -> Image.Image: + results, _ = self.api.transform( + [image], + depth_calc_transform(blend_weight=self.args.depth_model_weight) + ) + return results[0] + + def get_animation_prompts_weights(self, frame_idx: int) -> Tuple[List[str], List[float]]: + prev, next, tween = self.get_key_frame_tween(frame_idx) + if prev == next or not self.args.interpolate_prompts: + return [self.animation_prompts[prev]], [1.0] + else: + return [self.animation_prompts[prev], self.animation_prompts[next]], [1.0 - tween, tween] + + def get_color_match_image(self, frame_idx: int) -> Image.Image: + if not self.args.color_match_animate: + return self.color_match_images.get(0) + + prev, next, tween = self.get_key_frame_tween(frame_idx) + + if prev not in self.color_match_images: + self.color_match_images[prev] = self._render_frame(prev, self.args.seed) + prev_match = self.color_match_images[prev] + if prev == next: + return prev_match + + if next not in self.color_match_images: + self.color_match_images[next] = self._render_frame(next, self.args.seed) + next_match = self.color_match_images[next] + + # Create image combining colors from previous and next key frames without mixing + # the RGB values. Tiles of next key frame are filled in over tiles of previous + # key frame. The tween value increases the subtile size on each axis so the transition + # is non-linear - staying with previous key frame longer then quickly moving to next. + blended = prev_match.copy() + width, height, tile_size = blended.width, blended.height, 64 + for y in range(0, height, tile_size): + for x in range(0, width, tile_size): + cut = next_match.crop((x, y, x + int(tile_size * tween), y + int(tile_size * tween))) + blended.paste(cut, (x, y)) + return blended + + def get_key_frame_tween(self, frame_idx: int) -> Tuple[int, int, float]: + """Returns previous and next key frames along with in between ratio""" + keys = self.key_frame_values + idx = bisect.bisect_right(keys, frame_idx) + prev, next = idx - 1, idx + if next == len(keys): + return keys[-1], keys[-1], 1.0 + else: + tween = (frame_idx - keys[prev]) / (keys[next] - keys[prev]) + return keys[prev], keys[next], tween + + def get_frame_filename(self, frame_idx, prefix="frame") -> Optional[str]: + return os.path.join(self.out_dir, f"{prefix}_{frame_idx:05d}.png") if self.out_dir else None + + def image_resize(self, img: Image.Image, mode: str = 'stretch') -> Image.Image: + width, height = img.size + if mode == 'cover': + scale = max(self.args.width / width, self.args.height / height) + img = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS) + x = (img.width - self.args.width) // 2 + y = (img.height - self.args.height) // 2 + img = img.crop((x, y, x + self.args.width, y + self.args.height)) + elif mode == 'stretch': + img = img.resize((self.args.width, self.args.height), resample=Image.LANCZOS) + else: # 'resize-canvas' + width, height = map(lambda x: x - x % 64, (width, height)) + self.args.width, self.args.height = width, height + return img + + def inpaint_frame( + self, + frame_idx: int, + image: Image.Image, + mask: Image.Image, + seed: Optional[int] = None, + mask_blur_radius: Optional[int] = 8 + ) -> Image.Image: + args = self.args + steps = int(self.frame_args.steps_curve[frame_idx]) + sampler = sampler_from_string(args.sampler.lower()) + guidance = guidance_from_string(args.clip_guidance) + + # fetch set of prompts and weights for this frame + prompts, weights = self.get_animation_prompts_weights(frame_idx) + if len(self.negative_prompt) and self.negative_prompt_weight != 0.0: + prompts.append(self.negative_prompt) + weights.append(-abs(self.negative_prompt_weight)) + + if args.use_inpainting_model: + binary_mask = self._postprocess_inpainting_mask( + mask, binarize=True, blur_radius=mask_blur_radius) + results = self.api.inpaint( + image, binary_mask, + prompts, weights, + steps=steps, + seed=seed if seed is not None else args.seed, + cfg_scale=args.cfg_scale, + sampler=sampler, + init_strength=0.0, + masked_area_init=generation.MASKED_AREA_INIT_ZERO, + guidance_preset=guidance, + preset=args.preset, + ) + else: + mask_min_value = self.frame_args.mask_min_value[frame_idx] + binary_mask = self._postprocess_inpainting_mask( + mask, binarize=True, min_val=mask_min_value, blur_radius=mask_blur_radius) + adjusted_steps = max(5, int(steps * (1.0 - mask_min_value))) if args.steps_strength_adj else steps + noise_scale = self.frame_args.noise_scale_curve[frame_idx] + results = self.api.generate( + prompts, weights, + args.width, args.height, + steps=adjusted_steps, + seed=seed if seed is not None else args.seed, + cfg_scale=args.cfg_scale, + sampler=sampler, + init_image=image, + init_strength=mask_min_value, + init_noise_scale=noise_scale, + mask=binary_mask, + masked_area_init=generation.MASKED_AREA_INIT_ORIGINAL, + guidance_preset=guidance, + preset=args.preset, + ) + return results[generation.ARTIFACT_IMAGE][0] + + def load_init_image(self, fpath=None): + if fpath is None: + fpath = self.args.init_image + if not fpath: + return + + img = self.image_resize(Image.open(fpath), self.args.init_sizing) + + self.prior_frames.extend([img, img]) + self.prior_diffused.extend([img, img]) + + def load_mask(self): + if not self.args.mask_path: + return + + # try to load mask as an image + mask = Image.open(self.args.mask_path) + if mask is not None: + self.set_mask(mask) + + # try to load mask as a video + if self.mask is None: + self.mask_reader = cv2.VideoCapture(self.args.mask_path) + self.next_mask() + + if self.mask is None: + raise Exception(f"Failed to read mask from {self.args.mask_path}") + + def load_video(self): + if self.args.animation_mode != 'Video Input' or not self.args.video_init_path: + return + + self.video_reader = cv2.VideoCapture(self.args.video_init_path) + if self.video_reader is not None: + success, image = self.video_reader.read() + if not success: + raise Exception(f"Failed to read first frame from {self.args.video_init_path}") + self.video_prev_frame = self.image_resize(cv2_to_pil(image), 'cover') + self.prior_frames.extend([self.video_prev_frame, self.video_prev_frame]) + self.prior_diffused.extend([self.video_prev_frame, self.video_prev_frame]) + + def next_mask(self): + if not self.mask_reader: + return False + + for _ in range(self.args.extract_nth_frame): + success, mask = self.mask_reader.read() + if not success: + return + + self.set_mask(cv2_to_pil(mask)) + + def prepare_init_ops(self, init_image: Optional[Image.Image], frame_idx: int, noise_seed:int) -> List[generation.TransformParameters]: + if init_image is None: + return [] + + args, frame_args = self.args, self.frame_args + brightness = frame_args.brightness_curve[frame_idx] + contrast = frame_args.contrast_curve[frame_idx] + hue = frame_args.hue_curve[frame_idx] + saturation = frame_args.saturation_curve[frame_idx] + lightness = frame_args.lightness_curve[frame_idx] + noise_amount = frame_args.noise_add_curve[frame_idx] + + color_match_image = None + if args.color_coherence != 'None' and frame_idx > 0: + color_match_image = self.get_color_match_image(frame_idx) + + do_color_match = args.color_coherence != 'None' and color_match_image is not None + do_bchsl = brightness != 1.0 or contrast != 1.0 or hue != 0.0 or saturation != 1.0 or lightness != 0.0 + do_noise = noise_amount > 0.0 + + init_ops: List[generation.TransformParameters] = [] + + if do_color_match or do_bchsl or do_noise: + init_ops.append(color_adjust_transform( + brightness=brightness, + contrast=contrast, + hue=hue, + saturation=saturation, + lightness=lightness, + match_image=color_match_image, + match_mode=args.color_coherence, + noise_amount=noise_amount, + noise_seed=noise_seed + )) + + return init_ops + + def render(self) -> Generator[Image.Image, None, None]: + args = self.args + seed = args.seed + + # experimental span-based outpainting mode + if args.cadence_spans and args.animation_mode != 'Video Input': + for idx, frame in self._spans_render(): + yield self.emit_frame(idx, frame) + return + + for frame_idx in range(self.start_frame_idx, args.max_frames): + # select image generation model + self.api._generate.engine_id = args.custom_model if args.model == "custom" else args.model + if model_requires_depth(args.model) and not self.prior_frames: + self.api._generate.engine_id = DEFAULT_MODEL + + diffusion_cadence = max(1, int(self.frame_args.diffusion_cadence_curve[frame_idx])) + self.set_cadence_mode(enabled=(diffusion_cadence > 1)) + is_diffusion_frame = (frame_idx - self.diffusion_cadence_ofs) % diffusion_cadence == 0 + + steps = int(self.frame_args.steps_curve[frame_idx]) + strength = max(0.0, self.frame_args.strength_curve[frame_idx]) + + # fetch set of prompts and weights for this frame + prompts, weights = self.get_animation_prompts_weights(frame_idx) + if len(self.negative_prompt) and self.negative_prompt_weight != 0.0: + prompts.append(self.negative_prompt) + weights.append(-abs(self.negative_prompt_weight)) + + + # transform prior frames + stashed_prior_frames = [i.copy() for i in self.prior_frames] if self.mask is not None else [] + self.inpaint_mask = None + if args.animation_mode == '2D': + self.inpaint_mask = self.transform_2d(frame_idx) + elif args.animation_mode in ('3D render', '3D warp'): + self.inpaint_mask = self.transform_3d(frame_idx) + elif args.animation_mode == 'Video Input': + self.inpaint_mask = self.transform_video(frame_idx) + + # apply inpainting + # If cadence is disabled and inpainting is performed using the same model as for generation, + # we can optimize inpaint->generate calls into a single generate call. + if args.inpaint_border and self.inpaint_mask is not None \ + and (self.cadence_on or args.use_inpainting_model): + for i in range(len(self.prior_frames)): + # The earliest prior frame will be popped right after the generation step, so its inpainting would be redundant. + if self.cadence_on and is_diffusion_frame and i==0: + continue + self.prior_frames[i] = self.inpaint_frame( + frame_idx, self.prior_frames[i], self.inpaint_mask, + seed=None if args.use_inpainting_model else seed) + + # apply mask to transformed prior frames + self.next_mask() + if self.mask is not None: + for i in range(len(self.prior_frames)): + if self.cadence_on and is_diffusion_frame and i==0: + continue + self.prior_frames[i] = image_mix(self.prior_frames[i], stashed_prior_frames[i], self.mask) + + # either run diffusion or emit an inbetween frame + if is_diffusion_frame: + init_image = self.prior_frames[-1] if len(self.prior_frames) and strength > 0 else None + init_strength = strength if init_image is not None else 0.0 + + # mix video frame into init image + mix_in = self.frame_args.video_mix_in_curve[frame_idx] + if init_image is not None and mix_in > 0 and self.video_prev_frame is not None: + init_image = image_mix(init_image, self.video_prev_frame, mix_in) + + # when using depth model, compute a depth init image + init_depth = None + if init_image is not None and model_requires_depth(args.model): + depth_source = self.video_prev_frame if self.video_prev_frame is not None else init_image + params = depth_calc_transform(blend_weight=1.0, blur_radius=0, reverse=True) + results, _ = self.api.transform([depth_source], params) + init_depth = results[0] + + # builds set of transform ops to prepare init image for generation + init_image_ops = self.prepare_init_ops(init_image, frame_idx, seed) + + # For in-diffusion frames instead of a full run through inpainting model and then generate call, + # inpainting can be done in a single call with non-inpainting model + do_inpainting = not self.cadence_on and not args.use_inpainting_model \ + and self.inpaint_mask is not None \ + and (args.inpaint_border or args.animation_mode == '3D render') + if do_inpainting: + mask_min_value = self.frame_args.mask_min_value[frame_idx] + init_strength = min(strength, mask_min_value) + self.inpaint_mask = self._postprocess_inpainting_mask( + self.inpaint_mask, + mask_pow=args.mask_power if args.animation_mode == '3D render' else None, + mask_multiplier=strength, + blur_radius=None, + min_val=mask_min_value) + + # generate the next frame + sampler = sampler_from_string(args.sampler.lower()) + guidance = guidance_from_string(args.clip_guidance) + noise_scale = self.frame_args.noise_scale_curve[frame_idx] + adjusted_steps = int(max(5, steps*(1.0-init_strength))) if args.steps_strength_adj else int(steps) + generate_request = self.api.generate( + prompts, weights, + args.width, args.height, + steps=adjusted_steps, + seed=seed, + cfg_scale=args.cfg_scale, + sampler=sampler, + init_image=init_image if init_image_ops is None else None, + init_strength=init_strength, + init_noise_scale=noise_scale, + init_depth=init_depth, + mask = self.inpaint_mask if do_inpainting else self.mask, + masked_area_init=generation.MASKED_AREA_INIT_ORIGINAL, + guidance_preset=guidance, + preset=args.preset, + return_request=True + ) + image = self.api.transform_and_generate(init_image, init_image_ops, generate_request) + + if args.color_coherence != 'None' and frame_idx == 0: + self.color_match_images[0] = image + if not len(self.prior_frames): + self.prior_frames.append(image) + self.prior_diffused.append(image) + self.prior_xforms.append(matrix.identity) + + self.prior_frames.append(image) + self.prior_diffused.append(image) + self.prior_xforms.append(matrix.identity) + self.diffusion_cadence_ofs = frame_idx + out_frame = image if not self.cadence_on else self.prior_frames[0] + else: + assert self.cadence_on + # smoothly blend between prior frames + tween = ((frame_idx - self.diffusion_cadence_ofs) % diffusion_cadence) / float(diffusion_cadence) + out_frame = self.api.interpolate( + [self.prior_frames[0], self.prior_frames[1]], + [tween], + interpolate_mode_from_string(args.cadence_interp) + )[0] + + # save and return final frame + yield self.emit_frame(frame_idx, out_frame) + + if not args.locked_seed: + seed += 1 + + def save_settings(self, filename: str): + settings_filepath = os.path.join(self.out_dir, filename) if self.out_dir else filename + with open(settings_filepath, "w", encoding="utf-8") as f: + save_dict = args_to_dict(self.args) + for k in ['angle', 'zoom', 'translation_x', 'translation_y', 'translation_z', 'rotation_x', 'rotation_y', 'rotation_z']: + save_dict.move_to_end(k, last=True) + save_dict['animation_prompts'] = self.animation_prompts + save_dict['negative_prompt'] = self.negative_prompt + save_dict['negative_prompt_weight'] = self.negative_prompt_weight + json.dump(save_dict, f, ensure_ascii=False, indent=4) + + def save_to_out_dir(self, frame_idx: int, image: Image.Image, prefix: str = "frame"): + if self.out_dir is not None: + image.save(self.get_frame_filename(frame_idx, prefix=prefix)) + + def set_mask(self, mask: Image.Image): + self.mask = mask.convert('L').resize((self.args.width, self.args.height), resample=Image.LANCZOS) + + # this is intentionally flipped because we want white in the mask to represent + # areas that should change which is opposite from the backend which treats + # the mask as per pixel offset in the schedule starting value + if not self.args.mask_invert: + self.mask = ImageOps.invert(self.mask) + + def set_cadence_mode(self, enabled: bool): + def set_queue_size(prior_queue: deque, prev_length: int, new_length: int) -> deque: + assert new_length in (1, 2) + if new_length == prev_length: + return prior_queue + new_queue: deque = deque([], new_length) + if len(prior_queue) > 0: + if new_length == 2 and prev_length == 1: + new_queue.extend([prior_queue[0], prior_queue[0]]) + elif new_length == 1 and prev_length == 2: + new_queue.append(prior_queue[-1]) + return new_queue + + if enabled == self.cadence_on: + return + elif enabled: + self.prior_frames = set_queue_size(self.prior_frames, 1, 2) + self.prior_diffused = set_queue_size(self.prior_diffused, 1, 2) + self.prior_xforms = set_queue_size(self.prior_xforms, 1, 2) + else: + self.prior_frames = set_queue_size(self.prior_frames, 2, 1) + self.prior_diffused = set_queue_size(self.prior_diffused, 2, 1) + self.prior_xforms = set_queue_size(self.prior_xforms, 2, 1) + self.cadence_on = enabled + + def setup_animation(self, resume): + args = self.args + + # change request for random seed into explicit value so it is saved to settings + if args.seed <= 0: + args.seed = random.randint(0, 2**32 - 1) + + # select image generation model + self.api._generate.engine_id = args.custom_model if args.model == "custom" else args.model + + # validate border settings + if args.border == 'wrap' and args.animation_mode != '2D': + args.border = 'reflect' + logger.warning(f"Border 'wrap' is only supported in 2D mode, switching to '{args.border}'.") + + # validate clip guidance setting against selected model and sampler + if args.clip_guidance.lower() != 'none': + if not (model_supports_clip_guidance(args.model) and sampler_supports_clip_guidance(args.sampler)): + unsupported = args.model if not model_supports_clip_guidance(args.model) else args.sampler + logger.warning(f"CLIP guidance is not supported by {unsupported}, disabling guidance.") + args.clip_guidance = 'None' + + def curve_to_series(curve: str) -> List[float]: + return curve_from_cn_string(curve) + + # expand key frame strings to per frame series + frame_args_dict = {f.name: curve_to_series(getattr(args, f.name)) for f in fields(FrameArgs)} + self.frame_args = FrameArgs(**frame_args_dict) + + # prepare sorted list of key frames + self.key_frame_values = sorted(list(self.animation_prompts.keys())) + if self.key_frame_values[0] != 0: + raise ValueError("First keyframe must be 0") + if len(self.key_frame_values) != len(set(self.key_frame_values)): + raise ValueError("Duplicate keyframes are not allowed!") + + diffusion_cadence = max(1, int(self.frame_args.diffusion_cadence_curve[self.start_frame_idx])) + # initialize accumulated transforms + self.set_cadence_mode(enabled=(diffusion_cadence > 1)) + self.prior_xforms.extend([matrix.identity, matrix.identity]) + + # prepare inputs + self.load_mask() + self.load_video() + self.load_init_image() + + # handle resuming animation from last frames of a previous run + if resume: + if not self.out_dir: + raise ValueError("Cannot resume animation without out_dir specified") + frames = [f for f in os.listdir(self.out_dir) if f.endswith(".png") and f.startswith("frame_")] + self.start_frame_idx = len(frames) + self.diffusion_cadence_ofs = self.start_frame_idx + if self.start_frame_idx > 2: + prev = Image.open(self.get_frame_filename(self.start_frame_idx-2)) + next = Image.open(self.get_frame_filename(self.start_frame_idx-1)) + self.prior_frames.extend([prev, next]) + self.prior_diffused.extend([prev, next]) + elif self.start_frame_idx > 1 and not self.cadence_on: + prev = Image.open(self.get_frame_filename(self.start_frame_idx-1)) + self.prior_frames.append(prev) + self.prior_diffused.append(prev) + + def transform_2d(self, frame_idx) -> Optional[Image.Image]: + if not len(self.prior_frames): + return None + + # create xform for the current frame + xform = self.build_frame_xform(frame_idx) + + # check if we can skip transform request + if np.allclose(xform, matrix.identity): + return None + + args = self.args + if not args.inpaint_border: + # apply xform to prior frames running xforms + for i in range(len(self.prior_xforms)): + self.prior_xforms[i] = matrix.multiply(xform, self.prior_xforms[i]) + + # warp prior diffused frames by accumulated xforms + for i in range(len(self.prior_diffused)): + params = resample_transform(args.border, to_3x3(self.prior_xforms[i]), export_mask=args.inpaint_border) + xformed, mask = self.api.transform([self.prior_diffused[i]], params) + self.prior_frames[i] = xformed[0] + else: + params = resample_transform(args.border, to_3x3(xform), export_mask=args.inpaint_border) + transformed_prior_frames, mask = self.api.transform(self.prior_frames, params) + self.prior_frames.extend(transformed_prior_frames) + + return mask[0] if isinstance(mask, list) else mask + + def transform_3d(self, frame_idx) -> Optional[Image.Image]: + if not len(self.prior_frames): + return None + + args, frame_args = self.args, self.frame_args + near, far = args.near_plane, args.far_plane + fov = frame_args.fov_curve[frame_idx] + depth_blur = int(frame_args.depth_blur_curve[frame_idx]) + depth_warp = frame_args.depth_warp_curve[frame_idx] + + depth_calc = depth_calc_transform(args.depth_model_weight, depth_blur) + + # create xform for the current frame + world_view = self.build_frame_xform(frame_idx) + projection = matrix.projection_fov(math.radians(fov), 1.0, near, far) + + if False: + # currently disabled. for 3D mode transform accumulation needs additional + # depth map changes to work properly without swimming artifacts + + # apply world_view xform to prior frames running xforms + for i in range(len(self.prior_xforms)): + self.prior_xforms[i] = matrix.multiply(world_view, self.prior_xforms[i]) + + # warp prior diffused frames by accumulated xforms + for i in range(len(self.prior_diffused)): + wvp = matrix.multiply(projection, self.prior_xforms[i]) + resample = resample_transform(args.border, wvp, projection, depth_warp=depth_warp, export_mask=args.inpaint_border) + xformed, mask = self.api.transform_3d([self.prior_diffused[i]], depth_calc, resample) + self.prior_frames[i] = xformed[0] + else: + if args.animation_mode == '3D warp': + wvp = matrix.multiply(projection, world_view) + transform_op = resample_transform(args.border, wvp, projection, depth_warp=depth_warp, export_mask=args.inpaint_border) + else: + transform_op = camera_pose_transform( + world_view, near, far, fov, + args.camera_type, + render_mode=args.render_mode, + do_prefill=not args.use_inpainting_model) + transformed_prior_frames, mask = self.api.transform_3d(self.prior_frames, depth_calc, transform_op) + self.prior_frames.extend(transformed_prior_frames) + return mask[0] if isinstance(mask, list) else mask + + def transform_video(self, frame_idx) -> Optional[Image.Image]: + assert self.video_reader is not None + if not len(self.prior_frames): + return None + + args = self.args + for _ in range(args.extract_nth_frame): + success, video_next_frame = self.video_reader.read() + video_next_frame = cv2_to_pil(video_next_frame) + if success: + video_next_frame = self.image_resize(video_next_frame, 'cover') + mask = None + if args.video_flow_warp and video_next_frame is not None: + # warp_flow is in `extras` and will change in the future + prev_b64 = base64.b64encode(image_to_png_bytes(self.video_prev_frame)).decode('utf-8') + next_b64 = base64.b64encode(image_to_png_bytes(video_next_frame)).decode('utf-8') + extras = { "warp_flow": { "prev_frame": prev_b64, "next_frame": next_b64, "export_mask": args.inpaint_border } } + transformed_prior_frames, masks = self.api.transform(self.prior_frames, generation.TransformParameters(), extras=extras) + if masks is not None: + mask = masks[0] + self.prior_frames.extend(transformed_prior_frames) + self.video_prev_frame = video_next_frame + self.color_match_image = video_next_frame + return mask + return None + + def _postprocess_inpainting_mask( + self, + mask: Union[Image.Image, np.ndarray], + mask_pow: Optional[float] = None, + mask_multiplier: Optional[float] = None, + binarize: bool = False, + blur_radius: Optional[int] = None, + min_val: Optional[float] = None + ) -> Image.Image: + # Being applied in 3D render mode. Camera pose transform operation returns a mask which pixel values encode + # how much signal from the previous frame is present there. But a mapping from the signal presence values + # to the optimal per-pixel init strength is unknown, and roughly guessed as a per-pixel power function. + # Leaving mask_pow=1 results in near objects changing to a greater extent than a natural emergence of fine details when approaching an object. + if isinstance(mask, Image.Image): + mask = np.array(mask) + if mask_pow is not None: + mask = (np.power(mask / 255., mask_pow) * 255).astype(np.uint8) + if mask_multiplier is not None: + mask = (mask * mask_multiplier).astype(np.uint8) + if binarize: + mask = np.where(mask > self.args.mask_binarization_thr * 255, 255, 0).astype(np.uint8) + if blur_radius: + kernel_size = blur_radius*2+1 + mask = cv2.erode(mask, np.ones((kernel_size, kernel_size), np.uint8)) + mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0) + if min_val is not None: + mask = mask.clip(255 * min_val, 255).astype(np.uint8) + return Image.fromarray(mask) + + def _render_frame( + self, + frame_idx: int, + seed: int, + init: Optional[Image.Image]=None, + mask: Optional[Image.Image]=None, + strength: Optional[float]=None + ) -> Image.Image: + args = self.args + steps = int(self.frame_args.steps_curve[frame_idx]) + strength = strength if strength is not None else max(0.0, self.frame_args.strength_curve[frame_idx]) + adjusted_steps = int(max(5, steps*(1.0-strength))) if args.steps_strength_adj else int(steps) + + # fetch set of prompts and weights for this frame + prompts, weights = self.get_animation_prompts_weights(frame_idx) + if len(self.negative_prompt) and self.negative_prompt_weight != 0.0: + prompts.append(self.negative_prompt) + weights.append(-abs(self.negative_prompt_weight)) + + init_ops = self.prepare_init_ops(init, frame_idx, seed) + + sampler = sampler_from_string(args.sampler.lower()) + guidance = guidance_from_string(args.clip_guidance) + generate_request = self.api.generate( + prompts, weights, + args.width, args.height, + steps = adjusted_steps, + seed = seed, + cfg_scale = args.cfg_scale, + sampler = sampler, + init_image = init if init_ops is None else None, + init_strength = strength if init is not None else 0.0, + init_noise_scale = self.frame_args.noise_scale_curve[frame_idx], + mask = mask if mask is not None else self.mask, + masked_area_init = generation.MASKED_AREA_INIT_ORIGINAL, + guidance_preset = guidance, + preset = args.preset, + return_request = True + ) + + result_image = self.api.transform_and_generate(init, init_ops, generate_request) + + if args.color_coherence != 'None' and frame_idx == 0: + self.color_match_images[0] = result_image + + return result_image + + def _span_render(self, start: int, end: int, prev_frame: Image.Image, next_seed: Callable[[], int]) -> Generator[Tuple[int, Image.Image], None, None]: + args = self.args + + def apply_xform(frame: Image.Image, xform: matrix.Matrix, frame_idx: int) -> Tuple[Image.Image, Image.Image]: + args, frame_args = self.args, self.frame_args + if args.animation_mode == '2D': + xform = to_3x3(xform) + frames, masks = self.api.transform([frame], resample_transform(args.border, xform, export_mask=True)) + else: + fov = frame_args.fov_curve[frame_idx] + depth_blur = int(frame_args.depth_blur_curve[frame_idx]) + depth_warp = frame_args.depth_warp_curve[frame_idx] + projection = matrix.projection_fov(math.radians(fov), 1.0, args.near_plane, args.far_plane) + wvp = matrix.multiply(projection, xform) + depth_calc = depth_calc_transform(args.depth_model_weight, depth_blur) + resample = resample_transform(args.border, wvp, projection, depth_warp=depth_warp, export_mask=True) + frames, masks = self.api.transform_3d([frame], depth_calc, resample) + masks = cast(List[Image.Image], masks) + return frames[0], masks[0] + + # transform the previous frame forward + accum_xform = matrix.identity + forward_frames, forward_masks = [], [] + for frame_idx in range(start, end): + accum_xform = matrix.multiply(self.build_frame_xform(frame_idx), accum_xform) + frame, mask = apply_xform(prev_frame, accum_xform, frame_idx) + forward_frames.append(frame) + forward_masks.append(mask) + + # inpaint the final frame + if not np.all(forward_masks[-1]): + forward_frames[-1] = self.inpaint_frame( + end-1, forward_frames[-1], forward_masks[-1], + mask_blur_radius=0, seed=next_seed()) + + # run diffusion on top of the final result to allow content to evolve over time + strength = max(0.0, self.frame_args.strength_curve[end-1]) + if strength < 1.0: + final_frame = self._render_frame(end-1, next_seed(), forward_frames[-1]) + else: + final_frame = forward_frames[-1] + + # go backwards through the frames in the span + backward_frames, backward_masks = [final_frame], [Image.new('L', forward_masks[-1].size, 255)] + accum_xform = matrix.identity + for frame_idx in range(end-2, start-1, -1): + frame_xform = self.build_frame_xform(frame_idx+1) + inv_xform = np.linalg.inv(frame_xform).tolist() + accum_xform = matrix.multiply(inv_xform, accum_xform) + xformed, mask = apply_xform(backward_frames[-1], accum_xform, frame_idx) + backward_frames.insert(0, xformed) + backward_masks.insert(0, mask) + + # inpaint the backwards frame + if not np.all(backward_masks[0]): + backward_frames[0] = self.inpaint_frame( + start, backward_frames[0], backward_masks[0], + mask_blur_radius=0, seed=next_seed()) + + # yield the final frames blending from forward to backward + for idx, (frame_fwd, frame_bwd) in enumerate(zip(forward_frames, backward_frames)): + t = (idx) / max(1, end-start-1) + fwd_fill = image_mix(frame_bwd, frame_fwd, mask_erode_blur(forward_masks[idx], 8, 8)) + bwd_fill = image_mix(frame_fwd, frame_bwd, mask_erode_blur(backward_masks[idx], 8, 8)) + blended = self.api.interpolate( + [fwd_fill, bwd_fill], + [t], + interpolate_mode_from_string(args.cadence_interp) + )[0] + yield start+idx, blended + + def _spans_render(self) -> Generator[Tuple[int, Image.Image], None, None]: + frame_idx = self.start_frame_idx + seed = self.args.seed + def next_seed() -> int: + nonlocal seed + if not self.args.locked_seed: + seed += 1 + return seed + + prev_frame = self._render_frame(frame_idx, seed, None) + yield frame_idx, prev_frame + + while frame_idx < self.args.max_frames: + # determine how many frames the span will process together + diffusion_cadence = max(1, int(self.frame_args.diffusion_cadence_curve[frame_idx])) + if frame_idx + diffusion_cadence > self.args.max_frames: + diffusion_cadence = self.args.max_frames - frame_idx + + # render all frames in the span + for idx, frame in self._span_render(frame_idx, frame_idx + diffusion_cadence, prev_frame, next_seed): + yield idx, frame + prev_frame = frame + next_seed() + + frame_idx += diffusion_cadence diff --git a/src/stability_sdk/animation_ui.py b/src/stability_sdk/animation_ui.py new file mode 100644 index 00000000..88a2010a --- /dev/null +++ b/src/stability_sdk/animation_ui.py @@ -0,0 +1,835 @@ +import glob +import json +import locale +import os +import param +import shutil +import traceback + +from collections import OrderedDict +from PIL import Image +from tqdm import tqdm +from typing import Any, Dict, List, Optional + +try: + import gradio as gr +except ImportError: + raise ImportError( + "Failed to import animation UI requirements. To use the animation UI, install the dependencies with:\n" + " pip install --upgrade stability_sdk[anim_ui]" + ) + +from .api import ( + ClassifierException, + Context, + OutOfCreditsException, +) +from .animation import ( + AnimationArgs, + Animator, + AnimationSettings, + BasicSettings, + CameraSettings, + CoherenceSettings, + ColorSettings, + DepthSettings, + InpaintingSettings, + Rendering3dSettings, + VideoInputSettings, + VideoOutputSettings, + interpolate_frames +) +from .utils import ( + create_video_from_frames, + extract_frames_from_video, + interpolate_mode_from_string +) + + +DATA_VERSION = "0.1" +DATA_GENERATOR = "stability_sdk.animation_ui" + +PRESETS = { + "Default": {}, + "3D warp rotate": { + "animation_mode": "3D warp", "rotation_y":"0:(0.4)", "translation_x":"0:(-1.2)", "depth_model_weight":1.0, + "animation_prompts": "{\n0:\"a flower vase on a table\"\n}" + }, + "3D warp zoom": { + "animation_mode":"3D warp", "diffusion_cadence_curve":"0:(4)", "noise_scale_curve":"0:(1.04)", + "strength_curve":"0:(0.7)", "translation_z":"0:(1.0)", + }, + "3D render rotate": { + "animation_mode": "3D render", "depth_model_weight":1.0, + "translation_x":"0:(-3.5)", "rotation_y":"0:(1.7)", "translation_z":"0:(-0.5)", + "diffusion_cadence_curve":"0:(1)", "strength_curve":"0:(0.96)", "noise_scale_curve":"0:(1.01)", + "mask_min_value":"0:(0.35)", "use_inpainting_model":False, "preset": "anime", + "animation_prompts": "{\n0:\"beautiful portrait of a ninja in a sunflower field\"\n}" + }, + "3D render explore": { + "animation_mode": "3D render", "translation_z":"0:(10)", "translation_x":"0:(2), 20:(-2), 40:(2)", + "rotation_y":"0:(0), 10:(1.5), 30:(-2), 50: (3)", "rotation_x":"0:(0.4)", + "diffusion_cadence_curve":"0:(1)", "strength_curve":"0:(0.98)", + "noise_scale_curve":"0:(1.01)", "depth_model_weight":1.0, + "mask_min_value":"0:(0.1)", "use_inpainting_model":False, "preset":"3d-model", + "animation_prompts": "{\n0:\"Phantasmagoric carnival, carnival attractions shifting and changing, bizarre surreal circus\"\n}" + }, + "Prompt interpolate": { + "animation_mode":"2D", "interpolate_prompts":True, "locked_seed":True, "max_frames":24, + "strength_curve":"0:(0)", "diffusion_cadence_curve":"0:(4)", "cadence_interp":"film", + "clip_guidance":"None", "animation_prompts": "{\n0:\"a photo of a cute cat\",\n24:\"a photo of a cute dog\"\n}" + }, + "Translate and inpaint": { + "animation_mode":"2D", "inpaint_border":True, "use_inpainting_model":False, "translation_x":"0:(-20)", + "diffusion_cadence_curve":"0:(3)", "strength_curve":"0:(0.85)", "noise_scale_curve":"0:(1.01)", "border":"reflect", + "animation_prompts": "{\n0:\"Mystical pumpkin field landscapes on starry Halloween night, pop surrealism art\"\n}" + }, + "Outpaint": { + "animation_mode":"2D", "diffusion_cadence_curve":"0:(16)", "cadence_spans":True, "use_inpainting_model":True, + "strength_curve":"0:(1)", "reverse":True, "preset": "fantasy-art", "inpaint_border":True, "zoom":"0:(0.95)", + "animation_prompts": "{\n0:\"an ancient and magical portal, in a fantasy corridor\"\n}" + }, + "Video Stylize": { + "animation_mode":"Video Input", "model":"stable-diffusion-depth-v2-0", "locked_seed":True, + "strength_curve":"0:(0.22)", "clip_guidance":"None", "video_mix_in_curve":"0:(1.0)", "video_flow_warp":True, + }, +} + +class Project(): + def __init__(self, title, settings={}) -> None: + self.folder = title.replace("/", "_").replace("\\", "_").replace(":", "") + self.settings = settings + self.title = title + + @classmethod + def list_projects(cls) -> List["Project"]: + projects = [] + for path in os.listdir(outputs_path): + directory = os.path.join(outputs_path, path) + if not os.path.isdir(directory): + continue + + json_files = glob.glob(os.path.join(directory, '*.json')) + json_files = sorted(json_files, key=lambda x: os.stat(x).st_mtime) + if not json_files: + continue + + filename = os.path.basename(json_files[-1]) + if not '(' in filename: + continue + + project = cls(filename[:filename.rfind('(')-1].strip()) + try: + project.settings = json.load(open(os.path.join(directory, filename), 'r')) + except: + continue + projects.append(project) + return projects + + +context = None +outputs_path = None + +args_generation = BasicSettings() +args_animation = AnimationSettings() +args_camera = CameraSettings() +args_coherence = CoherenceSettings() +args_color = ColorSettings() +args_depth = DepthSettings() +args_render_3d = Rendering3dSettings() +args_inpaint = InpaintingSettings() +args_vid_in = VideoInputSettings() +args_vid_out = VideoOutputSettings() +arg_objs = ( + args_generation, + args_animation, + args_camera, + args_coherence, + args_color, + args_depth, + args_render_3d, + args_inpaint, + args_vid_in, + args_vid_out, +) + +animation_prompts = "{\n0: \"\"\n}" +negative_prompt = "blurry, low resolution" +negative_prompt_weight = -1.0 + +controls: Dict[str, gr.components.Component] = {} +header = gr.HTML("", show_progress=False) +interrupt = False +last_interp_factor = None +last_interp_mode = None +last_project_settings_path = None +last_upscale = None +projects: List[Project] = [] +project: Optional[Project] = None +resume_checkbox = gr.Checkbox(label="Resume", value=False, interactive=True) +resume_from_number = gr.Number(label="Resume from frame", value=-1, interactive=True, precision=0, + info="Positive frame number to resume from, or -1 to resume from the last") + +project_create_button = gr.Button("Create") +project_data_log = gr.Textbox(label="Status", visible=False) +project_load_button = gr.Button("Load") +project_new_title = gr.Text(label="Name", value="My amazing animation", interactive=True) +project_preset_dropdown = gr.Dropdown(label="Preset", choices=list(PRESETS.keys()), value=list(PRESETS.keys())[0], interactive=True) +project_row_create = None +project_row_import = None +project_row_load = None +projects_dropdown = gr.Dropdown([p.title for p in projects], label="Project", visible=True, interactive=True) + +project_import_button = gr.Button("Import") +project_import_file = gr.File(label="Project file", file_types=[".json", ".txt"], type="binary") +project_import_title = gr.Text(label="Name", value="Imported project", interactive=True) + + +def accordion_for_color(args: ColorSettings): + p = args.param + with gr.Accordion("Color", open=False): + controls["color_coherence"] = gr.Dropdown(label="Color coherence", choices=p.color_coherence.objects, value=p.color_coherence.default, interactive=True) + with gr.Row(): + controls["brightness_curve"] = gr.Text(label="Brightness curve", value=p.brightness_curve.default, interactive=True) + controls["contrast_curve"] = gr.Text(label="Contrast curve", value=p.contrast_curve.default, interactive=True) + with gr.Row(): + controls["hue_curve"] = gr.Text(label="Hue curve", value=p.hue_curve.default, interactive=True) + controls["saturation_curve"] = gr.Text(label="Saturation curve", value=p.saturation_curve.default, interactive=True) + controls["lightness_curve"] = gr.Text(label="Lightness curve", value=p.lightness_curve.default, interactive=True) + controls["color_match_animate"] = gr.Checkbox(label="Animated color match", value=p.color_match_animate.default, interactive=True) + +def accordion_from_args(name: str, args: param.Parameterized, exclude: List[str]=[], open=False): + with gr.Accordion(name, open=open): + ui_from_args(args, exclude) + +def args_reset_to_defaults(): + for args in arg_objs: + for k, v in args.param.objects().items(): + if k == "name": + continue + setattr(args, k, v.default) + +def args_to_controls(data: Optional[dict]=None) -> dict: + # go through all the parameters and load their settings from the data + global animation_prompts, negative_prompt + if data: + for arg in arg_objs: + for k, v in arg.param.objects().items(): + if k != "name" and k in data: + arg.param.set_param(k, data[k]) + if "animation_prompts" in data: + animation_prompts = data["animation_prompts"] + if "negative_prompt" in data: + negative_prompt = data["negative_prompt"] + + returns = {} + returns[controls['animation_prompts']] = gr.update(value=animation_prompts) + returns[controls['negative_prompt']] = gr.update(value=negative_prompt) + + for args in arg_objs: + for k, v in args.param.objects().items(): + if k in controls: + c = controls[k] + returns[c] = gr.update(value=getattr(args, k)) + + return returns + +def ensure_api_context(): + if context is None: + raise gr.Error("Not connected to Stability API") + +def format_header_html() -> str: + try: + balance, profile_picture = context.get_user_info() + except: + return "" + formatted_number = locale.format_string("%d", balance, grouping=True) + return f""" +
+
Stable Animation UI
+
+ + + + + + + {formatted_number} +
+ user avatar +
+
+
+ """ + +def get_default_project(): + data = OrderedDict(AnimationArgs().param.values()) + data.update({ + "version": DATA_VERSION, + "generator": DATA_GENERATOR + }) + return data + +def post_process_tab(): + with gr.Row(): + with gr.Column(): + with gr.Row(visible=False): + use_video_instead = gr.Checkbox(label="Postprocess a video instead", value=False, interactive=True) + video_to_postprocess = gr.Text(label="Videofile to postprocess", value="", interactive=True) + fps = gr.Number(label="Output FPS", value=24, interactive=True, precision=0) + reverse = gr.Checkbox(label="Reverse", value=False, interactive=True) + with gr.Row(): + frame_interp_mode = gr.Dropdown(label="Frame interpolation mode", choices=['None', 'film', 'rife'], value='None', interactive=True) + frame_interp_factor = gr.Dropdown(label="Frame interpolation factor", choices=[2, 4, 8], value=2, interactive=True) + with gr.Row(): + upscale = gr.Checkbox(label="Upscale 2X", value=False, interactive=True) + with gr.Column(): + image_out = gr.Image(label="image", visible=True) + video_out = gr.Video(label="video", visible=False) + process_button = gr.Button("Process") + stop_button = gr.Button("Stop", visible=False) + error_log = gr.Textbox(label="Error", lines=3, visible=False) + + def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: int, upscale: bool, + use_video_instead: bool, video_to_postprocess: str): + global interrupt, last_interp_factor, last_interp_mode, last_upscale + interrupt = False + if not use_video_instead and last_project_settings_path is None: + raise gr.Error("Please render an animation first or specify a videofile to postprocess") + if use_video_instead and not os.path.exists(video_to_postprocess): + raise gr.Error("Videofile does not exist") + + yield { + header: gr.update(), + image_out: gr.update(visible=True, label=""), + video_out: gr.update(visible=False), + process_button: gr.update(visible=False), + stop_button: gr.update(visible=True), + error_log: gr.update(visible=False), + } + + error = None + try: + outdir = os.path.dirname(last_project_settings_path) \ + if not use_video_instead \ + else extract_frames_from_video(video_to_postprocess) + suffix = "" + + can_skip_upscale = last_upscale == upscale + can_skip_interp = can_skip_upscale and last_interp_factor == interp_factor and last_interp_mode == interp_mode + + if upscale: + suffix += "_x2" + upscale_dir = os.path.join(outdir, "upscale") + os.makedirs(upscale_dir, exist_ok=True) + frame_paths = sorted(glob.glob(os.path.join(outdir, "frame_*.png"))) + num_frames = len(frame_paths) + if not can_skip_upscale: + remove_frames_from_path(upscale_dir) + for frame_idx in tqdm(range(num_frames)): + frame = Image.open(frame_paths[frame_idx]) + frame = context.upscale(frame) + frame.save(os.path.join(upscale_dir, os.path.basename(frame_paths[frame_idx]))) + yield { + header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(), + image_out: gr.update(value=frame, label=f"upscale {frame_idx}/{num_frames}", visible=True), + video_out: gr.update(visible=False), + process_button: gr.update(visible=False), + stop_button: gr.update(visible=True), + error_log: gr.update(visible=False), + } + if interrupt: + break + last_upscale = upscale + outdir = upscale_dir + + if interp_mode != 'None': + suffix += f"_{interp_mode}{interp_factor}" + interp_dir = os.path.join(outdir, "interpolate") + interp_mode = interpolate_mode_from_string(interp_mode) + if not can_skip_interp: + remove_frames_from_path(interp_dir) + num_frames = interp_factor * len(glob.glob(os.path.join(outdir, "frame_*.png"))) + for frame_idx, frame in enumerate(tqdm(interpolate_frames(context, outdir, interp_dir, interp_mode, interp_factor), total=num_frames)): + yield { + header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(), + image_out: gr.update(value=frame, label=f"interpolate {frame_idx}/{num_frames}", visible=True), + video_out: gr.update(visible=False), + process_button: gr.update(visible=False), + stop_button: gr.update(visible=True), + error_log: gr.update(visible=False), + } + if interrupt: + break + last_interp_mode, last_interp_factor = interp_mode, interp_factor + outdir = interp_dir + + if not use_video_instead: + output_video = last_project_settings_path.replace(".json", f"{suffix}.mp4") + else: + _, video_ext = os.path.splitext(video_to_postprocess) + output_video = video_to_postprocess.replace(video_ext, f"{suffix}.mp4") + create_video_from_frames(outdir, output_video, fps=fps, reverse=reverse) + except Exception as e: + traceback.print_exc() + error = f"Post-processing terminated early due to exception: {e}" + + yield { + header: gr.update(value=format_header_html()), + image_out: gr.update(visible=False), + video_out: gr.update(value=output_video, visible=True), + process_button: gr.update(visible=True), + stop_button: gr.update(visible=False), + error_log: gr.update(value=error, visible=bool(error)) + } + + process_button.click( + postprocess_video, + inputs=[fps, reverse, frame_interp_mode, frame_interp_factor, upscale, use_video_instead, video_to_postprocess], + outputs=[header, image_out, video_out, process_button, stop_button, error_log] + ) + + def stop_button_click(): + global interrupt + interrupt = True + stop_button.click(stop_button_click) + + +def project_create(title, preset): + ensure_api_context() + global project, projects + titles = [p.title for p in projects] + if title in titles: + raise gr.Error(f"Project with title '{title}' already exists") + project = Project(title, get_default_project()) + projects.append(project) + projects = sorted(projects, key=lambda p: p.title) + + # grab each setting from the preset and add to settings + for k, v in PRESETS[preset].items(): + project.settings[k] = v + + log = f"Created project '{title}'" + + args_reset_to_defaults() + returns = args_to_controls(project.settings) + returns[project_data_log] = gr.update(value=log, visible=True) + returns[projects_dropdown] = gr.update(choices=[p.title for p in projects], visible=True, value=title) + returns[project_row_load] = gr.update(visible=len(projects) > 0) + return returns + +def project_import(title, file): + ensure_api_context() + global project, projects + titles = [p.title for p in projects] + if title in titles: + raise gr.Error(f"Project with title '{title}' already exists") + + # read json from file + try: + settings = json.loads(file.decode('utf-8')) + except Exception as e: + raise gr.Error(f"Failed to read settings from file: {e}") + + project = Project(title, settings) + projects.append(project) + projects = sorted(projects, key=lambda p: p.title) + + log = f"Imported project '{title}'" + + args_reset_to_defaults() + returns = args_to_controls(project.settings) + returns[project_data_log] = gr.update(value=log, visible=True) + returns[projects_dropdown] = gr.update(choices=[p.title for p in projects], visible=True, value=title) + returns[project_row_load] = gr.update(visible=len(projects) > 0) + return returns + +def project_load(title: str): + ensure_api_context() + global project + project = next(p for p in projects if p.title == title) + data = project.settings + + log = f"Loaded project '{title}'" + + # filter project file to latest version + if "animation_mode" in data and data["animation_mode"] == "3D": + data["animation_mode"] = "3D warp" + if "midas_weight" in data: + data["depth_model_weight"] = data["midas_weight"] + del data["midas_weight"] + + # update the ui controls + returns = args_to_controls(data) + returns[project_data_log] = gr.update(value=log, visible=True) + return returns + +def project_tab(): + global project_row_create, project_row_import, project_row_load + + button_load_projects = gr.Button("Load Projects") + with gr.Accordion("Load a project", open=True, visible=False) as projects_row_: + project_row_load = projects_row_ + with gr.Row(): + projects_dropdown.render() + with gr.Column(): + project_load_button.render() + with gr.Row(): + delete_btn = gr.Button("Delete") + confirm_btn = gr.Button("Confirm delete", variant="stop", visible=False) + cancel_btn = gr.Button("Cancel", visible=False) + delete_btn.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)], None, [delete_btn, confirm_btn, cancel_btn]) + cancel_btn.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, [delete_btn, confirm_btn, cancel_btn]) + + with gr.Accordion("Create a new project", open=True, visible=False) as project_row_create_: + project_row_create = project_row_create_ + with gr.Column(): + with gr.Row(): + project_new_title.render() + project_preset_dropdown.render() + with gr.Column(): + project_create_button.render() + + with gr.Accordion("Import a project file", open=False, visible=False) as project_row_import_: + project_row_import = project_row_import_ + with gr.Column(): + with gr.Row(): + project_import_title.render() + project_import_file.render() + with gr.Column(): + project_import_button.render() + + project_data_log.render() + + def delete_project(title: str): + ensure_api_context() + global project, projects + + project = next(p for p in projects if p.title == title) + project_path = os.path.join(outputs_path, project.folder) + if os.path.exists(project_path): + shutil.rmtree(project_path) + + projects.remove(project) + project = None + + log = f"Deleted project \"{title}\" at \"{project_path}\"" + return { + projects_dropdown: gr.update(choices=[p.title for p in projects], visible=True), + project_row_load: gr.update(visible=len(projects) > 0), + project_data_log: gr.update(value=log, visible=True), + delete_btn: gr.update(visible=True), + confirm_btn: gr.update(visible=False), + cancel_btn: gr.update(visible=False) + } + + def load_projects(): + ensure_api_context() + global projects + projects = Project.list_projects() + return { + button_load_projects: gr.update(visible=False), + projects_dropdown: gr.update(choices=[p.title for p in projects], visible=True), + project_row_create: gr.update(visible=True), + project_row_import: gr.update(visible=True), + project_row_load: gr.update(visible=len(projects) > 0), + header: gr.update(value=format_header_html()) + } + + button_load_projects.click(load_projects, outputs=[button_load_projects, projects_dropdown, project_row_create, project_row_import, project_row_load, header]) + confirm_btn.click(delete_project, inputs=projects_dropdown, outputs=[projects_dropdown, project_row_load, project_data_log, delete_btn, confirm_btn, cancel_btn]) + +def remove_frames_from_path(path: str, leave_first: Optional[int]=None): + if os.path.isdir(path): + frames = sorted(glob.glob(os.path.join(path, "frame_*.png"))) + if leave_first: + frames = frames[leave_first:] + for f in frames: + os.remove(f) + +def render_tab(): + with gr.Row(): + with gr.Column(): + ui_layout_tabs() + with gr.Column(): + image_out = gr.Image(label="image", visible=True) + video_out = gr.Video(label="video", visible=False) + button = gr.Button("Render") + button_stop = gr.Button("Stop", visible=False) + error_log = gr.Textbox(label="Error", lines=3, visible=False) + + def render(resume: bool, resume_from: int, *render_args): + global interrupt, last_interp_factor, last_interp_mode, last_project_settings_path, last_upscale, project + interrupt = False + + if not project: + raise gr.Error("No project active!") + + # create local folder for the project + outdir = os.path.join(outputs_path, project.folder) + os.makedirs(outdir, exist_ok=True) + + # each render gets a unique run index + run_index = 0 + while True: + project_settings_path = os.path.join(outdir, f"{project.folder} ({run_index}).json") + if not os.path.exists(project_settings_path): + break + run_index += 1 + + # gather up all the settings from sub-objects + args_d = {k: v for k, v in zip(controls.keys(), render_args)} + animation_prompts, negative_prompt = args_d['animation_prompts'], args_d['negative_prompt'] + del args_d['animation_prompts'], args_d['negative_prompt'] + args = AnimationArgs(**args_d) + + if args.animation_mode == "Video Input" and not args.video_init_path: + raise gr.Error("No video input file selected!") + + # convert animation_prompts from string (JSON or python) to dict + try: + prompts = json.loads(animation_prompts) + except json.JSONDecodeError: + try: + prompts = eval(animation_prompts) + except Exception as e: + raise gr.Error("Invalid JSON or Python code for animation_prompts!") + prompts = {int(k): v for k, v in prompts.items()} + + # save settings to a dict + save_dict = OrderedDict() + save_dict['version'] = DATA_VERSION + save_dict['generator'] = DATA_GENERATOR + save_dict.update(args.param.values()) + save_dict['animation_prompts'] = animation_prompts + save_dict['negative_prompt'] = negative_prompt + project.settings = save_dict + with open(project_settings_path, 'w', encoding='utf-8') as f: + json.dump(save_dict, f, indent=4) + + # initial yield to switch render button to stop button + yield { + button: gr.update(visible=False), + button_stop: gr.update(visible=True), + image_out: gr.update(visible=True, label=""), + video_out: gr.update(visible=False), + header: gr.update(), + error_log: gr.update(visible=False), + } + + # delete frames from previous animation + if resume: + if resume_from > 0: + remove_frames_from_path(outdir, resume_from) + elif resume_from == 0 or resume_from < -1: + raise gr.Error("Frame number to resume from must be positive, or -1 to resume from the last frame") + else: + remove_frames_from_path(outdir) + + frame_idx, error = 0, None + try: + animator = Animator( + api_context=context, + animation_prompts=prompts, + args=args, + out_dir=outdir, + negative_prompt=negative_prompt, + negative_prompt_weight=negative_prompt_weight, + resume=resume, + ) + for frame_idx, frame in enumerate(tqdm(animator.render(), initial=animator.start_frame_idx, total=args.max_frames), start=animator.start_frame_idx): + if interrupt: + break + + # saving frames to project + #frame_uuid = project.put_image_asset(frame) + + yield { + button: gr.update(visible=False), + button_stop: gr.update(visible=True), + image_out: gr.update(value=frame, label=f"frame {frame_idx}/{args.max_frames}", visible=True), + video_out: gr.update(visible=False), + header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(), + error_log: gr.update(visible=False), + } + except ClassifierException as e: + error = "Animation terminated early due to NSFW classifier." + if e.prompt is not None: + error += "\nPlease revise your prompt: " + e.prompt + except OutOfCreditsException as e: + error = f"Animation terminated early, out of credits.\n{e.details}" + except Exception as e: + traceback.print_exc() + error = f"Animation terminated early due to exception: {e}" + + if frame_idx: + last_project_settings_path = project_settings_path + last_interp_factor, last_interp_mode, last_upscale = None, None, None + output_video = project_settings_path.replace(".json", ".mp4") + try: + create_video_from_frames(outdir, output_video, fps=args.fps, reverse=args.reverse) + except RuntimeError as e: + error = f"Error creating video: {e}" + output_video = None + else: + output_video = None + yield { + button: gr.update(visible=True), + button_stop: gr.update(visible=False), + image_out: gr.update(visible=False), + video_out: gr.update(value=output_video, visible=True), + header: gr.update(value=format_header_html()), + error_log: gr.update(value=error, visible=bool(error)), + } + + button.click( + render, + inputs=[resume_checkbox, resume_from_number] + list(controls.values()), + outputs=[button, button_stop, image_out, video_out, header, error_log] + ) + + # stop animation in progress + def stop(): + global interrupt + interrupt = True + button_stop.click(stop) + +def ui_for_animation_settings(args: AnimationSettings): + with gr.Row(): + controls["steps_strength_adj"] = gr.Checkbox(label="Steps strength adj", value=args.param.steps_strength_adj.default, interactive=True) + controls["interpolate_prompts"] = gr.Checkbox(label="Interpolate prompts", value=args.param.interpolate_prompts.default, interactive=True) + controls["locked_seed"] = gr.Checkbox(label="Locked seed", value=args.param.locked_seed.default, interactive=True) + controls["noise_add_curve"] = gr.Text(label="Noise add curve", value=args.param.noise_add_curve.default, interactive=True) + controls["noise_scale_curve"] = gr.Text(label="Noise scale curve", value=args.param.noise_scale_curve.default, interactive=True) + controls["strength_curve"] = gr.Text(label="Previous frame strength curve", value=args.param.strength_curve.default, interactive=True) + controls["steps_curve"] = gr.Text(label="Steps curve", value=args.param.steps_curve.default, interactive=True) + +def ui_for_generation(args: AnimationSettings): + p = args.param + with gr.Row(): + controls["width"] = gr.Number(label="Width", value=p.width.default, interactive=True, precision=0) + controls["height"] = gr.Number(label="Height", value=p.height.default, interactive=True, precision=0) + with gr.Row(): + controls["model"] = gr.Dropdown(label="Model", choices=p.model.objects, value=p.model.default, interactive=True) + controls["custom_model"] = gr.Text(label="Custom model", value=p.custom_model.default, interactive=True) + with gr.Row(): + controls["preset"] = gr.Dropdown(label="Style preset", choices=p.preset.objects, value=p.preset.default, interactive=True) + with gr.Row(): + controls["sampler"] = gr.Dropdown(label="Sampler", choices=p.sampler.objects, value=p.sampler.default, interactive=True) + controls["seed"] = gr.Number(label="Seed", value=p.seed.default, interactive=True, precision=0) + controls["cfg_scale"] = gr.Number(label="Guidance scale", value=p.cfg_scale.default, interactive=True) + controls["clip_guidance"] = gr.Dropdown(label="CLIP guidance", choices=p.clip_guidance.objects, value=p.clip_guidance.default, interactive=True) + +def ui_for_init_and_mask(args_generation): + p = args_generation.param + with gr.Row(): + controls["init_image"] = gr.Text(label="Init image", value=p.init_image.default, interactive=True) + controls["init_sizing"] = gr.Dropdown(label="Init sizing", choices=p.init_sizing.objects, value=p.init_sizing.default, interactive=True) + with gr.Row(): + controls["mask_path"] = gr.Text(label="Mask path", value=p.mask_path.default, interactive=True) + controls["mask_invert"] = gr.Checkbox(label="Mask invert", value=p.mask_invert.default, interactive=True) + +def ui_for_video_output(args: VideoOutputSettings): + p = args.param + controls["fps"] = gr.Number(label="FPS", value=p.fps.default, interactive=True, precision=0) + controls["reverse"] = gr.Checkbox(label="Reverse", value=p.reverse.default, interactive=True) + +def ui_from_args(args: param.Parameterized, exclude: List[str]=[]): + for k, v in args.param.objects().items(): + if k == "name" or k in exclude: + continue + if isinstance(v, param.Boolean): + t = gr.Checkbox(label=v.label, value=v.default, interactive=True) + elif isinstance(v, param.Integer): + t = gr.Number(label=v.label, value=v.default, interactive=True, precision=0) + elif isinstance(v, param.Number): + t = gr.Number(label=v.label, value=v.default, interactive=True) + elif isinstance(v, param.Selector): + t = gr.Dropdown(label=v.label, choices=v.objects, value=v.default, interactive=True) + elif isinstance(v, param.String): + t = gr.Text(label=v.label, value=v.default, interactive=True) + else: + raise Exception(f"Unknown parameter type {v} for param {k}") + controls[k] = t + +def ui_layout_tabs(): + with gr.Tab("Prompts"): + with gr.Row(): + controls['animation_prompts'] = gr.TextArea(label="Animation prompts", max_lines=8, value=animation_prompts, interactive=True) + with gr.Row(): + controls['negative_prompt'] = gr.Textbox(label="Negative prompt", max_lines=1, value=negative_prompt, interactive=True) + with gr.Tab("Config"): + with gr.Row(): + args = args_animation + controls["animation_mode"] = gr.Dropdown(label="Animation mode", choices=args.param.animation_mode.objects, value=args.param.animation_mode.default, interactive=True) + controls["max_frames"] = gr.Number(label="Max frames", value=args.param.max_frames.default, interactive=True, precision=0) + controls["border"] = gr.Dropdown(label="Border", choices=args.param.border.objects, value=args.param.border.default, interactive=True) + ui_for_generation(args_generation) + ui_for_animation_settings(args_animation) + accordion_from_args("Coherence", args_coherence, open=False) + accordion_for_color(args_color) + accordion_from_args("Depth", args_depth, exclude=["near_plane", "far_plane"], open=False) + accordion_from_args("3D render", args_render_3d, open=False) + accordion_from_args("Inpainting", args_inpaint, open=False) + with gr.Tab("Input"): + with gr.Row(): + resume_checkbox.render() + resume_from_number.render() + ui_for_init_and_mask(args_generation) + with gr.Column(): + p = args_vid_in.param + with gr.Row(): + controls["video_init_path"] = gr.Text(label="Video init path", value=p.video_init_path.default, interactive=True) + with gr.Row(): + controls["video_mix_in_curve"] = gr.Text(label="Mix in curve", value=p.video_mix_in_curve.default, interactive=True) + controls["extract_nth_frame"] = gr.Number(label="Extract nth frame", value=p.extract_nth_frame.default, interactive=True, precision=0) + controls["video_flow_warp"] = gr.Checkbox(label="Flow warp", value=p.video_flow_warp.default, interactive=True) + + with gr.Tab("Camera"): + p = args_camera.param + gr.Markdown("2D Camera") + controls["angle"] = gr.Text(label="Angle", value=p.angle.default, interactive=True) + controls["zoom"] = gr.Text(label="Zoom", value=p.zoom.default, interactive=True) + + gr.Markdown("2D and 3D Camera translation") + controls["translation_x"] = gr.Text(label="Translation X", value=p.translation_x.default, interactive=True) + controls["translation_y"] = gr.Text(label="Translation Y", value=p.translation_y.default, interactive=True) + controls["translation_z"] = gr.Text(label="Translation Z", value=p.translation_z.default, interactive=True) + + gr.Markdown("3D Camera rotation") + controls["rotation_x"] = gr.Text(label="Rotation X", value=p.rotation_x.default, interactive=True) + controls["rotation_y"] = gr.Text(label="Rotation Y", value=p.rotation_y.default, interactive=True) + controls["rotation_z"] = gr.Text(label="Rotation Z", value=p.rotation_z.default, interactive=True) + + with gr.Tab("Output"): + ui_for_video_output(args_vid_out) + + +def create_ui(api_context: Context, outputs_root_path: str): + global context, outputs_path, projects + context, outputs_path = api_context, outputs_root_path + + locale.setlocale(locale.LC_ALL, '') + + with gr.Blocks() as ui: + header.render() + + with gr.Tab("Project"): + project_tab() + + with gr.Tab("Render"): + render_tab() + + with gr.Tab("Post-process"): + post_process_tab() + + load_project_outputs = [project_data_log] + load_project_outputs.extend(controls.values()) + project_load_button.click(project_load, inputs=projects_dropdown, outputs=load_project_outputs) + + create_project_outputs = [project_data_log, projects_dropdown, project_row_load] + create_project_outputs.extend(controls.values()) + project_create_button.click(project_create, inputs=[project_new_title, project_preset_dropdown], outputs=create_project_outputs) + project_import_button.click(project_import, inputs=[project_import_title, project_import_file], outputs=create_project_outputs) + + return ui diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py new file mode 100644 index 00000000..2057961f --- /dev/null +++ b/src/stability_sdk/api.py @@ -0,0 +1,662 @@ +import grpc +import io +import logging +import random +import time + +from google.protobuf.struct_pb2 import Struct +from PIL import Image +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import stability_sdk.interfaces.gooseai.dashboard.dashboard_pb2 as dashboard +import stability_sdk.interfaces.gooseai.dashboard.dashboard_pb2_grpc as dashboard_grpc +import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation +import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc + +from .utils import ( + image_mix, + image_to_prompt, + tensor_to_prompt, +) + + +logger = logging.getLogger(__name__) +logger.setLevel(level=logging.INFO) + + +def open_channel(host: str, api_key: str = None, max_message_len: int = 20*1024*1024) -> grpc.Channel: + options=[ + ('grpc.max_send_message_length', max_message_len), + ('grpc.max_receive_message_length', max_message_len), + ] + if host.endswith(":443"): + call_credentials = [grpc.access_token_call_credentials(api_key)] + channel_credentials = grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), *call_credentials + ) + channel = grpc.secure_channel(host, channel_credentials, options=options) + else: + channel = grpc.insecure_channel(host, options=options) + return channel + + +class ClassifierException(Exception): + """Raised when server classifies generated content as inappropriate. + + Attributes: + classifier_result: Categories the result image exceeded the threshold for + prompt: The prompt that was classified as inappropriate + """ + def __init__(self, classifier_result: Optional[generation.ClassifierParameters]=None, prompt: Optional[str]=None): + self.classifier_result = classifier_result + self.prompt = prompt + +class OutOfCreditsException(Exception): + """Raised when account doesn't have enough credits to perform a request.""" + def __init__(self, details: str): + self.details = details + + +class Endpoint: + def __init__(self, stub, engine_id): + self.stub = stub + self.engine_id = engine_id + + +class Context: + def __init__( + self, + host: str="", + api_key: str=None, + stub: generation_grpc.GenerationServiceStub=None, + generate_engine_id: str="stable-diffusion-xl-beta-v2-2-2", + inpaint_engine_id: str="stable-inpainting-512-v2-0", + interpolate_engine_id: str="interpolation-server-v1", + transform_engine_id: str="transform-server-v1", + upscale_engine_id: str="esrgan-v1-x2plus", + ): + if not host and stub is None: + raise Exception("Must provide either GRPC host or stub to Api") + + channel = open_channel(host, api_key) if host else None + if not stub: + stub = generation_grpc.GenerationServiceStub(channel) + + self._dashboard_stub = dashboard_grpc.DashboardServiceStub(channel) if channel else None + + self._generate = Endpoint(stub, generate_engine_id) + self._inpaint = Endpoint(stub, inpaint_engine_id) + self._interpolate = Endpoint(stub, interpolate_engine_id) + self._transform = Endpoint(stub, transform_engine_id) + self._upscale = Endpoint(stub, upscale_engine_id) + + self._debug_no_chains = False + self._max_retries = 5 # retry request on RPC error + self._request_timeout = 30.0 # timeout in seconds for each request + self._retry_delay = 1.0 # base delay in seconds between retries, each attempt will double + self._retry_obfuscation = False # retry request with different seed on classifier obfuscation + self._retry_schedule_offset = 0.1 # increase schedule start by this amount on each retry after the first + + self._user_organization_id: Optional[str] = None + self._user_profile_picture: str = '' + + def generate( + self, + prompts: List[str], + weights: List[float], + width: int = 512, + height: int = 512, + steps: Optional[int] = None, + seed: Union[Sequence[int], int] = 0, + samples: int = 1, + cfg_scale: float = 7.0, + sampler: generation.DiffusionSampler = None, + init_image: Optional[Image.Image] = None, + init_strength: float = 0.0, + init_noise_scale: Optional[float] = None, + init_depth: Optional[Image.Image] = None, + mask: Optional[Image.Image] = None, + masked_area_init: generation.MaskedAreaInit = generation.MASKED_AREA_INIT_ORIGINAL, + guidance_preset: generation.GuidancePreset = generation.GUIDANCE_PRESET_NONE, + guidance_cuts: int = 0, + guidance_strength: float = 0.0, + preset: Optional[str] = None, + return_request: bool = False, + ) -> Dict[int, List[Any]]: + """ + Generate an image from a set of weighted prompts. + + :param prompts: List of text prompts + :param weights: List of prompt weights + :param width: Width of the generated image + :param height: Height of the generated image + :param steps: Number of steps to run the diffusion process + :param seed: Random seed for the starting noise + :param samples: Number of samples to generate + :param cfg_scale: Classifier free guidance scale + :param sampler: Sampler to use for the diffusion process + :param init_image: Initial image to use + :param init_strength: Strength of the initial image + :param init_noise_scale: Scale of the initial noise + :param mask: Mask to use (0 for pixels to change, 255 for pixels to keep) + :param masked_area_init: How to initialize the masked area + :param guidance_preset: Preset to use for CLIP guidance + :param guidance_cuts: Number of cuts to use with CLIP guidance + :param guidance_strength: Strength of CLIP guidance + :param preset: Style preset to use + :param return_request: Whether to return the request instead of running it + :return: dict mapping artifact type to data + """ + if not prompts and init_image is None: + raise ValueError("prompt and/or init_image must be provided") + + if (mask is not None) and (init_image is None) and not return_request: + raise ValueError("If mask_image is provided, init_image must also be provided") + + p = [generation.Prompt(text=prompt, parameters=generation.PromptParameters(weight=weight)) for prompt,weight in zip(prompts, weights)] + if init_image is not None: + p.append(image_to_prompt(init_image)) + if mask is not None: + p.append(image_to_prompt(mask, type=generation.ARTIFACT_MASK)) + if init_depth is not None: + p.append(image_to_prompt(init_depth, type=generation.ARTIFACT_DEPTH)) + + start_schedule = 1.0 - init_strength + image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale, + start_schedule, init_noise_scale, masked_area_init, + guidance_preset, guidance_cuts, guidance_strength) + + extras = Struct() + if preset and preset.lower() != 'none': + extras.update({ '$IPC': { "preset": preset } }) + + request = generation.Request(engine_id=self._generate.engine_id, prompt=p, image=image_params, extras=extras) + if return_request: + return request + + results = self._run_request(self._generate, request) + + return results + + def get_user_info(self) -> Tuple[float, str]: + """Get the number of credits the user has remaining and their profile picture.""" + if not self._user_organization_id: + user = self._dashboard_stub.GetMe(dashboard.EmptyRequest()) + self._user_profile_picture = user.profile_picture + self._user_organization_id = user.organizations[0].organization.id + organization = self._dashboard_stub.GetOrganization(dashboard.GetOrganizationRequest(id=self._user_organization_id)) + return organization.payment_info.balance * 100, self._user_profile_picture + + def inpaint( + self, + image: Image.Image, + mask: Image.Image, + prompts: List[str], + weights: List[float], + steps: Optional[int] = None, + seed: Union[Sequence[int], int] = 0, + samples: int = 1, + cfg_scale: float = 7.0, + sampler: generation.DiffusionSampler = None, + init_strength: float = 0.0, + init_noise_scale: Optional[float] = None, + masked_area_init: generation.MaskedAreaInit = generation.MASKED_AREA_INIT_ZERO, + guidance_preset: generation.GuidancePreset = generation.GUIDANCE_PRESET_NONE, + guidance_cuts: int = 0, + guidance_strength: float = 0.0, + preset: Optional[str] = None, + ) -> Dict[int, List[Any]]: + """ + Apply inpainting to an image. + + :param image: Source image + :param mask: Mask image with 0 for pixels to change and 255 for pixels to keep + :param prompts: List of text prompts + :param weights: List of prompt weights + :param steps: Number of steps to run + :param seed: Random seed + :param samples: Number of samples to generate + :param cfg_scale: Classifier free guidance scale + :param sampler: Sampler to use for the diffusion process + :param init_strength: Strength of the initial image + :param init_noise_scale: Scale of the initial noise + :param masked_area_init: How to initialize the masked area + :param guidance_preset: Preset to use for CLIP guidance + :param guidance_cuts: Number of cuts to use with CLIP guidance + :param guidance_strength: Strength of CLIP guidance + :param preset: Style preset to use + :return: dict mapping artifact type to data + """ + p = [generation.Prompt(text=prompt, parameters=generation.PromptParameters(weight=weight)) for prompt,weight in zip(prompts, weights)] + p.append(image_to_prompt(image)) + p.append(image_to_prompt(mask, type=generation.ARTIFACT_MASK)) + + width, height = image.size + start_schedule = 1.0-init_strength + image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale, + start_schedule, init_noise_scale, masked_area_init, + guidance_preset, guidance_cuts, guidance_strength) + + extras = Struct() + if preset and preset.lower() != 'none': + extras.update({ '$IPC': { "preset": preset } }) + + request = generation.Request(engine_id=self._inpaint.engine_id, prompt=p, image=image_params, extras=extras) + results = self._run_request(self._inpaint, request) + + return results + + def interpolate( + self, + images: Sequence[Image.Image], + ratios: List[float], + mode: generation.InterpolateMode = generation.INTERPOLATE_LINEAR, + ) -> List[Image.Image]: + """ + Interpolate between two images + + :param images: Two images with matching resolution + :param ratios: In-between ratios to interpolate at + :param mode: Interpolation mode + :return: One image for each ratio + """ + assert len(images) == 2 + assert len(ratios) >= 1 + + if len(ratios) == 1: + if ratios[0] == 0.0: + return [images[0]] + elif ratios[0] == 1.0: + return [images[1]] + elif mode == generation.INTERPOLATE_LINEAR: + return [image_mix(images[0], images[1], ratios[0])] + + p = [image_to_prompt(image) for image in images] + request = generation.Request( + engine_id=self._interpolate.engine_id, + prompt=p, + interpolate=generation.InterpolateParameters(ratios=ratios, mode=mode) + ) + + results = self._run_request(self._interpolate, request) + return results[generation.ARTIFACT_IMAGE] + + def transform_and_generate( + self, + image: Optional[Image.Image], + params: List[generation.TransformParameters], + generate_request: generation.Request, + extras: Optional[Dict] = None, + ) -> Image.Image: + extras_struct = None + if extras is not None: + extras_struct = Struct() + extras_struct.update(extras) + + if not params: + results = self._run_request(self._generate, generate_request) + return results[generation.ARTIFACT_IMAGE][0] + + assert image is not None + requests = [ + generation.Request( + engine_id=self._transform.engine_id, + requested_type=generation.ARTIFACT_TENSOR, + prompt=[image_to_prompt(image)], + transform=param, + extras=extras_struct, + ) for param in params + ] + + if self._debug_no_chains: + prev_result = None + for rq in requests: + if prev_result is not None: + rq.prompt.pop() + rq.prompt.append(tensor_to_prompt(prev_result)) + prev_result = self._run_request(self._transform, rq)[generation.ARTIFACT_TENSOR][0] + generate_request.prompt.append(tensor_to_prompt(prev_result)) + results = self._run_request(self._generate, generate_request) + else: + stages = [] + for idx, rq in enumerate(requests): + stages.append(generation.Stage( + id=str(idx), + request=rq, + on_status=[generation.OnStatus( + action=[generation.STAGE_ACTION_PASS], + target=str(idx+1) + )] + )) + stages.append(generation.Stage( + id=str(len(params)), + request=generate_request, + on_status=[generation.OnStatus( + action=[generation.STAGE_ACTION_RETURN], + target=None + )] + )) + chain_rq = generation.ChainRequest(request_id="xform_gen_chain", stage=stages) + results = self._run_request(self._transform, chain_rq) + + return results[generation.ARTIFACT_IMAGE][0] + + def transform( + self, + images: Sequence[Image.Image], + params: Union[generation.TransformParameters, List[generation.TransformParameters]], + extras: Optional[Dict] = None + ) -> Tuple[List[Image.Image], Optional[List[Image.Image]]]: + """ + Transform images + + :param images: One or more images to transform + :param params: Transform operations to apply to each image + :return: One image artifact for each image and one transform dependent mask + """ + assert len(images) + assert isinstance(images[0], Image.Image) + + extras_struct = None + if extras is not None: + extras_struct = Struct() + extras_struct.update(extras) + + if isinstance(params, List) and len(params) > 1: + if self._debug_no_chains: + for param in params: + images, mask = self.transform(images, param, extras) + return images, mask + + assert extras is None + stages = [] + for idx, param in enumerate(params): + final = idx == len(params) - 1 + rq = generation.Request( + engine_id=self._transform.engine_id, + prompt=[image_to_prompt(image) for image in images] if idx == 0 else None, + transform=param, + extras_struct=extras_struct + ) + stages.append(generation.Stage( + id=str(idx), + request=rq, + on_status=[generation.OnStatus( + action=[generation.STAGE_ACTION_PASS if not final else generation.STAGE_ACTION_RETURN], + target=str(idx+1) if not final else None + )] + )) + chain_rq = generation.ChainRequest(request_id="xform_chain", stage=stages) + results = self._run_request(self._transform, chain_rq) + else: + request = generation.Request( + engine_id=self._transform.engine_id, + prompt=[image_to_prompt(image) for image in images], + transform=params[0] if isinstance(params, List) else params, + extras=extras_struct + ) + results = self._run_request(self._transform, request) + + images = results.get(generation.ARTIFACT_IMAGE, []) + results.get(generation.ARTIFACT_DEPTH, []) + masks = results.get(generation.ARTIFACT_MASK, None) + return images, masks + + def transform_3d( + self, + images: Sequence[Image.Image], + depth_calc: generation.TransformParameters, + transform: generation.TransformParameters, + extras: Optional[Dict] = None + ) -> Tuple[List[Image.Image], Optional[List[Image.Image]]]: + assert len(images) + assert isinstance(images[0], Image.Image) + + image_prompts = [image_to_prompt(image) for image in images] + warped_images = [] + warp_mask = None + op_id = "resample" if transform.HasField("resample") else "camera_pose" + + extras_struct = Struct() + if extras is not None: + extras_struct.update(extras) + + rq_depth = generation.Request( + engine_id=self._transform.engine_id, + requested_type=generation.ARTIFACT_TENSOR, + prompt=[image_prompts[0]], + transform=depth_calc, + ) + rq_transform = generation.Request( + engine_id=self._transform.engine_id, + prompt=image_prompts, + transform=transform, + extras=extras_struct + ) + + if self._debug_no_chains: + results = self._run_request(self._transform, rq_depth) + rq_transform.prompt.append( + generation.Prompt( + artifact=generation.Artifact( + type=generation.ARTIFACT_TENSOR, + tensor=results[generation.ARTIFACT_TENSOR][0] + ) + ) + ) + results = self._run_request(self._transform, rq_transform) + else: + chain_rq = generation.ChainRequest( + request_id=f"{op_id}_3d_chain", + stage=[ + generation.Stage( + id="depth_calc", + request=rq_depth, + on_status=[generation.OnStatus(action=[generation.STAGE_ACTION_PASS], target=op_id)] + ), + generation.Stage( + id=op_id, + request=rq_transform, + on_status=[generation.OnStatus(action=[generation.STAGE_ACTION_RETURN])] + ) + ]) + results = self._run_request(self._transform, chain_rq) + + warped_images = results[generation.ARTIFACT_IMAGE] + warp_mask = results.get(generation.ARTIFACT_MASK, None) + + return warped_images, warp_mask + + def upscale( + self, + init_image: Image.Image, + width: Optional[int] = None, + height: Optional[int] = None, + prompt: Union[str, generation.Prompt] = None, + steps: Optional[int] = 20, + cfg_scale: Optional[float] = 7.0, + seed: int = 0 + ) -> Image.Image: + """ + Upscale an image. + + :param init_image: Image to upscale. + + Optional parameters for upscale method: + + :param width: Width of the output images. + :param height: Height of the output images. + :param prompt: Prompt used in text conditioned models + :param steps: Number of diffusion steps + :param cfg_scale: Intensity of the prompt, when a prompt is used + :param seed: Seed for the random number generator. + + Some variables are not used for specific engines, but are included for consistency. + + Variables ignored in ESRGAN engines: prompt, steps, cfg_scale, seed + + :return: Tuple of (prompts, image_parameters) + """ + + prompts = [image_to_prompt(init_image)] + if prompt: + if isinstance(prompt, str): + prompt = generation.Prompt(text=prompt) + elif not isinstance(prompt, generation.Prompt): + raise ValueError("prompt must be a string or Prompt object") + prompts.append(prompt) + + request = generation.Request( + engine_id=self._upscale.engine_id, + prompt=prompts, + image=generation.ImageParameters( + width=width, + height=height, + seed=[seed], + steps=steps, + parameters=[generation.StepParameter( + sampler=generation.SamplerParameters(cfg_scale=cfg_scale) + )], + ) + ) + results = self._run_request(self._upscale, request) + return results[generation.ARTIFACT_IMAGE][0] + + def _adjust_request_engine(self, request: generation.Request): + if request.engine_id == self._transform.engine_id: + assert request.HasField("transform") + if request.transform.HasField("color_adjust") or \ + (request.transform.HasField("resample") and len(request.transform.resample.transform.data) == 9): + request.engine_id = self._transform.engine_id + "-cpu" + + def _adjust_request_for_retry(self, request: generation.Request, attempt: int): + logger.warning(f" adjusting request, will retry {self._max_retries-attempt} more times") + request.image.seed[:] = [random.randrange(0, 4294967295) for _ in request.image.seed] + if attempt > 0 and request.image.parameters and request.image.parameters[0].HasField("schedule"): + schedule = request.image.parameters[0].schedule + if schedule.HasField("start"): + schedule.start = max(0.0, min(1.0, schedule.start + self._retry_schedule_offset)) + + def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale, + schedule_start, init_noise_scale, masked_area_init, + guidance_preset, guidance_cuts, guidance_strength): + + if not seed: + seed = [random.randrange(0, 4294967295)] + elif isinstance(seed, int): + seed = [seed] + else: + seed = list(seed) + + step_parameters = { + "scaled_step": 0, + "sampler": generation.SamplerParameters(cfg_scale=cfg_scale, init_noise_scale=init_noise_scale), + } + if schedule_start != 1.0: + step_parameters["schedule"] = generation.ScheduleParameters(start=schedule_start) + + if guidance_preset is not generation.GUIDANCE_PRESET_NONE: + cutouts = generation.CutoutParameters(count=guidance_cuts) if guidance_cuts else None + if guidance_strength == 0.0: + guidance_strength = None + step_parameters["guidance"] = generation.GuidanceParameters( + guidance_preset=guidance_preset, + instances=[ + generation.GuidanceInstanceParameters( + cutouts=cutouts, + guidance_strength=guidance_strength, + models=None, prompt=None + ) + ] + ) + + return generation.ImageParameters( + transform=None if sampler is None else generation.TransformType(diffusion=sampler), + height=height, + width=width, + seed=seed, + steps=steps, + samples=samples, + masked_area_init=masked_area_init, + parameters=[generation.StepParameter(**step_parameters)], + ) + + def _process_response(self, response) -> Dict[int, List[Any]]: + results: Dict[int, List[Any]] = {} + for resp in response: + for artifact in resp.artifacts: + # check for classifier rejecting a text prompt + if artifact.finish_reason == generation.FILTER and artifact.type == generation.ARTIFACT_TEXT: + raise ClassifierException(prompt=artifact.text) + + if artifact.type not in results: + results[artifact.type] = [] + + if artifact.type == generation.ARTIFACT_CLASSIFICATIONS: + results[artifact.type].append(artifact.classifier) + elif artifact.type in (generation.ARTIFACT_DEPTH, generation.ARTIFACT_IMAGE, generation.ARTIFACT_MASK): + image = Image.open(io.BytesIO(artifact.binary)) + results[artifact.type].append(image) + elif artifact.type == generation.ARTIFACT_TENSOR: + results[artifact.type].append(artifact.tensor) + elif artifact.type == generation.ARTIFACT_TEXT: + results[artifact.type].append(artifact.text) + + return results + + def _run_request( + self, + endpoint: Endpoint, + request: Union[generation.ChainRequest, generation.Request] + ) -> Dict[int, List[Any]]: + if isinstance(request, generation.Request): + self._adjust_request_engine(request) + elif isinstance(request, generation.ChainRequest): + for stage in request.stage: + self._adjust_request_engine(stage.request) + + for attempt in range(self._max_retries+1): + try: + if isinstance(request, generation.Request): + response = endpoint.stub.Generate(request, timeout=self._request_timeout) + else: + response = endpoint.stub.ChainGenerate(request, timeout=self._request_timeout) + + results = self._process_response(response) + + # check for classifier obfuscation + if generation.ARTIFACT_CLASSIFICATIONS in results: + for classifier in results[generation.ARTIFACT_CLASSIFICATIONS]: + if classifier.realized_action == generation.ACTION_OBFUSCATE: + raise ClassifierException(classifier) + + break + except ClassifierException as ce: + if attempt == self._max_retries or not self._retry_obfuscation or ce.prompt is not None: + raise ce + + for exceed in ce.classifier_result.exceeds: + logger.warning(f"Received classifier obfuscation. Exceeded {exceed.name} threshold") + + if isinstance(request, generation.Request) and request.HasField("image"): + self._adjust_request_for_retry(request, attempt) + elif isinstance(request, generation.ChainRequest): + for stage in request.stage: + if stage.request.HasField("image"): + self._adjust_request_for_retry(stage.request, attempt) + else: + raise ce + except grpc.RpcError as rpc_error: + if hasattr(rpc_error, "code"): + if rpc_error.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: + if "message larger than max" in rpc_error.details(): + raise rpc_error + raise OutOfCreditsException(rpc_error.details()) + elif rpc_error.code() == grpc.StatusCode.UNAUTHENTICATED: + raise rpc_error + + if attempt == self._max_retries: + raise rpc_error + + logger.warning(f"Received RpcError: {rpc_error} will retry {self._max_retries-attempt} more times") + time.sleep(self._retry_delay * 2**attempt) + return results diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py index f5100087..20547f00 100644 --- a/src/stability_sdk/client.py +++ b/src/stability_sdk/client.py @@ -2,65 +2,39 @@ # fmt: off -import pathlib -import sys +import getpass +import grpc +import logging +import mimetypes import os -import uuid import random -import io -import logging +import sys import time -import mimetypes +import uuid -import grpc from argparse import ArgumentParser, Namespace -from typing import Dict, Generator, List, Optional, Union, Any, Sequence, Tuple from google.protobuf.json_format import MessageToJson from google.protobuf.struct_pb2 import Struct from PIL import Image - -try: - from dotenv import load_dotenv -except ModuleNotFoundError: - pass -else: - load_dotenv() +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc -from stability_sdk.utils import ( +from .api import open_channel +from .utils import ( SAMPLERS, MAX_FILENAME_SZ, - artifact_type_to_str, - truncate_fit, - get_sampler_from_str, + artifact_type_to_string, + image_to_prompt, open_images, + sampler_from_string, + truncate_fit, ) - logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) -def image_to_prompt(im, init: bool = False, mask: bool = False) -> generation.Prompt: - if init and mask: - raise ValueError("init and mask cannot both be True") - buf = io.BytesIO() - im.save(buf, format="PNG") - buf.seek(0) - if mask: - return generation.Prompt( - artifact=generation.Artifact( - type=generation.ARTIFACT_MASK, binary=buf.getvalue() - ) - ) - return generation.Prompt( - artifact=generation.Artifact( - type=generation.ARTIFACT_IMAGE, binary=buf.getvalue() - ), - parameters=generation.PromptParameters(init=init), - ) - def process_artifacts_from_answers( prefix: str, @@ -100,17 +74,17 @@ def process_artifacts_from_answers( ext = ".pb" contents = artifact.SerializeToString() out_p = truncate_fit(prefix, prompt, ext, int(artifact_start), idx, MAX_FILENAME_SZ) - is_allowed_type = filter_types is None or artifact_type_to_str(artifact.type) in filter_types + is_allowed_type = filter_types is None or artifact_type_to_string(artifact.type) in filter_types if write: if is_allowed_type: with open(out_p, "wb") as f: f.write(bytes(contents)) if verbose: - logger.info(f"wrote {artifact_type_to_str(artifact.type)} to {out_p}") + logger.info(f"wrote {artifact_type_to_string(artifact.type)} to {out_p}") else: if verbose: logger.info( - f"skipping {artifact_type_to_str(artifact.type)} due to artifact type filter") + f"skipping {artifact_type_to_string(artifact.type)} due to artifact type filter") yield (out_p, artifact) idx += 1 @@ -142,7 +116,6 @@ def __init__( self.upscale_engine = upscale_engine self.grpc_args = {"wait_for_ready": wait_for_ready} - if verbose: logger.info(f"Opening channel to {host}") @@ -259,10 +232,10 @@ def generate( start=start_schedule, end=end_schedule, ) - prompts += [image_to_prompt(init_image, init=True)] + prompts += [image_to_prompt(init_image)] if mask_image is not None: - prompts += [image_to_prompt(mask_image, mask=True)] + prompts += [image_to_prompt(mask_image, type=generation.ARTIFACT_MASK)] if guidance_prompt: @@ -360,7 +333,7 @@ def upscale( parameters=[generation.StepParameter(**step_parameters)], ) - prompts = [image_to_prompt(init_image, init=True)] + prompts = [image_to_prompt(init_image)] if prompt: if isinstance(prompt, str): @@ -403,7 +376,7 @@ def emit_request( if self.verbose: if len(answer.artifacts) > 0: artifact_ts = [ - artifact_type_to_str(artifact.type) + artifact_type_to_string(artifact.type) for artifact in answer.artifacts ] logger.info( @@ -446,17 +419,12 @@ def process_cli(logger: logging.Logger = None, STABILITY_HOST = os.getenv("STABILITY_HOST", "grpc.stability.ai:443") STABILITY_KEY = os.getenv("STABILITY_KEY", "") - if not STABILITY_HOST: - logger.warning("STABILITY_HOST environment variable needs to be set.") - sys.exit(1) - if not STABILITY_KEY: - logger.warning( - "STABILITY_KEY environment variable needs to be set. You may" - " need to login to the Stability website to obtain the" - " API key." + print( + "Please enter your API key from dreamstudio.ai or set the " + "STABILITY_KEY environment variable to skip this prompt." ) - sys.exit(1) + STABILITY_KEY = getpass.getpass("Enter your Stability API key: ") # CLI parsing parser = ArgumentParser() @@ -515,6 +483,12 @@ def process_cli(logger: logging.Logger = None, parser_upscale.add_argument( "prompt", nargs="*" ) + + + parser_animate = subparsers.add_parser('animate') + parser_animate.add_argument("--gui", action="store_true", help="serve Gradio UI") + parser_animate.add_argument("--share", action="store_true", help="create shareable UI link") + parser_animate.add_argument("--output", "-o", type=str, default=".", help="root output folder") parser_generate = subparsers.add_parser('generate') @@ -601,10 +575,10 @@ def process_cli(logger: logging.Logger = None, if command not in subparsers.choices.keys() and command != '-h' and command != '--help': logger.warning(f"command {command} not recognized, defaulting to 'generate'") logger.warning( - "[Deprecation Warning] The method you have used to invoke the sdk will be deprecated shortly." - "[Deprecation Warning] Please modify your code to call the sdk with the following syntax:" - "[Deprecation Warning] python -m stability_sdk " - "[Deprecation Warning] Where is one of: upscale, generate" + "[Deprecation Warning] The method you have used to invoke the sdk will be deprecated shortly." + "[Deprecation Warning] Please modify your code to call the sdk with the following syntax:" + "[Deprecation Warning] python -m stability_sdk " + "[Deprecation Warning] Where is one of: upscale, generate" ) input_args = ['generate'] + input_args @@ -624,7 +598,7 @@ def process_cli(logger: logging.Logger = None, "seed": args.seed, "cfg_scale": args.cfg_scale, "prompt": args.prompt, - } + } stability_api = StabilityInference( STABILITY_HOST, STABILITY_KEY, upscale_engine=args.engine, verbose=True ) @@ -660,7 +634,7 @@ def process_cli(logger: logging.Logger = None, } if args.sampler: - request["sampler"] = get_sampler_from_str(args.sampler) + request["sampler"] = sampler_from_string(args.sampler) if args.steps: request["steps"] = args.steps @@ -673,6 +647,18 @@ def process_cli(logger: logging.Logger = None, args.prefix, args.prompt, answers, write=not args.no_store, verbose=True, filter_types=args.artifact_types, ) + elif args.command == "animate": + if args.gui: + from .animation_ui import create_ui + from .api import Context + ui = create_ui(Context(STABILITY_HOST, STABILITY_KEY), args.output) + ui.queue(concurrency_count=2, max_size=2) + ui.launch(show_api=False, debug=True, height=768, share=args.share, show_error=True) + sys.exit(0) + else: + logger.warning("animate must be invoked with --gui") + sys.exit(1) + if args.show: for artifact in open_images(artifacts, verbose=True): diff --git a/src/stability_sdk/interfaces b/src/stability_sdk/interfaces index 193a1e41..f3a50851 160000 --- a/src/stability_sdk/interfaces +++ b/src/stability_sdk/interfaces @@ -1 +1 @@ -Subproject commit 193a1e41a984c6e452aa1f99d9c1c20c15c692a4 +Subproject commit f3a50851f8ea158fef1b1d76661cfd9a8cf83e01 diff --git a/src/stability_sdk/matrix.py b/src/stability_sdk/matrix.py new file mode 100644 index 00000000..31118755 --- /dev/null +++ b/src/stability_sdk/matrix.py @@ -0,0 +1,76 @@ +""" +Minimal set of 4x4 column-major matrix functions for building transforms +compatible with the animation transform API. This serves as reference +implementation for the different languages we will support so only basic +types and no external libraries are used. + + [sx, 10, 20, tx] [x] + [01, sy, 21, ty] . [y] + [02, 12, sz, tz] [z] + [03, 13, 23, 33] [1] + +""" +import math +from typing import List + +Matrix = List[List[float]] + +identity = [[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]] + +def multiply(a: Matrix, b: Matrix) -> Matrix: + assert len(a) == len(b) == 4 + c = [[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]] + for row in range(4): + for col in range(4): + for k in range(4): + c[row][col] += a[row][k] * b[k][col] + return c + +def projection_fov(fov_y: float, aspect: float, near: float, far: float) -> Matrix: + min_x, min_y = -1, -1 + max_x, max_y = 1, 1 + h1 = (max_y + min_y) / (max_y - min_y) + w1 = (max_x + min_x) / (max_x - min_x) + t = math.tan(fov_y / 2) + s1 = 1 / t + s2 = 1 / (t * aspect) + + # map z to the range [0, 1] + f1 = far / (far - near) + f2 = -(far * near) / (far - near) + + return [[s1, 0., w1, 0.], + [0., s2, h1, 0.], + [0., 0., f1, f2], + [0., 0., 1., 0.]] + +def rotation_euler(x: float, y: float, z: float) -> Matrix: + """Returns a rotation matrix for the given Euler angles (in radians) using XYZ order.""" + a, b = math.cos(x), math.sin(x) + c, d = math.cos(y), math.sin(y) + e, f = math.cos(z), math.sin(z) + + ae = a * e + af = a * f + be = b * e + bf = b * f + + return [[ c * e, af + be * d, bf - ae * d, 0.], + [-c * f, ae - bf * d, be + af * d, 0.], + [ d, -b * c, a * c, 0.], + [ 0., 0., 0., 1.]] + +def scale(sx: float, sy: float, sz: float) -> Matrix: + return [[sx, 0., 0., 0.], + [0., sy, 0., 0.], + [0., 0., sz, 0.], + [0., 0., 0., 1.]] + +def translation(tx: float, ty: float, tz: float) -> Matrix: + return [[1., 0., 0., tx], + [0., 1., 0., ty], + [0., 0., 1., tz], + [0., 0., 0., 1.]] diff --git a/src/stability_sdk/utils.py b/src/stability_sdk/utils.py index 96b75a97..fd178a58 100644 --- a/src/stability_sdk/utils.py +++ b/src/stability_sdk/utils.py @@ -1,23 +1,64 @@ -import pathlib -import sys -import os -import uuid -import random import io import logging -import time -from typing import Dict, Generator, List, Optional, Union, Any, Sequence, Tuple -import mimetypes - +import os +import subprocess from PIL import Image +from typing import Dict, Generator, Optional, Sequence, Tuple, Type, TypeVar, Union + +from .api import generation +from .matrix import Matrix -import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation -import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) +MAX_FILENAME_SZ = int(os.getenv("MAX_FILENAME_SZ", 200)) + + +#============================================================================== +# Mappings from strings to protobuf enums +#============================================================================== + +BORDER_MODES = { + 'replicate': generation.BORDER_REPLICATE, + 'reflect': generation.BORDER_REFLECT, + 'wrap': generation.BORDER_WRAP, + 'zero': generation.BORDER_ZERO, + 'prefill': generation.BORDER_PREFILL, +} + +CAMERA_TYPES = { + 'perspective': generation.CAMERA_PERSPECTIVE, + 'orthographic': generation.CAMERA_ORTHOGRAPHIC, +} + +COLOR_MATCH_MODES = { + "hsv": generation.COLOR_MATCH_HSV, + "lab": generation.COLOR_MATCH_LAB, + "rgb": generation.COLOR_MATCH_RGB, +} + +GUIDANCE_PRESETS: Dict[str, int] = { + "none": generation.GUIDANCE_PRESET_NONE, + "simple": generation.GUIDANCE_PRESET_SIMPLE, + "fastblue": generation.GUIDANCE_PRESET_FAST_BLUE, + "fastgreen": generation.GUIDANCE_PRESET_FAST_GREEN, +} + +INTERPOLATE_MODES = { + 'film': generation.INTERPOLATE_FILM, + 'mix': generation.INTERPOLATE_LINEAR, + 'rife': generation.INTERPOLATE_RIFE, + 'vae-lerp': generation.INTERPOLATE_VAE_LINEAR, + 'vae-slerp': generation.INTERPOLATE_VAE_SLERP, +} + +RENDER_MODES = { + 'mesh': generation.RENDER_MESH, + 'pointcloud': generation.RENDER_POINTCLOUD, +} + SAMPLERS: Dict[str, int] = { "ddim": generation.SAMPLER_DDIM, "plms": generation.SAMPLER_DDPM, @@ -30,10 +71,128 @@ "k_dpmpp_2m": generation.SAMPLER_K_DPMPP_2M, "k_dpmpp_2s_ancestral": generation.SAMPLER_K_DPMPP_2S_ANCESTRAL } - -MAX_FILENAME_SZ = int(os.getenv("MAX_FILENAME_SZ", 200)) -def artifact_type_to_str(artifact_type: generation.ArtifactType): +T = TypeVar('T') + +def _from_string(s: str, mapping: Dict[str, T], name: str, enum_cls: Type[T]) -> T: + enum_value = mapping.get(s.lower().strip()) + if enum_value is None: + raise ValueError(f"invalid {name}: {s}") + return enum_value + +def border_mode_from_string(s: str) -> generation.BorderMode: + return _from_string(s, BORDER_MODES, "border mode", generation.BorderMode) + +def camera_type_from_string(s: str) -> generation.CameraType: + return _from_string(s, CAMERA_TYPES, "camera type", generation.CameraType) + +def color_match_from_string(s: str) -> generation.ColorMatchMode: + return _from_string(s, COLOR_MATCH_MODES, "color match", generation.ColorMatchMode) + +def guidance_from_string(s: str) -> generation.GuidancePreset: + return _from_string(s, GUIDANCE_PRESETS, "guidance preset", generation.GuidancePreset) + +def interpolate_mode_from_string(s: str) -> generation.InterpolateMode: + return _from_string(s, INTERPOLATE_MODES, "interpolate mode", generation.InterpolateMode) + +def render_mode_from_string(s: str) -> generation.RenderMode: + return _from_string(s, RENDER_MODES, "render mode", generation.RenderMode) + +def sampler_from_string(s: str) -> generation.DiffusionSampler: + return _from_string(s, SAMPLERS, "sampler", generation.DiffusionSampler) + + +#============================================================================== +# Transform helper functions +#============================================================================== + +def camera_pose_transform( + transform: Matrix, + near_plane: float, + far_plane: float, + fov: float, + camera_type: str='perspective', + render_mode: str='mesh', + do_prefill: bool=True, +) -> generation.TransformParameters: + camera_parameters = generation.CameraParameters( + camera_type=camera_type_from_string(camera_type), + near_plane=near_plane, far_plane=far_plane, fov=fov) + return generation.TransformParameters( + camera_pose=generation.TransformCameraPose( + world_to_view_matrix=generation.TransformMatrix(data=sum(transform, [])), + camera_parameters=camera_parameters, + render_mode=render_mode_from_string(render_mode), + do_prefill=do_prefill + ) + ) + +def color_adjust_transform( + brightness: float=1.0, + contrast: float=1.0, + hue: float=0.0, + saturation: float=1.0, + lightness: float=0.0, + match_image: Optional[Image.Image]=None, + match_mode: str='LAB', + noise_amount: float=0.0, + noise_seed: int=0 +) -> generation.TransformParameters: + if match_mode == 'None': + match_mode = 'RGB' + match_image = None + return generation.TransformParameters( + color_adjust=generation.TransformColorAdjust( + brightness=brightness, + contrast=contrast, + hue=hue, + saturation=saturation, + lightness=lightness, + match_image=generation.Artifact( + type=generation.ARTIFACT_IMAGE, + binary=image_to_jpg_bytes(match_image), + ) if match_image is not None else None, + match_mode=color_match_from_string(match_mode), + noise_amount=noise_amount, + noise_seed=noise_seed, + )) + +def depth_calc_transform( + blend_weight: float, + blur_radius: int=0, + reverse: bool=False, +) -> generation.TransformParameters: + return generation.TransformParameters( + depth_calc=generation.TransformDepthCalc( + blend_weight=blend_weight, + blur_radius=blur_radius, + reverse=reverse + ) + ) + +def resample_transform( + border_mode: str, + transform: Matrix, + prev_transform: Optional[Matrix]=None, + depth_warp: float=1.0, + export_mask: bool=False +) -> generation.TransformParameters: + return generation.TransformParameters( + resample=generation.TransformResample( + border_mode=border_mode_from_string(border_mode), + transform=generation.TransformMatrix(data=sum(transform, [])), + prev_transform=generation.TransformMatrix(data=sum(prev_transform, [])) if prev_transform else None, + depth_warp=depth_warp, + export_mask=export_mask + ) + ) + + +#============================================================================== +# General utility functions +#============================================================================== + +def artifact_type_to_string(artifact_type: generation.ArtifactType): """ Convert ArtifactType to a string. :param artifact_type: The ArtifactType to convert. @@ -49,33 +208,117 @@ def artifact_type_to_str(artifact_type: generation.ArtifactType): ) return "ARTIFACT_UNRECOGNIZED" -def truncate_fit(prefix: str, prompt: str, ext: str, ts: int, idx: int, max: int) -> str: +def create_video_from_frames(frames_path: str, mp4_path: str, fps: int=24, reverse: bool=False): """ - Constructs an output filename from a collection of required fields. + Convert a series of image frames to a video file using ffmpeg. + + :param frames_path: The path to the directory containing the image frames named frame_00000.png, frame_00001.png, etc. + :param mp4_path: The path to save the output video file. + :param fps: The frames per second for the output video. Default is 24. + :param reverse: A flag to reverse the order of the frames in the output video. Default is False. + """ + + cmd = [ + 'ffmpeg', + '-y', + '-vcodec', 'png', + '-r', str(fps), + '-start_number', str(0), + '-i', os.path.join(frames_path, "frame_%05d.png"), + '-c:v', 'libx264', + '-vf', + f'fps={fps}', + '-pix_fmt', 'yuv420p', + '-crf', '17', + '-preset', 'veryslow', + mp4_path + ] + if reverse: + cmd.insert(-1, '-vf') + cmd.insert(-1, 'reverse') + + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + _, stderr = process.communicate() + if process.returncode != 0: + raise RuntimeError(stderr) + +def extract_frames_from_video(video_path: str, frames_subdir: str='frames'): + """ + Extracts all frames from a video to a subdirectory of the video's parent folder. + :param video_path: A path to the video. + :param frames_subdir: Name of the subdirectory to save the frames into. + :return: The frames subdirectory path. + """ + out_dir = os.path.join(os.path.dirname(video_path), frames_subdir) + if not os.path.exists(out_dir): + os.mkdir(out_dir) - Given an over-budget threshold of `max`, trims the prompt string to satisfy the budget. - NB: As implemented, 'max' is the smallest filename length that will trigger truncation. - It is presumed that the sum of the lengths of the other filename fields is smaller than `max`. - If they exceed `max`, this function will just always construct a filename with no prompt component. + cmd = [ + 'ffmpeg', + '-i', video_path, + os.path.join(out_dir, "frame_%05d.png"), + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + _, stderr = process.communicate() + if process.returncode != 0: + raise RuntimeError(stderr) + + return out_dir + +def image_mix(img_a: Image.Image, img_b: Image.Image, ratio: Union[float, Image.Image]) -> Image.Image: """ - post = f"_{ts}_{idx}" - prompt_budget = max - prompt_budget -= len(prefix) - prompt_budget -= len(post) - prompt_budget -= len(ext) + 1 - return f"{prefix}{prompt[:prompt_budget]}{post}{ext}" + Performs a linear interpolation between two images + :param img_a: The first image. + :param img_b: The second image. + :param ratio: Mix ratio or mask image. + :return: The mixed image + """ + if img_a.size != img_b.size: + raise ValueError(f"img_a size {img_a.size} does not match img_b size {img_b.size}") + + if isinstance(ratio, Image.Image): + if ratio.size != img_a.size: + raise ValueError(f"mix ratio size {ratio.size} does not match img_a size {img_a.size}") + return Image.composite(img_b, img_a, ratio) + + return Image.blend(img_a, img_b, ratio) + +def image_to_jpg_bytes(image: Image.Image, quality: int=90) -> bytes: + """ + Compresses an image to a JPEG byte array. + :param image: The image to convert. + :param quality: The JPEG quality to use. + :return: The JPEG byte array. + """ + buf = io.BytesIO() + image.save(buf, format="JPEG", quality=quality) + buf.seek(0) + return buf.getvalue() -def get_sampler_from_str(s: str) -> generation.DiffusionSampler: +def image_to_png_bytes(image: Image.Image) -> bytes: """ - Convert a string to a DiffusionSampler enum. - :param s: The string to convert. - :return: The DiffusionSampler enum. + Compresses an image to a PNG byte array. + :param image: The image to convert. + :return: The PNG byte array. """ - algorithm_key = s.lower().strip() - algorithm = SAMPLERS.get(algorithm_key, None) - if algorithm is None: - raise ValueError(f"unknown sampler {s}") - return algorithm + buf = io.BytesIO() + image.save(buf, format="PNG") + buf.seek(0) + return buf.getvalue() + +def image_to_prompt( + image: Image.Image, + type: generation.ArtifactType=generation.ARTIFACT_IMAGE +) -> generation.Prompt: + """ + Create Prompt message type from an image. + :param image: The image. + :param type: The ArtifactType to use (ARTIFACT_IMAGE, ARTIFACT_MASK, or ARTIFACT_DEPTH). + """ + return generation.Prompt(artifact=generation.Artifact( + type=type, + binary=image_to_png_bytes(image) + )) def open_images( images: Union[ @@ -97,3 +340,29 @@ def open_images( img = Image.open(io.BytesIO(artifact.binary)) img.show() yield (path, artifact) + +def tensor_to_prompt(tensor: 'tensors_pb.Tensor') -> generation.Prompt: + """ + Create Prompt message type from a tensor. + :param tensor: The tensor. + """ + return generation.Prompt(artifact=generation.Artifact( + type=generation.ARTIFACT_TENSOR, + tensor=tensor + )) + +def truncate_fit(prefix: str, prompt: str, ext: str, ts: int, idx: int, max: int) -> str: + """ + Constructs an output filename from a collection of required fields. + + Given an over-budget threshold of `max`, trims the prompt string to satisfy the budget. + NB: As implemented, 'max' is the smallest filename length that will trigger truncation. + It is presumed that the sum of the lengths of the other filename fields is smaller than `max`. + If they exceed `max`, this function will just always construct a filename with no prompt component. + """ + post = f"_{ts}_{idx}" + prompt_budget = max + prompt_budget -= len(prefix) + prompt_budget -= len(post) + prompt_budget -= len(ext) + 1 + return f"{prefix}{prompt[:prompt_budget]}{post}{ext}" diff --git a/tests/assets/4166726513_giant__rainbow_sequoia__tree_by_hayao_miyazaki___earth_tones__a_row_of_western_cedar_nurse_trees_che.png b/tests/assets/4166726513_giant__rainbow_sequoia__tree_by_hayao_miyazaki___earth_tones__a_row_of_western_cedar_nurse_trees_che.png new file mode 100644 index 00000000..8cfef971 Binary files /dev/null and b/tests/assets/4166726513_giant__rainbow_sequoia__tree_by_hayao_miyazaki___earth_tones__a_row_of_western_cedar_nurse_trees_che.png differ diff --git a/tests/assets/HebyMorgongava_512kb.mp4 b/tests/assets/HebyMorgongava_512kb.mp4 new file mode 100755 index 00000000..eb4360df Binary files /dev/null and b/tests/assets/HebyMorgongava_512kb.mp4 differ diff --git a/tests/conftest.py b/tests/conftest.py index 782ab634..12bc8653 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,11 @@ -from concurrent import futures - import grpc -import pytest - -import logging import pathlib -import sys - -thisPath = pathlib.Path(__file__).parent.parent.resolve() -genPath = thisPath / "src/stability_sdk/interfaces/gooseai/generation" -tensPath = thisPath / "src/stability_sdk/interfaces/src/tensorizer/tensors" -assert genPath.exists() -assert tensPath.exists() +import pytest -logger = logging.getLogger(__name__) -sys.path.extend([str(genPath), str(tensPath)]) +from concurrent import futures +from PIL import Image -import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation -import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc +from stability_sdk.api import generation_grpc # modified from https://github.com/justdoit0823/grpc-resolver/blob/master/tests/conftest.py @@ -38,3 +26,15 @@ def grpc_server(grpc_addr): server.start() yield server server.stop(0) + +@pytest.fixture(scope='module') +def impath() -> str: + return str(next(pathlib.Path('.').glob('**/tests/assets/*.png')).resolve()) + +@pytest.fixture(scope='module') +def pil_image(impath) -> Image.Image: + return Image.open(impath) + +@pytest.fixture(scope='module') +def vidpath() -> str: + return str(next(pathlib.Path('.').glob('**/tests/assets/*.mp4'))) diff --git a/tests/test_animator.py b/tests/test_animator.py new file mode 100644 index 00000000..304363fa --- /dev/null +++ b/tests/test_animator.py @@ -0,0 +1,74 @@ +import pytest + +from pathlib import Path + +from stability_sdk.animation import Animator, AnimationArgs +from stability_sdk.api import Context + +from .test_api import MockStub + +animation_prompts={0:"foo bar"} + +def test_init_animator(): + Animator( + Context(stub=MockStub()), + args=AnimationArgs(), + animation_prompts=animation_prompts, + ) + +def test_init_animator_prompts_notoptional(): + with pytest.raises(TypeError, match="missing 1 required positional argument: 'animation_prompts'"): + Animator( + Context(stub=MockStub()), + args=AnimationArgs(), + ) + +def test_save_settings(): + animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts) + animator.save_settings("settings.txt") + +def test_get_weights(): + animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts) + animator.get_animation_prompts_weights(frame_idx=0) + +def test_load_video(vidpath): + args = AnimationArgs() + args.animation_mode = 'Video Input' + args.video_init_path = vidpath + animator = Animator(Context(stub=MockStub()), args=args, animation_prompts=animation_prompts) + assert len(animator.prior_frames) > 0 + assert animator.video_prev_frame is not None + assert all([v is not None for v in animator.prior_frames]) + +@pytest.mark.parametrize('animation_mode', ['Video Input','2D','3D warp','3D render']) +def test_render(animation_mode, vidpath): + args = AnimationArgs() + args.animation_mode = animation_mode + args.video_init_path = vidpath + animator = Animator(Context(stub=MockStub()), args=args, animation_prompts=animation_prompts) + if animation_mode == 'Video Input': + print(len(animator.prior_frames)) + print([type(p) for p in animator.prior_frames]) + _ = animator.render() + +def test_init_image_none(): + animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts) + assert len(animator.prior_frames) == 0 + +def test_init_image_from_args(impath): + args = AnimationArgs() + args.init_image = impath + animator = Animator(Context(stub=MockStub()), args=args, animation_prompts=animation_prompts) + assert len(animator.prior_frames) == 1 + +def test_init_image_from_input(impath): + animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts) + print(animator.prior_frames) + assert len(animator.prior_frames) == 0 + animator.load_init_image() + assert len(animator.prior_frames) == 0 + assert Path(impath).exists() + animator.load_init_image(impath) + assert len(animator.prior_frames) == 1 + animator.set_cadence_mode(True) + assert len(animator.prior_frames) == 2 \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 00000000..b8a6c740 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,173 @@ +import io +import numpy as np +from PIL import Image +from typing import Generator + +import stability_sdk.matrix as matrix +from stability_sdk import utils +from stability_sdk.api import Context, generation + +def _artifact_from_image(image: Image.Image) -> generation.Artifact: + binary = utils.image_to_png_bytes(image) + return generation.Artifact( + type=generation.ARTIFACT_IMAGE, + mime="image/png", + binary=binary, + size=len(binary) + ) + +def _rand_image(width: int=512, height: int=512) -> Image.Image: + return Image.fromarray(np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)) + +class MockStub: + def __init__(self): + pass + + def ChainGenerate(self, chain: generation.ChainRequest, **kwargs) -> Generator[generation.Answer, None, None]: + # Not a full implementation of chaining, but enough to test current api.Context layer + artifacts = [] + for stage in chain.stage: + stage.request.MergeFrom(generation.Request(prompt=[generation.Prompt(artifact=a) for a in artifacts])) + artifacts = [] + for answer in self.Generate(stage.request): + artifacts.extend(answer.artifacts) + for artifact in artifacts: + yield generation.Answer(artifacts=[artifact]) + + def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]: + if request.HasField("image"): + image = _rand_image(request.image.width or 512, request.image.height or 512) + yield generation.Answer(artifacts=[_artifact_from_image(image)]) + + elif request.HasField("interpolate"): + assert len(request.prompt) == 2 + assert request.prompt[0].artifact.type == generation.ARTIFACT_IMAGE + assert request.prompt[1].artifact.type == generation.ARTIFACT_IMAGE + image_a = Image.open(io.BytesIO(request.prompt[0].artifact.binary)) + image_b = Image.open(io.BytesIO(request.prompt[1].artifact.binary)) + assert image_a.size == image_b.size + for ratio in request.interpolate.ratios: + tween = utils.image_mix(image_a, image_b, ratio) + yield generation.Answer(artifacts=[_artifact_from_image(tween)]) + + elif request.HasField("transform"): + assert len(request.prompt) >= 1 + + has_depth_input, has_tensor_input = False, False + for prompt in request.prompt: + if prompt.artifact.type == generation.ARTIFACT_DEPTH: + has_depth_input = True + elif prompt.artifact.type == generation.ARTIFACT_TENSOR: + has_tensor_input = True + + # 3D resample and camera pose require a depth or depth tensor artifact + if request.transform.HasField("resample") and len(request.transform.resample.transform.data) == 16: + assert has_depth_input or has_tensor_input + if request.transform.HasField("camera_pose"): + assert has_depth_input or has_tensor_input + + export_mask = request.transform.HasField("camera_pose") + if request.transform.HasField("resample"): + if request.transform.resample.HasField("export_mask"): + export_mask = request.transform.resample.export_mask + + for prompt in request.prompt: + if prompt.artifact.type == generation.ARTIFACT_IMAGE: + image = Image.open(io.BytesIO(prompt.artifact.binary)) + artifact = _artifact_from_image(image) + if request.transform.HasField("depth_calc"): + if request.requested_type == generation.ARTIFACT_TENSOR: + artifact.type = generation.ARTIFACT_TENSOR + else: + artifact.type = generation.ARTIFACT_DEPTH + yield generation.Answer(artifacts=[artifact]) + + if export_mask: + mask = _rand_image(image.width, image.height).convert("L") + artifact = _artifact_from_image(mask) + artifact.type = generation.ARTIFACT_MASK + yield generation.Answer(artifacts=[artifact]) + +def test_api_generate(): + api = Context(stub=MockStub()) + width, height = 512, 768 + results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height) + assert isinstance(results, dict) + assert generation.ARTIFACT_IMAGE in results + assert len(results[generation.ARTIFACT_IMAGE]) == 1 + image = results[generation.ARTIFACT_IMAGE][0] + assert isinstance(image, Image.Image) + assert image.size == (width, height) + +def test_api_inpaint(): + api = Context(stub=MockStub()) + width, height = 512, 768 + image = _rand_image(width, height) + mask = _rand_image(width, height).convert("L") + results = api.inpaint(image, mask, prompts=["foo bar"], weights=[1.0]) + assert generation.ARTIFACT_IMAGE in results + assert len(results[generation.ARTIFACT_IMAGE]) == 1 + image = results[generation.ARTIFACT_IMAGE][0] + assert isinstance(image, Image.Image) + assert image.size == (width, height) + +def test_api_interpolate(): + api = Context(stub=MockStub()) + width, height = 512, 768 + image_a = _rand_image(width, height) + image_b = _rand_image(width, height) + results = api.interpolate([image_a, image_b], [0.3, 0.5, 0.6]) + assert len(results) == 3 + for image in results: + assert isinstance(image, Image.Image) + assert image.size == (width, height) + +def test_api_transform_and_generate(): + api = Context(stub=MockStub()) + width, height = 512, 704 + init_image = _rand_image(width, height) + generate_request = api.generate(["a cute cat"], [1], width=width, height=height, + init_strength=0.65, return_request=True) + assert isinstance(generate_request, generation.Request) + image = api.transform_and_generate(init_image, [utils.color_adjust_transform()], generate_request) + assert isinstance(image, Image.Image) + assert image.size == (width, height) + +def test_api_transform_camera_pose(): + api = Context(stub=MockStub()) + image = _rand_image() + xform = matrix.identity + pose = utils.camera_pose_transform( + xform, 0.1, 100.0, 75.0, + camera_type='perspective', + render_mode='mesh', + do_prefill=True + ) + images, masks = api.transform_3d([image], utils.depth_calc_transform(blend_weight=1.0), pose) + assert len(images) == 1 and len(masks) == 1 + assert isinstance(images[0], Image.Image) + assert isinstance(masks[0], Image.Image) + +def test_api_transform_color_adjust(): + api = Context(stub=MockStub()) + image = _rand_image() + images, masks = api.transform([image], utils.color_adjust_transform()) + assert len(images) == 1 and not masks + assert isinstance(images[0], Image.Image) + images, masks = api.transform([image, image], utils.color_adjust_transform()) + assert len(images) == 2 and not masks + +def test_api_transform_resample_3d(): + api = Context(stub=MockStub()) + image = _rand_image() + xform = matrix.identity + resample = utils.resample_transform('replicate', xform, xform, export_mask=True) + images, masks = api.transform_3d([image], utils.depth_calc_transform(blend_weight=0.5), resample) + assert len(images) == 1 and len(masks) == 1 + assert isinstance(images[0], Image.Image) + assert isinstance(masks[0], Image.Image) + +def test_api_upscale(): + api = Context(stub=MockStub()) + result = api.upscale(_rand_image()) + assert isinstance(result, Image.Image) diff --git a/tests/test_client.py b/tests/test_client.py index a17feb72..69b93ca8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,60 +1,28 @@ -import pytest from PIL import Image +from typing import Generator from stability_sdk import client -import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation -import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc - -import grpc - -# feel like we should be using this, not sure how/where -import grpc_testing - -from typing import Generator +from stability_sdk.api import generation -def test_client_import(): - from stability_sdk import client - assert True def test_StabilityInference_init(): - class_instance = client.StabilityInference(key='thisIsNotARealKey') + _ = client.StabilityInference(key='thisIsNotARealKey') assert True -def test_StabilityInference_init_nokey_error(): - try: - class_instance = client.StabilityInference() - assert False - except ValueError: - assert True - def test_StabilityInference_init_nokey_insecure_host(): - class_instance = client.StabilityInference(host='foo.bar.baz') + _ = client.StabilityInference(host='foo.bar.baz') assert True -def test_image_to_prompt(): - im = Image.new('RGB',(1,1)) - prompt = client.image_to_prompt(im, init=False, mask=False) - assert isinstance(prompt, generation.Prompt) - def test_image_to_prompt_init(): - im = Image.new('RGB',(1,1)) - prompt = client.image_to_prompt(im, init=True, mask=False) + im = Image.new('RGB', (1,1)) + prompt = client.image_to_prompt(im) assert isinstance(prompt, generation.Prompt) def test_image_to_prompt_mask(): - im = Image.new('RGB',(1,1)) - prompt = client.image_to_prompt(im, init=False, mask=True) + im = Image.new('RGB', (1,1)) + prompt = client.image_to_prompt(im, type=generation.ARTIFACT_MASK) assert isinstance(prompt, generation.Prompt) -def test_image_to_prompt_init_mask(): - im = Image.new('RGB',(1,1)) - try: - prompt = client.image_to_prompt(im, init=True, mask=True) - assert False - except ValueError: - assert True - - def test_server_mocking(grpc_server, grpc_addr): class_instance = client.StabilityInference(host=grpc_addr[0]) response = class_instance.generate(prompt="foo bar") diff --git a/tests/test_utils.py b/tests/test_utils.py index a9c909c5..88c375da 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,30 +1,74 @@ import pytest -import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation +from PIL import Image +from typing import ByteString + +import stability_sdk.matrix as matrix +from stability_sdk.api import generation from stability_sdk.utils import ( + BORDER_MODES, + COLOR_MATCH_MODES, + GUIDANCE_PRESETS, SAMPLERS, - artifact_type_to_str, - get_sampler_from_str, + artifact_type_to_string, + border_mode_from_string, + color_adjust_transform, + color_match_from_string, + depth_calc_transform, + guidance_from_string, + image_mix, + image_to_jpg_bytes, + image_to_png_bytes, + image_to_prompt, + resample_transform, + sampler_from_string, truncate_fit, ) +@pytest.mark.parametrize("border", BORDER_MODES.keys()) +def test_border_mode_from_str_2d_valid(border): + border_mode_from_string(s=border) + assert True + +def test_border_mode_from_str_2d_invalid(): + with pytest.raises(ValueError, match="invalid border mode"): + border_mode_from_string(s='not a real border mode') + @pytest.mark.parametrize("artifact_type", generation.ArtifactType.values()) def test_artifact_type_to_str_valid(artifact_type): - type_str = artifact_type_to_str(artifact_type) + type_str = artifact_type_to_string(artifact_type) assert type_str == generation.ArtifactType.Name(artifact_type) def test_artifact_type_to_str_invalid(): - type_str = artifact_type_to_str(-1) + type_str = artifact_type_to_string(-1) assert type_str == 'ARTIFACT_UNRECOGNIZED' @pytest.mark.parametrize("sampler_name", SAMPLERS.keys()) -def test_get_sampler_from_str_valid(sampler_name): - get_sampler_from_str(s=sampler_name) +def test_sampler_from_str_valid(sampler_name): + sampler_from_string(s=sampler_name) assert True -def test_get_sampler_from_str_invalid(): - with pytest.raises(ValueError, match="unknown sampler"): - get_sampler_from_str(s='not a real sampler') +def test_sampler_from_str_invalid(): + with pytest.raises(ValueError, match="invalid sampler"): + sampler_from_string(s='not a real sampler') + +@pytest.mark.parametrize("preset_name", GUIDANCE_PRESETS.keys()) +def test_guidance_from_string_valid(preset_name): + guidance_from_string(s=preset_name) + assert True + +def test_guidance_from_string_invalid(): + with pytest.raises(ValueError, match="invalid guidance preset"): + guidance_from_string(s='not a real preset') + +@pytest.mark.parametrize("color_match_mode", COLOR_MATCH_MODES.keys()) +def test_color_match_from_string_valid(color_match_mode): + color_match_from_string(s=color_match_mode) + assert True + +def test_color_match_from_string_invalid(): + with pytest.raises(ValueError, match="invalid color match"): + color_match_from_string(s='not a real color match mode') #################################### @@ -49,4 +93,82 @@ def test_truncate_fit1(): idx=0, max=22) assert outv == 'foo_ba_12345678_0.baz' - + + +#============================================================================== +# Image functions +#============================================================================== + +def test_image_mix(pil_image): + result = image_mix(img_a=pil_image, img_b=pil_image, ratio=0.5) + assert isinstance(result, Image.Image) + assert result.size == pil_image.size + result = image_mix(img_a=Image.new('L', (64,64), 0), img_b=Image.new('L', (64,64), 255), ratio=1.0) + assert all(pixel_value == 255 for pixel_value in result.getdata()) + +def test_image_mix_mask(pil_image): + result = image_mix(img_a=pil_image, img_b=pil_image, ratio=pil_image.convert('L')) + assert isinstance(result, Image.Image) + assert result.size == pil_image.size + result = image_mix(img_a=Image.new('L', (64,64), 0), img_b=Image.new('L', (64,64), 255), ratio=Image.new('L', (64,64), 255)) + assert all(pixel_value == 255 for pixel_value in result.getdata()) + +def test_image_to_jpg_bytes(pil_image): + result = image_to_jpg_bytes(pil_image) + assert isinstance(result, ByteString) + +def test_image_to_png_bytes(pil_image): + result = image_to_png_bytes(image=pil_image) + assert isinstance(result, ByteString) + +def test_image_to_prompt(pil_image): + result = image_to_prompt(pil_image) + assert isinstance(result, generation.Prompt) + assert result.artifact.type == generation.ARTIFACT_IMAGE + +def test_image_to_prompt_mask(pil_image): + result = image_to_prompt(pil_image, type=generation.ARTIFACT_MASK) + assert isinstance(result, generation.Prompt) + assert result.artifact.type == generation.ARTIFACT_MASK + + +#============================================================================== +# Transform functions +#============================================================================== + +@pytest.mark.parametrize("color_mode", COLOR_MATCH_MODES.keys()) +def test_colormatch_valid(pil_image, color_mode): + op = color_adjust_transform( + match_image=pil_image, + match_mode=color_mode + ) + assert isinstance(op, generation.TransformParameters) + +def test_colormatch_invalid(pil_image): + with pytest.raises(ValueError, match="invalid color match"): + _ = color_adjust_transform( + match_image=pil_image, + match_mode="not a real color match mode", + ) + +@pytest.mark.parametrize("border_mode", BORDER_MODES.keys()) +def test_resample_valid(border_mode): + op = resample_transform( + border_mode=border_mode, + transform=matrix.identity, + prev_transform=matrix.identity, + depth_warp=1.0, + export_mask=False + ) + assert isinstance(op, generation.TransformParameters) + +@pytest.mark.parametrize("border_mode", ['not a border mode']) +def test_resample_invalid(border_mode): + with pytest.raises(ValueError, match="invalid border mode"): + _ = resample_transform( + border_mode=border_mode, + transform=matrix.identity, + prev_transform=matrix.identity, + depth_warp=1.0, + export_mask=False + )