Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisholder committed Mar 7, 2025
1 parent 26c309d commit 9f2e010
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions aeon/classification/distance_based/_time_series_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from aeon.classification.base import BaseClassifier
from aeon.distances import pairwise_distance
from aeon.utils.validation import check_n_jobs
from aeon.utils._threading import threaded

WEIGHTS_SUPPORTED = ["uniform", "distance"]

Expand Down Expand Up @@ -48,11 +48,10 @@ class KNeighborsTimeSeriesClassifier(BaseClassifier):
n_timepoints)`` as input and returns a float.
distance_params : dict, default = None
Dictionary for metric parameters for the case that distance is a str.
n_jobs : int, default = None
The number of parallel jobs to run for neighbors search.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors.
for more details. Parameter for compatibility purposes, still unimplemented.
n_jobs : int, default=1
The number of jobs to run in parallel. If -1, then the number of jobs is set
to the number of CPU cores. If 1, then the function is executed in a single
thread. If greater than 1, then the function is executed in parallel.
Examples
--------
Expand Down Expand Up @@ -164,10 +163,11 @@ def _predict(self, X):
"""
self._check_is_fitted()

indexes = self.kneighbors(X, return_distance=False)[:, 0]
indexes = self.kneighbors(X, return_distance=False, n_jobs=self.n_jobs)[:, 0]
return self.classes_[self.y_[indexes]]

def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
@threaded
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, n_jobs=1):
"""Find the K-neighbors of a point.
Returns indices of and distances to the neighbors of each point.
Expand All @@ -184,6 +184,10 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
passed to the constructor.
return_distance : bool, default=True
Whether or not to return the distances.
n_jobs : int, default=1
The number of jobs to run in parallel. If -1, then the number of jobs is set
to the number of CPU cores. If 1, then the function is executed in a single
thread. If greater than 1, then the function is executed in parallel.
Returns
-------
Expand All @@ -194,8 +198,6 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
Indices of the nearest points in the population matrix.
"""
self._check_is_fitted()
n_jobs = check_n_jobs(self.n_jobs)

if n_neighbors is None:
n_neighbors = self.n_neighbors
elif n_neighbors <= 0:
Expand Down

0 comments on commit 9f2e010

Please sign in to comment.