From f01a6475fe26a23ce0b00620834cbf7bc5d08942 Mon Sep 17 00:00:00 2001 From: enzymezoo-code <103286087+enzymezoo-code@users.noreply.github.com> Date: Mon, 13 Mar 2023 12:14:46 -0500 Subject: [PATCH] Upscale (#186) * Adding upscaler functionality * Update help and readme * Update notebook with upscaler * Specs and test * Client side message size limits increased * Update README.md * Bump version --- README.md | 32 +++++- nbs/demo_colab.ipynb | 28 ++++- setup.py | 2 +- src/stability_sdk/client.py | 217 ++++++++++++++++++++++++++---------- tests/specs.md | 94 ++++++++++++++++ tests/test_client.py | 10 ++ 6 files changed, 321 insertions(+), 62 deletions(-) create mode 100644 tests/specs.md diff --git a/README.md b/README.md index 44b3fd55..597f23b7 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,13 @@ You can manage API keys in your dreamstudio account [here](https://beta.dreamstu Then to invoke: -`python3 -m stability_sdk.client -W 512 -H 512 "A stunning house."` +`python3 -m stability_sdk generate -W 512 -H 512 "A stunning house."` It will generate and put PNGs in your current directory. +To upscale: +`python3 -m stability_sdk upscale -i "/path/to/image.png"` + ## SDK Usage See usage demo notebooks in ./nbs @@ -36,7 +39,7 @@ See usage demo notebooks in ./nbs ## Command line usage ``` -usage: python -m stability_sdk [-h] [--height HEIGHT] [--width WIDTH] [--start_schedule START_SCHEDULE] +usage: python -m stability_sdk generate [-h] [--height HEIGHT] [--width WIDTH] [--start_schedule START_SCHEDULE] [--end_schedule END_SCHEDULE] [--cfg_scale CFG_SCALE] [--sampler SAMPLER] [--steps STEPS] [--seed SEED] [--prefix PREFIX] [--engine ENGINE] [--num_samples NUM_SAMPLES] [--artifact_types ARTIFACT_TYPES] @@ -80,6 +83,31 @@ options: --mask_image MASK_IMAGE, -m MASK_IMAGE Mask image ``` +For upscale: +``` +usage: client.py upscale + [-h] + --init_image INIT_IMAGE + [--height HEIGHT] [--width WIDTH] [--prefix PREFIX] [--artifact_types ARTIFACT_TYPES] + [--no-store] [--show] [--engine ENGINE] + +options: + -h, --help show this help message and exit + --init_image INIT_IMAGE, -i INIT_IMAGE + Init image + --height HEIGHT, -H HEIGHT + height of upscaled image in pixels + --width WIDTH, -W WIDTH + width of upscaled image in pixels + --prefix PREFIX, -p PREFIX + output prefixes for artifacts + --artifact_types ARTIFACT_TYPES, -t ARTIFACT_TYPES + filter artifacts by type (ARTIFACT_IMAGE, ARTIFACT_TEXT, ARTIFACT_CLASSIFICATIONS, etc) + --no-store do not write out artifacts + --show open artifacts using PIL + --engine ENGINE, -e ENGINE + engine to use for upscale +``` ## Connecting to the API using languages other than Python diff --git a/nbs/demo_colab.ipynb b/nbs/demo_colab.ipynb index 7aff34db..0a08ebd0 100644 --- a/nbs/demo_colab.ipynb +++ b/nbs/demo_colab.ipynb @@ -300,11 +300,33 @@ " print('GUIDANCE: SLOWER:')\n", " display(img5)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Upscaling\n", + "answers = stability_api.upscale(\n", + " init_image=img3\n", + ")\n", + "\n", + "for resp in answers:\n", + " for artifact in resp.artifacts:\n", + " if artifact.finish_reason == generation.FILTER:\n", + " warnings.warn(\n", + " \"Your request activated the API's safety filters and could not be processed.\"\n", + " \"Please submit a different image and try again.\")\n", + " if artifact.type == generation.ARTIFACT_IMAGE:\n", + " img3 = Image.open(io.BytesIO(artifact.binary))\n", + " display(img3)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.7 ('dmarx-je5LfYh2')", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -318,12 +340,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.9.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "57881a85d677a34ea29564e0084ef84f4058c4e30a2bb466eb0e0b908d0628df" + "hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac" } } }, diff --git a/setup.py b/setup.py index 0bc8bbb1..9d334dd0 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name='stability-sdk', - version='0.3.2', + version='0.4.0', author='Wes Brown', author_email='wesbrown18@gmail.com', maintainer='David Marx', diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py index 4549d681..7db517b5 100644 --- a/src/stability_sdk/client.py +++ b/src/stability_sdk/client.py @@ -121,6 +121,7 @@ def __init__( host: str = "grpc.stability.ai:443", key: str = "", engine: str = "stable-diffusion-v1-5", + upscale_engine: str = "esrgan-v1-x2plus", verbose: bool = False, wait_for_ready: bool = True, ): @@ -130,12 +131,14 @@ def __init__( :param host: Host to connect to. :param key: Key to use for authentication. :param engine: Engine to use. + :param upscale_engine: Upscale engine to use. :param verbose: Whether to print debug messages. :param wait_for_ready: Whether to wait for the server to be ready, or to fail immediately. """ self.verbose = verbose self.engine = engine + self.upscale_engine = upscale_engine self.grpc_args = {"wait_for_ready": wait_for_ready} @@ -144,6 +147,15 @@ def __init__( call_credentials = [] + # Increase the max message size to 10MB to allow for larger images. + max_message_size: int = os.getenv("MAX_MESSAGE_SIZE") + if max_message_size is None: + max_message_size = 10 * 1024 * 1024 # 10MB + options = [ + ("grpc.max_send_message_length", int(max_message_size)), + ("grpc.max_receive_message_length",int(max_message_size)), + ] + if host.endswith("443"): if key: call_credentials.append(grpc.access_token_call_credentials(f"{key}")) @@ -152,13 +164,13 @@ def __init__( channel_credentials = grpc.composite_channel_credentials( grpc.ssl_channel_credentials(), *call_credentials ) - channel = grpc.secure_channel(host, channel_credentials) + channel = grpc.secure_channel(host, channel_credentials, options=options) else: if key: logger.warning( "Not using authentication token due to non-secure transport" ) - channel = grpc.insecure_channel(host) + channel = grpc.insecure_channel(host, options=options) if verbose: logger.info(f"Channel opened to {host}") @@ -304,6 +316,21 @@ def generate( return self.emit_request(prompt=prompts, image_parameters=image_parameters) + + def upscale( + self, + init_image: Image.Image, + height: int = None, + width: int = None, + ) -> Generator[generation.Answer, None, None]: + image_parameters=generation.ImageParameters( + height=height, + width=width, + ) + + prompts = [image_to_prompt(init_image, init=True)] + + return self.emit_request(prompt=prompts, image_parameters=image_parameters, engine_id=self.upscale_engine) # The motivation here is to facilitate constructing requests by passing protobuf objects directly. @@ -386,120 +413,198 @@ def emit_request( # CLI parsing parser = ArgumentParser() - parser.add_argument( + subparsers = parser.add_subparsers(dest='command') + + parser_upscale = subparsers.add_parser('upscale') + parser_upscale.add_argument( + "--init_image", + "-i", + type=str, + help="Init image", + required=True + ) + parser_upscale.add_argument( + "--height", "-H", type=int, default=None, help="height of upscaled image" + ) + parser_upscale.add_argument( + "--width", "-W", type=int, default=None, help="width of upscaled image" + ) + parser_upscale.add_argument( + "--prefix", + "-p", + type=str, + default="upscale_", + help="output prefixes for artifacts", + ) + parser_upscale.add_argument( + "--artifact_types", + "-t", + action='append', + type=str, + help="filter artifacts by type (ARTIFACT_IMAGE, ARTIFACT_TEXT, ARTIFACT_CLASSIFICATIONS, etc)" + ) + parser_upscale.add_argument( + "--no-store", action="store_true", help="do not write out artifacts" + ) + parser_upscale.add_argument("--show", action="store_true", help="open artifacts using PIL") + parser_upscale.add_argument( + "--engine", + "-e", + type=str, + help="engine to use for upscale", + default="esrgan-v1-x2plus", + ) + + + parser_generate = subparsers.add_parser('generate') + parser_generate.add_argument( "--height", "-H", type=int, default=512, help="[512] height of image" ) - parser.add_argument( + parser_generate.add_argument( "--width", "-W", type=int, default=512, help="[512] width of image" ) - parser.add_argument( + parser_generate.add_argument( "--start_schedule", type=float, default=0.5, help="[0.5] start schedule for init image (must be greater than 0, 1 is full strength text prompt, no trace of image)", ) - parser.add_argument( + parser_generate.add_argument( "--end_schedule", type=float, default=0.01, help="[0.01] end schedule for init image", ) - parser.add_argument( + parser_generate.add_argument( "--cfg_scale", "-C", type=float, default=7.0, help="[7.0] CFG scale factor" ) - parser.add_argument( + parser_generate.add_argument( "--sampler", "-A", type=str, help="[auto-select] (" + ", ".join(SAMPLERS.keys()) + ")", ) - parser.add_argument( + parser_generate.add_argument( "--steps", "-s", type=int, default=None, help="[auto] number of steps" ) - parser.add_argument("--seed", "-S", type=int, default=0, help="random seed to use") - parser.add_argument( + parser_generate.add_argument("--seed", "-S", type=int, default=0, help="random seed to use") + parser_generate.add_argument( "--prefix", "-p", type=str, default="generation_", help="output prefixes for artifacts", ) - parser.add_argument( + parser_generate.add_argument( "--artifact_types", "-t", action='append', type=str, help="filter artifacts by type (ARTIFACT_IMAGE, ARTIFACT_TEXT, ARTIFACT_CLASSIFICATIONS, etc)" ) - parser.add_argument( + parser_generate.add_argument( "--no-store", action="store_true", help="do not write out artifacts" ) - parser.add_argument( + parser_generate.add_argument( "--num_samples", "-n", type=int, default=1, help="number of samples to generate" ) - parser.add_argument("--show", action="store_true", help="open artifacts using PIL") - parser.add_argument( + parser_generate.add_argument("--show", action="store_true", help="open artifacts using PIL") + parser_generate.add_argument( "--engine", "-e", type=str, help="engine to use for inference", default="stable-diffusion-v1-5", ) - parser.add_argument( + parser_generate.add_argument( "--init_image", "-i", type=str, help="Init image", ) - parser.add_argument( + parser_generate.add_argument( "--mask_image", "-m", type=str, help="Mask image", ) - parser.add_argument("prompt", nargs="*") - - args = parser.parse_args() - if not args.prompt and not args.init_image: - logger.warning("prompt or init image must be provided") - parser.print_help() - sys.exit(1) - else: - args.prompt = " ".join(args.prompt) - - if args.init_image: + parser_generate.add_argument("prompt", nargs="*") + + + # handle backwards compatibility, default command to generate + input_args = sys.argv[1:] + command = None + if len(input_args)>0: + command = input_args[0] + if command not in subparsers.choices.keys() and command != '-h' and command != '--help': + logger.warning(f"command {command} not recognized, defaulting to 'generate'") + logger.warning( + "[Deprecation Warning] The method you have used to invoke the sdk will be deprecated shortly." + "[Deprecation Warning] Please modify your code to call the sdk with the following syntax:" + "[Deprecation Warning] python -m stability_sdk " + "[Deprecation Warning] Where is one of: upscale, generate" + ) + input_args = ['generate'] + input_args + + args = parser.parse_args(input_args) + + if args.command == "upscale": args.init_image = Image.open(args.init_image) - if args.mask_image: - args.mask_image = Image.open(args.mask_image) - - request = { + request = { "height": args.height, "width": args.width, - "start_schedule": args.start_schedule, - "end_schedule": args.end_schedule, - "cfg_scale": args.cfg_scale, - "seed": args.seed, - "samples": args.num_samples, "init_image": args.init_image, - "mask_image": args.mask_image, - } - - if args.sampler: - request["sampler"] = get_sampler_from_str(args.sampler) - - if args.steps: - request["steps"] = args.steps - - stability_api = StabilityInference( - STABILITY_HOST, STABILITY_KEY, engine=args.engine, verbose=True - ) - - answers = stability_api.generate(args.prompt, **request) - artifacts = process_artifacts_from_answers( - args.prefix, args.prompt, answers, write=not args.no_store, verbose=True, - filter_types=args.artifact_types, - ) + } + stability_api = StabilityInference( + STABILITY_HOST, STABILITY_KEY, upscale_engine=args.engine, verbose=True + ) + answers = stability_api.upscale(**request) + artifacts = process_artifacts_from_answers( + args.prefix, "", answers, write=not args.no_store, verbose=True, + filter_types=args.artifact_types, + ) + elif args.command == "generate": + if not args.prompt and not args.init_image: + logger.warning("prompt or init image must be provided") + parser.print_help() + sys.exit(1) + else: + args.prompt = " ".join(args.prompt) + + if args.init_image: + args.init_image = Image.open(args.init_image) + + if args.mask_image: + args.mask_image = Image.open(args.mask_image) + + request = { + "height": args.height, + "width": args.width, + "start_schedule": args.start_schedule, + "end_schedule": args.end_schedule, + "cfg_scale": args.cfg_scale, + "seed": args.seed, + "samples": args.num_samples, + "init_image": args.init_image, + "mask_image": args.mask_image, + } + + if args.sampler: + request["sampler"] = get_sampler_from_str(args.sampler) + + if args.steps: + request["steps"] = args.steps + + stability_api = StabilityInference( + STABILITY_HOST, STABILITY_KEY, engine=args.engine, verbose=True + ) + answers = stability_api.generate(args.prompt, **request) + artifacts = process_artifacts_from_answers( + args.prefix, args.prompt, answers, write=not args.no_store, verbose=True, + filter_types=args.artifact_types, + ) + if args.show: for artifact in open_images(artifacts, verbose=True): pass diff --git a/tests/specs.md b/tests/specs.md new file mode 100644 index 00000000..3b6761be --- /dev/null +++ b/tests/specs.md @@ -0,0 +1,94 @@ +# stability-sdk + +This document contains usage expectations not contained in the main README.md + +# Generation + +These examples generate and put PNGs in your current directory. + +Command line: + +`python3 -m stability_sdk generate -W 512 -H 512 "A stunning house."` + +SDK Usage: + +See usage demo notebooks in ./nbs + +# Upscale + +## Engine selection +The upscale engine can be optionally chosen when initializing the client: + +``` +stability_api = client.StabilityInference( + key=os.environ['STABILITY_KEY'], # API Key reference. + upscale_engine="upscale_engine_name", # The name of the upscaling model we want to use. +) +``` + +Command line example: + +`python3 -m stability_sdk upscale -e "upscale_engine_name" -i "/path/to/img.png"` + +Default upscale_engine_name is "esrgan-v1-x2plus" + +## Inputs +**Required inputs:** + +init_image + +**Optional inputs:** + +height +width + +## Additional requirements: +Max input size = 1048576 pixels (ie. the total pixels in a 1024 x 1024 image) +Max output size = 4194304 pixels (ie. the total pixels in a 2048 x 2048 image) + + +The default output size is set by the specific endpoint. +For example, upscale_engine == "esrgan-v1-x2plus" will upscale to 2x the input size + + +If height or width is provided, the original aspect ratio will be maintained. + +Specifying both height and width will throw an error. This is so original aspect ratio is maintained. + +For example: +``` +# This is fine +answers = stability_api.upscale( + init_image=img +) # results in a 2x image if using default upscale_engine + +# This is fine +answers = stability_api.upscale( + width=1000, + init_image=img +) + +# !! This will throw an error !! +answers = stability_api.upscale( + width=1000, + height=1000, + init_image=img +) +``` + +## Example calls + +Command line: + +`python3 -m stability_sdk upscale -i "/path/to/image.png"` + +`python3 -m stability_sdk upscale --engine "esrgan-v1-x2plus" -i "/path/to/image.png"` + +`python3 -m stability_sdk upscale -H 1200 -i "/path/to/image.png"` + +`python3 -m stability_sdk upscale -W 1200 -i "/path/to/image.png"` + +SDK Usage: + +See usage demo notebooks in ./nbs + diff --git a/tests/test_client.py b/tests/test_client.py index 7adedf0a..a17feb72 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -62,3 +62,13 @@ def test_server_mocking(grpc_server, grpc_addr): # might need this link later: # - https://stackoverflow.com/questions/54541338/calling-function-that-yields-from-a-pytest-fixture assert isinstance(response, Generator) + +def test_upscale(grpc_server, grpc_addr): + class_instance = client.StabilityInference(host=grpc_addr[0]) + im = Image.new('RGB',(1,1)) + response = class_instance.upscale(init_image=im) + print(response) + # might need this link later: + # - https://stackoverflow.com/questions/54541338/calling-function-that-yields-from-a-pytest-fixture + assert isinstance(response, Generator) +