From 8f5b149386da524a1b0eac4078a3704362b99db6 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Thu, 1 Feb 2024 14:18:38 +0100 Subject: [PATCH] Add field_boosts and fuzzy_fields optional parameters to Index::parse_query to expose this QueryParser functionality. --- src/index.rs | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/index.rs b/src/index.rs index 4636d24e6..5195a117b 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,5 +1,7 @@ #![allow(clippy::new_ret_no_self)] +use std::collections::HashMap; + use pyo3::{exceptions, prelude::*, types::PyAny}; use crate::{ @@ -358,14 +360,26 @@ impl Index { /// /// Args: /// query: the query, following the tantivy query language. + /// /// default_fields_names (List[Field]): A list of fields used to search if no /// field is specified in the query. /// - #[pyo3(signature = (query, default_field_names = None))] + /// field_boosts: A dictionary keyed on field names which provides default boosts + /// for the query constructed by this method. + /// + /// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one) + /// triples making queries constructed by this method fuzzy against the given fields + /// and using the given parameters. + /// `prefix` determines if terms which prefixes of the given term match the query. + /// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term. + /// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance. + #[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))] pub fn parse_query( &self, query: &str, default_field_names: Option>, + field_boosts: HashMap, + fuzzy_fields: HashMap, ) -> PyResult { let mut default_fields = vec![]; let schema = self.index.schema(); @@ -394,8 +408,28 @@ impl Index { } } } - let parser = + let mut parser = tv::query::QueryParser::for_index(&self.index, default_fields); + + for (field_name, boost) in field_boosts { + let field = schema.get_field(&field_name).map_err(|_err| { + exceptions::PyValueError::new_err(format!( + "Field `{field_name}` is not defined in the schema." + )) + })?; + parser.set_field_boost(field, boost); + } + + for (field_name, (prefix, distance, transpose_cost_one)) in fuzzy_fields + { + let field = schema.get_field(&field_name).map_err(|_err| { + exceptions::PyValueError::new_err(format!( + "Field `{field_name}` is not defined in the schema." + )) + })?; + parser.set_field_fuzzy(field, prefix, distance, transpose_cost_one); + } + let query = parser.parse_query(query).map_err(to_pyerr)?; Ok(Query { inner: query })