diff --git a/activeft/sift.py b/activeft/sift.py index 22758d7..a76806b 100644 --- a/activeft/sift.py +++ b/activeft/sift.py @@ -56,6 +56,7 @@ def __init__( lambda_: float = 0.01, fast: bool = False, also_query_opposite: bool = True, + alpha: float | None = None, only_faiss: bool = False, device: torch.device | None = None, ): @@ -65,11 +66,13 @@ def __init__( :param lambda_: Value of the lambda parameter of SIFT. Ignored if `acquisition_function` is set. :param fast: Whether to use the SIFT-Fast. Ignored if `acquisition_function` is set. :param also_query_opposite: If using an inner product index, setting this to `True` will also query the opposite of the query embeddings, pre-selecting points with high *absolute* inner product. + :param alpha: Adaptive stopping criterion. Does not apply stopping criterion if set to `None`. :param only_faiss: Whether to only use Faiss for search. :param device: Device to use for computation. """ self.index = index self.also_query_opposite = also_query_opposite + self.alpha = alpha self.only_faiss = only_faiss self.device = ( device @@ -100,6 +103,11 @@ def __init__( noise_std=np.sqrt(lambda_), ) + if self.alpha is not None: + assert isinstance(self.acquisition_function, LazyVTL) or isinstance( + self.acquisition_function, VTL + ), "Adaptive SIFT can only be used with VTL acquisition function." + def search( self, query: np.ndarray, @@ -189,11 +197,22 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: acquisition_function=self.acquisition_function, device=self.device, ).next() - return ( - np.array(values), - np.array(I[i][sub_indexes]), - np.array(V[i][sub_indexes]), - ) + values = np.array(values) + indices = np.array(I[i][sub_indexes]) + embeddings = np.array(V[i][sub_indexes]) + + # Adaptive SIFT + if self.alpha is not None: + uncertainty = np.sqrt(-values) + iteration = np.arange(uncertainty.shape[0]) + 1 + stopped = uncertainty > 1 / (self.alpha * iteration) + values, indices, embeddings = ( + values[~stopped], + indices[~stopped], + embeddings[~stopped], + ) + + return values, indices, embeddings t_start = time.time() resulting_values = [] @@ -210,9 +229,12 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: resulting_embeddings.append(embeddings) t_sift = time.time() - t_start retrieval_time = RetrievalTime(faiss=t_faiss, sift=t_sift) + dtype = ( + None if self.alpha is None else object + ) # Array of adaptive SIFT might have inconsistent lengths return ( - np.array(resulting_values), - np.array(resulting_indices), - np.array(resulting_embeddings), + np.array(resulting_values, dtype=dtype), + np.array(resulting_indices, dtype=dtype), + np.array(resulting_embeddings, dtype=dtype), retrieval_time, )