Skip to content

Commit

Permalink
Merge branch 'master' into mhof_dev
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Dec 10, 2024
2 parents ae907b2 + 163ca91 commit 924f4cc
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 4 deletions.
12 changes: 10 additions & 2 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch import optim
from torch.optim import lr_scheduler

from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler

Expand All @@ -27,7 +28,12 @@ def mk_opt(model, aconf):
# {'params': model._decoratee.parameters()}
# ], lr=aconf.lr)
optimizer = optim.Adam(list_par, lr=aconf.lr)
return optimizer
if aconf.lr_scheduler is not None:
class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler)
scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos)
else:
scheduler = None
return optimizer, scheduler


class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -102,6 +108,8 @@ def __init__(self, successor_node=None, extend=None):
self.dict_multiplier = {}
# MIRO
self.input_tensor_shape = None
# LR-scheduler
self.lr_scheduler = None

@property
def model(self):
Expand Down Expand Up @@ -178,7 +186,7 @@ def reset(self):
"""
make a new optimizer to clear internal state
"""
self.optimizer = mk_opt(self.model, self.aconf)
self.optimizer, self.lr_scheduler = mk_opt(self.model, self.aconf)

@abc.abstractmethod
def tr_epoch(self, epoch):
Expand Down
2 changes: 2 additions & 0 deletions domainlab/algos/trainers/train_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others)
loss.backward()
self.optimizer.step()
if self.lr_scheduler:
self.lr_scheduler.step()
self.after_batch(epoch, ind_batch)
self.counter_batch += 1

Expand Down
7 changes: 7 additions & 0 deletions domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ def mk_parser_main():
help="name of pytorch optimizer",
)

parser.add_argument(
"--lr_scheduler",
type=str,
default=None,
help="name of pytorch learning rate scheduler",
)

parser.add_argument(
"--param_idx",
type=bool,
Expand Down
27 changes: 27 additions & 0 deletions scripts/generate_latex_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
aggregate benchmark csv file to generate latex table
"""
import argparse
import pandas as pd


def gen_latex_table(raw_df, fname="table_perf.tex",
group="method", str_perf="acc"):
"""
aggregate benchmark csv file to generate latex table
"""
df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std", "count"])
latex_table = df_result.to_latex(float_format="%.3f")
str_table = df_result.to_string()
print(str_table)
with open(fname, 'w') as file:
file.write(latex_table)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Read a CSV file")
parser.add_argument("filename", help="Name of the CSV file to read")
args = parser.parse_args()

df = pd.read_csv(args.filename, index_col=False, skipinitialspace=True)
gen_latex_table(df)
5 changes: 3 additions & 2 deletions scripts/sh_genplot.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mkdir $2
python main_out.py --gen_plots $1 --outp_dir $2
# mkdir $2
sh scripts/merge_csvs.sh $1
python main_out.py --gen_plots merged_data.csv --outp_dir partial_agg_plots
14 changes: 14 additions & 0 deletions tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

"""
unit and end-end test for lr scheduler
"""
from tests.utils_test import utils_test_algo


def test_lr_scheduler():
"""
train
"""
args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \
--nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR"
utils_test_algo(args)

0 comments on commit 924f4cc

Please sign in to comment.