From 89e63c89a9372d33b1199380ce761b45b410e895 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 00:55:23 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- setup_training.py | 3 ++- train.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/setup_training.py b/setup_training.py index e2b1bef..e832f53 100644 --- a/setup_training.py +++ b/setup_training.py @@ -1435,7 +1435,7 @@ def models(args, gen_only=False): from mpgan import Graph_GAN G = Graph_GAN(gen=False, args=deepcopy(args)) - + logging.info(f"# of parameters in D: {count_parameters(D)}") if args.load_model: @@ -1531,6 +1531,7 @@ def get_model_args(args): return model_train_args, model_eval_args, extra_args + def optimizers(args, G, D): if args.spectral_norm_gen: G_params = filter(lambda p: p.requires_grad, G.parameters()) diff --git a/train.py b/train.py index 8c4d6af..f80b0a0 100644 --- a/train.py +++ b/train.py @@ -25,6 +25,7 @@ import time + def main(): start_time = time.time() device = "cuda" if torch.cuda.is_available() else "cpu" @@ -93,7 +94,7 @@ def main(): losses, best_epoch = setup_training.losses(args) loss_calc_end_time = time.time() logging.info("Loss calculation took: %s seconds" % (loss_calc_end_time - loss_calc_start_time)) - + train_start_time = time.time() train( args, @@ -116,6 +117,7 @@ def main(): end_time = time.time() logging.info("Total execution time: %s seconds" % (end_time - start_time)) + def get_gen_noise( model_args, num_samples: int, @@ -507,6 +509,7 @@ def calc_G_loss(loss, fake_outputs): time_calc_G_loss = time.time() - start_time return G_loss, time_calc_G_loss + def train_G( model_args, D, @@ -791,7 +794,7 @@ def eval_save_plot( else: gen_mask = None real_mask = None - + gen_jets = gen_jets.numpy() if gen_mask is not None: gen_mask = gen_mask.numpy() @@ -877,6 +880,7 @@ def eval_save_plot( logging.info("Total save_lowest took: %s seconds" % (end_save_lowest - start_save_lowest)) logging.info("Total eval_save_plot took: %s seconds" % (total_end_time - start_time)) + def train_loop( args, X_train_loaded, @@ -976,6 +980,7 @@ def train_loop( logging.info("Total calc_G_loss took: %s seconds" % (total_calc_G_loss)) logging.info("Total train_loop took: %s seconds" % (total_end_time - start_time)) + def train( args, X_train,