diff --git a/Makefile b/Makefile index 5fc797f..6aaf992 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ format: black manifest/ tests/ web_app/ check: - isort -c -v manifest/ tests/ web_app/ + isort -c manifest/ tests/ web_app/ black manifest/ tests/ web_app/ --check flake8 manifest/ tests/ web_app/ mypy manifest/ tests/ web_app/ diff --git a/manifest/api/models/diffuser.py b/manifest/api/models/diffuser.py index b42ed3a..e04db4f 100644 --- a/manifest/api/models/diffuser.py +++ b/manifest/api/models/diffuser.py @@ -75,7 +75,7 @@ def get_init_params(self) -> Dict: @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model. diff --git a/manifest/api/models/huggingface.py b/manifest/api/models/huggingface.py index 89c4c2b..912832b 100644 --- a/manifest/api/models/huggingface.py +++ b/manifest/api/models/huggingface.py @@ -132,7 +132,7 @@ def __init__( def __call__( self, text: Union[str, List[str]], **kwargs: Any - ) -> List[Dict[str, Union[str, List[float]]]]: + ) -> List[Dict[str, Union[str, List[float], List[str]]]]: """Generate from text. Args: @@ -162,6 +162,7 @@ def __call__( top_p=kwargs.get("top_p"), repetition_penalty=kwargs.get("repetition_penalty"), num_return_sequences=kwargs.get("num_return_sequences"), + do_sample=kwargs.get("do_sample"), ) kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None} output_dict = self.model.generate( # type: ignore @@ -587,7 +588,7 @@ def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray: @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model. @@ -616,7 +617,7 @@ def generate( ( cast(str, r["generated_text"]), sum(cast(List[float], r["logprobs"])), - cast(List[int], r["tokens"]), + cast(List[str], r["tokens"]), cast(List[float], r["logprobs"]), ) for r in result diff --git a/manifest/api/models/model.py b/manifest/api/models/model.py index 3317211..dcb04b9 100644 --- a/manifest/api/models/model.py +++ b/manifest/api/models/model.py @@ -45,7 +45,7 @@ def get_init_params(self) -> Dict: def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model. diff --git a/manifest/api/models/sentence_transformer.py b/manifest/api/models/sentence_transformer.py index bd3f5fa..5f6c2fb 100644 --- a/manifest/api/models/sentence_transformer.py +++ b/manifest/api/models/sentence_transformer.py @@ -66,7 +66,7 @@ def get_init_params(self) -> Dict: @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model.