Skip to content

Commit

Permalink
dataset object, preliminary version
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmcg committed Nov 22, 2024
1 parent 0d5c7a2 commit 67fcc6a
Showing 1 changed file with 99 additions and 27 deletions.
126 changes: 99 additions & 27 deletions credit/data_downscaling.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,61 @@
from typing import Optional, Callable, TypedDict, Union, List
# system tools
import os
from typing import Optional, Callable, TypedDict, Union, List
from dataclasses import dataclass, field
from functools import reduce
from glob import glob
from itertools import repeat
from timeit import timeit

# data utils
import numpy as np
import xarray as xr

# Pytorch utils
import torch
import torch.utils.data
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler


Array = Union[np.ndarray, xr.DataArray]
IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images')


# def get_forward_data(filename) -> xr.DataArray:
# """Lazily opens a netcdf file or zarr store as xr.DataArray
# """
# if filename[-3:] == '.nc' or filename[-4:] == '.nc4':
# dataset = xr.open_dataset(filename)
# else:
# dataset = xr.open_zarr(filename, consolidated=True)
# return dataset

def get_forward_data(filename) -> xr.DataArray:
"""Lazily opens a Zarr store

def flatten(array):

""" flattens a list-of-lists
"""
dataset = xr.open_zarr(filename, consolidated=True)
return dataset
return reduce(lambda a, b: a+b, array)


Array = Union[np.ndarray, xr.DataArray]
IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images')
# ## no longer needed for everything in the same file?
# def lazymerge(zlist, rename=None):
# """ merges zarr stores opened lazily with get_forward_data()
# """
# zarrs = [get_forward_data(z) for z in zlist]
# if rename is not None:
# oldname = flatten([list(z.keys()) for z in zarrs])
# # ^^ this will break on multi-var zarr stores
# zarrs = [z.rename_vars({old: new}) for z, old, new in zip(zarrs, oldname, rename)]
# return xr.merge(zarrs)


class Sample(TypedDict):
"""Simple class for structuring data for the ML model.
x = input (predictor) data (i.e, C404dataset[historical mask]
y = target (predictand) data (i.e, C404dataset[forecast mask]
x = predictor (input) data
y = predictand (target) data
Using typing.TypedDict gives us several advantages:
1. Single 'source of truth' for the type and documentation of each example.
Expand All @@ -40,24 +69,6 @@ class Sample(TypedDict):
y: Array


def flatten(array):

""" flattens a list-of-lists
"""
return reduce(lambda a, b: a+b, array)


def lazymerge(zlist, rename=None):
""" merges zarr stores opened lazily with get_forward_data()
"""
zarrs = [get_forward_data(z) for z in zlist]
if rename is not None:
oldname = flatten([list(z.keys()) for z in zarrs])
# ^^ this will break on multi-var zarr stores
zarrs = [z.rename_vars({old: new}) for z, old, new in zip(zarrs, oldname, rename)]
return xr.merge(zarrs)


# using dataclass decorator avoids lots of self.x=x and gets us free __repr__
@dataclass
class CONUS404Dataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -230,3 +241,64 @@ def testC4loader():
print(testvars)
cmd = 'c4 = CONUS404Dataset("'+src+'",varnames='+str(testvars)+')'
print(cmd+"\t"+str(timeit(cmd, globals=globals(), number=1)))


#####################

@dataclass
class DownscalingDataset(torch.utils.data.Dataset):
''' pass **conf['data'] as arguments to constructor
'''
rootpath: str
history_len: int = 2
forecast_len: int = 1
first_date: str = None
last_date: str = None
datasets: Dict = field(default_factory=dict)

def __post_init__(self):
super().__init__()

## replace the datasets dict (which holds configurations for
## the various DataMaps in the dataset) with actual DataMap
## objects intialized from those configurations. Need to pop
## datasets from __dict__ because we need to update each one
## with the other class attributes (which are common to all
## datasets) first.

dmap_configs = self.__dict__.pop("datasets")

## error if length not > 1

self.datasets = dict()
for k in dmap_configs.keys():
dmap_configs[k].update(self.__dict__)
self.datasets[k] = DataMap(**dmap_configs[k])

dlengths = [len(d) for d in datasets.values()]
self.len = np.max(dlengths)
# error if any dlengths != self.len or 1

def __getitem__(self, index):
items = {k:self.datasets[k][index] for k in self.datasets.keys()}

## okay, need to put static into a dict with key 'static'



## combine results by use
# result = dict()
# loop on use:
# result[use] = dict
# loop on keys:
# result[use].append items[key][use]


# transforms to tensor
# (includes unstacking z-dim)
# applies normalization
# (and any other transformations)
pass

# actually, does it do tensor transformation? Or do we write a
# ToTensor object that takes a dict of variables & does it?

0 comments on commit 67fcc6a

Please sign in to comment.