Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634412237
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 16, 2024
1 parent 348a88d commit 2e3093e
Show file tree
Hide file tree
Showing 20 changed files with 40 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ copy compared to `tensor._numpy()`.
```python
for batch in iter(ds):
train_step(jax.tree_map(lambda y: y._numpy(), batch))
train_step(jax.tree.map(lambda y: y._numpy(), batch))
```
### Models
Expand Down
12 changes: 6 additions & 6 deletions baselines/diabetic_retinopathy_detection/batchensemble_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def log_average_sigmoid_probs(logits: jnp.ndarray) -> jnp.ndarray:

def tree_clip_norm_global_pmax(tree, max_norm, axis_name):
"""Global norm clipping, with pmax of global norm before clipping."""
global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree_leaves(tree)))
global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree.leaves(tree)))
global_norm = jax.lax.pmax(global_norm, axis_name=axis_name)
factor = jnp.minimum(1.0, max_norm / global_norm)
return jax.tree_map(lambda x: factor * x, tree), global_norm
return jax.tree.map(lambda x: factor * x, tree), global_norm


def _traverse_with_names(tree):
Expand Down Expand Up @@ -94,7 +94,7 @@ def tree_flatten_with_names(tree):
A list of values with names: [(name, value), ...].
A PyTreeDef tree definition object.
"""
vals, tree_def = jax.tree_flatten(tree)
vals, tree_def = jax.tree.flatten(tree)

# "Fake" token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
Expand All @@ -112,7 +112,7 @@ def tree_flatten_with_names(tree):


def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True):
"""Like jax.tree_map but with a filter on the leaf path name.
"""Like jax.tree.map but with a filter on the leaf path name.
Args:
f: The function to be applied to each parameter in `param_tree`.
Expand All @@ -133,8 +133,8 @@ def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True):

def tree_rngs_split(rngs, num_splits=2):
"""Splits a PyTree of PRNGKeys into num_splits PyTrees."""
rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_splits), rngs)
slice_rngs = lambda rngs, i: jax.tree_map(lambda rng: rng[i], rngs)
rngs = jax.tree.map(lambda rng: jax.random.split(rng, num_splits), rngs)
slice_rngs = lambda rngs, i: jax.tree.map(lambda rng: rng[i], rngs)
return tuple(slice_rngs(rngs, i) for i in range(num_splits))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _tree_flatten_with_names(tree):
Returns:
A list of values with names: [(name, value), ...].
"""
vals, tree_def = jax.tree_flatten(tree)
vals, tree_def = jax.tree.flatten(tree)

# "Fake" token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
Expand Down
2 changes: 1 addition & 1 deletion baselines/diabetic_retinopathy_detection/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _prepare(x):
# https://github.com/tensorflow/tensorflow/issues/33254#issuecomment-542379165
return np.asarray(memoryview(x))

it = (jax.tree_map(_prepare, xs) for xs in it)
it = (jax.tree.map(_prepare, xs) for xs in it)

if n_prefetch:
it = flax.jax_utils.prefetch_to_device(it, n_prefetch, devices=devices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -312,7 +312,7 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None:
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

Expand All @@ -336,7 +336,7 @@ def decay_fn(v, wd):
target=train_utils.tree_map_with_regex(decay_fn, opt.target,
decay_rules))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(
sum([jnp.vdot(p, p) for p in params]))

Expand Down
12 changes: 6 additions & 6 deletions baselines/diabetic_retinopathy_detection/jax_finetune_sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def init(rng):
params_cpu, states_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -359,19 +359,19 @@ def loss_fn(params, states, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None:
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

# Optionally resize the global gradient to a maximum norm. We found this
# useful in some cases across optimizers, hence it's in the main loop.
if grad_clip_norm is not None:
g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g)
g = jax.tree_map(lambda p: g_factor * p, g)
g = jax.tree.map(lambda p: g_factor * p, g)
opt = opt.apply_gradient(g, learning_rate=lr)
opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))
measurements['reset_covmat'] = reset_covmat

