Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.tree.map->jax.tree_map #1535

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions t5x/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,14 @@ def _parse_rule(
factor_dims = None

if fallback_to_heuristics and param_ndim <= 2 and not batched_dims:
logging.warning(
'Since rank of parameter %s %d is less than or equal to 2, the '
'factorization method falls back to heuristics and the provided '
'factor rule %s is ignored.',
path,
param_ndim,
rule,
)
# logging.warning(
# 'Since rank of parameter %s %d is less than or equal to 2, the '
# 'factorization method falls back to heuristics and the provided '
# 'factor rule %s is ignored.',
# path,
# param_ndim,
# rule,
# )
return tuple(np.arange(param_ndim)), HEURISTIC_RULE

return averaging_dims, factor_dims
Expand Down
2 changes: 2 additions & 0 deletions t5x/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
model: models.BaseModel,
partitioner: partitioning.BasePartitioner,
log_dir: Optional[str] = None,
num_examples: Optional[str] = None,
verify_matching_vocabs_fn: Optional[
Callable[[utils.DatasetConfig, models.BaseTransformerModel], None]
] = utils.verify_matching_vocabs,
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
seed=infer_eval_dataset_cfg.seed,
sequence_length=infer_eval_dataset_cfg.task_feature_lengths,
use_memory_cache=infer_eval_dataset_cfg.use_memory_cache,
num_examples=num_examples,
**kwargs,
)
# Lazily initialized upon the first `evaluate` call.
Expand Down
136 changes: 68 additions & 68 deletions t5x/examples/t5/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,74 +487,74 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
return output


class Embed(nn.Module):
"""A parameterized function from integers [0, n) to d-dimensional vectors.

Attributes:
num_embeddings: number of embeddings.
features: number of feature dimensions for each embedding.
dtype: the dtype of the embedding vectors (default: float32).
embedding_init: embedding initializer.
one_hot: performs the gather with a one-hot contraction rather than a true
gather. This is currently needed for SPMD partitioning.
"""

num_embeddings: int
features: int
cast_input_dtype: Optional[DType] = None
dtype: DType = jnp.float32
attend_dtype: Optional[DType] = None
embedding_init: Initializer = default_embed_init
one_hot: bool = False
embedding: Array = dataclasses.field(init=False)

def setup(self):
self.embedding = param_with_axes(
'embedding',
self.embedding_init,
(self.num_embeddings, self.features),
jnp.float32,
axes=('vocab', 'embed'),
)

def __call__(self, inputs: Array) -> Array:
"""Embeds the inputs along the last dimension.

Args:
inputs: input data, all dimensions are considered batch dimensions.

Returns:
Output which is embedded input data. The output shape follows the input,
with an additional `features` dimension appended.
"""
if self.cast_input_dtype:
inputs = inputs.astype(self.cast_input_dtype)
if not jnp.issubdtype(inputs.dtype, jnp.integer):
raise ValueError('Input type must be an integer or unsigned integer.')
if self.one_hot:
iota = lax.iota(jnp.int32, self.num_embeddings)
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
else:
output = jnp.asarray(self.embedding, self.dtype)[inputs]
output = with_sharding_constraint(output, ('batch', 'length', 'embed'))
return output

def attend(self, query: Array) -> Array:
"""Attend over the embedding using a query array.

Args:
query: array with last dimension equal the feature depth `features` of the
embedding.

Returns:
An array with final dim `num_embeddings` corresponding to the batched
inner-product of the array of query vectors against each embedding.
Commonly used for weight-sharing between embeddings and logit transform
in NLP models.
"""
dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
# class Embed(nn.Module):
# """A parameterized function from integers [0, n) to d-dimensional vectors.

# Attributes:
# num_embeddings: number of embeddings.
# features: number of feature dimensions for each embedding.
# dtype: the dtype of the embedding vectors (default: float32).
# embedding_init: embedding initializer.
# one_hot: performs the gather with a one-hot contraction rather than a true
# gather. This is currently needed for SPMD partitioning.
# """

# num_embeddings: int
# features: int
# cast_input_dtype: Optional[DType] = None
# dtype: DType = jnp.float32
# attend_dtype: Optional[DType] = None
# embedding_init: Initializer = default_embed_init
# one_hot: bool = False
# embedding: Array = dataclasses.field(init=False)

# def setup(self):
# self.embedding = param_with_axes(
# 'embedding',
# self.embedding_init,
# (self.num_embeddings, self.features),
# jnp.float32,
# axes=('vocab', 'embed'),
# )

# def __call__(self, inputs: Array) -> Array:
# """Embeds the inputs along the last dimension.

# Args:
# inputs: input data, all dimensions are considered batch dimensions.

# Returns:
# Output which is embedded input data. The output shape follows the input,
# with an additional `features` dimension appended.
# """
# if self.cast_input_dtype:
# inputs = inputs.astype(self.cast_input_dtype)
# if not jnp.issubdtype(inputs.dtype, jnp.integer):
# raise ValueError('Input type must be an integer or unsigned integer.')
# if self.one_hot:
# iota = lax.iota(jnp.int32, self.num_embeddings)
# one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
# output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
# else:
# output = jnp.asarray(self.embedding, self.dtype)[inputs]
# output = with_sharding_constraint(output, ('batch', 'length', 'embed'))
# return output

# def attend(self, query: Array) -> Array:
# """Attend over the embedding using a query array.

# Args:
# query: array with last dimension equal the feature depth `features` of the
# embedding.

# Returns:
# An array with final dim `num_embeddings` corresponding to the batched
# inner-product of the array of query vectors against each embedding.
# Commonly used for weight-sharing between embeddings and logit transform
# in NLP models.
# """
# dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
# return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)


