Skip to content

Commit

Permalink
jax.numpy.clip: update use of deprecated arguments.
Browse files Browse the repository at this point in the history
- a is now positional-only
- a_min is now min
- a_max is now max

The old argument names have been deprecated since JAX v0.4.27.

PiperOrigin-RevId: 714108629
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Jan 14, 2025
1 parent 81a2e2e commit abd1dde
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion baselines/t5/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# Clip the last value caused by the `diff` operator of the padding
# value 0.0 and the last cumulative sum. That value is positive because
# the sum of all log probabilities is negative.
finished_token_scores = jnp.clip(finished_token_scores, a_max=0.)
finished_token_scores = jnp.clip(finished_token_scores, max=0.)

if return_token_scores:
return finished_seqs[:, :, 1:], finished_token_scores
Expand Down

0 comments on commit abd1dde

Please sign in to comment.