From 0669ff43e018e71cc36a4c721748fdc17fc85797 Mon Sep 17 00:00:00 2001 From: Joel Oskarsson Date: Thu, 29 Feb 2024 11:50:27 +0100 Subject: [PATCH] Re-define RMSE metric to take sqrt after sample averaging (#10) --- neural_lam/metrics.py | 37 ----------------------------------- neural_lam/models/ar_model.py | 31 ++++++++++++++++++----------- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py index 93014fd3..7db2cca6 100644 --- a/neural_lam/metrics.py +++ b/neural_lam/metrics.py @@ -108,42 +108,6 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) -def rmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): - """ - Root Mean Squared Error - Note: here take sqrt only after spatial averaging, averaging the RMSE - of forecasts. - This is consistent with Weatherbench and others. - Because of this, averaging over grid must be set to true. - - (...,) is any number of batch dimensions, potentially different - but broadcastable - pred: (..., N, d_state), prediction - target: (..., N, d_state), target - pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. - mask: (N,), boolean mask describing which grid nodes to use in metric - average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum - over d_state) - - Returns: - metric_val: One of (...,), (..., d_state), depending on reduction arguments - """ - assert average_grid, "Can not compute RMSE without averaging grid" - - # Spatially averaged mse, masking is also performed here - averaged_mse = mse( - pred, target, pred_std, mask, average_grid=True, sum_vars=False - ) # (..., d_state) - entry_rmse = torch.sqrt(averaged_mse) # (..., d_state) - - # Optionally sum over variables here manually - if sum_vars: - return torch.sum(entry_rmse, dim=-1) # (...,) - - return entry_rmse # (..., d_state) - - def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ Weighted Mean Absolute Error @@ -266,7 +230,6 @@ def crps_gauss( DEFINED_METRICS = { "mse": mse, "mae": mae, - "rmse": rmse, "wmse": wmse, "wmae": wmae, "nll": nll, diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index d28aa36c..42e0e3e0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -75,10 +75,10 @@ def __init__(self, args): self.step_length = args.step_length # Number of hours per pred. step self.val_metrics = { - "rmse": [], + "mse": [], } self.test_metrics = { - "rmse": [], + "mse": [], "mae": [], } if self.output_std: @@ -238,7 +238,9 @@ def all_gather_cat(self, tensor_to_gather): """ return self.all_gather(tensor_to_gather).flatten(0, 1) - def validation_step(self, batch): + # newer lightning versions requires batch_idx argument, even if unused + # pylint: disable-next=unused-argument + def validation_step(self, batch, batch_idx): """ Run validation on single batch """ @@ -262,15 +264,15 @@ def validation_step(self, batch): val_log_dict, on_step=False, on_epoch=True, sync_dist=True ) - # Store RMSEs - entry_rmses = metrics.rmse( + # Store MSEs + entry_mses = metrics.mse( prediction, target, pred_std, mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) - self.val_metrics["rmse"].append(entry_rmses) + self.val_metrics["mse"].append(entry_mses) def on_validation_epoch_end(self): """ @@ -283,7 +285,8 @@ def on_validation_epoch_end(self): for metric_list in self.val_metrics.values(): metric_list.clear() - def test_step(self, batch): + # pylint: disable-next=unused-argument + def test_step(self, batch, batch_idx): """ Run test on single batch """ @@ -314,7 +317,7 @@ def test_step(self, batch): # Note: explicitly list metrics here, as test_metrics can contain # additional ones, computed differently, but that should be aggregated # on_test_epoch_end - for metric_name in ("rmse", "mae"): + for metric_name in ("mse", "mae"): metric_func = metrics.get_metric(metric_name) batch_metric_vals = metric_func( prediction, @@ -508,10 +511,16 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): ) # (N_eval, pred_steps, d_f) if self.trainer.is_global_zero: + metric_tensor_averaged = torch.mean(metric_tensor, dim=0) + # (pred_steps, d_f) + + # Take square root after all averaging to change MSE to RMSE + if "mse" in metric_name: + metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) + metric_name = metric_name.replace("mse", "rmse") + # Note: we here assume rescaling for all metrics is linear - metric_rescaled = ( - torch.mean(metric_tensor, dim=0) * self.data_std - ) + metric_rescaled = metric_tensor_averaged * self.data_std # (pred_steps, d_f) log_dict.update( self.create_metric_log_dict(