Skip to content

Commit

Permalink
fix masking output
Browse files Browse the repository at this point in the history
  • Loading branch information
palp committed Jul 22, 2023
1 parent f6432aa commit 6b9337e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/stability_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def api_request_to_proto(req: GenerationRequest) -> generation.Request:
if init_image.mode != "RGBA":
init_image = init_image.convert("RGBA")
mask_image = init_image.split()[-1] # Extract alpha channel
mask_binary = mask_image.tobytes()
mask_bytes = BytesIO()
mask_image.save(mask_bytes, format="PNG")
mask_binary = mask_bytes.getvalue()
elif mask_source == MaskSource.MASK_IMAGE_WHITE:
# Inverts the provided mask image, having the effect of masking out white pixels.
if req.mask_image is None:
Expand All @@ -201,7 +203,9 @@ def api_request_to_proto(req: GenerationRequest) -> generation.Request:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_image = ImageOps.invert(mask_image)
mask_binary = mask_image.tobytes()
mask_bytes = BytesIO()
mask_image.save(mask_bytes, format="PNG")
mask_binary = mask_bytes.getvalue()
elif mask_source == MaskSource.MASK_IMAGE_BLACK:
# Uses the given mask image as-is, so that black pixels are masked out.
if req.mask_image is None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import time
import base64
from PIL import Image
from io import BytesIO

from stability_sdk.api import CreateRequest, CreateResponse, GenerationResponse
from stability_sdk.interfaces.gooseai.generation.generation_pb2 import Answer, Artifact
Expand Down Expand Up @@ -116,6 +118,8 @@ def test_image_to_image_with_init_image_alpha():
assert prompts[1].artifact.binary is not None
assert prompts[2].artifact is not None
assert prompts[2].artifact.binary is not None
image = Image.open(BytesIO(prompts[2].artifact.binary))
assert image is not None


def test_image_to_image_with_mask_image_white():
Expand Down Expand Up @@ -146,6 +150,8 @@ def test_image_to_image_with_mask_image_white():
assert prompts[1].artifact.binary is not None
assert prompts[2].artifact is not None
assert prompts[2].artifact.binary is not None
image = Image.open(BytesIO(prompts[2].artifact.binary))
assert image is not None


def test_image_to_image_with_mask_image_black():
Expand Down Expand Up @@ -176,6 +182,8 @@ def test_image_to_image_with_mask_image_black():
assert prompts[1].artifact.binary is not None
assert prompts[2].artifact is not None
assert prompts[2].artifact.binary is not None
image = Image.open(BytesIO(prompts[2].artifact.binary))
assert image is not None

def test_generation_response_success():
test_result = {'result': 'success', 'artifacts': [{'base64': 'blahblah', 'finishReason': 'SUCCESS', 'seed': 1}]}
Expand Down

0 comments on commit 6b9337e

Please sign in to comment.