diff --git a/README.rst b/README.rst index d3b0aea48..3d321accc 100644 --- a/README.rst +++ b/README.rst @@ -78,4 +78,4 @@ Brandon Benton, Grant Buster, Guilherme Pimenta Castelao, Malik Hassanaly, Pavlo Acknowledgments =============== -This work was authored by the National Renewable Energy Laboratory, operated by Alliance for Sustainable Energy, LLC, for the U.S. Department of Energy (DOE) under Contract No. DE-AC36-08GO28308. This research was supported by the Grid Modernization Initiative of the U.S. Department of Energy (DOE) as part of its Grid Modernization Laboratory Consortium, a strategic partnership between DOE and the national laboratories to bring together leading experts, technologies, and resources to collaborate on the goal of modernizing the nation’s grid. Funding provided by the the DOE Office of Energy Efficiency and Renewable Energy (EERE), the DOE Office of Electricity (OE), DOE Grid Deployment Office (GDO), the DOE Office of Fossil Energy and Carbon Management (FECM), and the DOE Office of Cybersecurity, Energy Security, and Emergency Response (CESER), the DOE Advanced Scientific Computing Research (ASCR) program, the DOE Solar Energy Technologies Office (SETO), the DOE Wind Energy Technologies Office (WETO), the United States Agency for International Development (USAID), and the Laboratory Directed Research and Development (LDRD) program at the National Renewable Energy Laboratory. The research was performed using computational resources sponsored by the Department of Energy's Office of Energy Efficiency and Renewable Energy and located at the National Renewable Energy Laboratory. The views expressed in the article do not necessarily represent the views of the DOE or the U.S. Government. The U.S. Government retains and the publisher, by accepting the article for publication, acknowledges that the U.S. Government retains a nonexclusive, paid-up, irrevocable, worldwide license to publish or reproduce the published form of this work, or allow others to do so, for U.S. Government purposes. +This work was authored by the National Renewable Energy Laboratory, operated by Alliance for Sustainable Energy, LLC, for the U.S. Department of Energy (DOE) under Contract No. DE-AC36-08GO28308. This research was supported by the Grid Modernization Initiative of the U.S. Department of Energy (DOE) as part of its Grid Modernization Laboratory Consortium, a strategic partnership between DOE and the national laboratories to bring together leading experts, technologies, and resources to collaborate on the goal of modernizing the nation’s grid. Funding provided by the the DOE Office of Energy Efficiency and Renewable Energy (EERE), the DOE Office of Electricity (OE), DOE Grid Deployment Office (GDO), the DOE Office of Fossil Energy and Carbon Management (FECM), and the DOE Office of Cybersecurity, Energy Security, and Emergency Response (CESER), the DOE Advanced Scientific Computing Research (ASCR) program, the DOE Solar Energy Technologies Office (SETO), the DOE Wind Energy Technologies Office (WETO), the United States Agency for International Development (USAID), and the Laboratory Directed Research and Development (LDRD) program at the National Renewable Energy Laboratory. The research was performed using computational resources sponsored by the Department of Energy's Office of Energy Efficiency and Renewable Energy and located at the National Renewable Energy Laboratory. The views expressed in the article do not necessarily represent the views of the DOE or the U.S. Government. The U.S. Government retains and the publisher, by accepting the article for publication, acknowledges that the U.S. Government retains a nonexclusive, paid-up, irrevocable, worldwide license to publish or reproduce the published form of this work, or allow others to do so, for U.S. Government purposes. diff --git a/examples/sup3rwind/README.rst b/examples/sup3rwind/README.rst index 8ca30dc57..4cb0c5fa3 100644 --- a/examples/sup3rwind/README.rst +++ b/examples/sup3rwind/README.rst @@ -2,7 +2,7 @@ Sup3rWind Examples ################### -Super-Resolution for Renewable Energy Resource Data with Wind from Reanalysis Data (Sup3rWind) is one application of the sup3r software. In this work, we train generative models to create high-resolution (2km 5-minute) wind data based on coarse (30km hourly) ERA5 data. The generative models and high-resolution output data is publicly available via the `Open Energy Data Initiative (OEDI) `__ and via HSDS at the bucket ``nrel-pds-hsds`` and path ``/nrel/wtk/sup3rwind``. This data covers recent historical time periods for an expanding selection of countries. +Super-Resolution for Renewable Energy Resource Data with Wind from Reanalysis Data (Sup3rWind) is one application of the sup3r software. In this work, we train generative models to create high-resolution (2km 5-minute) wind data based on coarse (30km hourly) ERA5 data. The generative models, high-resolution output data, and training data is publicly available via the `Open Energy Data Initiative (OEDI) `__ and via HSDS at the bucket ``nrel-pds-hsds`` and path ``/nrel/wtk/sup3rwind``. This data covers recent historical time periods for an expanding selection of countries. Sup3rWind Data Access ---------------------- @@ -11,8 +11,8 @@ The Sup3rWind data and models are publicly available in a public AWS S3 bucket. The Sup3rWind data is also loaded into `HSDS `__ so that you may stream the data via the `NREL developer API `__ or your own HSDS server. This is the best option if you're not going to want a full annual dataset. See these `rex instructions `__ for more details on how to access this data with HSDS and rex. -Example Sup3rWind Data Usage ------------------------------ +Sup3rWind Data Usage +--------------------- Sup3rWind data can be used in generally the same way as `Sup3rCC `__ data, with the condition that Sup3rWind includes only wind data and ancillary variables for modeling wind energy generation. Refer to the Sup3rCC `example notebook `__ for usage patterns. @@ -32,6 +32,39 @@ The process for running the Sup3rWind models is much the same as for `Sup3rCC `__. This data is for training the spatial enhancement models only. The 2024-01 `models `__ perform spatial enhancement in two steps, 3x from ERA5 to coarsened WTK and 5x from coarsened WTK to uncoarsened WTK. The currently used approach performs spatial enhancement in a single 15x step. + +For a given year and training domain, initialize low-resolution and high-resolution data handlers and wrap these in a dual rasterizer object. Do this for as many years and training regions as desired, and use these containers to initialize a batch handler. To train models for 3x spatial enhancement use ``hr_spatial_coarsen=5`` in the ``hr_dh``. To train models for 15x (the currently used approach) ``hr_spatial_coarsen=1``. (Refer to tests and docs for information on additional arguments, denoted by the ellipses):: + + from sup3r.preprocessing import DataHandler, DualBatchHandler, DualRasterizer + containers = [] + for tdir in training_dirs: + lr_dh = DataHandler(f"{tdir}/lr_*.h5", ...) + hr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=...) + container = DualRasterizer({'low_res': lr_dh, 'high_res': hr_dh}, ...) + containers.append(container) + bh = DualBatchHandler(train_containers=containers, ...) + +To train a 5x model use the ``hr_*.h5`` files for both the ``lr_dh`` and the ``hr_dh``. Use ``hr_spatial_coarsen=3`` in the ``lr_dh`` and ``hr_spatial_coarsen=1`` in the ``hr_dh``:: + + for tdir in training_dirs: + lr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=3, ...) + hr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=1, ...) + container = DualRasterizer({'low_res': lr_dh, 'high_res': hr_dh}, ...) + containers.append(container) + bh = DualBatchHandler(train_containers=containers, ...) + + +Initialize a 3x, 5x, or 15x spatial enhancement model, with 14 output channels, and train for the desired number of epochs. (The 3x and 5x generator configs can be copied from the ``model_params.json`` files in each OEDI model `directory `__. The 15x generator config can be created from the OEDI model configs by changing the spatial enhancement factor or from the configs in the repo by changing the enhancement factor and the number of output channels):: + + from sup3r.models import Sup3rGan + model = Sup3rGan(gen_layers="./gen_config.json", disc_layers="./disc_config.json", ...) + model.train(batch_handler, ...) + + Sup3rWind Versions ------------------- diff --git a/pyproject.toml b/pyproject.toml index 80bc45ec4..e1c92afaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,10 +42,21 @@ dependencies = [ "pytest>=5.2", "scipy>=1.0.0", "sphinx>=7.0", - "tensorflow>2.4,<2.16", "xarray>=2023.0" ] +# If used, cause glibc conflict +# [tool.pixi.target.linux-64.dependencies] +# cuda = ">=11.8" +# cudnn = {version = ">=8.6.0", channel = "conda-forge"} +# # 8.9.7 + +[tool.pixi.target.linux-64.pypi-dependencies] +tensorflow = {version = "~=2.15.1", extras = ["and-cuda"] } + +[tool.pixi.target.osx-arm64.dependencies] +tensorflow = {version = "~=2.15.0", channel = "conda-forge"} + [project.optional-dependencies] dev = [ "build>=0.5", @@ -272,7 +283,6 @@ matplotlib = ">=3.1" numpy = "~=1.7" pandas = ">=2.0" scipy = ">=1.0.0" -tensorflow = ">2.4,<2.16" xarray = ">=2023.0" [tool.pixi.pypi-dependencies] @@ -284,6 +294,7 @@ NREL-farms = { version = ">=1.0.4" } [tool.pixi.environments] default = { solve-group = "default" } +kestrel = { features = ["kestrel"], solve-group = "default" } dev = { features = ["dev", "doc", "test"], solve-group = "default" } doc = { features = ["doc"], solve-group = "default" } test = { features = ["test"], solve-group = "default" } diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 5d6b51344..20fff799f 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -6,6 +6,7 @@ from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel +from .with_obs import Sup3rGanWithObs SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index b4b7fb869..54307843c 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1,11 +1,9 @@ """Abstract class defining the required interface for Sup3r model subclasses""" import json -import locale import logging import os import pprint -import re import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -23,588 +21,18 @@ from sup3r.preprocessing.data_handlers import ExoData from sup3r.preprocessing.utilities import numpy_if_tensor from sup3r.utilities import VERSION_RECORD -from sup3r.utilities.utilities import Timer, safe_cast +from sup3r.utilities.utilities import safe_cast -logger = logging.getLogger(__name__) - - -class TensorboardMixIn: - """MixIn class for tensorboard logging and profiling.""" - - def __init__(self): - self._tb_writer = None - self._tb_log_dir = None - self._write_tb_profile = False - self._total_batches = None - self._history = None - self.timer = Timer() - - @property - def total_batches(self): - """Record of total number of batches for logging.""" - if self._total_batches is None and self._history is None: - self._total_batches = 0 - elif self._history is None and 'total_batches' in self._history: - self._total_batches = self._history['total_batches'].values[-1] - elif self._total_batches is None and self._history is not None: - self._total_batches = 0 - return self._total_batches - - @total_batches.setter - def total_batches(self, value): - """Set total number of batches.""" - self._total_batches = value - - def dict_to_tensorboard(self, entry): - """Write data to tensorboard log file. This is usually a loss_details - dictionary. - - Parameters - ---------- - entry: dict - Dictionary of values to write to tensorboard log file - """ - if self._tb_writer is not None: - with self._tb_writer.as_default(): - for name, value in entry.items(): - if isinstance(value, str): - tf.summary.text(name, value, self.total_batches) - else: - tf.summary.scalar(name, value, self.total_batches) - - def profile_to_tensorboard(self, name): - """Write profile data to tensorboard log file. - - Parameters - ---------- - name : str - Tag name to use for profile info - """ - if self._tb_writer is not None and self._write_tb_profile: - with self._tb_writer.as_default(): - tf.summary.trace_export( - name=name, - step=self.total_batches, - profiler_outdir=self._tb_log_dir, - ) - - def _init_tensorboard_writer(self, out_dir): - """Initialize the ``tf.summary.SummaryWriter`` to use for writing - tensorboard compatible log files. - - Parameters - ---------- - out_dir : str - Standard out_dir where model epochs are saved. e.g. './gan_{epoch}' - """ - tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir)) - self._tb_log_dir = os.path.join(tb_log_pardir, 'logs') - os.makedirs(self._tb_log_dir, exist_ok=True) - self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) - - -class AbstractInterface(ABC): - """ - Abstract class to define the required interface for Sup3r model subclasses - - Note that this only sets the required interfaces for a GAN that can be - loaded from disk and used to predict synthetic outputs. The interface for - models that can be trained will be set in another class. - """ - - @classmethod - @abstractmethod - def load(cls, model_dir, verbose=True): - """Load the GAN with its sub-networks from a previously saved-to output - directory. - - Parameters - ---------- - model_dir - Directory to load GAN model files from. - verbose : bool - Flag to log information about the loaded model. - - Returns - ------- - out : BaseModel - Returns a pretrained gan model that was previously saved to - model_dir - """ - - @abstractmethod - def generate( - self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None - ): - """Use the generator model to generate high res data from low res - input. This is the public generate function.""" - - @staticmethod - def seed(s=0): - """ - Set the random seed for reproducible results. - - Parameters - ---------- - s : int - Random seed - """ - CustomNetwork.seed(s=s) - - @property - def input_dims(self): - """Get dimension of model generator input. This is usually 4D for - spatial models and 5D for spatiotemporal models. This gives the input - to the first step if the model is multi-step. Returns 5 for linear - models. - - Returns - ------- - int - """ - # pylint: disable=E1101 - if hasattr(self, '_gen'): - return self._gen.layers[0].rank - if hasattr(self, 'models'): - return self.models[0].input_dims - return 5 - - @property - def is_5d(self): - """Check if model expects spatiotemporal input""" - return self.input_dims == 5 - - @property - def is_4d(self): - """Check if model expects spatial only input""" - return self.input_dims == 4 - - # pylint: disable=E1101 - def get_s_enhance_from_layers(self): - """Compute factor by which model will enhance spatial resolution from - layer attributes. Used in model training during high res coarsening""" - s_enhance = None - if hasattr(self, '_gen'): - s_enhancements = [ - getattr(layer, '_spatial_mult', 1) - for layer in self._gen.layers - ] - s_enhance = int(np.prod(s_enhancements)) - return s_enhance - - # pylint: disable=E1101 - def get_t_enhance_from_layers(self): - """Compute factor by which model will enhance temporal resolution from - layer attributes. Used in model training during high res coarsening""" - t_enhance = None - if hasattr(self, '_gen'): - t_enhancements = [ - getattr(layer, '_temporal_mult', 1) - for layer in self._gen.layers - ] - t_enhance = int(np.prod(t_enhancements)) - return t_enhance - - @property - def s_enhance(self): - """Factor by which model will enhance spatial resolution. Used in - model training during high res coarsening and also in forward pass - routine to determine shape of needed exogenous data""" - models = getattr(self, 'models', [self]) - s_enhances = [m.meta.get('s_enhance', None) for m in models] - s_enhance = ( - self.get_s_enhance_from_layers() - if any(s is None for s in s_enhances) - else int(np.prod(s_enhances)) - ) - if len(models) == 1: - self.meta['s_enhance'] = s_enhance - return s_enhance - - @property - def t_enhance(self): - """Factor by which model will enhance temporal resolution. Used in - model training during high res coarsening and also in forward pass - routine to determine shape of needed exogenous data""" - models = getattr(self, 'models', [self]) - t_enhances = [m.meta.get('t_enhance', None) for m in models] - t_enhance = ( - self.get_t_enhance_from_layers() - if any(t is None for t in t_enhances) - else int(np.prod(t_enhances)) - ) - if len(models) == 1: - self.meta['t_enhance'] = t_enhance - return t_enhance - - @property - def s_enhancements(self): - """List of spatial enhancement factors. In the case of a single step - model this is just ``[self.s_enhance]``. This is used to determine - shapes of needed exogenous data in forward pass routine""" - if hasattr(self, 'models'): - return [model.s_enhance for model in self.models] - return [self.s_enhance] - - @property - def t_enhancements(self): - """List of temporal enhancement factors. In the case of a single step - model this is just ``[self.t_enhance]``. This is used to determine - shapes of needed exogenous data in forward pass routine""" - if hasattr(self, 'models'): - return [model.t_enhance for model in self.models] - return [self.t_enhance] - - @property - def input_resolution(self): - """Resolution of input data. Given as a dictionary - ``{'spatial': '...km', 'temporal': '...min'}``. The numbers are - required to be integers in the units specified. The units are not - strict as long as the resolution of the exogenous data, when extracting - exogenous data, is specified in the same units.""" - input_resolution = self.meta.get('input_resolution', None) - msg = 'model.input_resolution is None. This needs to be set.' - assert input_resolution is not None, msg - return input_resolution - - def _get_numerical_resolutions(self): - """Get the input and output resolutions without units. e.g. for - ``{"spatial": "30km", "temporal": "60min"}`` this returns - ``{"spatial": 30, "temporal": 60}``""" - ires_num = { - k: int(re.search(r'\d+', v).group(0)) - for k, v in self.input_resolution.items() - } - enhancements = {'spatial': self.s_enhance, 'temporal': self.t_enhance} - ores_num = {k: v // enhancements[k] for k, v in ires_num.items()} - return ires_num, ores_num - - def _ensure_valid_input_resolution(self): - """Ensure ehancement factors evenly divide input_resolution""" - - if self.input_resolution is None: - return - - ires_num, ores_num = self._get_numerical_resolutions() - s_enhance = self.meta['s_enhance'] - t_enhance = self.meta['t_enhance'] - check = ( - ires_num['temporal'] / ores_num['temporal'] == t_enhance - and ires_num['spatial'] / ores_num['spatial'] == s_enhance - ) - msg = ( - f'Enhancement factors (s_enhance={s_enhance}, ' - f't_enhance={t_enhance}) do not evenly divide ' - f'input resolution ({self.input_resolution})' - ) - if not check: - logger.error(msg) - raise RuntimeError(msg) - - def _ensure_valid_enhancement_factors(self): - """Ensure user provided enhancement factors are the same as those - computed from layer attributes""" - t_enhance = self.meta.get('t_enhance', None) - s_enhance = self.meta.get('s_enhance', None) - if s_enhance is None or t_enhance is None: - return - - layer_se = self.get_s_enhance_from_layers() - layer_te = self.get_t_enhance_from_layers() - layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] - layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] - msg = ( - f'Enhancement factors computed from layer attributes ' - f'(s_enhance={layer_se}, t_enhance={layer_te}) ' - f'conflict with user provided values (s_enhance={s_enhance}, ' - f't_enhance={t_enhance})' - ) - check = layer_se == s_enhance or layer_te == t_enhance - if not check: - logger.error(msg) - raise RuntimeError(msg) - - @property - def output_resolution(self): - """Resolution of output data. Given as a dictionary - {'spatial': '...km', 'temporal': '...min'}. This is computed from the - input resolution and the enhancement factors.""" - output_res = self.meta.get('output_resolution', None) - if self.input_resolution is not None and output_res is None: - ires_num, ores_num = self._get_numerical_resolutions() - output_res = { - k: v.replace(str(ires_num[k]), str(ores_num[k])) - for k, v in self.input_resolution.items() - } - self.meta['output_resolution'] = output_res - return output_res - - def _combine_fwp_input(self, low_res, exogenous_data=None): - """Combine exogenous_data at input resolution with low_res data prior - to forward pass through generator - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - exogenous_data : dict | ExoData | None - Special dictionary (class:`ExoData`) of exogenous feature data with - entries describing whether features should be combined at input, a - mid network layer, or with output. This doesn't have to include - the 'model' key since this data is for a single step model. - - Returns - ------- - low_res : np.ndarray - Low-resolution input data combined with exogenous_data, usually a - 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - if exogenous_data is None: - return low_res - - if ( - not isinstance(exogenous_data, ExoData) - and exogenous_data is not None - ): - exogenous_data = ExoData(exogenous_data) - - fnum_diff = len(self.lr_features) - low_res.shape[-1] - exo_feats = [] if fnum_diff <= 0 else self.lr_features[-fnum_diff:] - msg = ( - f'Provided exogenous_data: {exogenous_data} is missing some ' - f'required features ({exo_feats})' - ) - assert all(feature in exogenous_data for feature in exo_feats), msg - if exogenous_data is not None and fnum_diff > 0: - for feature in exo_feats: - exo_input = exogenous_data.get_combine_type_data( - feature, 'input' - ) - if exo_input is not None: - low_res = np.concatenate((low_res, exo_input), axis=-1) - - return low_res - - def _combine_fwp_output(self, hi_res, exogenous_data=None): - """Combine exogenous_data at output resolution with generated hi_res - data following forward pass output. - - Parameters - ---------- - hi_res : np.ndarray - High-resolution output data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - exogenous_data : dict | ExoData | None - Special dictionary (class:`ExoData`) of exogenous feature data with - entries describing whether features should be combined at input, a - mid network layer, or with output. This doesn't have to include - the 'model' key since this data is for a single step model. +from .utilities import TensorboardMixIn - Returns - ------- - hi_res : np.ndarray - High-resolution output data combined with exogenous_data, usually a - 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - if exogenous_data is None: - return hi_res - - if ( - not isinstance(exogenous_data, ExoData) - and exogenous_data is not None - ): - exogenous_data = ExoData(exogenous_data) - - fnum_diff = len(self.hr_out_features) - hi_res.shape[-1] - exo_feats = [] if fnum_diff <= 0 else self.hr_out_features[-fnum_diff:] - msg = ( - 'Provided exogenous_data is missing some required features ' - f'({exo_feats})' - ) - assert all(feature in exogenous_data for feature in exo_feats), msg - if exogenous_data is not None and fnum_diff > 0: - for feature in exo_feats: - exo_output = exogenous_data.get_combine_type_data( - feature, 'output' - ) - if exo_output is not None: - hi_res = np.concatenate((hi_res, exo_output), axis=-1) - return hi_res - - @tf.function - def _combine_loss_input(self, high_res_true, high_res_gen): - """Combine exogenous feature data from high_res_true with high_res_gen - for loss calculation - - Parameters - ---------- - high_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - high_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - - Returns - ------- - high_res_gen : tf.Tensor - Same as input with exogenous data combined with high_res input - """ - if high_res_true.shape[-1] > high_res_gen.shape[-1]: - for feature in self.hr_exo_features: - f_idx = self.hr_exo_features.index(feature) - f_idx += len(self.hr_out_features) - exo_data = high_res_true[..., f_idx : f_idx + 1] - high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) - return high_res_gen - - @property - @abstractmethod - def meta(self): - """Get meta data dictionary that defines how the model was created""" - - @property - def lr_features(self): - """Get a list of low-resolution features input to the generative model. - This includes low-resolution features that might be supplied - exogenously at inference time but that were in the low-res batches - during training""" - return self.meta.get('lr_features', []) - - @property - def hr_out_features(self): - """Get the list of high-resolution output feature names that the - generative model outputs.""" - return self.meta.get('hr_out_features', []) - - @property - def hr_exo_features(self): - """Get list of high-resolution exogenous filter names the model uses. - If the model has N concat or add layers this list will be the last N - features in the training features list. The ordering is assumed to be - the same as the order of concat or add layers. If training features is - [..., topo, sza], and the model has 2 concat or add layers, exo - features will be [topo, sza]. Topo will then be used in the first - concat layer and sza will be used in the second""" - # pylint: disable=E1101 - features = [] - if hasattr(self, '_gen'): - features = [ - layer.name - for layer in self._gen.layers - if isinstance(layer, (Sup3rAdder, Sup3rConcat)) - ] - return features - - @property - def smoothing(self): - """Value of smoothing parameter used in gaussian filtering of coarsened - high res data.""" - return self.meta.get('smoothing', None) - - @property - def smoothed_features(self): - """Get the list of smoothed input feature names that the generative - model was trained on.""" - return self.meta.get('smoothed_features', []) - - @property - def model_params(self): - """ - Model parameters, used to save model to disc - - Returns - ------- - dict - """ - return {'meta': self.meta} - - @property - def version_record(self): - """A record of important versions that this model was built with. - - Returns - ------- - dict - """ - return VERSION_RECORD - - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'input_resolution', - 'lr_features', 'hr_exo_features', 'hr_out_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' - """ - - keys = ( - 'input_resolution', - 'lr_features', - 'hr_exo_features', - 'hr_out_features', - 'smoothed_features', - 's_enhance', - 't_enhance', - 'smoothing', - ) - keys = [k for k in keys if k in kwargs] - - hr_exo_feat = kwargs.get('hr_exo_features', []) - msg = ( - f'Expected high-res exo features {self.hr_exo_features} ' - f'based on model architecture but received "hr_exo_features" ' - f'from data handler: {hr_exo_feat}' - ) - assert list(self.hr_exo_features) == list(hr_exo_feat), msg - - for var in keys: - val = self.meta.get(var, None) - if val is None: - self.meta[var] = kwargs[var] - elif val != kwargs[var]: - msg = ( - 'Model was previously trained with {var}={} but ' - 'received new {var}={}'.format(val, kwargs[var], var=var) - ) - logger.warning(msg) - warn(msg) - - self._ensure_valid_enhancement_factors() - self._ensure_valid_input_resolution() - - def save_params(self, out_dir): - """ - Parameters - ---------- - out_dir : str - Directory to save linear model params. This directory will be - created if it does not already exist. - """ - if not os.path.exists(out_dir): - os.makedirs(out_dir, exist_ok=True) - - fp_params = os.path.join(out_dir, 'model_params.json') - with open( - fp_params, 'w', encoding=locale.getpreferredencoding(False) - ) as f: - params = self.model_params - json.dump(params, f, sort_keys=True, indent=2, default=safe_cast) +logger = logging.getLogger(__name__) # pylint: disable=E1101,W0201,E0203 class AbstractSingleModel(ABC, TensorboardMixIn): """ - Abstract class to define the required training interface - for Sup3r model subclasses + Abstract class to define the required training interface for Sup3r model + subclasses """ def __init__(self): @@ -991,6 +419,30 @@ def get_high_res_exo_input(self, high_res): exo_data[feature] = exo_fdata return exo_data + @tf.function + def _combine_loss_input(self, high_res_true, high_res_gen): + """Combine exogenous feature data from high_res_true with high_res_gen + for loss calculation + + Parameters + ---------- + high_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + high_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + high_res_gen : tf.Tensor + Same as input with exogenous data combined with high_res input + """ + if high_res_true.shape[-1] > high_res_gen.shape[-1]: + exo_dict = self.get_high_res_exo_input(high_res_true) + exo_data = [exo_dict[feat] for feat in self.hr_exo_features] + high_res_gen = tf.concat((high_res_gen, *exo_data), axis=-1) + return high_res_gen + @staticmethod def get_loss_fun(loss): """Get the initialized loss function class from the sup3r loss library @@ -1110,6 +562,7 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None): if key in loss_details: saved_value = loss_details[key] + saved_value = 0 if np.isnan(saved_value) else saved_value saved_value *= prior_n_obs saved_value += batch_len * new_value saved_value /= new_n_obs @@ -1256,9 +709,9 @@ def finish_epoch( """ self.log_loss_details(loss_details) self._history.at[epoch, 'elapsed_time'] = time.time() - t0 - for key, value in loss_details.items(): - if key != 'n_obs': - self._history.at[epoch, key] = value + cols = [k for k in loss_details if k != 'n_obs'] + entry = np.vstack([loss_details[k] for k in cols]) + self._history.loc[epoch, cols] = entry.T last_epoch = epoch == epochs[-1] chp = checkpoint_int is not None and (epoch % checkpoint_int) == 0 @@ -1282,11 +735,71 @@ def finish_epoch( self.save(out_dir.format(epoch=epoch)) if extras is not None: - for k, v in extras.items(): - self._history.at[epoch, k] = safe_cast(v) + entry = np.vstack([safe_cast(v) for v in extras.values()]) + self._history.loc[epoch, list(extras)] = entry.T return stop + def _sum_parallel_grad(self, futures, start_time): + """Sum gradient descent future results""" + + # sum the gradients from each gpu to weight equally in + # optimizer momentum calculation + total_grad = None + for future in futures: + grad, loss_details = future.result() + if total_grad is None: + total_grad = grad + else: + for i, igrad in enumerate(grad): + total_grad[i] += igrad + + msg = ( + f'Finished {len(futures)} gradient descent steps on ' + f'{len(self.gpu_list)} GPUs in {time.time() - start_time:.4f} ' + 'seconds' + ) + logger.info(msg) + return total_grad, loss_details + + def _get_parallel_grad( + self, + low_res, + hi_res_true, + training_weights, + **calc_loss_kwargs, + ): + """Compute gradient for one mini-batch of (low_res, hi_res_true) + across multiple GPUs""" + + futures = [] + start_time = time.time() + lr_chunks = np.array_split(low_res, len(self.gpu_list)) + hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) + split_mask = False + mask_chunks = None + if 'mask' in calc_loss_kwargs: + split_mask = True + mask_chunks = np.array_split( + calc_loss_kwargs['mask'], len(self.gpu_list) + ) + + with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: + for i in range(len(self.gpu_list)): + if split_mask: + calc_loss_kwargs['mask'] = mask_chunks[i] + futures.append( + exe.submit( + self.get_single_grad, + lr_chunks[i], + hr_true_chunks[i], + training_weights, + device_name=f'/gpu:{i}', + **calc_loss_kwargs, + ) + ) + return self._sum_parallel_grad(futures, start_time=start_time) + def run_gradient_descent( self, low_res, @@ -1296,7 +809,6 @@ def run_gradient_descent( multi_gpu=False, **calc_loss_kwargs, ): - # pylint: disable=E0602 """Run gradient descent for one mini-batch of (low_res, hi_res_true) and update weights @@ -1332,11 +844,11 @@ def run_gradient_descent( loss_details : dict Namespace of the breakdown of loss components """ - self.timer.start() if optimizer is None: optimizer = self.optimizer if not multi_gpu or len(self.gpu_list) < 2: + start_time = time.time() grad, loss_details = self.get_single_grad( low_res, hi_res_true, @@ -1345,58 +857,20 @@ def run_gradient_descent( **calc_loss_kwargs, ) optimizer.apply_gradients(zip(grad, training_weights)) - self.timer.stop() - logger.debug( - 'Finished single gradient descent step in %s', - self.timer.elapsed_str, + msg = ( + 'Finished single gradient descent step in ' + f'{time.time() - start_time:.4f} seconds' ) + logger.debug(msg) else: - futures = [] - lr_chunks = np.array_split(low_res, len(self.gpu_list)) - hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) - split_mask = False - mask_chunks = None - if 'mask' in calc_loss_kwargs: - split_mask = True - mask_chunks = np.array_split( - calc_loss_kwargs['mask'], len(self.gpu_list) - ) - - with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: - for i in range(len(self.gpu_list)): - if split_mask: - calc_loss_kwargs['mask'] = mask_chunks[i] - futures.append( - exe.submit( - self.get_single_grad, - lr_chunks[i], - hr_true_chunks[i], - training_weights, - device_name=f'/gpu:{i}', - **calc_loss_kwargs, - ) - ) - - # sum the gradients from each gpu to weight equally in - # optimizer momentum calculation - total_grad = None - for future in futures: - grad, loss_details = future.result() - if total_grad is None: - total_grad = grad - else: - for i, igrad in enumerate(grad): - total_grad[i] += igrad - + total_grad, loss_details = self._get_parallel_grad( + low_res, + hi_res_true, + training_weights, + **calc_loss_kwargs, + ) optimizer.apply_gradients(zip(total_grad, training_weights)) - self.timer.stop() - logger.debug( - 'Finished %s gradient descent steps on %s GPUs in %s', - len(futures), - len(self.gpu_list), - self.timer.elapsed_str, - ) return loss_details def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): @@ -1588,6 +1062,21 @@ def _tf_generate(self, low_res, hi_res_exo=None): return hi_res + def _get_hr_exo_and_loss( + self, + low_res, + hi_res_true, + **calc_loss_kwargs, + ): + """Get high-resolution exogenous data, generate synthetic output, and + compute loss.""" + hi_res_exo = self.get_high_res_exo_input(hi_res_true) + hi_res_gen = self._tf_generate(low_res, hi_res_exo) + loss, loss_details = self.calc_loss( + hi_res_true, hi_res_gen, **calc_loss_kwargs + ) + return loss, loss_details, hi_res_gen + @tf.function def get_single_grad( self, @@ -1633,11 +1122,20 @@ def get_single_grad( watch_accessed_variables=False ) as tape: tape.watch(training_weights) - hi_res_exo = self.get_high_res_exo_input(hi_res_true) - hi_res_gen = self._tf_generate(low_res, hi_res_exo) - loss_out = self.calc_loss( - hi_res_true, hi_res_gen, **calc_loss_kwargs + loss, loss_details, _ = self._get_hr_exo_and_loss( + low_res, hi_res_true, **calc_loss_kwargs ) - loss, loss_details = loss_out grad = tape.gradient(loss, training_weights) return grad, loss_details + + @abstractmethod + def calc_loss( + self, + hi_res_true, + hi_res_gen, + weight_gen_advers=0.001, + train_gen=True, + train_disc=False, + ): + """Calculate the GAN loss function using generated and true high + resolution data.""" diff --git a/sup3r/models/base.py b/sup3r/models/base.py index cfdc47f73..e7e081e02 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -14,7 +14,8 @@ from sup3r.preprocessing.utilities import get_class_kwargs from sup3r.utilities import VERSION_RECORD -from .abstract import AbstractInterface, AbstractSingleModel +from .abstract import AbstractSingleModel +from .interface import AbstractInterface from .utilities import get_optimizer_class logger = logging.getLogger(__name__) @@ -546,242 +547,6 @@ def calc_loss_disc(disc_out_true, disc_out_gen): ) return tf.reduce_mean(loss_disc) - @tf.function - def calc_loss( - self, - hi_res_true, - hi_res_gen, - weight_gen_advers=0.001, - train_gen=True, - train_disc=False, - ): - """Calculate the GAN loss function using generated and true high - resolution data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - train_gen : bool - True if generator is being trained, then loss=loss_gen - train_disc : bool - True if disc is being trained, then loss=loss_disc - - Returns - ------- - loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components - """ - hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) - - if hi_res_gen.shape != hi_res_true.shape: - msg = ( - 'The tensor shapes of the synthetic output {} and ' - 'true high res {} did not have matching shape! ' - 'Check the spatiotemporal enhancement multipliers in your ' - 'your model config and data handlers.'.format( - hi_res_gen.shape, hi_res_true.shape - ) - ) - logger.error(msg) - raise RuntimeError(msg) - - disc_out_true = self._tf_discriminate(hi_res_true) - disc_out_gen = self._tf_discriminate(hi_res_gen) - - loss_gen_content = self.calc_loss_gen_content(hi_res_true, hi_res_gen) - loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) - loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers - - loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) - - loss = None - if train_gen: - loss = loss_gen - elif train_disc: - loss = loss_disc - - loss_details = { - 'loss_gen': loss_gen, - 'loss_gen_content': loss_gen_content, - 'loss_gen_advers': loss_gen_advers, - 'loss_disc': loss_disc, - } - - return loss, loss_details - - def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): - """Calculate the validation loss at the current state of model training - - Parameters - ---------- - batch_handler : sup3r.preprocessing.BatchHandler - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - loss_details : dict - Namespace of the breakdown of loss components - - Returns - ------- - loss_details : dict - Same as input but now includes val_* loss info - """ - logger.debug('Starting end-of-epoch validation loss calculation...') - loss_details['n_obs'] = 0 - for val_batch in batch_handler.val_data: - val_exo_data = self.get_high_res_exo_input(val_batch.high_res) - high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) - _, v_loss_details = self.calc_loss( - val_batch.high_res, - high_res_gen, - weight_gen_advers=weight_gen_advers, - train_gen=False, - train_disc=False, - ) - - loss_details = self.update_loss_details( - loss_details, v_loss_details, len(val_batch), prefix='val_' - ) - return loss_details - - def train_epoch( - self, - batch_handler, - weight_gen_advers, - train_gen, - train_disc, - disc_loss_bounds, - multi_gpu=False, - ): - """Train the GAN for one epoch. - - Parameters - ---------- - batch_handler : sup3r.preprocessing.BatchHandler - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - train_gen : bool - Flag whether to train the generator for this set of epochs - train_disc : bool - Flag whether to train the discriminator for this set of epochs - disc_loss_bounds : tuple - Lower and upper bounds for the discriminator loss outside of which - the discriminators will not train unless train_disc=True or - and train_gen=False. - multi_gpu : bool - Flag to break up the batch for parallel gradient descent - calculations on multiple gpus. If True and multiple GPUs are - present, each batch from the batch_handler will be divided up - between the GPUs and resulting gradients from each GPU will be - summed and then applied once per batch at the nominal learning - rate that the model and optimizer were initialized with. - If true and multiple gpus are found, ``default_device`` device - should be set to /gpu:0 - - Returns - ------- - loss_details : dict - Namespace of the breakdown of loss components - """ - - disc_th_low = np.min(disc_loss_bounds) - disc_th_high = np.max(disc_loss_bounds) - loss_details = {'n_obs': 0, 'train_loss_disc': 0} - - only_gen = train_gen and not train_disc - only_disc = train_disc and not train_gen - - if self._write_tb_profile: - tf.summary.trace_on(graph=True, profiler=True) - - for ib, batch in enumerate(batch_handler): - trained_gen = False - trained_disc = False - b_loss_details = {} - loss_disc = loss_details['train_loss_disc'] - disc_too_good = loss_disc <= disc_th_low - disc_too_bad = (loss_disc > disc_th_high) and train_disc - gen_too_good = disc_too_bad - - if not self.generator_weights: - self.init_weights(batch.low_res.shape, batch.high_res.shape) - - if only_gen or (train_gen and not gen_too_good): - trained_gen = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.generator_weights, - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer, - train_gen=True, - train_disc=False, - multi_gpu=multi_gpu, - ) - - if only_disc or (train_disc and not disc_too_good): - trained_disc = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.discriminator_weights, - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer_disc, - train_gen=False, - train_disc=True, - multi_gpu=multi_gpu, - ) - - b_loss_details['gen_trained_frac'] = float(trained_gen) - b_loss_details['disc_trained_frac'] = float(trained_disc) - - self.dict_to_tensorboard(b_loss_details) - self.dict_to_tensorboard(self.timer.log) - - loss_details = self.update_loss_details( - loss_details, - b_loss_details, - batch_handler.batch_size, - prefix='train_', - ) - logger.debug( - 'Batch {} out of {} has epoch-average ' - '(gen / disc) loss of: ({:.2e} / {:.2e}). ' - 'Trained (gen / disc): ({} / {})'.format( - ib + 1, - len(batch_handler), - loss_details['train_loss_gen'], - loss_details['train_loss_disc'], - trained_gen, - trained_disc, - ) - ) - if all([not trained_gen, not trained_disc]): - msg = ( - 'For some reason none of the GAN networks trained ' - 'during batch {} out of {}!'.format(ib, len(batch_handler)) - ) - logger.warning(msg) - warn(msg) - self.total_batches += 1 - - loss_details['total_batches'] = int(self.total_batches) - self.profile_to_tensorboard('training_epoch') - return loss_details - def update_adversarial_weights( self, history, @@ -795,7 +560,7 @@ def update_adversarial_weights( Parameters ---------- - history : dict + history : dicts Dictionary with information on how often discriminators were trained during current and previous epochs. adaptive_update_fraction : float @@ -985,6 +750,7 @@ def train( ) for epoch in epochs: + t_epoch = time.time() loss_details = self.train_epoch( batch_handler, weight_gen_advers, @@ -1048,7 +814,341 @@ def train( early_stop_n_epoch, extras=extras, ) + logger.info( + 'Finished training epoch in {:.4f} seconds'.format( + time.time() - t_epoch + ) + ) if stop: break + logger.info( + 'Finished training {} epochs in {:.4f} seconds'.format( + n_epoch, + time.time() - t0, + ) + ) batch_handler.stop() + + @tf.function + def calc_loss( + self, + hi_res_true, + hi_res_gen, + weight_gen_advers=0.001, + train_gen=True, + train_disc=False, + ): + """Calculate the GAN loss function using generated and true high + resolution data. + + Parameters + ---------- + hi_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + hi_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + train_gen : bool + True if generator is being trained, then loss=loss_gen + train_disc : bool + True if disc is being trained, then loss=loss_disc + + Returns + ------- + loss : tf.Tensor + 0D tensor representing the loss value for the network being trained + (either generator or one of the discriminators) + loss_details : dict + Namespace of the breakdown of loss components + """ + hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) + + if hi_res_gen.shape != hi_res_true.shape: + msg = ( + 'The tensor shapes of the synthetic output {} and ' + 'true high res {} did not have matching shape! ' + 'Check the spatiotemporal enhancement multipliers in your ' + 'your model config and data handlers.'.format( + hi_res_gen.shape, hi_res_true.shape + ) + ) + logger.error(msg) + raise RuntimeError(msg) + + disc_out_true = self._tf_discriminate(hi_res_true) + disc_out_gen = self._tf_discriminate(hi_res_gen) + + loss_gen_content = self.calc_loss_gen_content(hi_res_true, hi_res_gen) + loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) + loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers + loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) + + loss = None + if train_gen: + loss = loss_gen + elif train_disc: + loss = loss_disc + + loss_details = { + 'loss_gen': loss_gen, + 'loss_gen_content': loss_gen_content, + 'loss_gen_advers': loss_gen_advers, + 'loss_disc': loss_disc, + } + + return loss, loss_details + + def _calc_val_loss(self, batch, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + for a given batch + + Parameters + ---------- + batch : DsetTuple + Object with ``.high_res`` and ``.low_res`` arrays + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input but now includes val_* loss info + """ + _, v_loss_details, _ = self._get_hr_exo_and_loss( + batch.low_res, + batch.high_res, + weight_gen_advers=weight_gen_advers, + train_gen=False, + train_disc=False, + ) + loss_details = self.update_loss_details( + loss_details, v_loss_details, len(batch), prefix='val_' + ) + return loss_details + + def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + + Parameters + ---------- + batch_handler : sup3r.preprocessing.BatchHandler + BatchHandler object to iterate through + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input but now includes val_* loss info + """ + logger.debug('Starting end-of-epoch validation loss calculation...') + loss_details['n_obs'] = 0 + for val_batch in batch_handler.val_data: + loss_details = self._calc_val_loss( + val_batch, weight_gen_advers, loss_details + ) + return loss_details + + def _get_batch_loss_details( + self, + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu=False, + ): + """Get loss details for a given batch for the current epoch. + + Parameters + ---------- + batch : sup3r.preprocessing.base.DsetTuple + Object with ``.low_res`` and ``.high_res`` arrays + train_gen : bool + Flag whether to train the generator for this set of epochs + only_gen : bool + Flag whether to only train the generator for this set of epochs + gen_too_good : bool + Flag whether to skip training the generator and only train the + discriminator, due to superior performance, for this batch. + train_disc : bool + Flag whether to train the discriminator for this set of epochs + only_disc : bool + Flag whether to only train the discriminator for this set of epochs + gen_too_good : bool + Flag whether to skip training the discriminator and only train the + generator, due to superior performance, for this batch. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components for the given batch + """ + + trained_gen = False + trained_disc = False + b_loss_details = {} + if only_gen or (train_gen and not gen_too_good): + trained_gen = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.generator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer, + train_gen=True, + train_disc=False, + multi_gpu=multi_gpu, + ) + + if only_disc or (train_disc and not disc_too_good): + trained_disc = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.discriminator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer_disc, + train_gen=False, + train_disc=True, + multi_gpu=multi_gpu, + ) + + b_loss_details['gen_trained_frac'] = float(trained_gen) + b_loss_details['disc_trained_frac'] = float(trained_disc) + return b_loss_details + + def train_epoch( + self, + batch_handler, + weight_gen_advers, + train_gen, + train_disc, + disc_loss_bounds, + multi_gpu=False, + ): + """Train the GAN for one epoch. + + Parameters + ---------- + batch_handler : sup3r.preprocessing.BatchHandler + BatchHandler object to iterate through + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + train_gen : bool + Flag whether to train the generator for this set of epochs + train_disc : bool + Flag whether to train the discriminator for this set of epochs + disc_loss_bounds : tuple + Lower and upper bounds for the discriminator loss outside of which + the discriminators will not train unless train_disc=True or + and train_gen=False. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components + """ + + disc_th_low = np.min(disc_loss_bounds) + disc_th_high = np.max(disc_loss_bounds) + loss_details = {'n_obs': 0, 'train_loss_disc': 0} + + only_gen = train_gen and not train_disc + only_disc = train_disc and not train_gen + + if self._write_tb_profile: + tf.summary.trace_on(graph=True, profiler=True) + + for ib, batch in enumerate(batch_handler): + b_loss_details = {} + loss_disc = loss_details['train_loss_disc'] + disc_too_good = loss_disc <= disc_th_low + disc_too_bad = (loss_disc > disc_th_high) and train_disc + gen_too_good = disc_too_bad + + if not self.generator_weights: + self.init_weights(batch.low_res.shape, batch.high_res.shape) + + b_loss_details = self._get_batch_loss_details( + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu, + ) + trained_gen = bool(b_loss_details.get('gen_trained_frac', False)) + trained_disc = bool(b_loss_details.get('disc_trained_frac', False)) + + self.dict_to_tensorboard(b_loss_details) + self.dict_to_tensorboard(self.timer.log) + + loss_details = self.update_loss_details( + loss_details, + b_loss_details, + batch_handler.batch_size, + prefix='train_', + ) + logger.debug( + 'Batch {} out of {} has epoch-average ' + '(gen / disc) loss of: ({:.2e} / {:.2e}). ' + 'Trained (gen / disc): ({} / {})'.format( + ib + 1, + len(batch_handler), + loss_details['train_loss_gen'], + loss_details['train_loss_disc'], + trained_gen, + trained_disc, + ) + ) + if all([not trained_gen, not trained_disc]): + msg = ( + 'For some reason none of the GAN networks trained ' + 'during batch {} out of {}!'.format(ib, len(batch_handler)) + ) + logger.warning(msg) + warn(msg) + + self.total_batches += len(batch_handler) + loss_details['total_batches'] = int(self.total_batches) + self.profile_to_tensorboard('training_epoch') + return loss_details diff --git a/sup3r/models/conditional.py b/sup3r/models/conditional.py index a5025419e..6faed1a6f 100644 --- a/sup3r/models/conditional.py +++ b/sup3r/models/conditional.py @@ -12,7 +12,8 @@ from sup3r.utilities import VERSION_RECORD -from .abstract import AbstractInterface, AbstractSingleModel +from .abstract import AbstractSingleModel +from .interface import AbstractInterface logger = logging.getLogger(__name__) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py new file mode 100644 index 000000000..25f63f410 --- /dev/null +++ b/sup3r/models/interface.py @@ -0,0 +1,511 @@ +"""Abstract class defining the required interface for Sup3r model subclasses""" + +import json +import locale +import logging +import os +import re +from abc import ABC, abstractmethod +from warnings import warn + +import numpy as np +from phygnn import CustomNetwork +from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat + +from sup3r.preprocessing.data_handlers import ExoData +from sup3r.utilities import VERSION_RECORD +from sup3r.utilities.utilities import safe_cast + +logger = logging.getLogger(__name__) + + +class AbstractInterface(ABC): + """ + Abstract class to define the required interface for Sup3r model subclasses + + Note that this only sets the required interfaces for a GAN that can be + loaded from disk and used to predict synthetic outputs. The interface for + models that can be trained will be set in another class. + """ + + @classmethod + @abstractmethod + def load(cls, model_dir, verbose=True): + """Load the GAN with its sub-networks from a previously saved-to output + directory. + + Parameters + ---------- + model_dir + Directory to load GAN model files from. + verbose : bool + Flag to log information about the loaded model. + + Returns + ------- + out : BaseModel + Returns a pretrained gan model that was previously saved to + model_dir + """ + + @abstractmethod + def generate( + self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None + ): + """Use the generator model to generate high res data from low res + input. This is the public generate function.""" + + @staticmethod + def seed(s=0): + """ + Set the random seed for reproducible results. + + Parameters + ---------- + s : int + Random seed + """ + CustomNetwork.seed(s=s) + + @property + def input_dims(self): + """Get dimension of model generator input. This is usually 4D for + spatial models and 5D for spatiotemporal models. This gives the input + to the first step if the model is multi-step. Returns 5 for linear + models. + + Returns + ------- + int + """ + # pylint: disable=E1101 + if hasattr(self, '_gen'): + return self._gen.layers[0].rank + if hasattr(self, 'models'): + return self.models[0].input_dims + return 5 + + @property + def max_paddings(self): + """Get the maximum padding values used by the generator. This is used + to apply extra padding during forward passes if the raw input doesn't + meet the minimum input shape.""" + + paddings = (1, 1, 1) + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if hasattr(layer, 'paddings'): + new_pw = np.max(layer.paddings, axis=1)[1:-1] + if len(new_pw) < 3: + new_pw = (*new_pw, 1) + paddings = [ + np.max((new_pw[i], paddings[i])) for i in range(3) + ] + return paddings + if hasattr(self, 'models'): + return self.models[0].max_paddings + return paddings + + @property + def is_5d(self): + """Check if model expects spatiotemporal input""" + return self.input_dims == 5 + + @property + def is_4d(self): + """Check if model expects spatial only input""" + return self.input_dims == 4 + + # pylint: disable=E1101 + def get_s_enhance_from_layers(self): + """Compute factor by which model will enhance spatial resolution from + layer attributes. Used in model training during high res coarsening""" + s_enhance = None + if hasattr(self, '_gen'): + s_enhancements = [ + getattr(layer, '_spatial_mult', 1) + for layer in self._gen.layers + ] + s_enhance = int(np.prod(s_enhancements)) + return s_enhance + + # pylint: disable=E1101 + def get_t_enhance_from_layers(self): + """Compute factor by which model will enhance temporal resolution from + layer attributes. Used in model training during high res coarsening""" + t_enhance = None + if hasattr(self, '_gen'): + t_enhancements = [ + getattr(layer, '_temporal_mult', 1) + for layer in self._gen.layers + ] + t_enhance = int(np.prod(t_enhancements)) + return t_enhance + + @property + def s_enhance(self): + """Factor by which model will enhance spatial resolution. Used in + model training during high res coarsening and also in forward pass + routine to determine shape of needed exogenous data""" + models = getattr(self, 'models', [self]) + s_enhances = [m.meta.get('s_enhance', None) for m in models] + s_enhance = ( + self.get_s_enhance_from_layers() + if any(s is None for s in s_enhances) + else int(np.prod(s_enhances)) + ) + if len(models) == 1: + self.meta['s_enhance'] = s_enhance + return s_enhance + + @property + def t_enhance(self): + """Factor by which model will enhance temporal resolution. Used in + model training during high res coarsening and also in forward pass + routine to determine shape of needed exogenous data""" + models = getattr(self, 'models', [self]) + t_enhances = [m.meta.get('t_enhance', None) for m in models] + t_enhance = ( + self.get_t_enhance_from_layers() + if any(t is None for t in t_enhances) + else int(np.prod(t_enhances)) + ) + if len(models) == 1: + self.meta['t_enhance'] = t_enhance + return t_enhance + + @property + def s_enhancements(self): + """List of spatial enhancement factors. In the case of a single step + model this is just ``[self.s_enhance]``. This is used to determine + shapes of needed exogenous data in forward pass routine""" + if hasattr(self, 'models'): + return [model.s_enhance for model in self.models] + return [self.s_enhance] + + @property + def t_enhancements(self): + """List of temporal enhancement factors. In the case of a single step + model this is just ``[self.t_enhance]``. This is used to determine + shapes of needed exogenous data in forward pass routine""" + if hasattr(self, 'models'): + return [model.t_enhance for model in self.models] + return [self.t_enhance] + + @property + def input_resolution(self): + """Resolution of input data. Given as a dictionary + ``{'spatial': '...km', 'temporal': '...min'}``. The numbers are + required to be integers in the units specified. The units are not + strict as long as the resolution of the exogenous data, when extracting + exogenous data, is specified in the same units.""" + input_resolution = self.meta.get('input_resolution', None) + msg = 'model.input_resolution is None. This needs to be set.' + assert input_resolution is not None, msg + return input_resolution + + def _get_numerical_resolutions(self): + """Get the input and output resolutions without units. e.g. for + ``{"spatial": "30km", "temporal": "60min"}`` this returns + ``{"spatial": 30, "temporal": 60}``""" + ires_num = { + k: int(re.search(r'\d+', v).group(0)) + for k, v in self.input_resolution.items() + } + enhancements = {'spatial': self.s_enhance, 'temporal': self.t_enhance} + ores_num = {k: v // enhancements[k] for k, v in ires_num.items()} + return ires_num, ores_num + + def _ensure_valid_input_resolution(self): + """Ensure ehancement factors evenly divide input_resolution""" + + if self.input_resolution is None: + return + + ires_num, ores_num = self._get_numerical_resolutions() + s_enhance = self.meta['s_enhance'] + t_enhance = self.meta['t_enhance'] + check = ( + ires_num['temporal'] / ores_num['temporal'] == t_enhance + and ires_num['spatial'] / ores_num['spatial'] == s_enhance + ) + msg = ( + f'Enhancement factors (s_enhance={s_enhance}, ' + f't_enhance={t_enhance}) do not evenly divide ' + f'input resolution ({self.input_resolution})' + ) + if not check: + logger.error(msg) + raise RuntimeError(msg) + + def _ensure_valid_enhancement_factors(self): + """Ensure user provided enhancement factors are the same as those + computed from layer attributes""" + t_enhance = self.meta.get('t_enhance', None) + s_enhance = self.meta.get('s_enhance', None) + if s_enhance is None or t_enhance is None: + return + + layer_se = self.get_s_enhance_from_layers() + layer_te = self.get_t_enhance_from_layers() + layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] + layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] + msg = ( + f'Enhancement factors computed from layer attributes ' + f'(s_enhance={layer_se}, t_enhance={layer_te}) ' + f'conflict with user provided values (s_enhance={s_enhance}, ' + f't_enhance={t_enhance})' + ) + check = layer_se == s_enhance or layer_te == t_enhance + if not check: + logger.error(msg) + raise RuntimeError(msg) + + @property + def output_resolution(self): + """Resolution of output data. Given as a dictionary + {'spatial': '...km', 'temporal': '...min'}. This is computed from the + input resolution and the enhancement factors.""" + output_res = self.meta.get('output_resolution', None) + if self.input_resolution is not None and output_res is None: + ires_num, ores_num = self._get_numerical_resolutions() + output_res = { + k: v.replace(str(ires_num[k]), str(ores_num[k])) + for k, v in self.input_resolution.items() + } + self.meta['output_resolution'] = output_res + return output_res + + def _combine_fwp_input(self, low_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data prior + to forward pass through generator + + Parameters + ---------- + low_res : np.ndarray + Low-resolution input data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | ExoData | None + Special dictionary (class:`ExoData`) of exogenous feature data with + entries describing whether features should be combined at input, a + mid network layer, or with output. This doesn't have to include + the 'model' key since this data is for a single step model. + + Returns + ------- + low_res : np.ndarray + Low-resolution input data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + if exogenous_data is None: + return low_res + + if ( + not isinstance(exogenous_data, ExoData) + and exogenous_data is not None + ): + exogenous_data = ExoData(exogenous_data) + + fnum_diff = len(self.lr_features) - low_res.shape[-1] + exo_feats = [] if fnum_diff <= 0 else self.lr_features[-fnum_diff:] + msg = ( + f'Provided exogenous_data: {exogenous_data} is missing some ' + f'required features ({exo_feats})' + ) + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + exo_input = exogenous_data.get_combine_type_data( + feature, 'input' + ) + if exo_input is not None: + low_res = np.concatenate((low_res, exo_input), axis=-1) + + return low_res + + def _combine_fwp_output(self, hi_res, exogenous_data=None): + """Combine exogenous_data at output resolution with generated hi_res + data following forward pass output. + + Parameters + ---------- + hi_res : np.ndarray + High-resolution output data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | ExoData | None + Special dictionary (class:`ExoData`) of exogenous feature data with + entries describing whether features should be combined at input, a + mid network layer, or with output. This doesn't have to include + the 'model' key since this data is for a single step model. + + Returns + ------- + hi_res : np.ndarray + High-resolution output data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + if exogenous_data is None: + return hi_res + + if ( + not isinstance(exogenous_data, ExoData) + and exogenous_data is not None + ): + exogenous_data = ExoData(exogenous_data) + + fnum_diff = len(self.hr_out_features) - hi_res.shape[-1] + exo_feats = [] if fnum_diff <= 0 else self.hr_out_features[-fnum_diff:] + msg = ( + 'Provided exogenous_data is missing some required features ' + f'({exo_feats})' + ) + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + exo_output = exogenous_data.get_combine_type_data( + feature, 'output' + ) + if exo_output is not None: + hi_res = np.concatenate((hi_res, exo_output), axis=-1) + return hi_res + + @property + @abstractmethod + def meta(self): + """Get meta data dictionary that defines how the model was created""" + + @property + def lr_features(self): + """Get a list of low-resolution features input to the generative model. + This includes low-resolution features that might be supplied + exogenously at inference time but that were in the low-res batches + during training""" + return self.meta.get('lr_features', []) + + @property + def hr_out_features(self): + """Get the list of high-resolution output feature names that the + generative model outputs.""" + return self.meta.get('hr_out_features', []) + + @property + def hr_exo_features(self): + """Get list of high-resolution exogenous filter names the model uses. + If the model has N concat or add layers this list will be the last N + features in the training features list. The ordering is assumed to be + the same as the order of concat or add layers. If training features is + [..., topo, sza], and the model has 2 concat or add layers, exo + features will be [topo, sza]. Topo will then be used in the first + concat layer and sza will be used in the second""" + # pylint: disable=E1101 + features = [] + if hasattr(self, '_gen'): + features = [ + layer.name + for layer in self._gen.layers + if isinstance(layer, (Sup3rAdder, Sup3rConcat)) + ] + return features + + @property + def smoothing(self): + """Value of smoothing parameter used in gaussian filtering of coarsened + high res data.""" + return self.meta.get('smoothing', None) + + @property + def smoothed_features(self): + """Get the list of smoothed input feature names that the generative + model was trained on.""" + return self.meta.get('smoothed_features', []) + + @property + def model_params(self): + """ + Model parameters, used to save model to disc + + Returns + ------- + dict + """ + return {'meta': self.meta} + + @property + def version_record(self): + """A record of important versions that this model was built with. + + Returns + ------- + dict + """ + return VERSION_RECORD + + def set_model_params(self, **kwargs): + """Set parameters used for training the model + + Parameters + ---------- + kwargs : dict + Keyword arguments including 'input_resolution', + 'lr_features', 'hr_exo_features', 'hr_out_features', + 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + """ + + keys = ( + 'input_resolution', + 'lr_features', + 'hr_exo_features', + 'hr_out_features', + 'smoothed_features', + 's_enhance', + 't_enhance', + 'smoothing', + ) + keys = [k for k in keys if k in kwargs] + + hr_exo_feat = kwargs.get('hr_exo_features', []) + msg = ( + f'Expected high-res exo features {self.hr_exo_features} ' + f'based on model architecture but received "hr_exo_features" ' + f'from data handler: {hr_exo_feat}' + ) + assert list(self.hr_exo_features) == list(hr_exo_feat), msg + + for var in keys: + val = self.meta.get(var, None) + if val is None: + self.meta[var] = kwargs[var] + elif val != kwargs[var]: + msg = ( + 'Model was previously trained with {var}={} but ' + 'received new {var}={}'.format(val, kwargs[var], var=var) + ) + logger.warning(msg) + warn(msg) + + self._ensure_valid_enhancement_factors() + self._ensure_valid_input_resolution() + + def save_params(self, out_dir): + """ + Parameters + ---------- + out_dir : str + Directory to save linear model params. This directory will be + created if it does not already exist. + """ + if not os.path.exists(out_dir): + os.makedirs(out_dir, exist_ok=True) + + fp_params = os.path.join(out_dir, 'model_params.json') + with open( + fp_params, 'w', encoding=locale.getpreferredencoding(False) + ) as f: + params = self.model_params + json.dump(params, f, sort_keys=True, indent=2, default=safe_cast) diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index 2c00a02b5..8ca9f0a0a 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -6,7 +6,7 @@ import numpy as np -from .abstract import AbstractInterface +from .interface import AbstractInterface from .utilities import st_interp logger = logging.getLogger(__name__) diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index d447633cf..e0d33cca0 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -14,8 +14,8 @@ import sup3r.models from sup3r.preprocessing.data_handlers import ExoData -from .abstract import AbstractInterface from .base import Sup3rGan +from .interface import AbstractInterface logger = logging.getLogger(__name__) diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index 8e825f124..dee3c51bc 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -1,13 +1,17 @@ """Utilities shared across the `sup3r.models` module""" import logging +import os import sys import threading import numpy as np +import tensorflow as tf from scipy.interpolate import RegularGridInterpolator from tensorflow.keras import optimizers +from sup3r.utilities.utilities import Timer + logger = logging.getLogger(__name__) @@ -58,6 +62,80 @@ def run(self): model_thread.join() +class TensorboardMixIn: + """MixIn class for tensorboard logging and profiling.""" + + def __init__(self): + self._tb_writer = None + self._tb_log_dir = None + self._write_tb_profile = False + self._total_batches = None + self._history = None + self.timer = Timer() + + @property + def total_batches(self): + """Record of total number of batches for logging.""" + if self._total_batches is None: + if self._history is not None and 'total_batches' in self._history: + self._total_batches = self._history['total_batches'].values[-1] + else: + self._total_batches = 0 + return self._total_batches + + @total_batches.setter + def total_batches(self, value): + """Set total number of batches.""" + self._total_batches = value + + def dict_to_tensorboard(self, entry): + """Write data to tensorboard log file. This is usually a loss_details + dictionary. + + Parameters + ---------- + entry: dict + Dictionary of values to write to tensorboard log file + """ + if self._tb_writer is not None: + with self._tb_writer.as_default(): + for name, value in entry.items(): + if isinstance(value, str): + tf.summary.text(name, value, self.total_batches) + else: + tf.summary.scalar(name, value, self.total_batches) + + def profile_to_tensorboard(self, name): + """Write profile data to tensorboard log file. + + Parameters + ---------- + name : str + Tag name to use for profile info + """ + if self._tb_writer is not None and self._write_tb_profile: + with self._tb_writer.as_default(): + tf.summary.trace_export( + name=name, + step=self.total_batches, + profiler_outdir=self._tb_log_dir, + ) + + def _init_tensorboard_writer(self, out_dir): + """Initialize the ``tf.summary.SummaryWriter`` to use for writing + tensorboard compatible log files. + + Parameters + ---------- + out_dir : str + Standard out_dir where model epochs are saved. e.g. './gan_{epoch}' + """ + tb_log_pardir = os.path.abspath(os.path.join(out_dir, os.pardir)) + self._tb_log_dir = os.path.join(tb_log_pardir, 'logs') + os.makedirs(self._tb_log_dir, exist_ok=True) + self._tb_writer = tf.summary.create_file_writer(self._tb_log_dir) + + def get_optimizer_class(conf): """Get optimizer class from keras""" if hasattr(optimizers, conf['name']): diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py new file mode 100644 index 000000000..1e25014ec --- /dev/null +++ b/sup3r/models/with_obs.py @@ -0,0 +1,362 @@ +"""Sup3r model with training on observation data.""" + +import logging +import time +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import tensorflow as tf +from tensorflow.keras.losses import MeanAbsoluteError + +from .base import Sup3rGan + +logger = logging.getLogger(__name__) + + +class Sup3rGanWithObs(Sup3rGan): + """Sup3r GAN model with additional observation data content loss.""" + + def _calc_val_loss(self, batch, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + for a given batch + + Parameters + ---------- + batch : DsetTuple + Object with ``.high_res``, ``.low_res``, and ``.obs`` arrays + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input with updated val_* loss info + """ + val_exo_data = self.get_high_res_exo_input(batch.high_res) + high_res_gen = self._tf_generate(batch.low_res, val_exo_data) + _, v_loss_details = self.calc_loss( + batch.high_res, + high_res_gen, + weight_gen_advers=weight_gen_advers, + train_gen=False, + train_disc=False, + ) + v_loss_details['loss_obs'] = self.cal_loss_obs(batch.obs, high_res_gen) + + loss_details = self.update_loss_details( + loss_details, v_loss_details, len(batch), prefix='val_' + ) + return loss_details + + def _get_batch_loss_details( + self, + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu=False, + ): + """Get loss details for a given batch for the current epoch. + + Parameters + ---------- + batch : sup3r.preprocessing.base.DsetTuple + Object with ``.low_res``, ``.high_res``, and ``.obs`` arrays + train_gen : bool + Flag whether to train the generator for this set of epochs + only_gen : bool + Flag whether to only train the generator for this set of epochs + gen_too_good : bool + Flag whether to skip training the generator and only train the + discriminator, due to superior performance, for this batch. + train_disc : bool + Flag whether to train the discriminator for this set of epochs + only_disc : bool + Flag whether to only train the discriminator for this set of epochs + gen_too_good : bool + Flag whether to skip training the discriminator and only train the + generator, due to superior performance, for this batch. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components for the given batch + """ + trained_gen = False + trained_disc = False + if only_gen or (train_gen and not gen_too_good): + trained_gen = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.generator_weights, + obs_data=getattr(batch, 'obs', None), + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer, + train_gen=True, + train_disc=False, + multi_gpu=multi_gpu, + ) + + if only_disc or (train_disc and not disc_too_good): + trained_disc = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.discriminator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer_disc, + train_gen=False, + train_disc=True, + multi_gpu=multi_gpu, + ) + + b_loss_details['gen_trained_frac'] = float(trained_gen) + b_loss_details['disc_trained_frac'] = float(trained_disc) + + if 'loss_obs' in b_loss_details and not tf.math.is_nan( + b_loss_details['loss_obs'] + ): + loss_update = b_loss_details['loss_gen'] + loss_update += b_loss_details['loss_obs'] + b_loss_details.update({'loss_gen': loss_update}) + return b_loss_details + + def _get_parallel_grad( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + **calc_loss_kwargs, + ): + """Compute gradient for one mini-batch of (low_res, hi_res_true, + obs_data) across multiple GPUs. Can include observation data as well. + """ + + futures = [] + start_time = time.time() + lr_chunks = np.array_split(low_res, len(self.gpu_list)) + hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) + obs_data_chunks = ( + [None] * len(hr_true_chunks) + if obs_data is None + else np.array_split(obs_data, len(self.gpu_list)) + ) + split_mask = False + mask_chunks = None + if 'mask' in calc_loss_kwargs: + split_mask = True + mask_chunks = np.array_split( + calc_loss_kwargs['mask'], len(self.gpu_list) + ) + + with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: + for i in range(len(self.gpu_list)): + if split_mask: + calc_loss_kwargs['mask'] = mask_chunks[i] + futures.append( + exe.submit( + self.get_single_grad, + lr_chunks[i], + hr_true_chunks[i], + training_weights, + obs_data=obs_data_chunks[i], + device_name=f'/gpu:{i}', + **calc_loss_kwargs, + ) + ) + + return self._sum_parallel_grad(futures, start_time=start_time) + + def run_gradient_descent( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + optimizer=None, + multi_gpu=False, + **calc_loss_kwargs, + ): + """Run gradient descent for one mini-batch of (low_res, hi_res_true) + and update weights + + Parameters + ---------- + low_res : np.ndarray + Real low-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + hi_res_true : np.ndarray + Real high-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + training_weights : list + A list of layer weights that are to-be-trained based on the + current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + optimizer : tf.keras.optimizers.Optimizer + Optimizer class to use to update weights. This can be different if + you're training just the generator or one of the discriminator + models. Defaults to the generator optimizer. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + calc_loss_kwargs : dict + Kwargs to pass to the self.calc_loss() method + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components + """ + + self.timer.start() + if optimizer is None: + optimizer = self.optimizer + + if not multi_gpu or len(self.gpu_list) < 2: + grad, loss_details = self.get_single_grad( + low_res, + hi_res_true, + training_weights, + obs_data=obs_data, + device_name=self.default_device, + **calc_loss_kwargs, + ) + optimizer.apply_gradients(zip(grad, training_weights)) + self.timer.stop() + logger.debug( + 'Finished single gradient descent step in %s', + self.timer.elapsed_str, + ) + else: + total_grad, loss_details = self._get_parallel_grad( + low_res, + hi_res_true, + training_weights, + obs_data, + **calc_loss_kwargs, + ) + optimizer.apply_gradients(zip(total_grad, training_weights)) + + return loss_details + + @tf.function + def get_single_grad( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + device_name=None, + **calc_loss_kwargs, + ): + """Run gradient descent for one mini-batch of (low_res, hi_res_true), + do not update weights, just return gradient details. + + Parameters + ---------- + low_res : np.ndarray + Real low-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + hi_res_true : np.ndarray + Real high-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + training_weights : list + A list of layer weights that are to-be-trained based on the + current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + device_name : None | str + Optional tensorflow device name for GPU placement. Note that if a + GPU is available, variables will be placed on that GPU even if + device_name=None. + calc_loss_kwargs : dict + Kwargs to pass to the self.calc_loss() method + + Returns + ------- + grad : list + a list or nested structure of Tensors (or IndexedSlices, or None, + or CompositeTensor) representing the gradients for the + training_weights + loss_details : dict + Namespace of the breakdown of loss components + """ + with tf.device(device_name), tf.GradientTape( + watch_accessed_variables=False + ) as tape: + tape.watch(training_weights) + loss, loss_details, hi_res_gen = self._get_hr_exo_and_loss( + low_res, hi_res_true, **calc_loss_kwargs + ) + loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + if not tf.reduce_any(tf.math.is_nan(loss_obs)): + loss += loss_obs + loss_details.update({'loss_obs': loss_obs}) + grad = tape.gradient(loss, training_weights) + return grad, loss_details + + @tf.function + def calc_loss_obs(self, obs_data, hi_res_gen): + """Calculate loss term for the observation data vs generated + high-resolution data + + Parameters + ---------- + obs_data : tf.Tensor | None + Observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + hi_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + loss : tf.Tensor + 0D tensor of observation loss + """ + loss_obs = tf.constant(np.nan) + if obs_data is not None: + mask = tf.math.is_nan(obs_data) + masked_obs = obs_data[~mask] + if len(masked_obs) > 0: + loss_obs = MeanAbsoluteError()( + masked_obs, + hi_res_gen[..., : len(self.hr_out_features)][~mask], + ) + return loss_obs diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index 332d62a31..424e0d0ef 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -3,7 +3,7 @@ import itertools as it import logging from dataclasses import dataclass -from typing import Union +from typing import Optional, Union from warnings import warn import numpy as np @@ -27,9 +27,23 @@ class ForwardPassSlicer: time_steps : int Number of time steps for full temporal domain of low res data. This is used to construct a dummy_time_index from np.arange(time_steps) + s_enhance : int + Spatial enhancement factor + t_enhance : int + Temporal enhancement factor time_slice : slice | list Slice to use to extract range from time_index. Can be a ``slice(start, stop, step)`` or list ``[start, stop, step]`` + temporal_pad : int + Size of temporal overlap between coarse chunks passed to forward + passes for subsequent temporal stitching. This overlap will pad + both sides of the fwp_chunk_shape. Note that the first and last + chunks in the temporal dimension will not be padded. + spatial_pad : int + Size of spatial overlap between coarse chunks passed to forward + passes for subsequent spatial stitching. This overlap will pad both + sides of the fwp_chunk_shape. Note that the first and last chunks + in any of the spatial dimension will not be padded. chunk_shape : tuple Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse chunk to use for a forward pass. The number of nodes that the @@ -41,20 +55,11 @@ class ForwardPassSlicer: to the generator can be bigger than this shape. If running in serial set this equal to the shape of the full spatiotemporal data volume for best performance. - s_enhance : int - Spatial enhancement factor - t_enhance : int - Temporal enhancement factor - spatial_pad : int - Size of spatial overlap between coarse chunks passed to forward - passes for subsequent spatial stitching. This overlap will pad both - sides of the fwp_chunk_shape. Note that the first and last chunks - in any of the spatial dimension will not be padded. - temporal_pad : int - Size of temporal overlap between coarse chunks passed to forward - passes for subsequent temporal stitching. This overlap will pad - both sides of the fwp_chunk_shape. Note that the first and last - chunks in the temporal dimension will not be padded. + min_width : tuple + Minimum width of padded slices, with each element providing the min + width for the corresponding dimension. e.g. (spatial_1, spatial_2, + temporal). This is used to make sure generator network input meets the + minimum size requirement for padding layers. """ coarse_shape: Union[tuple, list] @@ -65,6 +70,7 @@ class ForwardPassSlicer: temporal_pad: int spatial_pad: int chunk_shape: Union[tuple, list] + min_width: Optional[Union[tuple, list]] = None @log_args def __post_init__(self): @@ -88,6 +94,9 @@ def __post_init__(self): self._s_hr_crop_slices = None self._t_hr_crop_slices = None self._hr_crop_slices = None + self.min_width = ( + self.chunk_shape if self.min_width is None else self.min_width + ) def get_spatial_slices(self): """Get spatial slices for small data chunks that are passed through @@ -254,6 +263,8 @@ def s1_hr_crop_slices(self): self._s1_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s1_lr_slices, cropped_slices=self._s1_hr_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=0, ) return self._s1_hr_crop_slices @@ -273,6 +284,8 @@ def s2_hr_crop_slices(self): self._s2_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s2_lr_slices, cropped_slices=self._s2_hr_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=1, ) return self._s2_hr_crop_slices @@ -313,6 +326,8 @@ def s_lr_crop_slices(self): s1_crop_slices = self.check_boundary_slice( unpadded_slices=self.s1_lr_slices, cropped_slices=s1_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=0, ) s2_crop_slices = self.get_cropped_slices( @@ -321,6 +336,8 @@ def s_lr_crop_slices(self): s2_crop_slices = self.check_boundary_slice( unpadded_slices=self.s2_lr_slices, cropped_slices=s2_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=1, ) self._s_lr_crop_slices = list( @@ -525,21 +542,22 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None): pad_slices.append(slice(start, end, step)) return pad_slices - def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): + def check_boundary_slice( + self, unpadded_slices, cropped_slices, enhancement, padding, dim + ): """Check cropped slice at the right boundary for minimum shape. It is possible for the forward pass chunk shape to divide the grid size such that the last slice (right boundary) does not meet the minimum - number of elements. (Padding layers in the generator typically require - a minimum shape of 4). e.g. ``grid_size = (8, 8)`` with - ``fwp_chunk_shape = (7, 7, ...)`` results in unpadded slices with just - one element. If the padding is 0 or 1 these padded slices have length - less than 4. When this minimum shape is not met we apply extra padding - in :meth:`self._get_pad_width`. Cropped slices have to be adjusted to + number of elements. (Padding layers in the generator require a minimum + shape). e.g. ``grid_size = (8, 8)`` with ``fwp_chunk_shape = (7, 7, + ...)`` results in unpadded slices with just one element. When this + minimum shape is not met we apply extra padding in + :meth:`self._get_pad_width`. Cropped slices have to be adjusted to account for this here.""" warn_msg = ( - 'The final spatial slice for dimension #%s is too small ' + 'The final slice for dimension #%s is too small ' '(slice=slice(%s, %s), padding=%s). The start of this slice will ' 'be reduced to try to meet the minimum slice length.' ) @@ -548,19 +566,22 @@ def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): lr_slice_stop = unpadded_slices[-1].stop or self.coarse_shape[dim] # last slice adjustment - if 2 * self.spatial_pad + (lr_slice_stop - lr_slice_start) < 4: + if ( + 2 * padding + (lr_slice_stop - lr_slice_start) + <= self.min_width[dim] + ): + half_width = self.min_width[dim] // 2 + 1 logger.warning( warn_msg, dim + 1, lr_slice_start, lr_slice_stop, - self.spatial_pad, + padding, ) - warn( - warn_msg - % (dim + 1, lr_slice_start, lr_slice_stop, self.spatial_pad) + warn(warn_msg % (dim + 1, lr_slice_start, lr_slice_stop, padding)) + cropped_slices[-1] = slice( + half_width * enhancement, -half_width * enhancement ) - cropped_slices[-1] = slice(2 * self.s_enhance, -2 * self.s_enhance) return cropped_slices @@ -600,7 +621,9 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement): return cropped_slices @staticmethod - def _get_pad_width(window, max_steps, max_pad, check_boundary=False): + def _get_pad_width( + window, max_steps, max_pad, min_width=None, check_boundary=False + ): """ Parameters ---------- @@ -610,6 +633,10 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False): Maximum number of steps available. Padding cannot extend past this max_pad : int Maximum amount of padding to apply. + min_width : int | None + Minimum width to enforce. This could be the forward pass chunk + shape or the padding value in the first padding layer of the + generator network. This is only used if ``check_boundary = True`` check_bounary : bool Whether to check the final slice for minimum size requirement @@ -625,14 +652,16 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False): # We add minimum padding to the last slice if the padded window is # too small for the generator. This can happen if 2 * spatial_pad + - # modulo(grid_size, fwp_chunk_shape) < 4 + # modulo(grid_size, fwp_chunk_shape) is less than the padding applied + # in the first padding layer of the generator if ( check_boundary and win_stop == max_steps - and (2 * max_pad + win_stop - win_start) < 4 + and (2 * max_pad + win_stop - win_start) <= min_width ): - stop = np.max([2, max_pad]) - start = np.max([2, max_pad]) + half_width = min_width // 2 + 1 + stop = np.max([half_width, max_pad]) + start = np.max([half_width, max_pad]) return (start, stop) @@ -662,16 +691,20 @@ def get_pad_width(self, chunk_index): lr_slice[0], self.coarse_shape[0], self.spatial_pad, + self.min_width[0], check_boundary=True, ), self._get_pad_width( lr_slice[1], self.coarse_shape[1], self.spatial_pad, + self.min_width[1], check_boundary=True, ), self._get_pad_width( - ti_slice, len(self.dummy_time_index), self.temporal_pad + ti_slice, + len(self.dummy_time_index), + self.temporal_pad ), ) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 9a488e6a1..db05a72aa 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -221,7 +221,6 @@ def __post_init__(self): self.input_handler_kwargs.get('time_slice', slice(None)) ) self.fwp_chunk_shape = self._get_fwp_chunk_shape() - self.fwp_slicer = ForwardPassSlicer( coarse_shape=self.input_handler.grid_shape, time_steps=len(self.input_handler.time_index), @@ -231,6 +230,7 @@ def __post_init__(self): t_enhance=self.t_enhance, spatial_pad=self.spatial_pad, temporal_pad=self.temporal_pad, + min_width=model.max_paddings, ) self.n_chunks = self.fwp_slicer.n_chunks diff --git a/sup3r/preprocessing/__init__.py b/sup3r/preprocessing/__init__.py index 770f09159..315e9b79a 100644 --- a/sup3r/preprocessing/__init__.py +++ b/sup3r/preprocessing/__init__.py @@ -58,4 +58,9 @@ Rasterizer, SzaRasterizer, ) -from .samplers import DualSampler, DualSamplerCC, Sampler, SamplerDC +from .samplers import ( + DualSampler, + DualSamplerCC, + Sampler, + SamplerDC, +) diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 8eaa28d90..aad25b884 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -488,8 +488,8 @@ def size(self): def time_index(self): """Base time index for contained data.""" return ( - pd.to_datetime(self._ds.indexes['time']) - if 'time' in self._ds.indexes + pd.to_datetime(self._ds.indexes[Dimension.TIME]) + if Dimension.TIME in self._ds.indexes else None ) @@ -562,40 +562,40 @@ def flatten(self): """Flatten rasterized dataset so that there is only a single spatial dimension.""" if not self.flattened: - self._ds = self._ds.stack( - {Dimension.FLATTENED_SPATIAL: Dimension.dims_2d()} - ) - self._ds = self._ds.assign( - { - Dimension.FLATTENED_SPATIAL: np.arange( - len(self._ds[Dimension.FLATTENED_SPATIAL]) - ) - } - ) + dims = {Dimension.FLATTENED_SPATIAL: Dimension.dims_2d()} + self._ds = self._ds.stack(dims) + index = np.arange(len(self._ds[Dimension.FLATTENED_SPATIAL])) + self._ds = self._ds.assign({Dimension.FLATTENED_SPATIAL: index}) else: msg = 'Dataset is already flattened' logger.warning(msg) warn(msg) return self - def _qa(self, feature): + def _qa(self, feature, stats=None): """Get qa info for given feature.""" info = {} + stats = stats or ['nan_perc', 'std', 'mean', 'min', 'max'] logger.info('Running qa on feature: %s', feature) nan_count = 100 * np.isnan(self[feature].data).sum() nan_perc = nan_count / self[feature].size - info['nan_perc'] = compute_if_dask(nan_perc) - info['std'] = compute_if_dask(self[feature].std().data) - info['mean'] = compute_if_dask(self[feature].mean().data) - info['min'] = compute_if_dask(self[feature].min().data) - info['max'] = compute_if_dask(self[feature].max().data) + + for stat in stats: + logger.info('Running QA method %s on feature: %s', stat, feature) + if stat == 'nan_perc': + info['nan_perc'] = compute_if_dask(nan_perc) + else: + msg = f'Unknown QA method requested: {stat}' + assert hasattr(self[feature], stat), msg + qa_data = getattr(self[feature], stat)().data + info[stat] = compute_if_dask(qa_data) return info - def qa(self): - """Check NaNs and stats for all features.""" + def qa(self, stats=None): + """Check NaNs and the given stats for all features.""" qa_info = {} for f in self.features: - qa_info[f] = self._qa(f) + qa_info[f] = self._qa(f, stats=stats) return qa_info def __mul__(self, other): @@ -604,9 +604,8 @@ def __mul__(self, other): try: return type(self)(other * self._ds) except Exception as e: - raise NotImplementedError( - f'Multiplication not supported for type {type(other)}.' - ) from e + msg = f'Multiplication not supported for type {type(other)}.' + raise NotImplementedError(msg) from e def __rmul__(self, other): return self.__mul__(other) @@ -617,6 +616,5 @@ def __pow__(self, other): try: return type(self)(self._ds**other) except Exception as e: - raise NotImplementedError( - f'Exponentiation not supported for type {type(other)}.' - ) from e + msg = f'Exponentiation not supported for type {type(other)}.' + raise NotImplementedError(msg) from e diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 5ddd9dea0..7d81f9d27 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -11,8 +11,7 @@ import logging import pprint from abc import ABCMeta -from collections import namedtuple -from typing import Mapping, Tuple, Union +from typing import Dict, Mapping, Tuple, Union from warnings import warn import numpy as np @@ -70,6 +69,34 @@ def __repr__(cls): return f"" +class DsetTuple: + """A simple class to mimic namedtuple behavior with dynamic attributes + while being serializable""" + + def __init__(self, **kwargs): + self.dset_names = list(kwargs) + self.__dict__.update(kwargs) + + @property + def dsets(self): + """Dictionary with only dset names and associated values.""" + return {k: v for k, v in self.__dict__.items() if k in self.dset_names} + + def __iter__(self): + return iter(self.dsets.values()) + + def __getitem__(self, key): + if isinstance(key, int): + key = list(self.dsets)[key] + return self.dsets[key] + + def __len__(self): + return len(self.dsets) + + def __repr__(self): + return f'DsetTuple({self.dsets})' + + class Sup3rDataset: """Interface for interacting with one or two ``xr.Dataset`` instances. This is a wrapper around one or two ``Sup3rX`` objects so they work well @@ -116,6 +143,8 @@ class Sup3rDataset: from the high_res non coarsened variable. """ + DSET_NAMES = ('low_res', 'high_res', 'obs') + def __init__( self, **dsets: Mapping[str, Union[xr.Dataset, Sup3rX]], @@ -147,7 +176,7 @@ def __init__( assert len(dset) == 1, msg dsets[name] = dset._ds[0] - self._ds = namedtuple('Dataset', list(dsets))(**dsets) + self._ds = DsetTuple(**dsets) def __iter__(self): yield from self._ds @@ -183,17 +212,16 @@ def rewrap(self, data): """Rewrap data as ``Sup3rDataset`` after calling parent method.""" if isinstance(data, type(self)): return data - return ( - type(self)(low_res=data[0], high_res=data[1]) - if len(data) > 1 - else type(self)(high_res=data[0]) - ) + if len(data) == 1: + return type(self)(high_res=data[0]) + return type(self)(**dict(zip(self.DSET_NAMES, data))) def sample(self, idx): """Get samples from ``self._ds`` members. idx should be either a tuple of slices for the dimensions (south_north, west_east, time) and a list - of feature names or a 2-tuple of the same, for dual datasets.""" - if len(self._ds) == 2: + of feature names or a tuple of the same, for multi-member datasets + (dual datasets and dual with observations datasets).""" + if len(self._ds) > 1: return tuple(d.sample(idx[i]) for i, d in enumerate(self)) return self._ds[-1].sample(idx) @@ -215,7 +243,7 @@ def __getitem__(self, keys): if len(self._ds) == 1: return out[-1] if all(isinstance(o, Sup3rX) for o in out): - return type(self)(**dict(zip(self._ds._fields, out))) + return type(self)(**dict(zip(self._ds.dset_names, out))) return out @property @@ -228,10 +256,12 @@ def shape(self): def features(self): """The features are determined by the set of features from all data members.""" + if len(self._ds) == 1: + return self._ds[0].features feats = [ - f for f in self._ds[0].features if f not in self._ds[-1].features + f for f in self._ds[0].features if f not in self._ds[1].features ] - feats += self._ds[-1].features + feats += self._ds[1].features return feats @property @@ -258,13 +288,13 @@ def mean(self, **kwargs): """Use the high_res members to compute the means. These are used for normalization during training.""" kwargs['skipna'] = kwargs.get('skipna', True) - return self._ds[-1].mean(**kwargs) + return self._ds[1 if len(self._ds) > 1 else 0].mean(**kwargs) def std(self, **kwargs): """Use the high_res members to compute the stds. These are used for normalization during training.""" kwargs['skipna'] = kwargs.get('skipna', True) - return self._ds[-1].std(**kwargs) + return self._ds[1 if len(self._ds) > 1 else 0].std(**kwargs) def normalize(self, means, stds): """Normalize dataset using the given mean and stds. These are provided @@ -292,7 +322,12 @@ class Container(metaclass=Sup3rMeta): def __init__( self, data: Union[ - Sup3rX, Sup3rDataset, Tuple[Sup3rX, ...], Tuple[Sup3rDataset, ...] + Sup3rX, + Sup3rDataset, + Tuple[Sup3rX, ...], + Tuple[Sup3rDataset, ...], + Dict[str, Sup3rX], + Dict[str, Sup3rDataset], ] = None, ): """ @@ -309,10 +344,12 @@ def __init__( such. This is a tuple when the `.data` attribute belongs to a :class:`~.collections.base.Collection` object like :class:`~.batch_handlers.factory.BatchHandler`. Otherwise this is - :class:`~.Sup3rDataset` object, which is either a wrapped 2-tuple - or 1-tuple (e.g. ``len(data) == 2`` or ``len(data) == 1)``. This is - a 2-tuple when ``.data`` belongs to a dual container object like - :class:`~.samplers.DualSampler` and a 1-tuple otherwise. + :class:`~.Sup3rDataset` object, which is either a wrapped 3-tuple, + 2-tuple, or 1-tuple (e.g. ``len(data) == 3``, ``len(data) == 2`` or + ``len(data) == 1)``. This is a 3-tuple when ``.data`` belongs to a + container object like :class:`~.samplers.DualSamplerWithObs`, a + 2-tuple when ``.data`` belongs to a dual container object like + :class:`~.samplers.DualSampler`, and a 1-tuple otherwise. """ self.data = data @@ -342,29 +379,37 @@ def data(self, data): def wrap(self, data): """ Return a :class:`~.Sup3rDataset` object or tuple of such. This is a - tuple when the `.data` attribute belongs to a + tuple when the ``.data`` attribute belongs to a :class:`~.collections.base.Collection` object like :class:`~.batch_handlers.factory.BatchHandler`. Otherwise this is - :class:`~.Sup3rDataset` object, which is either a wrapped 2-tuple or - 1-tuple (e.g. ``len(data) == 2`` or ``len(data) == 1)``. This is a - 2-tuple when ``.data`` belongs to a dual container object like - :class:`~.samplers.DualSampler` and a 1-tuple otherwise. + :class:`~.Sup3rDataset` object, which is either a wrapped 3-tuple, + 2-tuple, or 1-tuple (e.g. ``len(data) == 3``, ``len(data) == 2`` or + ``len(data) == 1)``. This is a 3-tuple when ``.data`` belongs to a + container object like :class:`~.samplers.DualSamplerWithObs`, a 2-tuple + when ``.data`` belongs to a dual container object like + :class:`~.samplers.DualSampler`, and a 1-tuple otherwise. """ if data is None: return data + if hasattr(data, 'data'): + data = data.data + if is_type_of(data, Sup3rDataset): return data - if isinstance(data, tuple) and len(data) == 2: + if isinstance(data, dict): + data = Sup3rDataset(**data) + + if isinstance(data, tuple) and len(data) > 1: msg = ( f'{self.__class__.__name__}.data is being set with a ' - '2-tuple without explicit dataset names. We will assume ' - 'first tuple member is low-res and second is high-res.' + f'{len(data)}-tuple without explicit dataset names. We will ' + f'assume name ordering: {Sup3rDataset.DSET_NAMES[:len(data)]}' ) logger.warning(msg) warn(msg) - data = Sup3rDataset(low_res=data[0], high_res=data[1]) + data = Sup3rDataset(**dict(zip(Sup3rDataset.DSET_NAMES, data))) elif not isinstance(data, Sup3rDataset): name = getattr(data, 'name', None) or 'high_res' data = Sup3rDataset(**{name: data}) @@ -383,6 +428,9 @@ def shape(self): """Get shape of underlying data.""" return self.data.shape + def __len__(self): + return len(self.data) + def __contains__(self, vals): return vals in self.data @@ -403,11 +451,8 @@ def __setitem__(self, keys, data): def __getattr__(self, attr): """Check if attribute is available from ``.data``""" - if attr in dir(self): - return self.__getattribute__(attr) try: - data = self.__getattribute__('_data') - return getattr(data, attr) + return getattr(self._data, attr) except Exception as e: msg = f'{self.__class__.__name__} object has no attribute "{attr}"' raise AttributeError(msg) from e diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index 7e63c34b6..4037f2224 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -265,26 +265,26 @@ def init_samplers( """Initialize samplers from given data containers.""" train_samplers = [ self.SAMPLER( - data=c.data, + data=container, sample_shape=sample_shape, feature_sets=feature_sets, batch_size=batch_size, **sampler_kwargs, ) - for c in train_containers + for container in train_containers ] val_samplers = ( [] if val_containers is None else [ self.SAMPLER( - data=c.data, + data=container, sample_shape=sample_shape, feature_sets=feature_sets, batch_size=batch_size, **sampler_kwargs, ) - for c in val_containers + for container in val_containers ] ) return train_samplers, val_samplers @@ -315,6 +315,7 @@ def stop(self): DualBatchHandler = BatchHandlerFactory( DualBatchQueue, DualSampler, name='DualBatchHandler' ) + BatchHandlerCC = BatchHandlerFactory( DualBatchQueue, DualSamplerCC, name='BatchHandlerCC' ) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 772a172c6..c0d418ddc 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -4,20 +4,20 @@ (1) Figure out apparent "blocking" issue with threaded enqueue batches. max_workers=1 is the fastest? (2) Setup distributed data handling so this can work with data distributed - over multiple nodes. + over multiple nodes. """ import logging import threading import time from abc import ABC, abstractmethod -from collections import namedtuple -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import TYPE_CHECKING, List, Optional, Union import numpy as np import tensorflow as tf +from sup3r.preprocessing.base import DsetTuple from sup3r.preprocessing.collections.base import Collection from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer @@ -32,7 +32,7 @@ class AbstractBatchQueue(Collection, ABC): generator and maintains a queue of batches in a dedicated thread so the training routine can proceed as soon as batches are available.""" - Batch = namedtuple('Batch', ['low_res', 'high_res']) + BATCH_MEMBERS = ('low_res', 'high_res') def __init__( self, @@ -46,6 +46,7 @@ def __init__( max_workers: int = 1, thread_name: str = 'training', mode: str = 'lazy', + verbose: bool = False, ): """ Parameters @@ -77,6 +78,8 @@ def __init__( Loading mode. Default is 'lazy', which only loads data into memory as batches are queued. 'eager' will load all data into memory right away. + verbose : bool + Whether to log timing information for batch steps. """ msg = ( f'{self.__class__.__name__} requires a list of samplers. ' @@ -88,6 +91,7 @@ def __init__( self._queue_thread = None self._training_flag = threading.Event() self._thread_name = thread_name + self._thread_pool = ThreadPoolExecutor(max_workers=max_workers) self.mode = mode self.s_enhance = s_enhance self.t_enhance = t_enhance @@ -101,6 +105,7 @@ def __init__( 'smoothing_ignore': [], 'smoothing': None, } + self.verbose = verbose self.timer = Timer() self.preflight() @@ -111,6 +116,17 @@ def queue_shape(self): this is (batch_size, *sample_shape, len(features)). For dual dataset queues this is [(batch_size, *lr_shape), (batch_size, *hr_shape)]""" + @property + def queue_len(self): + """Get number of batches in the queue.""" + return self.queue.size().numpy() + self.queue_futures + + @property + def queue_futures(self): + """Get number of scheduled futures that will eventually add batches to + the queue.""" + return self._thread_pool._work_queue.qsize() + def get_queue(self): """Return FIFO queue for storing batches.""" return tf.queue.FIFOQueue( @@ -174,7 +190,7 @@ def transform(self, samples, **kwargs): high res samples. For a dual dataset queue this will just include smoothing.""" - def post_proc(self, samples) -> Batch: + def post_proc(self, samples) -> DsetTuple: """Performs some post proc on dequeued samples before sending out for training. Post processing can include coarsening on high-res data (if :class:`Collection` consists of :class:`Sampler` objects and not @@ -182,11 +198,12 @@ def post_proc(self, samples) -> Batch: Returns ------- - Batch : namedtuple - namedtuple with `low_res` and `high_res` attributes + Batch : DsetTuple + namedtuple-like object with `low_res` and `high_res` attributes. + Could also include `obs` member. """ - lr, hr = self.transform(samples, **self.transform_kwargs) - return self.Batch(low_res=lr, high_res=hr) + tsamps = self.transform(samples, **self.transform_kwargs) + return DsetTuple(**dict(zip(self.BATCH_MEMBERS, tsamps))) def start(self) -> None: """Start thread to keep sample queue full for batches.""" @@ -214,14 +231,10 @@ def __iter__(self): self.start() return self - def get_batch(self) -> Batch: + def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" - if ( - self.mode == 'eager' - or self.queue_cap == 0 - or self.queue.size().numpy() == 0 - ): + if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0: return self.sample_batch() return self.queue.dequeue() @@ -234,31 +247,47 @@ def running(self): and not self.queue.is_closed() ) + def sample_batches(self, n_batches) -> None: + """Sample given number of batches either in serial or with thread + pool.""" + if n_batches == 1 or self.max_workers == 1: + return [self.sample_batch() for _ in range(n_batches)] + tasks = [ + self._thread_pool.submit(self.sample_batch) + for _ in range(n_batches) + ] + logger.debug( + 'Added %s sample_batch futures to %s queue.', + n_batches, + self._thread_name, + ) + return tasks + def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.""" log_time = time.time() while self.running: - needed = self.queue_cap - self.queue.size().numpy() - if needed == 1 or self.max_workers == 1: - self.enqueue_batch() - elif needed > 0: - with ThreadPoolExecutor(self.max_workers) as exe: - _ = [exe.submit(self.enqueue_batch) for _ in range(needed)] - logger.debug( - 'Added %s enqueue futures to %s queue.', - needed, - self._thread_name, - ) + needed = max(self.queue_cap - self.queue_len, 0) + needed = min(self.max_workers, needed) + if needed > 0: + batches = self.sample_batches(n_batches=needed) + if needed > 1 and self.max_workers > 1: + for batch in as_completed(batches): + self.queue.enqueue(batch.result()) + else: + for batch in batches: + self.queue.enqueue(batch) + if time.time() > log_time + 10: logger.debug(self.log_queue_info()) log_time = time.time() - def __next__(self) -> Batch: + def __next__(self) -> DsetTuple: """Dequeue batch samples, squeeze if for a spatial only model, perform some post-proc like smoothing, coarsening, etc, and then send out for - training as a namedtuple of low_res / high_res arrays. + training as a namedtuple-like object of low_res / high_res arrays. Returns ------- @@ -276,11 +305,12 @@ def __next__(self) -> Batch: batch = self.post_proc(samples) self.timer.stop() self._batch_count += 1 - logger.debug( - 'Batch step %s finished in %s.', - self._batch_count, - self.timer.elapsed_str, - ) + if self.verbose: + logger.debug( + 'Batch step %s finished in %s.', + self._batch_count, + self.timer.elapsed_str, + ) else: raise StopIteration return batch @@ -304,21 +334,17 @@ def sample_batch(self): These samples are wrapped in an ``np.asarray`` call, so they have been loaded into memory. """ - return next(self.get_random_container()) + out = next(self.get_random_container()) + if not isinstance(out, tuple): + return tf.convert_to_tensor(out, dtype=tf.float32) + return tuple(tf.convert_to_tensor(o, dtype=tf.float32) for o in out) def log_queue_info(self): """Log info about queue size.""" - return '{} queue length: {} / {}.'.format( - self._thread_name.title(), - self.queue.size().numpy(), - self.queue_cap, + return '{} queue length: {} / {}'.format( + self._thread_name.title(), self.queue_len, self.queue_cap ) - def enqueue_batch(self): - """Build batch and send to queue.""" - if self.running and self.queue.size().numpy() < self.queue_cap: - self.queue.enqueue(self.sample_batch()) - @property def lr_shape(self): """Shape of low resolution sample in a low-res / high-res pair. (e.g. diff --git a/sup3r/preprocessing/batch_queues/conditional.py b/sup3r/preprocessing/batch_queues/conditional.py index 43d479b5f..488e70b2c 100644 --- a/sup3r/preprocessing/batch_queues/conditional.py +++ b/sup3r/preprocessing/batch_queues/conditional.py @@ -2,12 +2,12 @@ import logging from abc import abstractmethod -from collections import namedtuple from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np from sup3r.models.conditional import Sup3rCondMom +from sup3r.preprocessing.base import DsetTuple from sup3r.preprocessing.utilities import numpy_if_tensor from .base import SingleBatchQueue @@ -22,10 +22,6 @@ class ConditionalBatchQueue(SingleBatchQueue): """BatchQueue class for conditional moment estimation.""" - ConditionalBatch = namedtuple( - 'ConditionalBatch', ['low_res', 'high_res', 'output', 'mask'] - ) - def __init__( self, samplers: Union[List['Sampler'], List['DualSampler']], @@ -160,14 +156,14 @@ def post_proc(self, samples): Returns ------- - namedtuple - Named tuple with `low_res`, `high_res`, `mask`, and `output` - attributes + DsetTuple + Namedtuple-like object with `low_res`, `high_res`, `mask`, and + `output` attributes """ lr, hr = self.transform(samples, **self.transform_kwargs) mask = self.make_mask(high_res=hr) output = self.make_output(samples=(lr, hr)) - return self.ConditionalBatch( + return DsetTuple( low_res=lr, high_res=hr, output=output, mask=mask ) diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 691350cf7..56b2b08d4 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -20,6 +20,7 @@ def __init__(self, samplers, **kwargs): -------- :class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue` """ + self.BATCH_MEMBERS = samplers[0].dset_names super().__init__(samplers, **kwargs) self.check_enhancement_factors() @@ -27,11 +28,19 @@ def __init__(self, samplers, **kwargs): @property def queue_shape(self): - """Shape of objects stored in the queue.""" - return [ + """Shape of objects stored in the queue. Optionally includes shape of + observation data which would be included in an extra content loss + term""" + obs_shape = ( + *self.hr_shape[:-1], + len(self.containers[0].hr_out_features), + ) + queue_shapes = [ (self.batch_size, *self.lr_shape), (self.batch_size, *self.hr_shape), + (self.batch_size, *obs_shape), ] + return queue_shapes[: len(self.BATCH_MEMBERS)] def check_enhancement_factors(self): """Make sure each DualSampler has the same enhancment factors and they @@ -58,7 +67,7 @@ def transform(self, samples, smoothing=None, smoothing_ignore=None): This does not include temporal or spatial coarsening like :class:`SingleBatchQueue` """ - low_res, high_res = samples + low_res = samples[0] if smoothing is not None: feat_iter = [ @@ -71,4 +80,4 @@ def transform(self, samples, smoothing=None, smoothing_ignore=None): low_res[i, ..., j] = gaussian_filter( low_res[i, ..., j], smoothing, mode='nearest' ) - return low_res, high_res + return low_res, *samples[1:] diff --git a/sup3r/preprocessing/cachers/base.py b/sup3r/preprocessing/cachers/base.py index be380b673..82a4bd5c1 100644 --- a/sup3r/preprocessing/cachers/base.py +++ b/sup3r/preprocessing/cachers/base.py @@ -330,7 +330,8 @@ def write_h5( ] if Dimension.TIME in data: - data[Dimension.TIME] = data[Dimension.TIME].astype(int) + # int64 used explicity to avoid incorrect encoding as int32 + data[Dimension.TIME] = data[Dimension.TIME].astype('int64') for dset in [*coord_names, *features]: data_var, chunksizes = cls.get_chunksizes(dset, data, chunks) diff --git a/sup3r/preprocessing/collections/base.py b/sup3r/preprocessing/collections/base.py index f9f373bc9..736606a73 100644 --- a/sup3r/preprocessing/collections/base.py +++ b/sup3r/preprocessing/collections/base.py @@ -60,8 +60,6 @@ def container_weights(self): def __getattr__(self, attr): """Get attributes from self or the first container in the collection.""" - if attr in dir(self): - return self.__getattribute__(attr) return self.check_shared_attr(attr) def check_shared_attr(self, attr): diff --git a/sup3r/preprocessing/data_handlers/base.py b/sup3r/preprocessing/data_handlers/base.py index 452518d63..3d8f73125 100644 --- a/sup3r/preprocessing/data_handlers/base.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -145,6 +145,31 @@ def __init__( Dictionary of additional keyword args for :class:`~sup3r.preprocessing.rasterizers.Rasterizer`, used specifically for rasterizing flattened data + + Examples + -------- + Extract windspeed at 40m and 80m above the ground from files for u/v at + 10m and 100m. Windspeed will be interpolated from surrounding levels + using a log profile. ``dh`` will contain dask arrays of this data with + 10x10x50 chunk sizes. Data will be cached to files named + 'windspeed_40m.h5' and 'windspeed_80m.h5' in './cache_dir' with + 5x5x10 chunks on disk. + >>> cache_chunks = {'south_north': 5, 'west_east': 5, 'time': 10} + >>> load_chunks = {'south_north': 10, 'west_east': 10, 'time': 50} + >>> grid_size = (50, 50) + >>> lower_left_coordinate = (39.7, -105.2) + >>> dh = DataHandler( + ... file_paths=['./data_dir/u_10m.nc', './data_dir/u_100m.nc', + ... './data_dir/v_10m.nc', './data_dir/v_100m.nc'], + ... features=['windspeed_40m', 'windspeed_80m'], + ... shape=grid_size, time_slice=slice(0, 100), + ... target=lower_left_coordinate, hr_spatial_coarsen=2, + ... chunks=load_chunks, interp_kwargs={'method': 'log'}, + ... cache_kwargs={'cache_pattern': './cache_dir/{feature}.h5', + ... 'chunks': cache_chunks}) + + Derive more features from already initialized data handler: + >>> dh['windspeed_60m'] = dh.derive('windspeed_60m') """ # pylint: disable=line-too-long features = parse_to_list(features=features) diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index 77a5ca528..3c2ce0675 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -56,7 +56,7 @@ def __init__( a method to derive the feature in the registry. interp_kwargs : dict | None Dictionary of kwargs for level interpolation. Can include "method" - and "run_level_check". "method" specifies how to perform height + and "run_level_check" keys. Method specifies how to perform height interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options are "linear" and "log". See :py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation` @@ -65,7 +65,7 @@ def __init__( self.FEATURE_REGISTRY = FeatureRegistry super().__init__(data=data) - self.interp_kwargs = interp_kwargs or {} + self.interp_kwargs = interp_kwargs features = parse_to_list(data=data, features=features) new_features = [f for f in features if f not in self.data] for f in new_features: @@ -269,6 +269,7 @@ def get_single_level_data(self, feature): var_array = da.stack(var_list, axis=-1) sl_shape = (*var_array.shape[:-1], len(lev_list)) lev_array = da.broadcast_to(da.from_array(lev_list), sl_shape) + return var_array, lev_array def get_multi_level_data(self, feature): @@ -295,8 +296,8 @@ def get_multi_level_data(self, feature): assert can_calc_height or have_height, msg if can_calc_height: - lev_array = self.data['zg'] - self.data['topography'] - lev_array = lev_array.data + lev_array = self.data[['zg', 'topography']].as_array() + lev_array = lev_array[..., 0] - lev_array[..., 1] else: lev_array = da.broadcast_to( self.data[Dimension.HEIGHT].astype(np.float32), diff --git a/sup3r/preprocessing/names.py b/sup3r/preprocessing/names.py index 0ade726c6..700dce81b 100644 --- a/sup3r/preprocessing/names.py +++ b/sup3r/preprocessing/names.py @@ -162,6 +162,7 @@ def dims_4d_bc(cls): 'northward_turbulent_surface_stress', 'eastward_turbulent_surface_stress', 'sea_surface_temperature', + 'instantaneous_10m_wind_gust' ] # variables available on multiple pressure levels diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 777f6b5dc..2df70fa84 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -2,7 +2,7 @@ datasets""" import logging -from typing import Tuple, Union +from typing import Dict, Tuple, Union from warnings import warn import numpy as np @@ -38,9 +38,12 @@ class DualRasterizer(Container): @log_args def __init__( self, - data: Union[Sup3rDataset, Tuple[xr.Dataset, xr.Dataset]], + data: Union[ + Sup3rDataset, Tuple[xr.Dataset, xr.Dataset], Dict[str, xr.Dataset] + ], regrid_workers=1, regrid_lr=True, + run_qa=True, s_enhance=1, t_enhance=1, lr_cache_kwargs=None, @@ -51,7 +54,8 @@ def __init__( Parameters ---------- - data : Sup3rDataset | Tuple[xr.Dataset, xr.Dataset] + data : Sup3rDataset | Tuple[xr.Dataset, xr.Dataset] | + Dict[str, xr.Dataset] A tuple of xr.Dataset instances. The first must be low-res and the second must be high-res data regrid_workers : int | None @@ -60,6 +64,9 @@ def __init__( Flag to regrid the low-res data to the high-res grid. This will take care of any minor inconsistencies in different projections. Disable this if the grids are known to be the same. + run_qa : bool + Flag to run qa on the regridded low-res data. This will check for + NaNs and fill them if there are not too many. s_enhance : int Spatial enhancement factor t_enhance : int @@ -76,9 +83,11 @@ def __init__( self.s_enhance = s_enhance self.t_enhance = t_enhance if isinstance(data, tuple): - data = Sup3rDataset(low_res=data[0], high_res=data[1]) + data = {'low_res': data[0], 'high_res': data[1]} + if isinstance(data, dict): + data = Sup3rDataset(**data) msg = ( - 'The DualRasterizer requires either a data tuple with two ' + 'The DualRasterizer requires a data tuple or dictionary with two ' 'members, low and high resolution in that order, or a ' f'Sup3rDataset instance. Received {type(data)}.' ) @@ -130,7 +139,8 @@ def __init__( self.update_hr_data() super().__init__(data=(self.lr_data, self.hr_data)) - self.check_regridded_lr_data() + if run_qa: + self.check_regridded_lr_data() if lr_cache_kwargs is not None: Cacher(self.lr_data, lr_cache_kwargs) @@ -200,7 +210,7 @@ def update_lr_data(self): lr_coords_new = { Dimension.LATITUDE: self.lr_lat_lon[..., 0], Dimension.LONGITUDE: self.lr_lat_lon[..., 1], - Dimension.TIME: self.lr_data.indexes['time'][ + Dimension.TIME: self.lr_data.indexes[Dimension.TIME][ : self.lr_required_shape[2] ], } @@ -213,7 +223,7 @@ def check_regridded_lr_data(self): """Check for NaNs after regridding and do NN fill if needed.""" fill_feats = [] logger.info('Checking for NaNs after regridding') - qa_info = self.lr_data.qa() + qa_info = self.lr_data.qa(stats=['nan_perc']) for f in self.lr_data.features: nan_perc = qa_info[f]['nan_perc'] if nan_perc > 0: diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index dbf75f61f..e576c1521 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple from warnings import warn -import dask.array as da import numpy as np from sup3r.preprocessing.base import Container @@ -197,7 +196,7 @@ def _reshape_samples(self, samples): # (lats, lons, batch_size, times, feats) out = np.reshape(samples, new_shape) # (batch_size, lats, lons, times, feats) - return compute_if_dask(np.transpose(out, axes=(2, 0, 1, 3, 4))) + return np.transpose(out, axes=(2, 0, 1, 3, 4)) def _stack_samples(self, samples): """Used to build batch arrays in the case of independent time samples @@ -222,25 +221,27 @@ def _stack_samples(self, samples): (batch_size, samp_shape[0], samp_shape[1], samp_shape[2], n_feats) """ if isinstance(samples[0], tuple): - lr = da.stack([s[0] for s in samples], axis=0) - hr = da.stack([s[1] for s in samples], axis=0) + lr = np.stack([s[0] for s in samples], axis=0) + hr = np.stack([s[1] for s in samples], axis=0) return (lr, hr) - return da.stack(samples, axis=0) + return np.stack(samples, axis=0) def _fast_batch(self): """Get batch of samples with adjacent time slices.""" out = self.data.sample(self.get_sample_index(n_obs=self.batch_size)) + out = compute_if_dask(out) if isinstance(out, tuple): return tuple(self._reshape_samples(o) for o in out) return self._reshape_samples(out) def _slow_batch(self): """Get batch of samples with random time slices.""" - samples = [ + out = [ self.data.sample(self.get_sample_index(n_obs=1)) for _ in range(self.batch_size) ] - return self._stack_samples(samples) + out = compute_if_dask(out) + return self._stack_samples(out) def _fast_batch_possible(self): return self.batch_size * self.sample_shape[2] <= self.data.shape[2] diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index bf259dff3..56b7c4fd5 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -16,7 +16,10 @@ class DualSampler(Sampler): """Sampler for sampling from paired (or dual) datasets. Pairs consist of - low and high resolution data, which are contained by a Sup3rDataset.""" + low and high resolution data, which are contained by a Sup3rDataset. This + can also include extra observation data on the same grid as the + high-resolution data which has NaNs at points where observation data + doesn't exist. This will be used in an additional content loss term.""" def __init__( self, @@ -32,7 +35,7 @@ def __init__( ---------- data : Sup3rDataset A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with - low-res and high-res data members + low-res and high-res data members, and optionally an obs member. sample_shape : tuple Size of arrays to sample from the high-res data. The sample shape for the low-res sampler will be determined from the enhancement @@ -57,10 +60,16 @@ def __init__( """ msg = ( f'{self.__class__.__name__} requires a Sup3rDataset object ' - 'with `.low_res` and `.high_res` data members, in that order' + 'with `.low_res` and `.high_res` data members, and optionally an ' + '`.obs` member, in that order' ) - assert hasattr(data, 'low_res') and hasattr(data, 'high_res'), msg - assert data.low_res == data[0] and data.high_res == data[1], msg + dnames = ['low_res', 'high_res', 'obs'][:len(data)] + check = ( + hasattr(data, dname) and getattr(data, dname) == data[i] + for i, dname in enumerate(dnames) + ) + assert check, msg + super().__init__( data=data, sample_shape=sample_shape, batch_size=batch_size ) @@ -109,16 +118,17 @@ def check_for_consistent_shapes(self): self.lr_data.shape[2] * self.t_enhance, ) msg = ( - f'hr_data.shape {self.hr_data.shape} and enhanced ' + f'hr_data.shape {self.hr_data.shape[:-1]} and enhanced ' f'lr_data.shape {enhanced_shape} are not compatible with ' 'the given enhancement factors' ) - assert self.hr_data.shape[:3] == enhanced_shape, msg + assert self.hr_data.shape[:-1] == enhanced_shape, msg def get_sample_index(self, n_obs=None): """Get paired sample index, consisting of index for the low res sample and the index for the high res sample with the same spatiotemporal - extent.""" + extent. Optionally includes an extra high res index if the sample data + includes observation data.""" n_obs = n_obs or self.batch_size spatial_slice = uniform_box_sampler( self.lr_data.shape, self.lr_sample_shape[:2] @@ -135,5 +145,8 @@ def get_sample_index(self, n_obs=None): slice(s.start * self.t_enhance, s.stop * self.t_enhance) for s in lr_index[2:-1] ] + obs_index = (*hr_index, self.hr_out_features) hr_index = (*hr_index, self.hr_features) - return (lr_index, hr_index) + + sample_index = (lr_index, hr_index, obs_index) + return sample_index[:len(self.data)] diff --git a/sup3r/preprocessing/utilities.py b/sup3r/preprocessing/utilities.py index 8a51fadbf..db0dd5909 100644 --- a/sup3r/preprocessing/utilities.py +++ b/sup3r/preprocessing/utilities.py @@ -228,6 +228,8 @@ def compute_if_dask(arr): compute_if_dask(arr.stop), compute_if_dask(arr.step), ) + if isinstance(arr, (tuple, list)): + return type(arr)(compute_if_dask(a) for a in arr) return arr.compute() if hasattr(arr, 'compute') else arr @@ -409,6 +411,7 @@ def parse_keys( order. If keys is empty then we just want to return the coordinate data, so features will be set to just the coordinate names.""" + keys = list(keys) if isinstance(keys, set) else keys keys = keys if isinstance(keys, tuple) else (keys,) has_feats = is_type_of(keys[0], str) diff --git a/sup3r/utilities/pytest/helpers.py b/sup3r/utilities/pytest/helpers.py index 3f9b320a5..bb9912c9f 100644 --- a/sup3r/utilities/pytest/helpers.py +++ b/sup3r/utilities/pytest/helpers.py @@ -107,9 +107,17 @@ class DummySampler(Sampler): """Dummy container with random data.""" def __init__( - self, sample_shape, data_shape, features, batch_size, feature_sets=None + self, + sample_shape, + data_shape, + features, + batch_size, + feature_sets=None, + chunk_shape=None, ): data = make_fake_dset(data_shape, features=features) + if chunk_shape is not None: + data = data.chunk(chunk_shape) super().__init__( Sup3rDataset(high_res=data), sample_shape, @@ -314,10 +322,7 @@ def make_collect_chunks(td): out_files = [] for t, slice_hr in enumerate(t_slices_hr): for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)): - out_file = out_pattern.format( - t=str(t).zfill(6), - s=str(s).zfill(6) - ) + out_file = out_pattern.format(t=str(t).zfill(6), s=str(s).zfill(6)) out_files.append(out_file) OutputHandlerH5._write_output( data[s1_hr, s2_hr, slice_hr, :], @@ -330,15 +335,7 @@ def make_collect_chunks(td): gids=gids[s1_hr, s2_hr], ) - return ( - out_files, - data, - ws_true, - wd_true, - features, - hr_lat_lon, - hr_times - ) + return (out_files, data, ws_true, wd_true, features, hr_lat_lon, hr_times) def make_fake_h5_chunks(td): @@ -436,7 +433,7 @@ def make_fake_h5_chunks(td): s_slices_lr, s_slices_hr, low_res_lat_lon, - low_res_times + low_res_times, ) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 9b7be514a..f0f64246a 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -26,7 +26,7 @@ def preprocess_datasets(dset): dset.indexes['time'], 'to_datetimeindex' ): dset['time'] = dset.indexes['time'].to_datetimeindex() - ti = dset['time'].astype(int) + ti = dset['time'].astype('int64') dset['time'] = ti if 'latitude' in dset.dims: dset = dset.swap_dims({'latitude': 'south_north'}) @@ -91,6 +91,8 @@ def stop(self): @property def elapsed(self): """Elapsed time between start and stop.""" + if self._stop is None: + return time.time() - self._start return self._stop - self._start @property diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index b8f10fa8d..616eb1833 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -1,14 +1,14 @@ """Smoke tests for batcher objects. Just make sure things run without errors""" - import copy +import os +import time +from tempfile import TemporaryDirectory import numpy as np import pytest from scipy.ndimage import gaussian_filter -from sup3r.preprocessing import ( - BatchHandler, -) +from sup3r.preprocessing import BatchHandler, DataHandler from sup3r.preprocessing.base import Container from sup3r.utilities.pytest.helpers import ( BatchHandlerTesterFactory, @@ -17,6 +17,7 @@ ) from sup3r.utilities.utilities import ( RANDOM_GENERATOR, + Timer, spatial_coarsening, temporal_coarsening, ) @@ -29,6 +30,88 @@ BatchHandlerTester = BatchHandlerTesterFactory(BatchHandler, SamplerTester) +def test_batch_sampling_workers(): + """Check that it is faster to sample batches with max_workers > 1 than with + max_workers = 1. This does not include enqueueing and dequeueing.""" + + timer = Timer() + sample_shape = (100, 100, 30) + chunk_shape = ( + 2 * sample_shape[0], + 2 * sample_shape[1], + 2 * sample_shape[-1], + ) + n_obs = 10 + max_workers = 10 + n_batches = 50 + n_epochs = 3 + chunks = dict(zip(['south_north', 'west_east', 'time'], chunk_shape)) + + with TemporaryDirectory() as td: + ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m']) + ds.to_netcdf(os.path.join(td, 'test.nc')) + ds = DataHandler(os.path.join(td, 'test.nc'), chunks=chunks) + + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=max_workers, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, + ) + timer.start() + queue_time = 0 + for _ in range(n_epochs): + batches = batcher.sample_batches(n_batches) + batches = [batch.result() for batch in batches] + queue_start = time.time() + for batch in batches: + batcher.queue.enqueue(batch) + _ = batcher.queue.dequeue() + queue_time += (time.time() - queue_start) + timer.stop() + parallel_time = timer.elapsed / (n_batches * n_epochs) + parallel_queue_time = queue_time / (n_batches * n_epochs) + batcher.stop() + + batcher = BatchHandler( + [ds], + n_batches=n_batches, + batch_size=n_obs, + sample_shape=sample_shape, + max_workers=1, + means={'u_100m': 0, 'v_100m': 0}, + stds={'u_100m': 1, 'v_100m': 1}, + ) + timer.start() + queue_time = 0 + for _ in range(n_epochs): + batches = batcher.sample_batches(n_batches) + queue_start = time.time() + for batch in batches: + batcher.queue.enqueue(batch) + _ = batcher.queue.dequeue() + queue_time += time.time() - queue_start + timer.stop() + serial_time = timer.elapsed / (n_batches * n_epochs) + serial_queue_time = queue_time / (n_batches * n_epochs) + batcher.stop() + + print( + 'Elapsed total time (serial / parallel): {} / {}'.format( + serial_time, parallel_time + ) + ) + print( + 'Elapsed queue time (serial / parallel): {} / {}'.format( + serial_queue_time, parallel_queue_time + ) + ) + assert serial_time > parallel_time + + def test_eager_vs_lazy(): """Make sure eager and lazy loading agree. We use queue_cap = 0 here so there is no disagreement that results from dequeuing vs direct batch diff --git a/tests/batch_handlers/test_bh_h5_cc.py b/tests/batch_handlers/test_bh_h5_cc.py index 0fac7520a..942e52eed 100644 --- a/tests/batch_handlers/test_bh_h5_cc.py +++ b/tests/batch_handlers/test_bh_h5_cc.py @@ -1,6 +1,5 @@ """pytests for H5 climate change data batch handlers""" - import matplotlib.pyplot as plt import numpy as np import pytest @@ -370,10 +369,11 @@ def test_surf_min_max_vars(): assert batch.low_res.shape[-1] == len(surf_features) # compare daily avg temp vs min and max - assert (batch.low_res[..., 0] > batch.low_res[..., 2]).all() - assert (batch.low_res[..., 0] < batch.low_res[..., 3]).all() + blr = batch.low_res.numpy() + assert (blr[..., 0] > blr[..., 2]).all() + assert (blr[..., 0] < blr[..., 3]).all() # compare daily avg rh vs min and max - assert (batch.low_res[..., 1] > batch.low_res[..., 4]).all() - assert (batch.low_res[..., 1] < batch.low_res[..., 5]).all() + assert (blr[..., 1] > blr[..., 4]).all() + assert (blr[..., 1] < blr[..., 5]).all() batcher.stop() diff --git a/tests/conftest.py b/tests/conftest.py index 9c23755cf..2f52610da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,6 +39,10 @@ def pytest_configure(config): # pylint: disable=unused-argument # noqa: ARG001 pytest.ST_FP_DISC = os.path.join(TEST_DATA_DIR, 'config_disc_st_test.json') pytest.S_FP_DISC = os.path.join(TEST_DATA_DIR, 'config_disc_s_test.json') + pytest.ST_FP_DISC_PROD = os.path.join( + CONFIG_DIR, 'spatiotemporal/disc.json' + ) + pytest.FPS_GCM = [ os.path.join(TEST_DATA_DIR, 'ua_test.nc'), os.path.join(TEST_DATA_DIR, 'va_test.nc'), diff --git a/tests/data/extract_raster_wtk.py b/tests/data/extract_raster_wtk.py deleted file mode 100644 index ed826d134..000000000 --- a/tests/data/extract_raster_wtk.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Script to extract data subset in raster shape from flattened WTK h5 files. - -TODO: Is this worth keeping for any reason? -""" - -import matplotlib.pyplot as plt -from rex import init_logger -from rex.outputs import Outputs -from rex.resource_extraction.resource_extraction import WindX - -if __name__ == '__main__': - init_logger('rex', log_level='DEBUG') - - res_fp = '/datasets/WIND/conus/v1.0.0/wtk_conus_2013.h5' - fout = './test_wtk_co_2013.h5' - - dsets = ['windspeed_80m', 'windspeed_100m', 'winddirection_80m', - 'winddirection_100m', 'temperature_100m', 'pressure_100m'] - - target = (39.0, -105.15) - shape = (20, 20) - - with WindX(res_fp) as res: - meta = res.meta - raster_index_2d = res.get_raster_index(target, shape, max_delta=20) - - for d in ('elevation', 'latitude', 'longitude'): - data = meta[d].values[raster_index_2d] - a = plt.imshow(data) - plt.colorbar(a, label=d) - plt.savefig(d + '.png') - plt.close() - - raster_index = sorted(raster_index_2d.ravel()) - - attrs = {k: res.resource.attrs[k] for k in dsets} - chunks = dict.fromkeys(dsets) - dtypes = {k: res.resource.dtypes[k] for k in dsets} - meta = meta.iloc[raster_index].reset_index(drop=True) - time_index = res.time_index - shapes = {k: (len(time_index), len(meta)) for k in dsets} - print(shapes) - - Outputs.init_h5(fout, dsets, shapes, attrs, chunks, dtypes, meta, - time_index=time_index) - - with Outputs(fout, mode='a') as f: - for d in dsets: - f[d] = res[d, :, raster_index] - - with Outputs(fout, mode='r') as f: - meta = f.meta - for d in dsets: - data = f[d].mean(axis=0) - data = data.reshape(shape) - a = plt.imshow(data) - plt.colorbar(a, label=d) - plt.savefig(d + '.png') - plt.close() diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 76d268762..944c0bf2c 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -802,8 +802,8 @@ def test_slicing_auto_boundary_pad(input_files, spatial_pad): fwp.strategy.ti_slices[t_idx], ) - assert chunk.input_data.shape[0] > 3 - assert chunk.input_data.shape[1] > 3 + assert chunk.input_data.shape[0] > strategy.fwp_slicer.min_width[0] + assert chunk.input_data.shape[1] > strategy.fwp_slicer.min_width[1] input_data = chunk.input_data.copy() if spatial_pad > 0: slices = [ diff --git a/tests/rasterizers/test_dual.py b/tests/rasterizers/test_dual.py index ed11feecb..725c2fdb0 100644 --- a/tests/rasterizers/test_dual.py +++ b/tests/rasterizers/test_dual.py @@ -30,7 +30,9 @@ def test_dual_rasterizer_shapes(full_shape=(20, 20)): ) pair_rasterizer = DualRasterizer( - (lr_container.data, hr_container.data), s_enhance=2, t_enhance=1 + {'low_res': lr_container.data, 'high_res': hr_container.data}, + s_enhance=2, + t_enhance=1, ) assert pair_rasterizer.lr_data.shape == ( pair_rasterizer.hr_data.shape[0] // 2, @@ -63,7 +65,9 @@ def test_dual_nan_fill(full_shape=(20, 20)): assert np.isnan(lr_container.data.as_array()).any() pair_rasterizer = DualRasterizer( - (lr_container.data, hr_container.data), s_enhance=1, t_enhance=1 + {'low_res': lr_container.data, 'high_res': hr_container.data}, + s_enhance=1, + t_enhance=1, ) assert not np.isnan(pair_rasterizer.lr_data.as_array()).any() @@ -89,7 +93,7 @@ def test_regrid_caching(full_shape=(20, 20)): lr_cache_pattern = os.path.join(td, 'lr_{feature}.h5') hr_cache_pattern = os.path.join(td, 'hr_{feature}.h5') pair_rasterizer = DualRasterizer( - (lr_container.data, hr_container.data), + {'low_res': lr_container.data, 'high_res': hr_container.data}, s_enhance=2, t_enhance=1, lr_cache_kwargs={'cache_pattern': lr_cache_pattern}, diff --git a/tests/training/test_train_dual.py b/tests/training/test_train_dual.py index f92443961..29e38a32b 100644 --- a/tests/training/test_train_dual.py +++ b/tests/training/test_train_dual.py @@ -66,7 +66,7 @@ def test_train_h5_nc( # time indices conflict with t_enhance with pytest.raises(AssertionError): dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) @@ -78,7 +78,7 @@ def test_train_h5_nc( ) dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) @@ -158,7 +158,7 @@ def test_train_coarse_h5( ) dual_rasterizer = DualRasterizer( - data=(lr_handler.data, hr_handler.data), + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, s_enhance=s_enhance, t_enhance=t_enhance, ) diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py new file mode 100644 index 000000000..2399b21ba --- /dev/null +++ b/tests/training/test_train_dual_with_obs.py @@ -0,0 +1,232 @@ +"""Test the training of GANs with dual data handler""" + +import itertools +import os +import tempfile + +import numpy as np +import pytest + +from sup3r.models import Sup3rGanWithObs +from sup3r.preprocessing import ( + Container, + DataHandler, + DualBatchHandler, + DualRasterizer, +) +from sup3r.preprocessing.samplers import DualSampler +from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory + +TARGET_COORD = (39.01, -105.15) +FEATURES = ['u_100m', 'v_100m'] + + +DualBatchHandlerWithObsTester = BatchHandlerTesterFactory( + DualBatchHandler, DualSampler +) + + +@pytest.mark.parametrize( + [ + 'fp_gen', + 'fp_disc', + 's_enhance', + 't_enhance', + 'sample_shape', + 'mode', + ], + [ + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), + ], +) +def test_train_h5_nc( + fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 +): + """Test model training with a dual data handler / batch handler with h5 and + era as hr / lr datasets. Tests both spatiotemporal and spatial models.""" + + lr = 1e-5 + kwargs = { + 'features': FEATURES, + 'target': TARGET_COORD, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=s_enhance, + t_enhance=t_enhance, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = list(range(0, 20, 4)) + lon_ids = list(range(0, 20, 4)) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[feat] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': dual_rasterizer.high_res, + 'obs': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=3, + mode=mode, + ) + + for batch in batch_handler: + assert hasattr(batch, 'obs') + assert not np.isnan(batch.obs).all() + assert np.isnan(batch.obs).any() + + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': n_epoch, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tlossg = model.history['train_loss_gen'].values + tlosso = model.history['train_loss_obs'].values + assert np.sum(np.diff(tlossg)) < 0 + assert np.sum(np.diff(tlosso)) < 0 + + +@pytest.mark.parametrize( + [ + 'fp_gen', + 'fp_disc', + 's_enhance', + 't_enhance', + 'sample_shape', + 'mode', + ], + [ + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), + ], +) +def test_train_coarse_h5( + fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 +): + """Test model training with a dual data handler / batch handler with + additional sparse observation data used in extra content loss term. Tests + both spatiotemporal and spatial models.""" + + lr = 1e-5 + kwargs = { + 'features': FEATURES, + 'target': TARGET_COORD, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + hr_spatial_coarsen=s_enhance, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=s_enhance, + t_enhance=t_enhance, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = list(range(0, 20, 4)) + lon_ids = list(range(0, 20, 4)) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[feat] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': dual_rasterizer.high_res, + 'obs': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=3, + mode=mode, + ) + + for batch in batch_handler: + assert hasattr(batch, 'obs') + assert not np.isnan(batch.obs).all() + assert np.isnan(batch.obs).any() + + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': n_epoch, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tlossg = model.history['train_loss_gen'].values + tlosso = model.history['train_loss_obs'].values + assert np.sum(np.diff(tlossg)) < 0 + assert np.sum(np.diff(tlosso)) < 0 diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 81d1ebd63..954bc3cb2 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -113,10 +113,13 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): assert 'OptmGen/learning_rate' in model.history assert 'OptmDisc/learning_rate' in model.history - msg = ('Could not find OptmGen states in columns: ' - f'{sorted(model.history.columns)}') - check = [col.startswith('OptmGen/Adam/v') - for col in model.history.columns] + msg = ( + 'Could not find OptmGen states in columns: ' + f'{sorted(model.history.columns)}' + ) + check = [ + col.startswith('OptmGen/Adam/v') for col in model.history.columns + ] assert any(check), msg assert 'config_generator' in loaded.meta