Skip to content

Commit

Permalink
checkpointer logging, script fixes post-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
pvankatwyk committed Oct 30, 2024
1 parent 69c7b33 commit cd04ddc
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 211 deletions.
74 changes: 45 additions & 29 deletions ise/models/density_estimators/normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,32 +87,43 @@ def fit(self, X, y, epochs=100, batch_size=64, save_checkpoints=True, checkpoint
checkpointer = CheckpointSaver(self, self.optimizer, checkpoint_path, verbose)
checkpointer.best_loss = best_loss

for epoch in range(start_epoch, epochs + 1):
epoch_loss = []
for i, (x, y) in enumerate(data_loader):
x = x.to(self.device).view(x.shape[0], -1)
y = y.to(self.device)
self.optimizer.zero_grad()
loss = torch.mean(-self.flow.log_prob(inputs=y, context=x))
loss.backward()
self.optimizer.step()
epoch_loss.append(loss.item())
average_epoch_loss = sum(epoch_loss) / len(epoch_loss)

if save_checkpoints:
checkpointer(average_epoch_loss)
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
if verbose:
print("Early stopping")
break

if start_epoch < epochs:
for epoch in range(start_epoch, epochs + 1):
epoch_loss = []
for i, (x, y) in enumerate(data_loader):
x = x.to(self.device).view(x.shape[0], -1)
y = y.to(self.device)
self.optimizer.zero_grad()
loss = torch.mean(-self.flow.log_prob(inputs=y, context=x))
loss.backward()
self.optimizer.step()
epoch_loss.append(loss.item())
average_epoch_loss = sum(epoch_loss) / len(epoch_loss)

if save_checkpoints:
checkpointer(average_epoch_loss, epoch)
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
if verbose:
print("Early stopping")
break

if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], loss: {average_epoch_loss}{f' -- {checkpointer.log}' if save_checkpoints else ''}")
else:
if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], loss: {average_epoch_loss}{f' -- {checkpointer.log}' if early_stopping else ''}")

print(f"Training already completed ({epochs}/{epochs}).")
self.trained = True

if early_stopping:
self.load_state_dict(torch.load(checkpoint_path))
if save_checkpoints:
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint.keys():
self.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.best_loss = checkpoint['best_loss']
self.epochs_trained = checkpoint['epoch']
else:
self.load_state_dict(checkpoint)
os.remove(checkpoint_path)

def sample(self, features, num_samples, return_type="numpy"):
Expand Down Expand Up @@ -150,7 +161,9 @@ def save(self, path):
metadata = {
"input_size": self.num_input_features,
"output_size": self.num_predicted_sle,
"device": self.device
"device": self.device,
"best_loss": self.best_loss,
"epochs_trained": self.epochs_trained,
}
metadata_path = path + "_metadata.json"

Expand All @@ -177,11 +190,14 @@ def load(path):

checkpoint = torch.load(path, map_location="cpu" if not torch.cuda.is_available() else None)

#
# model.load_state_dict(checkpoint['model_state_dict'])
model.load_state_dict(checkpoint)
# model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# model.trained = checkpoint['trained']
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint.keys():
model.load_state_dict(checkpoint['model_state_dict'])
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.trained = checkpoint['trained']
else:
model.load_state_dict(checkpoint)
model.trained = True

model.trained = True
model.to(model.device)
model.eval()
Expand Down
6 changes: 6 additions & 0 deletions ise/models/predictors/deep_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def save(self, model_path):
"output_size": member.output_size,
"trained": member.trained,
"path": os.path.join("ensemble_members", f"member_{i+1}.pth"),
"best_loss": float(member.best_loss),
"epochs_trained": int(member.epochs_trained),
}
for i, member in enumerate(self.ensemble_members)
],
Expand All @@ -129,6 +131,10 @@ def save(self, model_path):
member_path = os.path.join(ensemble_dir, f"member_{i+1}.pth")
torch.save(member.state_dict(), member_path)
print(f"Ensemble Member {i+1} saved to {member_path}")

print('Removing checkpoints after saving to model directory...')
[os.remove(member.checkpoint_path) for member in self.ensemble_members if hasattr(member, "checkpoint_path")]


