-
Notifications
You must be signed in to change notification settings - Fork 3
quick start
Mingli Yuan edited this page Mar 23, 2021
·
3 revisions
pip install wxbtool
Prepare the .env file in your current working directory
echo 'WXBHOME=/YOUR-WEATHERBENCH_HOME/' > .env
Issue the command to fire the trainer
wxb train
Then the trainer will train the default model to predict t850 3-days later
The default model is based on ResUNet, it is a python module which can be separated into a setting part and a model part
class SettingWeyn(Setting):
def __init__(self):
super().__init__()
self.resolution = '5.625deg' # The spatial resolution of the model
self.levels = ['300', '500', '700', '850', '1000'] # Which vertical levels to choose
self.height = len(self.levels) # How many vertical levels to choose
# The name of variables to choose, for both input features and output
self.vars = ['geopotential', 'toa_incident_solar_radiation', '2m_temperature', 'temperature']
# The code of variables in input features
self.vars_in = ['z500', 'z1000', 'tau', 't850', 't2m', 'tisr']
# The code of variables in output
self.vars_out = ['t850']
# temporal scopes for train
self.years_train = [
1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989,
1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999,
2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009,
2010, 2011, 2012, 2013, 2014,
]
# temporal scopes for evaluation
self.years_eval = [2015, 2016]
# temporal scopes for test
self.years_test = [2017, 2018]
class Setting3d(SettingWeyn):
def __init__(self):
super().__init__()
self.step = 8 # How many hours of a hourly step which all features in organized temporally
self.input_span = 3 # How many hourly steps for an input
self.pred_span = 1 # How many hourly steps for a prediction
self.pred_shift = 72 # How many hours between the end of the input span and the beginning of prediction span
class Setting5d(SettingWeyn):
def __init__(self):
super().__init__()
self.step = 8 # How many hours of a hourly step which all features in organized temporally
self.input_span = 3 # How many hourly steps for an input
self.pred_span = 1 # How many hourly steps for a prediction
self.pred_shift = 120 # How many hours between the end of the input span and the beginning of prediction span
class ResUNetModel(Spec):
def __init__(self, setting):
super().__init__(setting)
self.name = 't850d3sm-weyn'
self.resunet = resunet(setting.input_span * (len(setting.vars) + 2) + self.constant_size + 2, 1,
spatial=(32, 64+2), layers=5, ratio=-1,
vblks=[2, 2, 2, 2, 2], hblks=[1, 1, 1, 1, 1],
scales=[-1, -1, -1, -1, -1], factors=[1, 1, 1, 1, 1],
block=SEBottleneck, relu=CappingRelu(), final_normalized=False)
def forward(self, **kwargs):
batch_size = kwargs['temperature'].size()[0]
self.update_da_status(batch_size)
_, input = self.get_inputs(**kwargs)
constant = self.get_augmented_constant(input)
input = th.cat((input, constant), dim=1)
input = th.cat((input[:, :, :, 63:64], input, input[:, :, :, 0:1]), dim=3)
output = self.resunet(input)
output = output[:, :, :, 1:65]
return {
't850': output
}