From cee405827580d8af6e44808ed9d4450a10fb6a12 Mon Sep 17 00:00:00 2001 From: "Brandon N. Benton" Date: Fri, 20 Dec 2024 10:19:44 -0800 Subject: [PATCH 01/32] dual sampler, queue, and batch handler with obs. modifying Sup3rDataset to work with three data members. --- sup3r/preprocessing/__init__.py | 8 +- sup3r/preprocessing/base.py | 55 +++++--- .../preprocessing/batch_handlers/__init__.py | 1 + sup3r/preprocessing/batch_handlers/factory.py | 7 + sup3r/preprocessing/batch_queues/__init__.py | 1 + sup3r/preprocessing/batch_queues/with_obs.py | 50 +++++++ sup3r/preprocessing/samplers/__init__.py | 1 + sup3r/preprocessing/samplers/with_obs.py | 69 ++++++++++ tests/training/test_train_dual_with_obs.py | 125 ++++++++++++++++++ 9 files changed, 297 insertions(+), 20 deletions(-) create mode 100644 sup3r/preprocessing/batch_queues/with_obs.py create mode 100644 sup3r/preprocessing/samplers/with_obs.py create mode 100644 tests/training/test_train_dual_with_obs.py diff --git a/sup3r/preprocessing/__init__.py b/sup3r/preprocessing/__init__.py index 770f09159..0832d8d3a 100644 --- a/sup3r/preprocessing/__init__.py +++ b/sup3r/preprocessing/__init__.py @@ -33,8 +33,14 @@ BatchHandlerMom2SepSF, BatchHandlerMom2SF, DualBatchHandler, + DualBatchHandlerWithObs, +) +from .batch_queues import ( + BatchQueueDC, + DualBatchQueue, + DualBatchQueueWithObs, + SingleBatchQueue, ) -from .batch_queues import BatchQueueDC, DualBatchQueue, SingleBatchQueue from .cachers import Cacher from .collections import Collection, StatsCollection from .data_handlers import ( diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 5ddd9dea0..bc1897e17 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -183,17 +183,18 @@ def rewrap(self, data): """Rewrap data as ``Sup3rDataset`` after calling parent method.""" if isinstance(data, type(self)): return data - return ( - type(self)(low_res=data[0], high_res=data[1]) - if len(data) > 1 - else type(self)(high_res=data[0]) - ) + if len(data) == 2: + return type(self)(low_res=data[0], high_res=data[1]) + if len(data) == 3: + return type(self)(low_res=data[0], high_res=data[1], obs=data[2]) + return type(self)(high_res=data[0]) def sample(self, idx): """Get samples from ``self._ds`` members. idx should be either a tuple of slices for the dimensions (south_north, west_east, time) and a list - of feature names or a 2-tuple of the same, for dual datasets.""" - if len(self._ds) == 2: + of feature names or a tuple of the same, for multi-member datasets + (dual datasets and dual with observations datasets).""" + if len(self._ds) > 1: return tuple(d.sample(idx[i]) for i, d in enumerate(self)) return self._ds[-1].sample(idx) @@ -228,10 +229,12 @@ def shape(self): def features(self): """The features are determined by the set of features from all data members.""" + if len(self._ds) == 1: + return self._ds[0].features feats = [ - f for f in self._ds[0].features if f not in self._ds[-1].features + f for f in self._ds[0].features if f not in self._ds[1].features ] - feats += self._ds[-1].features + feats += self._ds[1].features return feats @property @@ -258,13 +261,13 @@ def mean(self, **kwargs): """Use the high_res members to compute the means. These are used for normalization during training.""" kwargs['skipna'] = kwargs.get('skipna', True) - return self._ds[-1].mean(**kwargs) + return self._ds[1 if len(self._ds) > 1 else 0].mean(**kwargs) def std(self, **kwargs): """Use the high_res members to compute the stds. These are used for normalization during training.""" kwargs['skipna'] = kwargs.get('skipna', True) - return self._ds[-1].std(**kwargs) + return self._ds[1 if len(self._ds) > 1 else 0].std(**kwargs) def normalize(self, means, stds): """Normalize dataset using the given mean and stds. These are provided @@ -309,10 +312,12 @@ def __init__( such. This is a tuple when the `.data` attribute belongs to a :class:`~.collections.base.Collection` object like :class:`~.batch_handlers.factory.BatchHandler`. Otherwise this is - :class:`~.Sup3rDataset` object, which is either a wrapped 2-tuple - or 1-tuple (e.g. ``len(data) == 2`` or ``len(data) == 1)``. This is - a 2-tuple when ``.data`` belongs to a dual container object like - :class:`~.samplers.DualSampler` and a 1-tuple otherwise. + :class:`~.Sup3rDataset` object, which is either a wrapped 3-tuple, + 2-tuple, or 1-tuple (e.g. ``len(data) == 3``, ``len(data) == 2`` or + ``len(data) == 1)``. This is a 3-tuple when ``.data`` belongs to a + container object like :class:`~.samplers.DualSamplerWithObs`, a + 2-tuple when ``.data`` belongs to a dual container object like + :class:`~.samplers.DualSampler`, and a 1-tuple otherwise. """ self.data = data @@ -345,10 +350,12 @@ def wrap(self, data): tuple when the `.data` attribute belongs to a :class:`~.collections.base.Collection` object like :class:`~.batch_handlers.factory.BatchHandler`. Otherwise this is - :class:`~.Sup3rDataset` object, which is either a wrapped 2-tuple or - 1-tuple (e.g. ``len(data) == 2`` or ``len(data) == 1)``. This is a - 2-tuple when ``.data`` belongs to a dual container object like - :class:`~.samplers.DualSampler` and a 1-tuple otherwise. + :class:`~.Sup3rDataset` object, which is either a wrapped 3-tuple, + 2-tuple, or 1-tuple (e.g. ``len(data) == 3``, ``len(data) == 2`` or + ``len(data) == 1)``. This is a 3-tuple when ``.data`` belongs to a + container object like :class:`~.samplers.DualSamplerWithObs`, a 2-tuple + when ``.data`` belongs to a dual container object like + :class:`~.samplers.DualSampler`, and a 1-tuple otherwise. """ if data is None: return data @@ -365,6 +372,16 @@ def wrap(self, data): logger.warning(msg) warn(msg) data = Sup3rDataset(low_res=data[0], high_res=data[1]) + elif isinstance(data, tuple) and len(data) == 3: + msg = ( + f'{self.__class__.__name__}.data is being set with a ' + '3-tuple without explicit dataset names. We will assume ' + 'first tuple member is low-res, second is high-res, and third ' + 'is obs' + ) + logger.warning(msg) + warn(msg) + data = Sup3rDataset(low_res=data[0], high_res=data[1], obs=data[2]) elif not isinstance(data, Sup3rDataset): name = getattr(data, 'name', None) or 'high_res' data = Sup3rDataset(**{name: data}) diff --git a/sup3r/preprocessing/batch_handlers/__init__.py b/sup3r/preprocessing/batch_handlers/__init__.py index 08bba8d6b..d66b10126 100644 --- a/sup3r/preprocessing/batch_handlers/__init__.py +++ b/sup3r/preprocessing/batch_handlers/__init__.py @@ -11,4 +11,5 @@ BatchHandlerMom2SepSF, BatchHandlerMom2SF, DualBatchHandler, + DualBatchHandlerWithObs, ) diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index 7e63c34b6..4b6ddf3aa 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -14,10 +14,12 @@ QueueMom2SF, ) from sup3r.preprocessing.batch_queues.dual import DualBatchQueue +from sup3r.preprocessing.batch_queues.with_obs import DualBatchQueueWithObs from sup3r.preprocessing.collections.stats import StatsCollection from sup3r.preprocessing.samplers.base import Sampler from sup3r.preprocessing.samplers.cc import DualSamplerCC from sup3r.preprocessing.samplers.dual import DualSampler +from sup3r.preprocessing.samplers.with_obs import DualSamplerWithObs from sup3r.preprocessing.utilities import ( check_signatures, get_class_kwargs, @@ -315,6 +317,11 @@ def stop(self): DualBatchHandler = BatchHandlerFactory( DualBatchQueue, DualSampler, name='DualBatchHandler' ) + +DualBatchHandlerWithObs = BatchHandlerFactory( + DualBatchQueueWithObs, DualSamplerWithObs, name='DualBatchHandlerWithObs' +) + BatchHandlerCC = BatchHandlerFactory( DualBatchQueue, DualSamplerCC, name='BatchHandlerCC' ) diff --git a/sup3r/preprocessing/batch_queues/__init__.py b/sup3r/preprocessing/batch_queues/__init__.py index 63053f123..067d21070 100644 --- a/sup3r/preprocessing/batch_queues/__init__.py +++ b/sup3r/preprocessing/batch_queues/__init__.py @@ -12,3 +12,4 @@ ) from .dc import BatchQueueDC, ValBatchQueueDC from .dual import DualBatchQueue +from .with_obs import DualBatchQueueWithObs diff --git a/sup3r/preprocessing/batch_queues/with_obs.py b/sup3r/preprocessing/batch_queues/with_obs.py new file mode 100644 index 000000000..4c2a16571 --- /dev/null +++ b/sup3r/preprocessing/batch_queues/with_obs.py @@ -0,0 +1,50 @@ +"""DualBatchQueue with additional observation data on the same grid as the +high-res data. The observation data is sampled with the same index as the +high-res data during training.""" + +import logging + +from scipy.ndimage import gaussian_filter + +from .dual import DualBatchQueue + +logger = logging.getLogger(__name__) + + +class DualBatchQueueWithObs(DualBatchQueue): + """Base BatchQueue for use with + :class:`~sup3r.preprocessing.samplers.DualSamplerWithObs` objects.""" + + _signature_objs = (DualBatchQueue,) + + @property + def queue_shape(self): + """Shape of objects stored in the queue.""" + return [ + (self.batch_size, *self.lr_shape), + (self.batch_size, *self.hr_shape), + (self.batch_size, *self.hr_shape), + ] + + def transform(self, samples, smoothing=None, smoothing_ignore=None): + """Perform smoothing if requested. + + Note + ---- + This does not include temporal or spatial coarsening like + :class:`SingleBatchQueue` + """ + low_res, high_res, obs = samples + + if smoothing is not None: + feat_iter = [ + j + for j in range(low_res.shape[-1]) + if self.features[j] not in smoothing_ignore + ] + for i in range(low_res.shape[0]): + for j in feat_iter: + low_res[i, ..., j] = gaussian_filter( + low_res[i, ..., j], smoothing, mode='nearest' + ) + return low_res, high_res, obs diff --git a/sup3r/preprocessing/samplers/__init__.py b/sup3r/preprocessing/samplers/__init__.py index e281616d5..990e23861 100644 --- a/sup3r/preprocessing/samplers/__init__.py +++ b/sup3r/preprocessing/samplers/__init__.py @@ -9,3 +9,4 @@ from .cc import DualSamplerCC from .dc import SamplerDC from .dual import DualSampler +from .with_obs import DualSamplerWithObs diff --git a/sup3r/preprocessing/samplers/with_obs.py b/sup3r/preprocessing/samplers/with_obs.py new file mode 100644 index 000000000..267794b9b --- /dev/null +++ b/sup3r/preprocessing/samplers/with_obs.py @@ -0,0 +1,69 @@ +"""Extended Sampler for sampling observation data in addition to standard +gridded training data.""" + +from typing import Dict, Optional + +from sup3r.preprocessing.base import Sup3rDataset +from sup3r.preprocessing.samplers.dual import DualSampler + + +class DualSamplerWithObs(DualSampler): + """Dual Sampler which also samples from extra observation data. The + observation data is on the same grid as the high-resolution data but + includes NaNs at points where observation data doesn't exist. This will + be used in an additional content loss term.""" + + def __init__( + self, + data: Sup3rDataset, + sample_shape: Optional[tuple] = None, + batch_size: int = 16, + s_enhance: int = 1, + t_enhance: int = 24, + feature_sets: Optional[Dict] = None, + ): + """ + Parameters + ---------- + data : Sup3rDataset + A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with + low-res, high-res, and obs data members. The observation data is on + the same grid as the high-res data. + sample_shape : tuple + Size of arrays to sample from the high-res data. The sample shape + for the low-res sampler will be determined from the enhancement + factors. + s_enhance : int + Spatial enhancement factor + t_enhance : int + Temporal enhancement factor + feature_sets : Optional[dict] + Optional dictionary describing how the full set of features is + split between `lr_only_features` and `hr_exo_features`. + + lr_only_features : list | tuple + List of feature names or patt*erns that should only be + included in the low-res training set and not the high-res + observations. + hr_exo_features : list | tuple + List of feature names or patt*erns that should be included + in the high-resolution observation but not expected to be + output from the generative model. An example is high-res + topography that is to be injected mid-network. + """ + super().__init__( + data, + sample_shape=sample_shape, + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + feature_sets=feature_sets, + ) + + def get_sample_index(self, n_obs=None): + """Get paired sample index, consisting of index for the low res sample + and the index for the high res sample with the same spatiotemporal + extent, with an additional index (same as the index for the high-res + data) for the observation data""" + lr_index, hr_index = super().get_sample_index(n_obs=n_obs) + return (lr_index, hr_index, hr_index) diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py new file mode 100644 index 000000000..b66b3c85d --- /dev/null +++ b/tests/training/test_train_dual_with_obs.py @@ -0,0 +1,125 @@ +"""Test the training of GANs with dual data handler""" + +import os +import tempfile + +import numpy as np +import pytest + +from sup3r.models import Sup3rGan +from sup3r.preprocessing import ( + DataHandler, + DualBatchHandlerWithObs, + DualRasterizer, + Sup3rDataset, +) +from sup3r.preprocessing.samplers import DualSamplerWithObs +from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory + +TARGET_COORD = (39.01, -105.15) +FEATURES = ['u_100m', 'v_100m'] + + +DualBatchHandlerWithObsTester = BatchHandlerTesterFactory( + DualBatchHandlerWithObs, DualSamplerWithObs +) + + +@pytest.mark.parametrize( + [ + 'fp_gen', + 'fp_disc', + 's_enhance', + 't_enhance', + 'sample_shape', + 'mode', + ], + [ + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), + ], +) +def test_train_h5_nc( + fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 +): + """Test model training with a dual data handler / batch handler with h5 and + era as hr / lr datasets with additional observation data used in extra + content loss. Tests both spatiotemporal and spatial models.""" + + lr = 1e-5 + kwargs = { + 'features': FEATURES, + 'target': TARGET_COORD, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES, + time_slice=slice(None, None, 30), + ) + + # time indices conflict with t_enhance + with pytest.raises(AssertionError): + dual_rasterizer = DualRasterizer( + data=(lr_handler.data, hr_handler.data), + s_enhance=s_enhance, + t_enhance=t_enhance, + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data=(lr_handler.data, hr_handler.data), + s_enhance=s_enhance, + t_enhance=t_enhance, + ) + obs_data = dual_rasterizer.high_res.copy() + + dual_with_obs = Sup3rDataset( + low_res=dual_rasterizer.low_res, + high_res=dual_rasterizer.high_res, + obs=obs_data, + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=3, + mode=mode, + ) + + Sup3rGan.seed() + model = Sup3rGan( + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': n_epoch, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tlossg = model.history['train_loss_gen'].values + assert np.sum(np.diff(tlossg)) < 0 From 61f6206533852e138a8b469937c25af1609311c4 Mon Sep 17 00:00:00 2001 From: "Brandon N. Benton" Date: Fri, 20 Dec 2024 15:43:13 -0800 Subject: [PATCH 02/32] training with obs test --- sup3r/models/abstract.py | 15 +- sup3r/models/base.py | 486 ++++++++++--------- sup3r/preprocessing/batch_queues/with_obs.py | 17 + tests/training/test_train_dual_with_obs.py | 43 +- 4 files changed, 302 insertions(+), 259 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index b4b7fb869..9179380c9 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1292,6 +1292,7 @@ def run_gradient_descent( low_res, hi_res_true, training_weights, + obs_data=None, optimizer=None, multi_gpu=False, **calc_loss_kwargs, @@ -1313,6 +1314,8 @@ def run_gradient_descent( training_weights : list A list of layer weights that are to-be-trained based on the current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. optimizer : tf.keras.optimizers.Optimizer Optimizer class to use to update weights. This can be different if you're training just the generator or one of the discriminator @@ -1341,6 +1344,7 @@ def run_gradient_descent( low_res, hi_res_true, training_weights, + obs_data=obs_data, device_name=self.default_device, **calc_loss_kwargs, ) @@ -1354,6 +1358,11 @@ def run_gradient_descent( futures = [] lr_chunks = np.array_split(low_res, len(self.gpu_list)) hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) + obs_data_chunks = ( + [None] * len(hr_true_chunks) + if obs_data is None + else np.array_split(obs_data, len(self.gpu_list)) + ) split_mask = False mask_chunks = None if 'mask' in calc_loss_kwargs: @@ -1372,6 +1381,7 @@ def run_gradient_descent( lr_chunks[i], hr_true_chunks[i], training_weights, + obs_data=obs_data_chunks[i], device_name=f'/gpu:{i}', **calc_loss_kwargs, ) @@ -1594,6 +1604,7 @@ def get_single_grad( low_res, hi_res_true, training_weights, + obs_data=None, device_name=None, **calc_loss_kwargs, ): @@ -1613,6 +1624,8 @@ def get_single_grad( training_weights : list A list of layer weights that are to-be-trained based on the current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. device_name : None | str Optional tensorflow device name for GPU placement. Note that if a GPU is available, variables will be placed on that GPU even if @@ -1636,7 +1649,7 @@ def get_single_grad( hi_res_exo = self.get_high_res_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss_out = self.calc_loss( - hi_res_true, hi_res_gen, **calc_loss_kwargs + hi_res_true, hi_res_gen, obs_data=obs_data, **calc_loss_kwargs ) loss, loss_details = loss_out grad = tape.gradient(loss, training_weights) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index cfdc47f73..4ff617226 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -546,242 +546,6 @@ def calc_loss_disc(disc_out_true, disc_out_gen): ) return tf.reduce_mean(loss_disc) - @tf.function - def calc_loss( - self, - hi_res_true, - hi_res_gen, - weight_gen_advers=0.001, - train_gen=True, - train_disc=False, - ): - """Calculate the GAN loss function using generated and true high - resolution data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - train_gen : bool - True if generator is being trained, then loss=loss_gen - train_disc : bool - True if disc is being trained, then loss=loss_disc - - Returns - ------- - loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components - """ - hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) - - if hi_res_gen.shape != hi_res_true.shape: - msg = ( - 'The tensor shapes of the synthetic output {} and ' - 'true high res {} did not have matching shape! ' - 'Check the spatiotemporal enhancement multipliers in your ' - 'your model config and data handlers.'.format( - hi_res_gen.shape, hi_res_true.shape - ) - ) - logger.error(msg) - raise RuntimeError(msg) - - disc_out_true = self._tf_discriminate(hi_res_true) - disc_out_gen = self._tf_discriminate(hi_res_gen) - - loss_gen_content = self.calc_loss_gen_content(hi_res_true, hi_res_gen) - loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) - loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers - - loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) - - loss = None - if train_gen: - loss = loss_gen - elif train_disc: - loss = loss_disc - - loss_details = { - 'loss_gen': loss_gen, - 'loss_gen_content': loss_gen_content, - 'loss_gen_advers': loss_gen_advers, - 'loss_disc': loss_disc, - } - - return loss, loss_details - - def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): - """Calculate the validation loss at the current state of model training - - Parameters - ---------- - batch_handler : sup3r.preprocessing.BatchHandler - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - loss_details : dict - Namespace of the breakdown of loss components - - Returns - ------- - loss_details : dict - Same as input but now includes val_* loss info - """ - logger.debug('Starting end-of-epoch validation loss calculation...') - loss_details['n_obs'] = 0 - for val_batch in batch_handler.val_data: - val_exo_data = self.get_high_res_exo_input(val_batch.high_res) - high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) - _, v_loss_details = self.calc_loss( - val_batch.high_res, - high_res_gen, - weight_gen_advers=weight_gen_advers, - train_gen=False, - train_disc=False, - ) - - loss_details = self.update_loss_details( - loss_details, v_loss_details, len(val_batch), prefix='val_' - ) - return loss_details - - def train_epoch( - self, - batch_handler, - weight_gen_advers, - train_gen, - train_disc, - disc_loss_bounds, - multi_gpu=False, - ): - """Train the GAN for one epoch. - - Parameters - ---------- - batch_handler : sup3r.preprocessing.BatchHandler - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - train_gen : bool - Flag whether to train the generator for this set of epochs - train_disc : bool - Flag whether to train the discriminator for this set of epochs - disc_loss_bounds : tuple - Lower and upper bounds for the discriminator loss outside of which - the discriminators will not train unless train_disc=True or - and train_gen=False. - multi_gpu : bool - Flag to break up the batch for parallel gradient descent - calculations on multiple gpus. If True and multiple GPUs are - present, each batch from the batch_handler will be divided up - between the GPUs and resulting gradients from each GPU will be - summed and then applied once per batch at the nominal learning - rate that the model and optimizer were initialized with. - If true and multiple gpus are found, ``default_device`` device - should be set to /gpu:0 - - Returns - ------- - loss_details : dict - Namespace of the breakdown of loss components - """ - - disc_th_low = np.min(disc_loss_bounds) - disc_th_high = np.max(disc_loss_bounds) - loss_details = {'n_obs': 0, 'train_loss_disc': 0} - - only_gen = train_gen and not train_disc - only_disc = train_disc and not train_gen - - if self._write_tb_profile: - tf.summary.trace_on(graph=True, profiler=True) - - for ib, batch in enumerate(batch_handler): - trained_gen = False - trained_disc = False - b_loss_details = {} - loss_disc = loss_details['train_loss_disc'] - disc_too_good = loss_disc <= disc_th_low - disc_too_bad = (loss_disc > disc_th_high) and train_disc - gen_too_good = disc_too_bad - - if not self.generator_weights: - self.init_weights(batch.low_res.shape, batch.high_res.shape) - - if only_gen or (train_gen and not gen_too_good): - trained_gen = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.generator_weights, - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer, - train_gen=True, - train_disc=False, - multi_gpu=multi_gpu, - ) - - if only_disc or (train_disc and not disc_too_good): - trained_disc = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.discriminator_weights, - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer_disc, - train_gen=False, - train_disc=True, - multi_gpu=multi_gpu, - ) - - b_loss_details['gen_trained_frac'] = float(trained_gen) - b_loss_details['disc_trained_frac'] = float(trained_disc) - - self.dict_to_tensorboard(b_loss_details) - self.dict_to_tensorboard(self.timer.log) - - loss_details = self.update_loss_details( - loss_details, - b_loss_details, - batch_handler.batch_size, - prefix='train_', - ) - logger.debug( - 'Batch {} out of {} has epoch-average ' - '(gen / disc) loss of: ({:.2e} / {:.2e}). ' - 'Trained (gen / disc): ({} / {})'.format( - ib + 1, - len(batch_handler), - loss_details['train_loss_gen'], - loss_details['train_loss_disc'], - trained_gen, - trained_disc, - ) - ) - if all([not trained_gen, not trained_disc]): - msg = ( - 'For some reason none of the GAN networks trained ' - 'during batch {} out of {}!'.format(ib, len(batch_handler)) - ) - logger.warning(msg) - warn(msg) - self.total_batches += 1 - - loss_details['total_batches'] = int(self.total_batches) - self.profile_to_tensorboard('training_epoch') - return loss_details - def update_adversarial_weights( self, history, @@ -795,7 +559,7 @@ def update_adversarial_weights( Parameters ---------- - history : dict + history : dicts Dictionary with information on how often discriminators were trained during current and previous epochs. adaptive_update_fraction : float @@ -1052,3 +816,251 @@ def train( break batch_handler.stop() + + @tf.function + def calc_loss( + self, + hi_res_true, + hi_res_gen, + obs_data=None, + weight_gen_advers=0.001, + train_gen=True, + train_disc=False, + ): + """Calculate the GAN loss function using generated and true high + resolution data. + + Parameters + ---------- + hi_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + hi_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + train_gen : bool + True if generator is being trained, then loss=loss_gen + train_disc : bool + True if disc is being trained, then loss=loss_disc + + Returns + ------- + loss : tf.Tensor + 0D tensor representing the loss value for the network being trained + (either generator or one of the discriminators) + loss_details : dict + Namespace of the breakdown of loss components + """ + hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) + + if hi_res_gen.shape != hi_res_true.shape: + msg = ( + 'The tensor shapes of the synthetic output {} and ' + 'true high res {} did not have matching shape! ' + 'Check the spatiotemporal enhancement multipliers in your ' + 'your model config and data handlers.'.format( + hi_res_gen.shape, hi_res_true.shape + ) + ) + logger.error(msg) + raise RuntimeError(msg) + + disc_out_true = self._tf_discriminate(hi_res_true) + disc_out_gen = self._tf_discriminate(hi_res_gen) + + loss_gen_content = self.calc_loss_gen_content(hi_res_true, hi_res_gen) + loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) + loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers + + loss_obs = None + if obs_data is not None: + mask = tf.math.is_nan(obs_data) + loss_obs = self.loss_fun(obs_data[~mask], hi_res_gen[~mask]) + loss_gen += loss_obs + + loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) + + loss = None + if train_gen: + loss = loss_gen + elif train_disc: + loss = loss_disc + + loss_details = { + 'loss_gen': loss_gen, + 'loss_obs': loss_obs, + 'loss_gen_content': loss_gen_content, + 'loss_gen_advers': loss_gen_advers, + 'loss_disc': loss_disc, + } + + return loss, loss_details + + def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + + Parameters + ---------- + batch_handler : sup3r.preprocessing.BatchHandler + BatchHandler object to iterate through + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input but now includes val_* loss info + """ + logger.debug('Starting end-of-epoch validation loss calculation...') + loss_details['n_obs'] = 0 + for val_batch in batch_handler.val_data: + val_exo_data = self.get_high_res_exo_input(val_batch.high_res) + high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) + _, v_loss_details = self.calc_loss( + val_batch.high_res, + high_res_gen, + obs_data=getattr(val_batch, 'obs', None), + weight_gen_advers=weight_gen_advers, + train_gen=False, + train_disc=False, + ) + + loss_details = self.update_loss_details( + loss_details, v_loss_details, len(val_batch), prefix='val_' + ) + return loss_details + + def train_epoch( + self, + batch_handler, + weight_gen_advers, + train_gen, + train_disc, + disc_loss_bounds, + multi_gpu=False, + ): + """Train the GAN for one epoch. + + Parameters + ---------- + batch_handler : sup3r.preprocessing.BatchHandler + BatchHandler object to iterate through + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + train_gen : bool + Flag whether to train the generator for this set of epochs + train_disc : bool + Flag whether to train the discriminator for this set of epochs + disc_loss_bounds : tuple + Lower and upper bounds for the discriminator loss outside of which + the discriminators will not train unless train_disc=True or + and train_gen=False. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components + """ + + disc_th_low = np.min(disc_loss_bounds) + disc_th_high = np.max(disc_loss_bounds) + loss_details = {'n_obs': 0, 'train_loss_disc': 0} + + only_gen = train_gen and not train_disc + only_disc = train_disc and not train_gen + + if self._write_tb_profile: + tf.summary.trace_on(graph=True, profiler=True) + + for ib, batch in enumerate(batch_handler): + trained_gen = False + trained_disc = False + b_loss_details = {} + loss_disc = loss_details['train_loss_disc'] + disc_too_good = loss_disc <= disc_th_low + disc_too_bad = (loss_disc > disc_th_high) and train_disc + gen_too_good = disc_too_bad + + if not self.generator_weights: + self.init_weights(batch.low_res.shape, batch.high_res.shape) + + if only_gen or (train_gen and not gen_too_good): + trained_gen = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.generator_weights, + obs_data=getattr(batch, 'obs', None), + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer, + train_gen=True, + train_disc=False, + multi_gpu=multi_gpu, + ) + + if only_disc or (train_disc and not disc_too_good): + trained_disc = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.discriminator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer_disc, + train_gen=False, + train_disc=True, + multi_gpu=multi_gpu, + ) + + b_loss_details['gen_trained_frac'] = float(trained_gen) + b_loss_details['disc_trained_frac'] = float(trained_disc) + + self.dict_to_tensorboard(b_loss_details) + self.dict_to_tensorboard(self.timer.log) + + loss_details = self.update_loss_details( + loss_details, + b_loss_details, + batch_handler.batch_size, + prefix='train_', + ) + logger.debug( + 'Batch {} out of {} has epoch-average ' + '(gen / disc) loss of: ({:.2e} / {:.2e}). ' + 'Trained (gen / disc): ({} / {})'.format( + ib + 1, + len(batch_handler), + loss_details['train_loss_gen'], + loss_details['train_loss_disc'], + trained_gen, + trained_disc, + ) + ) + if all([not trained_gen, not trained_disc]): + msg = ( + 'For some reason none of the GAN networks trained ' + 'during batch {} out of {}!'.format(ib, len(batch_handler)) + ) + logger.warning(msg) + warn(msg) + self.total_batches += 1 + + loss_details['total_batches'] = int(self.total_batches) + self.profile_to_tensorboard('training_epoch') + return loss_details diff --git a/sup3r/preprocessing/batch_queues/with_obs.py b/sup3r/preprocessing/batch_queues/with_obs.py index 4c2a16571..0a0a0890e 100644 --- a/sup3r/preprocessing/batch_queues/with_obs.py +++ b/sup3r/preprocessing/batch_queues/with_obs.py @@ -3,6 +3,7 @@ high-res data during training.""" import logging +from collections import namedtuple from scipy.ndimage import gaussian_filter @@ -15,6 +16,8 @@ class DualBatchQueueWithObs(DualBatchQueue): """Base BatchQueue for use with :class:`~sup3r.preprocessing.samplers.DualSamplerWithObs` objects.""" + Batch = namedtuple('Batch', ['low_res', 'high_res', 'obs']) + _signature_objs = (DualBatchQueue,) @property @@ -48,3 +51,17 @@ def transform(self, samples, smoothing=None, smoothing_ignore=None): low_res[i, ..., j], smoothing, mode='nearest' ) return low_res, high_res, obs + + def post_proc(self, samples) -> Batch: + """Performs some post proc on dequeued samples before sending out for + training. Post processing can include coarsening on high-res data (if + :class:`Collection` consists of :class:`Sampler` objects and not + :class:`DualSampler` objects), smoothing, etc + + Returns + ------- + Batch : namedtuple + namedtuple with `low_res`, `high_res`, and `obs` attributes + """ + lr, hr, obs = self.transform(samples, **self.transform_kwargs) + return self.Batch(low_res=lr, high_res=hr, obs=obs) diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py index b66b3c85d..ae5a88357 100644 --- a/tests/training/test_train_dual_with_obs.py +++ b/tests/training/test_train_dual_with_obs.py @@ -1,5 +1,6 @@ """Test the training of GANs with dual data handler""" +import itertools import os import tempfile @@ -8,6 +9,7 @@ from sup3r.models import Sup3rGan from sup3r.preprocessing import ( + Container, DataHandler, DualBatchHandlerWithObs, DualRasterizer, @@ -41,7 +43,7 @@ (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), ], ) -def test_train_h5_nc( +def test_train_coarse_h5( fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 ): """Test model training with a dual data handler / batch handler with h5 and @@ -59,23 +61,11 @@ def test_train_h5_nc( **kwargs, time_slice=slice(None, None, 1), ) - lr_handler = DataHandler( - pytest.FP_ERA, - features=FEATURES, - time_slice=slice(None, None, 30), - ) - - # time indices conflict with t_enhance - with pytest.raises(AssertionError): - dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), - s_enhance=s_enhance, - t_enhance=t_enhance, - ) lr_handler = DataHandler( - pytest.FP_ERA, - features=FEATURES, + pytest.FP_WTK, + **kwargs, + hr_spatial_coarsen=s_enhance, time_slice=slice(None, None, t_enhance), ) @@ -85,11 +75,20 @@ def test_train_h5_nc( t_enhance=t_enhance, ) obs_data = dual_rasterizer.high_res.copy() - - dual_with_obs = Sup3rDataset( - low_res=dual_rasterizer.low_res, - high_res=dual_rasterizer.high_res, - obs=obs_data, + for feat in FEATURES: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = list(range(0, 20, 4)) + lon_ids = list(range(0, 20, 4)) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[feat] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data=Sup3rDataset( + low_res=dual_rasterizer.low_res, + high_res=dual_rasterizer.high_res, + obs=obs_data, + ) ) batch_handler = DualBatchHandlerWithObsTester( @@ -122,4 +121,6 @@ def test_train_h5_nc( model.train(batch_handler, **model_kwargs) tlossg = model.history['train_loss_gen'].values + tlosso = model.history['train_loss_obs'].values assert np.sum(np.diff(tlossg)) < 0 + assert np.sum(np.diff(tlosso)) < 0 From e6b88182d576a119487b5acedc6656ea4d83bbd1 Mon Sep 17 00:00:00 2001 From: "Brandon N. Benton" Date: Sat, 21 Dec 2024 08:18:27 -0800 Subject: [PATCH 03/32] split up interface and abstact model --- sup3r/models/abstract.py | 621 ++------------------- sup3r/models/base.py | 3 +- sup3r/models/conditional.py | 3 +- sup3r/models/interface.py | 527 +++++++++++++++++ sup3r/models/tensorboard.py | 85 +++ sup3r/preprocessing/base.py | 39 +- tests/training/test_train_dual_with_obs.py | 17 +- 7 files changed, 686 insertions(+), 609 deletions(-) create mode 100644 sup3r/models/interface.py create mode 100644 sup3r/models/tensorboard.py diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 9179380c9..dc94bc52a 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1,11 +1,9 @@ """Abstract class defining the required interface for Sup3r model subclasses""" import json -import locale import logging import os import pprint -import re import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -23,588 +21,18 @@ from sup3r.preprocessing.data_handlers import ExoData from sup3r.preprocessing.utilities import numpy_if_tensor from sup3r.utilities import VERSION_RECORD -from sup3r.utilities.utilities import Timer, safe_cast +from sup3r.utilities.utilities import safe_cast -logger = logging.getLogger(__name__) - - -class TensorboardMixIn: - """MixIn class for tensorboard logging and profiling.""" - - def __init__(self): - self._tb_writer = None - self._tb_log_dir = None - self._write_tb_profile = False - self._total_batches = None - self._history = None - self.timer = Timer() - - @property - def total_batches(self): - """Record of total number of batches for logging.""" - if self._total_batches is None and self._history is None: - self._total_batches = 0 - elif self._history is None and 'total_batches' in self._history: - self._total_batches = self._history['total_batches'].values[-1] - elif self._total_batches is None and self._history is not None: - self._total_batches = 0 - return self._total_batches - - @total_batches.setter - def total_batches(self, value): - """Set total number of batches.""" - self._total_batches = value - - def dict_to_tensorboard(self, entry): - """Write data to tensorboard log file. This is usually a loss_details - dictionary. - - Parameters - ---------- - entry: dict - Dictionary of values to write to tensorboard log file - """ - if self._tb_writer is not None: - with self._tb_writer.as_default(): - for name, value in entry.items(): - if isinstance(value, str): - tf.summary.text(name, value, self.total_batches) - else: - tf.summary.scalar(name, value, self.total_batches) - - def profile_to_tensorboard(self, name): - """Write profile data to tensorboard log file. - - Parameters - ---------- - name : str - Tag name to use for profile info - """ - if self._tb_writer is not None and self._write_tb_profile: - with self._tb_writer.as_default(): - tf.summary.trace_export( - name=name, - step=self.total_batches, - profiler_outdir=self._tb_log_dir, - ) - - def _init_tensorboard_writer(self, out_dir): - """Initialize the ``tf.summary.SummaryWriter`` to use for writing - tensorboard compatible log files. - - Parameters - ---------- - out_dir : str - Standard out_dir where model epochs are saved. e.g. './gan_{epoch}' - """ - tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir)) - self._tb_log_dir = os.path.join(tb_log_pardir, 'logs') - os.makedirs(self._tb_log_dir, exist_ok=True) - self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) - - -class AbstractInterface(ABC): - """ - Abstract class to define the required interface for Sup3r model subclasses - - Note that this only sets the required interfaces for a GAN that can be - loaded from disk and used to predict synthetic outputs. The interface for - models that can be trained will be set in another class. - """ - - @classmethod - @abstractmethod - def load(cls, model_dir, verbose=True): - """Load the GAN with its sub-networks from a previously saved-to output - directory. - - Parameters - ---------- - model_dir - Directory to load GAN model files from. - verbose : bool - Flag to log information about the loaded model. - - Returns - ------- - out : BaseModel - Returns a pretrained gan model that was previously saved to - model_dir - """ - - @abstractmethod - def generate( - self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None - ): - """Use the generator model to generate high res data from low res - input. This is the public generate function.""" - - @staticmethod - def seed(s=0): - """ - Set the random seed for reproducible results. - - Parameters - ---------- - s : int - Random seed - """ - CustomNetwork.seed(s=s) +from .tensorboard import TensorboardMixIn - @property - def input_dims(self): - """Get dimension of model generator input. This is usually 4D for - spatial models and 5D for spatiotemporal models. This gives the input - to the first step if the model is multi-step. Returns 5 for linear - models. - - Returns - ------- - int - """ - # pylint: disable=E1101 - if hasattr(self, '_gen'): - return self._gen.layers[0].rank - if hasattr(self, 'models'): - return self.models[0].input_dims - return 5 - - @property - def is_5d(self): - """Check if model expects spatiotemporal input""" - return self.input_dims == 5 - - @property - def is_4d(self): - """Check if model expects spatial only input""" - return self.input_dims == 4 - - # pylint: disable=E1101 - def get_s_enhance_from_layers(self): - """Compute factor by which model will enhance spatial resolution from - layer attributes. Used in model training during high res coarsening""" - s_enhance = None - if hasattr(self, '_gen'): - s_enhancements = [ - getattr(layer, '_spatial_mult', 1) - for layer in self._gen.layers - ] - s_enhance = int(np.prod(s_enhancements)) - return s_enhance - - # pylint: disable=E1101 - def get_t_enhance_from_layers(self): - """Compute factor by which model will enhance temporal resolution from - layer attributes. Used in model training during high res coarsening""" - t_enhance = None - if hasattr(self, '_gen'): - t_enhancements = [ - getattr(layer, '_temporal_mult', 1) - for layer in self._gen.layers - ] - t_enhance = int(np.prod(t_enhancements)) - return t_enhance - - @property - def s_enhance(self): - """Factor by which model will enhance spatial resolution. Used in - model training during high res coarsening and also in forward pass - routine to determine shape of needed exogenous data""" - models = getattr(self, 'models', [self]) - s_enhances = [m.meta.get('s_enhance', None) for m in models] - s_enhance = ( - self.get_s_enhance_from_layers() - if any(s is None for s in s_enhances) - else int(np.prod(s_enhances)) - ) - if len(models) == 1: - self.meta['s_enhance'] = s_enhance - return s_enhance - - @property - def t_enhance(self): - """Factor by which model will enhance temporal resolution. Used in - model training during high res coarsening and also in forward pass - routine to determine shape of needed exogenous data""" - models = getattr(self, 'models', [self]) - t_enhances = [m.meta.get('t_enhance', None) for m in models] - t_enhance = ( - self.get_t_enhance_from_layers() - if any(t is None for t in t_enhances) - else int(np.prod(t_enhances)) - ) - if len(models) == 1: - self.meta['t_enhance'] = t_enhance - return t_enhance - - @property - def s_enhancements(self): - """List of spatial enhancement factors. In the case of a single step - model this is just ``[self.s_enhance]``. This is used to determine - shapes of needed exogenous data in forward pass routine""" - if hasattr(self, 'models'): - return [model.s_enhance for model in self.models] - return [self.s_enhance] - - @property - def t_enhancements(self): - """List of temporal enhancement factors. In the case of a single step - model this is just ``[self.t_enhance]``. This is used to determine - shapes of needed exogenous data in forward pass routine""" - if hasattr(self, 'models'): - return [model.t_enhance for model in self.models] - return [self.t_enhance] - - @property - def input_resolution(self): - """Resolution of input data. Given as a dictionary - ``{'spatial': '...km', 'temporal': '...min'}``. The numbers are - required to be integers in the units specified. The units are not - strict as long as the resolution of the exogenous data, when extracting - exogenous data, is specified in the same units.""" - input_resolution = self.meta.get('input_resolution', None) - msg = 'model.input_resolution is None. This needs to be set.' - assert input_resolution is not None, msg - return input_resolution - - def _get_numerical_resolutions(self): - """Get the input and output resolutions without units. e.g. for - ``{"spatial": "30km", "temporal": "60min"}`` this returns - ``{"spatial": 30, "temporal": 60}``""" - ires_num = { - k: int(re.search(r'\d+', v).group(0)) - for k, v in self.input_resolution.items() - } - enhancements = {'spatial': self.s_enhance, 'temporal': self.t_enhance} - ores_num = {k: v // enhancements[k] for k, v in ires_num.items()} - return ires_num, ores_num - - def _ensure_valid_input_resolution(self): - """Ensure ehancement factors evenly divide input_resolution""" - - if self.input_resolution is None: - return - - ires_num, ores_num = self._get_numerical_resolutions() - s_enhance = self.meta['s_enhance'] - t_enhance = self.meta['t_enhance'] - check = ( - ires_num['temporal'] / ores_num['temporal'] == t_enhance - and ires_num['spatial'] / ores_num['spatial'] == s_enhance - ) - msg = ( - f'Enhancement factors (s_enhance={s_enhance}, ' - f't_enhance={t_enhance}) do not evenly divide ' - f'input resolution ({self.input_resolution})' - ) - if not check: - logger.error(msg) - raise RuntimeError(msg) - - def _ensure_valid_enhancement_factors(self): - """Ensure user provided enhancement factors are the same as those - computed from layer attributes""" - t_enhance = self.meta.get('t_enhance', None) - s_enhance = self.meta.get('s_enhance', None) - if s_enhance is None or t_enhance is None: - return - - layer_se = self.get_s_enhance_from_layers() - layer_te = self.get_t_enhance_from_layers() - layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] - layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] - msg = ( - f'Enhancement factors computed from layer attributes ' - f'(s_enhance={layer_se}, t_enhance={layer_te}) ' - f'conflict with user provided values (s_enhance={s_enhance}, ' - f't_enhance={t_enhance})' - ) - check = layer_se == s_enhance or layer_te == t_enhance - if not check: - logger.error(msg) - raise RuntimeError(msg) - - @property - def output_resolution(self): - """Resolution of output data. Given as a dictionary - {'spatial': '...km', 'temporal': '...min'}. This is computed from the - input resolution and the enhancement factors.""" - output_res = self.meta.get('output_resolution', None) - if self.input_resolution is not None and output_res is None: - ires_num, ores_num = self._get_numerical_resolutions() - output_res = { - k: v.replace(str(ires_num[k]), str(ores_num[k])) - for k, v in self.input_resolution.items() - } - self.meta['output_resolution'] = output_res - return output_res - - def _combine_fwp_input(self, low_res, exogenous_data=None): - """Combine exogenous_data at input resolution with low_res data prior - to forward pass through generator - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - exogenous_data : dict | ExoData | None - Special dictionary (class:`ExoData`) of exogenous feature data with - entries describing whether features should be combined at input, a - mid network layer, or with output. This doesn't have to include - the 'model' key since this data is for a single step model. - - Returns - ------- - low_res : np.ndarray - Low-resolution input data combined with exogenous_data, usually a - 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - if exogenous_data is None: - return low_res - - if ( - not isinstance(exogenous_data, ExoData) - and exogenous_data is not None - ): - exogenous_data = ExoData(exogenous_data) - - fnum_diff = len(self.lr_features) - low_res.shape[-1] - exo_feats = [] if fnum_diff <= 0 else self.lr_features[-fnum_diff:] - msg = ( - f'Provided exogenous_data: {exogenous_data} is missing some ' - f'required features ({exo_feats})' - ) - assert all(feature in exogenous_data for feature in exo_feats), msg - if exogenous_data is not None and fnum_diff > 0: - for feature in exo_feats: - exo_input = exogenous_data.get_combine_type_data( - feature, 'input' - ) - if exo_input is not None: - low_res = np.concatenate((low_res, exo_input), axis=-1) - - return low_res - - def _combine_fwp_output(self, hi_res, exogenous_data=None): - """Combine exogenous_data at output resolution with generated hi_res - data following forward pass output. - - Parameters - ---------- - hi_res : np.ndarray - High-resolution output data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - exogenous_data : dict | ExoData | None - Special dictionary (class:`ExoData`) of exogenous feature data with - entries describing whether features should be combined at input, a - mid network layer, or with output. This doesn't have to include - the 'model' key since this data is for a single step model. - - Returns - ------- - hi_res : np.ndarray - High-resolution output data combined with exogenous_data, usually a - 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - if exogenous_data is None: - return hi_res - - if ( - not isinstance(exogenous_data, ExoData) - and exogenous_data is not None - ): - exogenous_data = ExoData(exogenous_data) - - fnum_diff = len(self.hr_out_features) - hi_res.shape[-1] - exo_feats = [] if fnum_diff <= 0 else self.hr_out_features[-fnum_diff:] - msg = ( - 'Provided exogenous_data is missing some required features ' - f'({exo_feats})' - ) - assert all(feature in exogenous_data for feature in exo_feats), msg - if exogenous_data is not None and fnum_diff > 0: - for feature in exo_feats: - exo_output = exogenous_data.get_combine_type_data( - feature, 'output' - ) - if exo_output is not None: - hi_res = np.concatenate((hi_res, exo_output), axis=-1) - return hi_res - - @tf.function - def _combine_loss_input(self, high_res_true, high_res_gen): - """Combine exogenous feature data from high_res_true with high_res_gen - for loss calculation - - Parameters - ---------- - high_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - high_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - - Returns - ------- - high_res_gen : tf.Tensor - Same as input with exogenous data combined with high_res input - """ - if high_res_true.shape[-1] > high_res_gen.shape[-1]: - for feature in self.hr_exo_features: - f_idx = self.hr_exo_features.index(feature) - f_idx += len(self.hr_out_features) - exo_data = high_res_true[..., f_idx : f_idx + 1] - high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) - return high_res_gen - - @property - @abstractmethod - def meta(self): - """Get meta data dictionary that defines how the model was created""" - - @property - def lr_features(self): - """Get a list of low-resolution features input to the generative model. - This includes low-resolution features that might be supplied - exogenously at inference time but that were in the low-res batches - during training""" - return self.meta.get('lr_features', []) - - @property - def hr_out_features(self): - """Get the list of high-resolution output feature names that the - generative model outputs.""" - return self.meta.get('hr_out_features', []) - - @property - def hr_exo_features(self): - """Get list of high-resolution exogenous filter names the model uses. - If the model has N concat or add layers this list will be the last N - features in the training features list. The ordering is assumed to be - the same as the order of concat or add layers. If training features is - [..., topo, sza], and the model has 2 concat or add layers, exo - features will be [topo, sza]. Topo will then be used in the first - concat layer and sza will be used in the second""" - # pylint: disable=E1101 - features = [] - if hasattr(self, '_gen'): - features = [ - layer.name - for layer in self._gen.layers - if isinstance(layer, (Sup3rAdder, Sup3rConcat)) - ] - return features - - @property - def smoothing(self): - """Value of smoothing parameter used in gaussian filtering of coarsened - high res data.""" - return self.meta.get('smoothing', None) - - @property - def smoothed_features(self): - """Get the list of smoothed input feature names that the generative - model was trained on.""" - return self.meta.get('smoothed_features', []) - - @property - def model_params(self): - """ - Model parameters, used to save model to disc - - Returns - ------- - dict - """ - return {'meta': self.meta} - - @property - def version_record(self): - """A record of important versions that this model was built with. - - Returns - ------- - dict - """ - return VERSION_RECORD - - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'input_resolution', - 'lr_features', 'hr_exo_features', 'hr_out_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' - """ - - keys = ( - 'input_resolution', - 'lr_features', - 'hr_exo_features', - 'hr_out_features', - 'smoothed_features', - 's_enhance', - 't_enhance', - 'smoothing', - ) - keys = [k for k in keys if k in kwargs] - - hr_exo_feat = kwargs.get('hr_exo_features', []) - msg = ( - f'Expected high-res exo features {self.hr_exo_features} ' - f'based on model architecture but received "hr_exo_features" ' - f'from data handler: {hr_exo_feat}' - ) - assert list(self.hr_exo_features) == list(hr_exo_feat), msg - - for var in keys: - val = self.meta.get(var, None) - if val is None: - self.meta[var] = kwargs[var] - elif val != kwargs[var]: - msg = ( - 'Model was previously trained with {var}={} but ' - 'received new {var}={}'.format(val, kwargs[var], var=var) - ) - logger.warning(msg) - warn(msg) - - self._ensure_valid_enhancement_factors() - self._ensure_valid_input_resolution() - - def save_params(self, out_dir): - """ - Parameters - ---------- - out_dir : str - Directory to save linear model params. This directory will be - created if it does not already exist. - """ - if not os.path.exists(out_dir): - os.makedirs(out_dir, exist_ok=True) - - fp_params = os.path.join(out_dir, 'model_params.json') - with open( - fp_params, 'w', encoding=locale.getpreferredencoding(False) - ) as f: - params = self.model_params - json.dump(params, f, sort_keys=True, indent=2, default=safe_cast) +logger = logging.getLogger(__name__) # pylint: disable=E1101,W0201,E0203 class AbstractSingleModel(ABC, TensorboardMixIn): """ - Abstract class to define the required training interface - for Sup3r model subclasses + Abstract class to define the required training interface for Sup3r model + subclasses """ def __init__(self): @@ -1654,3 +1082,42 @@ def get_single_grad( loss, loss_details = loss_out grad = tape.gradient(loss, training_weights) return grad, loss_details + + @abstractmethod + def calc_loss( + self, + hi_res_true, + hi_res_gen, + obs_data=None, + weight_gen_advers=0.001, + train_gen=True, + train_disc=False, + ): + """Calculate the GAN loss function using generated and true high + resolution data. + + Parameters + ---------- + hi_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + hi_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + train_gen : bool + True if generator is being trained, then loss=loss_gen + train_disc : bool + True if disc is being trained, then loss=loss_disc + + Returns + ------- + loss : tf.Tensor + 0D tensor representing the loss value for the network being trained + (either generator or one of the discriminators) + loss_details : dict + Namespace of the breakdown of loss components + """ diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 4ff617226..c6405deac 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -14,7 +14,8 @@ from sup3r.preprocessing.utilities import get_class_kwargs from sup3r.utilities import VERSION_RECORD -from .abstract import AbstractInterface, AbstractSingleModel +from .abstract import AbstractSingleModel +from .interface import AbstractInterface from .utilities import get_optimizer_class logger = logging.getLogger(__name__) diff --git a/sup3r/models/conditional.py b/sup3r/models/conditional.py index a5025419e..6faed1a6f 100644 --- a/sup3r/models/conditional.py +++ b/sup3r/models/conditional.py @@ -12,7 +12,8 @@ from sup3r.utilities import VERSION_RECORD -from .abstract import AbstractInterface, AbstractSingleModel +from .abstract import AbstractSingleModel +from .interface import AbstractInterface logger = logging.getLogger(__name__) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py new file mode 100644 index 000000000..6d18ef7d1 --- /dev/null +++ b/sup3r/models/interface.py @@ -0,0 +1,527 @@ +"""Abstract class defining the required interface for Sup3r model subclasses""" + +import json +import locale +import logging +import os +import pprint +import re +import time +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from inspect import signature +from warnings import warn + +import numpy as np +import tensorflow as tf +from phygnn import CustomNetwork +from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat +from rex.utilities.utilities import safe_json_load +from tensorflow.keras import optimizers + +import sup3r.utilities.loss_metrics +from sup3r.preprocessing.data_handlers import ExoData +from sup3r.preprocessing.utilities import numpy_if_tensor +from sup3r.utilities import VERSION_RECORD +from sup3r.utilities.utilities import safe_cast + +from .tensorboard import TensorboardMixIn + +logger = logging.getLogger(__name__) + + +class AbstractInterface(ABC): + """ + Abstract class to define the required interface for Sup3r model subclasses + + Note that this only sets the required interfaces for a GAN that can be + loaded from disk and used to predict synthetic outputs. The interface for + models that can be trained will be set in another class. + """ + + @classmethod + @abstractmethod + def load(cls, model_dir, verbose=True): + """Load the GAN with its sub-networks from a previously saved-to output + directory. + + Parameters + ---------- + model_dir + Directory to load GAN model files from. + verbose : bool + Flag to log information about the loaded model. + + Returns + ------- + out : BaseModel + Returns a pretrained gan model that was previously saved to + model_dir + """ + + @abstractmethod + def generate( + self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None + ): + """Use the generator model to generate high res data from low res + input. This is the public generate function.""" + + @staticmethod + def seed(s=0): + """ + Set the random seed for reproducible results. + + Parameters + ---------- + s : int + Random seed + """ + CustomNetwork.seed(s=s) + + @property + def input_dims(self): + """Get dimension of model generator input. This is usually 4D for + spatial models and 5D for spatiotemporal models. This gives the input + to the first step if the model is multi-step. Returns 5 for linear + models. + + Returns + ------- + int + """ + # pylint: disable=E1101 + if hasattr(self, '_gen'): + return self._gen.layers[0].rank + if hasattr(self, 'models'): + return self.models[0].input_dims + return 5 + + @property + def is_5d(self): + """Check if model expects spatiotemporal input""" + return self.input_dims == 5 + + @property + def is_4d(self): + """Check if model expects spatial only input""" + return self.input_dims == 4 + + # pylint: disable=E1101 + def get_s_enhance_from_layers(self): + """Compute factor by which model will enhance spatial resolution from + layer attributes. Used in model training during high res coarsening""" + s_enhance = None + if hasattr(self, '_gen'): + s_enhancements = [ + getattr(layer, '_spatial_mult', 1) + for layer in self._gen.layers + ] + s_enhance = int(np.prod(s_enhancements)) + return s_enhance + + # pylint: disable=E1101 + def get_t_enhance_from_layers(self): + """Compute factor by which model will enhance temporal resolution from + layer attributes. Used in model training during high res coarsening""" + t_enhance = None + if hasattr(self, '_gen'): + t_enhancements = [ + getattr(layer, '_temporal_mult', 1) + for layer in self._gen.layers + ] + t_enhance = int(np.prod(t_enhancements)) + return t_enhance + + @property + def s_enhance(self): + """Factor by which model will enhance spatial resolution. Used in + model training during high res coarsening and also in forward pass + routine to determine shape of needed exogenous data""" + models = getattr(self, 'models', [self]) + s_enhances = [m.meta.get('s_enhance', None) for m in models] + s_enhance = ( + self.get_s_enhance_from_layers() + if any(s is None for s in s_enhances) + else int(np.prod(s_enhances)) + ) + if len(models) == 1: + self.meta['s_enhance'] = s_enhance + return s_enhance + + @property + def t_enhance(self): + """Factor by which model will enhance temporal resolution. Used in + model training during high res coarsening and also in forward pass + routine to determine shape of needed exogenous data""" + models = getattr(self, 'models', [self]) + t_enhances = [m.meta.get('t_enhance', None) for m in models] + t_enhance = ( + self.get_t_enhance_from_layers() + if any(t is None for t in t_enhances) + else int(np.prod(t_enhances)) + ) + if len(models) == 1: + self.meta['t_enhance'] = t_enhance + return t_enhance + + @property + def s_enhancements(self): + """List of spatial enhancement factors. In the case of a single step + model this is just ``[self.s_enhance]``. This is used to determine + shapes of needed exogenous data in forward pass routine""" + if hasattr(self, 'models'): + return [model.s_enhance for model in self.models] + return [self.s_enhance] + + @property + def t_enhancements(self): + """List of temporal enhancement factors. In the case of a single step + model this is just ``[self.t_enhance]``. This is used to determine + shapes of needed exogenous data in forward pass routine""" + if hasattr(self, 'models'): + return [model.t_enhance for model in self.models] + return [self.t_enhance] + + @property + def input_resolution(self): + """Resolution of input data. Given as a dictionary + ``{'spatial': '...km', 'temporal': '...min'}``. The numbers are + required to be integers in the units specified. The units are not + strict as long as the resolution of the exogenous data, when extracting + exogenous data, is specified in the same units.""" + input_resolution = self.meta.get('input_resolution', None) + msg = 'model.input_resolution is None. This needs to be set.' + assert input_resolution is not None, msg + return input_resolution + + def _get_numerical_resolutions(self): + """Get the input and output resolutions without units. e.g. for + ``{"spatial": "30km", "temporal": "60min"}`` this returns + ``{"spatial": 30, "temporal": 60}``""" + ires_num = { + k: int(re.search(r'\d+', v).group(0)) + for k, v in self.input_resolution.items() + } + enhancements = {'spatial': self.s_enhance, 'temporal': self.t_enhance} + ores_num = {k: v // enhancements[k] for k, v in ires_num.items()} + return ires_num, ores_num + + def _ensure_valid_input_resolution(self): + """Ensure ehancement factors evenly divide input_resolution""" + + if self.input_resolution is None: + return + + ires_num, ores_num = self._get_numerical_resolutions() + s_enhance = self.meta['s_enhance'] + t_enhance = self.meta['t_enhance'] + check = ( + ires_num['temporal'] / ores_num['temporal'] == t_enhance + and ires_num['spatial'] / ores_num['spatial'] == s_enhance + ) + msg = ( + f'Enhancement factors (s_enhance={s_enhance}, ' + f't_enhance={t_enhance}) do not evenly divide ' + f'input resolution ({self.input_resolution})' + ) + if not check: + logger.error(msg) + raise RuntimeError(msg) + + def _ensure_valid_enhancement_factors(self): + """Ensure user provided enhancement factors are the same as those + computed from layer attributes""" + t_enhance = self.meta.get('t_enhance', None) + s_enhance = self.meta.get('s_enhance', None) + if s_enhance is None or t_enhance is None: + return + + layer_se = self.get_s_enhance_from_layers() + layer_te = self.get_t_enhance_from_layers() + layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] + layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] + msg = ( + f'Enhancement factors computed from layer attributes ' + f'(s_enhance={layer_se}, t_enhance={layer_te}) ' + f'conflict with user provided values (s_enhance={s_enhance}, ' + f't_enhance={t_enhance})' + ) + check = layer_se == s_enhance or layer_te == t_enhance + if not check: + logger.error(msg) + raise RuntimeError(msg) + + @property + def output_resolution(self): + """Resolution of output data. Given as a dictionary + {'spatial': '...km', 'temporal': '...min'}. This is computed from the + input resolution and the enhancement factors.""" + output_res = self.meta.get('output_resolution', None) + if self.input_resolution is not None and output_res is None: + ires_num, ores_num = self._get_numerical_resolutions() + output_res = { + k: v.replace(str(ires_num[k]), str(ores_num[k])) + for k, v in self.input_resolution.items() + } + self.meta['output_resolution'] = output_res + return output_res + + def _combine_fwp_input(self, low_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data prior + to forward pass through generator + + Parameters + ---------- + low_res : np.ndarray + Low-resolution input data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | ExoData | None + Special dictionary (class:`ExoData`) of exogenous feature data with + entries describing whether features should be combined at input, a + mid network layer, or with output. This doesn't have to include + the 'model' key since this data is for a single step model. + + Returns + ------- + low_res : np.ndarray + Low-resolution input data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + if exogenous_data is None: + return low_res + + if ( + not isinstance(exogenous_data, ExoData) + and exogenous_data is not None + ): + exogenous_data = ExoData(exogenous_data) + + fnum_diff = len(self.lr_features) - low_res.shape[-1] + exo_feats = [] if fnum_diff <= 0 else self.lr_features[-fnum_diff:] + msg = ( + f'Provided exogenous_data: {exogenous_data} is missing some ' + f'required features ({exo_feats})' + ) + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + exo_input = exogenous_data.get_combine_type_data( + feature, 'input' + ) + if exo_input is not None: + low_res = np.concatenate((low_res, exo_input), axis=-1) + + return low_res + + def _combine_fwp_output(self, hi_res, exogenous_data=None): + """Combine exogenous_data at output resolution with generated hi_res + data following forward pass output. + + Parameters + ---------- + hi_res : np.ndarray + High-resolution output data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | ExoData | None + Special dictionary (class:`ExoData`) of exogenous feature data with + entries describing whether features should be combined at input, a + mid network layer, or with output. This doesn't have to include + the 'model' key since this data is for a single step model. + + Returns + ------- + hi_res : np.ndarray + High-resolution output data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + if exogenous_data is None: + return hi_res + + if ( + not isinstance(exogenous_data, ExoData) + and exogenous_data is not None + ): + exogenous_data = ExoData(exogenous_data) + + fnum_diff = len(self.hr_out_features) - hi_res.shape[-1] + exo_feats = [] if fnum_diff <= 0 else self.hr_out_features[-fnum_diff:] + msg = ( + 'Provided exogenous_data is missing some required features ' + f'({exo_feats})' + ) + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + exo_output = exogenous_data.get_combine_type_data( + feature, 'output' + ) + if exo_output is not None: + hi_res = np.concatenate((hi_res, exo_output), axis=-1) + return hi_res + + @tf.function + def _combine_loss_input(self, high_res_true, high_res_gen): + """Combine exogenous feature data from high_res_true with high_res_gen + for loss calculation + + Parameters + ---------- + high_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + high_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + high_res_gen : tf.Tensor + Same as input with exogenous data combined with high_res input + """ + if high_res_true.shape[-1] > high_res_gen.shape[-1]: + for feature in self.hr_exo_features: + f_idx = self.hr_exo_features.index(feature) + f_idx += len(self.hr_out_features) + exo_data = high_res_true[..., f_idx : f_idx + 1] + high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) + return high_res_gen + + @property + @abstractmethod + def meta(self): + """Get meta data dictionary that defines how the model was created""" + + @property + def lr_features(self): + """Get a list of low-resolution features input to the generative model. + This includes low-resolution features that might be supplied + exogenously at inference time but that were in the low-res batches + during training""" + return self.meta.get('lr_features', []) + + @property + def hr_out_features(self): + """Get the list of high-resolution output feature names that the + generative model outputs.""" + return self.meta.get('hr_out_features', []) + + @property + def hr_exo_features(self): + """Get list of high-resolution exogenous filter names the model uses. + If the model has N concat or add layers this list will be the last N + features in the training features list. The ordering is assumed to be + the same as the order of concat or add layers. If training features is + [..., topo, sza], and the model has 2 concat or add layers, exo + features will be [topo, sza]. Topo will then be used in the first + concat layer and sza will be used in the second""" + # pylint: disable=E1101 + features = [] + if hasattr(self, '_gen'): + features = [ + layer.name + for layer in self._gen.layers + if isinstance(layer, (Sup3rAdder, Sup3rConcat)) + ] + return features + + @property + def smoothing(self): + """Value of smoothing parameter used in gaussian filtering of coarsened + high res data.""" + return self.meta.get('smoothing', None) + + @property + def smoothed_features(self): + """Get the list of smoothed input feature names that the generative + model was trained on.""" + return self.meta.get('smoothed_features', []) + + @property + def model_params(self): + """ + Model parameters, used to save model to disc + + Returns + ------- + dict + """ + return {'meta': self.meta} + + @property + def version_record(self): + """A record of important versions that this model was built with. + + Returns + ------- + dict + """ + return VERSION_RECORD + + def set_model_params(self, **kwargs): + """Set parameters used for training the model + + Parameters + ---------- + kwargs : dict + Keyword arguments including 'input_resolution', + 'lr_features', 'hr_exo_features', 'hr_out_features', + 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + """ + + keys = ( + 'input_resolution', + 'lr_features', + 'hr_exo_features', + 'hr_out_features', + 'smoothed_features', + 's_enhance', + 't_enhance', + 'smoothing', + ) + keys = [k for k in keys if k in kwargs] + + hr_exo_feat = kwargs.get('hr_exo_features', []) + msg = ( + f'Expected high-res exo features {self.hr_exo_features} ' + f'based on model architecture but received "hr_exo_features" ' + f'from data handler: {hr_exo_feat}' + ) + assert list(self.hr_exo_features) == list(hr_exo_feat), msg + + for var in keys: + val = self.meta.get(var, None) + if val is None: + self.meta[var] = kwargs[var] + elif val != kwargs[var]: + msg = ( + 'Model was previously trained with {var}={} but ' + 'received new {var}={}'.format(val, kwargs[var], var=var) + ) + logger.warning(msg) + warn(msg) + + self._ensure_valid_enhancement_factors() + self._ensure_valid_input_resolution() + + def save_params(self, out_dir): + """ + Parameters + ---------- + out_dir : str + Directory to save linear model params. This directory will be + created if it does not already exist. + """ + if not os.path.exists(out_dir): + os.makedirs(out_dir, exist_ok=True) + + fp_params = os.path.join(out_dir, 'model_params.json') + with open( + fp_params, 'w', encoding=locale.getpreferredencoding(False) + ) as f: + params = self.model_params + json.dump(params, f, sort_keys=True, indent=2, default=safe_cast) diff --git a/sup3r/models/tensorboard.py b/sup3r/models/tensorboard.py new file mode 100644 index 000000000..00c576462 --- /dev/null +++ b/sup3r/models/tensorboard.py @@ -0,0 +1,85 @@ +"""Abstract class defining the required interface for Sup3r model subclasses""" + +import logging +import os + +import tensorflow as tf + +from sup3r.utilities.utilities import Timer + +logger = logging.getLogger(__name__) + + +class TensorboardMixIn: + """MixIn class for tensorboard logging and profiling.""" + + def __init__(self): + self._tb_writer = None + self._tb_log_dir = None + self._write_tb_profile = False + self._total_batches = None + self._history = None + self.timer = Timer() + + @property + def total_batches(self): + """Record of total number of batches for logging.""" + if self._total_batches is None and self._history is None: + self._total_batches = 0 + elif self._history is None and 'total_batches' in self._history: + self._total_batches = self._history['total_batches'].values[-1] + elif self._total_batches is None and self._history is not None: + self._total_batches = 0 + return self._total_batches + + @total_batches.setter + def total_batches(self, value): + """Set total number of batches.""" + self._total_batches = value + + def dict_to_tensorboard(self, entry): + """Write data to tensorboard log file. This is usually a loss_details + dictionary. + + Parameters + ---------- + entry: dict + Dictionary of values to write to tensorboard log file + """ + if self._tb_writer is not None: + with self._tb_writer.as_default(): + for name, value in entry.items(): + if isinstance(value, str): + tf.summary.text(name, value, self.total_batches) + else: + tf.summary.scalar(name, value, self.total_batches) + + def profile_to_tensorboard(self, name): + """Write profile data to tensorboard log file. + + Parameters + ---------- + name : str + Tag name to use for profile info + """ + if self._tb_writer is not None and self._write_tb_profile: + with self._tb_writer.as_default(): + tf.summary.trace_export( + name=name, + step=self.total_batches, + profiler_outdir=self._tb_log_dir, + ) + + def _init_tensorboard_writer(self, out_dir): + """Initialize the ``tf.summary.SummaryWriter`` to use for writing + tensorboard compatible log files. + + Parameters + ---------- + out_dir : str + Standard out_dir where model epochs are saved. e.g. './gan_{epoch}' + """ + tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir)) + self._tb_log_dir = os.path.join(tb_log_pardir, 'logs') + os.makedirs(self._tb_log_dir, exist_ok=True) + self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) \ No newline at end of file diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index bc1897e17..1cfa0c7ea 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -12,7 +12,7 @@ import pprint from abc import ABCMeta from collections import namedtuple -from typing import Mapping, Tuple, Union +from typing import Dict, Mapping, Tuple, Union from warnings import warn import numpy as np @@ -183,11 +183,9 @@ def rewrap(self, data): """Rewrap data as ``Sup3rDataset`` after calling parent method.""" if isinstance(data, type(self)): return data - if len(data) == 2: - return type(self)(low_res=data[0], high_res=data[1]) - if len(data) == 3: - return type(self)(low_res=data[0], high_res=data[1], obs=data[2]) - return type(self)(high_res=data[0]) + if len(data) == 1: + return type(self)(high_res=data[0]) + return type(self)(**dict(zip(['low_res', 'high_res', 'obs'], data))) def sample(self, idx): """Get samples from ``self._ds`` members. idx should be either a tuple @@ -295,7 +293,12 @@ class Container(metaclass=Sup3rMeta): def __init__( self, data: Union[ - Sup3rX, Sup3rDataset, Tuple[Sup3rX, ...], Tuple[Sup3rDataset, ...] + Sup3rX, + Sup3rDataset, + Tuple[Sup3rX, ...], + Tuple[Sup3rDataset, ...], + Dict[str, Sup3rX], + Dict[str, Sup3rDataset], ] = None, ): """ @@ -363,25 +366,19 @@ def wrap(self, data): if is_type_of(data, Sup3rDataset): return data - if isinstance(data, tuple) and len(data) == 2: - msg = ( - f'{self.__class__.__name__}.data is being set with a ' - '2-tuple without explicit dataset names. We will assume ' - 'first tuple member is low-res and second is high-res.' - ) - logger.warning(msg) - warn(msg) - data = Sup3rDataset(low_res=data[0], high_res=data[1]) - elif isinstance(data, tuple) and len(data) == 3: + if isinstance(data, dict): + data = Sup3rDataset(**data) + + default_names = ['low_res', 'high_res', 'obs'] + if isinstance(data, tuple) and len(data) > 1: msg = ( f'{self.__class__.__name__}.data is being set with a ' - '3-tuple without explicit dataset names. We will assume ' - 'first tuple member is low-res, second is high-res, and third ' - 'is obs' + f'{len(data)}-tuple without explicit dataset names. We will ' + f'assume name ordering: {default_names[:len(data)]}' ) logger.warning(msg) warn(msg) - data = Sup3rDataset(low_res=data[0], high_res=data[1], obs=data[2]) + data = Sup3rDataset(**dict(zip(default_names, data))) elif not isinstance(data, Sup3rDataset): name = getattr(data, 'name', None) or 'high_res' data = Sup3rDataset(**{name: data}) diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py index ae5a88357..01ab41e50 100644 --- a/tests/training/test_train_dual_with_obs.py +++ b/tests/training/test_train_dual_with_obs.py @@ -13,7 +13,6 @@ DataHandler, DualBatchHandlerWithObs, DualRasterizer, - Sup3rDataset, ) from sup3r.preprocessing.samplers import DualSamplerWithObs from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory @@ -46,9 +45,9 @@ def test_train_coarse_h5( fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 ): - """Test model training with a dual data handler / batch handler with h5 and - era as hr / lr datasets with additional observation data used in extra - content loss. Tests both spatiotemporal and spatial models.""" + """Test model training with a dual data handler / batch handler with + additional sparse observation data used in extra content loss term. Tests + both spatiotemporal and spatial models.""" lr = 1e-5 kwargs = { @@ -84,11 +83,11 @@ def test_train_coarse_h5( obs_data[feat] = (obs_data[feat].dims, tmp) dual_with_obs = Container( - data=Sup3rDataset( - low_res=dual_rasterizer.low_res, - high_res=dual_rasterizer.high_res, - obs=obs_data, - ) + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': dual_rasterizer.high_res, + 'obs': obs_data, + } ) batch_handler = DualBatchHandlerWithObsTester( From ee79c133975496461b6394a0292d1fdbf6478d0e Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 22 Dec 2024 07:52:24 -0800 Subject: [PATCH 04/32] made dual batch queue flexible enough to account for additional obs member in data --- README.rst | 2 +- pyproject.toml | 15 ++++- sup3r/models/base.py | 2 +- sup3r/models/interface.py | 10 --- sup3r/models/linear.py | 2 +- sup3r/models/multi_step.py | 2 +- sup3r/models/tensorboard.py | 2 +- sup3r/preprocessing/__init__.py | 8 ++- sup3r/preprocessing/batch_handlers/factory.py | 6 +- sup3r/preprocessing/batch_queues/__init__.py | 3 +- sup3r/preprocessing/batch_queues/abstract.py | 18 +++-- sup3r/preprocessing/batch_queues/dual.py | 20 ++++-- sup3r/preprocessing/batch_queues/with_obs.py | 67 ------------------- sup3r/preprocessing/samplers/dual.py | 4 +- sup3r/preprocessing/samplers/with_obs.py | 27 ++++++-- 15 files changed, 82 insertions(+), 106 deletions(-) delete mode 100644 sup3r/preprocessing/batch_queues/with_obs.py diff --git a/README.rst b/README.rst index d3b0aea48..3d321accc 100644 --- a/README.rst +++ b/README.rst @@ -78,4 +78,4 @@ Brandon Benton, Grant Buster, Guilherme Pimenta Castelao, Malik Hassanaly, Pavlo Acknowledgments =============== -This work was authored by the National Renewable Energy Laboratory, operated by Alliance for Sustainable Energy, LLC, for the U.S. Department of Energy (DOE) under Contract No. DE-AC36-08GO28308. This research was supported by the Grid Modernization Initiative of the U.S. Department of Energy (DOE) as part of its Grid Modernization Laboratory Consortium, a strategic partnership between DOE and the national laboratories to bring together leading experts, technologies, and resources to collaborate on the goal of modernizing the nation’s grid. Funding provided by the the DOE Office of Energy Efficiency and Renewable Energy (EERE), the DOE Office of Electricity (OE), DOE Grid Deployment Office (GDO), the DOE Office of Fossil Energy and Carbon Management (FECM), and the DOE Office of Cybersecurity, Energy Security, and Emergency Response (CESER), the DOE Advanced Scientific Computing Research (ASCR) program, the DOE Solar Energy Technologies Office (SETO), the DOE Wind Energy Technologies Office (WETO), the United States Agency for International Development (USAID), and the Laboratory Directed Research and Development (LDRD) program at the National Renewable Energy Laboratory. The research was performed using computational resources sponsored by the Department of Energy's Office of Energy Efficiency and Renewable Energy and located at the National Renewable Energy Laboratory. The views expressed in the article do not necessarily represent the views of the DOE or the U.S. Government. The U.S. Government retains and the publisher, by accepting the article for publication, acknowledges that the U.S. Government retains a nonexclusive, paid-up, irrevocable, worldwide license to publish or reproduce the published form of this work, or allow others to do so, for U.S. Government purposes. +This work was authored by the National Renewable Energy Laboratory, operated by Alliance for Sustainable Energy, LLC, for the U.S. Department of Energy (DOE) under Contract No. DE-AC36-08GO28308. This research was supported by the Grid Modernization Initiative of the U.S. Department of Energy (DOE) as part of its Grid Modernization Laboratory Consortium, a strategic partnership between DOE and the national laboratories to bring together leading experts, technologies, and resources to collaborate on the goal of modernizing the nation’s grid. Funding provided by the the DOE Office of Energy Efficiency and Renewable Energy (EERE), the DOE Office of Electricity (OE), DOE Grid Deployment Office (GDO), the DOE Office of Fossil Energy and Carbon Management (FECM), and the DOE Office of Cybersecurity, Energy Security, and Emergency Response (CESER), the DOE Advanced Scientific Computing Research (ASCR) program, the DOE Solar Energy Technologies Office (SETO), the DOE Wind Energy Technologies Office (WETO), the United States Agency for International Development (USAID), and the Laboratory Directed Research and Development (LDRD) program at the National Renewable Energy Laboratory. The research was performed using computational resources sponsored by the Department of Energy's Office of Energy Efficiency and Renewable Energy and located at the National Renewable Energy Laboratory. The views expressed in the article do not necessarily represent the views of the DOE or the U.S. Government. The U.S. Government retains and the publisher, by accepting the article for publication, acknowledges that the U.S. Government retains a nonexclusive, paid-up, irrevocable, worldwide license to publish or reproduce the published form of this work, or allow others to do so, for U.S. Government purposes. diff --git a/pyproject.toml b/pyproject.toml index 80bc45ec4..e1c92afaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,10 +42,21 @@ dependencies = [ "pytest>=5.2", "scipy>=1.0.0", "sphinx>=7.0", - "tensorflow>2.4,<2.16", "xarray>=2023.0" ] +# If used, cause glibc conflict +# [tool.pixi.target.linux-64.dependencies] +# cuda = ">=11.8" +# cudnn = {version = ">=8.6.0", channel = "conda-forge"} +# # 8.9.7 + +[tool.pixi.target.linux-64.pypi-dependencies] +tensorflow = {version = "~=2.15.1", extras = ["and-cuda"] } + +[tool.pixi.target.osx-arm64.dependencies] +tensorflow = {version = "~=2.15.0", channel = "conda-forge"} + [project.optional-dependencies] dev = [ "build>=0.5", @@ -272,7 +283,6 @@ matplotlib = ">=3.1" numpy = "~=1.7" pandas = ">=2.0" scipy = ">=1.0.0" -tensorflow = ">2.4,<2.16" xarray = ">=2023.0" [tool.pixi.pypi-dependencies] @@ -284,6 +294,7 @@ NREL-farms = { version = ">=1.0.4" } [tool.pixi.environments] default = { solve-group = "default" } +kestrel = { features = ["kestrel"], solve-group = "default" } dev = { features = ["dev", "doc", "test"], solve-group = "default" } doc = { features = ["doc"], solve-group = "default" } test = { features = ["test"], solve-group = "default" } diff --git a/sup3r/models/base.py b/sup3r/models/base.py index c6405deac..a3deb0243 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -877,7 +877,7 @@ def calc_loss( loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers - loss_obs = None + loss_obs = np.nan if obs_data is not None: mask = tf.math.is_nan(obs_data) loss_obs = self.loss_fun(obs_data[~mask], hi_res_gen[~mask]) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 6d18ef7d1..c81065f86 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -4,29 +4,19 @@ import locale import logging import os -import pprint import re -import time from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from inspect import signature from warnings import warn import numpy as np import tensorflow as tf from phygnn import CustomNetwork from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat -from rex.utilities.utilities import safe_json_load -from tensorflow.keras import optimizers -import sup3r.utilities.loss_metrics from sup3r.preprocessing.data_handlers import ExoData -from sup3r.preprocessing.utilities import numpy_if_tensor from sup3r.utilities import VERSION_RECORD from sup3r.utilities.utilities import safe_cast -from .tensorboard import TensorboardMixIn - logger = logging.getLogger(__name__) diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index 2c00a02b5..8ca9f0a0a 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -6,7 +6,7 @@ import numpy as np -from .abstract import AbstractInterface +from .interface import AbstractInterface from .utilities import st_interp logger = logging.getLogger(__name__) diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index d447633cf..e0d33cca0 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -14,8 +14,8 @@ import sup3r.models from sup3r.preprocessing.data_handlers import ExoData -from .abstract import AbstractInterface from .base import Sup3rGan +from .interface import AbstractInterface logger = logging.getLogger(__name__) diff --git a/sup3r/models/tensorboard.py b/sup3r/models/tensorboard.py index 00c576462..4712c9522 100644 --- a/sup3r/models/tensorboard.py +++ b/sup3r/models/tensorboard.py @@ -82,4 +82,4 @@ def _init_tensorboard_writer(self, out_dir): tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir)) self._tb_log_dir = os.path.join(tb_log_pardir, 'logs') os.makedirs(self._tb_log_dir, exist_ok=True) - self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) \ No newline at end of file + self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) diff --git a/sup3r/preprocessing/__init__.py b/sup3r/preprocessing/__init__.py index 0832d8d3a..4420efba6 100644 --- a/sup3r/preprocessing/__init__.py +++ b/sup3r/preprocessing/__init__.py @@ -64,4 +64,10 @@ Rasterizer, SzaRasterizer, ) -from .samplers import DualSampler, DualSamplerCC, Sampler, SamplerDC +from .samplers import ( + DualSampler, + DualSamplerCC, + DualSamplerWithObs, + Sampler, + SamplerDC, +) diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index 4b6ddf3aa..a2b0be526 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -13,8 +13,10 @@ QueueMom2SepSF, QueueMom2SF, ) -from sup3r.preprocessing.batch_queues.dual import DualBatchQueue -from sup3r.preprocessing.batch_queues.with_obs import DualBatchQueueWithObs +from sup3r.preprocessing.batch_queues.dual import ( + DualBatchQueue, + DualBatchQueueWithObs, +) from sup3r.preprocessing.collections.stats import StatsCollection from sup3r.preprocessing.samplers.base import Sampler from sup3r.preprocessing.samplers.cc import DualSamplerCC diff --git a/sup3r/preprocessing/batch_queues/__init__.py b/sup3r/preprocessing/batch_queues/__init__.py index 067d21070..1e91caa1b 100644 --- a/sup3r/preprocessing/batch_queues/__init__.py +++ b/sup3r/preprocessing/batch_queues/__init__.py @@ -11,5 +11,4 @@ QueueMom2SF, ) from .dc import BatchQueueDC, ValBatchQueueDC -from .dual import DualBatchQueue -from .with_obs import DualBatchQueueWithObs +from .dual import DualBatchQueue, DualBatchQueueWithObs diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 772a172c6..6bc09afdc 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -22,7 +22,11 @@ from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer if TYPE_CHECKING: - from sup3r.preprocessing.samplers import DualSampler, Sampler + from sup3r.preprocessing.samplers import ( + DualSampler, + DualSamplerWithObs, + Sampler, + ) logger = logging.getLogger(__name__) @@ -36,7 +40,9 @@ class AbstractBatchQueue(Collection, ABC): def __init__( self, - samplers: Union[List['Sampler'], List['DualSampler']], + samplers: Union[ + List['Sampler'], List['DualSampler'], List['DualSamplerWithObs'] + ], batch_size: int = 16, n_batches: int = 64, s_enhance: int = 1, @@ -183,10 +189,12 @@ def post_proc(self, samples) -> Batch: Returns ------- Batch : namedtuple - namedtuple with `low_res` and `high_res` attributes + namedtuple with `low_res` and `high_res` attributes. Could also + include additional members for subclass queues + (i.e. ``DualBatchQueueWithObs``) """ - lr, hr = self.transform(samples, **self.transform_kwargs) - return self.Batch(low_res=lr, high_res=hr) + tsamps = self.transform(samples, **self.transform_kwargs) + return self.Batch(**dict(zip(self.Batch._fields, tsamps))) def start(self) -> None: """Start thread to keep sample queue full for batches.""" diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 691350cf7..549145f42 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -2,6 +2,7 @@ interface with models.""" import logging +from collections import namedtuple from scipy.ndimage import gaussian_filter @@ -28,10 +29,10 @@ def __init__(self, samplers, **kwargs): @property def queue_shape(self): """Shape of objects stored in the queue.""" - return [ - (self.batch_size, *self.lr_shape), - (self.batch_size, *self.hr_shape), - ] + queue_shapes = [(self.batch_size, *self.lr_shape)] + hr_mems = len(self.Batch._fields) - 1 + queue_shapes += [(self.batch_size, *self.hr_shape)] * hr_mems + return queue_shapes def check_enhancement_factors(self): """Make sure each DualSampler has the same enhancment factors and they @@ -58,7 +59,7 @@ def transform(self, samples, smoothing=None, smoothing_ignore=None): This does not include temporal or spatial coarsening like :class:`SingleBatchQueue` """ - low_res, high_res = samples + low_res = samples[0] if smoothing is not None: feat_iter = [ @@ -71,4 +72,11 @@ def transform(self, samples, smoothing=None, smoothing_ignore=None): low_res[i, ..., j] = gaussian_filter( low_res[i, ..., j], smoothing, mode='nearest' ) - return low_res, high_res + return low_res, *samples[1:] + + +class DualBatchQueueWithObs(DualBatchQueue): + """BatchQueue for use with + :class:`~sup3r.preprocessing.samplers.DualSamplerWithObs` objects.""" + + Batch = namedtuple('Batch', ['low_res', 'high_res', 'obs']) diff --git a/sup3r/preprocessing/batch_queues/with_obs.py b/sup3r/preprocessing/batch_queues/with_obs.py deleted file mode 100644 index 0a0a0890e..000000000 --- a/sup3r/preprocessing/batch_queues/with_obs.py +++ /dev/null @@ -1,67 +0,0 @@ -"""DualBatchQueue with additional observation data on the same grid as the -high-res data. The observation data is sampled with the same index as the -high-res data during training.""" - -import logging -from collections import namedtuple - -from scipy.ndimage import gaussian_filter - -from .dual import DualBatchQueue - -logger = logging.getLogger(__name__) - - -class DualBatchQueueWithObs(DualBatchQueue): - """Base BatchQueue for use with - :class:`~sup3r.preprocessing.samplers.DualSamplerWithObs` objects.""" - - Batch = namedtuple('Batch', ['low_res', 'high_res', 'obs']) - - _signature_objs = (DualBatchQueue,) - - @property - def queue_shape(self): - """Shape of objects stored in the queue.""" - return [ - (self.batch_size, *self.lr_shape), - (self.batch_size, *self.hr_shape), - (self.batch_size, *self.hr_shape), - ] - - def transform(self, samples, smoothing=None, smoothing_ignore=None): - """Perform smoothing if requested. - - Note - ---- - This does not include temporal or spatial coarsening like - :class:`SingleBatchQueue` - """ - low_res, high_res, obs = samples - - if smoothing is not None: - feat_iter = [ - j - for j in range(low_res.shape[-1]) - if self.features[j] not in smoothing_ignore - ] - for i in range(low_res.shape[0]): - for j in feat_iter: - low_res[i, ..., j] = gaussian_filter( - low_res[i, ..., j], smoothing, mode='nearest' - ) - return low_res, high_res, obs - - def post_proc(self, samples) -> Batch: - """Performs some post proc on dequeued samples before sending out for - training. Post processing can include coarsening on high-res data (if - :class:`Collection` consists of :class:`Sampler` objects and not - :class:`DualSampler` objects), smoothing, etc - - Returns - ------- - Batch : namedtuple - namedtuple with `low_res`, `high_res`, and `obs` attributes - """ - lr, hr, obs = self.transform(samples, **self.transform_kwargs) - return self.Batch(low_res=lr, high_res=hr, obs=obs) diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index bf259dff3..008e0577f 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -109,11 +109,11 @@ def check_for_consistent_shapes(self): self.lr_data.shape[2] * self.t_enhance, ) msg = ( - f'hr_data.shape {self.hr_data.shape} and enhanced ' + f'hr_data.shape {self.hr_data.shape[:-1]} and enhanced ' f'lr_data.shape {enhanced_shape} are not compatible with ' 'the given enhancement factors' ) - assert self.hr_data.shape[:3] == enhanced_shape, msg + assert self.hr_data.shape[:-1] == enhanced_shape, msg def get_sample_index(self, n_obs=None): """Get paired sample index, consisting of index for the low res sample diff --git a/sup3r/preprocessing/samplers/with_obs.py b/sup3r/preprocessing/samplers/with_obs.py index 267794b9b..1f74f8d80 100644 --- a/sup3r/preprocessing/samplers/with_obs.py +++ b/sup3r/preprocessing/samplers/with_obs.py @@ -1,10 +1,14 @@ """Extended Sampler for sampling observation data in addition to standard gridded training data.""" +import logging from typing import Dict, Optional from sup3r.preprocessing.base import Sup3rDataset -from sup3r.preprocessing.samplers.dual import DualSampler + +from .dual import DualSampler + +logger = logging.getLogger(__name__) class DualSamplerWithObs(DualSampler): @@ -19,7 +23,7 @@ def __init__( sample_shape: Optional[tuple] = None, batch_size: int = 16, s_enhance: int = 1, - t_enhance: int = 24, + t_enhance: int = 1, feature_sets: Optional[Dict] = None, ): """ @@ -27,8 +31,7 @@ def __init__( ---------- data : Sup3rDataset A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with - low-res, high-res, and obs data members. The observation data is on - the same grid as the high-res data. + low-res and high-res data members sample_shape : tuple Size of arrays to sample from the high-res data. The sample shape for the low-res sampler will be determined from the enhancement @@ -51,6 +54,22 @@ def __init__( output from the generative model. An example is high-res topography that is to be injected mid-network. """ + + msg = ( + f'{self.__class__.__name__} requires a Sup3rDataset object ' + 'with `.low_res`, `.high_res`, and `.obs` data members, in that ' + 'order' + ) + assert ( + hasattr(data, 'low_res') + and hasattr(data, 'high_res') + and hasattr(data, 'obs') + ), msg + assert ( + data.low_res == data[0] + and data.high_res == data[1] + and data.obs == data[2] + ), msg super().__init__( data, sample_shape=sample_shape, From 1b61a66356d34662e140959026680f3c18ad80fd Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 22 Dec 2024 12:26:42 -0800 Subject: [PATCH 05/32] tensorboard mixin moved to model utilities. dual queue completely absorbed WithObs variant. --- sup3r/models/abstract.py | 2 +- sup3r/models/tensorboard.py | 85 ------------------ sup3r/models/utilities.py | 79 +++++++++++++++++ sup3r/preprocessing/__init__.py | 7 +- sup3r/preprocessing/batch_handlers/factory.py | 10 +-- sup3r/preprocessing/batch_queues/__init__.py | 2 +- sup3r/preprocessing/batch_queues/abstract.py | 4 +- sup3r/preprocessing/batch_queues/dual.py | 8 +- sup3r/preprocessing/samplers/__init__.py | 3 +- sup3r/preprocessing/samplers/dual.py | 77 ++++++++++++++++ sup3r/preprocessing/samplers/with_obs.py | 88 ------------------- 11 files changed, 166 insertions(+), 199 deletions(-) delete mode 100644 sup3r/models/tensorboard.py delete mode 100644 sup3r/preprocessing/samplers/with_obs.py diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index dc94bc52a..17836deb2 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -23,7 +23,7 @@ from sup3r.utilities import VERSION_RECORD from sup3r.utilities.utilities import safe_cast -from .tensorboard import TensorboardMixIn +from .utilities import TensorboardMixIn logger = logging.getLogger(__name__) diff --git a/sup3r/models/tensorboard.py b/sup3r/models/tensorboard.py deleted file mode 100644 index 4712c9522..000000000 --- a/sup3r/models/tensorboard.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Abstract class defining the required interface for Sup3r model subclasses""" - -import logging -import os - -import tensorflow as tf - -from sup3r.utilities.utilities import Timer - -logger = logging.getLogger(__name__) - - -class TensorboardMixIn: - """MixIn class for tensorboard logging and profiling.""" - - def __init__(self): - self._tb_writer = None - self._tb_log_dir = None - self._write_tb_profile = False - self._total_batches = None - self._history = None - self.timer = Timer() - - @property - def total_batches(self): - """Record of total number of batches for logging.""" - if self._total_batches is None and self._history is None: - self._total_batches = 0 - elif self._history is None and 'total_batches' in self._history: - self._total_batches = self._history['total_batches'].values[-1] - elif self._total_batches is None and self._history is not None: - self._total_batches = 0 - return self._total_batches - - @total_batches.setter - def total_batches(self, value): - """Set total number of batches.""" - self._total_batches = value - - def dict_to_tensorboard(self, entry): - """Write data to tensorboard log file. This is usually a loss_details - dictionary. - - Parameters - ---------- - entry: dict - Dictionary of values to write to tensorboard log file - """ - if self._tb_writer is not None: - with self._tb_writer.as_default(): - for name, value in entry.items(): - if isinstance(value, str): - tf.summary.text(name, value, self.total_batches) - else: - tf.summary.scalar(name, value, self.total_batches) - - def profile_to_tensorboard(self, name): - """Write profile data to tensorboard log file. - - Parameters - ---------- - name : str - Tag name to use for profile info - """ - if self._tb_writer is not None and self._write_tb_profile: - with self._tb_writer.as_default(): - tf.summary.trace_export( - name=name, - step=self.total_batches, - profiler_outdir=self._tb_log_dir, - ) - - def _init_tensorboard_writer(self, out_dir): - """Initialize the ``tf.summary.SummaryWriter`` to use for writing - tensorboard compatible log files. - - Parameters - ---------- - out_dir : str - Standard out_dir where model epochs are saved. e.g. './gan_{epoch}' - """ - tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir)) - self._tb_log_dir = os.path.join(tb_log_pardir, 'logs') - os.makedirs(self._tb_log_dir, exist_ok=True) - self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index 8e825f124..68e0f0102 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -1,13 +1,17 @@ """Utilities shared across the `sup3r.models` module""" import logging +import os import sys import threading import numpy as np +import tensorflow as tf from scipy.interpolate import RegularGridInterpolator from tensorflow.keras import optimizers +from sup3r.utilities.utilities import Timer + logger = logging.getLogger(__name__) @@ -58,6 +62,81 @@ def run(self): model_thread.join() +class TensorboardMixIn: + """MixIn class for tensorboard logging and profiling.""" + + def __init__(self): + self._tb_writer = None + self._tb_log_dir = None + self._write_tb_profile = False + self._total_batches = None + self._history = None + self.timer = Timer() + + @property + def total_batches(self): + """Record of total number of batches for logging.""" + if self._total_batches is None and self._history is None: + self._total_batches = 0 + elif self._history is None and 'total_batches' in self._history: + self._total_batches = self._history['total_batches'].values[-1] + elif self._total_batches is None and self._history is not None: + self._total_batches = 0 + return self._total_batches + + @total_batches.setter + def total_batches(self, value): + """Set total number of batches.""" + self._total_batches = value + + def dict_to_tensorboard(self, entry): + """Write data to tensorboard log file. This is usually a loss_details + dictionary. + + Parameters + ---------- + entry: dict + Dictionary of values to write to tensorboard log file + """ + if self._tb_writer is not None: + with self._tb_writer.as_default(): + for name, value in entry.items(): + if isinstance(value, str): + tf.summary.text(name, value, self.total_batches) + else: + tf.summary.scalar(name, value, self.total_batches) + + def profile_to_tensorboard(self, name): + """Write profile data to tensorboard log file. + + Parameters + ---------- + name : str + Tag name to use for profile info + """ + if self._tb_writer is not None and self._write_tb_profile: + with self._tb_writer.as_default(): + tf.summary.trace_export( + name=name, + step=self.total_batches, + profiler_outdir=self._tb_log_dir, + ) + + def _init_tensorboard_writer(self, out_dir): + """Initialize the ``tf.summary.SummaryWriter`` to use for writing + tensorboard compatible log files. + + Parameters + ---------- + out_dir : str + Standard out_dir where model epochs are saved. e.g. './gan_{epoch}' + """ + tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir)) + self._tb_log_dir = os.path.join(tb_log_pardir, 'logs') + os.makedirs(self._tb_log_dir, exist_ok=True) + self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) + + def get_optimizer_class(conf): """Get optimizer class from keras""" if hasattr(optimizers, conf['name']): diff --git a/sup3r/preprocessing/__init__.py b/sup3r/preprocessing/__init__.py index 4420efba6..0070fad12 100644 --- a/sup3r/preprocessing/__init__.py +++ b/sup3r/preprocessing/__init__.py @@ -35,12 +35,7 @@ DualBatchHandler, DualBatchHandlerWithObs, ) -from .batch_queues import ( - BatchQueueDC, - DualBatchQueue, - DualBatchQueueWithObs, - SingleBatchQueue, -) +from .batch_queues import BatchQueueDC, DualBatchQueue, SingleBatchQueue from .cachers import Cacher from .collections import Collection, StatsCollection from .data_handlers import ( diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index a2b0be526..c43e9198a 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -13,15 +13,11 @@ QueueMom2SepSF, QueueMom2SF, ) -from sup3r.preprocessing.batch_queues.dual import ( - DualBatchQueue, - DualBatchQueueWithObs, -) +from sup3r.preprocessing.batch_queues.dual import DualBatchQueue from sup3r.preprocessing.collections.stats import StatsCollection from sup3r.preprocessing.samplers.base import Sampler from sup3r.preprocessing.samplers.cc import DualSamplerCC -from sup3r.preprocessing.samplers.dual import DualSampler -from sup3r.preprocessing.samplers.with_obs import DualSamplerWithObs +from sup3r.preprocessing.samplers.dual import DualSampler, DualSamplerWithObs from sup3r.preprocessing.utilities import ( check_signatures, get_class_kwargs, @@ -321,7 +317,7 @@ def stop(self): ) DualBatchHandlerWithObs = BatchHandlerFactory( - DualBatchQueueWithObs, DualSamplerWithObs, name='DualBatchHandlerWithObs' + DualBatchQueue, DualSamplerWithObs, name='DualBatchHandlerWithObs' ) BatchHandlerCC = BatchHandlerFactory( diff --git a/sup3r/preprocessing/batch_queues/__init__.py b/sup3r/preprocessing/batch_queues/__init__.py index 1e91caa1b..63053f123 100644 --- a/sup3r/preprocessing/batch_queues/__init__.py +++ b/sup3r/preprocessing/batch_queues/__init__.py @@ -11,4 +11,4 @@ QueueMom2SF, ) from .dc import BatchQueueDC, ValBatchQueueDC -from .dual import DualBatchQueue, DualBatchQueueWithObs +from .dual import DualBatchQueue diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 6bc09afdc..01c6d98c4 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -190,8 +190,8 @@ def post_proc(self, samples) -> Batch: ------- Batch : namedtuple namedtuple with `low_res` and `high_res` attributes. Could also - include additional members for subclass queues - (i.e. ``DualBatchQueueWithObs``) + include additional members for integration with + ``DualSamplerWithObs`` """ tsamps = self.transform(samples, **self.transform_kwargs) return self.Batch(**dict(zip(self.Batch._fields, tsamps))) diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 549145f42..4b64c593d 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -21,6 +21,7 @@ def __init__(self, samplers, **kwargs): -------- :class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue` """ + self.Batch = namedtuple('Batch', samplers[0]._fields) super().__init__(samplers, **kwargs) self.check_enhancement_factors() @@ -73,10 +74,3 @@ def transform(self, samples, smoothing=None, smoothing_ignore=None): low_res[i, ..., j], smoothing, mode='nearest' ) return low_res, *samples[1:] - - -class DualBatchQueueWithObs(DualBatchQueue): - """BatchQueue for use with - :class:`~sup3r.preprocessing.samplers.DualSamplerWithObs` objects.""" - - Batch = namedtuple('Batch', ['low_res', 'high_res', 'obs']) diff --git a/sup3r/preprocessing/samplers/__init__.py b/sup3r/preprocessing/samplers/__init__.py index 990e23861..545da5a3c 100644 --- a/sup3r/preprocessing/samplers/__init__.py +++ b/sup3r/preprocessing/samplers/__init__.py @@ -8,5 +8,4 @@ from .base import Sampler from .cc import DualSamplerCC from .dc import SamplerDC -from .dual import DualSampler -from .with_obs import DualSamplerWithObs +from .dual import DualSampler, DualSamplerWithObs diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index 008e0577f..af144b8ca 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -137,3 +137,80 @@ def get_sample_index(self, n_obs=None): ] hr_index = (*hr_index, self.hr_features) return (lr_index, hr_index) + + +class DualSamplerWithObs(DualSampler): + """Dual Sampler which also samples from extra observation data. The + observation data is on the same grid as the high-resolution data but + includes NaNs at points where observation data doesn't exist. This will + be used in an additional content loss term.""" + + def __init__( + self, + data: Sup3rDataset, + sample_shape: Optional[tuple] = None, + batch_size: int = 16, + s_enhance: int = 1, + t_enhance: int = 1, + feature_sets: Optional[Dict] = None, + ): + """ + Parameters + ---------- + data : Sup3rDataset + A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with + low-res and high-res data members + sample_shape : tuple + Size of arrays to sample from the high-res data. The sample shape + for the low-res sampler will be determined from the enhancement + factors. + s_enhance : int + Spatial enhancement factor + t_enhance : int + Temporal enhancement factor + feature_sets : Optional[dict] + Optional dictionary describing how the full set of features is + split between `lr_only_features` and `hr_exo_features`. + + lr_only_features : list | tuple + List of feature names or patt*erns that should only be + included in the low-res training set and not the high-res + observations. + hr_exo_features : list | tuple + List of feature names or patt*erns that should be included + in the high-resolution observation but not expected to be + output from the generative model. An example is high-res + topography that is to be injected mid-network. + """ + + msg = ( + f'{self.__class__.__name__} requires a Sup3rDataset object ' + 'with `.low_res`, `.high_res`, and `.obs` data members, in that ' + 'order' + ) + assert ( + hasattr(data, 'low_res') + and hasattr(data, 'high_res') + and hasattr(data, 'obs') + ), msg + assert ( + data.low_res == data[0] + and data.high_res == data[1] + and data.obs == data[2] + ), msg + super().__init__( + data, + sample_shape=sample_shape, + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + feature_sets=feature_sets, + ) + + def get_sample_index(self, n_obs=None): + """Get paired sample index, consisting of index for the low res sample + and the index for the high res sample with the same spatiotemporal + extent, with an additional index (same as the index for the high-res + data) for the observation data""" + lr_index, hr_index = super().get_sample_index(n_obs=n_obs) + return (lr_index, hr_index, hr_index) diff --git a/sup3r/preprocessing/samplers/with_obs.py b/sup3r/preprocessing/samplers/with_obs.py deleted file mode 100644 index 1f74f8d80..000000000 --- a/sup3r/preprocessing/samplers/with_obs.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Extended Sampler for sampling observation data in addition to standard -gridded training data.""" - -import logging -from typing import Dict, Optional - -from sup3r.preprocessing.base import Sup3rDataset - -from .dual import DualSampler - -logger = logging.getLogger(__name__) - - -class DualSamplerWithObs(DualSampler): - """Dual Sampler which also samples from extra observation data. The - observation data is on the same grid as the high-resolution data but - includes NaNs at points where observation data doesn't exist. This will - be used in an additional content loss term.""" - - def __init__( - self, - data: Sup3rDataset, - sample_shape: Optional[tuple] = None, - batch_size: int = 16, - s_enhance: int = 1, - t_enhance: int = 1, - feature_sets: Optional[Dict] = None, - ): - """ - Parameters - ---------- - data : Sup3rDataset - A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with - low-res and high-res data members - sample_shape : tuple - Size of arrays to sample from the high-res data. The sample shape - for the low-res sampler will be determined from the enhancement - factors. - s_enhance : int - Spatial enhancement factor - t_enhance : int - Temporal enhancement factor - feature_sets : Optional[dict] - Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. - - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. - hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. - """ - - msg = ( - f'{self.__class__.__name__} requires a Sup3rDataset object ' - 'with `.low_res`, `.high_res`, and `.obs` data members, in that ' - 'order' - ) - assert ( - hasattr(data, 'low_res') - and hasattr(data, 'high_res') - and hasattr(data, 'obs') - ), msg - assert ( - data.low_res == data[0] - and data.high_res == data[1] - and data.obs == data[2] - ), msg - super().__init__( - data, - sample_shape=sample_shape, - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - feature_sets=feature_sets, - ) - - def get_sample_index(self, n_obs=None): - """Get paired sample index, consisting of index for the low res sample - and the index for the high res sample with the same spatiotemporal - extent, with an additional index (same as the index for the high-res - data) for the observation data""" - lr_index, hr_index = super().get_sample_index(n_obs=n_obs) - return (lr_index, hr_index, hr_index) From df698e11e5d946a85bce79c70d165e6cf5496bcf Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 22 Dec 2024 19:13:50 -0800 Subject: [PATCH 06/32] integrated dual sampler with obs into base dual sampler. --- sup3r/preprocessing/__init__.py | 2 - .../preprocessing/batch_handlers/__init__.py | 1 - sup3r/preprocessing/batch_handlers/factory.py | 6 +- sup3r/preprocessing/batch_queues/abstract.py | 2 +- sup3r/preprocessing/rasterizers/dual.py | 4 +- sup3r/preprocessing/samplers/__init__.py | 2 +- sup3r/preprocessing/samplers/dual.py | 102 +++------------- tests/training/test_train_dual_with_obs.py | 113 +++++++++++++++++- 8 files changed, 135 insertions(+), 97 deletions(-) diff --git a/sup3r/preprocessing/__init__.py b/sup3r/preprocessing/__init__.py index 0070fad12..315e9b79a 100644 --- a/sup3r/preprocessing/__init__.py +++ b/sup3r/preprocessing/__init__.py @@ -33,7 +33,6 @@ BatchHandlerMom2SepSF, BatchHandlerMom2SF, DualBatchHandler, - DualBatchHandlerWithObs, ) from .batch_queues import BatchQueueDC, DualBatchQueue, SingleBatchQueue from .cachers import Cacher @@ -62,7 +61,6 @@ from .samplers import ( DualSampler, DualSamplerCC, - DualSamplerWithObs, Sampler, SamplerDC, ) diff --git a/sup3r/preprocessing/batch_handlers/__init__.py b/sup3r/preprocessing/batch_handlers/__init__.py index d66b10126..08bba8d6b 100644 --- a/sup3r/preprocessing/batch_handlers/__init__.py +++ b/sup3r/preprocessing/batch_handlers/__init__.py @@ -11,5 +11,4 @@ BatchHandlerMom2SepSF, BatchHandlerMom2SF, DualBatchHandler, - DualBatchHandlerWithObs, ) diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index c43e9198a..3a65b64bb 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -17,7 +17,7 @@ from sup3r.preprocessing.collections.stats import StatsCollection from sup3r.preprocessing.samplers.base import Sampler from sup3r.preprocessing.samplers.cc import DualSamplerCC -from sup3r.preprocessing.samplers.dual import DualSampler, DualSamplerWithObs +from sup3r.preprocessing.samplers.dual import DualSampler from sup3r.preprocessing.utilities import ( check_signatures, get_class_kwargs, @@ -316,10 +316,6 @@ def stop(self): DualBatchQueue, DualSampler, name='DualBatchHandler' ) -DualBatchHandlerWithObs = BatchHandlerFactory( - DualBatchQueue, DualSamplerWithObs, name='DualBatchHandlerWithObs' -) - BatchHandlerCC = BatchHandlerFactory( DualBatchQueue, DualSamplerCC, name='BatchHandlerCC' ) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 01c6d98c4..3f8876db7 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -4,7 +4,7 @@ (1) Figure out apparent "blocking" issue with threaded enqueue batches. max_workers=1 is the fastest? (2) Setup distributed data handling so this can work with data distributed - over multiple nodes. + over multiple nodes. """ import logging diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 777f6b5dc..e968ff27e 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -76,7 +76,9 @@ def __init__( self.s_enhance = s_enhance self.t_enhance = t_enhance if isinstance(data, tuple): - data = Sup3rDataset(low_res=data[0], high_res=data[1]) + data = {'low_res': data[0], 'high_res': data[1]} + if isinstance(data, dict): + data = Sup3rDataset(data) msg = ( 'The DualRasterizer requires either a data tuple with two ' 'members, low and high resolution in that order, or a ' diff --git a/sup3r/preprocessing/samplers/__init__.py b/sup3r/preprocessing/samplers/__init__.py index 545da5a3c..e281616d5 100644 --- a/sup3r/preprocessing/samplers/__init__.py +++ b/sup3r/preprocessing/samplers/__init__.py @@ -8,4 +8,4 @@ from .base import Sampler from .cc import DualSamplerCC from .dc import SamplerDC -from .dual import DualSampler, DualSamplerWithObs +from .dual import DualSampler diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index af144b8ca..b4bf11eeb 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -16,7 +16,10 @@ class DualSampler(Sampler): """Sampler for sampling from paired (or dual) datasets. Pairs consist of - low and high resolution data, which are contained by a Sup3rDataset.""" + low and high resolution data, which are contained by a Sup3rDataset. This + can also include extra observation data on the same grid as the + high-resolution data which has NaNs at points where observation data + doesn't exist. This will be used in an additional content loss term.""" def __init__( self, @@ -32,7 +35,7 @@ def __init__( ---------- data : Sup3rDataset A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with - low-res and high-res data members + low-res and high-res data members, and optionally an obs member. sample_shape : tuple Size of arrays to sample from the high-res data. The sample shape for the low-res sampler will be determined from the enhancement @@ -57,10 +60,15 @@ def __init__( """ msg = ( f'{self.__class__.__name__} requires a Sup3rDataset object ' - 'with `.low_res` and `.high_res` data members, in that order' + 'with `.low_res` and `.high_res` data members, and optionally an ' + '`.obs` member, in that order' ) - assert hasattr(data, 'low_res') and hasattr(data, 'high_res'), msg - assert data.low_res == data[0] and data.high_res == data[1], msg + check = hasattr(data, 'low_res') and hasattr(data, 'high_res') + check = check and data.low_res == data[0] and data.high_res == data[1] + if len(data) == 2: + check = check and (hasattr(data, 'obs') and data.obs == data[2]) + assert check, msg + super().__init__( data=data, sample_shape=sample_shape, batch_size=batch_size ) @@ -118,7 +126,8 @@ def check_for_consistent_shapes(self): def get_sample_index(self, n_obs=None): """Get paired sample index, consisting of index for the low res sample and the index for the high res sample with the same spatiotemporal - extent.""" + extent. Optionally includes an extra high res index if the sample data + includes observation data.""" n_obs = n_obs or self.batch_size spatial_slice = uniform_box_sampler( self.lr_data.shape, self.lr_sample_shape[:2] @@ -136,81 +145,8 @@ def get_sample_index(self, n_obs=None): for s in lr_index[2:-1] ] hr_index = (*hr_index, self.hr_features) - return (lr_index, hr_index) - -class DualSamplerWithObs(DualSampler): - """Dual Sampler which also samples from extra observation data. The - observation data is on the same grid as the high-resolution data but - includes NaNs at points where observation data doesn't exist. This will - be used in an additional content loss term.""" - - def __init__( - self, - data: Sup3rDataset, - sample_shape: Optional[tuple] = None, - batch_size: int = 16, - s_enhance: int = 1, - t_enhance: int = 1, - feature_sets: Optional[Dict] = None, - ): - """ - Parameters - ---------- - data : Sup3rDataset - A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with - low-res and high-res data members - sample_shape : tuple - Size of arrays to sample from the high-res data. The sample shape - for the low-res sampler will be determined from the enhancement - factors. - s_enhance : int - Spatial enhancement factor - t_enhance : int - Temporal enhancement factor - feature_sets : Optional[dict] - Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. - - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. - hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. - """ - - msg = ( - f'{self.__class__.__name__} requires a Sup3rDataset object ' - 'with `.low_res`, `.high_res`, and `.obs` data members, in that ' - 'order' - ) - assert ( - hasattr(data, 'low_res') - and hasattr(data, 'high_res') - and hasattr(data, 'obs') - ), msg - assert ( - data.low_res == data[0] - and data.high_res == data[1] - and data.obs == data[2] - ), msg - super().__init__( - data, - sample_shape=sample_shape, - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - feature_sets=feature_sets, - ) - - def get_sample_index(self, n_obs=None): - """Get paired sample index, consisting of index for the low res sample - and the index for the high res sample with the same spatiotemporal - extent, with an additional index (same as the index for the high-res - data) for the observation data""" - lr_index, hr_index = super().get_sample_index(n_obs=n_obs) - return (lr_index, hr_index, hr_index) + sample_index = (lr_index, hr_index) + if hasattr(self.data, 'obs'): + sample_index += (hr_index,) + return sample_index diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py index 01ab41e50..48a9c7d67 100644 --- a/tests/training/test_train_dual_with_obs.py +++ b/tests/training/test_train_dual_with_obs.py @@ -11,10 +11,10 @@ from sup3r.preprocessing import ( Container, DataHandler, - DualBatchHandlerWithObs, + DualBatchHandler, DualRasterizer, ) -from sup3r.preprocessing.samplers import DualSamplerWithObs +from sup3r.preprocessing.samplers import DualSampler from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory TARGET_COORD = (39.01, -105.15) @@ -22,10 +22,112 @@ DualBatchHandlerWithObsTester = BatchHandlerTesterFactory( - DualBatchHandlerWithObs, DualSamplerWithObs + DualBatchHandler, DualSampler ) +@pytest.mark.parametrize( + [ + 'fp_gen', + 'fp_disc', + 's_enhance', + 't_enhance', + 'sample_shape', + 'mode', + ], + [ + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), + ], +) +def test_train_h5_nc( + fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 +): + """Test model training with a dual data handler / batch handler with h5 and + era as hr / lr datasets. Tests both spatiotemporal and spatial models.""" + + lr = 1e-5 + kwargs = { + 'features': FEATURES, + 'target': TARGET_COORD, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data=(lr_handler.data, hr_handler.data), + s_enhance=s_enhance, + t_enhance=t_enhance, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = list(range(0, 20, 4)) + lon_ids = list(range(0, 20, 4)) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[feat] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': dual_rasterizer.high_res, + 'obs': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=3, + mode=mode, + ) + + for batch in batch_handler: + assert hasattr(batch, 'obs') + assert not np.isnan(batch.obs).all() + assert np.isnan(batch.obs).any() + + Sup3rGan.seed() + model = Sup3rGan( + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': n_epoch, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tlossg = model.history['train_loss_gen'].values + tlosso = model.history['train_loss_obs'].values + assert np.sum(np.diff(tlossg)) < 0 + assert np.sum(np.diff(tlosso)) < 0 + + @pytest.mark.parametrize( [ 'fp_gen', @@ -101,6 +203,11 @@ def test_train_coarse_h5( mode=mode, ) + for batch in batch_handler: + assert hasattr(batch, 'obs') + assert not np.isnan(batch.obs).all() + assert np.isnan(batch.obs).any() + Sup3rGan.seed() model = Sup3rGan( fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' From 0cb253c945f29f6d8c2a7cfb683d596e864f524a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 23 Dec 2024 07:30:58 -0800 Subject: [PATCH 07/32] examples added to DataHandler doc string. Some instructions on sup3rwind model training added to examples README. --- examples/sup3rwind/README.rst | 39 ++++++++++++- sup3r/preprocessing/base.py | 9 +-- sup3r/preprocessing/batch_queues/abstract.py | 9 +-- sup3r/preprocessing/cachers/base.py | 2 +- sup3r/preprocessing/data_handlers/base.py | 25 +++++++++ sup3r/preprocessing/rasterizers/dual.py | 13 +++-- sup3r/preprocessing/samplers/dual.py | 15 +++-- sup3r/utilities/utilities.py | 2 +- tests/data/extract_raster_wtk.py | 59 -------------------- 9 files changed, 85 insertions(+), 88 deletions(-) delete mode 100644 tests/data/extract_raster_wtk.py diff --git a/examples/sup3rwind/README.rst b/examples/sup3rwind/README.rst index 8ca30dc57..4cb0c5fa3 100644 --- a/examples/sup3rwind/README.rst +++ b/examples/sup3rwind/README.rst @@ -2,7 +2,7 @@ Sup3rWind Examples ################### -Super-Resolution for Renewable Energy Resource Data with Wind from Reanalysis Data (Sup3rWind) is one application of the sup3r software. In this work, we train generative models to create high-resolution (2km 5-minute) wind data based on coarse (30km hourly) ERA5 data. The generative models and high-resolution output data is publicly available via the `Open Energy Data Initiative (OEDI) `__ and via HSDS at the bucket ``nrel-pds-hsds`` and path ``/nrel/wtk/sup3rwind``. This data covers recent historical time periods for an expanding selection of countries. +Super-Resolution for Renewable Energy Resource Data with Wind from Reanalysis Data (Sup3rWind) is one application of the sup3r software. In this work, we train generative models to create high-resolution (2km 5-minute) wind data based on coarse (30km hourly) ERA5 data. The generative models, high-resolution output data, and training data is publicly available via the `Open Energy Data Initiative (OEDI) `__ and via HSDS at the bucket ``nrel-pds-hsds`` and path ``/nrel/wtk/sup3rwind``. This data covers recent historical time periods for an expanding selection of countries. Sup3rWind Data Access ---------------------- @@ -11,8 +11,8 @@ The Sup3rWind data and models are publicly available in a public AWS S3 bucket. The Sup3rWind data is also loaded into `HSDS `__ so that you may stream the data via the `NREL developer API `__ or your own HSDS server. This is the best option if you're not going to want a full annual dataset. See these `rex instructions `__ for more details on how to access this data with HSDS and rex. -Example Sup3rWind Data Usage ------------------------------ +Sup3rWind Data Usage +--------------------- Sup3rWind data can be used in generally the same way as `Sup3rCC `__ data, with the condition that Sup3rWind includes only wind data and ancillary variables for modeling wind energy generation. Refer to the Sup3rCC `example notebook `__ for usage patterns. @@ -32,6 +32,39 @@ The process for running the Sup3rWind models is much the same as for `Sup3rCC `__. This data is for training the spatial enhancement models only. The 2024-01 `models `__ perform spatial enhancement in two steps, 3x from ERA5 to coarsened WTK and 5x from coarsened WTK to uncoarsened WTK. The currently used approach performs spatial enhancement in a single 15x step. + +For a given year and training domain, initialize low-resolution and high-resolution data handlers and wrap these in a dual rasterizer object. Do this for as many years and training regions as desired, and use these containers to initialize a batch handler. To train models for 3x spatial enhancement use ``hr_spatial_coarsen=5`` in the ``hr_dh``. To train models for 15x (the currently used approach) ``hr_spatial_coarsen=1``. (Refer to tests and docs for information on additional arguments, denoted by the ellipses):: + + from sup3r.preprocessing import DataHandler, DualBatchHandler, DualRasterizer + containers = [] + for tdir in training_dirs: + lr_dh = DataHandler(f"{tdir}/lr_*.h5", ...) + hr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=...) + container = DualRasterizer({'low_res': lr_dh, 'high_res': hr_dh}, ...) + containers.append(container) + bh = DualBatchHandler(train_containers=containers, ...) + +To train a 5x model use the ``hr_*.h5`` files for both the ``lr_dh`` and the ``hr_dh``. Use ``hr_spatial_coarsen=3`` in the ``lr_dh`` and ``hr_spatial_coarsen=1`` in the ``hr_dh``:: + + for tdir in training_dirs: + lr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=3, ...) + hr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=1, ...) + container = DualRasterizer({'low_res': lr_dh, 'high_res': hr_dh}, ...) + containers.append(container) + bh = DualBatchHandler(train_containers=containers, ...) + + +Initialize a 3x, 5x, or 15x spatial enhancement model, with 14 output channels, and train for the desired number of epochs. (The 3x and 5x generator configs can be copied from the ``model_params.json`` files in each OEDI model `directory `__. The 15x generator config can be created from the OEDI model configs by changing the spatial enhancement factor or from the configs in the repo by changing the enhancement factor and the number of output channels):: + + from sup3r.models import Sup3rGan + model = Sup3rGan(gen_layers="./gen_config.json", disc_layers="./disc_config.json", ...) + model.train(batch_handler, ...) + + Sup3rWind Versions ------------------- diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 1cfa0c7ea..a6e9074e9 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -116,6 +116,8 @@ class Sup3rDataset: from the high_res non coarsened variable. """ + DSET_NAMES = ('low_res', 'high_res', 'obs') + def __init__( self, **dsets: Mapping[str, Union[xr.Dataset, Sup3rX]], @@ -185,7 +187,7 @@ def rewrap(self, data): return data if len(data) == 1: return type(self)(high_res=data[0]) - return type(self)(**dict(zip(['low_res', 'high_res', 'obs'], data))) + return type(self)(**dict(zip(self.DSET_NAMES, data))) def sample(self, idx): """Get samples from ``self._ds`` members. idx should be either a tuple @@ -369,16 +371,15 @@ def wrap(self, data): if isinstance(data, dict): data = Sup3rDataset(**data) - default_names = ['low_res', 'high_res', 'obs'] if isinstance(data, tuple) and len(data) > 1: msg = ( f'{self.__class__.__name__}.data is being set with a ' f'{len(data)}-tuple without explicit dataset names. We will ' - f'assume name ordering: {default_names[:len(data)]}' + f'assume name ordering: {Sup3rDataset.DSET_NAMES[:len(data)]}' ) logger.warning(msg) warn(msg) - data = Sup3rDataset(**dict(zip(default_names, data))) + data = Sup3rDataset(**dict(zip(Sup3rDataset.DSET_NAMES, data))) elif not isinstance(data, Sup3rDataset): name = getattr(data, 'name', None) or 'high_res' data = Sup3rDataset(**{name: data}) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 3f8876db7..a9423fe7e 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -22,11 +22,7 @@ from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer if TYPE_CHECKING: - from sup3r.preprocessing.samplers import ( - DualSampler, - DualSamplerWithObs, - Sampler, - ) + from sup3r.preprocessing.samplers import DualSampler, Sampler logger = logging.getLogger(__name__) @@ -41,8 +37,7 @@ class AbstractBatchQueue(Collection, ABC): def __init__( self, samplers: Union[ - List['Sampler'], List['DualSampler'], List['DualSamplerWithObs'] - ], + List['Sampler'], List['DualSampler']], batch_size: int = 16, n_batches: int = 64, s_enhance: int = 1, diff --git a/sup3r/preprocessing/cachers/base.py b/sup3r/preprocessing/cachers/base.py index be380b673..4df526039 100644 --- a/sup3r/preprocessing/cachers/base.py +++ b/sup3r/preprocessing/cachers/base.py @@ -330,7 +330,7 @@ def write_h5( ] if Dimension.TIME in data: - data[Dimension.TIME] = data[Dimension.TIME].astype(int) + data[Dimension.TIME] = data[Dimension.TIME].astype('int64') for dset in [*coord_names, *features]: data_var, chunksizes = cls.get_chunksizes(dset, data, chunks) diff --git a/sup3r/preprocessing/data_handlers/base.py b/sup3r/preprocessing/data_handlers/base.py index 452518d63..3d8f73125 100644 --- a/sup3r/preprocessing/data_handlers/base.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -145,6 +145,31 @@ def __init__( Dictionary of additional keyword args for :class:`~sup3r.preprocessing.rasterizers.Rasterizer`, used specifically for rasterizing flattened data + + Examples + -------- + Extract windspeed at 40m and 80m above the ground from files for u/v at + 10m and 100m. Windspeed will be interpolated from surrounding levels + using a log profile. ``dh`` will contain dask arrays of this data with + 10x10x50 chunk sizes. Data will be cached to files named + 'windspeed_40m.h5' and 'windspeed_80m.h5' in './cache_dir' with + 5x5x10 chunks on disk. + >>> cache_chunks = {'south_north': 5, 'west_east': 5, 'time': 10} + >>> load_chunks = {'south_north': 10, 'west_east': 10, 'time': 50} + >>> grid_size = (50, 50) + >>> lower_left_coordinate = (39.7, -105.2) + >>> dh = DataHandler( + ... file_paths=['./data_dir/u_10m.nc', './data_dir/u_100m.nc', + ... './data_dir/v_10m.nc', './data_dir/v_100m.nc'], + ... features=['windspeed_40m', 'windspeed_80m'], + ... shape=grid_size, time_slice=slice(0, 100), + ... target=lower_left_coordinate, hr_spatial_coarsen=2, + ... chunks=load_chunks, interp_kwargs={'method': 'log'}, + ... cache_kwargs={'cache_pattern': './cache_dir/{feature}.h5', + ... 'chunks': cache_chunks}) + + Derive more features from already initialized data handler: + >>> dh['windspeed_60m'] = dh.derive('windspeed_60m') """ # pylint: disable=line-too-long features = parse_to_list(features=features) diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index e968ff27e..47706aa4e 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -2,7 +2,7 @@ datasets""" import logging -from typing import Tuple, Union +from typing import Dict, Tuple, Union from warnings import warn import numpy as np @@ -38,7 +38,9 @@ class DualRasterizer(Container): @log_args def __init__( self, - data: Union[Sup3rDataset, Tuple[xr.Dataset, xr.Dataset]], + data: Union[ + Sup3rDataset, Tuple[xr.Dataset, xr.Dataset], Dict[str, xr.Dataset] + ], regrid_workers=1, regrid_lr=True, s_enhance=1, @@ -51,7 +53,8 @@ def __init__( Parameters ---------- - data : Sup3rDataset | Tuple[xr.Dataset, xr.Dataset] + data : Sup3rDataset | Tuple[xr.Dataset, xr.Dataset] | + Dict[str, xr.Dataset] A tuple of xr.Dataset instances. The first must be low-res and the second must be high-res data regrid_workers : int | None @@ -78,9 +81,9 @@ def __init__( if isinstance(data, tuple): data = {'low_res': data[0], 'high_res': data[1]} if isinstance(data, dict): - data = Sup3rDataset(data) + data = Sup3rDataset(**data) msg = ( - 'The DualRasterizer requires either a data tuple with two ' + 'The DualRasterizer requires a data tuple or dictionary with two ' 'members, low and high resolution in that order, or a ' f'Sup3rDataset instance. Received {type(data)}.' ) diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index b4bf11eeb..62e5c4233 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -63,10 +63,11 @@ def __init__( 'with `.low_res` and `.high_res` data members, and optionally an ' '`.obs` member, in that order' ) - check = hasattr(data, 'low_res') and hasattr(data, 'high_res') - check = check and data.low_res == data[0] and data.high_res == data[1] - if len(data) == 2: - check = check and (hasattr(data, 'obs') and data.obs == data[2]) + dnames = ['low_res', 'high_res', 'obs'][:len(data)] + check = ( + hasattr(data, dname) and getattr(data, dname) == data[i] + for i, dname in enumerate(dnames) + ) assert check, msg super().__init__( @@ -146,7 +147,5 @@ def get_sample_index(self, n_obs=None): ] hr_index = (*hr_index, self.hr_features) - sample_index = (lr_index, hr_index) - if hasattr(self.data, 'obs'): - sample_index += (hr_index,) - return sample_index + sample_index = (lr_index, hr_index, hr_index) + return sample_index[:len(self.data)] diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 9b7be514a..3a563119a 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -26,7 +26,7 @@ def preprocess_datasets(dset): dset.indexes['time'], 'to_datetimeindex' ): dset['time'] = dset.indexes['time'].to_datetimeindex() - ti = dset['time'].astype(int) + ti = dset['time'].astype('int64') dset['time'] = ti if 'latitude' in dset.dims: dset = dset.swap_dims({'latitude': 'south_north'}) diff --git a/tests/data/extract_raster_wtk.py b/tests/data/extract_raster_wtk.py deleted file mode 100644 index ed826d134..000000000 --- a/tests/data/extract_raster_wtk.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Script to extract data subset in raster shape from flattened WTK h5 files. - -TODO: Is this worth keeping for any reason? -""" - -import matplotlib.pyplot as plt -from rex import init_logger -from rex.outputs import Outputs -from rex.resource_extraction.resource_extraction import WindX - -if __name__ == '__main__': - init_logger('rex', log_level='DEBUG') - - res_fp = '/datasets/WIND/conus/v1.0.0/wtk_conus_2013.h5' - fout = './test_wtk_co_2013.h5' - - dsets = ['windspeed_80m', 'windspeed_100m', 'winddirection_80m', - 'winddirection_100m', 'temperature_100m', 'pressure_100m'] - - target = (39.0, -105.15) - shape = (20, 20) - - with WindX(res_fp) as res: - meta = res.meta - raster_index_2d = res.get_raster_index(target, shape, max_delta=20) - - for d in ('elevation', 'latitude', 'longitude'): - data = meta[d].values[raster_index_2d] - a = plt.imshow(data) - plt.colorbar(a, label=d) - plt.savefig(d + '.png') - plt.close() - - raster_index = sorted(raster_index_2d.ravel()) - - attrs = {k: res.resource.attrs[k] for k in dsets} - chunks = dict.fromkeys(dsets) - dtypes = {k: res.resource.dtypes[k] for k in dsets} - meta = meta.iloc[raster_index].reset_index(drop=True) - time_index = res.time_index - shapes = {k: (len(time_index), len(meta)) for k in dsets} - print(shapes) - - Outputs.init_h5(fout, dsets, shapes, attrs, chunks, dtypes, meta, - time_index=time_index) - - with Outputs(fout, mode='a') as f: - for d in dsets: - f[d] = res[d, :, raster_index] - - with Outputs(fout, mode='r') as f: - meta = f.meta - for d in dsets: - data = f[d].mean(axis=0) - data = data.reshape(shape) - a = plt.imshow(data) - plt.colorbar(a, label=d) - plt.savefig(d + '.png') - plt.close() From 13664d1173bf889d7c6902db6e2dd5b545fc17b3 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 26 Dec 2024 14:22:32 -0800 Subject: [PATCH 08/32] removed namedtuple from Sup3rDataset to make Sup3rDataset picklable. --- sup3r/models/abstract.py | 14 +++--- sup3r/preprocessing/base.py | 25 ++++++++++- sup3r/preprocessing/batch_queues/abstract.py | 47 ++++++++++++-------- sup3r/preprocessing/cachers/base.py | 1 + sup3r/utilities/pytest/helpers.py | 27 +++++------ 5 files changed, 74 insertions(+), 40 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 17836deb2..506f85b38 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -684,9 +684,9 @@ def finish_epoch( """ self.log_loss_details(loss_details) self._history.at[epoch, 'elapsed_time'] = time.time() - t0 - for key, value in loss_details.items(): - if key != 'n_obs': - self._history.at[epoch, key] = value + cols = [k for k in loss_details if k != 'n_obs'] + entry = np.vstack([loss_details[k] for k in cols]) + self._history.loc[epoch, cols] = entry.T last_epoch = epoch == epochs[-1] chp = checkpoint_int is not None and (epoch % checkpoint_int) == 0 @@ -710,8 +710,8 @@ def finish_epoch( self.save(out_dir.format(epoch=epoch)) if extras is not None: - for k, v in extras.items(): - self._history.at[epoch, k] = safe_cast(v) + entry = np.vstack([safe_cast(v) for v in extras.values()]) + self._history.loc[epoch, list(extras)] = entry.T return stop @@ -744,6 +744,8 @@ def run_gradient_descent( current loss weight values. obs_data : tf.Tensor | None Optional observation data to use in additional content loss term. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) optimizer : tf.keras.optimizers.Optimizer Optimizer class to use to update weights. This can be different if you're training just the generator or one of the discriminator @@ -1054,6 +1056,8 @@ def get_single_grad( current loss weight values. obs_data : tf.Tensor | None Optional observation data to use in additional content loss term. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) device_name : None | str Optional tensorflow device name for GPU placement. Note that if a GPU is available, variables will be placed on that GPU even if diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index a6e9074e9..8a82efe3d 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -11,7 +11,6 @@ import logging import pprint from abc import ABCMeta -from collections import namedtuple from typing import Dict, Mapping, Tuple, Union from warnings import warn @@ -70,6 +69,28 @@ def __repr__(cls): return f"" +class DsetTuple: + """A simple class to mimic namedtuple behavior with dynamic attributes + while being serializable""" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __iter__(self): + return iter(self.__dict__.values()) + + def __getitem__(self, key): + if isinstance(key, int): + key = list(self.__dict__)[key] + return self.__dict__[key] + + def __len__(self): + return len(self.__dict__) + + def __repr__(self): + return f"DsetTuple({self.__dict__})" + + class Sup3rDataset: """Interface for interacting with one or two ``xr.Dataset`` instances. This is a wrapper around one or two ``Sup3rX`` objects so they work well @@ -149,7 +170,7 @@ def __init__( assert len(dset) == 1, msg dsets[name] = dset._ds[0] - self._ds = namedtuple('Dataset', list(dsets))(**dsets) + self._ds = DsetTuple(**dsets) def __iter__(self): yield from self._ds diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index a9423fe7e..b3083e11e 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -12,9 +12,9 @@ import time from abc import ABC, abstractmethod from collections import namedtuple -from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, List, Optional, Union +import dask import numpy as np import tensorflow as tf @@ -36,8 +36,7 @@ class AbstractBatchQueue(Collection, ABC): def __init__( self, - samplers: Union[ - List['Sampler'], List['DualSampler']], + samplers: Union[List['Sampler'], List['DualSampler']], batch_size: int = 16, n_batches: int = 64, s_enhance: int = 1, @@ -237,6 +236,24 @@ def running(self): and not self.queue.is_closed() ) + def sample_batches(self, n_batches) -> None: + """Sample N batches from samplers. Returns N batches which are then + used to fill the queue.""" + if n_batches == 1: + return [self.sample_batch()] + + tasks = [dask.delayed(self.sample_batch)() for _ in range(n_batches)] + logger.debug('Added %s sample_batch futures to %s queue.', + n_batches, + self._thread_name) + + if self.max_workers == 1: + batches = dask.compute(*tasks, scheduler='single-threaded') + else: + batches = dask.compute( + *tasks, scheduler='threads', num_workers=self.max_workers) + return batches + def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are @@ -244,16 +261,15 @@ def enqueue_batches(self) -> None: log_time = time.time() while self.running: needed = self.queue_cap - self.queue.size().numpy() - if needed == 1 or self.max_workers == 1: - self.enqueue_batch() - elif needed > 0: - with ThreadPoolExecutor(self.max_workers) as exe: - _ = [exe.submit(self.enqueue_batch) for _ in range(needed)] - logger.debug( - 'Added %s enqueue futures to %s queue.', - needed, - self._thread_name, - ) + + # no point in getting more than one batch at a time if + # max_workers == 1 + needed = 1 if needed > 0 and self.max_workers == 1 else needed + + if needed > 0: + for batch in self.sample_batches(n_batches=needed): + self.queue.enqueue(batch) + if time.time() > log_time + 10: logger.debug(self.log_queue_info()) log_time = time.time() @@ -317,11 +333,6 @@ def log_queue_info(self): self.queue_cap, ) - def enqueue_batch(self): - """Build batch and send to queue.""" - if self.running and self.queue.size().numpy() < self.queue_cap: - self.queue.enqueue(self.sample_batch()) - @property def lr_shape(self): """Shape of low resolution sample in a low-res / high-res pair. (e.g. diff --git a/sup3r/preprocessing/cachers/base.py b/sup3r/preprocessing/cachers/base.py index 4df526039..82a4bd5c1 100644 --- a/sup3r/preprocessing/cachers/base.py +++ b/sup3r/preprocessing/cachers/base.py @@ -330,6 +330,7 @@ def write_h5( ] if Dimension.TIME in data: + # int64 used explicity to avoid incorrect encoding as int32 data[Dimension.TIME] = data[Dimension.TIME].astype('int64') for dset in [*coord_names, *features]: diff --git a/sup3r/utilities/pytest/helpers.py b/sup3r/utilities/pytest/helpers.py index 3f9b320a5..bb9912c9f 100644 --- a/sup3r/utilities/pytest/helpers.py +++ b/sup3r/utilities/pytest/helpers.py @@ -107,9 +107,17 @@ class DummySampler(Sampler): """Dummy container with random data.""" def __init__( - self, sample_shape, data_shape, features, batch_size, feature_sets=None + self, + sample_shape, + data_shape, + features, + batch_size, + feature_sets=None, + chunk_shape=None, ): data = make_fake_dset(data_shape, features=features) + if chunk_shape is not None: + data = data.chunk(chunk_shape) super().__init__( Sup3rDataset(high_res=data), sample_shape, @@ -314,10 +322,7 @@ def make_collect_chunks(td): out_files = [] for t, slice_hr in enumerate(t_slices_hr): for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)): - out_file = out_pattern.format( - t=str(t).zfill(6), - s=str(s).zfill(6) - ) + out_file = out_pattern.format(t=str(t).zfill(6), s=str(s).zfill(6)) out_files.append(out_file) OutputHandlerH5._write_output( data[s1_hr, s2_hr, slice_hr, :], @@ -330,15 +335,7 @@ def make_collect_chunks(td): gids=gids[s1_hr, s2_hr], ) - return ( - out_files, - data, - ws_true, - wd_true, - features, - hr_lat_lon, - hr_times - ) + return (out_files, data, ws_true, wd_true, features, hr_lat_lon, hr_times) def make_fake_h5_chunks(td): @@ -436,7 +433,7 @@ def make_fake_h5_chunks(td): s_slices_lr, s_slices_hr, low_res_lat_lon, - low_res_times + low_res_times, ) From c31ed723a1934ad96e4c859249ef09b4ebc18eef Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 27 Dec 2024 08:04:01 -0800 Subject: [PATCH 09/32] parallel batch queue test added. --- sup3r/preprocessing/base.py | 3 ++ sup3r/preprocessing/batch_handlers/factory.py | 8 +-- sup3r/preprocessing/batch_queues/abstract.py | 36 +++++++------ .../preprocessing/batch_queues/conditional.py | 14 ++---- sup3r/preprocessing/batch_queues/dual.py | 3 +- tests/batch_queues/test_bq_general.py | 50 +++++++++++++++++++ 6 files changed, 83 insertions(+), 31 deletions(-) diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 8a82efe3d..e046438fa 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -386,6 +386,9 @@ def wrap(self, data): if data is None: return data + if hasattr(data, 'data'): + data = data.data + if is_type_of(data, Sup3rDataset): return data diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index 3a65b64bb..4037f2224 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -265,26 +265,26 @@ def init_samplers( """Initialize samplers from given data containers.""" train_samplers = [ self.SAMPLER( - data=c.data, + data=container, sample_shape=sample_shape, feature_sets=feature_sets, batch_size=batch_size, **sampler_kwargs, ) - for c in train_containers + for container in train_containers ] val_samplers = ( [] if val_containers is None else [ self.SAMPLER( - data=c.data, + data=container, sample_shape=sample_shape, feature_sets=feature_sets, batch_size=batch_size, **sampler_kwargs, ) - for c in val_containers + for container in val_containers ] ) return train_samplers, val_samplers diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index b3083e11e..f5b6b49b9 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -11,13 +11,13 @@ import threading import time from abc import ABC, abstractmethod -from collections import namedtuple from typing import TYPE_CHECKING, List, Optional, Union import dask import numpy as np import tensorflow as tf +from sup3r.preprocessing.base import DsetTuple from sup3r.preprocessing.collections.base import Collection from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer @@ -32,7 +32,7 @@ class AbstractBatchQueue(Collection, ABC): generator and maintains a queue of batches in a dedicated thread so the training routine can proceed as soon as batches are available.""" - Batch = namedtuple('Batch', ['low_res', 'high_res']) + BATCH_MEMBERS = ('low_res', 'high_res') def __init__( self, @@ -46,6 +46,7 @@ def __init__( max_workers: int = 1, thread_name: str = 'training', mode: str = 'lazy', + verbose: bool = False ): """ Parameters @@ -77,6 +78,8 @@ def __init__( Loading mode. Default is 'lazy', which only loads data into memory as batches are queued. 'eager' will load all data into memory right away. + verbose : bool + Whether to log timing information for batch steps. """ msg = ( f'{self.__class__.__name__} requires a list of samplers. ' @@ -101,6 +104,7 @@ def __init__( 'smoothing_ignore': [], 'smoothing': None, } + self.verbose = verbose self.timer = Timer() self.preflight() @@ -174,7 +178,7 @@ def transform(self, samples, **kwargs): high res samples. For a dual dataset queue this will just include smoothing.""" - def post_proc(self, samples) -> Batch: + def post_proc(self, samples) -> DsetTuple: """Performs some post proc on dequeued samples before sending out for training. Post processing can include coarsening on high-res data (if :class:`Collection` consists of :class:`Sampler` objects and not @@ -182,13 +186,12 @@ def post_proc(self, samples) -> Batch: Returns ------- - Batch : namedtuple - namedtuple with `low_res` and `high_res` attributes. Could also - include additional members for integration with - ``DualSamplerWithObs`` + Batch : DsetTuple + namedtuple-like object with `low_res` and `high_res` attributes. + Could also include `obs` member. """ tsamps = self.transform(samples, **self.transform_kwargs) - return self.Batch(**dict(zip(self.Batch._fields, tsamps))) + return DsetTuple(**dict(zip(self.BATCH_MEMBERS, tsamps))) def start(self) -> None: """Start thread to keep sample queue full for batches.""" @@ -216,7 +219,7 @@ def __iter__(self): self.start() return self - def get_batch(self) -> Batch: + def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" if ( @@ -274,10 +277,10 @@ def enqueue_batches(self) -> None: logger.debug(self.log_queue_info()) log_time = time.time() - def __next__(self) -> Batch: + def __next__(self) -> DsetTuple: """Dequeue batch samples, squeeze if for a spatial only model, perform some post-proc like smoothing, coarsening, etc, and then send out for - training as a namedtuple of low_res / high_res arrays. + training as a namedtuple-like object of low_res / high_res arrays. Returns ------- @@ -295,11 +298,12 @@ def __next__(self) -> Batch: batch = self.post_proc(samples) self.timer.stop() self._batch_count += 1 - logger.debug( - 'Batch step %s finished in %s.', - self._batch_count, - self.timer.elapsed_str, - ) + if self.verbose: + logger.debug( + 'Batch step %s finished in %s.', + self._batch_count, + self.timer.elapsed_str, + ) else: raise StopIteration return batch diff --git a/sup3r/preprocessing/batch_queues/conditional.py b/sup3r/preprocessing/batch_queues/conditional.py index 43d479b5f..488e70b2c 100644 --- a/sup3r/preprocessing/batch_queues/conditional.py +++ b/sup3r/preprocessing/batch_queues/conditional.py @@ -2,12 +2,12 @@ import logging from abc import abstractmethod -from collections import namedtuple from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np from sup3r.models.conditional import Sup3rCondMom +from sup3r.preprocessing.base import DsetTuple from sup3r.preprocessing.utilities import numpy_if_tensor from .base import SingleBatchQueue @@ -22,10 +22,6 @@ class ConditionalBatchQueue(SingleBatchQueue): """BatchQueue class for conditional moment estimation.""" - ConditionalBatch = namedtuple( - 'ConditionalBatch', ['low_res', 'high_res', 'output', 'mask'] - ) - def __init__( self, samplers: Union[List['Sampler'], List['DualSampler']], @@ -160,14 +156,14 @@ def post_proc(self, samples): Returns ------- - namedtuple - Named tuple with `low_res`, `high_res`, `mask`, and `output` - attributes + DsetTuple + Namedtuple-like object with `low_res`, `high_res`, `mask`, and + `output` attributes """ lr, hr = self.transform(samples, **self.transform_kwargs) mask = self.make_mask(high_res=hr) output = self.make_output(samples=(lr, hr)) - return self.ConditionalBatch( + return DsetTuple( low_res=lr, high_res=hr, output=output, mask=mask ) diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 4b64c593d..9a89ce8bd 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -2,7 +2,6 @@ interface with models.""" import logging -from collections import namedtuple from scipy.ndimage import gaussian_filter @@ -21,7 +20,7 @@ def __init__(self, samplers, **kwargs): -------- :class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue` """ - self.Batch = namedtuple('Batch', samplers[0]._fields) + self.BATCH_MEMBERS = samplers[0]._fields super().__init__(samplers, **kwargs) self.check_enhancement_factors() diff --git a/tests/batch_queues/test_bq_general.py b/tests/batch_queues/test_bq_general.py index 52b2ffb25..b130aa7b7 100644 --- a/tests/batch_queues/test_bq_general.py +++ b/tests/batch_queues/test_bq_general.py @@ -12,6 +12,7 @@ DummyData, DummySampler, ) +from sup3r.utilities.utilities import Timer FEATURES = ['windspeed', 'winddirection'] @@ -53,6 +54,55 @@ def test_batch_queue(): batcher.stop() +def test_batch_queue_workers(): + """Check that using max_workers > 1 for a batch queue is faster than using + max_workers = 1.""" + + timer = Timer() + sample_shape = (10, 10, 20) + n_batches = 20 + batch_size = 10 + max_workers = 10 + n_epochs = 10 + chunk_shape = {'south_north': 20, 'west_east': 20, 'time': 40} + samplers = [ + DummySampler( + sample_shape, + data_shape=(100, 100, 1000), + batch_size=batch_size, + features=FEATURES, + chunk_shape=chunk_shape + ) + ] + batcher = SingleBatchQueue( + samplers=samplers, + n_batches=n_batches, + batch_size=batch_size, + max_workers=1, + ) + timer.start() + for _ in range(n_epochs): + _ = list(batcher) + timer.stop() + batcher.stop() + serial_time = timer.elapsed / (n_epochs * n_batches) + + batcher = SingleBatchQueue( + samplers=samplers, + n_batches=n_batches, + batch_size=batch_size, + max_workers=max_workers, + ) + timer.start() + for _ in range(n_epochs): + _ = list(batcher) + timer.stop() + batcher.stop() + parallel_time = timer.elapsed / (n_epochs * n_batches) + print(f'Parallel / Serial Time: {parallel_time} / {serial_time}') + assert parallel_time < serial_time + + def test_spatial_batch_queue(): """Smoke test for spatial batch queue. A batch queue returns batches for spatial models if the sample shapes have 1 for the time axis""" From 89c6cdec0fa20cefa1cc06761e0ac7b2c70808a8 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 27 Dec 2024 08:40:09 -0800 Subject: [PATCH 10/32] namedtuple -> DsetTuple missing attr fix --- sup3r/preprocessing/base.py | 18 ++++++++++++------ sup3r/preprocessing/batch_queues/dual.py | 4 ++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index e046438fa..1e04cba44 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -74,21 +74,27 @@ class DsetTuple: while being serializable""" def __init__(self, **kwargs): + self.dset_names = list(kwargs) self.__dict__.update(kwargs) + @property + def dsets(self): + """Dictionary with only dset names and associated values.""" + return {k: v for k, v in self.__dict__.items() if k in self.dset_names} + def __iter__(self): - return iter(self.__dict__.values()) + return iter(self.dsets.values()) def __getitem__(self, key): if isinstance(key, int): - key = list(self.__dict__)[key] - return self.__dict__[key] + key = list(self.dsets)[key] + return self.dsets[key] def __len__(self): - return len(self.__dict__) + return len(self.dsets) def __repr__(self): - return f"DsetTuple({self.__dict__})" + return f'DsetTuple({self.dsets})' class Sup3rDataset: @@ -237,7 +243,7 @@ def __getitem__(self, keys): if len(self._ds) == 1: return out[-1] if all(isinstance(o, Sup3rX) for o in out): - return type(self)(**dict(zip(self._ds._fields, out))) + return type(self)(**dict(zip(self._ds.dset_names, out))) return out @property diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 9a89ce8bd..b7e26bf94 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -20,7 +20,7 @@ def __init__(self, samplers, **kwargs): -------- :class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue` """ - self.BATCH_MEMBERS = samplers[0]._fields + self.BATCH_MEMBERS = samplers[0].dset_names super().__init__(samplers, **kwargs) self.check_enhancement_factors() @@ -30,7 +30,7 @@ def __init__(self, samplers, **kwargs): def queue_shape(self): """Shape of objects stored in the queue.""" queue_shapes = [(self.batch_size, *self.lr_shape)] - hr_mems = len(self.Batch._fields) - 1 + hr_mems = len(self.BATCH_MEMBERS) - 1 queue_shapes += [(self.batch_size, *self.hr_shape)] * hr_mems return queue_shapes From 5976aa2f08a7f22a436088a232d155eb21e53301 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 27 Dec 2024 13:20:07 -0700 Subject: [PATCH 11/32] gust added to era download variables. len dunder added to ``Container`` class to enable passing ``xr.Dataset`` and like objects directly to ``BatchHandlers`` without needing to invoke ``.data`` --- sup3r/preprocessing/base.py | 3 +++ sup3r/preprocessing/names.py | 1 + 2 files changed, 4 insertions(+) diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 1e04cba44..4b3e9d681 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -428,6 +428,9 @@ def shape(self): """Get shape of underlying data.""" return self.data.shape + def __len__(self): + return len(self.data) + def __contains__(self, vals): return vals in self.data diff --git a/sup3r/preprocessing/names.py b/sup3r/preprocessing/names.py index 0ade726c6..700dce81b 100644 --- a/sup3r/preprocessing/names.py +++ b/sup3r/preprocessing/names.py @@ -162,6 +162,7 @@ def dims_4d_bc(cls): 'northward_turbulent_surface_stress', 'eastward_turbulent_surface_stress', 'sea_surface_temperature', + 'instantaneous_10m_wind_gust' ] # variables available on multiple pressure levels From ec8b7390346113611f331b5ea91a9e734ebf4059 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 28 Dec 2024 08:49:15 -0700 Subject: [PATCH 12/32] computing before reshaping is 2x faster. --- sup3r/preprocessing/samplers/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index dbf75f61f..b824cd360 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -194,10 +194,11 @@ def _reshape_samples(self, samples): new_shape[2] // self.batch_size, new_shape[-1], ] + out = compute_if_dask(samples) # (lats, lons, batch_size, times, feats) - out = np.reshape(samples, new_shape) + out = np.reshape(out, new_shape) # (batch_size, lats, lons, times, feats) - return compute_if_dask(np.transpose(out, axes=(2, 0, 1, 3, 4))) + return np.transpose(out, axes=(2, 0, 1, 3, 4)) def _stack_samples(self, samples): """Used to build batch arrays in the case of independent time samples From e265cd5216156b27ca9ebffa6962d1285046d417 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 28 Dec 2024 13:44:14 -0700 Subject: [PATCH 13/32] obs_index fix - sampler needs to use hr_out_features for the obs member. --- sup3r/models/base.py | 5 +- sup3r/preprocessing/base.py | 2 - sup3r/preprocessing/batch_queues/abstract.py | 64 +++++++++++------ sup3r/preprocessing/collections/base.py | 2 - sup3r/preprocessing/samplers/dual.py | 3 +- tests/batch_handlers/test_bh_general.py | 76 ++++++++++++++++++++ tests/batch_queues/test_bq_general.py | 50 ------------- 7 files changed, 123 insertions(+), 79 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index a3deb0243..7cbe4af98 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd import tensorflow as tf +from tensorflow.keras.losses import MeanAbsoluteError from sup3r.preprocessing.utilities import get_class_kwargs from sup3r.utilities import VERSION_RECORD @@ -880,7 +881,9 @@ def calc_loss( loss_obs = np.nan if obs_data is not None: mask = tf.math.is_nan(obs_data) - loss_obs = self.loss_fun(obs_data[~mask], hi_res_gen[~mask]) + loss_obs = MeanAbsoluteError()( + obs_data[~mask], + hi_res_gen[..., : len(self.hr_out_features)][~mask]) loss_gen += loss_obs loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 4b3e9d681..ce9d1d6c4 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -451,8 +451,6 @@ def __setitem__(self, keys, data): def __getattr__(self, attr): """Check if attribute is available from ``.data``""" - if attr in dir(self): - return self.__getattribute__(attr) try: data = self.__getattribute__('_data') return getattr(data, attr) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index f5b6b49b9..0c8d2f315 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -11,9 +11,9 @@ import threading import time from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, List, Optional, Union -import dask import numpy as np import tensorflow as tf @@ -46,7 +46,7 @@ def __init__( max_workers: int = 1, thread_name: str = 'training', mode: str = 'lazy', - verbose: bool = False + verbose: bool = False, ): """ Parameters @@ -91,6 +91,7 @@ def __init__( self._queue_thread = None self._training_flag = threading.Event() self._thread_name = thread_name + self._thread_pool = ThreadPoolExecutor(max_workers=max_workers) self.mode = mode self.s_enhance = s_enhance self.t_enhance = t_enhance @@ -115,6 +116,22 @@ def queue_shape(self): this is (batch_size, *sample_shape, len(features)). For dual dataset queues this is [(batch_size, *lr_shape), (batch_size, *hr_shape)]""" + @property + def queue_len(self): + """Get number of batches in the queue.""" + return self.queue.size().numpy() + + @property + def queue_futures(self): + """Get number of scheduled futures that will eventually add batches to + the queue.""" + return self._thread_pool._work_queue.qsize() + + @property + def queue_free(self): + """Get number of free spots in the queue.""" + return self.queue_cap - self.queue_len + def get_queue(self): """Return FIFO queue for storing batches.""" return tf.queue.FIFOQueue( @@ -222,13 +239,9 @@ def __iter__(self): def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" - if ( - self.mode == 'eager' - or self.queue_cap == 0 - or self.queue.size().numpy() == 0 - ): - return self.sample_batch() - return self.queue.dequeue() + if self.queue_len > 0 or self.queue_futures > 0: + return self.queue.dequeue() + return self.sample_batch() @property def running(self): @@ -245,17 +258,19 @@ def sample_batches(self, n_batches) -> None: if n_batches == 1: return [self.sample_batch()] - tasks = [dask.delayed(self.sample_batch)() for _ in range(n_batches)] - logger.debug('Added %s sample_batch futures to %s queue.', - n_batches, - self._thread_name) - if self.max_workers == 1: - batches = dask.compute(*tasks, scheduler='single-threaded') - else: - batches = dask.compute( - *tasks, scheduler='threads', num_workers=self.max_workers) - return batches + return [self.sample_batch() for _ in range(n_batches)] + + tasks = [ + self._thread_pool.submit(self.sample_batch) + for _ in range(n_batches) + ] + logger.debug( + 'Added %s sample_batch futures to %s queue.', + n_batches, + self._thread_name, + ) + return [task.result() for task in tasks] def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is @@ -263,8 +278,10 @@ def enqueue_batches(self) -> None: removed from the queue.""" log_time = time.time() while self.running: - needed = self.queue_cap - self.queue.size().numpy() - + needed = min( + self.queue_free - self.queue_futures, + self.n_batches - self._batch_count + ) # no point in getting more than one batch at a time if # max_workers == 1 needed = 1 if needed > 0 and self.max_workers == 1 else needed @@ -331,10 +348,11 @@ def sample_batch(self): def log_queue_info(self): """Log info about queue size.""" - return '{} queue length: {} / {}.'.format( + return '{} queue length: {} / {}, with {} futures'.format( self._thread_name.title(), - self.queue.size().numpy(), + self.queue_len, self.queue_cap, + self.queue_futures ) @property diff --git a/sup3r/preprocessing/collections/base.py b/sup3r/preprocessing/collections/base.py index f9f373bc9..736606a73 100644 --- a/sup3r/preprocessing/collections/base.py +++ b/sup3r/preprocessing/collections/base.py @@ -60,8 +60,6 @@ def container_weights(self): def __getattr__(self, attr): """Get attributes from self or the first container in the collection.""" - if attr in dir(self): - return self.__getattribute__(attr) return self.check_shared_attr(attr) def check_shared_attr(self, attr): diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index 62e5c4233..56b7c4fd5 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -145,7 +145,8 @@ def get_sample_index(self, n_obs=None): slice(s.start * self.t_enhance, s.stop * self.t_enhance) for s in lr_index[2:-1] ] + obs_index = (*hr_index, self.hr_out_features) hr_index = (*hr_index, self.hr_features) - sample_index = (lr_index, hr_index, hr_index) + sample_index = (lr_index, hr_index, obs_index) return sample_index[:len(self.data)] diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index b8f10fa8d..0c3e89401 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -2,8 +2,11 @@ import copy +import dask.array as da import numpy as np +import pandas as pd import pytest +import xarray as xr from scipy.ndimage import gaussian_filter from sup3r.preprocessing import ( @@ -17,6 +20,7 @@ ) from sup3r.utilities.utilities import ( RANDOM_GENERATOR, + Timer, spatial_coarsening, temporal_coarsening, ) @@ -29,6 +33,78 @@ BatchHandlerTester = BatchHandlerTesterFactory(BatchHandler, SamplerTester) +def test_batch_handler_workers(): + """Check that it is faster to get batches with max_workers > 1 than with + max_workers = 1.""" + + timer = Timer() + n_lats = 200 + n_lons = 200 + sample_shape = (20, 20, 30) + chunk_shape = ( + 2 * sample_shape[0], + 2 * sample_shape[1], + 2 * sample_shape[-1], + ) + n_obs = 40 + max_workers = 5 + n_batches = 40 + + lons, lats = np.meshgrid( + np.linspace(0, 180, n_lats), np.linspace(40, 60, n_lons) + ) + time = pd.date_range('2023-01-01', '2023-05-01', freq='h') + u_arr = da.random.random((*lats.shape, len(time))) + v_arr = da.random.random((*lats.shape, len(time))) + ds = xr.Dataset( + coords={ + 'latitude': (('south_north', 'west_east'), lats), + 'longitude': (('south_north', 'west_east'), lons), + 'time': time, + }, + data_vars={ + 'u_100m': (('south_north', 'west_east', 'time'), u_arr), + 'v_100m': (('south_north', 'west_east', 'time'), v_arr), + }, + ) + ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape))) + + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=max_workers, + ) + timer.start() + for _ in range(10): + _ = list(batcher) + timer.stop() + parallel_time = timer.elapsed / (n_batches * 10) + batcher.stop() + + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=1, + ) + timer.start() + for _ in range(10): + _ = list(batcher) + timer.stop() + serial_time = timer.elapsed / (n_batches * 10) + batcher.stop() + + print( + 'Elapsed (serial / parallel): {} / {}'.format( + serial_time, parallel_time + ) + ) + assert serial_time > parallel_time + + def test_eager_vs_lazy(): """Make sure eager and lazy loading agree. We use queue_cap = 0 here so there is no disagreement that results from dequeuing vs direct batch diff --git a/tests/batch_queues/test_bq_general.py b/tests/batch_queues/test_bq_general.py index b130aa7b7..52b2ffb25 100644 --- a/tests/batch_queues/test_bq_general.py +++ b/tests/batch_queues/test_bq_general.py @@ -12,7 +12,6 @@ DummyData, DummySampler, ) -from sup3r.utilities.utilities import Timer FEATURES = ['windspeed', 'winddirection'] @@ -54,55 +53,6 @@ def test_batch_queue(): batcher.stop() -def test_batch_queue_workers(): - """Check that using max_workers > 1 for a batch queue is faster than using - max_workers = 1.""" - - timer = Timer() - sample_shape = (10, 10, 20) - n_batches = 20 - batch_size = 10 - max_workers = 10 - n_epochs = 10 - chunk_shape = {'south_north': 20, 'west_east': 20, 'time': 40} - samplers = [ - DummySampler( - sample_shape, - data_shape=(100, 100, 1000), - batch_size=batch_size, - features=FEATURES, - chunk_shape=chunk_shape - ) - ] - batcher = SingleBatchQueue( - samplers=samplers, - n_batches=n_batches, - batch_size=batch_size, - max_workers=1, - ) - timer.start() - for _ in range(n_epochs): - _ = list(batcher) - timer.stop() - batcher.stop() - serial_time = timer.elapsed / (n_epochs * n_batches) - - batcher = SingleBatchQueue( - samplers=samplers, - n_batches=n_batches, - batch_size=batch_size, - max_workers=max_workers, - ) - timer.start() - for _ in range(n_epochs): - _ = list(batcher) - timer.stop() - batcher.stop() - parallel_time = timer.elapsed / (n_epochs * n_batches) - print(f'Parallel / Serial Time: {parallel_time} / {serial_time}') - assert parallel_time < serial_time - - def test_spatial_batch_queue(): """Smoke test for spatial batch queue. A batch queue returns batches for spatial models if the sample shapes have 1 for the time axis""" From d4c009d3bcce7c83887d07d1b7567dc65e9851e5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 29 Dec 2024 06:48:39 -0800 Subject: [PATCH 14/32] split up ``calc_loss`` and ``calc_loss_obs`` --- sup3r/models/abstract.py | 158 ++++++++++++++++++++---------------- sup3r/models/base.py | 18 ++-- sup3r/preprocessing/base.py | 5 +- 3 files changed, 97 insertions(+), 84 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 506f85b38..794554ea0 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -16,6 +16,7 @@ from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat from rex.utilities.utilities import safe_json_load from tensorflow.keras import optimizers +from tensorflow.keras.losses import MeanAbsoluteError import sup3r.utilities.loss_metrics from sup3r.preprocessing.data_handlers import ExoData @@ -715,6 +716,69 @@ def finish_epoch( return stop + def _get_parallel_grad( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + **calc_loss_kwargs, + ): + """Compute gradient for one mini-batch of (low_res, hi_res_true) + across multiple GPUs""" + + futures = [] + lr_chunks = np.array_split(low_res, len(self.gpu_list)) + hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) + obs_data_chunks = ( + [None] * len(hr_true_chunks) + if obs_data is None + else np.array_split(obs_data, len(self.gpu_list)) + ) + split_mask = False + mask_chunks = None + if 'mask' in calc_loss_kwargs: + split_mask = True + mask_chunks = np.array_split( + calc_loss_kwargs['mask'], len(self.gpu_list) + ) + + with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: + for i in range(len(self.gpu_list)): + if split_mask: + calc_loss_kwargs['mask'] = mask_chunks[i] + futures.append( + exe.submit( + self.get_single_grad, + lr_chunks[i], + hr_true_chunks[i], + training_weights, + obs_data=obs_data_chunks[i], + device_name=f'/gpu:{i}', + **calc_loss_kwargs, + ) + ) + + # sum the gradients from each gpu to weight equally in + # optimizer momentum calculation + total_grad = None + for future in futures: + grad, loss_details = future.result() + if total_grad is None: + total_grad = grad + else: + for i, igrad in enumerate(grad): + total_grad[i] += igrad + + self.timer.stop() + logger.debug( + 'Finished %s gradient descent steps on %s GPUs in %s', + len(futures), + len(self.gpu_list), + self.timer.elapsed_str, + ) + return total_grad, loss_details + def run_gradient_descent( self, low_res, @@ -725,7 +789,6 @@ def run_gradient_descent( multi_gpu=False, **calc_loss_kwargs, ): - # pylint: disable=E0602 """Run gradient descent for one mini-batch of (low_res, hi_res_true) and update weights @@ -765,6 +828,7 @@ def run_gradient_descent( loss_details : dict Namespace of the breakdown of loss components """ + self.timer.start() if optimizer is None: optimizer = self.optimizer @@ -785,58 +849,15 @@ def run_gradient_descent( self.timer.elapsed_str, ) else: - futures = [] - lr_chunks = np.array_split(low_res, len(self.gpu_list)) - hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) - obs_data_chunks = ( - [None] * len(hr_true_chunks) - if obs_data is None - else np.array_split(obs_data, len(self.gpu_list)) + total_grad, loss_details = self._get_parallel_grad( + low_res, + hi_res_true, + training_weights, + obs_data, + **calc_loss_kwargs, ) - split_mask = False - mask_chunks = None - if 'mask' in calc_loss_kwargs: - split_mask = True - mask_chunks = np.array_split( - calc_loss_kwargs['mask'], len(self.gpu_list) - ) - - with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: - for i in range(len(self.gpu_list)): - if split_mask: - calc_loss_kwargs['mask'] = mask_chunks[i] - futures.append( - exe.submit( - self.get_single_grad, - lr_chunks[i], - hr_true_chunks[i], - training_weights, - obs_data=obs_data_chunks[i], - device_name=f'/gpu:{i}', - **calc_loss_kwargs, - ) - ) - - # sum the gradients from each gpu to weight equally in - # optimizer momentum calculation - total_grad = None - for future in futures: - grad, loss_details = future.result() - if total_grad is None: - total_grad = grad - else: - for i, igrad in enumerate(grad): - total_grad[i] += igrad - optimizer.apply_gradients(zip(total_grad, training_weights)) - self.timer.stop() - logger.debug( - 'Finished %s gradient descent steps on %s GPUs in %s', - len(futures), - len(self.gpu_list), - self.timer.elapsed_str, - ) return loss_details def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): @@ -1081,9 +1102,13 @@ def get_single_grad( hi_res_exo = self.get_high_res_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss_out = self.calc_loss( - hi_res_true, hi_res_gen, obs_data=obs_data, **calc_loss_kwargs + hi_res_true, hi_res_gen, **calc_loss_kwargs ) loss, loss_details = loss_out + if obs_data is not None: + loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + loss += loss_obs + loss_details['loss_obs'] = loss_obs grad = tape.gradient(loss, training_weights) return grad, loss_details @@ -1092,36 +1117,33 @@ def calc_loss( self, hi_res_true, hi_res_gen, - obs_data=None, weight_gen_advers=0.001, train_gen=True, train_disc=False, ): """Calculate the GAN loss function using generated and true high - resolution data. + resolution data.""" + + @tf.function + def calc_loss_obs(self, obs_data, hi_res_gen): + """Calculate loss term for the observation data vs generated + high-resolution data Parameters ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. hi_res_gen : tf.Tensor Superresolved high resolution spatiotemporal data generated by the generative model. - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - train_gen : bool - True if generator is being trained, then loss=loss_gen - train_disc : bool - True if disc is being trained, then loss=loss_disc Returns ------- loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components + 0D tensor of observation loss """ + mask = tf.math.is_nan(obs_data) + return MeanAbsoluteError()( + obs_data[~mask], + hi_res_gen[..., : len(self.hr_out_features)][~mask], + ) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 7cbe4af98..a31094f9a 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -10,7 +10,6 @@ import numpy as np import pandas as pd import tensorflow as tf -from tensorflow.keras.losses import MeanAbsoluteError from sup3r.preprocessing.utilities import get_class_kwargs from sup3r.utilities import VERSION_RECORD @@ -824,7 +823,6 @@ def calc_loss( self, hi_res_true, hi_res_gen, - obs_data=None, weight_gen_advers=0.001, train_gen=True, train_disc=False, @@ -877,15 +875,6 @@ def calc_loss( loss_gen_content = self.calc_loss_gen_content(hi_res_true, hi_res_gen) loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers - - loss_obs = np.nan - if obs_data is not None: - mask = tf.math.is_nan(obs_data) - loss_obs = MeanAbsoluteError()( - obs_data[~mask], - hi_res_gen[..., : len(self.hr_out_features)][~mask]) - loss_gen += loss_obs - loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) loss = None @@ -896,7 +885,6 @@ def calc_loss( loss_details = { 'loss_gen': loss_gen, - 'loss_obs': loss_obs, 'loss_gen_content': loss_gen_content, 'loss_gen_advers': loss_gen_advers, 'loss_disc': loss_disc, @@ -930,11 +918,15 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): _, v_loss_details = self.calc_loss( val_batch.high_res, high_res_gen, - obs_data=getattr(val_batch, 'obs', None), weight_gen_advers=weight_gen_advers, train_gen=False, train_disc=False, ) + obs_data = getattr(val_batch, 'obs', None) + if obs_data is not None: + v_loss_details['loss_obs'] = self.cal_loss_obs( + obs_data, high_res_gen + ) loss_details = self.update_loss_details( loss_details, v_loss_details, len(val_batch), prefix='val_' diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index ce9d1d6c4..7d81f9d27 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -379,7 +379,7 @@ def data(self, data): def wrap(self, data): """ Return a :class:`~.Sup3rDataset` object or tuple of such. This is a - tuple when the `.data` attribute belongs to a + tuple when the ``.data`` attribute belongs to a :class:`~.collections.base.Collection` object like :class:`~.batch_handlers.factory.BatchHandler`. Otherwise this is :class:`~.Sup3rDataset` object, which is either a wrapped 3-tuple, @@ -452,8 +452,7 @@ def __setitem__(self, keys, data): def __getattr__(self, attr): """Check if attribute is available from ``.data``""" try: - data = self.__getattribute__('_data') - return getattr(data, attr) + return getattr(self._data, attr) except Exception as e: msg = f'{self.__class__.__name__} object has no attribute "{attr}"' raise AttributeError(msg) from e From ae564acc76c1f3a9d27c40320d46254d5bbdae50 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 29 Dec 2024 07:43:59 -0700 Subject: [PATCH 15/32] Optional run_qa flag in ``DualRasterizer``. Queue shape fix for queues with obs data --- sup3r/models/abstract.py | 1 + sup3r/models/base.py | 2 +- sup3r/preprocessing/batch_queues/dual.py | 18 +++++++++++++----- sup3r/preprocessing/rasterizers/dual.py | 9 +++++++-- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 794554ea0..579dab5e9 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -539,6 +539,7 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None): if key in loss_details: saved_value = loss_details[key] + saved_value = 0 if np.isnan(saved_value) else saved_value saved_value *= prior_n_obs saved_value += batch_len * new_value saved_value /= new_n_obs diff --git a/sup3r/models/base.py b/sup3r/models/base.py index a31094f9a..cd4a437f7 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -977,7 +977,7 @@ def train_epoch( disc_th_low = np.min(disc_loss_bounds) disc_th_high = np.max(disc_loss_bounds) - loss_details = {'n_obs': 0, 'train_loss_disc': 0} + loss_details = {'n_obs': 0, 'train_loss_disc': 0, 'train_loss_obs': 0} only_gen = train_gen and not train_disc only_disc = train_disc and not train_gen diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index b7e26bf94..56b2b08d4 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -28,11 +28,19 @@ def __init__(self, samplers, **kwargs): @property def queue_shape(self): - """Shape of objects stored in the queue.""" - queue_shapes = [(self.batch_size, *self.lr_shape)] - hr_mems = len(self.BATCH_MEMBERS) - 1 - queue_shapes += [(self.batch_size, *self.hr_shape)] * hr_mems - return queue_shapes + """Shape of objects stored in the queue. Optionally includes shape of + observation data which would be included in an extra content loss + term""" + obs_shape = ( + *self.hr_shape[:-1], + len(self.containers[0].hr_out_features), + ) + queue_shapes = [ + (self.batch_size, *self.lr_shape), + (self.batch_size, *self.hr_shape), + (self.batch_size, *obs_shape), + ] + return queue_shapes[: len(self.BATCH_MEMBERS)] def check_enhancement_factors(self): """Make sure each DualSampler has the same enhancment factors and they diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 47706aa4e..a70f01b08 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -43,6 +43,7 @@ def __init__( ], regrid_workers=1, regrid_lr=True, + run_qa=False, s_enhance=1, t_enhance=1, lr_cache_kwargs=None, @@ -63,6 +64,9 @@ def __init__( Flag to regrid the low-res data to the high-res grid. This will take care of any minor inconsistencies in different projections. Disable this if the grids are known to be the same. + run_qa : bool + Flag to run qa on the regridded low-res data. This will check for + NaNs and fill them if there are not too many. s_enhance : int Spatial enhancement factor t_enhance : int @@ -135,7 +139,8 @@ def __init__( self.update_hr_data() super().__init__(data=(self.lr_data, self.hr_data)) - self.check_regridded_lr_data() + if run_qa: + self.check_regridded_lr_data() if lr_cache_kwargs is not None: Cacher(self.lr_data, lr_cache_kwargs) @@ -205,7 +210,7 @@ def update_lr_data(self): lr_coords_new = { Dimension.LATITUDE: self.lr_lat_lon[..., 0], Dimension.LONGITUDE: self.lr_lat_lon[..., 1], - Dimension.TIME: self.lr_data.indexes['time'][ + Dimension.TIME: self.lr_data.indexes[Dimension.TIME][ : self.lr_required_shape[2] ], } From 4085a078904e1954233e9239f0af6c7d8f7c68ec Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 29 Dec 2024 08:27:46 -0800 Subject: [PATCH 16/32] ``run_qa=True`` default for ``DualRasterizer`` --- sup3r/preprocessing/rasterizers/dual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index a70f01b08..7b3df385d 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -43,7 +43,7 @@ def __init__( ], regrid_workers=1, regrid_lr=True, - run_qa=False, + run_qa=True, s_enhance=1, t_enhance=1, lr_cache_kwargs=None, From dc933f90742a1a98ec5a345918b319ff42cfdf19 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 29 Dec 2024 12:13:09 -0800 Subject: [PATCH 17/32] better tracking of batch counting. (this can be tricky for parallel queueing, since batches can be sampled directly if there are none in the queue). --- sup3r/preprocessing/batch_queues/abstract.py | 41 ++++++++++---------- tests/batch_handlers/test_bh_general.py | 4 +- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 0c8d2f315..1537ca3bb 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -119,7 +119,7 @@ def queue_shape(self): @property def queue_len(self): """Get number of batches in the queue.""" - return self.queue.size().numpy() + return self.queue.size().numpy() + self.queue_futures @property def queue_futures(self): @@ -127,11 +127,6 @@ def queue_futures(self): the queue.""" return self._thread_pool._work_queue.qsize() - @property - def queue_free(self): - """Get number of free spots in the queue.""" - return self.queue_cap - self.queue_len - def get_queue(self): """Return FIFO queue for storing batches.""" return tf.queue.FIFOQueue( @@ -232,16 +227,16 @@ def __len__(self): return self.n_batches def __iter__(self): - self._batch_count = 0 self.start() + self._batch_count = 0 return self def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" - if self.queue_len > 0 or self.queue_futures > 0: - return self.queue.dequeue() - return self.sample_batch() + if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0: + return self.sample_batch() + return self.queue.dequeue() @property def running(self): @@ -272,19 +267,26 @@ def sample_batches(self, n_batches) -> None: ) return [task.result() for task in tasks] + @property + def needed_batches(self): + """Number of batches needed to either fill or the queue or hit the + epoch limit.""" + remaining = self.n_batches - self._batch_count - self.queue_len - 1 + return min(self.queue_cap - self.queue_len, remaining) + def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.""" log_time = time.time() while self.running: - needed = min( - self.queue_free - self.queue_futures, - self.n_batches - self._batch_count - ) # no point in getting more than one batch at a time if # max_workers == 1 - needed = 1 if needed > 0 and self.max_workers == 1 else needed + needed = ( + 1 + if self.needed_batches > 0 and self.max_workers == 1 + else self.needed_batches + ) if needed > 0: for batch in self.sample_batches(n_batches=needed): @@ -307,6 +309,7 @@ def __next__(self) -> DsetTuple: if self._batch_count < self.n_batches: self.timer.start() samples = self.get_batch() + self._batch_count += 1 if self.sample_shape[2] == 1: if isinstance(samples, (list, tuple)): samples = tuple(s[..., 0, :] for s in samples) @@ -314,7 +317,6 @@ def __next__(self) -> DsetTuple: samples = samples[..., 0, :] batch = self.post_proc(samples) self.timer.stop() - self._batch_count += 1 if self.verbose: logger.debug( 'Batch step %s finished in %s.', @@ -348,11 +350,8 @@ def sample_batch(self): def log_queue_info(self): """Log info about queue size.""" - return '{} queue length: {} / {}, with {} futures'.format( - self._thread_name.title(), - self.queue_len, - self.queue_cap, - self.queue_futures + return '{} queue length: {} / {}'.format( + self._thread_name.title(), self.queue_len, self.queue_cap ) @property diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index 0c3e89401..348ad5a2f 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -47,8 +47,8 @@ def test_batch_handler_workers(): 2 * sample_shape[-1], ) n_obs = 40 - max_workers = 5 - n_batches = 40 + max_workers = 20 + n_batches = 20 lons, lats = np.meshgrid( np.linspace(0, 180, n_lats), np.linspace(40, 60, n_lons) From f26f3047d1748c6d62f010ae23b3b7cfa4e32860 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 29 Dec 2024 15:41:00 -0800 Subject: [PATCH 18/32] missed compute call for slow batching. this was hidden by queueing and dequeueing since this would cast to tensors. --- sup3r/preprocessing/batch_queues/abstract.py | 4 ++-- sup3r/preprocessing/samplers/base.py | 16 ++++++++-------- sup3r/preprocessing/utilities.py | 2 ++ tests/batch_handlers/test_bh_general.py | 4 ++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 1537ca3bb..6c022118d 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -271,7 +271,7 @@ def sample_batches(self, n_batches) -> None: def needed_batches(self): """Number of batches needed to either fill or the queue or hit the epoch limit.""" - remaining = self.n_batches - self._batch_count - self.queue_len - 1 + remaining = self.n_batches - self._batch_count - self.queue_len return min(self.queue_cap - self.queue_len, remaining) def enqueue_batches(self) -> None: @@ -309,7 +309,6 @@ def __next__(self) -> DsetTuple: if self._batch_count < self.n_batches: self.timer.start() samples = self.get_batch() - self._batch_count += 1 if self.sample_shape[2] == 1: if isinstance(samples, (list, tuple)): samples = tuple(s[..., 0, :] for s in samples) @@ -317,6 +316,7 @@ def __next__(self) -> DsetTuple: samples = samples[..., 0, :] batch = self.post_proc(samples) self.timer.stop() + self._batch_count += 1 if self.verbose: logger.debug( 'Batch step %s finished in %s.', diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index b824cd360..e576c1521 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple from warnings import warn -import dask.array as da import numpy as np from sup3r.preprocessing.base import Container @@ -194,9 +193,8 @@ def _reshape_samples(self, samples): new_shape[2] // self.batch_size, new_shape[-1], ] - out = compute_if_dask(samples) # (lats, lons, batch_size, times, feats) - out = np.reshape(out, new_shape) + out = np.reshape(samples, new_shape) # (batch_size, lats, lons, times, feats) return np.transpose(out, axes=(2, 0, 1, 3, 4)) @@ -223,25 +221,27 @@ def _stack_samples(self, samples): (batch_size, samp_shape[0], samp_shape[1], samp_shape[2], n_feats) """ if isinstance(samples[0], tuple): - lr = da.stack([s[0] for s in samples], axis=0) - hr = da.stack([s[1] for s in samples], axis=0) + lr = np.stack([s[0] for s in samples], axis=0) + hr = np.stack([s[1] for s in samples], axis=0) return (lr, hr) - return da.stack(samples, axis=0) + return np.stack(samples, axis=0) def _fast_batch(self): """Get batch of samples with adjacent time slices.""" out = self.data.sample(self.get_sample_index(n_obs=self.batch_size)) + out = compute_if_dask(out) if isinstance(out, tuple): return tuple(self._reshape_samples(o) for o in out) return self._reshape_samples(out) def _slow_batch(self): """Get batch of samples with random time slices.""" - samples = [ + out = [ self.data.sample(self.get_sample_index(n_obs=1)) for _ in range(self.batch_size) ] - return self._stack_samples(samples) + out = compute_if_dask(out) + return self._stack_samples(out) def _fast_batch_possible(self): return self.batch_size * self.sample_shape[2] <= self.data.shape[2] diff --git a/sup3r/preprocessing/utilities.py b/sup3r/preprocessing/utilities.py index 8a51fadbf..f52e91b72 100644 --- a/sup3r/preprocessing/utilities.py +++ b/sup3r/preprocessing/utilities.py @@ -228,6 +228,8 @@ def compute_if_dask(arr): compute_if_dask(arr.stop), compute_if_dask(arr.step), ) + if isinstance(arr, (tuple, list)): + return type(arr)(compute_if_dask(a) for a in arr) return arr.compute() if hasattr(arr, 'compute') else arr diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index 348ad5a2f..10dabab23 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -47,8 +47,8 @@ def test_batch_handler_workers(): 2 * sample_shape[-1], ) n_obs = 40 - max_workers = 20 - n_batches = 20 + max_workers = 32 + n_batches = 40 lons, lats = np.meshgrid( np.linspace(0, 180, n_lats), np.linspace(40, 60, n_lons) From b43e8078a9f7144e5f8474027c344fb8bf9597e4 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Dec 2024 07:12:20 -0800 Subject: [PATCH 19/32] Included convert to tensor in ``sample_batch``. Test for training with ``max_workers > 1``. --- sup3r/preprocessing/batch_queues/abstract.py | 5 +- sup3r/utilities/utilities.py | 2 + tests/batch_handlers/test_bh_general.py | 17 ++-- tests/rasterizers/test_dual.py | 10 ++- tests/training/test_train_dual.py | 6 +- tests/training/test_train_dual_with_obs.py | 4 +- tests/training/test_train_gan.py | 81 +++++++++++++++++++- 7 files changed, 106 insertions(+), 19 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 6c022118d..e64c9dad8 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -346,7 +346,10 @@ def sample_batch(self): These samples are wrapped in an ``np.asarray`` call, so they have been loaded into memory. """ - return next(self.get_random_container()) + out = next(self.get_random_container()) + if not isinstance(out, tuple): + return tf.convert_to_tensor(out, dtype=tf.float32) + return tuple(tf.convert_to_tensor(o, dtype=tf.float32) for o in out) def log_queue_info(self): """Log info about queue size.""" diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 3a563119a..f0f64246a 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -91,6 +91,8 @@ def stop(self): @property def elapsed(self): """Elapsed time between start and stop.""" + if self._stop is None: + return time.time() - self._start return self._stop - self._start @property diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index 10dabab23..8eca9d58c 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -49,13 +49,14 @@ def test_batch_handler_workers(): n_obs = 40 max_workers = 32 n_batches = 40 + n_epochs = 3 lons, lats = np.meshgrid( np.linspace(0, 180, n_lats), np.linspace(40, 60, n_lons) ) time = pd.date_range('2023-01-01', '2023-05-01', freq='h') - u_arr = da.random.random((*lats.shape, len(time))) - v_arr = da.random.random((*lats.shape, len(time))) + u_arr = da.random.random((*lats.shape, len(time))).astype('float32') + v_arr = da.random.random((*lats.shape, len(time))).astype('float32') ds = xr.Dataset( coords={ 'latitude': (('south_north', 'west_east'), lats), @@ -75,12 +76,14 @@ def test_batch_handler_workers(): batch_size=n_obs, sample_shape=sample_shape, max_workers=max_workers, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, ) timer.start() - for _ in range(10): + for _ in range(n_epochs): _ = list(batcher) timer.stop() - parallel_time = timer.elapsed / (n_batches * 10) + parallel_time = timer.elapsed / (n_batches * n_epochs) batcher.stop() batcher = BatchHandler( @@ -89,12 +92,14 @@ def test_batch_handler_workers(): batch_size=n_obs, sample_shape=sample_shape, max_workers=1, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, ) timer.start() - for _ in range(10): + for _ in range(n_epochs): _ = list(batcher) timer.stop() - serial_time = timer.elapsed / (n_batches * 10) + serial_time = timer.elapsed / (n_batches * n_epochs) batcher.stop() print( diff --git a/tests/rasterizers/test_dual.py b/tests/rasterizers/test_dual.py index ed11feecb..725c2fdb0 100644 --- a/tests/rasterizers/test_dual.py +++ b/tests/rasterizers/test_dual.py @@ -30,7 +30,9 @@ def test_dual_rasterizer_shapes(full_shape=(20, 20)): ) pair_rasterizer = DualRasterizer( - (lr_container.data, hr_container.data), s_enhance=2, t_enhance=1 + {'low_res': lr_container.data, 'high_res': hr_container.data}, + s_enhance=2, + t_enhance=1, ) assert pair_rasterizer.lr_data.shape == ( pair_rasterizer.hr_data.shape[0] // 2, @@ -63,7 +65,9 @@ def test_dual_nan_fill(full_shape=(20, 20)): assert np.isnan(lr_container.data.as_array()).any() pair_rasterizer = DualRasterizer( - (lr_container.data, hr_container.data), s_enhance=1, t_enhance=1 + {'low_res': lr_container.data, 'high_res': hr_container.data}, + s_enhance=1, + t_enhance=1, ) assert not np.isnan(pair_rasterizer.lr_data.as_array()).any() @@ -89,7 +93,7 @@ def test_regrid_caching(full_shape=(20, 20)): lr_cache_pattern = os.path.join(td, 'lr_{feature}.h5') hr_cache_pattern = os.path.join(td, 'hr_{feature}.h5') pair_rasterizer = DualRasterizer( - (lr_container.data, hr_container.data), + {'low_res': lr_container.data, 'high_res': hr_container.data}, s_enhance=2, t_enhance=1, lr_cache_kwargs={'cache_pattern': lr_cache_pattern}, diff --git a/tests/training/test_train_dual.py b/tests/training/test_train_dual.py index f92443961..29e38a32b 100644 --- a/tests/training/test_train_dual.py +++ b/tests/training/test_train_dual.py @@ -66,7 +66,7 @@ def test_train_h5_nc( # time indices conflict with t_enhance with pytest.raises(AssertionError): dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) @@ -78,7 +78,7 @@ def test_train_h5_nc( ) dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) @@ -158,7 +158,7 @@ def test_train_coarse_h5( ) dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py index 48a9c7d67..0bf8244ee 100644 --- a/tests/training/test_train_dual_with_obs.py +++ b/tests/training/test_train_dual_with_obs.py @@ -67,7 +67,7 @@ def test_train_h5_nc( ) dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) @@ -171,7 +171,7 @@ def test_train_coarse_h5( ) dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 81d1ebd63..79d09119d 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -11,6 +11,7 @@ from sup3r.models import Sup3rGan from sup3r.preprocessing import BatchHandler, DataHandler +from sup3r.utilities.utilities import Timer TARGET_COORD = (39.01, -105.15) FEATURES = ['u_100m', 'v_100m'] @@ -113,10 +114,13 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): assert 'OptmGen/learning_rate' in model.history assert 'OptmDisc/learning_rate' in model.history - msg = ('Could not find OptmGen states in columns: ' - f'{sorted(model.history.columns)}') - check = [col.startswith('OptmGen/Adam/v') - for col in model.history.columns] + msg = ( + 'Could not find OptmGen states in columns: ' + f'{sorted(model.history.columns)}' + ) + check = [ + col.startswith('OptmGen/Adam/v') for col in model.history.columns + ] assert any(check), msg assert 'config_generator' in loaded.meta @@ -165,6 +169,75 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): batch_handler.stop() +def test_train_workers(n_epoch=3): + """Test that model training with max_workers > 1 for the batch queue is + faster than for max_workers = 1.""" + + lr = 5e-5 + Sup3rGan.seed() + model = Sup3rGan( + pytest.ST_FP_GEN, + pytest.ST_FP_DISC, + learning_rate=lr, + loss='MeanAbsoluteError', + ) + + train_handler, val_handler = _get_handlers() + timer = Timer() + + with tempfile.TemporaryDirectory() as td: + batch_handler = BatchHandler( + train_containers=[train_handler], + val_containers=[val_handler], + sample_shape=(12, 12, 16), + batch_size=15, + s_enhance=3, + t_enhance=4, + n_batches=10, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, + max_workers=10, + ) + + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': n_epoch, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + timer.start() + model.train(batch_handler, **model_kwargs) + parallel_time = timer.elapsed + + batch_handler = BatchHandler( + train_containers=[train_handler], + val_containers=[val_handler], + sample_shape=(12, 12, 16), + batch_size=15, + s_enhance=3, + t_enhance=4, + n_batches=10, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, + max_workers=1, + ) + + timer.start() + model.train(batch_handler, **model_kwargs) + serial_time = timer.elapsed + + print( + 'Elapsed (parallel / serial): {} / {}'.format( + parallel_time, serial_time + ) + ) + assert parallel_time < serial_time + + def test_train_st_weight_update(n_epoch=2): """Test basic spatiotemporal model training with discriminators and adversarial loss updating.""" From 9b542bec80834663e47aff1c087359b58c1f745e Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 31 Dec 2024 11:53:54 -0700 Subject: [PATCH 20/32] cc batch handler test fix --- sup3r/models/base.py | 7 ++- sup3r/preprocessing/accessor.py | 52 ++++++++--------- sup3r/preprocessing/batch_queues/abstract.py | 60 +++++++------------- sup3r/preprocessing/rasterizers/dual.py | 2 +- sup3r/preprocessing/utilities.py | 1 + tests/batch_handlers/test_bh_general.py | 6 +- tests/batch_handlers/test_bh_h5_cc.py | 10 ++-- tests/training/test_train_gan.py | 45 ++++++++------- 8 files changed, 88 insertions(+), 95 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index cd4a437f7..53b159464 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -815,6 +815,11 @@ def train( ) if stop: break + logger.info( + 'Finished training %s epochs in %s seconds', + n_epoch, + time.time() - t0, + ) batch_handler.stop() @@ -977,7 +982,7 @@ def train_epoch( disc_th_low = np.min(disc_loss_bounds) disc_th_high = np.max(disc_loss_bounds) - loss_details = {'n_obs': 0, 'train_loss_disc': 0, 'train_loss_obs': 0} + loss_details = {'n_obs': 0, 'train_loss_disc': 0} only_gen = train_gen and not train_disc only_disc = train_disc and not train_gen diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 8eaa28d90..aad25b884 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -488,8 +488,8 @@ def size(self): def time_index(self): """Base time index for contained data.""" return ( - pd.to_datetime(self._ds.indexes['time']) - if 'time' in self._ds.indexes + pd.to_datetime(self._ds.indexes[Dimension.TIME]) + if Dimension.TIME in self._ds.indexes else None ) @@ -562,40 +562,40 @@ def flatten(self): """Flatten rasterized dataset so that there is only a single spatial dimension.""" if not self.flattened: - self._ds = self._ds.stack( - {Dimension.FLATTENED_SPATIAL: Dimension.dims_2d()} - ) - self._ds = self._ds.assign( - { - Dimension.FLATTENED_SPATIAL: np.arange( - len(self._ds[Dimension.FLATTENED_SPATIAL]) - ) - } - ) + dims = {Dimension.FLATTENED_SPATIAL: Dimension.dims_2d()} + self._ds = self._ds.stack(dims) + index = np.arange(len(self._ds[Dimension.FLATTENED_SPATIAL])) + self._ds = self._ds.assign({Dimension.FLATTENED_SPATIAL: index}) else: msg = 'Dataset is already flattened' logger.warning(msg) warn(msg) return self - def _qa(self, feature): + def _qa(self, feature, stats=None): """Get qa info for given feature.""" info = {} + stats = stats or ['nan_perc', 'std', 'mean', 'min', 'max'] logger.info('Running qa on feature: %s', feature) nan_count = 100 * np.isnan(self[feature].data).sum() nan_perc = nan_count / self[feature].size - info['nan_perc'] = compute_if_dask(nan_perc) - info['std'] = compute_if_dask(self[feature].std().data) - info['mean'] = compute_if_dask(self[feature].mean().data) - info['min'] = compute_if_dask(self[feature].min().data) - info['max'] = compute_if_dask(self[feature].max().data) + + for stat in stats: + logger.info('Running QA method %s on feature: %s', stat, feature) + if stat == 'nan_perc': + info['nan_perc'] = compute_if_dask(nan_perc) + else: + msg = f'Unknown QA method requested: {stat}' + assert hasattr(self[feature], stat), msg + qa_data = getattr(self[feature], stat)().data + info[stat] = compute_if_dask(qa_data) return info - def qa(self): - """Check NaNs and stats for all features.""" + def qa(self, stats=None): + """Check NaNs and the given stats for all features.""" qa_info = {} for f in self.features: - qa_info[f] = self._qa(f) + qa_info[f] = self._qa(f, stats=stats) return qa_info def __mul__(self, other): @@ -604,9 +604,8 @@ def __mul__(self, other): try: return type(self)(other * self._ds) except Exception as e: - raise NotImplementedError( - f'Multiplication not supported for type {type(other)}.' - ) from e + msg = f'Multiplication not supported for type {type(other)}.' + raise NotImplementedError(msg) from e def __rmul__(self, other): return self.__mul__(other) @@ -617,6 +616,5 @@ def __pow__(self, other): try: return type(self)(self._ds**other) except Exception as e: - raise NotImplementedError( - f'Exponentiation not supported for type {type(other)}.' - ) from e + msg = f'Exponentiation not supported for type {type(other)}.' + raise NotImplementedError(msg) from e diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index e64c9dad8..55928d7f6 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -11,7 +11,7 @@ import threading import time from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import TYPE_CHECKING, List, Optional, Union import numpy as np @@ -227,14 +227,14 @@ def __len__(self): return self.n_batches def __iter__(self): - self.start() self._batch_count = 0 + self.start() return self def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" - if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0: + if self.mode == 'eager' or self.queue_cap == 0: return self.sample_batch() return self.queue.dequeue() @@ -247,32 +247,24 @@ def running(self): and not self.queue.is_closed() ) - def sample_batches(self, n_batches) -> None: - """Sample N batches from samplers. Returns N batches which are then - used to fill the queue.""" - if n_batches == 1: - return [self.sample_batch()] - - if self.max_workers == 1: - return [self.sample_batch() for _ in range(n_batches)] - - tasks = [ - self._thread_pool.submit(self.sample_batch) - for _ in range(n_batches) - ] - logger.debug( - 'Added %s sample_batch futures to %s queue.', - n_batches, - self._thread_name, - ) - return [task.result() for task in tasks] + def _enqueue_batches(self, n_batches) -> None: + """Sample N batches and enqueue them as they are sampled.""" + if n_batches == 1 or self.max_workers == 1: + for _ in range(n_batches): + self.queue.enqueue(self.sample_batch()) - @property - def needed_batches(self): - """Number of batches needed to either fill or the queue or hit the - epoch limit.""" - remaining = self.n_batches - self._batch_count - self.queue_len - return min(self.queue_cap - self.queue_len, remaining) + else: + tasks = [ + self._thread_pool.submit(self.sample_batch) + for _ in range(n_batches) + ] + logger.debug( + 'Added %s sample_batch futures to %s queue.', + n_batches, + self._thread_name, + ) + for batch in as_completed(tasks): + self.queue.enqueue(batch.result()) def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is @@ -280,17 +272,9 @@ def enqueue_batches(self) -> None: removed from the queue.""" log_time = time.time() while self.running: - # no point in getting more than one batch at a time if - # max_workers == 1 - needed = ( - 1 - if self.needed_batches > 0 and self.max_workers == 1 - else self.needed_batches - ) - + needed = max(self.queue_cap - self.queue_len, 0) if needed > 0: - for batch in self.sample_batches(n_batches=needed): - self.queue.enqueue(batch) + self._enqueue_batches(n_batches=needed) if time.time() > log_time + 10: logger.debug(self.log_queue_info()) diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 7b3df385d..2df70fa84 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -223,7 +223,7 @@ def check_regridded_lr_data(self): """Check for NaNs after regridding and do NN fill if needed.""" fill_feats = [] logger.info('Checking for NaNs after regridding') - qa_info = self.lr_data.qa() + qa_info = self.lr_data.qa(stats=['nan_perc']) for f in self.lr_data.features: nan_perc = qa_info[f]['nan_perc'] if nan_perc > 0: diff --git a/sup3r/preprocessing/utilities.py b/sup3r/preprocessing/utilities.py index f52e91b72..db0dd5909 100644 --- a/sup3r/preprocessing/utilities.py +++ b/sup3r/preprocessing/utilities.py @@ -411,6 +411,7 @@ def parse_keys( order. If keys is empty then we just want to return the coordinate data, so features will be set to just the coordinate names.""" + keys = list(keys) if isinstance(keys, set) else keys keys = keys if isinstance(keys, tuple) else (keys,) has_feats = is_type_of(keys[0], str) diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index 8eca9d58c..3c6e3e595 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -46,9 +46,9 @@ def test_batch_handler_workers(): 2 * sample_shape[1], 2 * sample_shape[-1], ) - n_obs = 40 - max_workers = 32 - n_batches = 40 + n_obs = 10 + max_workers = 10 + n_batches = 10 n_epochs = 3 lons, lats = np.meshgrid( diff --git a/tests/batch_handlers/test_bh_h5_cc.py b/tests/batch_handlers/test_bh_h5_cc.py index 0fac7520a..942e52eed 100644 --- a/tests/batch_handlers/test_bh_h5_cc.py +++ b/tests/batch_handlers/test_bh_h5_cc.py @@ -1,6 +1,5 @@ """pytests for H5 climate change data batch handlers""" - import matplotlib.pyplot as plt import numpy as np import pytest @@ -370,10 +369,11 @@ def test_surf_min_max_vars(): assert batch.low_res.shape[-1] == len(surf_features) # compare daily avg temp vs min and max - assert (batch.low_res[..., 0] > batch.low_res[..., 2]).all() - assert (batch.low_res[..., 0] < batch.low_res[..., 3]).all() + blr = batch.low_res.numpy() + assert (blr[..., 0] > blr[..., 2]).all() + assert (blr[..., 0] < blr[..., 3]).all() # compare daily avg rh vs min and max - assert (batch.low_res[..., 1] > batch.low_res[..., 4]).all() - assert (batch.low_res[..., 1] < batch.low_res[..., 5]).all() + assert (blr[..., 1] > blr[..., 4]).all() + assert (blr[..., 1] < blr[..., 5]).all() batcher.stop() diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 79d09119d..a5dbe9864 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -169,34 +169,37 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): batch_handler.stop() -def test_train_workers(n_epoch=3): +def test_train_workers(n_epoch=20): """Test that model training with max_workers > 1 for the batch queue is faster than for max_workers = 1.""" lr = 5e-5 + train_handler, val_handler = _get_handlers() + timer = Timer() + n_batches = 40 + batch_size = 40 + Sup3rGan.seed() model = Sup3rGan( - pytest.ST_FP_GEN, - pytest.ST_FP_DISC, + pytest.S_FP_GEN, + pytest.S_FP_DISC, learning_rate=lr, loss='MeanAbsoluteError', ) - train_handler, val_handler = _get_handlers() - timer = Timer() - with tempfile.TemporaryDirectory() as td: + batch_handler = BatchHandler( train_containers=[train_handler], val_containers=[val_handler], - sample_shape=(12, 12, 16), - batch_size=15, - s_enhance=3, - t_enhance=4, - n_batches=10, + sample_shape=(10, 10, 1), + batch_size=batch_size, + s_enhance=2, + t_enhance=1, + n_batches=n_batches, means={'u_100m': 0, 'v_100m': 0}, stds={'u_100m': 1, 'v_100m': 1}, - max_workers=10, + max_workers=5, ) model_kwargs = { @@ -205,22 +208,23 @@ def test_train_workers(n_epoch=3): 'weight_gen_advers': 0.0, 'train_gen': True, 'train_disc': False, - 'checkpoint_int': 1, + 'checkpoint_int': 10, 'out_dir': os.path.join(td, 'test_{epoch}'), } timer.start() model.train(batch_handler, **model_kwargs) - parallel_time = timer.elapsed + timer.stop() + parallel_time = timer.elapsed / n_epoch batch_handler = BatchHandler( train_containers=[train_handler], val_containers=[val_handler], - sample_shape=(12, 12, 16), - batch_size=15, - s_enhance=3, - t_enhance=4, - n_batches=10, + sample_shape=(10, 10, 1), + batch_size=batch_size, + s_enhance=2, + t_enhance=1, + n_batches=n_batches, means={'u_100m': 0, 'v_100m': 0}, stds={'u_100m': 1, 'v_100m': 1}, max_workers=1, @@ -228,7 +232,8 @@ def test_train_workers(n_epoch=3): timer.start() model.train(batch_handler, **model_kwargs) - serial_time = timer.elapsed + timer.stop() + serial_time = timer.elapsed / n_epoch print( 'Elapsed (parallel / serial): {} / {}'.format( From 10dbc9c0c37a0f4b6e056dc817b6d035b3c0ae57 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 31 Dec 2024 14:17:03 -0700 Subject: [PATCH 21/32] added test for new disc with "valid" padding --- tests/conftest.py | 4 ++++ tests/training/test_train_gan.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9c23755cf..2f52610da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,6 +39,10 @@ def pytest_configure(config): # pylint: disable=unused-argument # noqa: ARG001 pytest.ST_FP_DISC = os.path.join(TEST_DATA_DIR, 'config_disc_st_test.json') pytest.S_FP_DISC = os.path.join(TEST_DATA_DIR, 'config_disc_s_test.json') + pytest.ST_FP_DISC_PROD = os.path.join( + CONFIG_DIR, 'spatiotemporal/disc.json' + ) + pytest.FPS_GCM = [ os.path.join(TEST_DATA_DIR, 'ua_test.nc'), os.path.join(TEST_DATA_DIR, 'va_test.nc'), diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index a5dbe9864..b60245e16 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -43,6 +43,7 @@ def _get_handlers(): ['fp_gen', 'fp_disc', 's_enhance', 't_enhance', 'sample_shape'], [ (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16)), + (pytest.ST_FP_GEN, pytest.ST_FP_DISC_PROD, 3, 4, (12, 12, 16)), (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (10, 10, 1)), ], ) @@ -53,8 +54,7 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): lr = 5e-5 Sup3rGan.seed() model = Sup3rGan( - fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' - ) + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError') train_handler, val_handler = _get_handlers() @@ -252,7 +252,7 @@ def test_train_st_weight_update(n_epoch=2): pytest.ST_FP_GEN, pytest.ST_FP_DISC, learning_rate=1e-4, - learning_rate_disc=4e-4, + learning_rate_disc=4e-4 ) train_handler, val_handler = _get_handlers() From 2e6ed144f378c86a14f5f65946d357b7bcea80bb Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Jan 2025 10:24:28 -0700 Subject: [PATCH 22/32] parallel sampling batch sampling test. --- sup3r/preprocessing/batch_queues/abstract.py | 40 +++++----- tests/batch_handlers/test_bh_general.py | 84 ++++++++++++++------ tests/training/test_train_gan.py | 9 +-- 3 files changed, 87 insertions(+), 46 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 55928d7f6..c6a9200d1 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -247,24 +247,21 @@ def running(self): and not self.queue.is_closed() ) - def _enqueue_batches(self, n_batches) -> None: - """Sample N batches and enqueue them as they are sampled.""" + def sample_batches(self, n_batches) -> None: + """Sample given number of batches either in serial or with thread + pool.""" if n_batches == 1 or self.max_workers == 1: - for _ in range(n_batches): - self.queue.enqueue(self.sample_batch()) - - else: - tasks = [ - self._thread_pool.submit(self.sample_batch) - for _ in range(n_batches) - ] - logger.debug( - 'Added %s sample_batch futures to %s queue.', - n_batches, - self._thread_name, - ) - for batch in as_completed(tasks): - self.queue.enqueue(batch.result()) + return [self.sample_batch() for _ in range(n_batches)] + tasks = [ + self._thread_pool.submit(self.sample_batch) + for _ in range(n_batches) + ] + logger.debug( + 'Added %s sample_batch futures to %s queue.', + n_batches, + self._thread_name, + ) + return tasks def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is @@ -273,8 +270,15 @@ def enqueue_batches(self) -> None: log_time = time.time() while self.running: needed = max(self.queue_cap - self.queue_len, 0) + needed = min(self.max_workers, needed) if needed > 0: - self._enqueue_batches(n_batches=needed) + batches = self.sample_batches(n_batches=needed) + if needed > 1 and self.max_workers > 1: + for batch in as_completed(batches): + self.queue.enqueue(batch.result()) + else: + for batch in batches: + self.queue.enqueue(batch) if time.time() > log_time + 10: logger.debug(self.log_queue_info()) diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index 3c6e3e595..1a6504ef5 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -2,11 +2,8 @@ import copy -import dask.array as da import numpy as np -import pandas as pd import pytest -import xarray as xr from scipy.ndimage import gaussian_filter from sup3r.preprocessing import ( @@ -33,13 +30,12 @@ BatchHandlerTester = BatchHandlerTesterFactory(BatchHandler, SamplerTester) -def test_batch_handler_workers(): - """Check that it is faster to get batches with max_workers > 1 than with - max_workers = 1.""" +def test_batch_sampling_workers(): + """Check that it is faster to sample batches with max_workers > 1 than with + max_workers = 1. This does not include enqueueing and dequeueing.""" timer = Timer() - n_lats = 200 - n_lons = 200 + ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) sample_shape = (20, 20, 30) chunk_shape = ( 2 * sample_shape[0], @@ -51,23 +47,65 @@ def test_batch_handler_workers(): n_batches = 10 n_epochs = 3 - lons, lats = np.meshgrid( - np.linspace(0, 180, n_lats), np.linspace(40, 60, n_lons) + ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape))) + + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=max_workers, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, + ) + timer.start() + for _ in range(n_epochs): + batches = batcher.sample_batches(n_batches) + _ = [batch.result() for batch in batches] + timer.stop() + parallel_time = timer.elapsed / (n_batches * n_epochs) + batcher.stop() + + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=1, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, + ) + timer.start() + for _ in range(n_epochs): + _ = batcher.sample_batches(n_batches) + timer.stop() + serial_time = timer.elapsed / (n_batches * n_epochs) + batcher.stop() + + print( + 'Elapsed (serial / parallel): {} / {}'.format( + serial_time, parallel_time + ) ) - time = pd.date_range('2023-01-01', '2023-05-01', freq='h') - u_arr = da.random.random((*lats.shape, len(time))).astype('float32') - v_arr = da.random.random((*lats.shape, len(time))).astype('float32') - ds = xr.Dataset( - coords={ - 'latitude': (('south_north', 'west_east'), lats), - 'longitude': (('south_north', 'west_east'), lons), - 'time': time, - }, - data_vars={ - 'u_100m': (('south_north', 'west_east', 'time'), u_arr), - 'v_100m': (('south_north', 'west_east', 'time'), v_arr), - }, + assert serial_time > parallel_time + + +def test_batch_queue_workers(): + """Check that it is faster to queue batches with max_workers > 1 than with + max_workers = 1.""" + + timer = Timer() + ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) + sample_shape = (20, 20, 30) + chunk_shape = ( + 2 * sample_shape[0], + 2 * sample_shape[1], + 2 * sample_shape[-1], ) + n_obs = 10 + max_workers = 10 + n_batches = 10 + n_epochs = 3 ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape))) batcher = BatchHandler( diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index b60245e16..4bf974b8d 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -43,7 +43,6 @@ def _get_handlers(): ['fp_gen', 'fp_disc', 's_enhance', 't_enhance', 'sample_shape'], [ (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16)), - (pytest.ST_FP_GEN, pytest.ST_FP_DISC_PROD, 3, 4, (12, 12, 16)), (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (10, 10, 1)), ], ) @@ -54,7 +53,8 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): lr = 5e-5 Sup3rGan.seed() model = Sup3rGan( - fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError') + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' + ) train_handler, val_handler = _get_handlers() @@ -169,7 +169,7 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): batch_handler.stop() -def test_train_workers(n_epoch=20): +def test_train_workers(n_epoch=10): """Test that model training with max_workers > 1 for the batch queue is faster than for max_workers = 1.""" @@ -188,7 +188,6 @@ def test_train_workers(n_epoch=20): ) with tempfile.TemporaryDirectory() as td: - batch_handler = BatchHandler( train_containers=[train_handler], val_containers=[val_handler], @@ -252,7 +251,7 @@ def test_train_st_weight_update(n_epoch=2): pytest.ST_FP_GEN, pytest.ST_FP_DISC, learning_rate=1e-4, - learning_rate_disc=4e-4 + learning_rate_disc=4e-4, ) train_handler, val_handler = _get_handlers() From 8f2218d8e0a0bf7c5fb3f56310f3bf59c9e57ec5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 2 Jan 2025 12:00:18 -0700 Subject: [PATCH 23/32] removed workers tests. max_workers > 1 still not consistently faster. just sampling is, except for macos, but training is not. --- sup3r/preprocessing/batch_queues/abstract.py | 2 +- tests/batch_handlers/test_bh_general.py | 168 ++++++++----------- tests/training/test_train_gan.py | 74 -------- 3 files changed, 67 insertions(+), 177 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index c6a9200d1..c0d418ddc 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -234,7 +234,7 @@ def __iter__(self): def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" - if self.mode == 'eager' or self.queue_cap == 0: + if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0: return self.sample_batch() return self.queue.dequeue() diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index 1a6504ef5..616eb1833 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -1,14 +1,14 @@ """Smoke tests for batcher objects. Just make sure things run without errors""" - import copy +import os +import time +from tempfile import TemporaryDirectory import numpy as np import pytest from scipy.ndimage import gaussian_filter -from sup3r.preprocessing import ( - BatchHandler, -) +from sup3r.preprocessing import BatchHandler, DataHandler from sup3r.preprocessing.base import Container from sup3r.utilities.pytest.helpers import ( BatchHandlerTesterFactory, @@ -35,8 +35,7 @@ def test_batch_sampling_workers(): max_workers = 1. This does not include enqueueing and dequeueing.""" timer = Timer() - ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) - sample_shape = (20, 20, 30) + sample_shape = (100, 100, 30) chunk_shape = ( 2 * sample_shape[0], 2 * sample_shape[1], @@ -44,108 +43,73 @@ def test_batch_sampling_workers(): ) n_obs = 10 max_workers = 10 - n_batches = 10 + n_batches = 50 n_epochs = 3 + chunks = dict(zip(['south_north', 'west_east', 'time'], chunk_shape)) - ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape))) - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=max_workers, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - batches = batcher.sample_batches(n_batches) - _ = [batch.result() for batch in batches] - timer.stop() - parallel_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=1, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - _ = batcher.sample_batches(n_batches) - timer.stop() - serial_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() + with TemporaryDirectory() as td: + ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) + ds.to_netcdf(os.path.join(td, 'test.nc')) + ds = DataHandler(os.path.join(td, 'test.nc'), chunks=chunks) - print( - 'Elapsed (serial / parallel): {} / {}'.format( - serial_time, parallel_time + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=max_workers, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, ) - ) - assert serial_time > parallel_time - - -def test_batch_queue_workers(): - """Check that it is faster to queue batches with max_workers > 1 than with - max_workers = 1.""" + timer.start() + queue_time = 0 + for _ in range(n_epochs): + batches = batcher.sample_batches(n_batches) + batches = [batch.result() for batch in batches] + queue_start = time.time() + for batch in batches: + batcher.queue.enqueue(batch) + _ = batcher.queue.dequeue() + queue_time += (time.time() - queue_start) + timer.stop() + parallel_time = timer.elapsed / (n_batches * n_epochs) + parallel_queue_time = queue_time / (n_batches * n_epochs) + batcher.stop() - timer = Timer() - ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) - sample_shape = (20, 20, 30) - chunk_shape = ( - 2 * sample_shape[0], - 2 * sample_shape[1], - 2 * sample_shape[-1], - ) - n_obs = 10 - max_workers = 10 - n_batches = 10 - n_epochs = 3 - ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape))) - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=max_workers, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - _ = list(batcher) - timer.stop() - parallel_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() - - batcher = BatchHandler( - [ds], - n_batches=n_batches, - batch_size=n_obs, - sample_shape=sample_shape, - max_workers=1, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - ) - timer.start() - for _ in range(n_epochs): - _ = list(batcher) - timer.stop() - serial_time = timer.elapsed / (n_batches * n_epochs) - batcher.stop() - - print( - 'Elapsed (serial / parallel): {} / {}'.format( - serial_time, parallel_time + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=1, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, ) - ) - assert serial_time > parallel_time + timer.start() + queue_time = 0 + for _ in range(n_epochs): + batches = batcher.sample_batches(n_batches) + queue_start = time.time() + for batch in batches: + batcher.queue.enqueue(batch) + _ = batcher.queue.dequeue() + queue_time += time.time() - queue_start + timer.stop() + serial_time = timer.elapsed / (n_batches * n_epochs) + serial_queue_time = queue_time / (n_batches * n_epochs) + batcher.stop() + + print( + 'Elapsed total time (serial / parallel): {} / {}'.format( + serial_time, parallel_time + ) + ) + print( + 'Elapsed queue time (serial / parallel): {} / {}'.format( + serial_queue_time, parallel_queue_time + ) + ) + assert serial_time > parallel_time def test_eager_vs_lazy(): diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 4bf974b8d..954bc3cb2 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -11,7 +11,6 @@ from sup3r.models import Sup3rGan from sup3r.preprocessing import BatchHandler, DataHandler -from sup3r.utilities.utilities import Timer TARGET_COORD = (39.01, -105.15) FEATURES = ['u_100m', 'v_100m'] @@ -169,79 +168,6 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): batch_handler.stop() -def test_train_workers(n_epoch=10): - """Test that model training with max_workers > 1 for the batch queue is - faster than for max_workers = 1.""" - - lr = 5e-5 - train_handler, val_handler = _get_handlers() - timer = Timer() - n_batches = 40 - batch_size = 40 - - Sup3rGan.seed() - model = Sup3rGan( - pytest.S_FP_GEN, - pytest.S_FP_DISC, - learning_rate=lr, - loss='MeanAbsoluteError', - ) - - with tempfile.TemporaryDirectory() as td: - batch_handler = BatchHandler( - train_containers=[train_handler], - val_containers=[val_handler], - sample_shape=(10, 10, 1), - batch_size=batch_size, - s_enhance=2, - t_enhance=1, - n_batches=n_batches, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - max_workers=5, - ) - - model_kwargs = { - 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, - 'n_epoch': n_epoch, - 'weight_gen_advers': 0.0, - 'train_gen': True, - 'train_disc': False, - 'checkpoint_int': 10, - 'out_dir': os.path.join(td, 'test_{epoch}'), - } - - timer.start() - model.train(batch_handler, **model_kwargs) - timer.stop() - parallel_time = timer.elapsed / n_epoch - - batch_handler = BatchHandler( - train_containers=[train_handler], - val_containers=[val_handler], - sample_shape=(10, 10, 1), - batch_size=batch_size, - s_enhance=2, - t_enhance=1, - n_batches=n_batches, - means={'u_100m': 0, 'v_100m': 0}, - stds={'u_100m': 1, 'v_100m': 1}, - max_workers=1, - ) - - timer.start() - model.train(batch_handler, **model_kwargs) - timer.stop() - serial_time = timer.elapsed / n_epoch - - print( - 'Elapsed (parallel / serial): {} / {}'.format( - parallel_time, serial_time - ) - ) - assert parallel_time < serial_time - - def test_train_st_weight_update(n_epoch=2): """Test basic spatiotemporal model training with discriminators and adversarial loss updating.""" From 02a9ecc6bcebcb348f6840f0c42eae28a8526ea5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Jan 2025 15:42:25 -0700 Subject: [PATCH 24/32] ``Sup3rGanWithObs`` model subclass. Other misc model refactoring. --- sup3r/models/__init__.py | 1 + sup3r/models/abstract.py | 146 ++++----- sup3r/models/base.py | 186 ++++++++--- sup3r/models/interface.py | 27 -- sup3r/models/with_obs.py | 352 +++++++++++++++++++++ tests/training/test_train_dual_with_obs.py | 10 +- 6 files changed, 561 insertions(+), 161 deletions(-) create mode 100644 sup3r/models/with_obs.py diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 5d6b51344..20fff799f 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -6,6 +6,7 @@ from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel +from .with_obs import Sup3rGanWithObs SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 579dab5e9..c5b4c9dfe 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -16,7 +16,6 @@ from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat from rex.utilities.utilities import safe_json_load from tensorflow.keras import optimizers -from tensorflow.keras.losses import MeanAbsoluteError import sup3r.utilities.loss_metrics from sup3r.preprocessing.data_handlers import ExoData @@ -420,6 +419,30 @@ def get_high_res_exo_input(self, high_res): exo_data[feature] = exo_fdata return exo_data + @tf.function + def _combine_loss_input(self, high_res_true, high_res_gen): + """Combine exogenous feature data from high_res_true with high_res_gen + for loss calculation + + Parameters + ---------- + high_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + high_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + high_res_gen : tf.Tensor + Same as input with exogenous data combined with high_res input + """ + if high_res_true.shape[-1] > high_res_gen.shape[-1]: + exo_dict = self.get_high_res_exo_input(high_res_true) + exo_data = [exo_dict[feat] for feat in self.hr_exo_features] + high_res_gen = tf.concat((high_res_gen, *exo_data), axis=-1) + return high_res_gen + @staticmethod def get_loss_fun(loss): """Get the initialized loss function class from the sup3r loss library @@ -717,25 +740,42 @@ def finish_epoch( return stop + def _sum_parallel_grad(self, futures, start_time): + """Sum gradient descent future results""" + + # sum the gradients from each gpu to weight equally in + # optimizer momentum calculation + total_grad = None + for future in futures: + grad, loss_details = future.result() + if total_grad is None: + total_grad = grad + else: + for i, igrad in enumerate(grad): + total_grad[i] += igrad + + msg = ( + f'Finished {len(futures)} gradient descent steps on ' + f'{len(self.gpu_list)} GPUs in {time.time() - start_time:.4f} ' + 'seconds' + ) + logger.info(msg) + return total_grad, loss_details + def _get_parallel_grad( self, low_res, hi_res_true, training_weights, - obs_data=None, **calc_loss_kwargs, ): """Compute gradient for one mini-batch of (low_res, hi_res_true) across multiple GPUs""" futures = [] + start_time = time.time() lr_chunks = np.array_split(low_res, len(self.gpu_list)) hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) - obs_data_chunks = ( - [None] * len(hr_true_chunks) - if obs_data is None - else np.array_split(obs_data, len(self.gpu_list)) - ) split_mask = False mask_chunks = None if 'mask' in calc_loss_kwargs: @@ -754,38 +794,17 @@ def _get_parallel_grad( lr_chunks[i], hr_true_chunks[i], training_weights, - obs_data=obs_data_chunks[i], device_name=f'/gpu:{i}', **calc_loss_kwargs, ) ) - - # sum the gradients from each gpu to weight equally in - # optimizer momentum calculation - total_grad = None - for future in futures: - grad, loss_details = future.result() - if total_grad is None: - total_grad = grad - else: - for i, igrad in enumerate(grad): - total_grad[i] += igrad - - self.timer.stop() - logger.debug( - 'Finished %s gradient descent steps on %s GPUs in %s', - len(futures), - len(self.gpu_list), - self.timer.elapsed_str, - ) - return total_grad, loss_details + return self._sum_parallel_grad(futures, start_time=start_time) def run_gradient_descent( self, low_res, hi_res_true, training_weights, - obs_data=None, optimizer=None, multi_gpu=False, **calc_loss_kwargs, @@ -806,10 +825,6 @@ def run_gradient_descent( training_weights : list A list of layer weights that are to-be-trained based on the current loss weight values. - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) optimizer : tf.keras.optimizers.Optimizer Optimizer class to use to update weights. This can be different if you're training just the generator or one of the discriminator @@ -829,32 +844,27 @@ def run_gradient_descent( loss_details : dict Namespace of the breakdown of loss components """ - - self.timer.start() if optimizer is None: optimizer = self.optimizer if not multi_gpu or len(self.gpu_list) < 2: + start_time = time.time() grad, loss_details = self.get_single_grad( low_res, hi_res_true, training_weights, - obs_data=obs_data, device_name=self.default_device, **calc_loss_kwargs, ) optimizer.apply_gradients(zip(grad, training_weights)) - self.timer.stop() - logger.debug( - 'Finished single gradient descent step in %s', - self.timer.elapsed_str, - ) + msg = ('Finished single gradient descent step in ' + f'{time.time() - start_time:.4f} seconds') + logger.debug(msg) else: total_grad, loss_details = self._get_parallel_grad( low_res, hi_res_true, training_weights, - obs_data, **calc_loss_kwargs, ) optimizer.apply_gradients(zip(total_grad, training_weights)) @@ -1050,13 +1060,25 @@ def _tf_generate(self, low_res, hi_res_exo=None): return hi_res + def _get_hr_exo_and_loss( + self, + low_res, + hi_res_true, + **calc_loss_kwargs, + ): + """Get high-resolution exogenous data, generate synthetic output, and + compute loss.""" + hi_res_exo = self.get_high_res_exo_input(hi_res_true) + hi_res_gen = self._tf_generate(low_res, hi_res_exo) + loss_out = self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) + return *loss_out, hi_res_gen + @tf.function def get_single_grad( self, low_res, hi_res_true, training_weights, - obs_data=None, device_name=None, **calc_loss_kwargs, ): @@ -1076,10 +1098,6 @@ def get_single_grad( training_weights : list A list of layer weights that are to-be-trained based on the current loss weight values. - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) device_name : None | str Optional tensorflow device name for GPU placement. Note that if a GPU is available, variables will be placed on that GPU even if @@ -1100,16 +1118,10 @@ def get_single_grad( watch_accessed_variables=False ) as tape: tape.watch(training_weights) - hi_res_exo = self.get_high_res_exo_input(hi_res_true) - hi_res_gen = self._tf_generate(low_res, hi_res_exo) - loss_out = self.calc_loss( - hi_res_true, hi_res_gen, **calc_loss_kwargs + *loss_out, _ = self._get_hr_exo_and_loss( + low_res, hi_res_true, **calc_loss_kwargs ) loss, loss_details = loss_out - if obs_data is not None: - loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) - loss += loss_obs - loss_details['loss_obs'] = loss_obs grad = tape.gradient(loss, training_weights) return grad, loss_details @@ -1124,27 +1136,3 @@ def calc_loss( ): """Calculate the GAN loss function using generated and true high resolution data.""" - - @tf.function - def calc_loss_obs(self, obs_data, hi_res_gen): - """Calculate loss term for the observation data vs generated - high-resolution data - - Parameters - ---------- - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - - Returns - ------- - loss : tf.Tensor - 0D tensor of observation loss - """ - mask = tf.math.is_nan(obs_data) - return MeanAbsoluteError()( - obs_data[~mask], - hi_res_gen[..., : len(self.hr_out_features)][~mask], - ) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 53b159464..f91e6f89c 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -750,6 +750,7 @@ def train( ) for epoch in epochs: + t_epoch = time.time() loss_details = self.train_epoch( batch_handler, weight_gen_advers, @@ -813,12 +814,18 @@ def train( early_stop_n_epoch, extras=extras, ) + logger.info( + 'Finished training epoch in {:.4f} seconds'.format( + time.time() - t_epoch + ) + ) if stop: break logger.info( - 'Finished training %s epochs in %s seconds', - n_epoch, - time.time() - t0, + 'Finished training {} epochs in {:.4f} seconds'.format( + n_epoch, + time.time() - t0, + ) ) batch_handler.stop() @@ -842,8 +849,6 @@ def calc_loss( hi_res_gen : tf.Tensor Superresolved high resolution spatiotemporal data generated by the generative model. - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. weight_gen_advers : float Weight factor for the adversarial loss component of the generator vs. the discriminator. @@ -897,6 +902,37 @@ def calc_loss( return loss, loss_details + def _calc_val_loss(self, batch, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + for a given batch + + Parameters + ---------- + batch : DsetTuple + Object with ``.high_res`` and ``.low_res`` arrays + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input but now includes val_* loss info + """ + _, v_loss_details, _ = self._get_hr_exo_and_loss( + batch.low_res, + batch.high_res, + weight_gen_advers=weight_gen_advers, + train_gen=False, + train_disc=False, + ) + loss_details = self.update_loss_details( + loss_details, v_loss_details, len(batch), prefix='val_' + ) + return loss_details + def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): """Calculate the validation loss at the current state of model training @@ -918,25 +954,93 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self.get_high_res_exo_input(val_batch.high_res) - high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) - _, v_loss_details = self.calc_loss( - val_batch.high_res, - high_res_gen, + loss_details = self._calc_val_loss( + val_batch, weight_gen_advers, loss_details + ) + return loss_details + + def _get_batch_loss_details( + self, + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu=False, + ): + """Get loss details for a given batch for the current epoch. + + Parameters + ---------- + batch : sup3r.preprocessing.base.DsetTuple + Object with ``.low_res`` and ``.high_res`` arrays + train_gen : bool + Flag whether to train the generator for this set of epochs + only_gen : bool + Flag whether to only train the generator for this set of epochs + gen_too_good : bool + Flag whether to skip training the generator and only train the + discriminator, due to superior performance, for this batch. + train_disc : bool + Flag whether to train the discriminator for this set of epochs + only_disc : bool + Flag whether to only train the discriminator for this set of epochs + gen_too_good : bool + Flag whether to skip training the discriminator and only train the + generator, due to superior performance, for this batch. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components for the given batch + """ + + trained_gen = False + trained_disc = False + if only_gen or (train_gen and not gen_too_good): + trained_gen = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.generator_weights, weight_gen_advers=weight_gen_advers, - train_gen=False, + optimizer=self.optimizer, + train_gen=True, train_disc=False, + multi_gpu=multi_gpu, ) - obs_data = getattr(val_batch, 'obs', None) - if obs_data is not None: - v_loss_details['loss_obs'] = self.cal_loss_obs( - obs_data, high_res_gen - ) - loss_details = self.update_loss_details( - loss_details, v_loss_details, len(val_batch), prefix='val_' + if only_disc or (train_disc and not disc_too_good): + trained_disc = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.discriminator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer_disc, + train_gen=False, + train_disc=True, + multi_gpu=multi_gpu, ) - return loss_details + + b_loss_details['gen_trained_frac'] = float(trained_gen) + b_loss_details['disc_trained_frac'] = float(trained_disc) + return b_loss_details def train_epoch( self, @@ -991,8 +1095,6 @@ def train_epoch( tf.summary.trace_on(graph=True, profiler=True) for ib, batch in enumerate(batch_handler): - trained_gen = False - trained_disc = False b_loss_details = {} loss_disc = loss_details['train_loss_disc'] disc_too_good = loss_disc <= disc_th_low @@ -1002,35 +1104,19 @@ def train_epoch( if not self.generator_weights: self.init_weights(batch.low_res.shape, batch.high_res.shape) - if only_gen or (train_gen and not gen_too_good): - trained_gen = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.generator_weights, - obs_data=getattr(batch, 'obs', None), - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer, - train_gen=True, - train_disc=False, - multi_gpu=multi_gpu, - ) - - if only_disc or (train_disc and not disc_too_good): - trained_disc = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.discriminator_weights, - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer_disc, - train_gen=False, - train_disc=True, - multi_gpu=multi_gpu, - ) - - b_loss_details['gen_trained_frac'] = float(trained_gen) - b_loss_details['disc_trained_frac'] = float(trained_disc) + b_loss_details = self._get_batch_loss_details( + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu, + ) + trained_gen = bool(b_loss_details.get('gen_trained_frac', False)) + trained_disc = bool(b_loss_details.get('disc_trained_frac', False)) self.dict_to_tensorboard(b_loss_details) self.dict_to_tensorboard(self.timer.log) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index c81065f86..e284cc21a 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -9,7 +9,6 @@ from warnings import warn import numpy as np -import tensorflow as tf from phygnn import CustomNetwork from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat @@ -355,32 +354,6 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): hi_res = np.concatenate((hi_res, exo_output), axis=-1) return hi_res - @tf.function - def _combine_loss_input(self, high_res_true, high_res_gen): - """Combine exogenous feature data from high_res_true with high_res_gen - for loss calculation - - Parameters - ---------- - high_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - high_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - - Returns - ------- - high_res_gen : tf.Tensor - Same as input with exogenous data combined with high_res input - """ - if high_res_true.shape[-1] > high_res_gen.shape[-1]: - for feature in self.hr_exo_features: - f_idx = self.hr_exo_features.index(feature) - f_idx += len(self.hr_out_features) - exo_data = high_res_true[..., f_idx : f_idx + 1] - high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) - return high_res_gen - @property @abstractmethod def meta(self): diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py new file mode 100644 index 000000000..e8ce3e417 --- /dev/null +++ b/sup3r/models/with_obs.py @@ -0,0 +1,352 @@ +"""Sup3r model with training on observation data.""" + +import logging +import time +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import tensorflow as tf +from tensorflow.keras.losses import MeanAbsoluteError + +from .base import Sup3rGan + +logger = logging.getLogger(__name__) + + +class Sup3rGanWithObs(Sup3rGan): + """Sup3r GAN model which incorporates observation data into content loss. + """ + + def _calc_val_loss(self, batch, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + for a given batch + + Parameters + ---------- + batch : DsetTuple + Object with ``.high_res``, ``.low_res``, and ``.obs`` arrays + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input with updated val_* loss info + """ + val_exo_data = self.get_high_res_exo_input(batch.high_res) + high_res_gen = self._tf_generate(batch.low_res, val_exo_data) + _, v_loss_details = self.calc_loss( + batch.high_res, + high_res_gen, + weight_gen_advers=weight_gen_advers, + train_gen=False, + train_disc=False, + ) + v_loss_details['loss_obs'] = self.cal_loss_obs(batch.obs, high_res_gen) + + loss_details = self.update_loss_details( + loss_details, v_loss_details, len(batch), prefix='val_' + ) + return loss_details + + def _get_batch_loss_details( + self, + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu=False, + ): + """Get loss details for a given batch for the current epoch. + + Parameters + ---------- + batch : sup3r.preprocessing.base.DsetTuple + Object with ``.low_res``, ``.high_res``, and ``.obs`` arrays + train_gen : bool + Flag whether to train the generator for this set of epochs + only_gen : bool + Flag whether to only train the generator for this set of epochs + gen_too_good : bool + Flag whether to skip training the generator and only train the + discriminator, due to superior performance, for this batch. + train_disc : bool + Flag whether to train the discriminator for this set of epochs + only_disc : bool + Flag whether to only train the discriminator for this set of epochs + gen_too_good : bool + Flag whether to skip training the discriminator and only train the + generator, due to superior performance, for this batch. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components for the given batch + """ + trained_gen = False + trained_disc = False + if only_gen or (train_gen and not gen_too_good): + trained_gen = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.generator_weights, + obs_data=getattr(batch, 'obs', None), + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer, + train_gen=True, + train_disc=False, + multi_gpu=multi_gpu, + ) + + if only_disc or (train_disc and not disc_too_good): + trained_disc = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.discriminator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer_disc, + train_gen=False, + train_disc=True, + multi_gpu=multi_gpu, + ) + + b_loss_details['gen_trained_frac'] = float(trained_gen) + b_loss_details['disc_trained_frac'] = float(trained_disc) + return b_loss_details + + def _get_parallel_grad( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + **calc_loss_kwargs, + ): + """Compute gradient for one mini-batch of (low_res, hi_res_true, + obs_data) across multiple GPUs. Can include observation data as well. + """ + + futures = [] + start_time = time.time() + lr_chunks = np.array_split(low_res, len(self.gpu_list)) + hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) + obs_data_chunks = ( + [None] * len(hr_true_chunks) + if obs_data is None + else np.array_split(obs_data, len(self.gpu_list)) + ) + split_mask = False + mask_chunks = None + if 'mask' in calc_loss_kwargs: + split_mask = True + mask_chunks = np.array_split( + calc_loss_kwargs['mask'], len(self.gpu_list) + ) + + with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: + for i in range(len(self.gpu_list)): + if split_mask: + calc_loss_kwargs['mask'] = mask_chunks[i] + futures.append( + exe.submit( + self.get_single_grad, + lr_chunks[i], + hr_true_chunks[i], + training_weights, + obs_data=obs_data_chunks[i], + device_name=f'/gpu:{i}', + **calc_loss_kwargs, + ) + ) + + return self._sum_parallel_grad(futures, start_time=start_time) + + def run_gradient_descent( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + optimizer=None, + multi_gpu=False, + **calc_loss_kwargs, + ): + """Run gradient descent for one mini-batch of (low_res, hi_res_true) + and update weights + + Parameters + ---------- + low_res : np.ndarray + Real low-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + hi_res_true : np.ndarray + Real high-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + training_weights : list + A list of layer weights that are to-be-trained based on the + current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + optimizer : tf.keras.optimizers.Optimizer + Optimizer class to use to update weights. This can be different if + you're training just the generator or one of the discriminator + models. Defaults to the generator optimizer. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + calc_loss_kwargs : dict + Kwargs to pass to the self.calc_loss() method + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components + """ + + self.timer.start() + if optimizer is None: + optimizer = self.optimizer + + if not multi_gpu or len(self.gpu_list) < 2: + grad, loss_details = self.get_single_grad( + low_res, + hi_res_true, + training_weights, + obs_data=obs_data, + device_name=self.default_device, + **calc_loss_kwargs, + ) + optimizer.apply_gradients(zip(grad, training_weights)) + self.timer.stop() + logger.debug( + 'Finished single gradient descent step in %s', + self.timer.elapsed_str, + ) + else: + total_grad, loss_details = self._get_parallel_grad( + low_res, + hi_res_true, + training_weights, + obs_data, + **calc_loss_kwargs, + ) + optimizer.apply_gradients(zip(total_grad, training_weights)) + + return loss_details + + @tf.function + def get_single_grad( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + device_name=None, + **calc_loss_kwargs, + ): + """Run gradient descent for one mini-batch of (low_res, hi_res_true), + do not update weights, just return gradient details. + + Parameters + ---------- + low_res : np.ndarray + Real low-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + hi_res_true : np.ndarray + Real high-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + training_weights : list + A list of layer weights that are to-be-trained based on the + current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + device_name : None | str + Optional tensorflow device name for GPU placement. Note that if a + GPU is available, variables will be placed on that GPU even if + device_name=None. + calc_loss_kwargs : dict + Kwargs to pass to the self.calc_loss() method + + Returns + ------- + grad : list + a list or nested structure of Tensors (or IndexedSlices, or None, + or CompositeTensor) representing the gradients for the + training_weights + loss_details : dict + Namespace of the breakdown of loss components + """ + with tf.device(device_name), tf.GradientTape( + watch_accessed_variables=False + ) as tape: + tape.watch(training_weights) + *loss_out, hi_res_gen = self._get_hr_exo_and_loss( + low_res, hi_res_true, **calc_loss_kwargs + ) + loss, loss_details = loss_out + if obs_data is not None: + loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + loss += loss_obs + loss_details['loss_obs'] = loss_obs + grad = tape.gradient(loss, training_weights) + return grad, loss_details + + @tf.function + def calc_loss_obs(self, obs_data, hi_res_gen): + """Calculate loss term for the observation data vs generated + high-resolution data + + Parameters + ---------- + obs_data : tf.Tensor | None + Observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + hi_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + loss : tf.Tensor + 0D tensor of observation loss + """ + mask = tf.math.is_nan(obs_data) + return MeanAbsoluteError()( + obs_data[~mask], + hi_res_gen[..., : len(self.hr_out_features)][~mask], + ) diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py index 0bf8244ee..2399b21ba 100644 --- a/tests/training/test_train_dual_with_obs.py +++ b/tests/training/test_train_dual_with_obs.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from sup3r.models import Sup3rGan +from sup3r.models import Sup3rGanWithObs from sup3r.preprocessing import ( Container, DataHandler, @@ -104,8 +104,8 @@ def test_train_h5_nc( assert not np.isnan(batch.obs).all() assert np.isnan(batch.obs).any() - Sup3rGan.seed() - model = Sup3rGan( + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' ) @@ -208,8 +208,8 @@ def test_train_coarse_h5( assert not np.isnan(batch.obs).all() assert np.isnan(batch.obs).any() - Sup3rGan.seed() - model = Sup3rGan( + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' ) From dd70a576e89bb090485a8b89e57eaefd64adb0dc Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 5 Jan 2025 09:09:29 -0700 Subject: [PATCH 25/32] moved ``_run`` method to bias correction interface ``AbstractBiasCorrection`` --- sup3r/models/base.py | 2 +- sup3r/models/utilities.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index f91e6f89c..f7c3a43f8 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1146,8 +1146,8 @@ def train_epoch( ) logger.warning(msg) warn(msg) - self.total_batches += 1 + self.total_batches += len(batch_handler) loss_details['total_batches'] = int(self.total_batches) self.profile_to_tensorboard('training_epoch') return loss_details diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index 68e0f0102..dee3c51bc 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -76,12 +76,11 @@ def __init__(self): @property def total_batches(self): """Record of total number of batches for logging.""" - if self._total_batches is None and self._history is None: - self._total_batches = 0 - elif self._history is None and 'total_batches' in self._history: - self._total_batches = self._history['total_batches'].values[-1] - elif self._total_batches is None and self._history is not None: - self._total_batches = 0 + if self._total_batches is None: + if self._history is not None and 'total_batches' in self._history: + self._total_batches = self._history['total_batches'].values[-1] + else: + self._total_batches = 0 return self._total_batches @total_batches.setter From d93a6b11a568717c802d090e5bc02c672d4f7613 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 5 Jan 2025 16:28:37 -0700 Subject: [PATCH 26/32] moved ``_run`` method to bias correction interface ``AbstractBiasCorrection`` --- sup3r/preprocessing/derivers/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index 77a5ca528..3c2ce0675 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -56,7 +56,7 @@ def __init__( a method to derive the feature in the registry. interp_kwargs : dict | None Dictionary of kwargs for level interpolation. Can include "method" - and "run_level_check". "method" specifies how to perform height + and "run_level_check" keys. Method specifies how to perform height interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options are "linear" and "log". See :py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation` @@ -65,7 +65,7 @@ def __init__( self.FEATURE_REGISTRY = FeatureRegistry super().__init__(data=data) - self.interp_kwargs = interp_kwargs or {} + self.interp_kwargs = interp_kwargs features = parse_to_list(data=data, features=features) new_features = [f for f in features if f not in self.data] for f in new_features: @@ -269,6 +269,7 @@ def get_single_level_data(self, feature): var_array = da.stack(var_list, axis=-1) sl_shape = (*var_array.shape[:-1], len(lev_list)) lev_array = da.broadcast_to(da.from_array(lev_list), sl_shape) + return var_array, lev_array def get_multi_level_data(self, feature): @@ -295,8 +296,8 @@ def get_multi_level_data(self, feature): assert can_calc_height or have_height, msg if can_calc_height: - lev_array = self.data['zg'] - self.data['topography'] - lev_array = lev_array.data + lev_array = self.data[['zg', 'topography']].as_array() + lev_array = lev_array[..., 0] - lev_array[..., 1] else: lev_array = da.broadcast_to( self.data[Dimension.HEIGHT].astype(np.float32), From 5afa9edc3117bc50936e8d84437cf3567ac54628 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 8 Jan 2025 11:10:40 -0700 Subject: [PATCH 27/32] fix: tensorboard issue with loss obs details --- sup3r/models/with_obs.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index e8ce3e417..e15032e18 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -319,10 +319,10 @@ def get_single_grad( low_res, hi_res_true, **calc_loss_kwargs ) loss, loss_details = loss_out - if obs_data is not None: - loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + if not tf.reduce_any(tf.math.is_nan(loss_obs)): loss += loss_obs - loss_details['loss_obs'] = loss_obs + loss_details.update({'loss_obs': loss_obs}) grad = tape.gradient(loss, training_weights) return grad, loss_details @@ -345,8 +345,13 @@ def calc_loss_obs(self, obs_data, hi_res_gen): loss : tf.Tensor 0D tensor of observation loss """ - mask = tf.math.is_nan(obs_data) - return MeanAbsoluteError()( - obs_data[~mask], - hi_res_gen[..., : len(self.hr_out_features)][~mask], - ) + obs_loss = tf.constant(np.nan) + if obs_data is not None: + mask = tf.math.is_nan(obs_data) + masked_obs = obs_data[~mask] + if len(masked_obs) > 0: + obs_loss = MeanAbsoluteError()( + masked_obs, + hi_res_gen[..., : len(self.hr_out_features)][~mask], + ) + return obs_loss From 32edc38e34f30b40b096fd71d39e24969011743c Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 19 Jan 2025 09:54:23 -0700 Subject: [PATCH 28/32] Adding obs loss to logging of loss gen --- sup3r/models/abstract.py | 15 +++++++++------ sup3r/models/base.py | 1 + sup3r/models/with_obs.py | 17 ++++++++++------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index c5b4c9dfe..54307843c 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -857,8 +857,10 @@ def run_gradient_descent( **calc_loss_kwargs, ) optimizer.apply_gradients(zip(grad, training_weights)) - msg = ('Finished single gradient descent step in ' - f'{time.time() - start_time:.4f} seconds') + msg = ( + 'Finished single gradient descent step in ' + f'{time.time() - start_time:.4f} seconds' + ) logger.debug(msg) else: total_grad, loss_details = self._get_parallel_grad( @@ -1070,8 +1072,10 @@ def _get_hr_exo_and_loss( compute loss.""" hi_res_exo = self.get_high_res_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) - loss_out = self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) - return *loss_out, hi_res_gen + loss, loss_details = self.calc_loss( + hi_res_true, hi_res_gen, **calc_loss_kwargs + ) + return loss, loss_details, hi_res_gen @tf.function def get_single_grad( @@ -1118,10 +1122,9 @@ def get_single_grad( watch_accessed_variables=False ) as tape: tape.watch(training_weights) - *loss_out, _ = self._get_hr_exo_and_loss( + loss, loss_details, _ = self._get_hr_exo_and_loss( low_res, hi_res_true, **calc_loss_kwargs ) - loss, loss_details = loss_out grad = tape.gradient(loss, training_weights) return grad, loss_details diff --git a/sup3r/models/base.py b/sup3r/models/base.py index f7c3a43f8..e7e081e02 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1012,6 +1012,7 @@ def _get_batch_loss_details( trained_gen = False trained_disc = False + b_loss_details = {} if only_gen or (train_gen and not gen_too_good): trained_gen = True b_loss_details = self.timer(self.run_gradient_descent)( diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index e15032e18..7fa3937c5 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -14,8 +14,7 @@ class Sup3rGanWithObs(Sup3rGan): - """Sup3r GAN model which incorporates observation data into content loss. - """ + """Sup3r GAN model with additional observation data content loss.""" def _calc_val_loss(self, batch, weight_gen_advers, loss_details): """Calculate the validation loss at the current state of model training @@ -133,6 +132,11 @@ def _get_batch_loss_details( b_loss_details['gen_trained_frac'] = float(trained_gen) b_loss_details['disc_trained_frac'] = float(trained_disc) + + if 'loss_obs' in b_loss_details: + loss_update = b_loss_details['loss_gen'] + loss_update += b_loss_details['loss_obs'] + b_loss_details.update({'loss_gen': loss_update}) return b_loss_details def _get_parallel_grad( @@ -315,10 +319,9 @@ def get_single_grad( watch_accessed_variables=False ) as tape: tape.watch(training_weights) - *loss_out, hi_res_gen = self._get_hr_exo_and_loss( + loss, loss_details, hi_res_gen = self._get_hr_exo_and_loss( low_res, hi_res_true, **calc_loss_kwargs ) - loss, loss_details = loss_out loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) if not tf.reduce_any(tf.math.is_nan(loss_obs)): loss += loss_obs @@ -345,13 +348,13 @@ def calc_loss_obs(self, obs_data, hi_res_gen): loss : tf.Tensor 0D tensor of observation loss """ - obs_loss = tf.constant(np.nan) + loss_obs = tf.constant(np.nan) if obs_data is not None: mask = tf.math.is_nan(obs_data) masked_obs = obs_data[~mask] if len(masked_obs) > 0: - obs_loss = MeanAbsoluteError()( + loss_obs = MeanAbsoluteError()( masked_obs, hi_res_gen[..., : len(self.hr_out_features)][~mask], ) - return obs_loss + return loss_obs From dffdee2f9d02c8302de5d7da073657af0ffe0d84 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 21 Jan 2025 06:31:09 -0700 Subject: [PATCH 29/32] Adding ``loss_obs`` to ``loss_gen`` so the total loss shows in log output. --- sup3r/models/with_obs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index 7fa3937c5..1e25014ec 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -133,7 +133,9 @@ def _get_batch_loss_details( b_loss_details['gen_trained_frac'] = float(trained_gen) b_loss_details['disc_trained_frac'] = float(trained_disc) - if 'loss_obs' in b_loss_details: + if 'loss_obs' in b_loss_details and not tf.math.is_nan( + b_loss_details['loss_obs'] + ): loss_update = b_loss_details['loss_gen'] loss_update += b_loss_details['loss_obs'] b_loss_details.update({'loss_gen': loss_update}) From ea1d3fd48966155dcebe7bc6b3e0399a7bec8990 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 11 Jan 2025 11:22:34 -0700 Subject: [PATCH 30/32] generalized min pad width for padding slices so that this can accomodate models with increased receptive field and larger padding values. --- sup3r/pipeline/slicer.py | 99 ++++++++++++++++++++++++-------------- sup3r/pipeline/strategy.py | 14 ++++++ 2 files changed, 78 insertions(+), 35 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index 332d62a31..3df7ace7c 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -3,7 +3,7 @@ import itertools as it import logging from dataclasses import dataclass -from typing import Union +from typing import Optional, Union from warnings import warn import numpy as np @@ -27,9 +27,23 @@ class ForwardPassSlicer: time_steps : int Number of time steps for full temporal domain of low res data. This is used to construct a dummy_time_index from np.arange(time_steps) + s_enhance : int + Spatial enhancement factor + t_enhance : int + Temporal enhancement factor time_slice : slice | list Slice to use to extract range from time_index. Can be a ``slice(start, stop, step)`` or list ``[start, stop, step]`` + temporal_pad : int + Size of temporal overlap between coarse chunks passed to forward + passes for subsequent temporal stitching. This overlap will pad + both sides of the fwp_chunk_shape. Note that the first and last + chunks in the temporal dimension will not be padded. + spatial_pad : int + Size of spatial overlap between coarse chunks passed to forward + passes for subsequent spatial stitching. This overlap will pad both + sides of the fwp_chunk_shape. Note that the first and last chunks + in any of the spatial dimension will not be padded. chunk_shape : tuple Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse chunk to use for a forward pass. The number of nodes that the @@ -41,20 +55,11 @@ class ForwardPassSlicer: to the generator can be bigger than this shape. If running in serial set this equal to the shape of the full spatiotemporal data volume for best performance. - s_enhance : int - Spatial enhancement factor - t_enhance : int - Temporal enhancement factor - spatial_pad : int - Size of spatial overlap between coarse chunks passed to forward - passes for subsequent spatial stitching. This overlap will pad both - sides of the fwp_chunk_shape. Note that the first and last chunks - in any of the spatial dimension will not be padded. - temporal_pad : int - Size of temporal overlap between coarse chunks passed to forward - passes for subsequent temporal stitching. This overlap will pad - both sides of the fwp_chunk_shape. Note that the first and last - chunks in the temporal dimension will not be padded. + min_width : tuple + Minimum width of padded slices, with each element providing the min + width for the corresponding dimension. e.g. (spatial_1, spatial_2, + temporal). This is used to make sure generator network input meets the + minimum size requirement for padding layers. """ coarse_shape: Union[tuple, list] @@ -65,6 +70,7 @@ class ForwardPassSlicer: temporal_pad: int spatial_pad: int chunk_shape: Union[tuple, list] + min_width: Optional[Union[tuple, list]] = None @log_args def __post_init__(self): @@ -88,6 +94,9 @@ def __post_init__(self): self._s_hr_crop_slices = None self._t_hr_crop_slices = None self._hr_crop_slices = None + self.min_width = ( + self.chunk_shape if self.min_width is None else self.min_width + ) def get_spatial_slices(self): """Get spatial slices for small data chunks that are passed through @@ -254,6 +263,8 @@ def s1_hr_crop_slices(self): self._s1_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s1_lr_slices, cropped_slices=self._s1_hr_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=0, ) return self._s1_hr_crop_slices @@ -273,6 +284,8 @@ def s2_hr_crop_slices(self): self._s2_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s2_lr_slices, cropped_slices=self._s2_hr_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=1, ) return self._s2_hr_crop_slices @@ -525,21 +538,22 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None): pad_slices.append(slice(start, end, step)) return pad_slices - def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): + def check_boundary_slice( + self, unpadded_slices, cropped_slices, enhancement, padding, dim + ): """Check cropped slice at the right boundary for minimum shape. It is possible for the forward pass chunk shape to divide the grid size such that the last slice (right boundary) does not meet the minimum - number of elements. (Padding layers in the generator typically require - a minimum shape of 4). e.g. ``grid_size = (8, 8)`` with - ``fwp_chunk_shape = (7, 7, ...)`` results in unpadded slices with just - one element. If the padding is 0 or 1 these padded slices have length - less than 4. When this minimum shape is not met we apply extra padding - in :meth:`self._get_pad_width`. Cropped slices have to be adjusted to + number of elements. (Padding layers in the generator require a minimum + shape). e.g. ``grid_size = (8, 8)`` with ``fwp_chunk_shape = (7, 7, + ...)`` results in unpadded slices with just one element. When this + minimum shape is not met we apply extra padding in + :meth:`self._get_pad_width`. Cropped slices have to be adjusted to account for this here.""" warn_msg = ( - 'The final spatial slice for dimension #%s is too small ' + 'The final slice for dimension #%s is too small ' '(slice=slice(%s, %s), padding=%s). The start of this slice will ' 'be reduced to try to meet the minimum slice length.' ) @@ -548,19 +562,22 @@ def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): lr_slice_stop = unpadded_slices[-1].stop or self.coarse_shape[dim] # last slice adjustment - if 2 * self.spatial_pad + (lr_slice_stop - lr_slice_start) < 4: + if ( + 2 * padding + (lr_slice_stop - lr_slice_start) + <= self.min_width[dim] + ): + half_width = self.min_width[dim] // 2 + 1 logger.warning( warn_msg, dim + 1, lr_slice_start, lr_slice_stop, - self.spatial_pad, + padding, ) - warn( - warn_msg - % (dim + 1, lr_slice_start, lr_slice_stop, self.spatial_pad) + warn(warn_msg % (dim + 1, lr_slice_start, lr_slice_stop, padding)) + cropped_slices[-1] = slice( + half_width * enhancement, -half_width * enhancement ) - cropped_slices[-1] = slice(2 * self.s_enhance, -2 * self.s_enhance) return cropped_slices @@ -600,7 +617,9 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement): return cropped_slices @staticmethod - def _get_pad_width(window, max_steps, max_pad, check_boundary=False): + def _get_pad_width( + window, max_steps, max_pad, min_width=None, check_boundary=False + ): """ Parameters ---------- @@ -610,6 +629,10 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False): Maximum number of steps available. Padding cannot extend past this max_pad : int Maximum amount of padding to apply. + min_width : int | None + Minimum width to enforce. This could be the forward pass chunk + shape or the padding value in the first padding layer of the + generator network. This is only used if ``check_boundary = True`` check_bounary : bool Whether to check the final slice for minimum size requirement @@ -625,14 +648,16 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False): # We add minimum padding to the last slice if the padded window is # too small for the generator. This can happen if 2 * spatial_pad + - # modulo(grid_size, fwp_chunk_shape) < 4 + # modulo(grid_size, fwp_chunk_shape) is less than the padding applied + # in the first padding layer of the generator if ( check_boundary and win_stop == max_steps - and (2 * max_pad + win_stop - win_start) < 4 + and (2 * max_pad + win_stop - win_start) < min_width ): - stop = np.max([2, max_pad]) - start = np.max([2, max_pad]) + half_width = min_width // 2 + 1 + stop = np.max([half_width, max_pad]) + start = np.max([half_width, max_pad]) return (start, stop) @@ -662,16 +687,20 @@ def get_pad_width(self, chunk_index): lr_slice[0], self.coarse_shape[0], self.spatial_pad, + self.min_width[0], check_boundary=True, ), self._get_pad_width( lr_slice[1], self.coarse_shape[1], self.spatial_pad, + self.min_width[1], check_boundary=True, ), self._get_pad_width( - ti_slice, len(self.dummy_time_index), self.temporal_pad + ti_slice, + len(self.dummy_time_index), + self.temporal_pad ), ) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 9a488e6a1..6b01294a3 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -231,6 +231,7 @@ def __post_init__(self): t_enhance=self.t_enhance, spatial_pad=self.spatial_pad, temporal_pad=self.temporal_pad, + min_width=self.get_min_pad_width(model) ) self.n_chunks = self.fwp_slicer.n_chunks @@ -253,6 +254,19 @@ def __post_init__(self): self.preflight() + def get_min_pad_width(self, model): + """Get the padding values applied in the first padding layer of the + model. This is used to determine the minimum width of padded slices + used to chunk the generator input.""" + pad_width = (1, 1, 1) + for layer in model._gen.layers: + if hasattr(layer, 'paddings'): + pad_width = np.max(layer.paddings, axis=1)[1:-1] + if len(pad_width) < 3: + pad_width = (*pad_width, 1) + break + return pad_width + @property def meta(self): """Meta data dictionary for the strategy. Used to add info to forward From 06221d78ae19e8a3f87bc6df6bdf72bf74ba59ad Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 15 Jan 2025 09:00:58 -0700 Subject: [PATCH 31/32] min padding depends on the ``.paddings`` attribute of the ``FlexiblePadding`` layers in the generator model. Generalized current min padding to use these values. --- sup3r/pipeline/slicer.py | 6 +++++- sup3r/pipeline/strategy.py | 13 +++++++------ tests/forward_pass/test_forward_pass.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index 3df7ace7c..424e0d0ef 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -326,6 +326,8 @@ def s_lr_crop_slices(self): s1_crop_slices = self.check_boundary_slice( unpadded_slices=self.s1_lr_slices, cropped_slices=s1_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=0, ) s2_crop_slices = self.get_cropped_slices( @@ -334,6 +336,8 @@ def s_lr_crop_slices(self): s2_crop_slices = self.check_boundary_slice( unpadded_slices=self.s2_lr_slices, cropped_slices=s2_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=1, ) self._s_lr_crop_slices = list( @@ -653,7 +657,7 @@ def _get_pad_width( if ( check_boundary and win_stop == max_steps - and (2 * max_pad + win_stop - win_start) < min_width + and (2 * max_pad + win_stop - win_start) <= min_width ): half_width = min_width // 2 + 1 stop = np.max([half_width, max_pad]) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 6b01294a3..1a51f7646 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -221,7 +221,6 @@ def __post_init__(self): self.input_handler_kwargs.get('time_slice', slice(None)) ) self.fwp_chunk_shape = self._get_fwp_chunk_shape() - self.fwp_slicer = ForwardPassSlicer( coarse_shape=self.input_handler.grid_shape, time_steps=len(self.input_handler.time_index), @@ -231,7 +230,7 @@ def __post_init__(self): t_enhance=self.t_enhance, spatial_pad=self.spatial_pad, temporal_pad=self.temporal_pad, - min_width=self.get_min_pad_width(model) + min_width=self.get_min_pad_width(model), ) self.n_chunks = self.fwp_slicer.n_chunks @@ -261,10 +260,12 @@ def get_min_pad_width(self, model): pad_width = (1, 1, 1) for layer in model._gen.layers: if hasattr(layer, 'paddings'): - pad_width = np.max(layer.paddings, axis=1)[1:-1] - if len(pad_width) < 3: - pad_width = (*pad_width, 1) - break + new_pw = np.max(layer.paddings, axis=1)[1:-1] + if len(new_pw) < 3: + new_pw = (*new_pw, 1) + pad_width = [ + np.max((new_pw[i], pad_width[i])) for i in range(3) + ] return pad_width @property diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 76d268762..944c0bf2c 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -802,8 +802,8 @@ def test_slicing_auto_boundary_pad(input_files, spatial_pad): fwp.strategy.ti_slices[t_idx], ) - assert chunk.input_data.shape[0] > 3 - assert chunk.input_data.shape[1] > 3 + assert chunk.input_data.shape[0] > strategy.fwp_slicer.min_width[0] + assert chunk.input_data.shape[1] > strategy.fwp_slicer.min_width[1] input_data = chunk.input_data.copy() if spatial_pad > 0: slices = [ From 89807351af6560ed81bcd293cdf908a7355b94b6 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 21 Jan 2025 07:38:59 -0700 Subject: [PATCH 32/32] `max_paddings` method in `interface` instead of in `strategy.py`. --- sup3r/models/interface.py | 21 +++++++++++++++++++++ sup3r/pipeline/strategy.py | 17 +---------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index e284cc21a..25f63f410 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -85,6 +85,27 @@ def input_dims(self): return self.models[0].input_dims return 5 + @property + def max_paddings(self): + """Get the maximum padding values used by the generator. This is used + to apply extra padding during forward passes if the raw input doesn't + meet the minimum input shape.""" + + paddings = (1, 1, 1) + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if hasattr(layer, 'paddings'): + new_pw = np.max(layer.paddings, axis=1)[1:-1] + if len(new_pw) < 3: + new_pw = (*new_pw, 1) + paddings = [ + np.max((new_pw[i], paddings[i])) for i in range(3) + ] + return paddings + if hasattr(self, 'models'): + return self.models[0].max_paddings + return paddings + @property def is_5d(self): """Check if model expects spatiotemporal input""" diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 1a51f7646..db05a72aa 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -230,7 +230,7 @@ def __post_init__(self): t_enhance=self.t_enhance, spatial_pad=self.spatial_pad, temporal_pad=self.temporal_pad, - min_width=self.get_min_pad_width(model), + min_width=model.max_paddings, ) self.n_chunks = self.fwp_slicer.n_chunks @@ -253,21 +253,6 @@ def __post_init__(self): self.preflight() - def get_min_pad_width(self, model): - """Get the padding values applied in the first padding layer of the - model. This is used to determine the minimum width of padded slices - used to chunk the generator input.""" - pad_width = (1, 1, 1) - for layer in model._gen.layers: - if hasattr(layer, 'paddings'): - new_pw = np.max(layer.paddings, axis=1)[1:-1] - if len(new_pw) < 3: - new_pw = (*new_pw, 1) - pad_width = [ - np.max((new_pw[i], pad_width[i])) for i in range(3) - ] - return pad_width - @property def meta(self): """Meta data dictionary for the strategy. Used to add info to forward