diff --git a/README.md b/README.md index 746a4cc64..9003d4ebb 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ copy compared to `tensor._numpy()`. ```python for batch in iter(ds): - train_step(jax.tree_map(lambda y: y._numpy(), batch)) + train_step(jax.tree.map(lambda y: y._numpy(), batch)) ``` ### Models diff --git a/baselines/diabetic_retinopathy_detection/batchensemble_utils.py b/baselines/diabetic_retinopathy_detection/batchensemble_utils.py index 4d891775e..fc4f74200 100644 --- a/baselines/diabetic_retinopathy_detection/batchensemble_utils.py +++ b/baselines/diabetic_retinopathy_detection/batchensemble_utils.py @@ -60,10 +60,10 @@ def log_average_sigmoid_probs(logits: jnp.ndarray) -> jnp.ndarray: def tree_clip_norm_global_pmax(tree, max_norm, axis_name): """Global norm clipping, with pmax of global norm before clipping.""" - global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree_leaves(tree))) + global_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree.leaves(tree))) global_norm = jax.lax.pmax(global_norm, axis_name=axis_name) factor = jnp.minimum(1.0, max_norm / global_norm) - return jax.tree_map(lambda x: factor * x, tree), global_norm + return jax.tree.map(lambda x: factor * x, tree), global_norm def _traverse_with_names(tree): @@ -94,7 +94,7 @@ def tree_flatten_with_names(tree): A list of values with names: [(name, value), ...]. A PyTreeDef tree definition object. """ - vals, tree_def = jax.tree_flatten(tree) + vals, tree_def = jax.tree.flatten(tree) # "Fake" token tree that is use to track jax internal tree traversal and # adjust our custom tree traversal to be compatible with it. @@ -112,7 +112,7 @@ def tree_flatten_with_names(tree): def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True): - """Like jax.tree_map but with a filter on the leaf path name. + """Like jax.tree.map but with a filter on the leaf path name. Args: f: The function to be applied to each parameter in `param_tree`. @@ -133,8 +133,8 @@ def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True): def tree_rngs_split(rngs, num_splits=2): """Splits a PyTree of PRNGKeys into num_splits PyTrees.""" - rngs = jax.tree_map(lambda rng: jax.random.split(rng, num_splits), rngs) - slice_rngs = lambda rngs, i: jax.tree_map(lambda rng: rng[i], rngs) + rngs = jax.tree.map(lambda rng: jax.random.split(rng, num_splits), rngs) + slice_rngs = lambda rngs, i: jax.tree.map(lambda rng: rng[i], rngs) return tuple(slice_rngs(rngs, i) for i in range(num_splits)) diff --git a/baselines/diabetic_retinopathy_detection/checkpoint_utils.py b/baselines/diabetic_retinopathy_detection/checkpoint_utils.py index 46a9bf2f3..cebff9e0e 100644 --- a/baselines/diabetic_retinopathy_detection/checkpoint_utils.py +++ b/baselines/diabetic_retinopathy_detection/checkpoint_utils.py @@ -136,7 +136,7 @@ def _tree_flatten_with_names(tree): Returns: A list of values with names: [(name, value), ...]. """ - vals, tree_def = jax.tree_flatten(tree) + vals, tree_def = jax.tree.flatten(tree) # "Fake" token tree that is use to track jax internal tree traversal and # adjust our custom tree traversal to be compatible with it. diff --git a/baselines/diabetic_retinopathy_detection/input_utils.py b/baselines/diabetic_retinopathy_detection/input_utils.py index b18edf0bf..bcf2e717c 100644 --- a/baselines/diabetic_retinopathy_detection/input_utils.py +++ b/baselines/diabetic_retinopathy_detection/input_utils.py @@ -277,7 +277,7 @@ def _prepare(x): # https://github.com/tensorflow/tensorflow/issues/33254#issuecomment-542379165 return np.asarray(memoryview(x)) - it = (jax.tree_map(_prepare, xs) for xs in it) + it = (jax.tree.map(_prepare, xs) for xs in it) if n_prefetch: it = flax.jax_utils.prefetch_to_device(it, n_prefetch, devices=devices) diff --git a/baselines/diabetic_retinopathy_detection/jax_finetune_batchensemble.py b/baselines/diabetic_retinopathy_detection/jax_finetune_batchensemble.py index c81546d5b..33da259c4 100644 --- a/baselines/diabetic_retinopathy_detection/jax_finetune_batchensemble.py +++ b/baselines/diabetic_retinopathy_detection/jax_finetune_batchensemble.py @@ -253,7 +253,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) diff --git a/baselines/diabetic_retinopathy_detection/jax_finetune_deterministic.py b/baselines/diabetic_retinopathy_detection/jax_finetune_deterministic.py index 9fa236443..0cd08e9af 100644 --- a/baselines/diabetic_retinopathy_detection/jax_finetune_deterministic.py +++ b/baselines/diabetic_retinopathy_detection/jax_finetune_deterministic.py @@ -255,7 +255,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -312,7 +312,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -336,7 +336,7 @@ def decay_fn(v, wd): target=train_utils.tree_map_with_regex(decay_fn, opt.target, decay_rules)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) diff --git a/baselines/diabetic_retinopathy_detection/jax_finetune_sngp.py b/baselines/diabetic_retinopathy_detection/jax_finetune_sngp.py index 02f9dce46..d7472b58c 100644 --- a/baselines/diabetic_retinopathy_detection/jax_finetune_sngp.py +++ b/baselines/diabetic_retinopathy_detection/jax_finetune_sngp.py @@ -267,7 +267,7 @@ def init(rng): params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -359,7 +359,7 @@ def loss_fn(params, states, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -367,11 +367,11 @@ def loss_fn(params, states, images, labels): # useful in some cases across optimizers, hence it's in the main loop. if grad_clip_norm is not None: g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g) - g = jax.tree_map(lambda p: g_factor * p, g) + g = jax.tree.map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) measurements['reset_covmat'] = reset_covmat @@ -489,8 +489,8 @@ def loss_fn(params, states, images, labels): # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. - opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) - states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) + opt_cpu = jax.tree.map(lambda x: np.array(x[0]), opt_repl) + states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None diff --git a/baselines/diabetic_retinopathy_detection/train_utils.py b/baselines/diabetic_retinopathy_detection/train_utils.py index 6023ce7ff..494b8748f 100644 --- a/baselines/diabetic_retinopathy_detection/train_utils.py +++ b/baselines/diabetic_retinopathy_detection/train_utils.py @@ -70,7 +70,7 @@ def acc_grad_and_loss(i, l_and_g): (step_size, labels.shape[1])) li, gi = loss_and_grad_fn(params, imgs, lbls) l, g = l_and_g - return (l + li, jax.tree_map(lambda x, y: x + y, g, gi)) + return (l + li, jax.tree.map(lambda x, y: x + y, g, gi)) l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) return jax.tree_util.tree_map(lambda x: x / accum_steps, (l, g)) diff --git a/baselines/diabetic_retinopathy_detection/utils/vit_utils.py b/baselines/diabetic_retinopathy_detection/utils/vit_utils.py index e54cb5b60..08b71cfda 100644 --- a/baselines/diabetic_retinopathy_detection/utils/vit_utils.py +++ b/baselines/diabetic_retinopathy_detection/utils/vit_utils.py @@ -120,10 +120,10 @@ def acc_grad_and_loss(i, l_s_g): # Update state and accumulate gradient. l, s, g = l_s_g (li, si), gi = loss_and_grad_fn(params, s, imgs, lbls) - return (l + li, si, jax.tree_map(lambda x, y: x + y, g, gi)) + return (l + li, si, jax.tree.map(lambda x, y: x + y, g, gi)) l, s, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, s, g)) - l, g = jax.tree_map(lambda x: x / accum_steps, (l, g)) + l, g = jax.tree.map(lambda x: x / accum_steps, (l, g)) return (l, s), g else: return loss_and_grad_fn(params, states, images, labels) diff --git a/baselines/t5/utils.py b/baselines/t5/utils.py index f31482c94..e9abde40c 100644 --- a/baselines/t5/utils.py +++ b/baselines/t5/utils.py @@ -145,7 +145,7 @@ def _json_compat(value): json_dict['score'] = _json_compat(predictions) # This is a section where we deviate from the original function. if aux_values: - json_dict['intermediates'] = jax.tree_map(_json_compat, aux_values) + json_dict['intermediates'] = jax.tree.map(_json_compat, aux_values) elif mode == 'predict_batch_with_aux': assert vocabulary is not None # This is a section where we deviate from the original function. @@ -156,7 +156,7 @@ def _json_compat(value): predict_dict = {k: _json_compat(v) for k, v in predict_dict.items()} json_dict.update(predict_dict) - json_dict['aux'] = jax.tree_map(_json_compat, aux_values) + json_dict['aux'] = jax.tree.map(_json_compat, aux_values) else: raise ValueError(f'Invalid mode: {mode}') diff --git a/experimental/multimodal/checkpoint_utils.py b/experimental/multimodal/checkpoint_utils.py index 8407a0dbe..903fa8376 100644 --- a/experimental/multimodal/checkpoint_utils.py +++ b/experimental/multimodal/checkpoint_utils.py @@ -197,7 +197,7 @@ def _tree_flatten_with_names(tree): Returns: A list of values with names: [(name, value), ...]. """ - vals, tree_def = jax.tree_flatten(tree) + vals, tree_def = jax.tree.flatten(tree) # "Fake" token tree that is use to track jax internal tree traversal and # adjust our custom tree traversal to be compatible with it. diff --git a/experimental/multimodal/deterministic.py b/experimental/multimodal/deterministic.py index 13dcb724d..34ef1b84a 100644 --- a/experimental/multimodal/deterministic.py +++ b/experimental/multimodal/deterministic.py @@ -254,7 +254,7 @@ def init(rng): params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -385,7 +385,7 @@ def loss_fn(params, states, images, texts): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -398,7 +398,7 @@ def loss_fn(params, states, images, texts): opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) top1_idx = jnp.argmax(logits, axis=1) @@ -504,7 +504,7 @@ def loss_fn(params, states, images, texts): # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) - states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) + states_cpu = jax.tree.map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None diff --git a/experimental/multimodal/input_utils.py b/experimental/multimodal/input_utils.py index 84b0ef60e..77a6155d6 100644 --- a/experimental/multimodal/input_utils.py +++ b/experimental/multimodal/input_utils.py @@ -358,7 +358,7 @@ def _prepare(x): # https://github.com/tensorflow/tensorflow/issues/33254#issuecomment-542379165 return np.asarray(memoryview(x)) - it = (jax.tree_map(_prepare, xs) for xs in it) + it = (jax.tree.map(_prepare, xs) for xs in it) if n_prefetch: it = flax.jax_utils.prefetch_to_device(it, n_prefetch, devices=devices) diff --git a/experimental/near_ood/vit/deterministic.py b/experimental/near_ood/vit/deterministic.py index 614c4232a..19c3212d7 100644 --- a/experimental/near_ood/vit/deterministic.py +++ b/experimental/near_ood/vit/deterministic.py @@ -264,7 +264,7 @@ def init(rng): params_cpu = init(rng_init) if jax.process_index() == 0: - num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) + num_params = sum(p.size for p in jax.tree.flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @@ -383,7 +383,7 @@ def loss_fn(params, images, labels): # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): - grads, _ = jax.tree_flatten(g) + grads, _ = jax.tree.flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g @@ -396,7 +396,7 @@ def loss_fn(params, images, labels): opt = opt.replace(target=weight_decay_fn(opt.target, lr)) - params, _ = jax.tree_flatten(opt.target) + params, _ = jax.tree.flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements diff --git a/experimental/robust_segvit/custom_segmentation_trainer.py b/experimental/robust_segvit/custom_segmentation_trainer.py index ca4294fb0..79597bb6d 100644 --- a/experimental/robust_segvit/custom_segmentation_trainer.py +++ b/experimental/robust_segvit/custom_segmentation_trainer.py @@ -883,9 +883,9 @@ def train( train_summary = train_utils.log_train_summary( step=step, - train_metrics=jax.tree_map(train_utils.unreplicate_and_get, + train_metrics=jax.tree.map(train_utils.unreplicate_and_get, train_metrics), - extra_training_logs=jax.tree_map(train_utils.unreplicate_and_get, + extra_training_logs=jax.tree.map(train_utils.unreplicate_and_get, extra_training_logs), writer=writer) diff --git a/experimental/robust_segvit/custom_segmentation_trainer_test.py b/experimental/robust_segvit/custom_segmentation_trainer_test.py index 615f2e2cb..4ad323285 100644 --- a/experimental/robust_segvit/custom_segmentation_trainer_test.py +++ b/experimental/robust_segvit/custom_segmentation_trainer_test.py @@ -112,7 +112,7 @@ def train_and_evaluation(self, model, train_state, fake_batches, metrics_fn): metrics = train_utils.unreplicate_and_get(metrics) eval_metrics.append(metrics) eval_metrics = train_utils.stack_forest(eval_metrics) - eval_summary = jax.tree_map(lambda x: x.sum(), eval_metrics) + eval_summary = jax.tree.map(lambda x: x.sum(), eval_metrics) for key, val in eval_summary.items(): eval_summary[key] = val[0] / val[1] return eval_summary diff --git a/experimental/zero-shot-multi-modal/zero_shot_clip_evaluation.ipynb b/experimental/zero-shot-multi-modal/zero_shot_clip_evaluation.ipynb index c8a3c24ab..b5f83fe4b 100644 --- a/experimental/zero-shot-multi-modal/zero_shot_clip_evaluation.ipynb +++ b/experimental/zero-shot-multi-modal/zero_shot_clip_evaluation.ipynb @@ -163,7 +163,7 @@ " d['image'] = normalize(preprocess(d['image']))\n", " return d\n", " def _prepare(d):\n", - " return jax.tree_map(lambda x: x._numpy(), d)\n", + " return jax.tree.map(lambda x: x._numpy(), d)\n", " batched_dataset = ds.map(_preprocess).batch(batch_size)\n", " batched_dataset = map(_prepare, batched_dataset)\n", " return batched_dataset" diff --git a/uncertainty_baselines/models/vit_gp_test.py b/uncertainty_baselines/models/vit_gp_test.py index 981c3b281..397eb7df1 100644 --- a/uncertainty_baselines/models/vit_gp_test.py +++ b/uncertainty_baselines/models/vit_gp_test.py @@ -64,7 +64,7 @@ def test_vision_transformer(self, classifier, representation_size, key = jax.random.PRNGKey(0) variables = model.init(key, inputs, train=False) - param_count = sum(p.size for p in jax.tree_flatten(variables)[0]) + param_count = sum(p.size for p in jax.tree.flatten(variables)[0]) self.assertEqual(param_count, expected_param_count) logits, outputs = model.apply(variables, inputs, train=False) diff --git a/uncertainty_baselines/models/vit_test.py b/uncertainty_baselines/models/vit_test.py index 20d1b76b0..832674d3a 100644 --- a/uncertainty_baselines/models/vit_test.py +++ b/uncertainty_baselines/models/vit_test.py @@ -54,7 +54,7 @@ def test_vision_transformer(self, classifier, representation_size, key = jax.random.PRNGKey(0) variables = model.init(key, inputs, train=False) - param_count = sum(p.size for p in jax.tree_flatten(variables)[0]) + param_count = sum(p.size for p in jax.tree.flatten(variables)[0]) self.assertEqual(param_count, expected_param_count) logits, outputs = model.apply(variables, inputs, train=False) diff --git a/uncertainty_baselines/models/vit_tram_test.py b/uncertainty_baselines/models/vit_tram_test.py index af5acacb6..c9a95b7b3 100644 --- a/uncertainty_baselines/models/vit_tram_test.py +++ b/uncertainty_baselines/models/vit_tram_test.py @@ -79,7 +79,7 @@ def test_vision_transformer_tram(self, classifier, representation_size, key = jax.random.PRNGKey(0) variables = model.init(key, inputs, pi_inputs, train=False) - param_count = sum(p.size for p in jax.tree_flatten(variables)[0]) + param_count = sum(p.size for p in jax.tree.flatten(variables)[0]) self.assertEqual(param_count, expected_param_count) logits, outputs = model.apply(variables, inputs, pi_inputs, train=False)