Skip to content

Commit

Permalink
Flake8 Format fixes, Comments update
Browse files Browse the repository at this point in the history
  • Loading branch information
OmarElChamaa committed Feb 20, 2024
1 parent 1f4389a commit f5f8675
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 50 deletions.
6 changes: 2 additions & 4 deletions demo/demo_text_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def demo(self, model: ModelsTextConversation):
dim=-1) if model.conversation_step > 0 else (
new_user_input_ids)

# generated a response while limiting the total chat history to 1000 tokens,
# generated a response while limiting the total
# chat history to 1000 tokens,
chat_history_ids = model.pipeline.generate(
bot_input_ids, max_length=1000,
pad_token_id=model.tokenizer.eos_token_id)
Expand All @@ -55,12 +56,9 @@ def demo(self, model: ModelsTextConversation):
new_user_input_ids = model.tokenizer.encode(
"How are you ? " + model.tokenizer.eos_token,
return_tensors='pt')

# append the new user input tokens to the chat history
bot_input_ids = torch.cat(
[model.chat_history_ids, new_user_input_ids],
dim=-1) if model.conversation_step > 0 else new_user_input_ids
# generated a response while limiting the total chat history to 1000 tokens,
chat_history_ids = model.pipeline.generate(
bot_input_ids,
max_length=1000,
Expand Down
6 changes: 4 additions & 2 deletions demo/demo_txt_to_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class DemoMainTxtToImg():

def __init__(self):
options = OptionsTextToImage(
prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
prompt="Astronaut in a jungle, cold color palette, "
"muted colors, detailed, 8k",
device=Devices.GPU,
image_width=512,
image_height=512
Expand All @@ -19,7 +20,8 @@ def __init__(self):
model_management = ModelsManagement()
model_stabilityai = ModelTextToImage(model_stabilityai_name)

model_management.add_model(new_model=model_stabilityai, model_options=options)
model_management.add_model(new_model=model_stabilityai,
model_options=options)
model_management.load_model(model_stabilityai_name)

image = model_management.generate_prompt()
Expand Down
3 changes: 2 additions & 1 deletion models/model_text_conversation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, ConversationalPipeline, Conversation
from transformers import (AutoModelForCausalLM, AutoTokenizer,
ConversationalPipeline, Conversation)
from models.model import Model
from options.options import Devices
from options.options_text_conversation import OptionsTextConversation
Expand Down
76 changes: 33 additions & 43 deletions options/options_text_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,73 +15,59 @@ class OptionsTextConversation(Options):
device Union[str, Devices]:
The device to use for inference,
specified from the `Devices` enum.
prompt (str, optional):
The prompt or starting text for text generation.
prompt (str):
The prompt or starting text for text generation.
Default is an empty string.
model (Union[str, PreTrainedModel, "TFPreTrainedModel"], optional):
The model to use for text generation.
model (Union[str, PreTrainedModel, "TFPreTrainedModel"]):
The model to use for text generation.
This can be a model identifier, a pre-trained model instance
inheriting from `PreTrainedModel` for PyTorch,
inheriting from `PreTrainedModel` for PyTorch,
or `"TFPreTrainedModel"` for TensorFlow.
tokenizer (Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"],
optional):
tokenizer (Union[str, PreTrainedTokenizer,
"PreTrainedTokenizerFast"]):
The tokenizer to use for tokenizing input text.
This can be a model identifier, a pre-trained tokenizer
instance inheriting from `PreTrainedTokenizer`,
instance inheriting from `PreTrainedTokenizer`,
or `"PreTrainedTokenizerFast"`.
model_card (Optional[Union[str, ModelCard]], optional):
The model card providing details about the model,
The model card providing details about the model,
such as usage guidelines and citations.
framework (Optional[str], optional):
The deep learning framework used for the model,
The deep learning framework used for the model,
such as 'torch' or 'tf'.
task (str, optional):
The task for which the model is being used.
task (str):
The task for which the model is being used.
Default is an empty string.
num_workers (Optional[int], optional):
num_workers (int):
The number of worker processes for data loading.
Default is 8.
batch_size (Optional[int], optional):
The batch size for inference.
batch_size (int):
The batch size for inference.
Default is 1.
arg_parser (Optional[Dict[str, Any]], optional):
Optional additional arguments for the model.
torch_dtype (Optional[Union[str, torch.dtype]], optional):
The data type for PyTorch tensors,
such as 'float32' or 'float64'.
binary_output (Optional[bool], optional):
The data type for PyTorch tensors,
such as 'float32' or 'float64.
binary_output (bool):
Whether the output should be binary or text.
Default is False.
min_length_for_response (Optional[int], optional):
min_length_for_response (int):
The minimum length of response generated by the model.
Default is 32.
minimum_tokens (Optional[int], optional):
The minimum number of tokens required for a valid response.
minimum_tokens (int):
The minimum number of tokens required for a valid response.
Default is 10.
max_steps (Optional[int], optional):
The maximum number of steps for generating text.
max_steps (int):
The maximum number of steps for generating text.
Default is 50.
"""

prompt: str
model: Union[str, PreTrainedModel, "TFPreTrainedModel"] = None
tokenizer: Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"] = None

model: Union[str, PreTrainedModel,
"TFPreTrainedModel"] = None
tokenizer: Union[str, PreTrainedTokenizer,
"PreTrainedTokenizerFast"] = None
model_card: Optional[Union[str, ModelCard]] = None
framework: Optional[str] = None
task: str = ""
Expand All @@ -97,8 +83,12 @@ class OptionsTextConversation(Options):

def __init__(self,
prompt: str,
model: Union[str, PreTrainedModel, "TFPreTrainedModel"] = None,
tokenizer: Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"] = None,
model:
Union[str, PreTrainedModel, "TFPreTrainedModel"] = None,
tokenizer:
Union[str, PreTrainedTokenizer,
"PreTrainedTokenizerFast"] = None,

model_card: Optional[Union[str, ModelCard]] = None,
framework: Optional[str] = None,
task: str = "",
Expand Down

0 comments on commit f5f8675

Please sign in to comment.