From e90167c53ff1e5cca7403c4b582dfb55daae7de3 Mon Sep 17 00:00:00 2001 From: ilibarra Date: Sun, 30 Jun 2024 23:57:51 +0200 Subject: [PATCH] models and pred edits --- mubind/models/models.py | 4 +++- mubind/tl/prediction.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mubind/models/models.py b/mubind/models/models.py index 7829898..050acc3 100644 --- a/mubind/models/models.py +++ b/mubind/models/models.py @@ -568,6 +568,8 @@ def optimize_simple(self, # print('use_tqdm', use_tqdm) for epoch in tqdm(range(num_epochs)) if use_tqdm else range(num_epochs): + # print('train') + self.train() # for epoch in range(num_epochs): running_loss = 0 running_loss_sym_weights = 0 @@ -1316,7 +1318,7 @@ def vprint(*args, **kwargs): key=lambda x: x[-1], ) if verbose != 0: - print("sorted") + print("filter rearrangments (sorted by observed r2)") best_df = pd.DataFrame(best, columns=["expand.left", "expand.right", "shift", "model", 'pos_w_sum', 'width', "loss_diff_pct", "loss", 'r2'], ) diff --git a/mubind/tl/prediction.py b/mubind/tl/prediction.py index 65a0460..70fc1db 100644 --- a/mubind/tl/prediction.py +++ b/mubind/tl/prediction.py @@ -137,8 +137,13 @@ def optimize_simple( # the total number of trials n_trials = sum([d.dataset.rounds.shape[0] for d in (dataloader if isinstance(dataloader, list) else [dataloader])]) - + for epoch in range(num_epochs): + + # declare train statement + # model.train() + # assert False + running_loss = 0 running_crit = 0 running_rec = 0