diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2bbd0630..1b4b1e58 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,16 +40,21 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [3.9] + python-version: ["3.12"] + allow-prereleases: [false] include: - os: ubuntu-latest - python-version: "3.12" + python-version: "3.13" + allow-prereleases: true - os: ubuntu-latest python-version: "3.11" + allow-prereleases: false - os: ubuntu-latest python-version: "3.10" + allow-prereleases: false - os: ubuntu-latest python-version: 3.8 + allow-prereleases: false runs-on: "${{ matrix.os }}" steps: - name: Harden Runner @@ -72,6 +77,7 @@ jobs: - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # 5.1.0 with: python-version: ${{ matrix.python-version }} + allow-prereleases: ${{ matrix.allow-prereleases }} - uses: dtolnay/rust-toolchain@bb45937a053e097f8591208d8e74c90db1873d07 with: diff --git a/noxfile.py b/noxfile.py index 61652492..388d359d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,7 +1,7 @@ import nox -@nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"]) +@nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]) def test(session): session.install("-rrequirements-dev.txt") session.install("-e", ".", "--no-build-isolation") diff --git a/src/query.rs b/src/query.rs index 21ced35e..716e0b8e 100644 --- a/src/query.rs +++ b/src/query.rs @@ -2,7 +2,7 @@ use crate::{make_term, Schema}; use pyo3::{ exceptions, prelude::*, - types::{PyAny, PyString, PyTuple}, + types::{PyAny, PyFloat, PyString, PyTuple}, }; use tantivy as tv; @@ -154,8 +154,33 @@ impl Query { }) } + /// Construct a Tantivy's DisjunctionMaxQuery #[staticmethod] - #[pyo3(signature = (query, boost = 1.0))] + pub(crate) fn disjunction_max_query( + subqueries: Vec, + tie_breaker: Option<&PyFloat>, + ) -> PyResult { + let inner_queries: Vec> = subqueries + .iter() + .map(|query| query.inner.box_clone()) + .collect(); + + let dismax_query = if let Some(tie_breaker) = tie_breaker { + tv::query::DisjunctionMaxQuery::with_tie_breaker( + inner_queries, + tie_breaker.extract::()?, + ) + } else { + tv::query::DisjunctionMaxQuery::new(inner_queries) + }; + + Ok(Query { + inner: Box::new(dismax_query), + }) + } + + #[staticmethod] + #[pyo3(signature = (query, boost))] pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult { let inner = tv::query::BoostQuery::new(query.inner, boost); Ok(Query { diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 0eb745b6..710358c9 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -208,11 +208,16 @@ class Query: @staticmethod def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query: pass + + @staticmethod + def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query: + pass @staticmethod - def boost_query(query: Query, boost: float = 1.0) -> Query: + def boost_query(query: Query, boost: float) -> Query: pass + class Order(Enum): Asc = 1 Desc = 2 diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 48bf60c1..7ff8544c 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -101,7 +101,7 @@ def test_parse_query_field_boosts(self, ram_index): == """Query(BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "winter")), boost=2.3)), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })""" ) - def test_parse_query_field_boosts(self, ram_index): + def test_parse_query_fuzzy_fields(self, ram_index): query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)}) assert ( repr(query) @@ -879,6 +879,28 @@ def test_boolean_query(self, ram_index): (query1, Occur.Must), ]) + def test_disjunction_max_query(self, ram_index): + index = ram_index + + # query1 should match the doc: "The Old Man and the Sea" + query1 = Query.term_query(index.schema, "title", "sea") + # query2 should matches the doc: "Of Mice and Men" + query2 = Query.term_query(index.schema, "title", "mice") + # the disjunction max query should match both docs. + query = Query.disjunction_max_query([query1, query2]) + + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + # the disjunction max query should also take a tie_breaker parameter + query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5) + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"): + query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5) + + def test_boost_query(self, ram_index): index = ram_index query1 = Query.term_query(index.schema, "title", "sea") @@ -911,13 +933,17 @@ def test_boost_query(self, ram_index): == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))""" ) - boosted_query = Query.boost_query(query1) + boosted_query = Query.boost_query(query1, 0.0) - # Check for default boost values + # Check for zero boost values assert( repr(boosted_query) - == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=1))""" + == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))""" ) + result = index.searcher().search(boosted_query, 10) + for _score, _ in result.hits: + # the score should be 0.0 + assert _score == pytest.approx(0.0) boosted_query = Query.boost_query( Query.boost_query( @@ -965,3 +991,7 @@ def test_boost_query(self, ram_index): # wrong boost type with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"): Query.boost_query(query1, "0.1") + + # no boost type error + with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"): + Query.boost_query(query1)