Skip to content

quick start

Mingli Yuan edited this page Mar 23, 2021 · 3 revisions

Install

pip install wxbtool

Train the default model

Prepare the .env file in your current working directory

echo 'WXBHOME=/YOUR-WEATHERBENCH_HOME/' > .env

Issue the command to fire the trainer

wxb train

Then the trainer will train the default model to predict t850 3-days later

Understanding the default model

The default model is based on ResUNet, it is a python module which can be separated into a setting part and a model part

setting part

class SettingWeyn(Setting):
    def __init__(self):
        super().__init__()
        self.resolution = '5.625deg'    # The spatial resolution of the model

        self.levels = ['300', '500', '700', '850', '1000'] # Which vertical levels to choose
        self.height = len(self.levels)                     # How many vertical levels to choose

        # The name of variables to choose, for both input features and output
        self.vars = ['geopotential', 'toa_incident_solar_radiation', '2m_temperature', 'temperature']

        # The code of variables in input features
        self.vars_in = ['z500', 'z1000', 'tau', 't850', 't2m', 'tisr']
        # The code of variables in output
        self.vars_out = ['t850']

        # temporal scopes for train
        self.years_train = [
            1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989,
            1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999,
            2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009,
            2010, 2011, 2012, 2013, 2014,
        ]
        # temporal scopes for evaluation
        self.years_eval = [2015, 2016]
        # temporal scopes for test
        self.years_test = [2017, 2018]


class Setting3d(SettingWeyn):
    def __init__(self):
        super().__init__()
        self.step = 8                   # How many hours of a hourly step which all features in organized temporally
        self.input_span = 3             # How many hourly steps for an input
        self.pred_span = 1              # How many hourly steps for a prediction
        self.pred_shift = 72            # How many hours between the end of the input span and the beginning of prediction span


class Setting5d(SettingWeyn):
    def __init__(self):
        super().__init__()
        self.step = 8                   # How many hours of a hourly step which all features in organized temporally
        self.input_span = 3             # How many hourly steps for an input
        self.pred_span = 1              # How many hourly steps for a prediction
        self.pred_shift = 120           # How many hours between the end of the input span and the beginning of prediction span

model part

class ResUNetModel(Spec):
    def __init__(self, setting):
        super().__init__(setting)
        self.name = 't850d3sm-weyn'
        self.resunet = resunet(setting.input_span * (len(setting.vars) + 2) + self.constant_size + 2, 1,
                            spatial=(32, 64+2), layers=5, ratio=-1,
                            vblks=[2, 2, 2, 2, 2], hblks=[1, 1, 1, 1, 1],
                            scales=[-1, -1, -1, -1, -1], factors=[1, 1, 1, 1, 1],
                            block=SEBottleneck, relu=CappingRelu(), final_normalized=False)

    def forward(self, **kwargs):
        batch_size = kwargs['temperature'].size()[0]
        self.update_da_status(batch_size)

        _, input = self.get_inputs(**kwargs)
        constant = self.get_augmented_constant(input)
        input = th.cat((input, constant), dim=1)
        input = th.cat((input[:, :, :, 63:64], input, input[:, :, :, 0:1]), dim=3)

        output = self.resunet(input)

        output = output[:, :, :, 1:65]
        return {
            't850': output
        }
Clone this wiki locally