Skip to content

Commit

Permalink
Merge pull request #5 from alamb/alamb/opt-gcd-moar
Browse files Browse the repository at this point in the history
Use try_binary to make gcd even faster
  • Loading branch information
jayzhan211 authored Feb 25, 2025
2 parents 1394240 + e8d36eb commit 1caec80
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions datafusion/functions/src/math/gcd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array};
use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array, PrimitiveArray};
use arrow::compute::try_binary;
use arrow::datatypes::{DataType, Int64Type};
use arrow::error::ArrowError;
use std::any::Any;
use std::mem::swap;
use std::sync::Arc;

use arrow::datatypes::{DataType, Int64Type};

use datafusion_common::{
arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
ScalarValue,
exec_err, internal_datafusion_err, internal_err, Result, ScalarValue,
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Expand Down Expand Up @@ -113,17 +112,20 @@ impl ScalarUDFImpl for GcdFunc {
}

fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
let result: Result<Int64Array> = a
.as_primitive::<Int64Type>()
.iter()
.zip(b.as_primitive::<Int64Type>().iter())
.map(|(a, b)| match (a, b) {
(Some(a), Some(b)) => Ok(Some(compute_gcd(a, b)?)),
_ => Ok(None),
let a = a.as_primitive::<Int64Type>();
let b = b.as_primitive::<Int64Type>();
if a.len() != b.len() {
return internal_err!(
"Length of arguments for function gcd do not match: {} vs {}",
a.len(),
b.len()
);
}
try_binary(a, b, compute_gcd)
.map(|arr: PrimitiveArray<Int64Type>| {
ColumnarValue::Array(Arc::new(arr) as ArrayRef)
})
.collect();

result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef))
.map_err(Into::into) // convert ArrowError to DataFusionError
}

fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -171,14 +173,12 @@ pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
}

/// Computes greatest common divisor using Binary GCD algorithm.
pub fn compute_gcd(x: i64, y: i64) -> Result<i64> {
pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
let a = x.unsigned_abs();
let b = y.unsigned_abs();
let r = unsigned_gcd(a, b);
// gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64
r.try_into().map_err(|_| {
arrow_datafusion_err!(ArrowError::ComputeError(format!(
"Signed integer overflow in GCD({x}, {y})"
)))
ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
})
}

0 comments on commit 1caec80

Please sign in to comment.