Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 14, 2024
1 parent d25ea25 commit 89e63c8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
3 changes: 2 additions & 1 deletion setup_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ def models(args, gen_only=False):
from mpgan import Graph_GAN

G = Graph_GAN(gen=False, args=deepcopy(args))

logging.info(f"# of parameters in D: {count_parameters(D)}")

if args.load_model:
Expand Down Expand Up @@ -1531,6 +1531,7 @@ def get_model_args(args):

return model_train_args, model_eval_args, extra_args


def optimizers(args, G, D):
if args.spectral_norm_gen:
G_params = filter(lambda p: p.requires_grad, G.parameters())
Expand Down
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import time


def main():
start_time = time.time()
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -93,7 +94,7 @@ def main():
losses, best_epoch = setup_training.losses(args)
loss_calc_end_time = time.time()
logging.info("Loss calculation took: %s seconds" % (loss_calc_end_time - loss_calc_start_time))

train_start_time = time.time()
train(
args,
Expand All @@ -116,6 +117,7 @@ def main():
end_time = time.time()
logging.info("Total execution time: %s seconds" % (end_time - start_time))


def get_gen_noise(
model_args,
num_samples: int,
Expand Down Expand Up @@ -507,6 +509,7 @@ def calc_G_loss(loss, fake_outputs):
time_calc_G_loss = time.time() - start_time
return G_loss, time_calc_G_loss


def train_G(
model_args,
D,
Expand Down Expand Up @@ -791,7 +794,7 @@ def eval_save_plot(
else:
gen_mask = None
real_mask = None

gen_jets = gen_jets.numpy()
if gen_mask is not None:
gen_mask = gen_mask.numpy()
Expand Down Expand Up @@ -877,6 +880,7 @@ def eval_save_plot(
logging.info("Total save_lowest took: %s seconds" % (end_save_lowest - start_save_lowest))
logging.info("Total eval_save_plot took: %s seconds" % (total_end_time - start_time))


def train_loop(
args,
X_train_loaded,
Expand Down Expand Up @@ -976,6 +980,7 @@ def train_loop(
logging.info("Total calc_G_loss took: %s seconds" % (total_calc_G_loss))
logging.info("Total train_loop took: %s seconds" % (total_end_time - start_time))


def train(
args,
X_train,
Expand Down

0 comments on commit 89e63c8

Please sign in to comment.