diff --git a/learned_optimization/research/discovered_policy_optimisation/truncated_step.py b/learned_optimization/research/discovered_policy_optimisation/truncated_step.py index 67a4596..1e4f188 100644 --- a/learned_optimization/research/discovered_policy_optimisation/truncated_step.py +++ b/learned_optimization/research/discovered_policy_optimisation/truncated_step.py @@ -236,7 +236,7 @@ def compute_drift_loss( if ppo_init: eps = 0.2 drift_ppo = flax.linen.relu( - (rho_s - jnp.clip(rho_s, a_min=1 - eps, a_max=1 + eps)) * advantages) + (rho_s - jnp.clip(rho_s, min=1 - eps, max=1 + eps)) * advantages) drift = out_fn(drift_ppo + drift - 0.0001) else: drift = out_fn(drift - 0.0001)