diff --git a/t5x/adafactor.py b/t5x/adafactor.py index 1920fd5ce..bb7cee196 100644 --- a/t5x/adafactor.py +++ b/t5x/adafactor.py @@ -382,14 +382,14 @@ def _parse_rule( factor_dims = None if fallback_to_heuristics and param_ndim <= 2 and not batched_dims: - logging.warning( - 'Since rank of parameter %s %d is less than or equal to 2, the ' - 'factorization method falls back to heuristics and the provided ' - 'factor rule %s is ignored.', - path, - param_ndim, - rule, - ) + # logging.warning( + # 'Since rank of parameter %s %d is less than or equal to 2, the ' + # 'factorization method falls back to heuristics and the provided ' + # 'factor rule %s is ignored.', + # path, + # param_ndim, + # rule, + # ) return tuple(np.arange(param_ndim)), HEURISTIC_RULE return averaging_dims, factor_dims diff --git a/t5x/eval.py b/t5x/eval.py index d8fcc38e4..de9dc28cc 100644 --- a/t5x/eval.py +++ b/t5x/eval.py @@ -71,6 +71,7 @@ def __init__( model: models.BaseModel, partitioner: partitioning.BasePartitioner, log_dir: Optional[str] = None, + num_examples: Optional[str] = None, verify_matching_vocabs_fn: Optional[ Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] ] = utils.verify_matching_vocabs, @@ -112,6 +113,7 @@ def __init__( seed=infer_eval_dataset_cfg.seed, sequence_length=infer_eval_dataset_cfg.task_feature_lengths, use_memory_cache=infer_eval_dataset_cfg.use_memory_cache, + num_examples=num_examples, **kwargs, ) # Lazily initialized upon the first `evaluate` call. diff --git a/t5x/examples/t5/layers.py b/t5x/examples/t5/layers.py index cc2e6ac8b..af007380b 100644 --- a/t5x/examples/t5/layers.py +++ b/t5x/examples/t5/layers.py @@ -487,74 +487,74 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): return output -class Embed(nn.Module): - """A parameterized function from integers [0, n) to d-dimensional vectors. - - Attributes: - num_embeddings: number of embeddings. - features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: float32). - embedding_init: embedding initializer. - one_hot: performs the gather with a one-hot contraction rather than a true - gather. This is currently needed for SPMD partitioning. - """ - - num_embeddings: int - features: int - cast_input_dtype: Optional[DType] = None - dtype: DType = jnp.float32 - attend_dtype: Optional[DType] = None - embedding_init: Initializer = default_embed_init - one_hot: bool = False - embedding: Array = dataclasses.field(init=False) - - def setup(self): - self.embedding = param_with_axes( - 'embedding', - self.embedding_init, - (self.num_embeddings, self.features), - jnp.float32, - axes=('vocab', 'embed'), - ) - - def __call__(self, inputs: Array) -> Array: - """Embeds the inputs along the last dimension. - - Args: - inputs: input data, all dimensions are considered batch dimensions. - - Returns: - Output which is embedded input data. The output shape follows the input, - with an additional `features` dimension appended. - """ - if self.cast_input_dtype: - inputs = inputs.astype(self.cast_input_dtype) - if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') - if self.one_hot: - iota = lax.iota(jnp.int32, self.num_embeddings) - one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) - output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) - else: - output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = with_sharding_constraint(output, ('batch', 'length', 'embed')) - return output - - def attend(self, query: Array) -> Array: - """Attend over the embedding using a query array. - - Args: - query: array with last dimension equal the feature depth `features` of the - embedding. - - Returns: - An array with final dim `num_embeddings` corresponding to the batched - inner-product of the array of query vectors against each embedding. - Commonly used for weight-sharing between embeddings and logit transform - in NLP models. - """ - dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype - return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) +# class Embed(nn.Module): +# """A parameterized function from integers [0, n) to d-dimensional vectors. + +# Attributes: +# num_embeddings: number of embeddings. +# features: number of feature dimensions for each embedding. +# dtype: the dtype of the embedding vectors (default: float32). +# embedding_init: embedding initializer. +# one_hot: performs the gather with a one-hot contraction rather than a true +# gather. This is currently needed for SPMD partitioning. +# """ + +# num_embeddings: int +# features: int +# cast_input_dtype: Optional[DType] = None +# dtype: DType = jnp.float32 +# attend_dtype: Optional[DType] = None +# embedding_init: Initializer = default_embed_init +# one_hot: bool = False +# embedding: Array = dataclasses.field(init=False) + +# def setup(self): +# self.embedding = param_with_axes( +# 'embedding', +# self.embedding_init, +# (self.num_embeddings, self.features), +# jnp.float32, +# axes=('vocab', 'embed'), +# ) + +# def __call__(self, inputs: Array) -> Array: +# """Embeds the inputs along the last dimension. + +# Args: +# inputs: input data, all dimensions are considered batch dimensions. + +# Returns: +# Output which is embedded input data. The output shape follows the input, +# with an additional `features` dimension appended. +# """ +# if self.cast_input_dtype: +# inputs = inputs.astype(self.cast_input_dtype) +# if not jnp.issubdtype(inputs.dtype, jnp.integer): +# raise ValueError('Input type must be an integer or unsigned integer.') +# if self.one_hot: +# iota = lax.iota(jnp.int32, self.num_embeddings) +# one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) +# output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) +# else: +# output = jnp.asarray(self.embedding, self.dtype)[inputs] +# output = with_sharding_constraint(output, ('batch', 'length', 'embed')) +# return output + +# def attend(self, query: Array) -> Array: +# """Attend over the embedding using a query array. + +# Args: +# query: array with last dimension equal the feature depth `features` of the +# embedding. + +# Returns: +# An array with final dim `num_embeddings` corresponding to the batched +# inner-product of the array of query vectors against each embedding. +# Commonly used for weight-sharing between embeddings and logit transform +# in NLP models. +# """ +# dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype +# return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) class RelativePositionBiases(nn.Module): diff --git a/t5x/utils.py b/t5x/utils.py index 36ead8d0a..46280255b 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -673,14 +673,14 @@ def _put_to_devices(x, global_shape): return device_buffers - device_buffers = jax.tree.map(_put_to_devices, host_arrays, global_shapes) + device_buffers = jax.tree_map(_put_to_devices, host_arrays, global_shapes) def _jax_array(dbs, global_shape): return jax.make_array_from_single_device_arrays( global_shape, jax.sharding.NamedSharding(global_mesh, axes), dbs ) - return jax.tree.map( + return jax.tree_map( _jax_array, device_buffers, global_shapes, @@ -748,7 +748,7 @@ def prepare_train_iter( ) - input_shapes = jax.tree.map( + input_shapes = jax.tree_map( lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec ) train_iter = ShardedDatasetIterator(train_iter, partitioner, input_shapes) @@ -1394,18 +1394,18 @@ def log_model_info( state_dict = full_train_state.state_dict() total_num_params = jax.tree_util.tree_reduce( - np.add, jax.tree.map(np.size, state_dict['target']) + np.add, jax.tree_map(np.size, state_dict['target']) ) logical_axes = partitioner.get_logical_axes(full_train_state).state_dict() - mesh_axes = jax.tree.map( + mesh_axes = jax.tree_map( lambda x: tuple(x) if x is not None else None, partitioner.get_mesh_axes(full_train_state).state_dict(), ) def _log_info_and_write_to_file(writer, format_str, *args): - logging.info(format_str, *args) + # logging.info(format_str, *args) if writer is not None: writer.write(format_str % args + '\n') @@ -1447,7 +1447,7 @@ def _log_variable( mesh_axes, ) - jax.tree.map( + jax.tree_map( _log_variable, state_utils.get_name_tree(state_dict['target'], keep_empty_nodes=True), state_dict['target'], @@ -1462,7 +1462,7 @@ def _log_variable( # Add a blank line between params and states. _log_info_and_write_to_file(writer, '') - jax.tree.map( + jax.tree_map( _log_variable, state_utils.get_name_tree(state_dict['state'], keep_empty_nodes=True), state_dict['state'], @@ -1534,7 +1534,7 @@ def _remove_padding(all_inferences, all_indices): """ non_pad_idxs = np.where(all_indices >= 0) all_indices = all_indices[non_pad_idxs] - all_inferences = jax.tree.map(lambda x: x[non_pad_idxs], all_inferences) + all_inferences = jax.tree_map(lambda x: x[non_pad_idxs], all_inferences) return all_inferences, all_indices @@ -1623,7 +1623,7 @@ def infer_fn( train_state: train_state_lib.TrainState, rng: Optional[jnp.ndarray] = None, ): - ds_shapes = jax.tree.map(lambda x: jnp.array(x.shape), ds.element_spec) + ds_shapes = jax.tree_map(lambda x: jnp.array(x.shape), ds.element_spec) multihost_assert_equal( ds_shapes, ( @@ -1724,7 +1724,7 @@ def infer_fn( index, train_state.flax_mutables, ) - logging.info('Inference of batch %s done.', index) + # logging.info('Inference of batch %s done.', index) def _copy_to_host_async(x): @@ -1737,8 +1737,8 @@ def _copy_to_host_async(x): return x try: - batch_result = jax.tree.map(_copy_to_host_async, batch_result) - batch_indices = jax.tree.map(_copy_to_host_async, batch_indices) + batch_result = jax.tree_map(_copy_to_host_async, batch_result) + batch_indices = jax.tree_map(_copy_to_host_async, batch_indices) except AttributeError: # Similar to jax.device_get, we skip transfers for non DeviceArrays. pass @@ -1750,7 +1750,7 @@ def _copy_to_host_async(x): all_inferences = batched_results # List[B * shard_count, ...] -> [B * shard_count * batch_count, ...] - all_inferences = jax.tree.map( + all_inferences = jax.tree_map( lambda *args: np.concatenate(args), *all_inferences ) all_indices = np.concatenate(all_indices) @@ -1761,7 +1761,7 @@ def _copy_to_host_async(x): # Note: remove padding first, as -1 indices would mess up this operation. # Note: all_inferences may be a PyTree, not just an array, e.g. if # `infer_step` is `model.predict_batch_with_aux`. - all_inferences = jax.tree.map(lambda x: x[all_indices], all_inferences) + all_inferences = jax.tree_map(lambda x: x[all_indices], all_inferences) all_indices = all_indices[all_indices] # aux_values is supposed to be a dictionary that maps strings to a set of @@ -1793,7 +1793,7 @@ def _copy_to_host_async(x): zip(*all_inferences), ) indices_and_outputs = list(zip(all_indices, all_inferences)) - indices_and_outputs = jax.tree.map( + indices_and_outputs = jax.tree_map( lambda x: np.array(x).tolist(), indices_and_outputs ) if len(indices_and_outputs) != original_ds_length: @@ -1807,9 +1807,9 @@ def _copy_to_host_async(x): return indices_and_outputs else: if keep_aux_as_numpy: - aux_values = jax.tree.map(lambda x: list(np.array(x)), aux_values) + aux_values = jax.tree_map(lambda x: list(np.array(x)), aux_values) else: - aux_values = jax.tree.map(lambda x: np.array(x).tolist(), aux_values) + aux_values = jax.tree_map(lambda x: np.array(x).tolist(), aux_values) return indices_and_outputs, aux_values return infer_fn