diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index ca4294e4..6a87ab45 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -202,6 +202,10 @@ def run( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -221,6 +225,10 @@ def run( return self.tgis_generation_client.unary_generate( text=verbalized_text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -256,6 +264,10 @@ def run_stream_out( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing against the model running in TGIS @@ -275,6 +287,10 @@ def run_stream_out( return self.tgis_generation_client.stream_generate( text=verbalized_text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 598cdd59..7c55cc2c 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -227,6 +227,10 @@ def run( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -236,11 +240,14 @@ def run( GeneratedTextResult Generated text result produced by TGIS. """ - if self._model_loaded: return self.tgis_generation_client.unary_generate( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -276,6 +283,10 @@ def run_stream_out( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing for text generation module. @@ -289,6 +300,10 @@ def run_stream_out( return self.tgis_generation_client.stream_generate( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 876e3a8d..8b50dd15 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -37,9 +37,19 @@ GENERATE_FUNCTION_TGIS_ARGS = """ {} - preserve_input_text: str + preserve_input_text: bool Whether or not the source string should be contained in the generated output, e.g., as a prefix. + input_tokens: bool + Whether or not to include list of input tokens. + generated_tokens: bool + Whether or not to include list of individual generated tokens. + token_logprobs: bool + Whether or not to include logprob for each returned token. + Applicable only if generated_tokens == true and/or input_tokens == true + token_ranks: bool + Whether or not to include rank of each returned token. + Applicable only if generated_tokens == true and/or input_tokens == true """.format( GENERATE_FUNCTION_ARGS ) @@ -48,6 +58,10 @@ def validate_inf_params( text, preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, eos_token, max_new_tokens, min_new_tokens, @@ -74,6 +88,10 @@ def validate_inf_params( ) error.type_check("", str, text=text) error.type_check("", bool, preserve_input_text=preserve_input_text) + error.type_check("", bool, input_tokens=input_tokens) + error.type_check("", bool, generated_tokens=generated_tokens) + error.type_check("", bool, token_logprobs=token_logprobs) + error.type_check("", bool, token_ranks=token_ranks) error.type_check("", str, allow_none=True, eos_token=eos_token) error.type_check( "", @@ -174,6 +192,10 @@ def validate_inf_params( def get_params( preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, max_new_tokens, min_new_tokens, truncate_input_tokens, @@ -211,10 +233,10 @@ def get_params( res_options = generation_pb2.ResponseOptions( input_text=preserve_input_text, - generated_tokens=True, - input_tokens=False, - token_logprobs=True, - token_ranks=True, + generated_tokens=generated_tokens, + input_tokens=input_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, ) stopping = generation_pb2.StoppingCriteria( stop_sequences=stop_sequences, @@ -268,6 +290,10 @@ def unary_generate( self, text, preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, max_new_tokens, min_new_tokens, truncate_input_tokens, @@ -305,6 +331,10 @@ def unary_generate( validate_inf_params( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, @@ -325,6 +355,10 @@ def unary_generate( params = get_params( preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -366,6 +400,24 @@ def unary_generate( ) response = batch_response.responses[0] + token_list = [] + if response.tokens is not None: + for token in response.tokens: + token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) + + input_token_list = [] + if response.input_tokens is not None: + for token in response.input_tokens: + input_token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) + return GeneratedTextResult( generated_text=response.text, generated_tokens=response.generated_token_count, @@ -373,12 +425,18 @@ def unary_generate( producer_id=self.producer_id, input_token_count=response.input_token_count, seed=seed, + tokens=token_list, + input_tokens=input_token_list, ) def stream_generate( self, text, preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, max_new_tokens, min_new_tokens, truncate_input_tokens, @@ -416,6 +474,10 @@ def stream_generate( validate_inf_params( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, @@ -434,6 +496,10 @@ def stream_generate( params = get_params( preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -476,13 +542,25 @@ def stream_generate( input_token_count=stream_part.input_token_count, ) token_list = [] - for token in stream_part.tokens: - token_list.append( - GeneratedToken(text=token.text, logprob=token.logprob) - ) + if stream_part.tokens is not None: + for token in stream_part.tokens: + token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) + input_token_list = [] + if stream_part.input_tokens is not None: + for token in stream_part.input_tokens: + input_token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) yield GeneratedTextStreamResult( generated_text=stream_part.text, tokens=token_list, + input_tokens=input_token_list, details=details, ) diff --git a/pyproject.toml b/pyproject.toml index 37ccc71d..e1ce63ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.26.14,<0.27.0", + "caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0", "caikit-tgis-backend>=0.1.27,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0", diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index bccee629..e7440e83 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -215,6 +215,16 @@ def unary_generate(request): fake_result.generated_token_count = 1 fake_result.text = "moose" fake_result.input_token_count = 1 + token = mock.Mock() + token.text = "moose" + token.logprob = 0.2 + token.rank = 1 + fake_result.tokens = [token] + input_tokens = mock.Mock() + input_tokens.text = "moose" + input_tokens.logprob = 0.2 + input_tokens.rank = 1 + fake_result.input_tokens = [input_tokens] fake_response.responses = [fake_result] return fake_response @@ -228,7 +238,13 @@ def stream_generate(request): token = mock.Mock() token.text = "moose" token.logprob = 0.2 + token.rank = 1 fake_stream.tokens = [token] + input_tokens = mock.Mock() + input_tokens.text = "moose" + input_tokens.logprob = 0.2 + input_tokens.rank = 1 + fake_stream.input_tokens = [input_tokens] fake_stream.text = "moose" for _ in range(3): yield fake_stream @@ -248,6 +264,12 @@ def validate_unary_generate_response(result): assert result.generated_tokens == 1 assert result.finish_reason == 5 assert result.input_token_count == 1 + assert result.tokens[0].text == "moose" + assert result.tokens[0].logprob == 0.2 + assert result.tokens[0].rank == 1 + assert result.input_tokens[0].text == "moose" + assert result.input_tokens[0].logprob == 0.2 + assert result.input_tokens[0].rank == 1 @staticmethod def validate_stream_generate_response(stream_result): @@ -259,6 +281,10 @@ def validate_stream_generate_response(stream_result): assert first_result.generated_text == "moose" assert first_result.tokens[0].text == "moose" assert first_result.tokens[0].logprob == 0.2 + assert first_result.tokens[0].rank == 1 + assert first_result.input_tokens[0].text == "moose" + assert first_result.input_tokens[0].logprob == 0.2 + assert first_result.input_tokens[0].rank == 1 assert first_result.details.finish_reason == 5 assert first_result.details.generated_tokens == 1 assert first_result.details.seed == 10