Expand Down Expand Up @@ -489,8 +489,8 @@ def loss_fn(params, states, images, labels):
# (`states`). This is ok since `random features` are frozen throughout
# pre-training, and `precision matrix` is a finetuning-specific parameters
# that will be re-learned in the finetuning task.
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)
opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl)

# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
Expand Down
2 changes: 1 addition & 1 deletion baselines/diabetic_retinopathy_detection/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def acc_grad_and_loss(i, l_and_g):
(step_size, labels.shape[1]))
li, gi = loss_and_grad_fn(params, imgs, lbls)
l, g = l_and_g
return (l + li, jax.tree_map(lambda x, y: x + y, g, gi))
return (l + li, jax.tree.map(lambda x, y: x + y, g, gi))

l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
return jax.tree_util.tree_map(lambda x: x / accum_steps, (l, g))
Expand Down
4 changes: 2 additions & 2 deletions baselines/diabetic_retinopathy_detection/utils/vit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def acc_grad_and_loss(i, l_s_g):
# Update state and accumulate gradient.
l, s, g = l_s_g
(li, si), gi = loss_and_grad_fn(params, s, imgs, lbls)
return (l + li, si, jax.tree_map(lambda x, y: x + y, g, gi))
return (l + li, si, jax.tree.map(lambda x, y: x + y, g, gi))

l, s, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, s, g))
l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
l, g = jax.tree.map(lambda x: x / accum_steps, (l, g))
return (l, s), g
else:
return loss_and_grad_fn(params, states, images, labels)
Expand Down
4 changes: 2 additions & 2 deletions baselines/t5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _json_compat(value):
json_dict['score'] = _json_compat(predictions)
# This is a section where we deviate from the original function.
if aux_values:
json_dict['intermediates'] = jax.tree_map(_json_compat, aux_values)
json_dict['intermediates'] = jax.tree.map(_json_compat, aux_values)
elif mode == 'predict_batch_with_aux':
assert vocabulary is not None
# This is a section where we deviate from the original function.
Expand All @@ -156,7 +156,7 @@ def _json_compat(value):
predict_dict = {k: _json_compat(v) for k, v in predict_dict.items()}
json_dict.update(predict_dict)

json_dict['aux'] = jax.tree_map(_json_compat, aux_values)
json_dict['aux'] = jax.tree.map(_json_compat, aux_values)
else:
raise ValueError(f'Invalid mode: {mode}')

