diff --git a/dopamine/jax/losses.py b/dopamine/jax/losses.py index 736107cd..67c2dc72 100644 --- a/dopamine/jax/losses.py +++ b/dopamine/jax/losses.py @@ -17,9 +17,9 @@ import jax.numpy as jnp -def huber_loss(targets: jnp.array, - predictions: jnp.array, - delta: float = 1.0) -> jnp.ndarray: +def huber_loss( + targets: jnp.ndarray, predictions: jnp.ndarray, delta: float = 1.0 +) -> jnp.ndarray: """Implementation of the Huber loss with threshold delta. Let `x = |targets - predictions|`, the Huber loss is defined as: @@ -40,12 +40,13 @@ def huber_loss(targets: jnp.array, 0.5 * delta**2 + delta * (x - delta)) -def mse_loss(targets: jnp.array, predictions: jnp.array) -> jnp.ndarray: +def mse_loss(targets: jnp.ndarray, predictions: jnp.ndarray) -> jnp.ndarray: """Implementation of the mean squared error loss.""" return jnp.power((targets - predictions), 2) -def softmax_cross_entropy_loss_with_logits(labels: jnp.array, - logits: jnp.array) -> jnp.ndarray: +def softmax_cross_entropy_loss_with_logits( + labels: jnp.ndarray, logits: jnp.ndarray +) -> jnp.ndarray: """Implementation of the softmax cross entropy loss.""" return -jnp.sum(labels * nn.log_softmax(logits))