From 87cbc6888ca685eb818687d3c35d467bd05da5c9 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 12:55:45 +0100 Subject: [PATCH] not allowing naked erm be combined with fbopt --- domainlab/algos/trainers/fbopt_mu_controller.py | 3 +++ tests/test_fbopt.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index ce53fc0df..272d34908 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -45,6 +45,9 @@ def __init__(self, trainer, **kwargs): self.mu_min = trainer.aconf.mu_min self.mu_clip = trainer.aconf.mu_clip + if not kwargs: + raise RuntimeError("feedback scheduler requires **kwargs, the set \ + of multipliers non-empty") self.mmu = kwargs # force initial value of mu self.mmu = {key: self.init_mu for key, val in self.mmu.items()} diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index c442bf090..84b97f0f9 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -1,6 +1,7 @@ """ unit and end-end test for deep all, mldg """ +import pytest from tests.utils_test import utils_test_algo @@ -27,13 +28,24 @@ def test_diva_fbopt(): args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3" utils_test_algo(args) + def test_erm_fbopt(): """ erm """ args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long + with pytest.raises(RuntimeError): + utils_test_algo(args) + + +def test_irm_fbopt(): + """ + irm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3" # pylint: disable=line-too-long utils_test_algo(args) + def test_forcesetpoint_fbopt(): """ diva