Skip to content

Commit

Permalink
Expose Tantivy's DisjunctionMaxQuery (#244)
Browse files Browse the repository at this point in the history
Co-authored-by: Caleb Hattingh <caleb.hattingh@gmail.com>
  • Loading branch information
aecio and cjrh authored Apr 24, 2024
1 parent 7651d2b commit deb88cc
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
29 changes: 28 additions & 1 deletion src/query.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{make_term, Schema};
use pyo3::{
exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple,
exceptions,
prelude::*,
types::{PyAny, PyFloat, PyString, PyTuple},
};
use tantivy as tv;

Expand Down Expand Up @@ -151,4 +153,29 @@ impl Query {
inner: Box::new(inner),
})
}

/// Construct a Tantivy's DisjunctionMaxQuery
#[staticmethod]
pub(crate) fn disjunction_max_query(
subqueries: Vec<Query>,
tie_breaker: Option<&PyFloat>,
) -> PyResult<Query> {
let inner_queries: Vec<Box<dyn tv::query::Query>> = subqueries
.iter()
.map(|query| query.inner.box_clone())
.collect();

let dismax_query = if let Some(tie_breaker) = tie_breaker {
tv::query::DisjunctionMaxQuery::with_tie_breaker(
inner_queries,
tie_breaker.extract::<f32>()?,
)
} else {
tv::query::DisjunctionMaxQuery::new(inner_queries)
};

Ok(Query {
inner: Box::new(dismax_query),
})
}
}
5 changes: 5 additions & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ class Query:
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
pass

@staticmethod
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
pass


class Order(Enum):
Asc = 1
Desc = 2
Expand Down
23 changes: 22 additions & 1 deletion tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,4 +877,25 @@ def test_boolean_query(self, ram_index):
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
Query.boolean_query([
(query1, Occur.Must),
])
])

def test_disjunction_max_query(self, ram_index):
index = ram_index

# query1 should match the doc: "The Old Man and the Sea"
query1 = Query.term_query(index.schema, "title", "sea")
# query2 should matches the doc: "Of Mice and Men"
query2 = Query.term_query(index.schema, "title", "mice")
# the disjunction max query should match both docs.
query = Query.disjunction_max_query([query1, query2])

result = index.searcher().search(query, 10)
assert len(result.hits) == 2

# the disjunction max query should also take a tie_breaker parameter
query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5)
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)

0 comments on commit deb88cc

Please sign in to comment.