Skip to content

Commit

Permalink
Adaptive SIFT (#93)
Browse files Browse the repository at this point in the history
* adaptive sift

* fix dtype
  • Loading branch information
jonhue authored Nov 23, 2024
1 parent a3ad7e5 commit 908e363
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions activeft/sift.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
lambda_: float = 0.01,
fast: bool = False,
also_query_opposite: bool = True,
alpha: float | None = None,
only_faiss: bool = False,
device: torch.device | None = None,
):
Expand All @@ -65,11 +66,13 @@ def __init__(
:param lambda_: Value of the lambda parameter of SIFT. Ignored if `acquisition_function` is set.
:param fast: Whether to use the SIFT-Fast. Ignored if `acquisition_function` is set.
:param also_query_opposite: If using an inner product index, setting this to `True` will also query the opposite of the query embeddings, pre-selecting points with high *absolute* inner product.
:param alpha: Adaptive stopping criterion. Does not apply stopping criterion if set to `None`.
:param only_faiss: Whether to only use Faiss for search.
:param device: Device to use for computation.
"""
self.index = index
self.also_query_opposite = also_query_opposite
self.alpha = alpha
self.only_faiss = only_faiss
self.device = (
device
Expand Down Expand Up @@ -100,6 +103,11 @@ def __init__(
noise_std=np.sqrt(lambda_),
)

if self.alpha is not None:
assert isinstance(self.acquisition_function, LazyVTL) or isinstance(
self.acquisition_function, VTL
), "Adaptive SIFT can only be used with VTL acquisition function."

def search(
self,
query: np.ndarray,
Expand Down Expand Up @@ -189,11 +197,22 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
acquisition_function=self.acquisition_function,
device=self.device,
).next()
return (
np.array(values),
np.array(I[i][sub_indexes]),
np.array(V[i][sub_indexes]),
)
values = np.array(values)
indices = np.array(I[i][sub_indexes])
embeddings = np.array(V[i][sub_indexes])

# Adaptive SIFT
if self.alpha is not None:
uncertainty = np.sqrt(-values)
iteration = np.arange(uncertainty.shape[0]) + 1
stopped = uncertainty > 1 / (self.alpha * iteration)
values, indices, embeddings = (
values[~stopped],
indices[~stopped],
embeddings[~stopped],
)

return values, indices, embeddings

t_start = time.time()
resulting_values = []
Expand All @@ -210,9 +229,12 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
resulting_embeddings.append(embeddings)
t_sift = time.time() - t_start
retrieval_time = RetrievalTime(faiss=t_faiss, sift=t_sift)
dtype = (
None if self.alpha is None else object
) # Array of adaptive SIFT might have inconsistent lengths
return (
np.array(resulting_values),
np.array(resulting_indices),
np.array(resulting_embeddings),
np.array(resulting_values, dtype=dtype),
np.array(resulting_indices, dtype=dtype),
np.array(resulting_embeddings, dtype=dtype),
retrieval_time,
)

0 comments on commit 908e363

Please sign in to comment.