Skip to content

quick start

Mingli Yuan edited this page Mar 23, 2021 · 3 revisions


pip install wxbtool

Train the default model

Prepare the .env file in your current working directory


Issue the command to fire the trainer

wxb train

Then the trainer will train the default model to predict t850 3-days later

Understanding the default model

The default model is based on ResUNet, it is a python module which can be separated into a setting part and a model part

setting part

class SettingWeyn(Setting):
    def __init__(self):
        self.resolution = '5.625deg'    # The spatial resolution of the model

        self.levels = ['300', '500', '700', '850', '1000'] # Which vertical levels to choose
        self.height = len(self.levels)                     # How many vertical levels to choose

        # The name of variables to choose, for both input features and output
        self.vars = ['geopotential', 'toa_incident_solar_radiation', '2m_temperature', 'temperature']

        # The code of variables in input features
        self.vars_in = ['z500', 'z1000', 'tau', 't850', 't2m', 'tisr']
        # The code of variables in output
        self.vars_out = ['t850']

        # temporal scopes for train
        self.years_train = [
            1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989,
            1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999,
            2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009,
            2010, 2011, 2012, 2013, 2014,
        # temporal scopes for evaluation
        self.years_eval = [2015, 2016]
        # temporal scopes for test
        self.years_test = [2017, 2018]

class Setting3d(SettingWeyn):
    def __init__(self):
        self.step = 8                   # How many hours of a hourly step which all features in organized temporally
        self.input_span = 3             # How many hourly steps for an input
        self.pred_span = 1              # How many hourly steps for a prediction
        self.pred_shift = 72            # How many hours between the end of the input span and the beginning of prediction span

class Setting5d(SettingWeyn):
    def __init__(self):
        self.step = 8                   # How many hours of a hourly step which all features in organized temporally
        self.input_span = 3             # How many hourly steps for an input
        self.pred_span = 1              # How many hourly steps for a prediction
        self.pred_shift = 120           # How many hours between the end of the input span and the beginning of prediction span

model part

class ResUNetModel(Spec):
    def __init__(self, setting):
        super().__init__(setting) = 't850d3sm-weyn'
        self.resunet = resunet(setting.input_span * (len(setting.vars) + 2) + self.constant_size + 2, 1,
                            spatial=(32, 64+2), layers=5, ratio=-1,
                            vblks=[2, 2, 2, 2, 2], hblks=[1, 1, 1, 1, 1],
                            scales=[-1, -1, -1, -1, -1], factors=[1, 1, 1, 1, 1],
                            block=SEBottleneck, relu=CappingRelu(), final_normalized=False)

    def forward(self, **kwargs):
        batch_size = kwargs['temperature'].size()[0]

        _, input = self.get_inputs(**kwargs)
        constant = self.get_augmented_constant(input)
        input =, constant), dim=1)
        input =[:, :, :, 63:64], input, input[:, :, :, 0:1]), dim=3)

        output = self.resunet(input)

        output = output[:, :, :, 1:65]
        return {
            't850': output
Clone this wiki locally