From 67fcc6aba379c7802857ac21cfd5c217218523dd Mon Sep 17 00:00:00 2001 From: Seth McGinnis Date: Thu, 21 Nov 2024 23:30:00 -0700 Subject: [PATCH] dataset object, preliminary version --- credit/data_downscaling.py | 126 +++++++++++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 27 deletions(-) diff --git a/credit/data_downscaling.py b/credit/data_downscaling.py index 37e67acb..feff1b6e 100644 --- a/credit/data_downscaling.py +++ b/credit/data_downscaling.py @@ -1,32 +1,61 @@ -from typing import Optional, Callable, TypedDict, Union, List +# system tools import os +from typing import Optional, Callable, TypedDict, Union, List from dataclasses import dataclass, field from functools import reduce from glob import glob from itertools import repeat from timeit import timeit + +# data utils import numpy as np import xarray as xr + +# Pytorch utils import torch import torch.utils.data +from torch.utils.data import get_worker_info +from torch.utils.data.distributed import DistributedSampler + + +Array = Union[np.ndarray, xr.DataArray] +IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images') + +# def get_forward_data(filename) -> xr.DataArray: +# """Lazily opens a netcdf file or zarr store as xr.DataArray +# """ +# if filename[-3:] == '.nc' or filename[-4:] == '.nc4': +# dataset = xr.open_dataset(filename) +# else: +# dataset = xr.open_zarr(filename, consolidated=True) +# return dataset -def get_forward_data(filename) -> xr.DataArray: - """Lazily opens a Zarr store + +def flatten(array): + + """ flattens a list-of-lists """ - dataset = xr.open_zarr(filename, consolidated=True) - return dataset + return reduce(lambda a, b: a+b, array) -Array = Union[np.ndarray, xr.DataArray] -IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images') +# ## no longer needed for everything in the same file? +# def lazymerge(zlist, rename=None): +# """ merges zarr stores opened lazily with get_forward_data() +# """ +# zarrs = [get_forward_data(z) for z in zlist] +# if rename is not None: +# oldname = flatten([list(z.keys()) for z in zarrs]) +# # ^^ this will break on multi-var zarr stores +# zarrs = [z.rename_vars({old: new}) for z, old, new in zip(zarrs, oldname, rename)] +# return xr.merge(zarrs) class Sample(TypedDict): """Simple class for structuring data for the ML model. - x = input (predictor) data (i.e, C404dataset[historical mask] - y = target (predictand) data (i.e, C404dataset[forecast mask] + x = predictor (input) data + y = predictand (target) data Using typing.TypedDict gives us several advantages: 1. Single 'source of truth' for the type and documentation of each example. @@ -40,24 +69,6 @@ class Sample(TypedDict): y: Array -def flatten(array): - - """ flattens a list-of-lists - """ - return reduce(lambda a, b: a+b, array) - - -def lazymerge(zlist, rename=None): - """ merges zarr stores opened lazily with get_forward_data() - """ - zarrs = [get_forward_data(z) for z in zlist] - if rename is not None: - oldname = flatten([list(z.keys()) for z in zarrs]) - # ^^ this will break on multi-var zarr stores - zarrs = [z.rename_vars({old: new}) for z, old, new in zip(zarrs, oldname, rename)] - return xr.merge(zarrs) - - # using dataclass decorator avoids lots of self.x=x and gets us free __repr__ @dataclass class CONUS404Dataset(torch.utils.data.Dataset): @@ -230,3 +241,64 @@ def testC4loader(): print(testvars) cmd = 'c4 = CONUS404Dataset("'+src+'",varnames='+str(testvars)+')' print(cmd+"\t"+str(timeit(cmd, globals=globals(), number=1))) + + +##################### + +@dataclass +class DownscalingDataset(torch.utils.data.Dataset): + ''' pass **conf['data'] as arguments to constructor + ''' + rootpath: str + history_len: int = 2 + forecast_len: int = 1 + first_date: str = None + last_date: str = None + datasets: Dict = field(default_factory=dict) + + def __post_init__(self): + super().__init__() + + ## replace the datasets dict (which holds configurations for + ## the various DataMaps in the dataset) with actual DataMap + ## objects intialized from those configurations. Need to pop + ## datasets from __dict__ because we need to update each one + ## with the other class attributes (which are common to all + ## datasets) first. + + dmap_configs = self.__dict__.pop("datasets") + + ## error if length not > 1 + + self.datasets = dict() + for k in dmap_configs.keys(): + dmap_configs[k].update(self.__dict__) + self.datasets[k] = DataMap(**dmap_configs[k]) + + dlengths = [len(d) for d in datasets.values()] + self.len = np.max(dlengths) + # error if any dlengths != self.len or 1 + + def __getitem__(self, index): + items = {k:self.datasets[k][index] for k in self.datasets.keys()} + + ## okay, need to put static into a dict with key 'static' + + + + ## combine results by use + # result = dict() + # loop on use: + # result[use] = dict + # loop on keys: + # result[use].append items[key][use] + + + # transforms to tensor + # (includes unstacking z-dim) + # applies normalization + # (and any other transformations) + pass + + # actually, does it do tensor transformation? Or do we write a + # ToTensor object that takes a dict of variables & does it?