@classmethod
def load(cls, model_path):
Expand Down
84 changes: 48 additions & 36 deletions ise/models/predictors/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def fit(
# Check if a checkpoint exists and load it
start_epoch = 1
best_loss = float("inf")
self.checkpoint_path = checkpoint_path
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.load_state_dict(checkpoint['model_state_dict'])
Expand Down Expand Up @@ -134,47 +135,58 @@ def fit(
checkpointer.best_loss = best_loss

# Training loop
for epoch in range(start_epoch, epochs + 1):
self.train()
batch_losses = []
for i, (x, y) in enumerate(data_loader):
x = x.to(self.device)
y = y.to(self.device)
self.optimizer.zero_grad()
y_pred = self.forward(x)
loss = self.criterion(y_pred, y) # Renamed to 'loss' for clarity
loss.backward()
self.optimizer.step()
batch_losses.append(loss.item())

# Print average batch loss and validation loss (if provided)
if validate:
val_preds = self.predict(
X_val, sequence_length=sequence_length, batch_size=batch_size
).to(self.device)
val_loss = F.mse_loss(val_preds.squeeze(), y_val.squeeze())

if save_checkpoints:
checkpointer(val_loss)

if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
if verbose:
print("Early stopping")
break

if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {sum(batch_losses) / len(batch_losses)}, val mse: {val_loss:.6f} -- {getattr(checkpointer, 'log', '')}")
else:
average_batch_loss = sum(batch_losses) / len(batch_losses)
if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {average_batch_loss}")
if start_epoch < epochs:
for epoch in range(start_epoch, epochs + 1):
self.train()
batch_losses = []
for i, (x, y) in enumerate(data_loader):
x = x.to(self.device)
y = y.to(self.device)
self.optimizer.zero_grad()
y_pred = self.forward(x)
loss = self.criterion(y_pred, y) # Renamed to 'loss' for clarity
loss.backward()
self.optimizer.step()
batch_losses.append(loss.item())

# Print average batch loss and validation loss (if provided)
if validate:
val_preds = self.predict(
X_val, sequence_length=sequence_length, batch_size=batch_size
).to(self.device)
val_loss = F.mse_loss(val_preds.squeeze(), y_val.squeeze())

if save_checkpoints:
checkpointer(val_loss, epoch)

if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
if verbose:
print("Early stopping")
break

if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {sum(batch_losses) / len(batch_losses)}, val mse: {val_loss:.6f} -- {getattr(checkpointer, 'log', '')}")
else:
average_batch_loss = sum(batch_losses) / len(batch_losses)
if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {average_batch_loss}")
else:
if verbose:
print(f"Training already completed ({epochs}/{epochs}).")

self.trained = True

# loads best model
if save_checkpoints:
self.load_state_dict(torch.load(checkpoint_path))
os.remove(checkpoint_path)
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint.keys():
self.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.best_loss = checkpoint['best_loss']
self.epochs_trained = checkpoint['epoch']
else:
self.load_state_dict(checkpoint)
# os.remove(checkpoint_path)

def predict(self, X, sequence_length=5, batch_size=64, dataclass=EmulatorDataset):
self.eval()
Expand Down
17 changes: 10 additions & 7 deletions ise/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@ def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, che
self.optimizer = optimizer
self.best_loss = float('inf')
self.verbose = verbose
self.log = None

def __call__(self, loss, epoch, save_best_only=True, path=None):
def __call__(self, loss, epoch, save_best_only=True,):
is_better = self._determine_if_better(loss) if save_best_only else True

if is_better or not save_best_only: # Save if loss improves or save_best_only is False
self.save_checkpoint(epoch, loss, path)
self.save_checkpoint(epoch, loss, self.checkpoint_path)
if self.verbose:
print(f"Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving checkpoint.")
self.log = f"Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving checkpoint to {self.checkpoint_path}."
self._update_best_loss(loss)
return True
else:
self.log = ""
return False

def _determine_if_better(self, loss: float):
Expand All @@ -36,8 +39,8 @@ def save_checkpoint(self, epoch, loss, path: str = None):
'best_loss': self.best_loss,
}
torch.save(checkpoint, checkpoint_path)
if self.verbose:
print(f"Checkpoint saved to {checkpoint_path}")
# if self.verbose:
# print(f"Checkpoint saved to {checkpoint_path}")

def load_checkpoint(self, path: str = None):
checkpoint_path = path or self.checkpoint_path
Expand All @@ -57,8 +60,8 @@ def __init__(self, model, optimizer, checkpoint_path='checkpoint.pt', patience=1
self.counter = 0
self.early_stop = False

def __call__(self, loss, epoch, save_best_only=True, path=None):
saved = super().__call__(loss, epoch, save_best_only, path)
def __call__(self, loss, epoch, save_best_only=True,):
saved = super().__call__(loss, epoch, save_best_only,)
if saved:
self.counter = 0 # Reset counter if the model improved
else:
Expand Down
13 changes: 9 additions & 4 deletions manuscripts/ISEFlow/scripts/get_best_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,20 @@ def get_best_nn(data_directory, export_directory, iterations=10, with_chars=True
X_val_df, _ = f.get_X_y(pd.read_csv(f"{data_directory}/val.csv"), 'sectors', return_format='pandas', with_chars=with_chars)
X_val, y_val = f.get_X_y(pd.read_csv(f"{data_directory}/val.csv"), 'sectors', return_format='numpy', with_chars=with_chars)
cur_time = time.time()
de = DeepEnsemble(num_predictors=num_predictors, forcing_size=X_train.shape[1], )
nf = NormalizingFlow(forcing_size=X_train.shape[1])
de = DeepEnsemble(num_ensemble_members=num_predictors, input_size=X_train.shape[1], )
nf = NormalizingFlow(input_size=X_train.shape[1])
emulator = ISEFlow(de, nf)

nf_epochs = 100
de_epochs = 100
train_time_start = time.time()
print('\n\nTraining model with ', num_predictors, 'predictors,', nf_epochs, 'NF epochs, and', de_epochs, 'DE epochs')
emulator.fit(X_train, y_train, X_val=X_val, y_val=y_val, early_stopping=True, patience=20, delta=1e-5, nf_epochs=nf_epochs, de_epochs=de_epochs, early_stopping_path=f"checkpoint_{ice_sheet}")
emulator.fit(
X_train, y_train, X_val=X_val, y_val=y_val,
save_checkpoints=True, checkpoint_path=f"checkpoint_{ice_sheet}",
early_stopping=True, patience=20,
nf_epochs=nf_epochs, de_epochs=de_epochs,
)
train_time_end = time.time()
total_train_time = (train_time_end - train_time_start) / 60.0

Expand Down Expand Up @@ -70,7 +75,7 @@ def get_best_nn(data_directory, export_directory, iterations=10, with_chars=True
if __name__ == '__main__':
ICE_SHEET = 'GrIS'
WITH_CHARS = True
ITERATIONS = 10
ITERATIONS = 1
DATA_DIRECTORY = f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/data/ml/{ICE_SHEET}/'
EXPORT_DIRECTORY = f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/models/all_variables/{"with_characteristics" if WITH_CHARS else "without_characteristics"}/{ICE_SHEET}/'
get_best_nn(DATA_DIRECTORY, EXPORT_DIRECTORY, iterations=ITERATIONS, with_chars=WITH_CHARS)
Expand Down
19 changes: 13 additions & 6 deletions manuscripts/ISEFlow/scripts/get_best_onlytemp_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def get_optimal_temponly_model(ice_sheet, out_dir, iterations=10, with_chars=Fal
cur_time = time.time()

# Initialize the model
de = DeepEnsemble(num_predictors=num_predictors, forcing_size=X_train.shape[1])
nf = NormalizingFlow(forcing_size=X_train.shape[1])
de = DeepEnsemble(num_ensemble_members=num_predictors, input_size=X_train.shape[1])
nf = NormalizingFlow(input_size=X_train.shape[1])
emulator = ISEFlow(de, nf)

# Randomly choose epochs for normalizing flow and deep ensemble training
Expand All @@ -94,10 +94,10 @@ def get_optimal_temponly_model(ice_sheet, out_dir, iterations=10, with_chars=Fal
train_time_start = time.time()
print(f"\n\nTraining model with {num_predictors} predictors, {nf_epochs} NF epochs, and {de_epochs} DE epochs")
emulator.fit(
X_train, y_train, X_val, y_val,
early_stopping=True, patience=10, delta=1e-5,
X_train, y_train, X_val=X_val, y_val=y_val,
save_checkpoints=True, checkpoint_path=f"{ice_sheet}_onlysmb_checkpoint.pt",
early_stopping=True, patience=10,
nf_epochs=nf_epochs, de_epochs=de_epochs,
early_stopping_path=f"{ice_sheet}_onlysmb_checkpoint.pt"
)
train_time_end = time.time()
total_train_time = (train_time_end - train_time_start) / 60.0
Expand Down Expand Up @@ -147,9 +147,16 @@ def get_optimal_temponly_model(ice_sheet, out_dir, iterations=10, with_chars=Fal
with_chars = False

# Call the main function to start the model training process
# get_optimal_temponly_model(
# ice_sheet,
# f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/models/isolated_variables/SMB_only/{ice_sheet}/',
# iterations=iterations,
# with_chars=with_chars
# )

get_optimal_temponly_model(
ice_sheet,
f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/models/isolated_variables/SMB_only/{ice_sheet}/',
f'/users/pvankatw/research/ise/delete/',
iterations=iterations,
with_chars=with_chars
)
Loading

0 comments on commit cd04ddc

Please sign in to comment.