Skip to content

Commit

Permalink
Fix mypy warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Giulia Baldini committed May 31, 2024
1 parent d2a6413 commit 66dcbbc
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions bico/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def _fit(
has_bico_obj = getattr(self, "bico_obj_", None)
first_call = not (partial and has_bico_obj)

X = np.array(X, dtype=np.float64, order="C", copy=False)
assert isinstance(X, np.ndarray)
_X = np.array(X, dtype=np.float64, order="C", copy=False)
assert isinstance(_X, np.ndarray)

if first_call:
self.n_features_in_ = X.shape[1]
self.n_features_in_: int = _X.shape[1]
_seed = int(time.time()) if self.random_state is None else self.random_state

# In Melanie's thesis, p = d
Expand All @@ -181,10 +181,10 @@ def _fit(
c_random_state = ctypes.c_size_t(_seed)

_DLL.init.restype = ctypes.POINTER(ctypes.c_void_p)
self.bico_obj_ = _DLL.init(c_d, c_k, c_p, c_m, c_random_state)
self.bico_obj_: Any = _DLL.init(c_d, c_k, c_p, c_m, c_random_state)

c_array = X.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
c_n = ctypes.c_uint(X.shape[0])
c_array = _X.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
c_n = ctypes.c_uint(_X.shape[0])
_DLL.addData(self.bico_obj_, c_array, c_n)

if not partial or fit_coreset:
Expand All @@ -203,7 +203,7 @@ def fit_predict(
self._fit(X, partial=False, fit_coreset=True)
return self.labels_

def predict(self, X: Sequence[Sequence[float]]) -> np.ndarray:
def predict(self, X: Sequence[Sequence[float]]) -> Any:
self._fit_coreset()

if self.coreset_estimator is None:
Expand Down

0 comments on commit 66dcbbc

Please sign in to comment.