From 8fa2a1089925bb3b9f5ed63e8a79b5d11bee0299 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 15 May 2024 10:51:43 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634007125 --- baselines/jft/active_learning.py | 2 +- baselines/jft/al_utils.py | 2 +- baselines/jft/batchensemble.py | 2 +- baselines/jft/batchensemble_utils.py | 12 ++++++------ baselines/jft/begp.py | 12 ++++++------ baselines/jft/bit_deterministic.py | 6 +++--- baselines/jft/bit_heteroscedastic.py | 6 +++--- baselines/jft/checkpoint_utils.py | 2 +- baselines/jft/deterministic.py | 6 +++--- baselines/jft/deterministic_utils.py | 4 ++-- baselines/jft/heteroscedastic.py | 10 +++++----- baselines/jft/hetgpbe.py | 12 ++++++------ baselines/jft/hetsngp.py | 12 ++++++------ baselines/jft/input_utils.py | 2 +- baselines/jft/mimo.py | 6 +++--- baselines/jft/plex.py | 2 +- baselines/jft/rank1_bnn.py | 2 +- baselines/jft/sngp.py | 12 ++++++------ baselines/jft/train_utils.py | 6 +++--- baselines/jft/vmoe.py | 2 +- baselines/jft/vmoe_utils_test.py | 2 +- 21 files changed, 61 insertions(+), 61 deletions(-) diff --git a/baselines/jft/active_learning.py b/baselines/jft/active_learning.py index 022534e3..12079b49 100644 --- a/baselines/jft/active_learning.py +++ b/baselines/jft/active_learning.py @@ -759,7 +759,7 @@ def write_note(note): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) diff --git a/baselines/jft/al_utils.py b/baselines/jft/al_utils.py index 966d8441..4af1d695 100644 --- a/baselines/jft/al_utils.py +++ b/baselines/jft/al_utils.py @@ -121,7 +121,7 @@ def as_dataset(self, # pytype: disable=signature-mismatch # overriding-paramet element_spec = dataset.element_spec.copy() element_spec['id'] = tf.TensorSpec(shape=(), dtype=tf.int64, name=None) logging.info(msg=f'element_spec = {element_spec}; ' - f'type = {jax.tree_map(type, element_spec)}') + f'type = {jax.tree.map(type, element_spec)}') dataset = tf.data.Dataset.from_generator( _subset_generator( diff --git a/baselines/jft/batchensemble.py b/baselines/jft/batchensemble.py index 5390686c..8354567d 100644 --- a/baselines/jft/batchensemble.py +++ b/baselines/jft/batchensemble.py @@ -283,7 +283,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) diff --git a/baselines/jft/batchensemble_utils.py b/baselines/jft/batchensemble_utils.py index aeee7345..c73484c8 100644 --- a/baselines/jft/batchensemble_utils.py +++ b/baselines/jft/batchensemble_utils.py @@ -59,10 +59,10 @@ def log_average_sigmoid_probs(logits: jnp.ndarray) -> jnp.ndarray: def tree_clip_norm_global_pmax(tree, max_norm, axis_name): """Global norm clipping, with pmax of global norm before clipping.""" - global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree_leaves(tree))) + global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree.leaves(tree))) global_norm = jax.lax.pmax(global_norm, axis_name=axis_name) factor = jnp.minimum(1.0, max_norm / global_norm) - return jax.tree_map(lambda x: factor * x, tree), global_norm + return jax.tree.map(lambda x: factor * x, tree), global_norm def _traverse_with_names(tree): @@ -93,7 +93,7 @@ def tree_flatten_with_names(tree): A list of values with names: [(name, value), ...]. A PyTreeDef tree definition object. """ - vals, tree_def = jax.tree_flatten(tree) + vals, tree_def = jax.tree.flatten(tree) # "Fake" token tree that is use to track jax internal tree traversal and # adjust our custom tree traversal to be compatible with it. @@ -111,7 +111,7 @@ def tree_flatten_with_names(tree): def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True): - """Like jax.tree_map but with a filter on the leaf path name. + """Like jax.tree.map but with a filter on the leaf path name. Args: f: The function to be applied to each parameter in `param_tree`. @@ -132,8 +132,8 @@ def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True): def tree_rngs_split(rngs, num_splits=2): """Splits a PyTree of PRNGKeys into num_splits PyTrees.""" - rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_splits), rngs) - slice_rngs = lambda rngs, i: jax.tree_map(lambda rng: rng[i], rngs) + rngs = jax.tree.map(lambda rng: jax.random.split(rng, num_splits), rngs) + slice_rngs = lambda rngs, i: jax.tree.map(lambda rng: rng[i], rngs) return tuple(slice_rngs(rngs, i) for i in range(num_splits)) diff --git a/baselines/jft/begp.py b/baselines/jft/begp.py index a8513ca5..847b5a3a 100644 --- a/baselines/jft/begp.py +++ b/baselines/jft/begp.py @@ -321,7 +321,7 @@ def init(rng): params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -501,7 +501,7 @@ def loss_fn(params, states, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -509,11 +509,11 @@ def loss_fn(params, states, images, labels): # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) - g = jax.tree_map(lambda p: g_factor * p, g) + g = jax.tree.map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) # Compute training accuracy by the ensemble members independently to save @@ -644,8 +644,8 @@ def loss_fn(params, states, images, labels): # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. - opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) - states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) + opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl) + states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None diff --git a/baselines/jft/bit_deterministic.py b/baselines/jft/bit_deterministic.py index e4537f00..84286d84 100644 --- a/baselines/jft/bit_deterministic.py +++ b/baselines/jft/bit_deterministic.py @@ -277,7 +277,7 @@ def init(rng): params_cpu = init(rng_init) if jax.host_id() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -375,7 +375,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -397,7 +397,7 @@ def decay_fn(v, wd): target=train_utils.tree_map_with_regex(decay_fn, opt.target, decay_rules)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements diff --git a/baselines/jft/bit_heteroscedastic.py b/baselines/jft/bit_heteroscedastic.py index ff1d025f..7fd2e4de 100644 --- a/baselines/jft/bit_heteroscedastic.py +++ b/baselines/jft/bit_heteroscedastic.py @@ -303,7 +303,7 @@ def init(rng): params_cpu = init(rng_init) if jax.host_id() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -417,7 +417,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -439,7 +439,7 @@ def decay_fn(v, wd): target=train_utils.tree_map_with_regex(decay_fn, opt.target, decay_rules)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements diff --git a/baselines/jft/checkpoint_utils.py b/baselines/jft/checkpoint_utils.py index 3bab8299..da8ec19b 100644 --- a/baselines/jft/checkpoint_utils.py +++ b/baselines/jft/checkpoint_utils.py @@ -199,7 +199,7 @@ def _tree_flatten_with_names(tree): Returns: A list of values with names: [(name, value), ...]. """ - vals, tree_def = jax.tree_flatten(tree) + vals, tree_def = jax.tree.flatten(tree) # "Fake" token tree that is use to track jax internal tree traversal and # adjust our custom tree traversal to be compatible with it. diff --git a/baselines/jft/deterministic.py b/baselines/jft/deterministic.py index 05b871d0..5def28a1 100644 --- a/baselines/jft/deterministic.py +++ b/baselines/jft/deterministic.py @@ -277,7 +277,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -397,7 +397,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -410,7 +410,7 @@ def loss_fn(params, images, labels): opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) top1_idx = jnp.argmax(logits, axis=1) diff --git a/baselines/jft/deterministic_utils.py b/baselines/jft/deterministic_utils.py index 6ddc6e1a..d84a6498 100644 --- a/baselines/jft/deterministic_utils.py +++ b/baselines/jft/deterministic_utils.py @@ -143,7 +143,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g logging.info(msg=f'measurements = {measurements}') @@ -157,7 +157,7 @@ def loss_fn(params, images, labels): opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) top1_idx = jnp.argmax(logits, axis=1) diff --git a/baselines/jft/heteroscedastic.py b/baselines/jft/heteroscedastic.py index bf7fc43a..9aa79175 100644 --- a/baselines/jft/heteroscedastic.py +++ b/baselines/jft/heteroscedastic.py @@ -298,7 +298,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -438,7 +438,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -446,11 +446,11 @@ def loss_fn(params, images, labels): # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) - g = jax.tree_map(lambda p: g_factor * p, g) + g = jax.tree.map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements @@ -570,7 +570,7 @@ def loss_fn(params, images, labels): # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. - opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) + opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None diff --git a/baselines/jft/hetgpbe.py b/baselines/jft/hetgpbe.py index 3873b7ae..4749e945 100644 --- a/baselines/jft/hetgpbe.py +++ b/baselines/jft/hetgpbe.py @@ -310,7 +310,7 @@ def init(rng): params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -486,7 +486,7 @@ def loss_fn(params, states, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -494,11 +494,11 @@ def loss_fn(params, states, images, labels): # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) - g = jax.tree_map(lambda p: g_factor * p, g) + g = jax.tree.map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) # Compute training accuracy by the ensemble members independently to save @@ -646,8 +646,8 @@ def loss_fn(params, states, images, labels): # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. - opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) - states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) + opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl) + states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None diff --git a/baselines/jft/hetsngp.py b/baselines/jft/hetsngp.py index 5f6773bb..387fcf4d 100644 --- a/baselines/jft/hetsngp.py +++ b/baselines/jft/hetsngp.py @@ -312,7 +312,7 @@ def init(rng): params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -463,7 +463,7 @@ def loss_fn(params, states, images, labels): # or if we don't use grad_accum_steps, as they interact badly. do_grad_clip = config.get('grad_clip_norm', -1.) > 0. if config.get('grad_accum_steps', 1) == 1 or do_grad_clip: - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -471,11 +471,11 @@ def loss_fn(params, states, images, labels): # useful in some cases across optimizers, hence it's in the main loop. if do_grad_clip: g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) - g = jax.tree_map(lambda p: g_factor * p, g) + g = jax.tree.map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, s, l, rng, measurements @@ -595,8 +595,8 @@ def loss_fn(params, states, images, labels): # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. - opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) - states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) + opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl) + states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None diff --git a/baselines/jft/input_utils.py b/baselines/jft/input_utils.py index bc9d9ee5..1fc093bf 100644 --- a/baselines/jft/input_utils.py +++ b/baselines/jft/input_utils.py @@ -361,7 +361,7 @@ def _prepare(x): # https://github.com/tensorflow/tensorflow/issues/33254#issuecomment-542379165 return np.asarray(memoryview(x)) - it = (jax.tree_map(_prepare, xs) for xs in it) + it = (jax.tree.map(_prepare, xs) for xs in it) if n_prefetch: it = flax.jax_utils.prefetch_to_device(it, n_prefetch, devices=devices) diff --git a/baselines/jft/mimo.py b/baselines/jft/mimo.py index 33f7e492..8ca30467 100644 --- a/baselines/jft/mimo.py +++ b/baselines/jft/mimo.py @@ -333,7 +333,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -473,7 +473,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -486,7 +486,7 @@ def loss_fn(params, images, labels): opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements diff --git a/baselines/jft/plex.py b/baselines/jft/plex.py index d2dd9e57..451db1bc 100644 --- a/baselines/jft/plex.py +++ b/baselines/jft/plex.py @@ -304,7 +304,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) diff --git a/baselines/jft/rank1_bnn.py b/baselines/jft/rank1_bnn.py index e8774cce..3be844a6 100644 --- a/baselines/jft/rank1_bnn.py +++ b/baselines/jft/rank1_bnn.py @@ -446,7 +446,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) diff --git a/baselines/jft/sngp.py b/baselines/jft/sngp.py index 1233aef3..75d7be51 100644 --- a/baselines/jft/sngp.py +++ b/baselines/jft/sngp.py @@ -309,7 +309,7 @@ def init(rng): params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -451,7 +451,7 @@ def loss_fn(params, states, images, labels): # or if we don't use grad_accum_steps, as they interact badly. do_grad_clip = config.get('grad_clip_norm', -1.) > 0. if config.get('grad_accum_steps', 1) == 1 or do_grad_clip: - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -459,11 +459,11 @@ def loss_fn(params, states, images, labels): # useful in some cases across optimizers, hence it's in the main loop. if do_grad_clip: g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) - g = jax.tree_map(lambda p: g_factor * p, g) + g = jax.tree.map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) measurements['reset_covmat'] = reset_covmat @@ -585,8 +585,8 @@ def loss_fn(params, states, images, labels): # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. - opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) - states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) + opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl) + states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None diff --git a/baselines/jft/train_utils.py b/baselines/jft/train_utils.py index 03da48a6..3e315e61 100644 --- a/baselines/jft/train_utils.py +++ b/baselines/jft/train_utils.py @@ -70,7 +70,7 @@ def acc_grad_and_loss(i, l_and_g): (step_size, labels.shape[1])) li, gi = loss_and_grad_fn(params, imgs, lbls) l, g = l_and_g - return (l + li, jax.tree_map(lambda x, y: x + y, g, gi)) + return (l + li, jax.tree.map(lambda x, y: x + y, g, gi)) l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) return jax.tree_util.tree_map(lambda x: x / accum_steps, (l, g)) @@ -107,10 +107,10 @@ def acc_grad_and_loss(i, l_s_g): # Update state and accumulate gradient. l, s, g = l_s_g (li, si), gi = loss_and_grad_fn(params, s, imgs, lbls) - return (l + li, si, jax.tree_map(lambda x, y: x + y, g, gi)) + return (l + li, si, jax.tree.map(lambda x, y: x + y, g, gi)) l, s, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, s, g)) - l, g = jax.tree_map(lambda x: x / accum_steps, (l, g)) + l, g = jax.tree.map(lambda x: x / accum_steps, (l, g)) return (l, s), g else: return loss_and_grad_fn(params, states, images, labels) diff --git a/baselines/jft/vmoe.py b/baselines/jft/vmoe.py index 06707967..6ce79fda 100644 --- a/baselines/jft/vmoe.py +++ b/baselines/jft/vmoe.py @@ -333,7 +333,7 @@ def single_model_pred_fn(params, images): pjit_partition_params_fn = pjit.pjit( fun=lambda x: x, in_shardings=( - jax.tree_map(lambda _: jax.sharding.PartitionSpec(), model_params), + jax.tree.map(lambda _: jax.sharding.PartitionSpec(), model_params), ), out_shardings=variables_partition_spec[model_key], ) diff --git a/baselines/jft/vmoe_utils_test.py b/baselines/jft/vmoe_utils_test.py index 2e69b184..13dce512 100644 --- a/baselines/jft/vmoe_utils_test.py +++ b/baselines/jft/vmoe_utils_test.py @@ -42,7 +42,7 @@ def test_variables_partition_spec(self): 'self-attention': jax.sharding.PartitionSpec(), } } - jax.tree_map(np.testing.assert_equal, expected_partition_spec, + jax.tree.map(np.testing.assert_equal, expected_partition_spec, partition_spec) def test_deep_ensemble_reshape_outputs(self):