diff --git a/.github/workflows/unit_testing.yaml b/.github/workflows/unit_testing.yaml
index b3004e69..9f94c1e1 100644
--- a/.github/workflows/unit_testing.yaml
+++ b/.github/workflows/unit_testing.yaml
@@ -21,13 +21,14 @@ jobs:
cache: 'pip'
cache-dependency-path: |
**/setup.py
+
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Install package
run: |
- pip install .[dev]
+ pip install .[dev,anim]
- name: Test with pytest
run: |
pytest -s --ignore=src/stability_sdk/interfaces
diff --git a/.gitignore b/.gitignore
index 26e36ec4..1f255a33 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,8 +1,10 @@
*.pyc
*.egg-info
-*.png
+.vscode/
+build/
dist/
pyenv/
*venv/
.env
generation-*.pb.json
+Pipfile*
\ No newline at end of file
diff --git a/README.md b/README.md
index 358f93d6..39787fe0 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,14 @@ It will generate and put PNGs in your current directory.
To upscale:
`python3 -m stability_sdk upscale -i "/path/to/image.png"`
+## Animation UI
+
+Install with
+`pip install stability-sdk[anim_ui]`
+
+Then run with
+`python3 -m stability_sdk animate --gui`
+
## SDK Usage
Be sure to check out [Platform](https://platform.stability.ai) for comprehensive documentation on how to interact with our API.
@@ -57,7 +65,7 @@ options:
--width WIDTH, -W WIDTH
[512] width of image
--start_schedule START_SCHEDULE
- [0.5] start schedule for init image (must be greater than 0, 1 is full strength
+ [0.5] start schedule for init image (must be greater than 0; 1 is full strength
text prompt, no trace of image)
--end_schedule END_SCHEDULE
[0.01] end schedule for init image
diff --git a/nbs/animation.ipynb b/nbs/animation.ipynb
new file mode 100644
index 00000000..16f58781
--- /dev/null
+++ b/nbs/animation.ipynb
@@ -0,0 +1,303 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rWXrHouW3pq_"
+ },
+ "source": [
+ "# Animation SDK example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "jCZ-IphH3prD"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Mount Google Drive\n",
+ "try:\n",
+ " from google.colab import drive\n",
+ " drive.mount('/content/gdrive')\n",
+ " outputs_path = \"/content/gdrive/MyDrive/AI/StableAnimation\"\n",
+ " !mkdir -p $outputs_path\n",
+ "except:\n",
+ " outputs_path = \".\"\n",
+ "print(f\"Animations will be saved to {outputs_path}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "cellView": "form",
+ "id": "zj56t6tc3prF"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "#@title Connect to the Stability API\n",
+ "\n",
+ "# install Stability Animation SDK for Python\n",
+ "%pip install stability-sdk[anim]\n",
+ "\n",
+ "import datetime\n",
+ "import json\n",
+ "import os\n",
+ "import panel as pn\n",
+ "import param\n",
+ "import shutil\n",
+ "import sys\n",
+ "\n",
+ "from base64 import b64encode\n",
+ "from IPython import display\n",
+ "from pathlib import Path\n",
+ "from PIL import Image\n",
+ "from tqdm import tqdm\n",
+ "from types import SimpleNamespace\n",
+ "\n",
+ "from stability_sdk.api import Context\n",
+ "from stability_sdk.animation import AnimationArgs, Animator\n",
+ "from stability_sdk.utils import create_video_from_frames\n",
+ "\n",
+ "\n",
+ "# Enter your API key from dreamstudio.ai\n",
+ "STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n",
+ "STABILITY_KEY = \"\" #@param {type:\"string\"}\n",
+ "\n",
+ "# Connect to Stability API\n",
+ "api_context = Context(STABILITY_HOST, STABILITY_KEY)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "ldUAFmur3prH"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Settings\n",
+ "\n",
+ "# @markdown Run this cell to reveal the settings UI. After entering values, move on to the next step.\n",
+ "\n",
+ "# @markdown To reset values to default, simply re-run this cell.\n",
+ "\n",
+ "# @markdown NB: Settings are grouped across several tabs.\n",
+ "\n",
+ "show_documentation = True # @param {type:'boolean'}\n",
+ "\n",
+ "# #@markdown ####**Resume:**\n",
+ "resume_timestring = \"\" #@param {type:\"string\"}\n",
+ "\n",
+ "#@markdown ####**Override Settings:**\n",
+ "override_settings_path = \"\" #@param {type:\"string\"}\n",
+ "\n",
+ "###################\n",
+ "\n",
+ "from stability_sdk.animation import (\n",
+ " AnimationArgs,\n",
+ " Animator,\n",
+ " AnimationSettings,\n",
+ " BasicSettings,\n",
+ " CoherenceSettings,\n",
+ " ColorSettings,\n",
+ " DepthSettings,\n",
+ " InpaintingSettings,\n",
+ " Rendering3dSettings,\n",
+ " CameraSettings,\n",
+ " VideoInputSettings,\n",
+ " VideoOutputSettings,\n",
+ ")\n",
+ "\n",
+ "args_generation = BasicSettings()\n",
+ "args_animation = AnimationSettings()\n",
+ "args_camera = CameraSettings()\n",
+ "args_coherence = CoherenceSettings()\n",
+ "args_color = ColorSettings()\n",
+ "args_depth = DepthSettings()\n",
+ "args_render_3d = Rendering3dSettings()\n",
+ "args_inpaint = InpaintingSettings()\n",
+ "args_vid_in = VideoInputSettings()\n",
+ "args_vid_out = VideoOutputSettings()\n",
+ "arg_objs = (\n",
+ " args_generation,\n",
+ " args_animation,\n",
+ " args_camera,\n",
+ " args_coherence,\n",
+ " args_color,\n",
+ " args_depth,\n",
+ " args_render_3d,\n",
+ " args_inpaint,\n",
+ " args_vid_in,\n",
+ " args_vid_out,\n",
+ ")\n",
+ "\n",
+ "def _show_docs(component):\n",
+ " cols = []\n",
+ " for k, v in component.param.objects().items():\n",
+ " if k == 'name':\n",
+ " continue\n",
+ " col = pn.Column(v, v.doc)\n",
+ " cols.append(col)\n",
+ " return pn.Column(*cols)\n",
+ "\n",
+ "def build(component):\n",
+ " if show_documentation:\n",
+ " component = _show_docs(component)\n",
+ " return pn.Row(component, width=1000)\n",
+ "\n",
+ "pn.extension()\n",
+ "\n",
+ "pn.Tabs(*[(a.name[:-5], build(a)) for a in arg_objs])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_SudvbZG3prI"
+ },
+ "source": [
+ "### Prompts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "FT9slDSw3prJ"
+ },
+ "outputs": [],
+ "source": [
+ "animation_prompts = {\n",
+ " 0: \"a painting of a delicious cheeseburger\",\n",
+ " 24: \"a painting of the the answer to life the universe and everything\",\n",
+ "}\n",
+ "\n",
+ "negative_prompt = \"\"\n",
+ "negative_prompt_weight = -1.0\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "Rpqv6t303prJ"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Render the animation\n",
+ "\n",
+ "args_d = {}\n",
+ "[args_d.update(a.param.values()) for a in arg_objs]\n",
+ "args=AnimationArgs(**args_d)\n",
+ "\n",
+ "\n",
+ "# load override settings if provided\n",
+ "if override_settings_path:\n",
+ " if not os.path.exists(override_settings_path):\n",
+ " raise ValueError(f\"Override settings file not found: {override_settings_path}\")\n",
+ " with open(override_settings_path, 'r') as f:\n",
+ " overrides = json.load(f)\n",
+ " args = vars(args)\n",
+ " for k in args.keys():\n",
+ " if k in overrides:\n",
+ " args[k] = overrides[k]\n",
+ " args = SimpleNamespace(**args)\n",
+ " animation_prompts = overrides.get('animation_prompts', animation_prompts)\n",
+ " animation_prompts = {int(k): v for k, v in animation_prompts.items()}\n",
+ " negative_prompt = overrides.get('negative_prompt', negative_prompt)\n",
+ " negative_prompt_weight = overrides.get('negative_prompt_weight', negative_prompt_weight)\n",
+ "\n",
+ "# create folder for frames output\n",
+ "if resume_timestring:\n",
+ " out_dir = os.path.join(outputs_path, resume_timestring)\n",
+ " if not os.path.exists(out_dir):\n",
+ " raise Exception(\"Can't resume {resume_timestring} because path {out_dir} doesn't exist. Please make sure the timestring is correct.\")\n",
+ " timestring = resume_timestring\n",
+ "else:\n",
+ " timestring = datetime.datetime.now().strftime('%Y%m%d%H%M%S')\n",
+ " out_dir = os.path.join(outputs_path, timestring)\n",
+ " os.makedirs(out_dir, exist_ok=True)\n",
+ "print(f\"Saving animation frames to {out_dir}...\")\n",
+ "\n",
+ "animator = Animator(\n",
+ " api_context=api_context,\n",
+ " animation_prompts=animation_prompts,\n",
+ " args=args,\n",
+ " out_dir=out_dir, \n",
+ " negative_prompt=negative_prompt,\n",
+ " negative_prompt_weight=negative_prompt_weight,\n",
+ " resume=len(resume_timestring) != 0,\n",
+ ")\n",
+ "animator.save_settings(f\"{timestring}_settings.txt\")\n",
+ "\n",
+ "for frame in tqdm(animator.render(), initial=animator.start_frame_idx, total=args.max_frames):\n",
+ " display.clear_output(wait=True)\n",
+ " display.display(frame)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "aWhJnLNX3prL"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Create video from frames\n",
+ "skip_video_for_run_all = False #@param {type: 'boolean'}\n",
+ "fps = 12 #@param {type:\"number\"}\n",
+ "\n",
+ "if skip_video_for_run_all == True:\n",
+ " print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
+ "else:\n",
+ " mp4_path = os.path.join(out_dir, f\"{timestring}.mp4\")\n",
+ " print(f\"Compiling animation frames to {mp4_path}...\")\n",
+ " create_video_from_frames(out_dir, mp4_path, fps)\n",
+ "\n",
+ " mp4 = open(mp4_path,'rb').read()\n",
+ " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
+ " display.display( display.HTML(f'') )"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "client",
+ "language": "python",
+ "name": "client"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.5"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "fb02550c4ef2b9a37ba5f7f381e893a74079cea154f791601856f87ae67cf67c"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/nbs/animation_gradio.ipynb b/nbs/animation_gradio.ipynb
new file mode 100644
index 00000000..c38a0db6
--- /dev/null
+++ b/nbs/animation_gradio.ipynb
@@ -0,0 +1,110 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tPZFjbUDTwYE"
+ },
+ "source": [
+ "# Stable Animation notebook"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "LUMF8i8BTwYH",
+ "outputId": "f61c635e-bc57-48ab-cc3d-5166286b158f"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Mount Google Drive\n",
+ "import os\n",
+ "try:\n",
+ " from google.colab import drive\n",
+ " drive.mount('/content/gdrive')\n",
+ " outputs_path = \"/content/gdrive/MyDrive/AI/StableAnimation\"\n",
+ " os.makedirs(outputs_path, exist_ok=True)\n",
+ "except:\n",
+ " outputs_path = \".\"\n",
+ "print(f\"Animations will be saved to {outputs_path}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "tgZBrk8DTwYI"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Install Animation SDK and connect to the Stability API\n",
+ "%pip install stability-sdk[anim_ui]\n",
+ "\n",
+ "from stability_sdk.api import Context\n",
+ "from stability_sdk.animation_ui import create_ui\n",
+ "\n",
+ "# Enter your API key from dreamstudio.ai\n",
+ "STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n",
+ "STABILITY_KEY = \"\" #@param {type:\"string\"}\n",
+ "\n",
+ "# Connect to Stability API\n",
+ "api_context = Context(STABILITY_HOST, STABILITY_KEY)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "cellView": "form",
+ "id": "QVCAr8xcTwYI"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Animation UI\n",
+ "show_ui_in_notebook = True #@param {type:\"boolean\"}\n",
+ "\n",
+ "ui = create_ui(api_context, outputs_path)\n",
+ "\n",
+ "ui.queue(concurrency_count=2, max_size=2)\n",
+ "ui.launch(show_api=False, debug=True, inline=show_ui_in_notebook, height=768, share=True, show_error=True)"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.5"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "fb02550c4ef2b9a37ba5f7f381e893a74079cea154f791601856f87ae67cf67c"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/setup.py b/setup.py
index 0854938a..7cd124e5 100644
--- a/setup.py
+++ b/setup.py
@@ -1,42 +1,56 @@
# fmt: off
-from setuptools import setup, find_namespace_packages
+from setuptools import (
+ setup,
+ find_namespace_packages,
+)
with open('README.md','r') as f:
README = f.read()
setup(
name='stability-sdk',
- version='0.7.0',
+ version='0.8.0',
author='Stability AI',
author_email='support@stability.ai',
maintainer='Stability AI',
maintainer_email='support@stability.ai',
url='https://beta.dreamstudio.ai/',
download_url='https://github.com/Stability-AI/stability-sdk/',
-
description='Python SDK for interacting with stability.ai APIs',
long_description=README,
long_description_content_type="text/markdown",
-
install_requires=[
'Pillow',
'grpcio==1.53.0',
'grpcio-tools==1.53.0',
'python-dotenv',
+ 'param',
'protobuf==4.21.12'
],
extras_require={
'dev': [
'pytest',
'grpcio-testing'
- ]},
+ ],
+ 'anim': [
+ 'keyframed',
+ 'numpy',
+ 'opencv-python-headless',
+ ],
+ 'anim_ui': [
+ 'keyframed',
+ 'gradio',
+ 'numpy',
+ 'opencv-python-headless',
+ 'tqdm',
+ ]
+ },
packages=find_namespace_packages(
where='src',
include=['stability_sdk*'],
),
package_dir = {"": "src"},
-
classifiers=[
'Intended Audience :: Developers',
'Intended Audience :: Education',
diff --git a/src/stability_sdk/__init__.py b/src/stability_sdk/__init__.py
index 29430654..f9a14d1c 100644
--- a/src/stability_sdk/__init__.py
+++ b/src/stability_sdk/__init__.py
@@ -2,8 +2,10 @@
# this is necessary because of how the auto-generated code constructs its imports
# should be a way to move this upstream
-thisPath = pathlib.Path(__file__).parent.resolve()
-genPath = thisPath / "interfaces/gooseai/generation"
-tensPath = thisPath / "interfaces/src/tensorizer/tensors"
-#sys.path.append(str(genPath))
-sys.path.extend([str(genPath), str(tensPath)])
\ No newline at end of file
+this_path = pathlib.Path(__file__).parent.resolve()
+sys.path.extend([
+ str(this_path / "interfaces/gooseai/dashboard"),
+ str(this_path / "interfaces/gooseai/generation"),
+ str(this_path / "interfaces/gooseai/project"),
+ str(this_path / "interfaces/src/tensorizer/tensors")
+])
\ No newline at end of file
diff --git a/src/stability_sdk/animation.py b/src/stability_sdk/animation.py
new file mode 100644
index 00000000..5f183749
--- /dev/null
+++ b/src/stability_sdk/animation.py
@@ -0,0 +1,1131 @@
+import base64
+import bisect
+import cv2
+import glob
+import json
+import logging
+import math
+import numpy as np
+import os
+import param
+import random
+import shutil
+import subprocess
+
+from collections import OrderedDict, deque
+from dataclasses import dataclass, fields
+from keyframed.dsl import curve_from_cn_string
+from PIL import Image, ImageOps
+from types import SimpleNamespace
+from typing import Callable, cast, Deque, Dict, Generator, List, Optional, Tuple, Union
+
+from stability_sdk.api import Context, generation
+from stability_sdk.utils import (
+ camera_pose_transform,
+ color_adjust_transform,
+ depth_calc_transform,
+ guidance_from_string,
+ image_mix,
+ image_to_png_bytes,
+ interpolate_mode_from_string,
+ resample_transform,
+ sampler_from_string,
+)
+import stability_sdk.matrix as matrix
+
+logger = logging.getLogger(__name__)
+logger.setLevel(level=logging.INFO)
+
+DEFAULT_MODEL = 'stable-diffusion-v1-5'
+TRANSLATION_SCALE = 1.0/200.0 # matches Disco and Deforum
+
+docstring_bordermode = (
+ "Method that will be used to fill empty regions, e.g. after a rotation transform."
+ "\n\t* reflect - Mirror pixels across the image edge to fill empty regions."
+ "\n\t* replicate - Use closest pixel values (default)."
+ "\n\t* wrap - Treat image borders as if they were connected, i.e. use pixels from left edge to fill empty regions touching the right edge."
+ "\n\t* zero - Fill empty regions with black pixels."
+ "\n\t* prefill - Do simple inpainting over empty regions."
+)
+
+class BasicSettings(param.Parameterized):
+ width = param.Integer(default=512, doc="Output image dimensions. Will be resized to a multiple of 64.")
+ height = param.Integer(default=512, doc="Output image dimensions. Will be resized to a multiple of 64.")
+ sampler = param.Selector(
+ default='K_dpmpp_2m',
+ objects=[
+ "DDIM", "PLMS", "K_euler", "K_euler_ancestral", "K_heun", "K_dpm_2",
+ "K_dpm_2_ancestral", "K_lms", "K_dpmpp_2m", "K_dpmpp_2s_ancestral"
+ ]
+ )
+ model = param.Selector(
+ default=DEFAULT_MODEL,
+ check_on_set=False, # allow old and new models without raising ValueError
+ objects=[
+ "stable-diffusion-v1-5", "stable-diffusion-512-v2-1", "stable-diffusion-768-v2-1",
+ "stable-diffusion-depth-v2-0", "stable-diffusion-xl-beta-v2-2-2",
+ "custom"
+ ]
+ )
+ custom_model = param.String(default="", doc="Identifier of custom model to use.")
+ seed = param.Integer(default=-1, doc="Provide a seed value for more deterministic behavior. Negative seed values will be replaced with a random seed (default).")
+ cfg_scale = param.Number(default=7, softbounds=(0,20), doc="Classifier-free guidance scale. Strength of prompt influence on denoising process. `cfg_scale=0` gives unconditioned sampling.")
+ clip_guidance = param.Selector(default='None', objects=["None", "Simple", "FastBlue", "FastGreen"], doc="CLIP-guidance preset.")
+ init_image = param.String(default='', doc="Path to image. Height and width dimensions will be inherited from image.")
+ init_sizing = param.Selector(default='stretch', objects=["cover", "stretch", "resize-canvas"])
+ mask_path = param.String(default="", doc="Path to image or video mask")
+ mask_invert = param.Boolean(default=False, doc="White in mask marks areas to change by default.")
+ preset = param.Selector(
+ default='None',
+ objects=[
+ 'None', '3d-model', 'analog-film', 'anime', 'cinematic', 'comic-book', 'digital-art',
+ 'enhance', 'fantasy-art', 'isometric', 'line-art', 'low-poly', 'modeling-compound',
+ 'neon-punk', 'origami', 'photographic', 'pixel-art',
+ ]
+ )
+
+class AnimationSettings(param.Parameterized):
+ animation_mode = param.Selector(default='3D warp', objects=['2D', '3D warp', '3D render', 'Video Input'])
+ max_frames = param.Integer(default=72, doc="Force stop of animation job after this many frames are generated.")
+ border = param.Selector(default='replicate', objects=['reflect', 'replicate', 'wrap', 'zero', 'prefill'], doc=docstring_bordermode)
+ noise_add_curve = param.String(default="0:(0.02)")
+ noise_scale_curve = param.String(default="0:(0.99)")
+ strength_curve = param.String(default="0:(0.65)", doc="Image Strength (of init image relative to the prompt). 0 for ignore init image and attend only to prompt, 1 would return the init image unmodified")
+ steps_curve = param.String(default="0:(30)", doc="Diffusion steps")
+ steps_strength_adj = param.Boolean(default=False, doc="Adjusts number of diffusion steps based on current previous frame strength value.")
+ interpolate_prompts = param.Boolean(default=False, doc="Smoothly interpolate prompts between keyframes. Defaults to False")
+ locked_seed = param.Boolean(default=False)
+
+class CameraSettings(param.Parameterized):
+ """
+ See disco/deforum keyframing syntax, originally developed by Chigozie Nri
+ General syntax: ":(), f2:(v2),f3:(v3)...."
+ Values between intermediate keyframes will be linearly interpolated by default to produce smooth transitions.
+ For abrupt transitions, specify values at adjacent keyframes.
+ """
+ angle = param.String(default="0:(0)", doc="Camera rotation angle in degrees for 2D mode")
+ zoom = param.String(default="0:(1)", doc="Camera zoom factor for 2D mode (<1 zooms out, >1 zooms in)")
+ translation_x = param.String(default="0:(0)")
+ translation_y = param.String(default="0:(0)")
+ translation_z = param.String(default="0:(0)")
+ rotation_x = param.String(default="0:(0)", doc="Camera rotation around X-axis in degrees for 3D modes")
+ rotation_y = param.String(default="0:(0)", doc="Camera rotation around Y-axis in degrees for 3D modes")
+ rotation_z = param.String(default="0:(0)", doc="Camera rotation around Z-axis in degrees for 3D modes")
+
+
+class CoherenceSettings(param.Parameterized):
+ diffusion_cadence_curve = param.String(default="0:(1)", doc="One greater than the number of frames between diffusion operations. A cadence of 1 performs diffusion on each frame. Values greater than one will generate frames using interpolation methods.")
+ cadence_interp = param.Selector(default='mix', objects=['film', 'mix', 'rife', 'vae-lerp', 'vae-slerp'])
+ cadence_spans = param.Boolean(default=False, doc="Experimental diffusion cadence mode for better outpainting")
+
+
+class ColorSettings(param.Parameterized):
+ color_coherence = param.Selector(default='LAB', objects=['None', 'HSV', 'LAB', 'RGB'], doc="Color space that will be used for inter-frame color adjustments.")
+ brightness_curve = param.String(default="0:(1.0)")
+ contrast_curve = param.String(default="0:(1.0)")
+ hue_curve = param.String(default="0:(0.0)")
+ saturation_curve = param.String(default="0:(1.0)")
+ lightness_curve = param.String(default="0:(0.0)")
+ color_match_animate = param.Boolean(default=True, doc="Animate color match between key frames.")
+
+
+class DepthSettings(param.Parameterized):
+ depth_model_weight = param.Number(default=0.3, softbounds=(0,1), doc="Blend factor between AdaBins and MiDaS depth models.")
+ near_plane = param.Number(default=200, doc="Distance to nearest plane of camera view volume.")
+ far_plane = param.Number(default=10000, doc="Distance to furthest plane of camera view volume.")
+ fov_curve = param.String(default="0:(25)", doc="FOV angle of camera volume in degrees.")
+ depth_blur_curve = param.String(default="0:(0.0)", doc="Blur strength of depth map.")
+ depth_warp_curve = param.String(default="0:(1.0)", doc="Depth warp strength.")
+ save_depth_maps = param.Boolean(default=False)
+
+
+class Rendering3dSettings(param.Parameterized):
+ camera_type = param.Selector(default='perspective', objects=['perspective', 'orthographic'])
+ render_mode = param.Selector(default='mesh', objects=['mesh', 'pointcloud'], doc="Mode for image and mask rendering. 'pointcloud' is a bit faster, but 'mesh' is more stable")
+ mask_power = param.Number(default=0.3, softbounds=(0, 4), doc="Raises each mask (0, 1) value to this power. The higher the value the more changes will be applied to the nearest objects")
+
+class InpaintingSettings(param.Parameterized):
+ use_inpainting_model = param.Boolean(default=False, doc="If True, inpainting will be performed using dedicated inpainting model. If False, inpainting will be performed with the regular model that is selected")
+ inpaint_border = param.Boolean(default=False, doc="Use inpainting on top of border regions for 2D and 3D warp modes. Defaults to False")
+ mask_min_value = param.String(default="0:(0.25)", doc="Mask postprocessing for non-inpainting model. Mask floor values will be clipped by this value prior to inpainting")
+ mask_binarization_thr = param.Number(default=0.5, softbounds=(0,1), doc="Grayscale mask values lower than this value will be set to 0, values that are higher — to 1.")
+ save_inpaint_masks = param.Boolean(default=False)
+
+class VideoInputSettings(param.Parameterized):
+ video_init_path = param.String(default="", doc="Path to video input")
+ extract_nth_frame = param.Integer(default=1, bounds=(1,None), doc="Only use every Nth frame of the video")
+ video_mix_in_curve = param.String(default="0:(0.02)")
+ video_flow_warp = param.Boolean(default=True, doc="Whether or not to transfer the optical flow from the video to the generated animation as a warp effect.")
+
+class VideoOutputSettings(param.Parameterized):
+ fps = param.Integer(default=12, doc="Frame rate to use when generating video output.")
+ reverse = param.Boolean(default=False, doc="Whether to reverse the output video or not.")
+
+class AnimationArgs(
+ BasicSettings,
+ AnimationSettings,
+ CameraSettings,
+ CoherenceSettings,
+ ColorSettings,
+ DepthSettings,
+ Rendering3dSettings,
+ InpaintingSettings,
+ VideoInputSettings,
+ VideoOutputSettings
+):
+ """
+ Aggregates parameters from the multiple settings classes.
+ """
+
+@dataclass
+class FrameArgs:
+ """Expansion of key framed Args to per-frame values"""
+ angle: List[float]
+ zoom: List[float]
+ translation_x: List[float]
+ translation_y: List[float]
+ translation_z: List[float]
+ rotation_x: List[float]
+ rotation_y: List[float]
+ rotation_z: List[float]
+ brightness_curve: List[float]
+ contrast_curve: List[float]
+ hue_curve: List[float]
+ saturation_curve: List[float]
+ lightness_curve: List[float]
+ noise_add_curve: List[float]
+ noise_scale_curve: List[float]
+ steps_curve: List[float]
+ strength_curve: List[float]
+ diffusion_cadence_curve: List[float]
+ fov_curve: List[float]
+ depth_blur_curve: List[float]
+ depth_warp_curve: List[float]
+ video_mix_in_curve: List[float]
+ mask_min_value: List[float]
+
+
+def args_to_dict(args):
+ """
+ Converts arguments object to an OrderedDict
+ """
+ if isinstance(args, param.Parameterized):
+ return OrderedDict(args.param.values())
+ elif isinstance(args, SimpleNamespace):
+ return OrderedDict(vars(args))
+ else:
+ raise NotImplementedError(f"Unsupported arguments object type: {type(args)}")
+
+def cv2_to_pil(cv2_img: np.ndarray) -> Image.Image:
+ """Convert a cv2 BGR ndarray to a PIL Image"""
+ return Image.fromarray(cv2_img[:, :, ::-1])
+
+def interpolate_frames(
+ context: Context,
+ frames_path: str,
+ out_path: str,
+ interp_mode: generation.InterpolateMode,
+ interp_factor: int
+) -> Generator[Image.Image, None, None]:
+ """Interpolates frames in a directory using the specified interpolation mode."""
+ assert interp_factor > 1, "Interpolation factor must be greater than 1"
+
+ # gather source frames
+ frame_files = glob.glob(os.path.join(frames_path, "frame_*.png"))
+ frame_files.sort()
+
+ # perform frame interpolation
+ os.makedirs(out_path, exist_ok=True)
+ ratios = np.linspace(0, 1, interp_factor+1)[1:-1].tolist()
+ for i in range(len(frame_files) - 1):
+ shutil.copy(frame_files[i], os.path.join(out_path, f"frame_{i * interp_factor:05d}.png"))
+ frame1 = Image.open(frame_files[i])
+ frame2 = Image.open(frame_files[i + 1])
+ yield frame1
+ tweens = context.interpolate([frame1, frame2], ratios, interp_mode)
+ for ti, tween in enumerate(tweens):
+ tween.save(os.path.join(out_path, f"frame_{i * interp_factor + ti + 1:05d}.png"))
+ yield tween
+
+ # copy final frame
+ shutil.copy(frame_files[-1], os.path.join(out_path, f"frame_{(len(frame_files)-1) * interp_factor:05d}.png"))
+
+def mask_erode_blur(mask: Image.Image, mask_erode: int, mask_blur: int) -> Image.Image:
+ mask = np.array(mask)
+ if mask_erode > 0:
+ ks = mask_erode*2 + 1
+ mask = cv2.erode(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ks, ks)), iterations=1)
+ if mask_blur > 0:
+ ks = mask_blur*2 + 1
+ mask = cv2.GaussianBlur(mask, (ks, ks), 0)
+ return Image.fromarray(mask)
+
+def make_xform_2d(
+ w: float, h: float,
+ rotation_angle: float, # in radians
+ scale_factor: float,
+ translate_x: float,
+ translate_y: float,
+) -> matrix.Matrix:
+ center = (w / 2, h / 2)
+ pre = matrix.translation(-center[0], -center[1], 0)
+ post = matrix.translation(center[0], center[1], 0)
+ rotate = matrix.rotation_euler(0, 0, rotation_angle)
+ scale = matrix.scale(scale_factor, scale_factor, 1)
+ rotate_scale = matrix.multiply(post, matrix.multiply(rotate, matrix.multiply(scale, pre)))
+ # match 3D camera translation, +X moves camera to right, +Y moves camera up
+ translate = matrix.translation(-translate_x, translate_y, 0)
+ return matrix.multiply(rotate_scale, translate)
+
+def model_supports_clip_guidance(model_name: str) -> bool:
+ return not model_name.startswith('stable-diffusion-xl')
+
+def model_requires_depth(model_name: str) -> bool:
+ return model_name == 'stable-diffusion-depth-v2-0'
+
+def sampler_supports_clip_guidance(sampler_name: str) -> bool:
+ supported_samplers = [
+ generation.SAMPLER_K_EULER_ANCESTRAL,
+ generation.SAMPLER_K_DPM_2_ANCESTRAL,
+ generation.SAMPLER_K_DPMPP_2S_ANCESTRAL
+ ]
+ return sampler_from_string(sampler_name) in supported_samplers
+
+def to_3x3(m: matrix.Matrix) -> matrix.Matrix:
+ # convert 4x4 matrix with 2D rotation, scale, and translation to 3x3 matrix
+ return [[m[0][0], m[0][1], m[0][3]],
+ [m[1][0], m[1][1], m[1][3]],
+ [m[3][0], m[3][1], m[3][3]]]
+
+
+class Animator:
+ def __init__(
+ self,
+ api_context: Context,
+ animation_prompts: dict,
+ args: Optional[AnimationArgs] = None,
+ out_dir: Optional[str] = None,
+ negative_prompt: str = '',
+ negative_prompt_weight: float = -1.0,
+ resume: bool = False
+ ):
+ self.api = api_context
+ self.animation_prompts = animation_prompts
+ self.args = args or AnimationArgs()
+ self.color_match_images: Optional[Dict[int, Image.Image]] = {}
+ self.diffusion_cadence_ofs: int = 0
+ self.frame_args: FrameArgs
+ self.inpaint_mask: Optional[Image.Image] = None
+ self.key_frame_values: List[int] = []
+ self.out_dir: Optional[str] = out_dir
+ self.mask: Optional[Image.Image] = None
+ self.mask_reader = None
+ self.cadence_on: bool = False
+ self.prior_frames: Deque[Image.Image] = deque([], 1) # forward warped prior frames. stores one image with cadence off, two images otherwise
+ self.prior_diffused: Deque[Image.Image] = deque([], 1) # results of diffusion. stores one image with cadence off, two images otherwise
+ self.prior_xforms: Deque[matrix.Matrix] = deque([], 1) # accumulated transforms since last diffusion. stores one with cadence off, two otherwise
+ self.negative_prompt: str = negative_prompt
+ self.negative_prompt_weight: float = negative_prompt_weight
+ self.start_frame_idx: int = 0
+ self.video_prev_frame: Optional[Image.Image] = None
+ self.video_reader: Optional[cv2.VideoCapture] = None
+
+ # configure Api to retry on classifier obfuscations
+ self.api._retry_obfuscation = True
+
+ # create output directory
+ if self.out_dir is not None:
+ os.makedirs(self.out_dir, exist_ok=True)
+ elif self.args.save_depth_maps or self.args.save_inpaint_masks:
+ raise ValueError('out_dir must be specified when saving depth maps or inpaint masks')
+
+ self.setup_animation(resume)
+
+ def build_frame_xform(self, frame_idx) -> matrix.Matrix:
+ args, frame_args = self.args, self.frame_args
+
+ if self.args.animation_mode == '2D':
+ angle = frame_args.angle[frame_idx]
+ scale = frame_args.zoom[frame_idx]
+ dx = frame_args.translation_x[frame_idx]
+ dy = frame_args.translation_y[frame_idx]
+ return make_xform_2d(args.width, args.height, math.radians(angle), scale, dx, dy)
+
+ elif self.args.animation_mode in ('3D warp', '3D render'):
+ dx = frame_args.translation_x[frame_idx]
+ dy = frame_args.translation_y[frame_idx]
+ dz = frame_args.translation_z[frame_idx]
+ rx = frame_args.rotation_x[frame_idx]
+ ry = frame_args.rotation_y[frame_idx]
+ rz = frame_args.rotation_z[frame_idx]
+
+ dx, dy, dz = -dx*TRANSLATION_SCALE, dy*TRANSLATION_SCALE, -dz*TRANSLATION_SCALE
+ rx, ry, rz = math.radians(rx), math.radians(ry), math.radians(rz)
+
+ # create xform for the current frame
+ world_view = matrix.multiply(matrix.translation(dx, dy, dz), matrix.rotation_euler(rx, ry, rz))
+ return world_view
+
+ else:
+ return matrix.identity
+
+ def emit_frame(self, frame_idx: int, out_frame: Image.Image) -> Image.Image:
+ if self.args.save_depth_maps:
+ depth_image = self.generate_depth_image(out_frame)
+ self.save_to_out_dir(frame_idx, depth_image, prefix='depth')
+
+ if self.args.save_inpaint_masks and self.inpaint_mask is not None:
+ self.save_to_out_dir(frame_idx, self.inpaint_mask, prefix='mask')
+
+ self.save_to_out_dir(frame_idx, out_frame)
+ return out_frame
+
+ def generate_depth_image(self, image: Image.Image) -> Image.Image:
+ results, _ = self.api.transform(
+ [image],
+ depth_calc_transform(blend_weight=self.args.depth_model_weight)
+ )
+ return results[0]
+
+ def get_animation_prompts_weights(self, frame_idx: int) -> Tuple[List[str], List[float]]:
+ prev, next, tween = self.get_key_frame_tween(frame_idx)
+ if prev == next or not self.args.interpolate_prompts:
+ return [self.animation_prompts[prev]], [1.0]
+ else:
+ return [self.animation_prompts[prev], self.animation_prompts[next]], [1.0 - tween, tween]
+
+ def get_color_match_image(self, frame_idx: int) -> Image.Image:
+ if not self.args.color_match_animate:
+ return self.color_match_images.get(0)
+
+ prev, next, tween = self.get_key_frame_tween(frame_idx)
+
+ if prev not in self.color_match_images:
+ self.color_match_images[prev] = self._render_frame(prev, self.args.seed)
+ prev_match = self.color_match_images[prev]
+ if prev == next:
+ return prev_match
+
+ if next not in self.color_match_images:
+ self.color_match_images[next] = self._render_frame(next, self.args.seed)
+ next_match = self.color_match_images[next]
+
+ # Create image combining colors from previous and next key frames without mixing
+ # the RGB values. Tiles of next key frame are filled in over tiles of previous
+ # key frame. The tween value increases the subtile size on each axis so the transition
+ # is non-linear - staying with previous key frame longer then quickly moving to next.
+ blended = prev_match.copy()
+ width, height, tile_size = blended.width, blended.height, 64
+ for y in range(0, height, tile_size):
+ for x in range(0, width, tile_size):
+ cut = next_match.crop((x, y, x + int(tile_size * tween), y + int(tile_size * tween)))
+ blended.paste(cut, (x, y))
+ return blended
+
+ def get_key_frame_tween(self, frame_idx: int) -> Tuple[int, int, float]:
+ """Returns previous and next key frames along with in between ratio"""
+ keys = self.key_frame_values
+ idx = bisect.bisect_right(keys, frame_idx)
+ prev, next = idx - 1, idx
+ if next == len(keys):
+ return keys[-1], keys[-1], 1.0
+ else:
+ tween = (frame_idx - keys[prev]) / (keys[next] - keys[prev])
+ return keys[prev], keys[next], tween
+
+ def get_frame_filename(self, frame_idx, prefix="frame") -> Optional[str]:
+ return os.path.join(self.out_dir, f"{prefix}_{frame_idx:05d}.png") if self.out_dir else None
+
+ def image_resize(self, img: Image.Image, mode: str = 'stretch') -> Image.Image:
+ width, height = img.size
+ if mode == 'cover':
+ scale = max(self.args.width / width, self.args.height / height)
+ img = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS)
+ x = (img.width - self.args.width) // 2
+ y = (img.height - self.args.height) // 2
+ img = img.crop((x, y, x + self.args.width, y + self.args.height))
+ elif mode == 'stretch':
+ img = img.resize((self.args.width, self.args.height), resample=Image.LANCZOS)
+ else: # 'resize-canvas'
+ width, height = map(lambda x: x - x % 64, (width, height))
+ self.args.width, self.args.height = width, height
+ return img
+
+ def inpaint_frame(
+ self,
+ frame_idx: int,
+ image: Image.Image,
+ mask: Image.Image,
+ seed: Optional[int] = None,
+ mask_blur_radius: Optional[int] = 8
+ ) -> Image.Image:
+ args = self.args
+ steps = int(self.frame_args.steps_curve[frame_idx])
+ sampler = sampler_from_string(args.sampler.lower())
+ guidance = guidance_from_string(args.clip_guidance)
+
+ # fetch set of prompts and weights for this frame
+ prompts, weights = self.get_animation_prompts_weights(frame_idx)
+ if len(self.negative_prompt) and self.negative_prompt_weight != 0.0:
+ prompts.append(self.negative_prompt)
+ weights.append(-abs(self.negative_prompt_weight))
+
+ if args.use_inpainting_model:
+ binary_mask = self._postprocess_inpainting_mask(
+ mask, binarize=True, blur_radius=mask_blur_radius)
+ results = self.api.inpaint(
+ image, binary_mask,
+ prompts, weights,
+ steps=steps,
+ seed=seed if seed is not None else args.seed,
+ cfg_scale=args.cfg_scale,
+ sampler=sampler,
+ init_strength=0.0,
+ masked_area_init=generation.MASKED_AREA_INIT_ZERO,
+ guidance_preset=guidance,
+ preset=args.preset,
+ )
+ else:
+ mask_min_value = self.frame_args.mask_min_value[frame_idx]
+ binary_mask = self._postprocess_inpainting_mask(
+ mask, binarize=True, min_val=mask_min_value, blur_radius=mask_blur_radius)
+ adjusted_steps = max(5, int(steps * (1.0 - mask_min_value))) if args.steps_strength_adj else steps
+ noise_scale = self.frame_args.noise_scale_curve[frame_idx]
+ results = self.api.generate(
+ prompts, weights,
+ args.width, args.height,
+ steps=adjusted_steps,
+ seed=seed if seed is not None else args.seed,
+ cfg_scale=args.cfg_scale,
+ sampler=sampler,
+ init_image=image,
+ init_strength=mask_min_value,
+ init_noise_scale=noise_scale,
+ mask=binary_mask,
+ masked_area_init=generation.MASKED_AREA_INIT_ORIGINAL,
+ guidance_preset=guidance,
+ preset=args.preset,
+ )
+ return results[generation.ARTIFACT_IMAGE][0]
+
+ def load_init_image(self, fpath=None):
+ if fpath is None:
+ fpath = self.args.init_image
+ if not fpath:
+ return
+
+ img = self.image_resize(Image.open(fpath), self.args.init_sizing)
+
+ self.prior_frames.extend([img, img])
+ self.prior_diffused.extend([img, img])
+
+ def load_mask(self):
+ if not self.args.mask_path:
+ return
+
+ # try to load mask as an image
+ mask = Image.open(self.args.mask_path)
+ if mask is not None:
+ self.set_mask(mask)
+
+ # try to load mask as a video
+ if self.mask is None:
+ self.mask_reader = cv2.VideoCapture(self.args.mask_path)
+ self.next_mask()
+
+ if self.mask is None:
+ raise Exception(f"Failed to read mask from {self.args.mask_path}")
+
+ def load_video(self):
+ if self.args.animation_mode != 'Video Input' or not self.args.video_init_path:
+ return
+
+ self.video_reader = cv2.VideoCapture(self.args.video_init_path)
+ if self.video_reader is not None:
+ success, image = self.video_reader.read()
+ if not success:
+ raise Exception(f"Failed to read first frame from {self.args.video_init_path}")
+ self.video_prev_frame = self.image_resize(cv2_to_pil(image), 'cover')
+ self.prior_frames.extend([self.video_prev_frame, self.video_prev_frame])
+ self.prior_diffused.extend([self.video_prev_frame, self.video_prev_frame])
+
+ def next_mask(self):
+ if not self.mask_reader:
+ return False
+
+ for _ in range(self.args.extract_nth_frame):
+ success, mask = self.mask_reader.read()
+ if not success:
+ return
+
+ self.set_mask(cv2_to_pil(mask))
+
+ def prepare_init_ops(self, init_image: Optional[Image.Image], frame_idx: int, noise_seed:int) -> List[generation.TransformParameters]:
+ if init_image is None:
+ return []
+
+ args, frame_args = self.args, self.frame_args
+ brightness = frame_args.brightness_curve[frame_idx]
+ contrast = frame_args.contrast_curve[frame_idx]
+ hue = frame_args.hue_curve[frame_idx]
+ saturation = frame_args.saturation_curve[frame_idx]
+ lightness = frame_args.lightness_curve[frame_idx]
+ noise_amount = frame_args.noise_add_curve[frame_idx]
+
+ color_match_image = None
+ if args.color_coherence != 'None' and frame_idx > 0:
+ color_match_image = self.get_color_match_image(frame_idx)
+
+ do_color_match = args.color_coherence != 'None' and color_match_image is not None
+ do_bchsl = brightness != 1.0 or contrast != 1.0 or hue != 0.0 or saturation != 1.0 or lightness != 0.0
+ do_noise = noise_amount > 0.0
+
+ init_ops: List[generation.TransformParameters] = []
+
+ if do_color_match or do_bchsl or do_noise:
+ init_ops.append(color_adjust_transform(
+ brightness=brightness,
+ contrast=contrast,
+ hue=hue,
+ saturation=saturation,
+ lightness=lightness,
+ match_image=color_match_image,
+ match_mode=args.color_coherence,
+ noise_amount=noise_amount,
+ noise_seed=noise_seed
+ ))
+
+ return init_ops
+
+ def render(self) -> Generator[Image.Image, None, None]:
+ args = self.args
+ seed = args.seed
+
+ # experimental span-based outpainting mode
+ if args.cadence_spans and args.animation_mode != 'Video Input':
+ for idx, frame in self._spans_render():
+ yield self.emit_frame(idx, frame)
+ return
+
+ for frame_idx in range(self.start_frame_idx, args.max_frames):
+ # select image generation model
+ self.api._generate.engine_id = args.custom_model if args.model == "custom" else args.model
+ if model_requires_depth(args.model) and not self.prior_frames:
+ self.api._generate.engine_id = DEFAULT_MODEL
+
+ diffusion_cadence = max(1, int(self.frame_args.diffusion_cadence_curve[frame_idx]))
+ self.set_cadence_mode(enabled=(diffusion_cadence > 1))
+ is_diffusion_frame = (frame_idx - self.diffusion_cadence_ofs) % diffusion_cadence == 0
+
+ steps = int(self.frame_args.steps_curve[frame_idx])
+ strength = max(0.0, self.frame_args.strength_curve[frame_idx])
+
+ # fetch set of prompts and weights for this frame
+ prompts, weights = self.get_animation_prompts_weights(frame_idx)
+ if len(self.negative_prompt) and self.negative_prompt_weight != 0.0:
+ prompts.append(self.negative_prompt)
+ weights.append(-abs(self.negative_prompt_weight))
+
+
+ # transform prior frames
+ stashed_prior_frames = [i.copy() for i in self.prior_frames] if self.mask is not None else []
+ self.inpaint_mask = None
+ if args.animation_mode == '2D':
+ self.inpaint_mask = self.transform_2d(frame_idx)
+ elif args.animation_mode in ('3D render', '3D warp'):
+ self.inpaint_mask = self.transform_3d(frame_idx)
+ elif args.animation_mode == 'Video Input':
+ self.inpaint_mask = self.transform_video(frame_idx)
+
+ # apply inpainting
+ # If cadence is disabled and inpainting is performed using the same model as for generation,
+ # we can optimize inpaint->generate calls into a single generate call.
+ if args.inpaint_border and self.inpaint_mask is not None \
+ and (self.cadence_on or args.use_inpainting_model):
+ for i in range(len(self.prior_frames)):
+ # The earliest prior frame will be popped right after the generation step, so its inpainting would be redundant.
+ if self.cadence_on and is_diffusion_frame and i==0:
+ continue
+ self.prior_frames[i] = self.inpaint_frame(
+ frame_idx, self.prior_frames[i], self.inpaint_mask,
+ seed=None if args.use_inpainting_model else seed)
+
+ # apply mask to transformed prior frames
+ self.next_mask()
+ if self.mask is not None:
+ for i in range(len(self.prior_frames)):
+ if self.cadence_on and is_diffusion_frame and i==0:
+ continue
+ self.prior_frames[i] = image_mix(self.prior_frames[i], stashed_prior_frames[i], self.mask)
+
+ # either run diffusion or emit an inbetween frame
+ if is_diffusion_frame:
+ init_image = self.prior_frames[-1] if len(self.prior_frames) and strength > 0 else None
+ init_strength = strength if init_image is not None else 0.0
+
+ # mix video frame into init image
+ mix_in = self.frame_args.video_mix_in_curve[frame_idx]
+ if init_image is not None and mix_in > 0 and self.video_prev_frame is not None:
+ init_image = image_mix(init_image, self.video_prev_frame, mix_in)
+
+ # when using depth model, compute a depth init image
+ init_depth = None
+ if init_image is not None and model_requires_depth(args.model):
+ depth_source = self.video_prev_frame if self.video_prev_frame is not None else init_image
+ params = depth_calc_transform(blend_weight=1.0, blur_radius=0, reverse=True)
+ results, _ = self.api.transform([depth_source], params)
+ init_depth = results[0]
+
+ # builds set of transform ops to prepare init image for generation
+ init_image_ops = self.prepare_init_ops(init_image, frame_idx, seed)
+
+ # For in-diffusion frames instead of a full run through inpainting model and then generate call,
+ # inpainting can be done in a single call with non-inpainting model
+ do_inpainting = not self.cadence_on and not args.use_inpainting_model \
+ and self.inpaint_mask is not None \
+ and (args.inpaint_border or args.animation_mode == '3D render')
+ if do_inpainting:
+ mask_min_value = self.frame_args.mask_min_value[frame_idx]
+ init_strength = min(strength, mask_min_value)
+ self.inpaint_mask = self._postprocess_inpainting_mask(
+ self.inpaint_mask,
+ mask_pow=args.mask_power if args.animation_mode == '3D render' else None,
+ mask_multiplier=strength,
+ blur_radius=None,
+ min_val=mask_min_value)
+
+ # generate the next frame
+ sampler = sampler_from_string(args.sampler.lower())
+ guidance = guidance_from_string(args.clip_guidance)
+ noise_scale = self.frame_args.noise_scale_curve[frame_idx]
+ adjusted_steps = int(max(5, steps*(1.0-init_strength))) if args.steps_strength_adj else int(steps)
+ generate_request = self.api.generate(
+ prompts, weights,
+ args.width, args.height,
+ steps=adjusted_steps,
+ seed=seed,
+ cfg_scale=args.cfg_scale,
+ sampler=sampler,
+ init_image=init_image if init_image_ops is None else None,
+ init_strength=init_strength,
+ init_noise_scale=noise_scale,
+ init_depth=init_depth,
+ mask = self.inpaint_mask if do_inpainting else self.mask,
+ masked_area_init=generation.MASKED_AREA_INIT_ORIGINAL,
+ guidance_preset=guidance,
+ preset=args.preset,
+ return_request=True
+ )
+ image = self.api.transform_and_generate(init_image, init_image_ops, generate_request)
+
+ if args.color_coherence != 'None' and frame_idx == 0:
+ self.color_match_images[0] = image
+ if not len(self.prior_frames):
+ self.prior_frames.append(image)
+ self.prior_diffused.append(image)
+ self.prior_xforms.append(matrix.identity)
+
+ self.prior_frames.append(image)
+ self.prior_diffused.append(image)
+ self.prior_xforms.append(matrix.identity)
+ self.diffusion_cadence_ofs = frame_idx
+ out_frame = image if not self.cadence_on else self.prior_frames[0]
+ else:
+ assert self.cadence_on
+ # smoothly blend between prior frames
+ tween = ((frame_idx - self.diffusion_cadence_ofs) % diffusion_cadence) / float(diffusion_cadence)
+ out_frame = self.api.interpolate(
+ [self.prior_frames[0], self.prior_frames[1]],
+ [tween],
+ interpolate_mode_from_string(args.cadence_interp)
+ )[0]
+
+ # save and return final frame
+ yield self.emit_frame(frame_idx, out_frame)
+
+ if not args.locked_seed:
+ seed += 1
+
+ def save_settings(self, filename: str):
+ settings_filepath = os.path.join(self.out_dir, filename) if self.out_dir else filename
+ with open(settings_filepath, "w", encoding="utf-8") as f:
+ save_dict = args_to_dict(self.args)
+ for k in ['angle', 'zoom', 'translation_x', 'translation_y', 'translation_z', 'rotation_x', 'rotation_y', 'rotation_z']:
+ save_dict.move_to_end(k, last=True)
+ save_dict['animation_prompts'] = self.animation_prompts
+ save_dict['negative_prompt'] = self.negative_prompt
+ save_dict['negative_prompt_weight'] = self.negative_prompt_weight
+ json.dump(save_dict, f, ensure_ascii=False, indent=4)
+
+ def save_to_out_dir(self, frame_idx: int, image: Image.Image, prefix: str = "frame"):
+ if self.out_dir is not None:
+ image.save(self.get_frame_filename(frame_idx, prefix=prefix))
+
+ def set_mask(self, mask: Image.Image):
+ self.mask = mask.convert('L').resize((self.args.width, self.args.height), resample=Image.LANCZOS)
+
+ # this is intentionally flipped because we want white in the mask to represent
+ # areas that should change which is opposite from the backend which treats
+ # the mask as per pixel offset in the schedule starting value
+ if not self.args.mask_invert:
+ self.mask = ImageOps.invert(self.mask)
+
+ def set_cadence_mode(self, enabled: bool):
+ def set_queue_size(prior_queue: deque, prev_length: int, new_length: int) -> deque:
+ assert new_length in (1, 2)
+ if new_length == prev_length:
+ return prior_queue
+ new_queue: deque = deque([], new_length)
+ if len(prior_queue) > 0:
+ if new_length == 2 and prev_length == 1:
+ new_queue.extend([prior_queue[0], prior_queue[0]])
+ elif new_length == 1 and prev_length == 2:
+ new_queue.append(prior_queue[-1])
+ return new_queue
+
+ if enabled == self.cadence_on:
+ return
+ elif enabled:
+ self.prior_frames = set_queue_size(self.prior_frames, 1, 2)
+ self.prior_diffused = set_queue_size(self.prior_diffused, 1, 2)
+ self.prior_xforms = set_queue_size(self.prior_xforms, 1, 2)
+ else:
+ self.prior_frames = set_queue_size(self.prior_frames, 2, 1)
+ self.prior_diffused = set_queue_size(self.prior_diffused, 2, 1)
+ self.prior_xforms = set_queue_size(self.prior_xforms, 2, 1)
+ self.cadence_on = enabled
+
+ def setup_animation(self, resume):
+ args = self.args
+
+ # change request for random seed into explicit value so it is saved to settings
+ if args.seed <= 0:
+ args.seed = random.randint(0, 2**32 - 1)
+
+ # select image generation model
+ self.api._generate.engine_id = args.custom_model if args.model == "custom" else args.model
+
+ # validate border settings
+ if args.border == 'wrap' and args.animation_mode != '2D':
+ args.border = 'reflect'
+ logger.warning(f"Border 'wrap' is only supported in 2D mode, switching to '{args.border}'.")
+
+ # validate clip guidance setting against selected model and sampler
+ if args.clip_guidance.lower() != 'none':
+ if not (model_supports_clip_guidance(args.model) and sampler_supports_clip_guidance(args.sampler)):
+ unsupported = args.model if not model_supports_clip_guidance(args.model) else args.sampler
+ logger.warning(f"CLIP guidance is not supported by {unsupported}, disabling guidance.")
+ args.clip_guidance = 'None'
+
+ def curve_to_series(curve: str) -> List[float]:
+ return curve_from_cn_string(curve)
+
+ # expand key frame strings to per frame series
+ frame_args_dict = {f.name: curve_to_series(getattr(args, f.name)) for f in fields(FrameArgs)}
+ self.frame_args = FrameArgs(**frame_args_dict)
+
+ # prepare sorted list of key frames
+ self.key_frame_values = sorted(list(self.animation_prompts.keys()))
+ if self.key_frame_values[0] != 0:
+ raise ValueError("First keyframe must be 0")
+ if len(self.key_frame_values) != len(set(self.key_frame_values)):
+ raise ValueError("Duplicate keyframes are not allowed!")
+
+ diffusion_cadence = max(1, int(self.frame_args.diffusion_cadence_curve[self.start_frame_idx]))
+ # initialize accumulated transforms
+ self.set_cadence_mode(enabled=(diffusion_cadence > 1))
+ self.prior_xforms.extend([matrix.identity, matrix.identity])
+
+ # prepare inputs
+ self.load_mask()
+ self.load_video()
+ self.load_init_image()
+
+ # handle resuming animation from last frames of a previous run
+ if resume:
+ if not self.out_dir:
+ raise ValueError("Cannot resume animation without out_dir specified")
+ frames = [f for f in os.listdir(self.out_dir) if f.endswith(".png") and f.startswith("frame_")]
+ self.start_frame_idx = len(frames)
+ self.diffusion_cadence_ofs = self.start_frame_idx
+ if self.start_frame_idx > 2:
+ prev = Image.open(self.get_frame_filename(self.start_frame_idx-2))
+ next = Image.open(self.get_frame_filename(self.start_frame_idx-1))
+ self.prior_frames.extend([prev, next])
+ self.prior_diffused.extend([prev, next])
+ elif self.start_frame_idx > 1 and not self.cadence_on:
+ prev = Image.open(self.get_frame_filename(self.start_frame_idx-1))
+ self.prior_frames.append(prev)
+ self.prior_diffused.append(prev)
+
+ def transform_2d(self, frame_idx) -> Optional[Image.Image]:
+ if not len(self.prior_frames):
+ return None
+
+ # create xform for the current frame
+ xform = self.build_frame_xform(frame_idx)
+
+ # check if we can skip transform request
+ if np.allclose(xform, matrix.identity):
+ return None
+
+ args = self.args
+ if not args.inpaint_border:
+ # apply xform to prior frames running xforms
+ for i in range(len(self.prior_xforms)):
+ self.prior_xforms[i] = matrix.multiply(xform, self.prior_xforms[i])
+
+ # warp prior diffused frames by accumulated xforms
+ for i in range(len(self.prior_diffused)):
+ params = resample_transform(args.border, to_3x3(self.prior_xforms[i]), export_mask=args.inpaint_border)
+ xformed, mask = self.api.transform([self.prior_diffused[i]], params)
+ self.prior_frames[i] = xformed[0]
+ else:
+ params = resample_transform(args.border, to_3x3(xform), export_mask=args.inpaint_border)
+ transformed_prior_frames, mask = self.api.transform(self.prior_frames, params)
+ self.prior_frames.extend(transformed_prior_frames)
+
+ return mask[0] if isinstance(mask, list) else mask
+
+ def transform_3d(self, frame_idx) -> Optional[Image.Image]:
+ if not len(self.prior_frames):
+ return None
+
+ args, frame_args = self.args, self.frame_args
+ near, far = args.near_plane, args.far_plane
+ fov = frame_args.fov_curve[frame_idx]
+ depth_blur = int(frame_args.depth_blur_curve[frame_idx])
+ depth_warp = frame_args.depth_warp_curve[frame_idx]
+
+ depth_calc = depth_calc_transform(args.depth_model_weight, depth_blur)
+
+ # create xform for the current frame
+ world_view = self.build_frame_xform(frame_idx)
+ projection = matrix.projection_fov(math.radians(fov), 1.0, near, far)
+
+ if False:
+ # currently disabled. for 3D mode transform accumulation needs additional
+ # depth map changes to work properly without swimming artifacts
+
+ # apply world_view xform to prior frames running xforms
+ for i in range(len(self.prior_xforms)):
+ self.prior_xforms[i] = matrix.multiply(world_view, self.prior_xforms[i])
+
+ # warp prior diffused frames by accumulated xforms
+ for i in range(len(self.prior_diffused)):
+ wvp = matrix.multiply(projection, self.prior_xforms[i])
+ resample = resample_transform(args.border, wvp, projection, depth_warp=depth_warp, export_mask=args.inpaint_border)
+ xformed, mask = self.api.transform_3d([self.prior_diffused[i]], depth_calc, resample)
+ self.prior_frames[i] = xformed[0]
+ else:
+ if args.animation_mode == '3D warp':
+ wvp = matrix.multiply(projection, world_view)
+ transform_op = resample_transform(args.border, wvp, projection, depth_warp=depth_warp, export_mask=args.inpaint_border)
+ else:
+ transform_op = camera_pose_transform(
+ world_view, near, far, fov,
+ args.camera_type,
+ render_mode=args.render_mode,
+ do_prefill=not args.use_inpainting_model)
+ transformed_prior_frames, mask = self.api.transform_3d(self.prior_frames, depth_calc, transform_op)
+ self.prior_frames.extend(transformed_prior_frames)
+ return mask[0] if isinstance(mask, list) else mask
+
+ def transform_video(self, frame_idx) -> Optional[Image.Image]:
+ assert self.video_reader is not None
+ if not len(self.prior_frames):
+ return None
+
+ args = self.args
+ for _ in range(args.extract_nth_frame):
+ success, video_next_frame = self.video_reader.read()
+ video_next_frame = cv2_to_pil(video_next_frame)
+ if success:
+ video_next_frame = self.image_resize(video_next_frame, 'cover')
+ mask = None
+ if args.video_flow_warp and video_next_frame is not None:
+ # warp_flow is in `extras` and will change in the future
+ prev_b64 = base64.b64encode(image_to_png_bytes(self.video_prev_frame)).decode('utf-8')
+ next_b64 = base64.b64encode(image_to_png_bytes(video_next_frame)).decode('utf-8')
+ extras = { "warp_flow": { "prev_frame": prev_b64, "next_frame": next_b64, "export_mask": args.inpaint_border } }
+ transformed_prior_frames, masks = self.api.transform(self.prior_frames, generation.TransformParameters(), extras=extras)
+ if masks is not None:
+ mask = masks[0]
+ self.prior_frames.extend(transformed_prior_frames)
+ self.video_prev_frame = video_next_frame
+ self.color_match_image = video_next_frame
+ return mask
+ return None
+
+ def _postprocess_inpainting_mask(
+ self,
+ mask: Union[Image.Image, np.ndarray],
+ mask_pow: Optional[float] = None,
+ mask_multiplier: Optional[float] = None,
+ binarize: bool = False,
+ blur_radius: Optional[int] = None,
+ min_val: Optional[float] = None
+ ) -> Image.Image:
+ # Being applied in 3D render mode. Camera pose transform operation returns a mask which pixel values encode
+ # how much signal from the previous frame is present there. But a mapping from the signal presence values
+ # to the optimal per-pixel init strength is unknown, and roughly guessed as a per-pixel power function.
+ # Leaving mask_pow=1 results in near objects changing to a greater extent than a natural emergence of fine details when approaching an object.
+ if isinstance(mask, Image.Image):
+ mask = np.array(mask)
+ if mask_pow is not None:
+ mask = (np.power(mask / 255., mask_pow) * 255).astype(np.uint8)
+ if mask_multiplier is not None:
+ mask = (mask * mask_multiplier).astype(np.uint8)
+ if binarize:
+ mask = np.where(mask > self.args.mask_binarization_thr * 255, 255, 0).astype(np.uint8)
+ if blur_radius:
+ kernel_size = blur_radius*2+1
+ mask = cv2.erode(mask, np.ones((kernel_size, kernel_size), np.uint8))
+ mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
+ if min_val is not None:
+ mask = mask.clip(255 * min_val, 255).astype(np.uint8)
+ return Image.fromarray(mask)
+
+ def _render_frame(
+ self,
+ frame_idx: int,
+ seed: int,
+ init: Optional[Image.Image]=None,
+ mask: Optional[Image.Image]=None,
+ strength: Optional[float]=None
+ ) -> Image.Image:
+ args = self.args
+ steps = int(self.frame_args.steps_curve[frame_idx])
+ strength = strength if strength is not None else max(0.0, self.frame_args.strength_curve[frame_idx])
+ adjusted_steps = int(max(5, steps*(1.0-strength))) if args.steps_strength_adj else int(steps)
+
+ # fetch set of prompts and weights for this frame
+ prompts, weights = self.get_animation_prompts_weights(frame_idx)
+ if len(self.negative_prompt) and self.negative_prompt_weight != 0.0:
+ prompts.append(self.negative_prompt)
+ weights.append(-abs(self.negative_prompt_weight))
+
+ init_ops = self.prepare_init_ops(init, frame_idx, seed)
+
+ sampler = sampler_from_string(args.sampler.lower())
+ guidance = guidance_from_string(args.clip_guidance)
+ generate_request = self.api.generate(
+ prompts, weights,
+ args.width, args.height,
+ steps = adjusted_steps,
+ seed = seed,
+ cfg_scale = args.cfg_scale,
+ sampler = sampler,
+ init_image = init if init_ops is None else None,
+ init_strength = strength if init is not None else 0.0,
+ init_noise_scale = self.frame_args.noise_scale_curve[frame_idx],
+ mask = mask if mask is not None else self.mask,
+ masked_area_init = generation.MASKED_AREA_INIT_ORIGINAL,
+ guidance_preset = guidance,
+ preset = args.preset,
+ return_request = True
+ )
+
+ result_image = self.api.transform_and_generate(init, init_ops, generate_request)
+
+ if args.color_coherence != 'None' and frame_idx == 0:
+ self.color_match_images[0] = result_image
+
+ return result_image
+
+ def _span_render(self, start: int, end: int, prev_frame: Image.Image, next_seed: Callable[[], int]) -> Generator[Tuple[int, Image.Image], None, None]:
+ args = self.args
+
+ def apply_xform(frame: Image.Image, xform: matrix.Matrix, frame_idx: int) -> Tuple[Image.Image, Image.Image]:
+ args, frame_args = self.args, self.frame_args
+ if args.animation_mode == '2D':
+ xform = to_3x3(xform)
+ frames, masks = self.api.transform([frame], resample_transform(args.border, xform, export_mask=True))
+ else:
+ fov = frame_args.fov_curve[frame_idx]
+ depth_blur = int(frame_args.depth_blur_curve[frame_idx])
+ depth_warp = frame_args.depth_warp_curve[frame_idx]
+ projection = matrix.projection_fov(math.radians(fov), 1.0, args.near_plane, args.far_plane)
+ wvp = matrix.multiply(projection, xform)
+ depth_calc = depth_calc_transform(args.depth_model_weight, depth_blur)
+ resample = resample_transform(args.border, wvp, projection, depth_warp=depth_warp, export_mask=True)
+ frames, masks = self.api.transform_3d([frame], depth_calc, resample)
+ masks = cast(List[Image.Image], masks)
+ return frames[0], masks[0]
+
+ # transform the previous frame forward
+ accum_xform = matrix.identity
+ forward_frames, forward_masks = [], []
+ for frame_idx in range(start, end):
+ accum_xform = matrix.multiply(self.build_frame_xform(frame_idx), accum_xform)
+ frame, mask = apply_xform(prev_frame, accum_xform, frame_idx)
+ forward_frames.append(frame)
+ forward_masks.append(mask)
+
+ # inpaint the final frame
+ if not np.all(forward_masks[-1]):
+ forward_frames[-1] = self.inpaint_frame(
+ end-1, forward_frames[-1], forward_masks[-1],
+ mask_blur_radius=0, seed=next_seed())
+
+ # run diffusion on top of the final result to allow content to evolve over time
+ strength = max(0.0, self.frame_args.strength_curve[end-1])
+ if strength < 1.0:
+ final_frame = self._render_frame(end-1, next_seed(), forward_frames[-1])
+ else:
+ final_frame = forward_frames[-1]
+
+ # go backwards through the frames in the span
+ backward_frames, backward_masks = [final_frame], [Image.new('L', forward_masks[-1].size, 255)]
+ accum_xform = matrix.identity
+ for frame_idx in range(end-2, start-1, -1):
+ frame_xform = self.build_frame_xform(frame_idx+1)
+ inv_xform = np.linalg.inv(frame_xform).tolist()
+ accum_xform = matrix.multiply(inv_xform, accum_xform)
+ xformed, mask = apply_xform(backward_frames[-1], accum_xform, frame_idx)
+ backward_frames.insert(0, xformed)
+ backward_masks.insert(0, mask)
+
+ # inpaint the backwards frame
+ if not np.all(backward_masks[0]):
+ backward_frames[0] = self.inpaint_frame(
+ start, backward_frames[0], backward_masks[0],
+ mask_blur_radius=0, seed=next_seed())
+
+ # yield the final frames blending from forward to backward
+ for idx, (frame_fwd, frame_bwd) in enumerate(zip(forward_frames, backward_frames)):
+ t = (idx) / max(1, end-start-1)
+ fwd_fill = image_mix(frame_bwd, frame_fwd, mask_erode_blur(forward_masks[idx], 8, 8))
+ bwd_fill = image_mix(frame_fwd, frame_bwd, mask_erode_blur(backward_masks[idx], 8, 8))
+ blended = self.api.interpolate(
+ [fwd_fill, bwd_fill],
+ [t],
+ interpolate_mode_from_string(args.cadence_interp)
+ )[0]
+ yield start+idx, blended
+
+ def _spans_render(self) -> Generator[Tuple[int, Image.Image], None, None]:
+ frame_idx = self.start_frame_idx
+ seed = self.args.seed
+ def next_seed() -> int:
+ nonlocal seed
+ if not self.args.locked_seed:
+ seed += 1
+ return seed
+
+ prev_frame = self._render_frame(frame_idx, seed, None)
+ yield frame_idx, prev_frame
+
+ while frame_idx < self.args.max_frames:
+ # determine how many frames the span will process together
+ diffusion_cadence = max(1, int(self.frame_args.diffusion_cadence_curve[frame_idx]))
+ if frame_idx + diffusion_cadence > self.args.max_frames:
+ diffusion_cadence = self.args.max_frames - frame_idx
+
+ # render all frames in the span
+ for idx, frame in self._span_render(frame_idx, frame_idx + diffusion_cadence, prev_frame, next_seed):
+ yield idx, frame
+ prev_frame = frame
+ next_seed()
+
+ frame_idx += diffusion_cadence
diff --git a/src/stability_sdk/animation_ui.py b/src/stability_sdk/animation_ui.py
new file mode 100644
index 00000000..88a2010a
--- /dev/null
+++ b/src/stability_sdk/animation_ui.py
@@ -0,0 +1,835 @@
+import glob
+import json
+import locale
+import os
+import param
+import shutil
+import traceback
+
+from collections import OrderedDict
+from PIL import Image
+from tqdm import tqdm
+from typing import Any, Dict, List, Optional
+
+try:
+ import gradio as gr
+except ImportError:
+ raise ImportError(
+ "Failed to import animation UI requirements. To use the animation UI, install the dependencies with:\n"
+ " pip install --upgrade stability_sdk[anim_ui]"
+ )
+
+from .api import (
+ ClassifierException,
+ Context,
+ OutOfCreditsException,
+)
+from .animation import (
+ AnimationArgs,
+ Animator,
+ AnimationSettings,
+ BasicSettings,
+ CameraSettings,
+ CoherenceSettings,
+ ColorSettings,
+ DepthSettings,
+ InpaintingSettings,
+ Rendering3dSettings,
+ VideoInputSettings,
+ VideoOutputSettings,
+ interpolate_frames
+)
+from .utils import (
+ create_video_from_frames,
+ extract_frames_from_video,
+ interpolate_mode_from_string
+)
+
+
+DATA_VERSION = "0.1"
+DATA_GENERATOR = "stability_sdk.animation_ui"
+
+PRESETS = {
+ "Default": {},
+ "3D warp rotate": {
+ "animation_mode": "3D warp", "rotation_y":"0:(0.4)", "translation_x":"0:(-1.2)", "depth_model_weight":1.0,
+ "animation_prompts": "{\n0:\"a flower vase on a table\"\n}"
+ },
+ "3D warp zoom": {
+ "animation_mode":"3D warp", "diffusion_cadence_curve":"0:(4)", "noise_scale_curve":"0:(1.04)",
+ "strength_curve":"0:(0.7)", "translation_z":"0:(1.0)",
+ },
+ "3D render rotate": {
+ "animation_mode": "3D render", "depth_model_weight":1.0,
+ "translation_x":"0:(-3.5)", "rotation_y":"0:(1.7)", "translation_z":"0:(-0.5)",
+ "diffusion_cadence_curve":"0:(1)", "strength_curve":"0:(0.96)", "noise_scale_curve":"0:(1.01)",
+ "mask_min_value":"0:(0.35)", "use_inpainting_model":False, "preset": "anime",
+ "animation_prompts": "{\n0:\"beautiful portrait of a ninja in a sunflower field\"\n}"
+ },
+ "3D render explore": {
+ "animation_mode": "3D render", "translation_z":"0:(10)", "translation_x":"0:(2), 20:(-2), 40:(2)",
+ "rotation_y":"0:(0), 10:(1.5), 30:(-2), 50: (3)", "rotation_x":"0:(0.4)",
+ "diffusion_cadence_curve":"0:(1)", "strength_curve":"0:(0.98)",
+ "noise_scale_curve":"0:(1.01)", "depth_model_weight":1.0,
+ "mask_min_value":"0:(0.1)", "use_inpainting_model":False, "preset":"3d-model",
+ "animation_prompts": "{\n0:\"Phantasmagoric carnival, carnival attractions shifting and changing, bizarre surreal circus\"\n}"
+ },
+ "Prompt interpolate": {
+ "animation_mode":"2D", "interpolate_prompts":True, "locked_seed":True, "max_frames":24,
+ "strength_curve":"0:(0)", "diffusion_cadence_curve":"0:(4)", "cadence_interp":"film",
+ "clip_guidance":"None", "animation_prompts": "{\n0:\"a photo of a cute cat\",\n24:\"a photo of a cute dog\"\n}"
+ },
+ "Translate and inpaint": {
+ "animation_mode":"2D", "inpaint_border":True, "use_inpainting_model":False, "translation_x":"0:(-20)",
+ "diffusion_cadence_curve":"0:(3)", "strength_curve":"0:(0.85)", "noise_scale_curve":"0:(1.01)", "border":"reflect",
+ "animation_prompts": "{\n0:\"Mystical pumpkin field landscapes on starry Halloween night, pop surrealism art\"\n}"
+ },
+ "Outpaint": {
+ "animation_mode":"2D", "diffusion_cadence_curve":"0:(16)", "cadence_spans":True, "use_inpainting_model":True,
+ "strength_curve":"0:(1)", "reverse":True, "preset": "fantasy-art", "inpaint_border":True, "zoom":"0:(0.95)",
+ "animation_prompts": "{\n0:\"an ancient and magical portal, in a fantasy corridor\"\n}"
+ },
+ "Video Stylize": {
+ "animation_mode":"Video Input", "model":"stable-diffusion-depth-v2-0", "locked_seed":True,
+ "strength_curve":"0:(0.22)", "clip_guidance":"None", "video_mix_in_curve":"0:(1.0)", "video_flow_warp":True,
+ },
+}
+
+class Project():
+ def __init__(self, title, settings={}) -> None:
+ self.folder = title.replace("/", "_").replace("\\", "_").replace(":", "")
+ self.settings = settings
+ self.title = title
+
+ @classmethod
+ def list_projects(cls) -> List["Project"]:
+ projects = []
+ for path in os.listdir(outputs_path):
+ directory = os.path.join(outputs_path, path)
+ if not os.path.isdir(directory):
+ continue
+
+ json_files = glob.glob(os.path.join(directory, '*.json'))
+ json_files = sorted(json_files, key=lambda x: os.stat(x).st_mtime)
+ if not json_files:
+ continue
+
+ filename = os.path.basename(json_files[-1])
+ if not '(' in filename:
+ continue
+
+ project = cls(filename[:filename.rfind('(')-1].strip())
+ try:
+ project.settings = json.load(open(os.path.join(directory, filename), 'r'))
+ except:
+ continue
+ projects.append(project)
+ return projects
+
+
+context = None
+outputs_path = None
+
+args_generation = BasicSettings()
+args_animation = AnimationSettings()
+args_camera = CameraSettings()
+args_coherence = CoherenceSettings()
+args_color = ColorSettings()
+args_depth = DepthSettings()
+args_render_3d = Rendering3dSettings()
+args_inpaint = InpaintingSettings()
+args_vid_in = VideoInputSettings()
+args_vid_out = VideoOutputSettings()
+arg_objs = (
+ args_generation,
+ args_animation,
+ args_camera,
+ args_coherence,
+ args_color,
+ args_depth,
+ args_render_3d,
+ args_inpaint,
+ args_vid_in,
+ args_vid_out,
+)
+
+animation_prompts = "{\n0: \"\"\n}"
+negative_prompt = "blurry, low resolution"
+negative_prompt_weight = -1.0
+
+controls: Dict[str, gr.components.Component] = {}
+header = gr.HTML("", show_progress=False)
+interrupt = False
+last_interp_factor = None
+last_interp_mode = None
+last_project_settings_path = None
+last_upscale = None
+projects: List[Project] = []
+project: Optional[Project] = None
+resume_checkbox = gr.Checkbox(label="Resume", value=False, interactive=True)
+resume_from_number = gr.Number(label="Resume from frame", value=-1, interactive=True, precision=0,
+ info="Positive frame number to resume from, or -1 to resume from the last")
+
+project_create_button = gr.Button("Create")
+project_data_log = gr.Textbox(label="Status", visible=False)
+project_load_button = gr.Button("Load")
+project_new_title = gr.Text(label="Name", value="My amazing animation", interactive=True)
+project_preset_dropdown = gr.Dropdown(label="Preset", choices=list(PRESETS.keys()), value=list(PRESETS.keys())[0], interactive=True)
+project_row_create = None
+project_row_import = None
+project_row_load = None
+projects_dropdown = gr.Dropdown([p.title for p in projects], label="Project", visible=True, interactive=True)
+
+project_import_button = gr.Button("Import")
+project_import_file = gr.File(label="Project file", file_types=[".json", ".txt"], type="binary")
+project_import_title = gr.Text(label="Name", value="Imported project", interactive=True)
+
+
+def accordion_for_color(args: ColorSettings):
+ p = args.param
+ with gr.Accordion("Color", open=False):
+ controls["color_coherence"] = gr.Dropdown(label="Color coherence", choices=p.color_coherence.objects, value=p.color_coherence.default, interactive=True)
+ with gr.Row():
+ controls["brightness_curve"] = gr.Text(label="Brightness curve", value=p.brightness_curve.default, interactive=True)
+ controls["contrast_curve"] = gr.Text(label="Contrast curve", value=p.contrast_curve.default, interactive=True)
+ with gr.Row():
+ controls["hue_curve"] = gr.Text(label="Hue curve", value=p.hue_curve.default, interactive=True)
+ controls["saturation_curve"] = gr.Text(label="Saturation curve", value=p.saturation_curve.default, interactive=True)
+ controls["lightness_curve"] = gr.Text(label="Lightness curve", value=p.lightness_curve.default, interactive=True)
+ controls["color_match_animate"] = gr.Checkbox(label="Animated color match", value=p.color_match_animate.default, interactive=True)
+
+def accordion_from_args(name: str, args: param.Parameterized, exclude: List[str]=[], open=False):
+ with gr.Accordion(name, open=open):
+ ui_from_args(args, exclude)
+
+def args_reset_to_defaults():
+ for args in arg_objs:
+ for k, v in args.param.objects().items():
+ if k == "name":
+ continue
+ setattr(args, k, v.default)
+
+def args_to_controls(data: Optional[dict]=None) -> dict:
+ # go through all the parameters and load their settings from the data
+ global animation_prompts, negative_prompt
+ if data:
+ for arg in arg_objs:
+ for k, v in arg.param.objects().items():
+ if k != "name" and k in data:
+ arg.param.set_param(k, data[k])
+ if "animation_prompts" in data:
+ animation_prompts = data["animation_prompts"]
+ if "negative_prompt" in data:
+ negative_prompt = data["negative_prompt"]
+
+ returns = {}
+ returns[controls['animation_prompts']] = gr.update(value=animation_prompts)
+ returns[controls['negative_prompt']] = gr.update(value=negative_prompt)
+
+ for args in arg_objs:
+ for k, v in args.param.objects().items():
+ if k in controls:
+ c = controls[k]
+ returns[c] = gr.update(value=getattr(args, k))
+
+ return returns
+
+def ensure_api_context():
+ if context is None:
+ raise gr.Error("Not connected to Stability API")
+
+def format_header_html() -> str:
+ try:
+ balance, profile_picture = context.get_user_info()
+ except:
+ return ""
+ formatted_number = locale.format_string("%d", balance, grouping=True)
+ return f"""
+
+
Stable Animation UI
+
+
+ {formatted_number}
+
+
+
+
+
+ """
+
+def get_default_project():
+ data = OrderedDict(AnimationArgs().param.values())
+ data.update({
+ "version": DATA_VERSION,
+ "generator": DATA_GENERATOR
+ })
+ return data
+
+def post_process_tab():
+ with gr.Row():
+ with gr.Column():
+ with gr.Row(visible=False):
+ use_video_instead = gr.Checkbox(label="Postprocess a video instead", value=False, interactive=True)
+ video_to_postprocess = gr.Text(label="Videofile to postprocess", value="", interactive=True)
+ fps = gr.Number(label="Output FPS", value=24, interactive=True, precision=0)
+ reverse = gr.Checkbox(label="Reverse", value=False, interactive=True)
+ with gr.Row():
+ frame_interp_mode = gr.Dropdown(label="Frame interpolation mode", choices=['None', 'film', 'rife'], value='None', interactive=True)
+ frame_interp_factor = gr.Dropdown(label="Frame interpolation factor", choices=[2, 4, 8], value=2, interactive=True)
+ with gr.Row():
+ upscale = gr.Checkbox(label="Upscale 2X", value=False, interactive=True)
+ with gr.Column():
+ image_out = gr.Image(label="image", visible=True)
+ video_out = gr.Video(label="video", visible=False)
+ process_button = gr.Button("Process")
+ stop_button = gr.Button("Stop", visible=False)
+ error_log = gr.Textbox(label="Error", lines=3, visible=False)
+
+ def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: int, upscale: bool,
+ use_video_instead: bool, video_to_postprocess: str):
+ global interrupt, last_interp_factor, last_interp_mode, last_upscale
+ interrupt = False
+ if not use_video_instead and last_project_settings_path is None:
+ raise gr.Error("Please render an animation first or specify a videofile to postprocess")
+ if use_video_instead and not os.path.exists(video_to_postprocess):
+ raise gr.Error("Videofile does not exist")
+
+ yield {
+ header: gr.update(),
+ image_out: gr.update(visible=True, label=""),
+ video_out: gr.update(visible=False),
+ process_button: gr.update(visible=False),
+ stop_button: gr.update(visible=True),
+ error_log: gr.update(visible=False),
+ }
+
+ error = None
+ try:
+ outdir = os.path.dirname(last_project_settings_path) \
+ if not use_video_instead \
+ else extract_frames_from_video(video_to_postprocess)
+ suffix = ""
+
+ can_skip_upscale = last_upscale == upscale
+ can_skip_interp = can_skip_upscale and last_interp_factor == interp_factor and last_interp_mode == interp_mode
+
+ if upscale:
+ suffix += "_x2"
+ upscale_dir = os.path.join(outdir, "upscale")
+ os.makedirs(upscale_dir, exist_ok=True)
+ frame_paths = sorted(glob.glob(os.path.join(outdir, "frame_*.png")))
+ num_frames = len(frame_paths)
+ if not can_skip_upscale:
+ remove_frames_from_path(upscale_dir)
+ for frame_idx in tqdm(range(num_frames)):
+ frame = Image.open(frame_paths[frame_idx])
+ frame = context.upscale(frame)
+ frame.save(os.path.join(upscale_dir, os.path.basename(frame_paths[frame_idx])))
+ yield {
+ header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(),
+ image_out: gr.update(value=frame, label=f"upscale {frame_idx}/{num_frames}", visible=True),
+ video_out: gr.update(visible=False),
+ process_button: gr.update(visible=False),
+ stop_button: gr.update(visible=True),
+ error_log: gr.update(visible=False),
+ }
+ if interrupt:
+ break
+ last_upscale = upscale
+ outdir = upscale_dir
+
+ if interp_mode != 'None':
+ suffix += f"_{interp_mode}{interp_factor}"
+ interp_dir = os.path.join(outdir, "interpolate")
+ interp_mode = interpolate_mode_from_string(interp_mode)
+ if not can_skip_interp:
+ remove_frames_from_path(interp_dir)
+ num_frames = interp_factor * len(glob.glob(os.path.join(outdir, "frame_*.png")))
+ for frame_idx, frame in enumerate(tqdm(interpolate_frames(context, outdir, interp_dir, interp_mode, interp_factor), total=num_frames)):
+ yield {
+ header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(),
+ image_out: gr.update(value=frame, label=f"interpolate {frame_idx}/{num_frames}", visible=True),
+ video_out: gr.update(visible=False),
+ process_button: gr.update(visible=False),
+ stop_button: gr.update(visible=True),
+ error_log: gr.update(visible=False),
+ }
+ if interrupt:
+ break
+ last_interp_mode, last_interp_factor = interp_mode, interp_factor
+ outdir = interp_dir
+
+ if not use_video_instead:
+ output_video = last_project_settings_path.replace(".json", f"{suffix}.mp4")
+ else:
+ _, video_ext = os.path.splitext(video_to_postprocess)
+ output_video = video_to_postprocess.replace(video_ext, f"{suffix}.mp4")
+ create_video_from_frames(outdir, output_video, fps=fps, reverse=reverse)
+ except Exception as e:
+ traceback.print_exc()
+ error = f"Post-processing terminated early due to exception: {e}"
+
+ yield {
+ header: gr.update(value=format_header_html()),
+ image_out: gr.update(visible=False),
+ video_out: gr.update(value=output_video, visible=True),
+ process_button: gr.update(visible=True),
+ stop_button: gr.update(visible=False),
+ error_log: gr.update(value=error, visible=bool(error))
+ }
+
+ process_button.click(
+ postprocess_video,
+ inputs=[fps, reverse, frame_interp_mode, frame_interp_factor, upscale, use_video_instead, video_to_postprocess],
+ outputs=[header, image_out, video_out, process_button, stop_button, error_log]
+ )
+
+ def stop_button_click():
+ global interrupt
+ interrupt = True
+ stop_button.click(stop_button_click)
+
+
+def project_create(title, preset):
+ ensure_api_context()
+ global project, projects
+ titles = [p.title for p in projects]
+ if title in titles:
+ raise gr.Error(f"Project with title '{title}' already exists")
+ project = Project(title, get_default_project())
+ projects.append(project)
+ projects = sorted(projects, key=lambda p: p.title)
+
+ # grab each setting from the preset and add to settings
+ for k, v in PRESETS[preset].items():
+ project.settings[k] = v
+
+ log = f"Created project '{title}'"
+
+ args_reset_to_defaults()
+ returns = args_to_controls(project.settings)
+ returns[project_data_log] = gr.update(value=log, visible=True)
+ returns[projects_dropdown] = gr.update(choices=[p.title for p in projects], visible=True, value=title)
+ returns[project_row_load] = gr.update(visible=len(projects) > 0)
+ return returns
+
+def project_import(title, file):
+ ensure_api_context()
+ global project, projects
+ titles = [p.title for p in projects]
+ if title in titles:
+ raise gr.Error(f"Project with title '{title}' already exists")
+
+ # read json from file
+ try:
+ settings = json.loads(file.decode('utf-8'))
+ except Exception as e:
+ raise gr.Error(f"Failed to read settings from file: {e}")
+
+ project = Project(title, settings)
+ projects.append(project)
+ projects = sorted(projects, key=lambda p: p.title)
+
+ log = f"Imported project '{title}'"
+
+ args_reset_to_defaults()
+ returns = args_to_controls(project.settings)
+ returns[project_data_log] = gr.update(value=log, visible=True)
+ returns[projects_dropdown] = gr.update(choices=[p.title for p in projects], visible=True, value=title)
+ returns[project_row_load] = gr.update(visible=len(projects) > 0)
+ return returns
+
+def project_load(title: str):
+ ensure_api_context()
+ global project
+ project = next(p for p in projects if p.title == title)
+ data = project.settings
+
+ log = f"Loaded project '{title}'"
+
+ # filter project file to latest version
+ if "animation_mode" in data and data["animation_mode"] == "3D":
+ data["animation_mode"] = "3D warp"
+ if "midas_weight" in data:
+ data["depth_model_weight"] = data["midas_weight"]
+ del data["midas_weight"]
+
+ # update the ui controls
+ returns = args_to_controls(data)
+ returns[project_data_log] = gr.update(value=log, visible=True)
+ return returns
+
+def project_tab():
+ global project_row_create, project_row_import, project_row_load
+
+ button_load_projects = gr.Button("Load Projects")
+ with gr.Accordion("Load a project", open=True, visible=False) as projects_row_:
+ project_row_load = projects_row_
+ with gr.Row():
+ projects_dropdown.render()
+ with gr.Column():
+ project_load_button.render()
+ with gr.Row():
+ delete_btn = gr.Button("Delete")
+ confirm_btn = gr.Button("Confirm delete", variant="stop", visible=False)
+ cancel_btn = gr.Button("Cancel", visible=False)
+ delete_btn.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)], None, [delete_btn, confirm_btn, cancel_btn])
+ cancel_btn.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, [delete_btn, confirm_btn, cancel_btn])
+
+ with gr.Accordion("Create a new project", open=True, visible=False) as project_row_create_:
+ project_row_create = project_row_create_
+ with gr.Column():
+ with gr.Row():
+ project_new_title.render()
+ project_preset_dropdown.render()
+ with gr.Column():
+ project_create_button.render()
+
+ with gr.Accordion("Import a project file", open=False, visible=False) as project_row_import_:
+ project_row_import = project_row_import_
+ with gr.Column():
+ with gr.Row():
+ project_import_title.render()
+ project_import_file.render()
+ with gr.Column():
+ project_import_button.render()
+
+ project_data_log.render()
+
+ def delete_project(title: str):
+ ensure_api_context()
+ global project, projects
+
+ project = next(p for p in projects if p.title == title)
+ project_path = os.path.join(outputs_path, project.folder)
+ if os.path.exists(project_path):
+ shutil.rmtree(project_path)
+
+ projects.remove(project)
+ project = None
+
+ log = f"Deleted project \"{title}\" at \"{project_path}\""
+ return {
+ projects_dropdown: gr.update(choices=[p.title for p in projects], visible=True),
+ project_row_load: gr.update(visible=len(projects) > 0),
+ project_data_log: gr.update(value=log, visible=True),
+ delete_btn: gr.update(visible=True),
+ confirm_btn: gr.update(visible=False),
+ cancel_btn: gr.update(visible=False)
+ }
+
+ def load_projects():
+ ensure_api_context()
+ global projects
+ projects = Project.list_projects()
+ return {
+ button_load_projects: gr.update(visible=False),
+ projects_dropdown: gr.update(choices=[p.title for p in projects], visible=True),
+ project_row_create: gr.update(visible=True),
+ project_row_import: gr.update(visible=True),
+ project_row_load: gr.update(visible=len(projects) > 0),
+ header: gr.update(value=format_header_html())
+ }
+
+ button_load_projects.click(load_projects, outputs=[button_load_projects, projects_dropdown, project_row_create, project_row_import, project_row_load, header])
+ confirm_btn.click(delete_project, inputs=projects_dropdown, outputs=[projects_dropdown, project_row_load, project_data_log, delete_btn, confirm_btn, cancel_btn])
+
+def remove_frames_from_path(path: str, leave_first: Optional[int]=None):
+ if os.path.isdir(path):
+ frames = sorted(glob.glob(os.path.join(path, "frame_*.png")))
+ if leave_first:
+ frames = frames[leave_first:]
+ for f in frames:
+ os.remove(f)
+
+def render_tab():
+ with gr.Row():
+ with gr.Column():
+ ui_layout_tabs()
+ with gr.Column():
+ image_out = gr.Image(label="image", visible=True)
+ video_out = gr.Video(label="video", visible=False)
+ button = gr.Button("Render")
+ button_stop = gr.Button("Stop", visible=False)
+ error_log = gr.Textbox(label="Error", lines=3, visible=False)
+
+ def render(resume: bool, resume_from: int, *render_args):
+ global interrupt, last_interp_factor, last_interp_mode, last_project_settings_path, last_upscale, project
+ interrupt = False
+
+ if not project:
+ raise gr.Error("No project active!")
+
+ # create local folder for the project
+ outdir = os.path.join(outputs_path, project.folder)
+ os.makedirs(outdir, exist_ok=True)
+
+ # each render gets a unique run index
+ run_index = 0
+ while True:
+ project_settings_path = os.path.join(outdir, f"{project.folder} ({run_index}).json")
+ if not os.path.exists(project_settings_path):
+ break
+ run_index += 1
+
+ # gather up all the settings from sub-objects
+ args_d = {k: v for k, v in zip(controls.keys(), render_args)}
+ animation_prompts, negative_prompt = args_d['animation_prompts'], args_d['negative_prompt']
+ del args_d['animation_prompts'], args_d['negative_prompt']
+ args = AnimationArgs(**args_d)
+
+ if args.animation_mode == "Video Input" and not args.video_init_path:
+ raise gr.Error("No video input file selected!")
+
+ # convert animation_prompts from string (JSON or python) to dict
+ try:
+ prompts = json.loads(animation_prompts)
+ except json.JSONDecodeError:
+ try:
+ prompts = eval(animation_prompts)
+ except Exception as e:
+ raise gr.Error("Invalid JSON or Python code for animation_prompts!")
+ prompts = {int(k): v for k, v in prompts.items()}
+
+ # save settings to a dict
+ save_dict = OrderedDict()
+ save_dict['version'] = DATA_VERSION
+ save_dict['generator'] = DATA_GENERATOR
+ save_dict.update(args.param.values())
+ save_dict['animation_prompts'] = animation_prompts
+ save_dict['negative_prompt'] = negative_prompt
+ project.settings = save_dict
+ with open(project_settings_path, 'w', encoding='utf-8') as f:
+ json.dump(save_dict, f, indent=4)
+
+ # initial yield to switch render button to stop button
+ yield {
+ button: gr.update(visible=False),
+ button_stop: gr.update(visible=True),
+ image_out: gr.update(visible=True, label=""),
+ video_out: gr.update(visible=False),
+ header: gr.update(),
+ error_log: gr.update(visible=False),
+ }
+
+ # delete frames from previous animation
+ if resume:
+ if resume_from > 0:
+ remove_frames_from_path(outdir, resume_from)
+ elif resume_from == 0 or resume_from < -1:
+ raise gr.Error("Frame number to resume from must be positive, or -1 to resume from the last frame")
+ else:
+ remove_frames_from_path(outdir)
+
+ frame_idx, error = 0, None
+ try:
+ animator = Animator(
+ api_context=context,
+ animation_prompts=prompts,
+ args=args,
+ out_dir=outdir,
+ negative_prompt=negative_prompt,
+ negative_prompt_weight=negative_prompt_weight,
+ resume=resume,
+ )
+ for frame_idx, frame in enumerate(tqdm(animator.render(), initial=animator.start_frame_idx, total=args.max_frames), start=animator.start_frame_idx):
+ if interrupt:
+ break
+
+ # saving frames to project
+ #frame_uuid = project.put_image_asset(frame)
+
+ yield {
+ button: gr.update(visible=False),
+ button_stop: gr.update(visible=True),
+ image_out: gr.update(value=frame, label=f"frame {frame_idx}/{args.max_frames}", visible=True),
+ video_out: gr.update(visible=False),
+ header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(),
+ error_log: gr.update(visible=False),
+ }
+ except ClassifierException as e:
+ error = "Animation terminated early due to NSFW classifier."
+ if e.prompt is not None:
+ error += "\nPlease revise your prompt: " + e.prompt
+ except OutOfCreditsException as e:
+ error = f"Animation terminated early, out of credits.\n{e.details}"
+ except Exception as e:
+ traceback.print_exc()
+ error = f"Animation terminated early due to exception: {e}"
+
+ if frame_idx:
+ last_project_settings_path = project_settings_path
+ last_interp_factor, last_interp_mode, last_upscale = None, None, None
+ output_video = project_settings_path.replace(".json", ".mp4")
+ try:
+ create_video_from_frames(outdir, output_video, fps=args.fps, reverse=args.reverse)
+ except RuntimeError as e:
+ error = f"Error creating video: {e}"
+ output_video = None
+ else:
+ output_video = None
+ yield {
+ button: gr.update(visible=True),
+ button_stop: gr.update(visible=False),
+ image_out: gr.update(visible=False),
+ video_out: gr.update(value=output_video, visible=True),
+ header: gr.update(value=format_header_html()),
+ error_log: gr.update(value=error, visible=bool(error)),
+ }
+
+ button.click(
+ render,
+ inputs=[resume_checkbox, resume_from_number] + list(controls.values()),
+ outputs=[button, button_stop, image_out, video_out, header, error_log]
+ )
+
+ # stop animation in progress
+ def stop():
+ global interrupt
+ interrupt = True
+ button_stop.click(stop)
+
+def ui_for_animation_settings(args: AnimationSettings):
+ with gr.Row():
+ controls["steps_strength_adj"] = gr.Checkbox(label="Steps strength adj", value=args.param.steps_strength_adj.default, interactive=True)
+ controls["interpolate_prompts"] = gr.Checkbox(label="Interpolate prompts", value=args.param.interpolate_prompts.default, interactive=True)
+ controls["locked_seed"] = gr.Checkbox(label="Locked seed", value=args.param.locked_seed.default, interactive=True)
+ controls["noise_add_curve"] = gr.Text(label="Noise add curve", value=args.param.noise_add_curve.default, interactive=True)
+ controls["noise_scale_curve"] = gr.Text(label="Noise scale curve", value=args.param.noise_scale_curve.default, interactive=True)
+ controls["strength_curve"] = gr.Text(label="Previous frame strength curve", value=args.param.strength_curve.default, interactive=True)
+ controls["steps_curve"] = gr.Text(label="Steps curve", value=args.param.steps_curve.default, interactive=True)
+
+def ui_for_generation(args: AnimationSettings):
+ p = args.param
+ with gr.Row():
+ controls["width"] = gr.Number(label="Width", value=p.width.default, interactive=True, precision=0)
+ controls["height"] = gr.Number(label="Height", value=p.height.default, interactive=True, precision=0)
+ with gr.Row():
+ controls["model"] = gr.Dropdown(label="Model", choices=p.model.objects, value=p.model.default, interactive=True)
+ controls["custom_model"] = gr.Text(label="Custom model", value=p.custom_model.default, interactive=True)
+ with gr.Row():
+ controls["preset"] = gr.Dropdown(label="Style preset", choices=p.preset.objects, value=p.preset.default, interactive=True)
+ with gr.Row():
+ controls["sampler"] = gr.Dropdown(label="Sampler", choices=p.sampler.objects, value=p.sampler.default, interactive=True)
+ controls["seed"] = gr.Number(label="Seed", value=p.seed.default, interactive=True, precision=0)
+ controls["cfg_scale"] = gr.Number(label="Guidance scale", value=p.cfg_scale.default, interactive=True)
+ controls["clip_guidance"] = gr.Dropdown(label="CLIP guidance", choices=p.clip_guidance.objects, value=p.clip_guidance.default, interactive=True)
+
+def ui_for_init_and_mask(args_generation):
+ p = args_generation.param
+ with gr.Row():
+ controls["init_image"] = gr.Text(label="Init image", value=p.init_image.default, interactive=True)
+ controls["init_sizing"] = gr.Dropdown(label="Init sizing", choices=p.init_sizing.objects, value=p.init_sizing.default, interactive=True)
+ with gr.Row():
+ controls["mask_path"] = gr.Text(label="Mask path", value=p.mask_path.default, interactive=True)
+ controls["mask_invert"] = gr.Checkbox(label="Mask invert", value=p.mask_invert.default, interactive=True)
+
+def ui_for_video_output(args: VideoOutputSettings):
+ p = args.param
+ controls["fps"] = gr.Number(label="FPS", value=p.fps.default, interactive=True, precision=0)
+ controls["reverse"] = gr.Checkbox(label="Reverse", value=p.reverse.default, interactive=True)
+
+def ui_from_args(args: param.Parameterized, exclude: List[str]=[]):
+ for k, v in args.param.objects().items():
+ if k == "name" or k in exclude:
+ continue
+ if isinstance(v, param.Boolean):
+ t = gr.Checkbox(label=v.label, value=v.default, interactive=True)
+ elif isinstance(v, param.Integer):
+ t = gr.Number(label=v.label, value=v.default, interactive=True, precision=0)
+ elif isinstance(v, param.Number):
+ t = gr.Number(label=v.label, value=v.default, interactive=True)
+ elif isinstance(v, param.Selector):
+ t = gr.Dropdown(label=v.label, choices=v.objects, value=v.default, interactive=True)
+ elif isinstance(v, param.String):
+ t = gr.Text(label=v.label, value=v.default, interactive=True)
+ else:
+ raise Exception(f"Unknown parameter type {v} for param {k}")
+ controls[k] = t
+
+def ui_layout_tabs():
+ with gr.Tab("Prompts"):
+ with gr.Row():
+ controls['animation_prompts'] = gr.TextArea(label="Animation prompts", max_lines=8, value=animation_prompts, interactive=True)
+ with gr.Row():
+ controls['negative_prompt'] = gr.Textbox(label="Negative prompt", max_lines=1, value=negative_prompt, interactive=True)
+ with gr.Tab("Config"):
+ with gr.Row():
+ args = args_animation
+ controls["animation_mode"] = gr.Dropdown(label="Animation mode", choices=args.param.animation_mode.objects, value=args.param.animation_mode.default, interactive=True)
+ controls["max_frames"] = gr.Number(label="Max frames", value=args.param.max_frames.default, interactive=True, precision=0)
+ controls["border"] = gr.Dropdown(label="Border", choices=args.param.border.objects, value=args.param.border.default, interactive=True)
+ ui_for_generation(args_generation)
+ ui_for_animation_settings(args_animation)
+ accordion_from_args("Coherence", args_coherence, open=False)
+ accordion_for_color(args_color)
+ accordion_from_args("Depth", args_depth, exclude=["near_plane", "far_plane"], open=False)
+ accordion_from_args("3D render", args_render_3d, open=False)
+ accordion_from_args("Inpainting", args_inpaint, open=False)
+ with gr.Tab("Input"):
+ with gr.Row():
+ resume_checkbox.render()
+ resume_from_number.render()
+ ui_for_init_and_mask(args_generation)
+ with gr.Column():
+ p = args_vid_in.param
+ with gr.Row():
+ controls["video_init_path"] = gr.Text(label="Video init path", value=p.video_init_path.default, interactive=True)
+ with gr.Row():
+ controls["video_mix_in_curve"] = gr.Text(label="Mix in curve", value=p.video_mix_in_curve.default, interactive=True)
+ controls["extract_nth_frame"] = gr.Number(label="Extract nth frame", value=p.extract_nth_frame.default, interactive=True, precision=0)
+ controls["video_flow_warp"] = gr.Checkbox(label="Flow warp", value=p.video_flow_warp.default, interactive=True)
+
+ with gr.Tab("Camera"):
+ p = args_camera.param
+ gr.Markdown("2D Camera")
+ controls["angle"] = gr.Text(label="Angle", value=p.angle.default, interactive=True)
+ controls["zoom"] = gr.Text(label="Zoom", value=p.zoom.default, interactive=True)
+
+ gr.Markdown("2D and 3D Camera translation")
+ controls["translation_x"] = gr.Text(label="Translation X", value=p.translation_x.default, interactive=True)
+ controls["translation_y"] = gr.Text(label="Translation Y", value=p.translation_y.default, interactive=True)
+ controls["translation_z"] = gr.Text(label="Translation Z", value=p.translation_z.default, interactive=True)
+
+ gr.Markdown("3D Camera rotation")
+ controls["rotation_x"] = gr.Text(label="Rotation X", value=p.rotation_x.default, interactive=True)
+ controls["rotation_y"] = gr.Text(label="Rotation Y", value=p.rotation_y.default, interactive=True)
+ controls["rotation_z"] = gr.Text(label="Rotation Z", value=p.rotation_z.default, interactive=True)
+
+ with gr.Tab("Output"):
+ ui_for_video_output(args_vid_out)
+
+
+def create_ui(api_context: Context, outputs_root_path: str):
+ global context, outputs_path, projects
+ context, outputs_path = api_context, outputs_root_path
+
+ locale.setlocale(locale.LC_ALL, '')
+
+ with gr.Blocks() as ui:
+ header.render()
+
+ with gr.Tab("Project"):
+ project_tab()
+
+ with gr.Tab("Render"):
+ render_tab()
+
+ with gr.Tab("Post-process"):
+ post_process_tab()
+
+ load_project_outputs = [project_data_log]
+ load_project_outputs.extend(controls.values())
+ project_load_button.click(project_load, inputs=projects_dropdown, outputs=load_project_outputs)
+
+ create_project_outputs = [project_data_log, projects_dropdown, project_row_load]
+ create_project_outputs.extend(controls.values())
+ project_create_button.click(project_create, inputs=[project_new_title, project_preset_dropdown], outputs=create_project_outputs)
+ project_import_button.click(project_import, inputs=[project_import_title, project_import_file], outputs=create_project_outputs)
+
+ return ui
diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py
new file mode 100644
index 00000000..2057961f
--- /dev/null
+++ b/src/stability_sdk/api.py
@@ -0,0 +1,662 @@
+import grpc
+import io
+import logging
+import random
+import time
+
+from google.protobuf.struct_pb2 import Struct
+from PIL import Image
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+
+import stability_sdk.interfaces.gooseai.dashboard.dashboard_pb2 as dashboard
+import stability_sdk.interfaces.gooseai.dashboard.dashboard_pb2_grpc as dashboard_grpc
+import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc
+
+from .utils import (
+ image_mix,
+ image_to_prompt,
+ tensor_to_prompt,
+)
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(level=logging.INFO)
+
+
+def open_channel(host: str, api_key: str = None, max_message_len: int = 20*1024*1024) -> grpc.Channel:
+ options=[
+ ('grpc.max_send_message_length', max_message_len),
+ ('grpc.max_receive_message_length', max_message_len),
+ ]
+ if host.endswith(":443"):
+ call_credentials = [grpc.access_token_call_credentials(api_key)]
+ channel_credentials = grpc.composite_channel_credentials(
+ grpc.ssl_channel_credentials(), *call_credentials
+ )
+ channel = grpc.secure_channel(host, channel_credentials, options=options)
+ else:
+ channel = grpc.insecure_channel(host, options=options)
+ return channel
+
+
+class ClassifierException(Exception):
+ """Raised when server classifies generated content as inappropriate.
+
+ Attributes:
+ classifier_result: Categories the result image exceeded the threshold for
+ prompt: The prompt that was classified as inappropriate
+ """
+ def __init__(self, classifier_result: Optional[generation.ClassifierParameters]=None, prompt: Optional[str]=None):
+ self.classifier_result = classifier_result
+ self.prompt = prompt
+
+class OutOfCreditsException(Exception):
+ """Raised when account doesn't have enough credits to perform a request."""
+ def __init__(self, details: str):
+ self.details = details
+
+
+class Endpoint:
+ def __init__(self, stub, engine_id):
+ self.stub = stub
+ self.engine_id = engine_id
+
+
+class Context:
+ def __init__(
+ self,
+ host: str="",
+ api_key: str=None,
+ stub: generation_grpc.GenerationServiceStub=None,
+ generate_engine_id: str="stable-diffusion-xl-beta-v2-2-2",
+ inpaint_engine_id: str="stable-inpainting-512-v2-0",
+ interpolate_engine_id: str="interpolation-server-v1",
+ transform_engine_id: str="transform-server-v1",
+ upscale_engine_id: str="esrgan-v1-x2plus",
+ ):
+ if not host and stub is None:
+ raise Exception("Must provide either GRPC host or stub to Api")
+
+ channel = open_channel(host, api_key) if host else None
+ if not stub:
+ stub = generation_grpc.GenerationServiceStub(channel)
+
+ self._dashboard_stub = dashboard_grpc.DashboardServiceStub(channel) if channel else None
+
+ self._generate = Endpoint(stub, generate_engine_id)
+ self._inpaint = Endpoint(stub, inpaint_engine_id)
+ self._interpolate = Endpoint(stub, interpolate_engine_id)
+ self._transform = Endpoint(stub, transform_engine_id)
+ self._upscale = Endpoint(stub, upscale_engine_id)
+
+ self._debug_no_chains = False
+ self._max_retries = 5 # retry request on RPC error
+ self._request_timeout = 30.0 # timeout in seconds for each request
+ self._retry_delay = 1.0 # base delay in seconds between retries, each attempt will double
+ self._retry_obfuscation = False # retry request with different seed on classifier obfuscation
+ self._retry_schedule_offset = 0.1 # increase schedule start by this amount on each retry after the first
+
+ self._user_organization_id: Optional[str] = None
+ self._user_profile_picture: str = ''
+
+ def generate(
+ self,
+ prompts: List[str],
+ weights: List[float],
+ width: int = 512,
+ height: int = 512,
+ steps: Optional[int] = None,
+ seed: Union[Sequence[int], int] = 0,
+ samples: int = 1,
+ cfg_scale: float = 7.0,
+ sampler: generation.DiffusionSampler = None,
+ init_image: Optional[Image.Image] = None,
+ init_strength: float = 0.0,
+ init_noise_scale: Optional[float] = None,
+ init_depth: Optional[Image.Image] = None,
+ mask: Optional[Image.Image] = None,
+ masked_area_init: generation.MaskedAreaInit = generation.MASKED_AREA_INIT_ORIGINAL,
+ guidance_preset: generation.GuidancePreset = generation.GUIDANCE_PRESET_NONE,
+ guidance_cuts: int = 0,
+ guidance_strength: float = 0.0,
+ preset: Optional[str] = None,
+ return_request: bool = False,
+ ) -> Dict[int, List[Any]]:
+ """
+ Generate an image from a set of weighted prompts.
+
+ :param prompts: List of text prompts
+ :param weights: List of prompt weights
+ :param width: Width of the generated image
+ :param height: Height of the generated image
+ :param steps: Number of steps to run the diffusion process
+ :param seed: Random seed for the starting noise
+ :param samples: Number of samples to generate
+ :param cfg_scale: Classifier free guidance scale
+ :param sampler: Sampler to use for the diffusion process
+ :param init_image: Initial image to use
+ :param init_strength: Strength of the initial image
+ :param init_noise_scale: Scale of the initial noise
+ :param mask: Mask to use (0 for pixels to change, 255 for pixels to keep)
+ :param masked_area_init: How to initialize the masked area
+ :param guidance_preset: Preset to use for CLIP guidance
+ :param guidance_cuts: Number of cuts to use with CLIP guidance
+ :param guidance_strength: Strength of CLIP guidance
+ :param preset: Style preset to use
+ :param return_request: Whether to return the request instead of running it
+ :return: dict mapping artifact type to data
+ """
+ if not prompts and init_image is None:
+ raise ValueError("prompt and/or init_image must be provided")
+
+ if (mask is not None) and (init_image is None) and not return_request:
+ raise ValueError("If mask_image is provided, init_image must also be provided")
+
+ p = [generation.Prompt(text=prompt, parameters=generation.PromptParameters(weight=weight)) for prompt,weight in zip(prompts, weights)]
+ if init_image is not None:
+ p.append(image_to_prompt(init_image))
+ if mask is not None:
+ p.append(image_to_prompt(mask, type=generation.ARTIFACT_MASK))
+ if init_depth is not None:
+ p.append(image_to_prompt(init_depth, type=generation.ARTIFACT_DEPTH))
+
+ start_schedule = 1.0 - init_strength
+ image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
+ start_schedule, init_noise_scale, masked_area_init,
+ guidance_preset, guidance_cuts, guidance_strength)
+
+ extras = Struct()
+ if preset and preset.lower() != 'none':
+ extras.update({ '$IPC': { "preset": preset } })
+
+ request = generation.Request(engine_id=self._generate.engine_id, prompt=p, image=image_params, extras=extras)
+ if return_request:
+ return request
+
+ results = self._run_request(self._generate, request)
+
+ return results
+
+ def get_user_info(self) -> Tuple[float, str]:
+ """Get the number of credits the user has remaining and their profile picture."""
+ if not self._user_organization_id:
+ user = self._dashboard_stub.GetMe(dashboard.EmptyRequest())
+ self._user_profile_picture = user.profile_picture
+ self._user_organization_id = user.organizations[0].organization.id
+ organization = self._dashboard_stub.GetOrganization(dashboard.GetOrganizationRequest(id=self._user_organization_id))
+ return organization.payment_info.balance * 100, self._user_profile_picture
+
+ def inpaint(
+ self,
+ image: Image.Image,
+ mask: Image.Image,
+ prompts: List[str],
+ weights: List[float],
+ steps: Optional[int] = None,
+ seed: Union[Sequence[int], int] = 0,
+ samples: int = 1,
+ cfg_scale: float = 7.0,
+ sampler: generation.DiffusionSampler = None,
+ init_strength: float = 0.0,
+ init_noise_scale: Optional[float] = None,
+ masked_area_init: generation.MaskedAreaInit = generation.MASKED_AREA_INIT_ZERO,
+ guidance_preset: generation.GuidancePreset = generation.GUIDANCE_PRESET_NONE,
+ guidance_cuts: int = 0,
+ guidance_strength: float = 0.0,
+ preset: Optional[str] = None,
+ ) -> Dict[int, List[Any]]:
+ """
+ Apply inpainting to an image.
+
+ :param image: Source image
+ :param mask: Mask image with 0 for pixels to change and 255 for pixels to keep
+ :param prompts: List of text prompts
+ :param weights: List of prompt weights
+ :param steps: Number of steps to run
+ :param seed: Random seed
+ :param samples: Number of samples to generate
+ :param cfg_scale: Classifier free guidance scale
+ :param sampler: Sampler to use for the diffusion process
+ :param init_strength: Strength of the initial image
+ :param init_noise_scale: Scale of the initial noise
+ :param masked_area_init: How to initialize the masked area
+ :param guidance_preset: Preset to use for CLIP guidance
+ :param guidance_cuts: Number of cuts to use with CLIP guidance
+ :param guidance_strength: Strength of CLIP guidance
+ :param preset: Style preset to use
+ :return: dict mapping artifact type to data
+ """
+ p = [generation.Prompt(text=prompt, parameters=generation.PromptParameters(weight=weight)) for prompt,weight in zip(prompts, weights)]
+ p.append(image_to_prompt(image))
+ p.append(image_to_prompt(mask, type=generation.ARTIFACT_MASK))
+
+ width, height = image.size
+ start_schedule = 1.0-init_strength
+ image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
+ start_schedule, init_noise_scale, masked_area_init,
+ guidance_preset, guidance_cuts, guidance_strength)
+
+ extras = Struct()
+ if preset and preset.lower() != 'none':
+ extras.update({ '$IPC': { "preset": preset } })
+
+ request = generation.Request(engine_id=self._inpaint.engine_id, prompt=p, image=image_params, extras=extras)
+ results = self._run_request(self._inpaint, request)
+
+ return results
+
+ def interpolate(
+ self,
+ images: Sequence[Image.Image],
+ ratios: List[float],
+ mode: generation.InterpolateMode = generation.INTERPOLATE_LINEAR,
+ ) -> List[Image.Image]:
+ """
+ Interpolate between two images
+
+ :param images: Two images with matching resolution
+ :param ratios: In-between ratios to interpolate at
+ :param mode: Interpolation mode
+ :return: One image for each ratio
+ """
+ assert len(images) == 2
+ assert len(ratios) >= 1
+
+ if len(ratios) == 1:
+ if ratios[0] == 0.0:
+ return [images[0]]
+ elif ratios[0] == 1.0:
+ return [images[1]]
+ elif mode == generation.INTERPOLATE_LINEAR:
+ return [image_mix(images[0], images[1], ratios[0])]
+
+ p = [image_to_prompt(image) for image in images]
+ request = generation.Request(
+ engine_id=self._interpolate.engine_id,
+ prompt=p,
+ interpolate=generation.InterpolateParameters(ratios=ratios, mode=mode)
+ )
+
+ results = self._run_request(self._interpolate, request)
+ return results[generation.ARTIFACT_IMAGE]
+
+ def transform_and_generate(
+ self,
+ image: Optional[Image.Image],
+ params: List[generation.TransformParameters],
+ generate_request: generation.Request,
+ extras: Optional[Dict] = None,
+ ) -> Image.Image:
+ extras_struct = None
+ if extras is not None:
+ extras_struct = Struct()
+ extras_struct.update(extras)
+
+ if not params:
+ results = self._run_request(self._generate, generate_request)
+ return results[generation.ARTIFACT_IMAGE][0]
+
+ assert image is not None
+ requests = [
+ generation.Request(
+ engine_id=self._transform.engine_id,
+ requested_type=generation.ARTIFACT_TENSOR,
+ prompt=[image_to_prompt(image)],
+ transform=param,
+ extras=extras_struct,
+ ) for param in params
+ ]
+
+ if self._debug_no_chains:
+ prev_result = None
+ for rq in requests:
+ if prev_result is not None:
+ rq.prompt.pop()
+ rq.prompt.append(tensor_to_prompt(prev_result))
+ prev_result = self._run_request(self._transform, rq)[generation.ARTIFACT_TENSOR][0]
+ generate_request.prompt.append(tensor_to_prompt(prev_result))
+ results = self._run_request(self._generate, generate_request)
+ else:
+ stages = []
+ for idx, rq in enumerate(requests):
+ stages.append(generation.Stage(
+ id=str(idx),
+ request=rq,
+ on_status=[generation.OnStatus(
+ action=[generation.STAGE_ACTION_PASS],
+ target=str(idx+1)
+ )]
+ ))
+ stages.append(generation.Stage(
+ id=str(len(params)),
+ request=generate_request,
+ on_status=[generation.OnStatus(
+ action=[generation.STAGE_ACTION_RETURN],
+ target=None
+ )]
+ ))
+ chain_rq = generation.ChainRequest(request_id="xform_gen_chain", stage=stages)
+ results = self._run_request(self._transform, chain_rq)
+
+ return results[generation.ARTIFACT_IMAGE][0]
+
+ def transform(
+ self,
+ images: Sequence[Image.Image],
+ params: Union[generation.TransformParameters, List[generation.TransformParameters]],
+ extras: Optional[Dict] = None
+ ) -> Tuple[List[Image.Image], Optional[List[Image.Image]]]:
+ """
+ Transform images
+
+ :param images: One or more images to transform
+ :param params: Transform operations to apply to each image
+ :return: One image artifact for each image and one transform dependent mask
+ """
+ assert len(images)
+ assert isinstance(images[0], Image.Image)
+
+ extras_struct = None
+ if extras is not None:
+ extras_struct = Struct()
+ extras_struct.update(extras)
+
+ if isinstance(params, List) and len(params) > 1:
+ if self._debug_no_chains:
+ for param in params:
+ images, mask = self.transform(images, param, extras)
+ return images, mask
+
+ assert extras is None
+ stages = []
+ for idx, param in enumerate(params):
+ final = idx == len(params) - 1
+ rq = generation.Request(
+ engine_id=self._transform.engine_id,
+ prompt=[image_to_prompt(image) for image in images] if idx == 0 else None,
+ transform=param,
+ extras_struct=extras_struct
+ )
+ stages.append(generation.Stage(
+ id=str(idx),
+ request=rq,
+ on_status=[generation.OnStatus(
+ action=[generation.STAGE_ACTION_PASS if not final else generation.STAGE_ACTION_RETURN],
+ target=str(idx+1) if not final else None
+ )]
+ ))
+ chain_rq = generation.ChainRequest(request_id="xform_chain", stage=stages)
+ results = self._run_request(self._transform, chain_rq)
+ else:
+ request = generation.Request(
+ engine_id=self._transform.engine_id,
+ prompt=[image_to_prompt(image) for image in images],
+ transform=params[0] if isinstance(params, List) else params,
+ extras=extras_struct
+ )
+ results = self._run_request(self._transform, request)
+
+ images = results.get(generation.ARTIFACT_IMAGE, []) + results.get(generation.ARTIFACT_DEPTH, [])
+ masks = results.get(generation.ARTIFACT_MASK, None)
+ return images, masks
+
+ def transform_3d(
+ self,
+ images: Sequence[Image.Image],
+ depth_calc: generation.TransformParameters,
+ transform: generation.TransformParameters,
+ extras: Optional[Dict] = None
+ ) -> Tuple[List[Image.Image], Optional[List[Image.Image]]]:
+ assert len(images)
+ assert isinstance(images[0], Image.Image)
+
+ image_prompts = [image_to_prompt(image) for image in images]
+ warped_images = []
+ warp_mask = None
+ op_id = "resample" if transform.HasField("resample") else "camera_pose"
+
+ extras_struct = Struct()
+ if extras is not None:
+ extras_struct.update(extras)
+
+ rq_depth = generation.Request(
+ engine_id=self._transform.engine_id,
+ requested_type=generation.ARTIFACT_TENSOR,
+ prompt=[image_prompts[0]],
+ transform=depth_calc,
+ )
+ rq_transform = generation.Request(
+ engine_id=self._transform.engine_id,
+ prompt=image_prompts,
+ transform=transform,
+ extras=extras_struct
+ )
+
+ if self._debug_no_chains:
+ results = self._run_request(self._transform, rq_depth)
+ rq_transform.prompt.append(
+ generation.Prompt(
+ artifact=generation.Artifact(
+ type=generation.ARTIFACT_TENSOR,
+ tensor=results[generation.ARTIFACT_TENSOR][0]
+ )
+ )
+ )
+ results = self._run_request(self._transform, rq_transform)
+ else:
+ chain_rq = generation.ChainRequest(
+ request_id=f"{op_id}_3d_chain",
+ stage=[
+ generation.Stage(
+ id="depth_calc",
+ request=rq_depth,
+ on_status=[generation.OnStatus(action=[generation.STAGE_ACTION_PASS], target=op_id)]
+ ),
+ generation.Stage(
+ id=op_id,
+ request=rq_transform,
+ on_status=[generation.OnStatus(action=[generation.STAGE_ACTION_RETURN])]
+ )
+ ])
+ results = self._run_request(self._transform, chain_rq)
+
+ warped_images = results[generation.ARTIFACT_IMAGE]
+ warp_mask = results.get(generation.ARTIFACT_MASK, None)
+
+ return warped_images, warp_mask
+
+ def upscale(
+ self,
+ init_image: Image.Image,
+ width: Optional[int] = None,
+ height: Optional[int] = None,
+ prompt: Union[str, generation.Prompt] = None,
+ steps: Optional[int] = 20,
+ cfg_scale: Optional[float] = 7.0,
+ seed: int = 0
+ ) -> Image.Image:
+ """
+ Upscale an image.
+
+ :param init_image: Image to upscale.
+
+ Optional parameters for upscale method:
+
+ :param width: Width of the output images.
+ :param height: Height of the output images.
+ :param prompt: Prompt used in text conditioned models
+ :param steps: Number of diffusion steps
+ :param cfg_scale: Intensity of the prompt, when a prompt is used
+ :param seed: Seed for the random number generator.
+
+ Some variables are not used for specific engines, but are included for consistency.
+
+ Variables ignored in ESRGAN engines: prompt, steps, cfg_scale, seed
+
+ :return: Tuple of (prompts, image_parameters)
+ """
+
+ prompts = [image_to_prompt(init_image)]
+ if prompt:
+ if isinstance(prompt, str):
+ prompt = generation.Prompt(text=prompt)
+ elif not isinstance(prompt, generation.Prompt):
+ raise ValueError("prompt must be a string or Prompt object")
+ prompts.append(prompt)
+
+ request = generation.Request(
+ engine_id=self._upscale.engine_id,
+ prompt=prompts,
+ image=generation.ImageParameters(
+ width=width,
+ height=height,
+ seed=[seed],
+ steps=steps,
+ parameters=[generation.StepParameter(
+ sampler=generation.SamplerParameters(cfg_scale=cfg_scale)
+ )],
+ )
+ )
+ results = self._run_request(self._upscale, request)
+ return results[generation.ARTIFACT_IMAGE][0]
+
+ def _adjust_request_engine(self, request: generation.Request):
+ if request.engine_id == self._transform.engine_id:
+ assert request.HasField("transform")
+ if request.transform.HasField("color_adjust") or \
+ (request.transform.HasField("resample") and len(request.transform.resample.transform.data) == 9):
+ request.engine_id = self._transform.engine_id + "-cpu"
+
+ def _adjust_request_for_retry(self, request: generation.Request, attempt: int):
+ logger.warning(f" adjusting request, will retry {self._max_retries-attempt} more times")
+ request.image.seed[:] = [random.randrange(0, 4294967295) for _ in request.image.seed]
+ if attempt > 0 and request.image.parameters and request.image.parameters[0].HasField("schedule"):
+ schedule = request.image.parameters[0].schedule
+ if schedule.HasField("start"):
+ schedule.start = max(0.0, min(1.0, schedule.start + self._retry_schedule_offset))
+
+ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale,
+ schedule_start, init_noise_scale, masked_area_init,
+ guidance_preset, guidance_cuts, guidance_strength):
+
+ if not seed:
+ seed = [random.randrange(0, 4294967295)]
+ elif isinstance(seed, int):
+ seed = [seed]
+ else:
+ seed = list(seed)
+
+ step_parameters = {
+ "scaled_step": 0,
+ "sampler": generation.SamplerParameters(cfg_scale=cfg_scale, init_noise_scale=init_noise_scale),
+ }
+ if schedule_start != 1.0:
+ step_parameters["schedule"] = generation.ScheduleParameters(start=schedule_start)
+
+ if guidance_preset is not generation.GUIDANCE_PRESET_NONE:
+ cutouts = generation.CutoutParameters(count=guidance_cuts) if guidance_cuts else None
+ if guidance_strength == 0.0:
+ guidance_strength = None
+ step_parameters["guidance"] = generation.GuidanceParameters(
+ guidance_preset=guidance_preset,
+ instances=[
+ generation.GuidanceInstanceParameters(
+ cutouts=cutouts,
+ guidance_strength=guidance_strength,
+ models=None, prompt=None
+ )
+ ]
+ )
+
+ return generation.ImageParameters(
+ transform=None if sampler is None else generation.TransformType(diffusion=sampler),
+ height=height,
+ width=width,
+ seed=seed,
+ steps=steps,
+ samples=samples,
+ masked_area_init=masked_area_init,
+ parameters=[generation.StepParameter(**step_parameters)],
+ )
+
+ def _process_response(self, response) -> Dict[int, List[Any]]:
+ results: Dict[int, List[Any]] = {}
+ for resp in response:
+ for artifact in resp.artifacts:
+ # check for classifier rejecting a text prompt
+ if artifact.finish_reason == generation.FILTER and artifact.type == generation.ARTIFACT_TEXT:
+ raise ClassifierException(prompt=artifact.text)
+
+ if artifact.type not in results:
+ results[artifact.type] = []
+
+ if artifact.type == generation.ARTIFACT_CLASSIFICATIONS:
+ results[artifact.type].append(artifact.classifier)
+ elif artifact.type in (generation.ARTIFACT_DEPTH, generation.ARTIFACT_IMAGE, generation.ARTIFACT_MASK):
+ image = Image.open(io.BytesIO(artifact.binary))
+ results[artifact.type].append(image)
+ elif artifact.type == generation.ARTIFACT_TENSOR:
+ results[artifact.type].append(artifact.tensor)
+ elif artifact.type == generation.ARTIFACT_TEXT:
+ results[artifact.type].append(artifact.text)
+
+ return results
+
+ def _run_request(
+ self,
+ endpoint: Endpoint,
+ request: Union[generation.ChainRequest, generation.Request]
+ ) -> Dict[int, List[Any]]:
+ if isinstance(request, generation.Request):
+ self._adjust_request_engine(request)
+ elif isinstance(request, generation.ChainRequest):
+ for stage in request.stage:
+ self._adjust_request_engine(stage.request)
+
+ for attempt in range(self._max_retries+1):
+ try:
+ if isinstance(request, generation.Request):
+ response = endpoint.stub.Generate(request, timeout=self._request_timeout)
+ else:
+ response = endpoint.stub.ChainGenerate(request, timeout=self._request_timeout)
+
+ results = self._process_response(response)
+
+ # check for classifier obfuscation
+ if generation.ARTIFACT_CLASSIFICATIONS in results:
+ for classifier in results[generation.ARTIFACT_CLASSIFICATIONS]:
+ if classifier.realized_action == generation.ACTION_OBFUSCATE:
+ raise ClassifierException(classifier)
+
+ break
+ except ClassifierException as ce:
+ if attempt == self._max_retries or not self._retry_obfuscation or ce.prompt is not None:
+ raise ce
+
+ for exceed in ce.classifier_result.exceeds:
+ logger.warning(f"Received classifier obfuscation. Exceeded {exceed.name} threshold")
+
+ if isinstance(request, generation.Request) and request.HasField("image"):
+ self._adjust_request_for_retry(request, attempt)
+ elif isinstance(request, generation.ChainRequest):
+ for stage in request.stage:
+ if stage.request.HasField("image"):
+ self._adjust_request_for_retry(stage.request, attempt)
+ else:
+ raise ce
+ except grpc.RpcError as rpc_error:
+ if hasattr(rpc_error, "code"):
+ if rpc_error.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
+ if "message larger than max" in rpc_error.details():
+ raise rpc_error
+ raise OutOfCreditsException(rpc_error.details())
+ elif rpc_error.code() == grpc.StatusCode.UNAUTHENTICATED:
+ raise rpc_error
+
+ if attempt == self._max_retries:
+ raise rpc_error
+
+ logger.warning(f"Received RpcError: {rpc_error} will retry {self._max_retries-attempt} more times")
+ time.sleep(self._retry_delay * 2**attempt)
+ return results
diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py
index f5100087..20547f00 100644
--- a/src/stability_sdk/client.py
+++ b/src/stability_sdk/client.py
@@ -2,65 +2,39 @@
# fmt: off
-import pathlib
-import sys
+import getpass
+import grpc
+import logging
+import mimetypes
import os
-import uuid
import random
-import io
-import logging
+import sys
import time
-import mimetypes
+import uuid
-import grpc
from argparse import ArgumentParser, Namespace
-from typing import Dict, Generator, List, Optional, Union, Any, Sequence, Tuple
from google.protobuf.json_format import MessageToJson
from google.protobuf.struct_pb2 import Struct
from PIL import Image
-
-try:
- from dotenv import load_dotenv
-except ModuleNotFoundError:
- pass
-else:
- load_dotenv()
+from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc
-from stability_sdk.utils import (
+from .api import open_channel
+from .utils import (
SAMPLERS,
MAX_FILENAME_SZ,
- artifact_type_to_str,
- truncate_fit,
- get_sampler_from_str,
+ artifact_type_to_string,
+ image_to_prompt,
open_images,
+ sampler_from_string,
+ truncate_fit,
)
-
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
-def image_to_prompt(im, init: bool = False, mask: bool = False) -> generation.Prompt:
- if init and mask:
- raise ValueError("init and mask cannot both be True")
- buf = io.BytesIO()
- im.save(buf, format="PNG")
- buf.seek(0)
- if mask:
- return generation.Prompt(
- artifact=generation.Artifact(
- type=generation.ARTIFACT_MASK, binary=buf.getvalue()
- )
- )
- return generation.Prompt(
- artifact=generation.Artifact(
- type=generation.ARTIFACT_IMAGE, binary=buf.getvalue()
- ),
- parameters=generation.PromptParameters(init=init),
- )
-
def process_artifacts_from_answers(
prefix: str,
@@ -100,17 +74,17 @@ def process_artifacts_from_answers(
ext = ".pb"
contents = artifact.SerializeToString()
out_p = truncate_fit(prefix, prompt, ext, int(artifact_start), idx, MAX_FILENAME_SZ)
- is_allowed_type = filter_types is None or artifact_type_to_str(artifact.type) in filter_types
+ is_allowed_type = filter_types is None or artifact_type_to_string(artifact.type) in filter_types
if write:
if is_allowed_type:
with open(out_p, "wb") as f:
f.write(bytes(contents))
if verbose:
- logger.info(f"wrote {artifact_type_to_str(artifact.type)} to {out_p}")
+ logger.info(f"wrote {artifact_type_to_string(artifact.type)} to {out_p}")
else:
if verbose:
logger.info(
- f"skipping {artifact_type_to_str(artifact.type)} due to artifact type filter")
+ f"skipping {artifact_type_to_string(artifact.type)} due to artifact type filter")
yield (out_p, artifact)
idx += 1
@@ -142,7 +116,6 @@ def __init__(
self.upscale_engine = upscale_engine
self.grpc_args = {"wait_for_ready": wait_for_ready}
-
if verbose:
logger.info(f"Opening channel to {host}")
@@ -259,10 +232,10 @@ def generate(
start=start_schedule,
end=end_schedule,
)
- prompts += [image_to_prompt(init_image, init=True)]
+ prompts += [image_to_prompt(init_image)]
if mask_image is not None:
- prompts += [image_to_prompt(mask_image, mask=True)]
+ prompts += [image_to_prompt(mask_image, type=generation.ARTIFACT_MASK)]
if guidance_prompt:
@@ -360,7 +333,7 @@ def upscale(
parameters=[generation.StepParameter(**step_parameters)],
)
- prompts = [image_to_prompt(init_image, init=True)]
+ prompts = [image_to_prompt(init_image)]
if prompt:
if isinstance(prompt, str):
@@ -403,7 +376,7 @@ def emit_request(
if self.verbose:
if len(answer.artifacts) > 0:
artifact_ts = [
- artifact_type_to_str(artifact.type)
+ artifact_type_to_string(artifact.type)
for artifact in answer.artifacts
]
logger.info(
@@ -446,17 +419,12 @@ def process_cli(logger: logging.Logger = None,
STABILITY_HOST = os.getenv("STABILITY_HOST", "grpc.stability.ai:443")
STABILITY_KEY = os.getenv("STABILITY_KEY", "")
- if not STABILITY_HOST:
- logger.warning("STABILITY_HOST environment variable needs to be set.")
- sys.exit(1)
-
if not STABILITY_KEY:
- logger.warning(
- "STABILITY_KEY environment variable needs to be set. You may"
- " need to login to the Stability website to obtain the"
- " API key."
+ print(
+ "Please enter your API key from dreamstudio.ai or set the "
+ "STABILITY_KEY environment variable to skip this prompt."
)
- sys.exit(1)
+ STABILITY_KEY = getpass.getpass("Enter your Stability API key: ")
# CLI parsing
parser = ArgumentParser()
@@ -515,6 +483,12 @@ def process_cli(logger: logging.Logger = None,
parser_upscale.add_argument(
"prompt", nargs="*"
)
+
+
+ parser_animate = subparsers.add_parser('animate')
+ parser_animate.add_argument("--gui", action="store_true", help="serve Gradio UI")
+ parser_animate.add_argument("--share", action="store_true", help="create shareable UI link")
+ parser_animate.add_argument("--output", "-o", type=str, default=".", help="root output folder")
parser_generate = subparsers.add_parser('generate')
@@ -601,10 +575,10 @@ def process_cli(logger: logging.Logger = None,
if command not in subparsers.choices.keys() and command != '-h' and command != '--help':
logger.warning(f"command {command} not recognized, defaulting to 'generate'")
logger.warning(
- "[Deprecation Warning] The method you have used to invoke the sdk will be deprecated shortly."
- "[Deprecation Warning] Please modify your code to call the sdk with the following syntax:"
- "[Deprecation Warning] python -m stability_sdk "
- "[Deprecation Warning] Where is one of: upscale, generate"
+ "[Deprecation Warning] The method you have used to invoke the sdk will be deprecated shortly."
+ "[Deprecation Warning] Please modify your code to call the sdk with the following syntax:"
+ "[Deprecation Warning] python -m stability_sdk "
+ "[Deprecation Warning] Where is one of: upscale, generate"
)
input_args = ['generate'] + input_args
@@ -624,7 +598,7 @@ def process_cli(logger: logging.Logger = None,
"seed": args.seed,
"cfg_scale": args.cfg_scale,
"prompt": args.prompt,
- }
+ }
stability_api = StabilityInference(
STABILITY_HOST, STABILITY_KEY, upscale_engine=args.engine, verbose=True
)
@@ -660,7 +634,7 @@ def process_cli(logger: logging.Logger = None,
}
if args.sampler:
- request["sampler"] = get_sampler_from_str(args.sampler)
+ request["sampler"] = sampler_from_string(args.sampler)
if args.steps:
request["steps"] = args.steps
@@ -673,6 +647,18 @@ def process_cli(logger: logging.Logger = None,
args.prefix, args.prompt, answers, write=not args.no_store, verbose=True,
filter_types=args.artifact_types,
)
+ elif args.command == "animate":
+ if args.gui:
+ from .animation_ui import create_ui
+ from .api import Context
+ ui = create_ui(Context(STABILITY_HOST, STABILITY_KEY), args.output)
+ ui.queue(concurrency_count=2, max_size=2)
+ ui.launch(show_api=False, debug=True, height=768, share=args.share, show_error=True)
+ sys.exit(0)
+ else:
+ logger.warning("animate must be invoked with --gui")
+ sys.exit(1)
+
if args.show:
for artifact in open_images(artifacts, verbose=True):
diff --git a/src/stability_sdk/interfaces b/src/stability_sdk/interfaces
index 193a1e41..f3a50851 160000
--- a/src/stability_sdk/interfaces
+++ b/src/stability_sdk/interfaces
@@ -1 +1 @@
-Subproject commit 193a1e41a984c6e452aa1f99d9c1c20c15c692a4
+Subproject commit f3a50851f8ea158fef1b1d76661cfd9a8cf83e01
diff --git a/src/stability_sdk/matrix.py b/src/stability_sdk/matrix.py
new file mode 100644
index 00000000..31118755
--- /dev/null
+++ b/src/stability_sdk/matrix.py
@@ -0,0 +1,76 @@
+"""
+Minimal set of 4x4 column-major matrix functions for building transforms
+compatible with the animation transform API. This serves as reference
+implementation for the different languages we will support so only basic
+types and no external libraries are used.
+
+ [sx, 10, 20, tx] [x]
+ [01, sy, 21, ty] . [y]
+ [02, 12, sz, tz] [z]
+ [03, 13, 23, 33] [1]
+
+"""
+import math
+from typing import List
+
+Matrix = List[List[float]]
+
+identity = [[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]
+
+def multiply(a: Matrix, b: Matrix) -> Matrix:
+ assert len(a) == len(b) == 4
+ c = [[0., 0., 0., 0.],
+ [0., 0., 0., 0.],
+ [0., 0., 0., 0.],
+ [0., 0., 0., 0.]]
+ for row in range(4):
+ for col in range(4):
+ for k in range(4):
+ c[row][col] += a[row][k] * b[k][col]
+ return c
+
+def projection_fov(fov_y: float, aspect: float, near: float, far: float) -> Matrix:
+ min_x, min_y = -1, -1
+ max_x, max_y = 1, 1
+ h1 = (max_y + min_y) / (max_y - min_y)
+ w1 = (max_x + min_x) / (max_x - min_x)
+ t = math.tan(fov_y / 2)
+ s1 = 1 / t
+ s2 = 1 / (t * aspect)
+
+ # map z to the range [0, 1]
+ f1 = far / (far - near)
+ f2 = -(far * near) / (far - near)
+
+ return [[s1, 0., w1, 0.],
+ [0., s2, h1, 0.],
+ [0., 0., f1, f2],
+ [0., 0., 1., 0.]]
+
+def rotation_euler(x: float, y: float, z: float) -> Matrix:
+ """Returns a rotation matrix for the given Euler angles (in radians) using XYZ order."""
+ a, b = math.cos(x), math.sin(x)
+ c, d = math.cos(y), math.sin(y)
+ e, f = math.cos(z), math.sin(z)
+
+ ae = a * e
+ af = a * f
+ be = b * e
+ bf = b * f
+
+ return [[ c * e, af + be * d, bf - ae * d, 0.],
+ [-c * f, ae - bf * d, be + af * d, 0.],
+ [ d, -b * c, a * c, 0.],
+ [ 0., 0., 0., 1.]]
+
+def scale(sx: float, sy: float, sz: float) -> Matrix:
+ return [[sx, 0., 0., 0.],
+ [0., sy, 0., 0.],
+ [0., 0., sz, 0.],
+ [0., 0., 0., 1.]]
+
+def translation(tx: float, ty: float, tz: float) -> Matrix:
+ return [[1., 0., 0., tx],
+ [0., 1., 0., ty],
+ [0., 0., 1., tz],
+ [0., 0., 0., 1.]]
diff --git a/src/stability_sdk/utils.py b/src/stability_sdk/utils.py
index 96b75a97..fd178a58 100644
--- a/src/stability_sdk/utils.py
+++ b/src/stability_sdk/utils.py
@@ -1,23 +1,64 @@
-import pathlib
-import sys
-import os
-import uuid
-import random
import io
import logging
-import time
-from typing import Dict, Generator, List, Optional, Union, Any, Sequence, Tuple
-import mimetypes
-
+import os
+import subprocess
from PIL import Image
+from typing import Dict, Generator, Optional, Sequence, Tuple, Type, TypeVar, Union
+
+from .api import generation
+from .matrix import Matrix
-import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
-import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
+MAX_FILENAME_SZ = int(os.getenv("MAX_FILENAME_SZ", 200))
+
+
+#==============================================================================
+# Mappings from strings to protobuf enums
+#==============================================================================
+
+BORDER_MODES = {
+ 'replicate': generation.BORDER_REPLICATE,
+ 'reflect': generation.BORDER_REFLECT,
+ 'wrap': generation.BORDER_WRAP,
+ 'zero': generation.BORDER_ZERO,
+ 'prefill': generation.BORDER_PREFILL,
+}
+
+CAMERA_TYPES = {
+ 'perspective': generation.CAMERA_PERSPECTIVE,
+ 'orthographic': generation.CAMERA_ORTHOGRAPHIC,
+}
+
+COLOR_MATCH_MODES = {
+ "hsv": generation.COLOR_MATCH_HSV,
+ "lab": generation.COLOR_MATCH_LAB,
+ "rgb": generation.COLOR_MATCH_RGB,
+}
+
+GUIDANCE_PRESETS: Dict[str, int] = {
+ "none": generation.GUIDANCE_PRESET_NONE,
+ "simple": generation.GUIDANCE_PRESET_SIMPLE,
+ "fastblue": generation.GUIDANCE_PRESET_FAST_BLUE,
+ "fastgreen": generation.GUIDANCE_PRESET_FAST_GREEN,
+}
+
+INTERPOLATE_MODES = {
+ 'film': generation.INTERPOLATE_FILM,
+ 'mix': generation.INTERPOLATE_LINEAR,
+ 'rife': generation.INTERPOLATE_RIFE,
+ 'vae-lerp': generation.INTERPOLATE_VAE_LINEAR,
+ 'vae-slerp': generation.INTERPOLATE_VAE_SLERP,
+}
+
+RENDER_MODES = {
+ 'mesh': generation.RENDER_MESH,
+ 'pointcloud': generation.RENDER_POINTCLOUD,
+}
+
SAMPLERS: Dict[str, int] = {
"ddim": generation.SAMPLER_DDIM,
"plms": generation.SAMPLER_DDPM,
@@ -30,10 +71,128 @@
"k_dpmpp_2m": generation.SAMPLER_K_DPMPP_2M,
"k_dpmpp_2s_ancestral": generation.SAMPLER_K_DPMPP_2S_ANCESTRAL
}
-
-MAX_FILENAME_SZ = int(os.getenv("MAX_FILENAME_SZ", 200))
-def artifact_type_to_str(artifact_type: generation.ArtifactType):
+T = TypeVar('T')
+
+def _from_string(s: str, mapping: Dict[str, T], name: str, enum_cls: Type[T]) -> T:
+ enum_value = mapping.get(s.lower().strip())
+ if enum_value is None:
+ raise ValueError(f"invalid {name}: {s}")
+ return enum_value
+
+def border_mode_from_string(s: str) -> generation.BorderMode:
+ return _from_string(s, BORDER_MODES, "border mode", generation.BorderMode)
+
+def camera_type_from_string(s: str) -> generation.CameraType:
+ return _from_string(s, CAMERA_TYPES, "camera type", generation.CameraType)
+
+def color_match_from_string(s: str) -> generation.ColorMatchMode:
+ return _from_string(s, COLOR_MATCH_MODES, "color match", generation.ColorMatchMode)
+
+def guidance_from_string(s: str) -> generation.GuidancePreset:
+ return _from_string(s, GUIDANCE_PRESETS, "guidance preset", generation.GuidancePreset)
+
+def interpolate_mode_from_string(s: str) -> generation.InterpolateMode:
+ return _from_string(s, INTERPOLATE_MODES, "interpolate mode", generation.InterpolateMode)
+
+def render_mode_from_string(s: str) -> generation.RenderMode:
+ return _from_string(s, RENDER_MODES, "render mode", generation.RenderMode)
+
+def sampler_from_string(s: str) -> generation.DiffusionSampler:
+ return _from_string(s, SAMPLERS, "sampler", generation.DiffusionSampler)
+
+
+#==============================================================================
+# Transform helper functions
+#==============================================================================
+
+def camera_pose_transform(
+ transform: Matrix,
+ near_plane: float,
+ far_plane: float,
+ fov: float,
+ camera_type: str='perspective',
+ render_mode: str='mesh',
+ do_prefill: bool=True,
+) -> generation.TransformParameters:
+ camera_parameters = generation.CameraParameters(
+ camera_type=camera_type_from_string(camera_type),
+ near_plane=near_plane, far_plane=far_plane, fov=fov)
+ return generation.TransformParameters(
+ camera_pose=generation.TransformCameraPose(
+ world_to_view_matrix=generation.TransformMatrix(data=sum(transform, [])),
+ camera_parameters=camera_parameters,
+ render_mode=render_mode_from_string(render_mode),
+ do_prefill=do_prefill
+ )
+ )
+
+def color_adjust_transform(
+ brightness: float=1.0,
+ contrast: float=1.0,
+ hue: float=0.0,
+ saturation: float=1.0,
+ lightness: float=0.0,
+ match_image: Optional[Image.Image]=None,
+ match_mode: str='LAB',
+ noise_amount: float=0.0,
+ noise_seed: int=0
+) -> generation.TransformParameters:
+ if match_mode == 'None':
+ match_mode = 'RGB'
+ match_image = None
+ return generation.TransformParameters(
+ color_adjust=generation.TransformColorAdjust(
+ brightness=brightness,
+ contrast=contrast,
+ hue=hue,
+ saturation=saturation,
+ lightness=lightness,
+ match_image=generation.Artifact(
+ type=generation.ARTIFACT_IMAGE,
+ binary=image_to_jpg_bytes(match_image),
+ ) if match_image is not None else None,
+ match_mode=color_match_from_string(match_mode),
+ noise_amount=noise_amount,
+ noise_seed=noise_seed,
+ ))
+
+def depth_calc_transform(
+ blend_weight: float,
+ blur_radius: int=0,
+ reverse: bool=False,
+) -> generation.TransformParameters:
+ return generation.TransformParameters(
+ depth_calc=generation.TransformDepthCalc(
+ blend_weight=blend_weight,
+ blur_radius=blur_radius,
+ reverse=reverse
+ )
+ )
+
+def resample_transform(
+ border_mode: str,
+ transform: Matrix,
+ prev_transform: Optional[Matrix]=None,
+ depth_warp: float=1.0,
+ export_mask: bool=False
+) -> generation.TransformParameters:
+ return generation.TransformParameters(
+ resample=generation.TransformResample(
+ border_mode=border_mode_from_string(border_mode),
+ transform=generation.TransformMatrix(data=sum(transform, [])),
+ prev_transform=generation.TransformMatrix(data=sum(prev_transform, [])) if prev_transform else None,
+ depth_warp=depth_warp,
+ export_mask=export_mask
+ )
+ )
+
+
+#==============================================================================
+# General utility functions
+#==============================================================================
+
+def artifact_type_to_string(artifact_type: generation.ArtifactType):
"""
Convert ArtifactType to a string.
:param artifact_type: The ArtifactType to convert.
@@ -49,33 +208,117 @@ def artifact_type_to_str(artifact_type: generation.ArtifactType):
)
return "ARTIFACT_UNRECOGNIZED"
-def truncate_fit(prefix: str, prompt: str, ext: str, ts: int, idx: int, max: int) -> str:
+def create_video_from_frames(frames_path: str, mp4_path: str, fps: int=24, reverse: bool=False):
"""
- Constructs an output filename from a collection of required fields.
+ Convert a series of image frames to a video file using ffmpeg.
+
+ :param frames_path: The path to the directory containing the image frames named frame_00000.png, frame_00001.png, etc.
+ :param mp4_path: The path to save the output video file.
+ :param fps: The frames per second for the output video. Default is 24.
+ :param reverse: A flag to reverse the order of the frames in the output video. Default is False.
+ """
+
+ cmd = [
+ 'ffmpeg',
+ '-y',
+ '-vcodec', 'png',
+ '-r', str(fps),
+ '-start_number', str(0),
+ '-i', os.path.join(frames_path, "frame_%05d.png"),
+ '-c:v', 'libx264',
+ '-vf',
+ f'fps={fps}',
+ '-pix_fmt', 'yuv420p',
+ '-crf', '17',
+ '-preset', 'veryslow',
+ mp4_path
+ ]
+ if reverse:
+ cmd.insert(-1, '-vf')
+ cmd.insert(-1, 'reverse')
+
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ _, stderr = process.communicate()
+ if process.returncode != 0:
+ raise RuntimeError(stderr)
+
+def extract_frames_from_video(video_path: str, frames_subdir: str='frames'):
+ """
+ Extracts all frames from a video to a subdirectory of the video's parent folder.
+ :param video_path: A path to the video.
+ :param frames_subdir: Name of the subdirectory to save the frames into.
+ :return: The frames subdirectory path.
+ """
+ out_dir = os.path.join(os.path.dirname(video_path), frames_subdir)
+ if not os.path.exists(out_dir):
+ os.mkdir(out_dir)
- Given an over-budget threshold of `max`, trims the prompt string to satisfy the budget.
- NB: As implemented, 'max' is the smallest filename length that will trigger truncation.
- It is presumed that the sum of the lengths of the other filename fields is smaller than `max`.
- If they exceed `max`, this function will just always construct a filename with no prompt component.
+ cmd = [
+ 'ffmpeg',
+ '-i', video_path,
+ os.path.join(out_dir, "frame_%05d.png"),
+ ]
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ _, stderr = process.communicate()
+ if process.returncode != 0:
+ raise RuntimeError(stderr)
+
+ return out_dir
+
+def image_mix(img_a: Image.Image, img_b: Image.Image, ratio: Union[float, Image.Image]) -> Image.Image:
"""
- post = f"_{ts}_{idx}"
- prompt_budget = max
- prompt_budget -= len(prefix)
- prompt_budget -= len(post)
- prompt_budget -= len(ext) + 1
- return f"{prefix}{prompt[:prompt_budget]}{post}{ext}"
+ Performs a linear interpolation between two images
+ :param img_a: The first image.
+ :param img_b: The second image.
+ :param ratio: Mix ratio or mask image.
+ :return: The mixed image
+ """
+ if img_a.size != img_b.size:
+ raise ValueError(f"img_a size {img_a.size} does not match img_b size {img_b.size}")
+
+ if isinstance(ratio, Image.Image):
+ if ratio.size != img_a.size:
+ raise ValueError(f"mix ratio size {ratio.size} does not match img_a size {img_a.size}")
+ return Image.composite(img_b, img_a, ratio)
+
+ return Image.blend(img_a, img_b, ratio)
+
+def image_to_jpg_bytes(image: Image.Image, quality: int=90) -> bytes:
+ """
+ Compresses an image to a JPEG byte array.
+ :param image: The image to convert.
+ :param quality: The JPEG quality to use.
+ :return: The JPEG byte array.
+ """
+ buf = io.BytesIO()
+ image.save(buf, format="JPEG", quality=quality)
+ buf.seek(0)
+ return buf.getvalue()
-def get_sampler_from_str(s: str) -> generation.DiffusionSampler:
+def image_to_png_bytes(image: Image.Image) -> bytes:
"""
- Convert a string to a DiffusionSampler enum.
- :param s: The string to convert.
- :return: The DiffusionSampler enum.
+ Compresses an image to a PNG byte array.
+ :param image: The image to convert.
+ :return: The PNG byte array.
"""
- algorithm_key = s.lower().strip()
- algorithm = SAMPLERS.get(algorithm_key, None)
- if algorithm is None:
- raise ValueError(f"unknown sampler {s}")
- return algorithm
+ buf = io.BytesIO()
+ image.save(buf, format="PNG")
+ buf.seek(0)
+ return buf.getvalue()
+
+def image_to_prompt(
+ image: Image.Image,
+ type: generation.ArtifactType=generation.ARTIFACT_IMAGE
+) -> generation.Prompt:
+ """
+ Create Prompt message type from an image.
+ :param image: The image.
+ :param type: The ArtifactType to use (ARTIFACT_IMAGE, ARTIFACT_MASK, or ARTIFACT_DEPTH).
+ """
+ return generation.Prompt(artifact=generation.Artifact(
+ type=type,
+ binary=image_to_png_bytes(image)
+ ))
def open_images(
images: Union[
@@ -97,3 +340,29 @@ def open_images(
img = Image.open(io.BytesIO(artifact.binary))
img.show()
yield (path, artifact)
+
+def tensor_to_prompt(tensor: 'tensors_pb.Tensor') -> generation.Prompt:
+ """
+ Create Prompt message type from a tensor.
+ :param tensor: The tensor.
+ """
+ return generation.Prompt(artifact=generation.Artifact(
+ type=generation.ARTIFACT_TENSOR,
+ tensor=tensor
+ ))
+
+def truncate_fit(prefix: str, prompt: str, ext: str, ts: int, idx: int, max: int) -> str:
+ """
+ Constructs an output filename from a collection of required fields.
+
+ Given an over-budget threshold of `max`, trims the prompt string to satisfy the budget.
+ NB: As implemented, 'max' is the smallest filename length that will trigger truncation.
+ It is presumed that the sum of the lengths of the other filename fields is smaller than `max`.
+ If they exceed `max`, this function will just always construct a filename with no prompt component.
+ """
+ post = f"_{ts}_{idx}"
+ prompt_budget = max
+ prompt_budget -= len(prefix)
+ prompt_budget -= len(post)
+ prompt_budget -= len(ext) + 1
+ return f"{prefix}{prompt[:prompt_budget]}{post}{ext}"
diff --git a/tests/assets/4166726513_giant__rainbow_sequoia__tree_by_hayao_miyazaki___earth_tones__a_row_of_western_cedar_nurse_trees_che.png b/tests/assets/4166726513_giant__rainbow_sequoia__tree_by_hayao_miyazaki___earth_tones__a_row_of_western_cedar_nurse_trees_che.png
new file mode 100644
index 00000000..8cfef971
Binary files /dev/null and b/tests/assets/4166726513_giant__rainbow_sequoia__tree_by_hayao_miyazaki___earth_tones__a_row_of_western_cedar_nurse_trees_che.png differ
diff --git a/tests/assets/HebyMorgongava_512kb.mp4 b/tests/assets/HebyMorgongava_512kb.mp4
new file mode 100755
index 00000000..eb4360df
Binary files /dev/null and b/tests/assets/HebyMorgongava_512kb.mp4 differ
diff --git a/tests/conftest.py b/tests/conftest.py
index 782ab634..12bc8653 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,23 +1,11 @@
-from concurrent import futures
-
import grpc
-import pytest
-
-import logging
import pathlib
-import sys
-
-thisPath = pathlib.Path(__file__).parent.parent.resolve()
-genPath = thisPath / "src/stability_sdk/interfaces/gooseai/generation"
-tensPath = thisPath / "src/stability_sdk/interfaces/src/tensorizer/tensors"
-assert genPath.exists()
-assert tensPath.exists()
+import pytest
-logger = logging.getLogger(__name__)
-sys.path.extend([str(genPath), str(tensPath)])
+from concurrent import futures
+from PIL import Image
-import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
-import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc
+from stability_sdk.api import generation_grpc
# modified from https://github.com/justdoit0823/grpc-resolver/blob/master/tests/conftest.py
@@ -38,3 +26,15 @@ def grpc_server(grpc_addr):
server.start()
yield server
server.stop(0)
+
+@pytest.fixture(scope='module')
+def impath() -> str:
+ return str(next(pathlib.Path('.').glob('**/tests/assets/*.png')).resolve())
+
+@pytest.fixture(scope='module')
+def pil_image(impath) -> Image.Image:
+ return Image.open(impath)
+
+@pytest.fixture(scope='module')
+def vidpath() -> str:
+ return str(next(pathlib.Path('.').glob('**/tests/assets/*.mp4')))
diff --git a/tests/test_animator.py b/tests/test_animator.py
new file mode 100644
index 00000000..304363fa
--- /dev/null
+++ b/tests/test_animator.py
@@ -0,0 +1,74 @@
+import pytest
+
+from pathlib import Path
+
+from stability_sdk.animation import Animator, AnimationArgs
+from stability_sdk.api import Context
+
+from .test_api import MockStub
+
+animation_prompts={0:"foo bar"}
+
+def test_init_animator():
+ Animator(
+ Context(stub=MockStub()),
+ args=AnimationArgs(),
+ animation_prompts=animation_prompts,
+ )
+
+def test_init_animator_prompts_notoptional():
+ with pytest.raises(TypeError, match="missing 1 required positional argument: 'animation_prompts'"):
+ Animator(
+ Context(stub=MockStub()),
+ args=AnimationArgs(),
+ )
+
+def test_save_settings():
+ animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts)
+ animator.save_settings("settings.txt")
+
+def test_get_weights():
+ animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts)
+ animator.get_animation_prompts_weights(frame_idx=0)
+
+def test_load_video(vidpath):
+ args = AnimationArgs()
+ args.animation_mode = 'Video Input'
+ args.video_init_path = vidpath
+ animator = Animator(Context(stub=MockStub()), args=args, animation_prompts=animation_prompts)
+ assert len(animator.prior_frames) > 0
+ assert animator.video_prev_frame is not None
+ assert all([v is not None for v in animator.prior_frames])
+
+@pytest.mark.parametrize('animation_mode', ['Video Input','2D','3D warp','3D render'])
+def test_render(animation_mode, vidpath):
+ args = AnimationArgs()
+ args.animation_mode = animation_mode
+ args.video_init_path = vidpath
+ animator = Animator(Context(stub=MockStub()), args=args, animation_prompts=animation_prompts)
+ if animation_mode == 'Video Input':
+ print(len(animator.prior_frames))
+ print([type(p) for p in animator.prior_frames])
+ _ = animator.render()
+
+def test_init_image_none():
+ animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts)
+ assert len(animator.prior_frames) == 0
+
+def test_init_image_from_args(impath):
+ args = AnimationArgs()
+ args.init_image = impath
+ animator = Animator(Context(stub=MockStub()), args=args, animation_prompts=animation_prompts)
+ assert len(animator.prior_frames) == 1
+
+def test_init_image_from_input(impath):
+ animator = Animator(Context(stub=MockStub()), args=AnimationArgs(), animation_prompts=animation_prompts)
+ print(animator.prior_frames)
+ assert len(animator.prior_frames) == 0
+ animator.load_init_image()
+ assert len(animator.prior_frames) == 0
+ assert Path(impath).exists()
+ animator.load_init_image(impath)
+ assert len(animator.prior_frames) == 1
+ animator.set_cadence_mode(True)
+ assert len(animator.prior_frames) == 2
\ No newline at end of file
diff --git a/tests/test_api.py b/tests/test_api.py
new file mode 100644
index 00000000..b8a6c740
--- /dev/null
+++ b/tests/test_api.py
@@ -0,0 +1,173 @@
+import io
+import numpy as np
+from PIL import Image
+from typing import Generator
+
+import stability_sdk.matrix as matrix
+from stability_sdk import utils
+from stability_sdk.api import Context, generation
+
+def _artifact_from_image(image: Image.Image) -> generation.Artifact:
+ binary = utils.image_to_png_bytes(image)
+ return generation.Artifact(
+ type=generation.ARTIFACT_IMAGE,
+ mime="image/png",
+ binary=binary,
+ size=len(binary)
+ )
+
+def _rand_image(width: int=512, height: int=512) -> Image.Image:
+ return Image.fromarray(np.random.randint(0, 255, (height, width, 3), dtype=np.uint8))
+
+class MockStub:
+ def __init__(self):
+ pass
+
+ def ChainGenerate(self, chain: generation.ChainRequest, **kwargs) -> Generator[generation.Answer, None, None]:
+ # Not a full implementation of chaining, but enough to test current api.Context layer
+ artifacts = []
+ for stage in chain.stage:
+ stage.request.MergeFrom(generation.Request(prompt=[generation.Prompt(artifact=a) for a in artifacts]))
+ artifacts = []
+ for answer in self.Generate(stage.request):
+ artifacts.extend(answer.artifacts)
+ for artifact in artifacts:
+ yield generation.Answer(artifacts=[artifact])
+
+ def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
+ if request.HasField("image"):
+ image = _rand_image(request.image.width or 512, request.image.height or 512)
+ yield generation.Answer(artifacts=[_artifact_from_image(image)])
+
+ elif request.HasField("interpolate"):
+ assert len(request.prompt) == 2
+ assert request.prompt[0].artifact.type == generation.ARTIFACT_IMAGE
+ assert request.prompt[1].artifact.type == generation.ARTIFACT_IMAGE
+ image_a = Image.open(io.BytesIO(request.prompt[0].artifact.binary))
+ image_b = Image.open(io.BytesIO(request.prompt[1].artifact.binary))
+ assert image_a.size == image_b.size
+ for ratio in request.interpolate.ratios:
+ tween = utils.image_mix(image_a, image_b, ratio)
+ yield generation.Answer(artifacts=[_artifact_from_image(tween)])
+
+ elif request.HasField("transform"):
+ assert len(request.prompt) >= 1
+
+ has_depth_input, has_tensor_input = False, False
+ for prompt in request.prompt:
+ if prompt.artifact.type == generation.ARTIFACT_DEPTH:
+ has_depth_input = True
+ elif prompt.artifact.type == generation.ARTIFACT_TENSOR:
+ has_tensor_input = True
+
+ # 3D resample and camera pose require a depth or depth tensor artifact
+ if request.transform.HasField("resample") and len(request.transform.resample.transform.data) == 16:
+ assert has_depth_input or has_tensor_input
+ if request.transform.HasField("camera_pose"):
+ assert has_depth_input or has_tensor_input
+
+ export_mask = request.transform.HasField("camera_pose")
+ if request.transform.HasField("resample"):
+ if request.transform.resample.HasField("export_mask"):
+ export_mask = request.transform.resample.export_mask
+
+ for prompt in request.prompt:
+ if prompt.artifact.type == generation.ARTIFACT_IMAGE:
+ image = Image.open(io.BytesIO(prompt.artifact.binary))
+ artifact = _artifact_from_image(image)
+ if request.transform.HasField("depth_calc"):
+ if request.requested_type == generation.ARTIFACT_TENSOR:
+ artifact.type = generation.ARTIFACT_TENSOR
+ else:
+ artifact.type = generation.ARTIFACT_DEPTH
+ yield generation.Answer(artifacts=[artifact])
+
+ if export_mask:
+ mask = _rand_image(image.width, image.height).convert("L")
+ artifact = _artifact_from_image(mask)
+ artifact.type = generation.ARTIFACT_MASK
+ yield generation.Answer(artifacts=[artifact])
+
+def test_api_generate():
+ api = Context(stub=MockStub())
+ width, height = 512, 768
+ results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height)
+ assert isinstance(results, dict)
+ assert generation.ARTIFACT_IMAGE in results
+ assert len(results[generation.ARTIFACT_IMAGE]) == 1
+ image = results[generation.ARTIFACT_IMAGE][0]
+ assert isinstance(image, Image.Image)
+ assert image.size == (width, height)
+
+def test_api_inpaint():
+ api = Context(stub=MockStub())
+ width, height = 512, 768
+ image = _rand_image(width, height)
+ mask = _rand_image(width, height).convert("L")
+ results = api.inpaint(image, mask, prompts=["foo bar"], weights=[1.0])
+ assert generation.ARTIFACT_IMAGE in results
+ assert len(results[generation.ARTIFACT_IMAGE]) == 1
+ image = results[generation.ARTIFACT_IMAGE][0]
+ assert isinstance(image, Image.Image)
+ assert image.size == (width, height)
+
+def test_api_interpolate():
+ api = Context(stub=MockStub())
+ width, height = 512, 768
+ image_a = _rand_image(width, height)
+ image_b = _rand_image(width, height)
+ results = api.interpolate([image_a, image_b], [0.3, 0.5, 0.6])
+ assert len(results) == 3
+ for image in results:
+ assert isinstance(image, Image.Image)
+ assert image.size == (width, height)
+
+def test_api_transform_and_generate():
+ api = Context(stub=MockStub())
+ width, height = 512, 704
+ init_image = _rand_image(width, height)
+ generate_request = api.generate(["a cute cat"], [1], width=width, height=height,
+ init_strength=0.65, return_request=True)
+ assert isinstance(generate_request, generation.Request)
+ image = api.transform_and_generate(init_image, [utils.color_adjust_transform()], generate_request)
+ assert isinstance(image, Image.Image)
+ assert image.size == (width, height)
+
+def test_api_transform_camera_pose():
+ api = Context(stub=MockStub())
+ image = _rand_image()
+ xform = matrix.identity
+ pose = utils.camera_pose_transform(
+ xform, 0.1, 100.0, 75.0,
+ camera_type='perspective',
+ render_mode='mesh',
+ do_prefill=True
+ )
+ images, masks = api.transform_3d([image], utils.depth_calc_transform(blend_weight=1.0), pose)
+ assert len(images) == 1 and len(masks) == 1
+ assert isinstance(images[0], Image.Image)
+ assert isinstance(masks[0], Image.Image)
+
+def test_api_transform_color_adjust():
+ api = Context(stub=MockStub())
+ image = _rand_image()
+ images, masks = api.transform([image], utils.color_adjust_transform())
+ assert len(images) == 1 and not masks
+ assert isinstance(images[0], Image.Image)
+ images, masks = api.transform([image, image], utils.color_adjust_transform())
+ assert len(images) == 2 and not masks
+
+def test_api_transform_resample_3d():
+ api = Context(stub=MockStub())
+ image = _rand_image()
+ xform = matrix.identity
+ resample = utils.resample_transform('replicate', xform, xform, export_mask=True)
+ images, masks = api.transform_3d([image], utils.depth_calc_transform(blend_weight=0.5), resample)
+ assert len(images) == 1 and len(masks) == 1
+ assert isinstance(images[0], Image.Image)
+ assert isinstance(masks[0], Image.Image)
+
+def test_api_upscale():
+ api = Context(stub=MockStub())
+ result = api.upscale(_rand_image())
+ assert isinstance(result, Image.Image)
diff --git a/tests/test_client.py b/tests/test_client.py
index a17feb72..69b93ca8 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -1,60 +1,28 @@
-import pytest
from PIL import Image
+from typing import Generator
from stability_sdk import client
-import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
-import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc
-
-import grpc
-
-# feel like we should be using this, not sure how/where
-import grpc_testing
-
-from typing import Generator
+from stability_sdk.api import generation
-def test_client_import():
- from stability_sdk import client
- assert True
def test_StabilityInference_init():
- class_instance = client.StabilityInference(key='thisIsNotARealKey')
+ _ = client.StabilityInference(key='thisIsNotARealKey')
assert True
-def test_StabilityInference_init_nokey_error():
- try:
- class_instance = client.StabilityInference()
- assert False
- except ValueError:
- assert True
-
def test_StabilityInference_init_nokey_insecure_host():
- class_instance = client.StabilityInference(host='foo.bar.baz')
+ _ = client.StabilityInference(host='foo.bar.baz')
assert True
-def test_image_to_prompt():
- im = Image.new('RGB',(1,1))
- prompt = client.image_to_prompt(im, init=False, mask=False)
- assert isinstance(prompt, generation.Prompt)
-
def test_image_to_prompt_init():
- im = Image.new('RGB',(1,1))
- prompt = client.image_to_prompt(im, init=True, mask=False)
+ im = Image.new('RGB', (1,1))
+ prompt = client.image_to_prompt(im)
assert isinstance(prompt, generation.Prompt)
def test_image_to_prompt_mask():
- im = Image.new('RGB',(1,1))
- prompt = client.image_to_prompt(im, init=False, mask=True)
+ im = Image.new('RGB', (1,1))
+ prompt = client.image_to_prompt(im, type=generation.ARTIFACT_MASK)
assert isinstance(prompt, generation.Prompt)
-def test_image_to_prompt_init_mask():
- im = Image.new('RGB',(1,1))
- try:
- prompt = client.image_to_prompt(im, init=True, mask=True)
- assert False
- except ValueError:
- assert True
-
-
def test_server_mocking(grpc_server, grpc_addr):
class_instance = client.StabilityInference(host=grpc_addr[0])
response = class_instance.generate(prompt="foo bar")
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a9c909c5..88c375da 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,30 +1,74 @@
import pytest
-import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+from PIL import Image
+from typing import ByteString
+
+import stability_sdk.matrix as matrix
+from stability_sdk.api import generation
from stability_sdk.utils import (
+ BORDER_MODES,
+ COLOR_MATCH_MODES,
+ GUIDANCE_PRESETS,
SAMPLERS,
- artifact_type_to_str,
- get_sampler_from_str,
+ artifact_type_to_string,
+ border_mode_from_string,
+ color_adjust_transform,
+ color_match_from_string,
+ depth_calc_transform,
+ guidance_from_string,
+ image_mix,
+ image_to_jpg_bytes,
+ image_to_png_bytes,
+ image_to_prompt,
+ resample_transform,
+ sampler_from_string,
truncate_fit,
)
+@pytest.mark.parametrize("border", BORDER_MODES.keys())
+def test_border_mode_from_str_2d_valid(border):
+ border_mode_from_string(s=border)
+ assert True
+
+def test_border_mode_from_str_2d_invalid():
+ with pytest.raises(ValueError, match="invalid border mode"):
+ border_mode_from_string(s='not a real border mode')
+
@pytest.mark.parametrize("artifact_type", generation.ArtifactType.values())
def test_artifact_type_to_str_valid(artifact_type):
- type_str = artifact_type_to_str(artifact_type)
+ type_str = artifact_type_to_string(artifact_type)
assert type_str == generation.ArtifactType.Name(artifact_type)
def test_artifact_type_to_str_invalid():
- type_str = artifact_type_to_str(-1)
+ type_str = artifact_type_to_string(-1)
assert type_str == 'ARTIFACT_UNRECOGNIZED'
@pytest.mark.parametrize("sampler_name", SAMPLERS.keys())
-def test_get_sampler_from_str_valid(sampler_name):
- get_sampler_from_str(s=sampler_name)
+def test_sampler_from_str_valid(sampler_name):
+ sampler_from_string(s=sampler_name)
assert True
-def test_get_sampler_from_str_invalid():
- with pytest.raises(ValueError, match="unknown sampler"):
- get_sampler_from_str(s='not a real sampler')
+def test_sampler_from_str_invalid():
+ with pytest.raises(ValueError, match="invalid sampler"):
+ sampler_from_string(s='not a real sampler')
+
+@pytest.mark.parametrize("preset_name", GUIDANCE_PRESETS.keys())
+def test_guidance_from_string_valid(preset_name):
+ guidance_from_string(s=preset_name)
+ assert True
+
+def test_guidance_from_string_invalid():
+ with pytest.raises(ValueError, match="invalid guidance preset"):
+ guidance_from_string(s='not a real preset')
+
+@pytest.mark.parametrize("color_match_mode", COLOR_MATCH_MODES.keys())
+def test_color_match_from_string_valid(color_match_mode):
+ color_match_from_string(s=color_match_mode)
+ assert True
+
+def test_color_match_from_string_invalid():
+ with pytest.raises(ValueError, match="invalid color match"):
+ color_match_from_string(s='not a real color match mode')
####################################
@@ -49,4 +93,82 @@ def test_truncate_fit1():
idx=0,
max=22)
assert outv == 'foo_ba_12345678_0.baz'
-
+
+
+#==============================================================================
+# Image functions
+#==============================================================================
+
+def test_image_mix(pil_image):
+ result = image_mix(img_a=pil_image, img_b=pil_image, ratio=0.5)
+ assert isinstance(result, Image.Image)
+ assert result.size == pil_image.size
+ result = image_mix(img_a=Image.new('L', (64,64), 0), img_b=Image.new('L', (64,64), 255), ratio=1.0)
+ assert all(pixel_value == 255 for pixel_value in result.getdata())
+
+def test_image_mix_mask(pil_image):
+ result = image_mix(img_a=pil_image, img_b=pil_image, ratio=pil_image.convert('L'))
+ assert isinstance(result, Image.Image)
+ assert result.size == pil_image.size
+ result = image_mix(img_a=Image.new('L', (64,64), 0), img_b=Image.new('L', (64,64), 255), ratio=Image.new('L', (64,64), 255))
+ assert all(pixel_value == 255 for pixel_value in result.getdata())
+
+def test_image_to_jpg_bytes(pil_image):
+ result = image_to_jpg_bytes(pil_image)
+ assert isinstance(result, ByteString)
+
+def test_image_to_png_bytes(pil_image):
+ result = image_to_png_bytes(image=pil_image)
+ assert isinstance(result, ByteString)
+
+def test_image_to_prompt(pil_image):
+ result = image_to_prompt(pil_image)
+ assert isinstance(result, generation.Prompt)
+ assert result.artifact.type == generation.ARTIFACT_IMAGE
+
+def test_image_to_prompt_mask(pil_image):
+ result = image_to_prompt(pil_image, type=generation.ARTIFACT_MASK)
+ assert isinstance(result, generation.Prompt)
+ assert result.artifact.type == generation.ARTIFACT_MASK
+
+
+#==============================================================================
+# Transform functions
+#==============================================================================
+
+@pytest.mark.parametrize("color_mode", COLOR_MATCH_MODES.keys())
+def test_colormatch_valid(pil_image, color_mode):
+ op = color_adjust_transform(
+ match_image=pil_image,
+ match_mode=color_mode
+ )
+ assert isinstance(op, generation.TransformParameters)
+
+def test_colormatch_invalid(pil_image):
+ with pytest.raises(ValueError, match="invalid color match"):
+ _ = color_adjust_transform(
+ match_image=pil_image,
+ match_mode="not a real color match mode",
+ )
+
+@pytest.mark.parametrize("border_mode", BORDER_MODES.keys())
+def test_resample_valid(border_mode):
+ op = resample_transform(
+ border_mode=border_mode,
+ transform=matrix.identity,
+ prev_transform=matrix.identity,
+ depth_warp=1.0,
+ export_mask=False
+ )
+ assert isinstance(op, generation.TransformParameters)
+
+@pytest.mark.parametrize("border_mode", ['not a border mode'])
+def test_resample_invalid(border_mode):
+ with pytest.raises(ValueError, match="invalid border mode"):
+ _ = resample_transform(
+ border_mode=border_mode,
+ transform=matrix.identity,
+ prev_transform=matrix.identity,
+ depth_warp=1.0,
+ export_mask=False
+ )