Skip to content

Commit

Permalink
Suggestion to reduce API surface area
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Mar 4, 2025
1 parent d19fafa commit 595750d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 38 deletions.
20 changes: 5 additions & 15 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ use arrow::datatypes::{
};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_batches;
use datafusion_functions_aggregate::count::{
count_all, count_all_column, count_all_window, count_all_window_column,
};
use datafusion_functions_aggregate::count::{count_all, count_all_window};
use datafusion_functions_aggregate::expr_fn::{
array_agg, avg, count, count_distinct, max, median, min, sum,
};
Expand Down Expand Up @@ -2797,16 +2795,6 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_count_wildcard_shema_name() {
assert_eq!(count_all().schema_name().to_string(), "count(*)");
assert_eq!(count_all_column(), col("count(*)"));
assert_eq!(
count_all_window_column(),
col("count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING")
);
}

#[tokio::test]
async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
let ctx = create_join_context()?;
Expand Down Expand Up @@ -2855,6 +2843,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
// https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
// for compare difference between sql and df logical plan, we need to create a new SessionContext here
let ctx = create_join_context()?;
let agg_expr = count_all();
let agg_expr_col = col(agg_expr.schema_name().to_string());
let df_results = ctx
.table("t1")
.await?
Expand All @@ -2863,8 +2853,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
ctx.table("t2")
.await?
.filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))?
.aggregate(vec![], vec![count_all()])?
.select(vec![count_all_column()])?
.aggregate(vec![], vec![agg_expr])?
.select(vec![agg_expr_col])?
.into_unoptimized_plan(),
))
.gt(lit(ScalarValue::UInt8(Some(0)))),
Expand Down
62 changes: 39 additions & 23 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ use datafusion_common::{
downcast_value, internal_err, not_impl_err, Result, ScalarValue,
};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition,
};
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
};
use datafusion_expr::{
Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition,
};
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
Expand Down Expand Up @@ -82,37 +82,53 @@ pub fn count_distinct(expr: Expr) -> Expr {
))
}

/// Creates aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
/// Alias to count(*) for backward comaptibility
/// Creates aggregation to count all rows.
///
/// In SQL this is `SELECT COUNT(*) ... `
///
/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
/// aliased to a column named `"count(*)"` for backward compatibility.
///
/// Example
/// ```
/// # use datafusion_functions_aggregate::count::count_all;
/// # use datafusion_expr::col;
/// // create `count(*)` expression
/// let expr = count_all();
/// assert_eq!(expr.schema_name().to_string(), "count(*)");
/// // if you need to refer to this column, use the `schema_name` function
/// let expr = col(expr.schema_name().to_string());
/// ```
pub fn count_all() -> Expr {
count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)")
}

/// Creates window aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
/// Creates window aggregation to count all rows.
///
/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
///
/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
///
/// Example
/// ```
/// # use datafusion_functions_aggregate::count::count_all_window;
/// # use datafusion_expr::col;
/// // create `count(*)` OVER ... window function expression
/// let expr = count_all_window();
/// assert_eq!(
/// expr.schema_name().to_string(),
/// "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
/// );
/// // if you need to refer to this column, use the `schema_name` function
/// let expr = col(expr.schema_name().to_string());
/// ```
pub fn count_all_window() -> Expr {
Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![Expr::Literal(COUNT_STAR_EXPANSION)],
))
}

/// Expr::Column(Count Wildcard Window Function)
/// Could be used in Dataframe API where you need Expr::Column of count wildcard
pub fn count_all_window_column() -> Expr {
col(Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![Expr::Literal(COUNT_STAR_EXPANSION)],
))
.schema_name()
.to_string())
}

/// Expr::Column(Count Wildcard Aggregate Function)
/// Could be used in Dataframe API where you need Expr::Column of count wildcard
pub fn count_all_column() -> Expr {
col(count_all().schema_name().to_string())
}

#[user_doc(
doc_section(label = "General Functions"),
description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
Expand Down

0 comments on commit 595750d

Please sign in to comment.