Expand Down
2 changes: 1 addition & 1 deletion experimental/multimodal/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _tree_flatten_with_names(tree):
Returns:
A list of values with names: [(name, value), ...].
"""
vals, tree_def = jax.tree_flatten(tree)
vals, tree_def = jax.tree.flatten(tree)

# "Fake" token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
Expand Down
8 changes: 4 additions & 4 deletions experimental/multimodal/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def init(rng):
params_cpu, states_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -385,7 +385,7 @@ def loss_fn(params, states, images, texts):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

Expand All @@ -398,7 +398,7 @@ def loss_fn(params, states, images, texts):

opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

top1_idx = jnp.argmax(logits, axis=1)
Expand Down Expand Up @@ -504,7 +504,7 @@ def loss_fn(params, states, images, texts):
# alive while they'll be updated in a future step, creating hard to debug
# memory errors (see b/160593526). Also, takes device 0's params only.
opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)
states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl)

# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
Expand Down
2 changes: 1 addition & 1 deletion experimental/multimodal/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _prepare(x):
# https://github.com/tensorflow/tensorflow/issues/33254#issuecomment-542379165
return np.asarray(memoryview(x))

it = (jax.tree_map(_prepare, xs) for xs in it)
it = (jax.tree.map(_prepare, xs) for xs in it)

if n_prefetch:
it = flax.jax_utils.prefetch_to_device(it, n_prefetch, devices=devices)
Expand Down
6 changes: 3 additions & 3 deletions experimental/near_ood/vit/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def init(rng):
params_cpu = init(rng_init)

if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0])
parameter_overview.log_parameter_overview(params_cpu)
writer.write_scalars(step=0, scalars={'num_params': num_params})

Expand Down Expand Up @@ -383,7 +383,7 @@ def loss_fn(params, images, labels):
# Log the gradient norm only if we need to compute it anyways (clipping)
# or if we don't use grad_accum_steps, as they interact badly.
if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'):
grads, _ = jax.tree_flatten(g)
grads, _ = jax.tree.flatten(g)
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
measurements['l2_grads'] = l2_g

Expand All @@ -396,7 +396,7 @@ def loss_fn(params, images, labels):

opt = opt.replace(target=weight_decay_fn(opt.target, lr))

params, _ = jax.tree_flatten(opt.target)
params, _ = jax.tree.flatten(opt.target)
measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

return opt, l, rng, measurements
Expand Down
4 changes: 2 additions & 2 deletions experimental/robust_segvit/custom_segmentation_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,9 @@ def train(

train_summary = train_utils.log_train_summary(
step=step,
train_metrics=jax.tree_map(train_utils.unreplicate_and_get,
train_metrics=jax.tree.map(train_utils.unreplicate_and_get,
train_metrics),
extra_training_logs=jax.tree_map(train_utils.unreplicate_and_get,
extra_training_logs=jax.tree.map(train_utils.unreplicate_and_get,
extra_training_logs),
writer=writer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def train_and_evaluation(self, model, train_state, fake_batches, metrics_fn):
metrics = train_utils.unreplicate_and_get(metrics)
eval_metrics.append(metrics)
eval_metrics = train_utils.stack_forest(eval_metrics)
eval_summary = jax.tree_map(lambda x: x.sum(), eval_metrics)
eval_summary = jax.tree.map(lambda x: x.sum(), eval_metrics)
for key, val in eval_summary.items():
eval_summary[key] = val[0] / val[1]
return eval_summary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
" d['image'] = normalize(preprocess(d['image']))\n",
" return d\n",
" def _prepare(d):\n",
" return jax.tree_map(lambda x: x._numpy(), d)\n",
" return jax.tree.map(lambda x: x._numpy(), d)\n",
" batched_dataset = ds.map(_preprocess).batch(batch_size)\n",
" batched_dataset = map(_prepare, batched_dataset)\n",
" return batched_dataset"
Expand Down
2 changes: 1 addition & 1 deletion uncertainty_baselines/models/vit_gp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_vision_transformer(self, classifier, representation_size,
key = jax.random.PRNGKey(0)
variables = model.init(key, inputs, train=False)

param_count = sum(p.size for p in jax.tree_flatten(variables)[0])
param_count = sum(p.size for p in jax.tree.flatten(variables)[0])
self.assertEqual(param_count, expected_param_count)

logits, outputs = model.apply(variables, inputs, train=False)
Expand Down
2 changes: 1 addition & 1 deletion uncertainty_baselines/models/vit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_vision_transformer(self, classifier, representation_size,
key = jax.random.PRNGKey(0)
variables = model.init(key, inputs, train=False)

param_count = sum(p.size for p in jax.tree_flatten(variables)[0])
param_count = sum(p.size for p in jax.tree.flatten(variables)[0])
self.assertEqual(param_count, expected_param_count)

logits, outputs = model.apply(variables, inputs, train=False)
Expand Down
2 changes: 1 addition & 1 deletion uncertainty_baselines/models/vit_tram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_vision_transformer_tram(self, classifier, representation_size,
key = jax.random.PRNGKey(0)
variables = model.init(key, inputs, pi_inputs, train=False)

param_count = sum(p.size for p in jax.tree_flatten(variables)[0])
param_count = sum(p.size for p in jax.tree.flatten(variables)[0])
self.assertEqual(param_count, expected_param_count)

logits, outputs = model.apply(variables, inputs, pi_inputs, train=False)
Expand Down

0 comments on commit 2e3093e

Please sign in to comment.