Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634007125
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 17, 2024
1 parent 2e3093e commit 8fa2a10
Show file tree
Hide file tree
Showing 21 changed files with 61 additions and 61 deletions.
2 changes: 1 addition & 1 deletion baselines/jft/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def write_note(note):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down
2 changes: 1 addition & 1 deletion baselines/jft/al_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def as_dataset(self, # pytype: disable=signature-mismatch # overriding-paramet
element_spec = dataset.element_spec.copy()
element_spec['id'] = tf.TensorSpec(shape=(), dtype=tf.int64, name=None)
logging.info(msg=f'element_spec = {element_spec}; '
f'type = {jax.tree_map(type, element_spec)}')
f'type = {jax.tree.map(type, element_spec)}')

dataset = tf.data.Dataset.from_generator(
_subset_generator(
Expand Down
2 changes: 1 addition & 1 deletion baselines/jft/batchensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down
12 changes: 6 additions & 6 deletions baselines/jft/batchensemble_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def log_average_sigmoid_probs(logits: jnp.ndarray) -> jnp.ndarray:

def tree_clip_norm_global_pmax(tree, max_norm, axis_name):
"""Global norm clipping, with pmax of global norm before clipping."""
global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree_leaves(tree)))
global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree.leaves(tree)))
global_norm = jax.lax.pmax(global_norm, axis_name=axis_name)
factor = jnp.minimum(1.0, max_norm / global_norm)
return jax.tree_map(lambda x: factor * x, tree), global_norm
return jax.tree.map(lambda x: factor * x, tree), global_norm


def _traverse_with_names(tree):
Expand Down Expand Up @@ -93,7 +93,7 @@ def tree_flatten_with_names(tree):
A list of values with names: [(name, value), ...].
A PyTreeDef tree definition object.
"""
vals, tree_def = jax.tree_flatten(tree)
vals, tree_def = jax.tree.flatten(tree)

# "Fake" token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
Expand All @@ -111,7 +111,7 @@ def tree_flatten_with_names(tree):


def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True):
"""Like jax.tree_map but with a filter on the leaf path name.
"""Like jax.tree.map but with a filter on the leaf path name.
Args:
f: The function to be applied to each parameter in `param_tree`.
Expand All @@ -132,8 +132,8 @@ def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True):

def tree_rngs_split(rngs, num_splits=2):
"""Splits a PyTree of PRNGKeys into num_splits PyTrees."""
rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_splits), rngs)
slice_rngs = lambda rngs, i: jax.tree_map(lambda rng: rng[i], rngs)
rngs = jax.tree.map(lambda rng: jax.random.split(rng, num_splits), rngs)
slice_rngs = lambda rngs, i: jax.tree.map(lambda rng: rng[i], rngs)
return tuple(slice_rngs(rngs, i) for i in range(num_splits))


Expand Down
12 changes: 6 additions & 6 deletions baselines/jft/begp.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def init(rng):
params_cpu, states_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -501,19 +501,19 @@ def loss_fn(params, states, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

# Optionally resize the global gradient to a maximum norm. We found this
# useful in some cases across optimizers, hence it's in the main loop.
if config.get('grad_clip_norm'):
g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
g = jax.tree_map(lambda p: g_factor * p, g)
g = jax.tree.map(lambda p: g_factor * p, g)
opt = opt.apply_gradient(g, learning_rate=lr)
opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

# Compute training accuracy by the ensemble members independently to save
Expand Down Expand Up @@ -644,8 +644,8 @@ def loss_fn(params, states, images, labels):
# (`states`). This is ok since `random features` are frozen throughout
# pre-training, and `precision matrix` is a finetuning-specific parameters
# that will be re-learned in the finetuning task.
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)
opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl)

# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
Expand Down
6 changes: 3 additions & 3 deletions baselines/jft/bit_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.host_id() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -375,7 +375,7 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

Expand All @@ -397,7 +397,7 @@ def decay_fn(v, wd):
target=train_utils.tree_map_with_regex(decay_fn, opt.target,
decay_rules))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

return opt, l, rng, measurements
Expand Down
6 changes: 3 additions & 3 deletions baselines/jft/bit_heteroscedastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.host_id() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -417,7 +417,7 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

Expand All @@ -439,7 +439,7 @@ def decay_fn(v, wd):
target=train_utils.tree_map_with_regex(decay_fn, opt.target,
decay_rules))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

return opt, l, rng, measurements
Expand Down
2 changes: 1 addition & 1 deletion baselines/jft/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _tree_flatten_with_names(tree):
Returns:
A list of values with names: [(name, value), ...].
"""
vals, tree_def = jax.tree_flatten(tree)
vals, tree_def = jax.tree.flatten(tree)

