Skip to content

Commit

Permalink
Resolved conflicts with boost query
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-au-922 committed Apr 24, 2024
2 parents 1d88abf + c74990a commit ad6b7e7
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 22 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,21 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.9]
python-version: ["3.12"]
allow-prereleases: [false]
include:
- os: ubuntu-latest
python-version: "3.12"
python-version: "3.13"
allow-prereleases: true
- os: ubuntu-latest
python-version: "3.11"
allow-prereleases: false
- os: ubuntu-latest
python-version: "3.10"
allow-prereleases: false
- os: ubuntu-latest
python-version: 3.8
allow-prereleases: false
runs-on: "${{ matrix.os }}"
steps:
- name: Harden Runner
Expand All @@ -72,6 +77,7 @@ jobs:
- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # 5.1.0
with:
python-version: ${{ matrix.python-version }}
allow-prereleases: ${{ matrix.allow-prereleases }}

- uses: dtolnay/rust-toolchain@bb45937a053e097f8591208d8e74c90db1873d07
with:
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nox


@nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"])
@nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"])
def test(session):
session.install("-rrequirements-dev.txt")
session.install("-e", ".", "--no-build-isolation")
Expand Down
72 changes: 54 additions & 18 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{get_field, make_term, to_pyerr, Schema};
use pyo3::{
exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple,
exceptions,
prelude::*,
types::{PyAny, PyFloat, PyString, PyTuple},
};
use tantivy as tv;

Expand Down Expand Up @@ -135,6 +137,57 @@ impl Query {
})
}

#[staticmethod]
#[pyo3(signature = (subqueries))]
pub(crate) fn boolean_query(
subqueries: Vec<(Occur, Query)>,
) -> PyResult<Query> {
let dyn_subqueries = subqueries
.into_iter()
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
.collect::<Vec<_>>();

let inner = tv::query::BooleanQuery::from(dyn_subqueries);

Ok(Query {
inner: Box::new(inner),
})
}

/// Construct a Tantivy's DisjunctionMaxQuery
#[staticmethod]
pub(crate) fn disjunction_max_query(
subqueries: Vec<Query>,
tie_breaker: Option<&PyFloat>,
) -> PyResult<Query> {
let inner_queries: Vec<Box<dyn tv::query::Query>> = subqueries
.iter()
.map(|query| query.inner.box_clone())
.collect();

let dismax_query = if let Some(tie_breaker) = tie_breaker {
tv::query::DisjunctionMaxQuery::with_tie_breaker(
inner_queries,
tie_breaker.extract::<f32>()?,
)
} else {
tv::query::DisjunctionMaxQuery::new(inner_queries)
};

Ok(Query {
inner: Box::new(dismax_query),
})
}

#[staticmethod]
#[pyo3(signature = (query, boost))]
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
let inner = tv::query::BoostQuery::new(query.inner, boost);
Ok(Query {
inner: Box::new(inner),
})
}

#[staticmethod]
#[pyo3(signature = (schema, field_name, regex_pattern))]
pub(crate) fn regex_query(
Expand All @@ -153,21 +206,4 @@ impl Query {
Err(e) => Err(to_pyerr(e)),
}
}

#[staticmethod]
#[pyo3(signature = (subqueries))]
pub(crate) fn boolean_query(
subqueries: Vec<(Occur, Query)>,
) -> PyResult<Query> {
let dyn_subqueries = subqueries
.into_iter()
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
.collect::<Vec<_>>();

let inner = tv::query::BooleanQuery::from(dyn_subqueries);

Ok(Query {
inner: Box::new(inner),
})
}
}
9 changes: 9 additions & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ class Query:
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
pass

@staticmethod
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
pass

@staticmethod
def boost_query(query: Query, boost: float) -> Query:
pass


@staticmethod
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
pass
Expand Down
120 changes: 119 additions & 1 deletion tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_parse_query_field_boosts(self, ram_index):
== """Query(BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "winter")), boost=2.3)), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
)

def test_parse_query_field_boosts(self, ram_index):
def test_parse_query_fuzzy_fields(self, ram_index):
query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)})
assert (
repr(query)
Expand Down Expand Up @@ -879,6 +879,124 @@ def test_boolean_query(self, ram_index):
(query1, Occur.Must),
])

def test_disjunction_max_query(self, ram_index):
index = ram_index

# query1 should match the doc: "The Old Man and the Sea"
query1 = Query.term_query(index.schema, "title", "sea")
# query2 should matches the doc: "Of Mice and Men"
query2 = Query.term_query(index.schema, "title", "mice")
# the disjunction max query should match both docs.
query = Query.disjunction_max_query([query1, query2])

result = index.searcher().search(query, 10)
assert len(result.hits) == 2

# the disjunction max query should also take a tie_breaker parameter
query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5)
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)


def test_boost_query(self, ram_index):
index = ram_index
query1 = Query.term_query(index.schema, "title", "sea")
boosted_query = Query.boost_query(query1, 2.0)

# Normal boost query
assert (
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))"""
)

query2 = Query.fuzzy_term_query(index.schema, "title", "ice")
combined_query = Query.boolean_query([
(Occur.Should, boosted_query),
(Occur.Should, query2)
])
boosted_query = Query.boost_query(combined_query, 2.0)

# Boosted boolean query
assert (
repr(boosted_query)
== """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))"""
)

boosted_query = Query.boost_query(query1, 0.1)

# Check for decimal boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))"""
)

boosted_query = Query.boost_query(query1, 0.0)

# Check for zero boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))"""
)
result = index.searcher().search(boosted_query, 10)
for _score, _ in result.hits:
# the score should be 0.0
assert _score == pytest.approx(0.0)

boosted_query = Query.boost_query(
Query.boost_query(
query1, 0.1
), 0.1
)

# Check for nested boost queries
assert(
repr(boosted_query)
== """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))"""
)
result = index.searcher().search(boosted_query, 10)
for _score, _ in result.hits:
# the score should be very small, due to
# the unknown score of BM25, we can only check for the relative difference
assert _score == pytest.approx(0.01, rel = 1)


boosted_query = Query.boost_query(
query1, -0.1
)

# Check for negative boost values
assert(
repr(boosted_query)
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))"""
)

result = index.searcher().search(boosted_query, 10)
# Even with a negative boost, the query should still match the document
assert len(result.hits) == 1
titles = set()
for _score, doc_address in result.hits:

# the score should be negative
assert _score < 0
titles.update(index.searcher().doc(doc_address)["title"])
assert titles == {"The Old Man and the Sea"}

# wrong query type
with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"):
Query.boost_query(1, 0.1)

# wrong boost type
with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"):
Query.boost_query(query1, "0.1")

# no boost type error
with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"):
Query.boost_query(query1)


def test_regex_query(self, ram_index):
index = ram_index

Expand Down

0 comments on commit ad6b7e7

Please sign in to comment.