class RelativePositionBiases(nn.Module):
Expand Down
36 changes: 18 additions & 18 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,14 +673,14 @@ def _put_to_devices(x, global_shape):

return device_buffers

device_buffers = jax.tree.map(_put_to_devices, host_arrays, global_shapes)
device_buffers = jax.tree_map(_put_to_devices, host_arrays, global_shapes)

def _jax_array(dbs, global_shape):
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(global_mesh, axes), dbs
)

return jax.tree.map(
return jax.tree_map(
_jax_array,
device_buffers,
global_shapes,
Expand Down Expand Up @@ -748,7 +748,7 @@ def prepare_train_iter(
)


input_shapes = jax.tree.map(
input_shapes = jax.tree_map(
lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec
)
train_iter = ShardedDatasetIterator(train_iter, partitioner, input_shapes)
Expand Down Expand Up @@ -1394,18 +1394,18 @@ def log_model_info(

state_dict = full_train_state.state_dict()
total_num_params = jax.tree_util.tree_reduce(
np.add, jax.tree.map(np.size, state_dict['target'])
np.add, jax.tree_map(np.size, state_dict['target'])
)

logical_axes = partitioner.get_logical_axes(full_train_state).state_dict()

mesh_axes = jax.tree.map(
mesh_axes = jax.tree_map(
lambda x: tuple(x) if x is not None else None,
partitioner.get_mesh_axes(full_train_state).state_dict(),
)

def _log_info_and_write_to_file(writer, format_str, *args):
logging.info(format_str, *args)
# logging.info(format_str, *args)
if writer is not None:
writer.write(format_str % args + '\n')

Expand Down Expand Up @@ -1447,7 +1447,7 @@ def _log_variable(
mesh_axes,
)

jax.tree.map(
jax.tree_map(
_log_variable,
state_utils.get_name_tree(state_dict['target'], keep_empty_nodes=True),
state_dict['target'],
Expand All @@ -1462,7 +1462,7 @@ def _log_variable(
# Add a blank line between params and states.
_log_info_and_write_to_file(writer, '')

jax.tree.map(
jax.tree_map(
_log_variable,
state_utils.get_name_tree(state_dict['state'], keep_empty_nodes=True),
state_dict['state'],
Expand Down Expand Up @@ -1534,7 +1534,7 @@ def _remove_padding(all_inferences, all_indices):
"""
non_pad_idxs = np.where(all_indices >= 0)
all_indices = all_indices[non_pad_idxs]
all_inferences = jax.tree.map(lambda x: x[non_pad_idxs], all_inferences)
all_inferences = jax.tree_map(lambda x: x[non_pad_idxs], all_inferences)
return all_inferences, all_indices


Expand Down Expand Up @@ -1623,7 +1623,7 @@ def infer_fn(
train_state: train_state_lib.TrainState,
rng: Optional[jnp.ndarray] = None,
):
ds_shapes = jax.tree.map(lambda x: jnp.array(x.shape), ds.element_spec)
ds_shapes = jax.tree_map(lambda x: jnp.array(x.shape), ds.element_spec)
multihost_assert_equal(
ds_shapes,
(
Expand Down Expand Up @@ -1724,7 +1724,7 @@ def infer_fn(
index,
train_state.flax_mutables,
)
logging.info('Inference of batch %s done.', index)
# logging.info('Inference of batch %s done.', index)


def _copy_to_host_async(x):
Expand All @@ -1737,8 +1737,8 @@ def _copy_to_host_async(x):
return x

try:
batch_result = jax.tree.map(_copy_to_host_async, batch_result)
batch_indices = jax.tree.map(_copy_to_host_async, batch_indices)
batch_result = jax.tree_map(_copy_to_host_async, batch_result)
batch_indices = jax.tree_map(_copy_to_host_async, batch_indices)
except AttributeError:
# Similar to jax.device_get, we skip transfers for non DeviceArrays.
pass
Expand All @@ -1750,7 +1750,7 @@ def _copy_to_host_async(x):
all_inferences = batched_results

# List[B * shard_count, ...] -> [B * shard_count * batch_count, ...]
all_inferences = jax.tree.map(
all_inferences = jax.tree_map(
lambda *args: np.concatenate(args), *all_inferences
)
all_indices = np.concatenate(all_indices)
Expand All @@ -1761,7 +1761,7 @@ def _copy_to_host_async(x):
# Note: remove padding first, as -1 indices would mess up this operation.
# Note: all_inferences may be a PyTree, not just an array, e.g. if
# `infer_step` is `model.predict_batch_with_aux`.
all_inferences = jax.tree.map(lambda x: x[all_indices], all_inferences)
all_inferences = jax.tree_map(lambda x: x[all_indices], all_inferences)
all_indices = all_indices[all_indices]

# aux_values is supposed to be a dictionary that maps strings to a set of
Expand Down Expand Up @@ -1793,7 +1793,7 @@ def _copy_to_host_async(x):
zip(*all_inferences),
)
indices_and_outputs = list(zip(all_indices, all_inferences))
indices_and_outputs = jax.tree.map(
indices_and_outputs = jax.tree_map(
lambda x: np.array(x).tolist(), indices_and_outputs
)
if len(indices_and_outputs) != original_ds_length:
Expand All @@ -1807,9 +1807,9 @@ def _copy_to_host_async(x):
return indices_and_outputs
else:
if keep_aux_as_numpy:
aux_values = jax.tree.map(lambda x: list(np.array(x)), aux_values)
aux_values = jax.tree_map(lambda x: list(np.array(x)), aux_values)
else:
aux_values = jax.tree.map(lambda x: np.array(x).tolist(), aux_values)
aux_values = jax.tree_map(lambda x: np.array(x).tolist(), aux_values)
return indices_and_outputs, aux_values

return infer_fn
Expand Down