Skip to content

Commit

Permalink
docs: blackify readme
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 7, 2023
1 parent 5514288 commit 3dc2ab3
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ from docarray import DocList, BaseDoc

from transformers import pipeline


class Prompt(BaseDoc):
text: str

Expand All @@ -111,10 +112,11 @@ class Generation(BaseDoc):


class StableLM(Executor):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.generator = pipeline('text-generation', model='stabilityai/stablelm-base-alpha-3b')
self.generator = pipeline(
'text-generation', model='stabilityai/stablelm-base-alpha-3b'
)

@requests
def generate(self, docs: DocList[Prompt], **kwargs) -> DocList[Generation]:
Expand All @@ -124,6 +126,7 @@ class StableLM(Executor):
for prompt, output in zip(prompts, llm_outputs):
generations.append(Generation(prompt=prompt, text=output))
return generations

```

</td>
Expand Down Expand Up @@ -176,6 +179,7 @@ Use [Jina Client](https://docs.jina.ai/concepts/client/) to make requests to the
from jina import Client
from docarray import DocList, BaseDoc
class Prompt(BaseDoc):
text: str
Expand All @@ -186,7 +190,7 @@ class Generation(BaseDoc):
prompt = Prompt(
text = 'suggest an interesting image generation prompt for a mona lisa variant'
text='suggest an interesting image generation prompt for a mona lisa variant'
)
client = Client(port=12345) # use port from output above
Expand Down Expand Up @@ -230,6 +234,7 @@ from jina import Executor, requests
from docarray import BaseDoc, DocList
from docarray.documents import ImageDoc
class Generation(BaseDoc):
prompt: str
text: str
Expand All @@ -240,11 +245,16 @@ class TextToImage(Executor):
super().__init__(**kwargs)
from diffusers import StableDiffusionPipeline
import torch
self.pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
self.pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
).to("cuda")
@requests
def generate_image(self, docs: DocList[Generation], **kwargs) -> DocList[ImageDoc]:
images = self.pipe(docs.text).images # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
images = self.pipe(
docs.text
).images # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
for i, doc in enumerate(docs):
doc.tensor = np.array(images[i])
```
Expand Down Expand Up @@ -314,18 +324,20 @@ from jina import Client
from docarray import DocList, BaseDoc
from docarray.documents import ImageDoc
class Prompt(BaseDoc):
text: str
prompt = Prompt(
text = 'suggest an interesting image generation prompt for a mona lisa variant'
text='suggest an interesting image generation prompt for a mona lisa variant'
)
client = Client(port=12345) # use port from output above
response = client.post(on='/', inputs=[prompt], return_type=DocList[ImageDoc])
response[0].display()
```

![](./.github/images/mona-lisa.png)
Expand Down

0 comments on commit 3dc2ab3

Please sign in to comment.