Skip to content

Commit

Permalink
Add a prefill option for generate()
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jan 10, 2024
1 parent fde5879 commit 6f76de8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
31 changes: 25 additions & 6 deletions keras_nlp/models/generative_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import itertools
from functools import partial

import tensorflow as tf
import tree
Expand Down Expand Up @@ -49,7 +50,7 @@ def compile(
# Clear the compiled generate function.
self.generate_function = None

def generate_step(self):
def generate_step(self, end_token_id, prefill):
"""Run generation on a single batch of input."""
raise NotImplementedError

Expand All @@ -64,10 +65,11 @@ def make_generate_function(self):

def wrapped_generate_function(
inputs,
prefill=False,
end_token_id=None,
):
with torch.no_grad():
return self.generate_step(inputs, end_token_id)
return self.generate_step(inputs, prefill, end_token_id)

self.generate_function = wrapped_generate_function
elif config.backend() == "tensorflow" and not self.run_eagerly:
Expand All @@ -80,8 +82,13 @@ def wrapped_generate_function(
elif config.backend() == "jax" and not self.run_eagerly:
import jax

@jax.jit
def compiled_generate_function(inputs, end_token_id, state):
@partial(jax.jit, static_argnames=["prefill", "end_token_id"])
def compiled_generate_function(
inputs,
prefill,
end_token_id,
state,
):
(
sampler_variables,
trainable_variables,
Expand All @@ -94,7 +101,7 @@ def compiled_generate_function(inputs, end_token_id, state):
)

with keras.StatelessScope(state_mapping=mapping) as scope:
outputs = self.generate_step(inputs, end_token_id)
outputs = self.generate_step(inputs, prefill, end_token_id)

# Get updated sampler variables from the stateless scope.
sampler_variables = []
Expand All @@ -110,6 +117,7 @@ def compiled_generate_function(inputs, end_token_id, state):

def wrapped_generate_function(
inputs,
prefill=False,
end_token_id=None,
):
# Create an explicit tuple of all variable state.
Expand All @@ -121,6 +129,7 @@ def wrapped_generate_function(
inputs = tree.map_structure(ops.convert_to_tensor, inputs)
outputs, state = compiled_generate_function(
inputs,
prefill,
end_token_id,
state,
)
Expand Down Expand Up @@ -209,6 +218,7 @@ def generate(
self,
inputs,
max_length=None,
prefill=False,
):
"""Generate text given prompt `inputs`.
Expand Down Expand Up @@ -237,6 +247,11 @@ def generate(
`preprocessor`. If `preprocessor` is `None`, `inputs` should be
should be padded to the desired maximum length and this argument
will be ignored.
prefill: Optional. bool. If `True`, the output state of fixed
prompt in `inputs` computed in a single forward pass, consuming
more memory, but speeding generation. If `False`, the output
state of the prompt will be computed token-by-token, slowing
generation but saving memory. Defaults to `False.
"""
# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
Expand All @@ -253,7 +268,11 @@ def preprocess(x):
)

def generate(x):
return generate_function(x, end_token_id=end_token_id)
return generate_function(
x,
prefill=prefill,
end_token_id=end_token_id,
)

def postprocess(x):
return self.preprocessor.generate_postprocess(x)
Expand Down
22 changes: 15 additions & 7 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,18 @@ def _build_cache(self, token_ids):
max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
model_dim = self.backbone.hidden_dim
head_dim = model_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
shape = [batch_size, max_length, self.backbone.hidden_dim]
hidden_states = ops.zeros(shape, dtype=self.compute_dtype)
return hidden_states, cache

def generate_step(
self,
inputs,
prefill=False,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.
Expand All @@ -273,10 +275,16 @@ def generate_step(
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

if prefill:
# Compute the lengths of all user inputted tokens ids.
# Start at the first index that has no user inputted id.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
index = ops.min(row_lengths)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
else:
index = 1

def next(prompt, cache, index):
# The cache index is the index of our previous token.
Expand Down
24 changes: 17 additions & 7 deletions keras_nlp/models/opt/opt_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,18 @@ def _build_cache(self, token_ids):
max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
model_dim = self.backbone.hidden_dim
head_dim = model_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
shape = [batch_size, max_length, self.backbone.hidden_dim]
hidden_states = ops.zeros(shape, dtype=self.compute_dtype)
return hidden_states, cache

def generate_step(
self,
inputs,
prefill=False,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.
Expand All @@ -269,10 +271,18 @@ def generate_step(
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

if prefill:
# Compute the lengths of all user inputted tokens ids.
# Start at the first index that has no user inputted id.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
index = ops.min(row_lengths)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
else:
# If we are not prefilling with the prompt, start the loop at the
# first predicted token.
index = 1

def next(prompt, cache, index):
# The cache index is the index of our previous token.
Expand Down

0 comments on commit 6f76de8

Please sign in to comment.