# "Fake" token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
Expand Down
6 changes: 3 additions & 3 deletions baselines/jft/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -397,7 +397,7 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

Expand All @@ -410,7 +410,7 @@ def loss_fn(params, images, labels):

opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

top1_idx = jnp.argmax(logits, axis=1)
Expand Down
4 changes: 2 additions & 2 deletions baselines/jft/deterministic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g
logging.info(msg=f'measurements = {measurements}')
Expand All @@ -157,7 +157,7 @@ def loss_fn(params, images, labels):

opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

top1_idx = jnp.argmax(logits, axis=1)
Expand Down
10 changes: 5 additions & 5 deletions baselines/jft/heteroscedastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -438,19 +438,19 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

# Optionally resize the global gradient to a maximum norm. We found this
# useful in some cases across optimizers, hence it's in the main loop.
if config.get('grad_clip_norm'):
g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
g = jax.tree_map(lambda p: g_factor * p, g)
g = jax.tree.map(lambda p: g_factor * p, g)
opt = opt.apply_gradient(g, learning_rate=lr)
opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

return opt, l, rng, measurements
Expand Down Expand Up @@ -570,7 +570,7 @@ def loss_fn(params, images, labels):
# We need to transfer the weights over now or else we risk keeping them
# alive while they'll be updated in a future step, creating hard to debug
# memory errors (see b/160593526). Also, takes device 0's params only.
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl)

# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
Expand Down
12 changes: 6 additions & 6 deletions baselines/jft/hetgpbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def init(rng):
params_cpu, states_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -486,19 +486,19 @@ def loss_fn(params, states, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

# Optionally resize the global gradient to a maximum norm. We found this
# useful in some cases across optimizers, hence it's in the main loop.
if config.get('grad_clip_norm'):
g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
g = jax.tree_map(lambda p: g_factor * p, g)
g = jax.tree.map(lambda p: g_factor * p, g)
opt = opt.apply_gradient(g, learning_rate=lr)
opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

# Compute training accuracy by the ensemble members independently to save
Expand Down Expand Up @@ -646,8 +646,8 @@ def loss_fn(params, states, images, labels):
# (`states`). This is ok since `random features` are frozen throughout
# pre-training, and `precision matrix` is a finetuning-specific parameters
# that will be re-learned in the finetuning task.
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)
opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl)

# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
Expand Down
12 changes: 6 additions & 6 deletions baselines/jft/hetsngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def init(rng):
params_cpu, states_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -463,19 +463,19 @@ def loss_fn(params, states, images, labels):
# or if we don't use grad_accum_steps, as they interact badly.
do_grad_clip = config.get('grad_clip_norm', -1.) > 0.
if config.get('grad_accum_steps', 1) == 1 or do_grad_clip:
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

# Optionally resize the global gradient to a maximum norm. We found this
# useful in some cases across optimizers, hence it's in the main loop.
if do_grad_clip:
g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
g = jax.tree_map(lambda p: g_factor * p, g)
g = jax.tree.map(lambda p: g_factor * p, g)
opt = opt.apply_gradient(g, learning_rate=lr)
opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

return opt, s, l, rng, measurements
Expand Down Expand Up @@ -595,8 +595,8 @@ def loss_fn(params, states, images, labels):
# (`states`). This is ok since `random features` are frozen throughout
# pre-training, and `precision matrix` is a finetuning-specific parameters
# that will be re-learned in the finetuning task.
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)
opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl)

# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
Expand Down
2 changes: 1 addition & 1 deletion baselines/jft/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _prepare(x):
# https://github.com/tensorflow/tensorflow/issues/33254#issuecomment-542379165
return np.asarray(memoryview(x))

it = (jax.tree_map(_prepare, xs) for xs in it)
it = (jax.tree.map(_prepare, xs) for xs in it)

if n_prefetch:
it = flax.jax_utils.prefetch_to_device(it, n_prefetch, devices=devices)
Expand Down
6 changes: 3 additions & 3 deletions baselines/jft/mimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -473,7 +473,7 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

Expand All @@ -486,7 +486,7 @@ def loss_fn(params, images, labels):

opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

return opt, l, rng, measurements
Expand Down
2 changes: 1 addition & 1 deletion baselines/jft/plex.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down
Loading

0 comments on commit 8fa2a10

Please sign in to comment.