From 2d10e875164d6b2682ccfae8d0876d8da991c95e Mon Sep 17 00:00:00 2001 From: Cemlyn Date: Thu, 13 Oct 2022 21:59:33 +0100 Subject: [PATCH] fix: fixed type hinting for `extra_loss_fns` in mappo (#234) --- brax/experimental/composer/training/mappo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brax/experimental/composer/training/mappo.py b/brax/experimental/composer/training/mappo.py index 1fd65e84f..1e5b2bf45 100644 --- a/brax/experimental/composer/training/mappo.py +++ b/brax/experimental/composer/training/mappo.py @@ -83,7 +83,7 @@ def compute_ppo_loss( lambda_: float = 0.95, ppo_epsilon: float = 0.3, extra_loss_update_ratios: Optional[Dict[str, float]] = None, - extra_loss_fns: Optional[Dict[str, Callable[[ppo.StepData], + extra_loss_fns: Optional[Dict[str, Callable[[StepData], jnp.ndarray]]] = None, action_shapes: Dict[str, Dict[str, Any]] = None, agent_name: str = None,