Skip to content

Commit 5d0a1b6

Browse files
authored
Merge pull request #16 from aadya940/msm-predict
[MNT] Update MarkovSwitchingModels `predict` method for lags
2 parents 3b980fe + c429a0d commit 5d0a1b6

File tree

1 file changed

+42
-13
lines changed

1 file changed

+42
-13
lines changed

chainopy/markov_switching.py

+42-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from typing import List
2+
from typing import List, Tuple
33
from statsmodels.tsa.ar_model import AutoReg
44

55
from .markov_chain import MarkovChain
@@ -104,31 +104,60 @@ def _learn_models(
104104
X = ts_data[_regime_sequence == self.regimes[i]]
105105
self.models[self.regimes[i]] = AutoReg(X, lags=lags).fit()
106106

107-
def predict(self, start_regime: str, steps: int = 1) -> np.ndarray:
107+
def predict(
108+
self, start_regime: str, steps: int = 1
109+
) -> Tuple[np.ndarray, np.ndarray]:
108110
"""
109-
Predicts the target values for given number of steps into the future.
111+
Predicts the target values for a given number of steps into the future.
110112
111113
Parameters
112114
----------
113-
start_regime : str
114-
Regime at the start of the prediction
115-
steps : int, optional
115+
start_regime: str
116+
Regime at the start of the prediction.
117+
steps: int, optional
116118
Number of steps into the future to predict, by default 1.
117119
118120
Returns
119121
-------
120-
np.ndarray
121-
Array of predicted Target Values for each feature for each step.
122+
Tuple[np.ndarray, np.ndarray]
123+
Tuple containing the array of predicted target values and the predicted regime sequence.
122124
"""
123125
predictions = np.zeros(steps, dtype=np.float32)
124-
regime_predictions = []
125-
current_regime = start_regime
126-
regime_predictions = self._markov_chain.simulate(current_regime, steps)
126+
regime_predictions = self._markov_chain.simulate(start_regime, steps)
127+
128+
# Initialize current values for each regime
129+
current_values = {
130+
regime: list(
131+
self.models[regime].model.endog[
132+
-len(self.models[regime].model.ar_lags) :
133+
]
134+
)
135+
for regime in self.regimes
136+
}
137+
127138
for i, regime in enumerate(regime_predictions):
128-
_model = self.models[regime]
129-
prediction = _model.model.predict(_model.params)[-1]
139+
model = self.models[regime]
140+
available_lags = len(current_values[regime])
141+
142+
if available_lags < len(model.model.ar_lags):
143+
# Use all available data if there are fewer data points than lags
144+
start_index = -available_lags
145+
prediction = model.model.predict(
146+
model.params, start=start_index, end=-1
147+
)[0]
148+
else:
149+
# Use the full lag window if enough data points are available
150+
prediction = model.predict(
151+
start=len(model.model.endog), end=len(model.model.endog)
152+
)[0]
153+
130154
predictions[i] = prediction
131155

156+
# Update current values for the regime with the new prediction
157+
current_values[regime].append(prediction)
158+
if len(current_values[regime]) > len(model.model.ar_lags):
159+
current_values[regime].pop(0)
160+
132161
return predictions, np.array(regime_predictions)
133162

134163
def evaluate(self, ts_test, ts_pred):

0 commit comments

Comments
 (0)