From 78a9ad42516fae2680e9ef5bb15fefa5af8d7cd8 Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 3 Dec 2024 13:11:16 +0100 Subject: [PATCH 01/10] lr-scheduler in trainerBasic --- domainlab/algos/trainers/a_trainer.py | 8 ++++++-- domainlab/algos/trainers/train_basic.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index a51c34c14..b8d1de63b 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -5,6 +5,7 @@ import torch from torch import optim +from torch.optim.lr_scheduler import CosineAnnealingLR from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -16,6 +17,7 @@ def mk_opt(model, aconf): if model._decoratee is None: class_opt = getattr(optim, aconf.opt) optimizer = class_opt(model.parameters(), lr=aconf.lr) + scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) else: var1 = model.parameters() var2 = model._decoratee.parameters() @@ -27,7 +29,7 @@ def mk_opt(model, aconf): # {'params': model._decoratee.parameters()} # ], lr=aconf.lr) optimizer = optim.Adam(list_par, lr=aconf.lr) - return optimizer + return optimizer, scheduler class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta): @@ -94,6 +96,8 @@ def __init__(self, successor_node=None, extend=None): self.list_reg_over_task_ratio = None # MIRO self.input_tensor_shape = None + # LR-scheduler + self.lr_scheduler = None @property def model(self): @@ -168,7 +172,7 @@ def reset(self): """ make a new optimizer to clear internal state """ - self.optimizer = mk_opt(self.model, self.aconf) + self.optimizer, self.lr_scheduler = mk_opt(self.model, self.aconf) @abc.abstractmethod def tr_epoch(self, epoch): diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 179848467..10ac3b06f 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -82,6 +82,8 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss.backward() self.optimizer.step() + if self.lr_scheduler: + self.lr_scheduler.step() self.after_batch(epoch, ind_batch) self.counter_batch += 1 From 1d2dc78e4e4d646e74d34f0f287222fd26411ff3 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 17:46:34 +0100 Subject: [PATCH 02/10] scheduler always exsit --- domainlab/algos/trainers/a_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index b8d1de63b..2de178320 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -14,10 +14,10 @@ def mk_opt(model, aconf): """ create optimizer """ + scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) if model._decoratee is None: class_opt = getattr(optim, aconf.opt) optimizer = class_opt(model.parameters(), lr=aconf.lr) - scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) else: var1 = model.parameters() var2 = model._decoratee.parameters() From 41bc0e70e9fdafc7133658359d689fe356e20a55 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 17:55:16 +0100 Subject: [PATCH 03/10] lr scheduler via cmd arguments --- domainlab/algos/trainers/a_trainer.py | 8 ++++++-- domainlab/arg_parser.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 2de178320..051cc1e6f 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -5,7 +5,7 @@ import torch from torch import optim -from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim import lr_scheduler from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -14,7 +14,6 @@ def mk_opt(model, aconf): """ create optimizer """ - scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) if model._decoratee is None: class_opt = getattr(optim, aconf.opt) optimizer = class_opt(model.parameters(), lr=aconf.lr) @@ -29,6 +28,11 @@ def mk_opt(model, aconf): # {'params': model._decoratee.parameters()} # ], lr=aconf.lr) optimizer = optim.Adam(list_par, lr=aconf.lr) + if aconf.lr_scheduler is not None: + class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler) + scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos) + else: + scheduler = None return optimizer, scheduler diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 046810a66..47272d7fb 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -264,6 +264,13 @@ def mk_parser_main(): help="name of pytorch optimizer", ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="CosineAnnealingLR", + help="name of pytorch learning rate scheduler", + ) + parser.add_argument( "--param_idx", type=bool, From de004b5c1334b3060013c66991179e7c415a5711 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 19:35:01 +0100 Subject: [PATCH 04/10] unit test --- tests/test_lr_scheduler.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/test_lr_scheduler.py diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py new file mode 100644 index 000000000..7bb7ab92f --- /dev/null +++ b/tests/test_lr_scheduler.py @@ -0,0 +1,14 @@ + +""" +unit and end-end test for lr scheduler +""" +from tests.utils_test import utils_test_algo + + +def test_lr_scheduler(): + """ + train + """ + args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \ + --nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR" + utils_test_algo(args) From 06257726b48178d84188c690b99ab466321c2d4b Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 4 Dec 2024 19:39:40 +0100 Subject: [PATCH 05/10] Update arg_parser.py --- domainlab/arg_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 47272d7fb..bb7bda2b4 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -267,7 +267,7 @@ def mk_parser_main(): parser.add_argument( "--lr_scheduler", type=str, - default="CosineAnnealingLR", + default=None, help="name of pytorch learning rate scheduler", ) From 51c26d1e724ff9305aab216c33c2a9a26a5c56f5 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 13:37:20 +0100 Subject: [PATCH 06/10] latex table in scirpt --- scripts/generate_latex_table.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 scripts/generate_latex_table.py diff --git a/scripts/generate_latex_table.py b/scripts/generate_latex_table.py new file mode 100644 index 000000000..4b35f4209 --- /dev/null +++ b/scripts/generate_latex_table.py @@ -0,0 +1,25 @@ +""" +aggregate benchmark csv file to generate latex table +""" +import argparse +import pandas as pd + + +def gen_latex_table(raw_df, fname="table_perf.tex", + group="method", str_perf="acc"): + """ + aggregate benchmark csv file to generate latex table + """ + df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std"]) + latex_table = df_result.to_latex(float_format="%.3f") + with open(fname, 'w') as file: + file.write(latex_table) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Read a CSV file") + parser.add_argument("filename", help="Name of the CSV file to read") + args = parser.parse_args() + + df = pd.read_csv(args.filename, index_col=False, skipinitialspace=True) + gen_latex_table(df) From 1b98e5831a75d3b1c160f42a4666c3cee01abf05 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Dec 2024 13:43:24 +0100 Subject: [PATCH 07/10] csv agg mean, std to text table --- scripts/generate_latex_table.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/generate_latex_table.py b/scripts/generate_latex_table.py index 4b35f4209..caa333aab 100644 --- a/scripts/generate_latex_table.py +++ b/scripts/generate_latex_table.py @@ -12,6 +12,8 @@ def gen_latex_table(raw_df, fname="table_perf.tex", """ df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std"]) latex_table = df_result.to_latex(float_format="%.3f") + str_table = df_result.to_string() + print(str_table) with open(fname, 'w') as file: file.write(latex_table) From e6ac695ebfc38c32c0b9da0e12f1e1209c206b73 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Mon, 9 Dec 2024 16:55:19 +0100 Subject: [PATCH 08/10] Update generate_latex_table.py --- scripts/generate_latex_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate_latex_table.py b/scripts/generate_latex_table.py index caa333aab..ebade0583 100644 --- a/scripts/generate_latex_table.py +++ b/scripts/generate_latex_table.py @@ -10,7 +10,7 @@ def gen_latex_table(raw_df, fname="table_perf.tex", """ aggregate benchmark csv file to generate latex table """ - df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std"]) + df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std", "count"]) latex_table = df_result.to_latex(float_format="%.3f") str_table = df_result.to_string() print(str_table) From b5cc60a76d68c70d7d75e973a44c8c31981aa788 Mon Sep 17 00:00:00 2001 From: smilesun Date: Mon, 9 Dec 2024 17:06:08 +0100 Subject: [PATCH 09/10] . --- scripts/sh_genplot.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/sh_genplot.sh b/scripts/sh_genplot.sh index a1b1ecfad..9906e5f2e 100755 --- a/scripts/sh_genplot.sh +++ b/scripts/sh_genplot.sh @@ -1,2 +1,3 @@ mkdir $2 -python main_out.py --gen_plots $1 --outp_dir $2 +merge_csvs.sh +python main_out.py --gen_plots merged_data.csv --outp_dir $2 From 163ca911f4b7803f63404ccf376ed18e77108de7 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Mon, 9 Dec 2024 17:11:18 +0100 Subject: [PATCH 10/10] Update sh_genplot.sh --- scripts/sh_genplot.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/sh_genplot.sh b/scripts/sh_genplot.sh index 9906e5f2e..39f71f16d 100755 --- a/scripts/sh_genplot.sh +++ b/scripts/sh_genplot.sh @@ -1,3 +1,3 @@ -mkdir $2 -merge_csvs.sh -python main_out.py --gen_plots merged_data.csv --outp_dir $2 +# mkdir $2 +sh scripts/merge_csvs.sh $1 +python main_out.py --gen_plots merged_data.csv --outp_dir partial_agg_plots \ No newline at end of file