Skip to content

Commit

Permalink
Fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
asmith26 committed Sep 11, 2020
1 parent de152db commit 360fa5c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax_toolkit/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_haiku_loss_function(
def loss_function_wrapper(
params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray, rng: jnp.ndarray = None
) -> jnp.ndarray:
# rng argument can be used is net_transform.apply() is non-deterministic, and you require and "random seed"
# rng argument can be used if net_transform.apply() is non-deterministic, and you require a "random seed"
y_pred: jnp.ndarray = net_transform.apply(params, rng, x)
loss_value: jnp.ndarray = loss_function(y_true, y_pred)
return loss_value
Expand Down

0 comments on commit 360fa5c

Please sign in to comment.