Skip to content

Commit

Permalink
[JAX] Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.

PiperOrigin-RevId: 557089187
  • Loading branch information
Dopamine Team authored and psc-g committed Nov 27, 2023
1 parent 76990f6 commit b5fed9a
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions dopamine/jax/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import jax.numpy as jnp


def huber_loss(targets: jnp.array,
predictions: jnp.array,
delta: float = 1.0) -> jnp.ndarray:
def huber_loss(
targets: jnp.ndarray, predictions: jnp.ndarray, delta: float = 1.0
) -> jnp.ndarray:
"""Implementation of the Huber loss with threshold delta.
Let `x = |targets - predictions|`, the Huber loss is defined as:
Expand All @@ -40,12 +40,13 @@ def huber_loss(targets: jnp.array,
0.5 * delta**2 + delta * (x - delta))


def mse_loss(targets: jnp.array, predictions: jnp.array) -> jnp.ndarray:
def mse_loss(targets: jnp.ndarray, predictions: jnp.ndarray) -> jnp.ndarray:
"""Implementation of the mean squared error loss."""
return jnp.power((targets - predictions), 2)


def softmax_cross_entropy_loss_with_logits(labels: jnp.array,
logits: jnp.array) -> jnp.ndarray:
def softmax_cross_entropy_loss_with_logits(
labels: jnp.ndarray, logits: jnp.ndarray
) -> jnp.ndarray:
"""Implementation of the softmax cross entropy loss."""
return -jnp.sum(labels * nn.log_softmax(logits))

0 comments on commit b5fed9a

Please sign in to comment.