From 829eab90d9299b9cdb619d9a9afebe5abdeeb5b0 Mon Sep 17 00:00:00 2001 From: Dom Date: Tue, 18 Jul 2023 16:37:55 +0100 Subject: [PATCH] Allow additional attributes in trial result --- .../hydra_optuna_sweeper/_impl.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py index d17bbb1f70a..3a427e38464 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py @@ -352,23 +352,33 @@ def sweep(self, arguments: List[str]) -> None: self.job_idx += len(returns) failures = [] for trial, ret in zip(trials, returns): - values: Optional[List[float]] = None + return_value = ret.return_value + if isinstance(return_value, dict): + if "result" not in return_value: + raise KeyError("'result' key must be present in the return_value dictionary.") + result = return_value.pop("result") + user_attrs = return_value + for attr, attr_value in user_attrs.items(): + trial.set_user_attr(attr, attr_value) + else: + result = return_value state: optuna.trial.TrialState = optuna.trial.TrialState.COMPLETE + values: Optional[List[float]] = None try: if len(directions) == 1: try: - values = [float(ret.return_value)] + values = [float(result)] except (ValueError, TypeError): raise ValueError( - f"Return value must be float-castable. Got '{ret.return_value}'." + f"Return value must be float-castable. Got '{result}'." ).with_traceback(sys.exc_info()[2]) else: try: - values = [float(v) for v in ret.return_value] + values = [float(v) for v in result] except (ValueError, TypeError): raise ValueError( "Return value must be a list or tuple of float-castable values." - f" Got '{ret.return_value}'." + f" Got '{result}'." ).with_traceback(sys.exc_info()[2]) if len(values) != len(directions): raise ValueError(