Skip to content

Commit

Permalink
models and pred edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ilibarra committed Jun 30, 2024
1 parent 4495680 commit e90167c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 3 additions & 1 deletion mubind/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ def optimize_simple(self,
# print('use_tqdm', use_tqdm)

for epoch in tqdm(range(num_epochs)) if use_tqdm else range(num_epochs):
# print('train')
self.train()
# for epoch in range(num_epochs):
running_loss = 0
running_loss_sym_weights = 0
Expand Down Expand Up @@ -1316,7 +1318,7 @@ def vprint(*args, **kwargs):
key=lambda x: x[-1],
)
if verbose != 0:
print("sorted")
print("filter rearrangments (sorted by observed r2)")
best_df = pd.DataFrame(best, columns=["expand.left", "expand.right", "shift", "model",
'pos_w_sum', 'width', "loss_diff_pct", "loss", 'r2'],
)
Expand Down
7 changes: 6 additions & 1 deletion mubind/tl/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,13 @@ def optimize_simple(

# the total number of trials
n_trials = sum([d.dataset.rounds.shape[0] for d in (dataloader if isinstance(dataloader, list) else [dataloader])])

for epoch in range(num_epochs):

# declare train statement
# model.train()
# assert False

running_loss = 0
running_crit = 0
running_rec = 0
Expand Down

0 comments on commit e90167c

Please sign in to comment.