Skip to content

Commit

Permalink
read variables & times based on mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmcg committed Dec 2, 2024
1 parent 67fcc6a commit 7b44959
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
5 changes: 0 additions & 5 deletions credit/data_downscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,6 @@ def __post_init__(self):
def __getitem__(self, index):
items = {k:self.datasets[k][index] for k in self.datasets.keys()}

## okay, need to put static into a dict with key 'static'



## combine results by use
# result = dict()
# loop on use:
Expand All @@ -295,7 +291,6 @@ def __getitem__(self, index):


# transforms to tensor
# (includes unstacking z-dim)
# applies normalization
# (and any other transformations)
pass
Expand Down
78 changes: 47 additions & 31 deletions credit/datamap.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class DataMap:
forecast_len: number of output timesteps
first_date: restrict dataset to timesteps >= this point in time
last_date: restrict dataset to timesteps <= this point in time
mode: which variables to return by use type:
"train" = all; "init" = all but diagnostic; "infer" = static + boundary
first_date and last_date default to None, which means use the
first/last timestep in the dataset. Note that they must be
Expand All @@ -204,6 +206,7 @@ class DataMap:
forecast_len: int = 1
first_date: str = None
last_date: str = None
mode: str = "train"

def __post_init__(self):
super().__init__()
Expand Down Expand Up @@ -272,16 +275,6 @@ def __post_init__(self):

nc0.close()

## get last timestep index in each file
## do this in a loop to avoid many-many open filehandles at once
file_lens = list()
for f in self.filepaths:
ncf = nc.Dataset(f)
file_lens.append(len(ncf.variables["time"]))
ncf.close()

self.ends = list(np.cumsum(file_lens) - 1)

if(self.first_date is None):
self.first = 0
else:
Expand All @@ -293,6 +286,20 @@ def __post_init__(self):
self.last = self.date2tindex(self.last_date)

self.length = self.last - self.first + 1 - (self.sample_len - 1)

## get last timestep index in each file
## do this in a loop to avoid many-many open filehandles at once

self.ends = list()
cumlen = -1
for f in self.filepaths:
ncf = nc.Dataset(f)
cumlen = cumlen + len(ncf.variables["time"])
self.ends.append(cumlen)
ncf.close()
## file opens are slow; stop early if possible
if cumlen > self.last:
break

# end of __post_init__

Expand Down Expand Up @@ -326,14 +333,16 @@ def __len__(self):

def __getitem__(self, index):
if self.dim == "static":
return {static:self.data}
return {"static":self.data}

# error if index is not int
# error if index > length-1
# error if index < 0 - does not support direct slicing / negative indexing

start = index + self.first
finish = start + self.sample_len - 1
if index < 0 or index > self.length-1:
raise(IndexError())

start = index + self.first + 1
if self.mode == "train":
finish = start + self.sample_len - 1
else:
finish = start + self.history_len - 1

# get segment (which file) and subindex (within file) for start & finish
# subindexes are all negative, but that works fine & makes math simpler
Expand All @@ -355,22 +364,31 @@ def __getitem__(self, index):
a1 = data1[use][var]
a2 = data2[use][var]
result[use][var] = np.concatenate((a1,a2))


result["dates"] = {"start":self.sindex2date(start-self.first),
"finish":self.sindex2date(finish-self.first),
}
return result
pass

## If needed for speed / efficiency, we could add a "mode"
## attribute to the DataMap that read() would use to decide which
## variables to read in:
## training = everything
## intitalize = boundary & prognostic (skip diagnostic)
## inference = boundary vars only

def read(self, segment, start, finish):
'''open file & read data from start to finish for each variable in varlist'''
'''open file & read data from start to finish for needed variables'''

# Note: static DataMaps never call read; they short-circuit in getitem
match self.mode:
case "train":
uses = ("boundary","prognostic","diagnostic")
case "init":
uses = ("boundary","prognostic")
case "infer":
uses = ("boundary",)
case _:
raise ValueError("invalid DataMap mode")

ds = nc.Dataset(self.filepaths[segment])
data = dict()
for use in self.vardict.keys():
for use in uses:
data[use] = dict()
for var in self.vardict[use]:
if self.dim == "3D":
Expand All @@ -386,9 +404,7 @@ def read(self, segment, start, finish):
data[use][var] = ds[var][start:finish,:,:]
ds.close()
return data

## normalization (except for static) & structural transformations
## (split to hist/fore, unstack z, concatenate variables to
## tensor) happen in parent class. DataMap just gets you data
## from file(s).

## todo: normalization
## todo: splitting to input / target
## concatenation to tensor happens in dataset class

0 comments on commit 7b44959

Please sign in to comment.