|
1 | 1 | import numpy as np
|
2 |
| -from typing import List |
| 2 | +from typing import List, Tuple |
3 | 3 | from statsmodels.tsa.ar_model import AutoReg
|
4 | 4 |
|
5 | 5 | from .markov_chain import MarkovChain
|
@@ -104,31 +104,60 @@ def _learn_models(
|
104 | 104 | X = ts_data[_regime_sequence == self.regimes[i]]
|
105 | 105 | self.models[self.regimes[i]] = AutoReg(X, lags=lags).fit()
|
106 | 106 |
|
107 |
| - def predict(self, start_regime: str, steps: int = 1) -> np.ndarray: |
| 107 | + def predict( |
| 108 | + self, start_regime: str, steps: int = 1 |
| 109 | + ) -> Tuple[np.ndarray, np.ndarray]: |
108 | 110 | """
|
109 |
| - Predicts the target values for given number of steps into the future. |
| 111 | + Predicts the target values for a given number of steps into the future. |
110 | 112 |
|
111 | 113 | Parameters
|
112 | 114 | ----------
|
113 |
| - start_regime : str |
114 |
| - Regime at the start of the prediction |
115 |
| - steps : int, optional |
| 115 | + start_regime: str |
| 116 | + Regime at the start of the prediction. |
| 117 | + steps: int, optional |
116 | 118 | Number of steps into the future to predict, by default 1.
|
117 | 119 |
|
118 | 120 | Returns
|
119 | 121 | -------
|
120 |
| - np.ndarray |
121 |
| - Array of predicted Target Values for each feature for each step. |
| 122 | + Tuple[np.ndarray, np.ndarray] |
| 123 | + Tuple containing the array of predicted target values and the predicted regime sequence. |
122 | 124 | """
|
123 | 125 | predictions = np.zeros(steps, dtype=np.float32)
|
124 |
| - regime_predictions = [] |
125 |
| - current_regime = start_regime |
126 |
| - regime_predictions = self._markov_chain.simulate(current_regime, steps) |
| 126 | + regime_predictions = self._markov_chain.simulate(start_regime, steps) |
| 127 | + |
| 128 | + # Initialize current values for each regime |
| 129 | + current_values = { |
| 130 | + regime: list( |
| 131 | + self.models[regime].model.endog[ |
| 132 | + -len(self.models[regime].model.ar_lags) : |
| 133 | + ] |
| 134 | + ) |
| 135 | + for regime in self.regimes |
| 136 | + } |
| 137 | + |
127 | 138 | for i, regime in enumerate(regime_predictions):
|
128 |
| - _model = self.models[regime] |
129 |
| - prediction = _model.model.predict(_model.params)[-1] |
| 139 | + model = self.models[regime] |
| 140 | + available_lags = len(current_values[regime]) |
| 141 | + |
| 142 | + if available_lags < len(model.model.ar_lags): |
| 143 | + # Use all available data if there are fewer data points than lags |
| 144 | + start_index = -available_lags |
| 145 | + prediction = model.model.predict( |
| 146 | + model.params, start=start_index, end=-1 |
| 147 | + )[0] |
| 148 | + else: |
| 149 | + # Use the full lag window if enough data points are available |
| 150 | + prediction = model.predict( |
| 151 | + start=len(model.model.endog), end=len(model.model.endog) |
| 152 | + )[0] |
| 153 | + |
130 | 154 | predictions[i] = prediction
|
131 | 155 |
|
| 156 | + # Update current values for the regime with the new prediction |
| 157 | + current_values[regime].append(prediction) |
| 158 | + if len(current_values[regime]) > len(model.model.ar_lags): |
| 159 | + current_values[regime].pop(0) |
| 160 | + |
132 | 161 | return predictions, np.array(regime_predictions)
|
133 | 162 |
|
134 | 163 | def evaluate(self, ts_test, ts_pred):
|
|
0 commit comments