Skip to content

Commit

Permalink
fix bilinear basis for miso systems
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonrljr committed Feb 2, 2025
1 parent bc87dd5 commit 46bed4d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
32 changes: 21 additions & 11 deletions sysidentpy/basis_function/_bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings

from itertools import combinations_with_replacement
from itertools import combinations_with_replacement, chain
from typing import Optional
import numpy as np

Expand Down Expand Up @@ -101,19 +101,29 @@ def fit(
"In this case, you have a linear polynomial model.",
stacklevel=2,
)
else:
ny = self.get_max_ylag(ylag)
nx = self.get_max_xlag(xlag)
combination_ylag = list(
combinations_with_replacement(list(range(1, ny + 1)), self.degree)
)
combination_xlag = list(

ny = self.get_max_ylag(ylag)
combination_ylag = list(
combinations_with_replacement(list(range(1, ny + 1)), self.degree)
)
if isinstance(xlag, int):
xlag = [xlag]

combination_xlag = []
ni = 0
for lag in xlag:
nx = self.get_max_xlag(lag)
combination_lag = list(
combinations_with_replacement(
list(range(ny + 1, nx + ny + 1)), self.degree
list(range(ny + 1 + ni, nx + ny + 1 + ni)), self.degree
)
)
combinations_xy = combination_xlag + combination_ylag
combination_list = list(set(combination_list) - set(combinations_xy))
combination_xlag.append(combination_lag)
ni += nx

combination_xlag = list(chain.from_iterable(combination_xlag))
combinations_xy = combination_xlag + combination_ylag
combination_list = list(set(combination_list) - set(combinations_xy))

if predefined_regressors is not None:
combination_list = [
Expand Down
24 changes: 16 additions & 8 deletions sysidentpy/basis_function/basis_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,29 @@ def get_max_ylag(self, ylag: int = 1):
return ny

def get_max_xlag(self, xlag: int = 1):
"""Get maximum xlag.
"""Get maximum value from various xlag structures.
Parameters
----------
xlag : ndarray of int
The range of lags according to user definition.
xlag : int, list of int, or nested list of int
Input that can be a single integer, a list, or a nested list.
Returns
-------
nx : list
Maximum value of xlag.
int
Maximum value found.
"""
nx = np.max(list(chain.from_iterable([[np.array(xlag, dtype=object)]])))
return nx
if isinstance(xlag, int): # Case 1: Single integer
return xlag

if isinstance(xlag, list):
# Case 2: Flat list of integers
if all(isinstance(i, int) for i in xlag):
return max(xlag)
# Case 3: Nested list
return max(chain.from_iterable(xlag))

raise ValueError("Unsupported data type for xlag")

def get_iterable_list(
self, ylag: int = 1, xlag: int = 1, model_type: str = "NARMAX"
Expand Down

0 comments on commit 46bed4d

Please sign in to comment.