You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/04/15 16:40:07 UTC
[arrow-datafusion] branch main updated: update count_wildcard_rule for more scenario (#6010)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6ed655e4dc update count_wildcard_rule for more scenario (#6010)
6ed655e4dc is described below
commit 6ed655e4dc5ad89436ef94d1e441fdcf0ce8800e
Author: zhenxing jiang <ji...@gmail.com>
AuthorDate: Sun Apr 16 00:40:01 2023 +0800
update count_wildcard_rule for more scenario (#6010)
---
datafusion/core/tests/dataframe.rs | 235 +++++++++++--
.../optimizer/src/analyzer/count_wildcard_rule.rs | 372 ++++++++++++++++++++-
datafusion/optimizer/src/test/mod.rs | 12 +
3 files changed, 576 insertions(+), 43 deletions(-)
diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs
index 4b2daa100b..82c5a35b7f 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -32,24 +32,169 @@ use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::prelude::JoinType;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
+use datafusion::test_util::parquet_test_data;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
+use datafusion_common::ScalarValue;
use datafusion_expr::expr::{GroupingSet, Sort};
-use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
+use datafusion_expr::Expr::Wildcard;
+use datafusion_expr::{
+ avg, col, count, exists, expr, in_subquery, lit, max, scalar_subquery, sum,
+ AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
+ WindowFrameUnits, WindowFunction,
+};
#[tokio::test]
-async fn count_wildcard() -> Result<()> {
- let ctx = SessionContext::new();
- let testdata = datafusion::test_util::parquet_test_data();
+async fn test_count_wildcard_on_sort() -> Result<()> {
+ let ctx = create_join_context()?;
- ctx.register_parquet(
- "alltypes_tiny_pages",
- &format!("{testdata}/alltypes_tiny_pages.parquet"),
- ParquetReadOptions::default(),
- )
- .await?;
+ let sql_results = ctx
+ .sql("select b,count(*) from t1 group by b order by count(*)")
+ .await?
+ .explain(false, false)?
+ .collect()
+ .await?;
+
+ let df_results = ctx
+ .table("t1")
+ .await?
+ .aggregate(vec![col("b")], vec![count(Wildcard)])?
+ .sort(vec![count(Wildcard).sort(true, false)])?
+ .explain(false, false)?
+ .collect()
+ .await?;
+ //make sure sql plan same with df plan
+ assert_eq!(
+ pretty_format_batches(&sql_results)?.to_string(),
+ pretty_format_batches(&df_results)?.to_string()
+ );
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_where_in() -> Result<()> {
+ let ctx = create_join_context()?;
+ let sql_results = ctx
+ .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
+ .await?
+ .explain(false, false)?
+ .collect()
+ .await?;
+ // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
+ // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
+ // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
+ let ctx = create_join_context()?;
+ let df_results = ctx
+ .table("t1")
+ .await?
+ .filter(in_subquery(
+ col("a"),
+ Arc::new(
+ ctx.table("t2")
+ .await?
+ .aggregate(vec![], vec![count(Expr::Wildcard)])?
+ .select(vec![count(Expr::Wildcard)])?
+ .into_unoptimized_plan(),
+ // Usually, into_optimized_plan() should be used here, but due to
+ // https://github.com/apache/arrow-datafusion/issues/5771,
+ // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here.
+ ),
+ ))?
+ .select(vec![col("a"), col("b")])?
+ .explain(false, false)?
+ .collect()
+ .await?;
+
+ // make sure sql plan same with df plan
+ assert_eq!(
+ pretty_format_batches(&sql_results)?.to_string(),
+ pretty_format_batches(&df_results)?.to_string()
+ );
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_where_exist() -> Result<()> {
+ let ctx = create_join_context()?;
let sql_results = ctx
- .sql("select count(*) from alltypes_tiny_pages")
+ .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
+ .await?
+ .explain(false, false)?
+ .collect()
+ .await?;
+ let df_results = ctx
+ .table("t1")
+ .await?
+ .filter(exists(Arc::new(
+ ctx.table("t2")
+ .await?
+ .aggregate(vec![], vec![count(Expr::Wildcard)])?
+ .select(vec![count(Expr::Wildcard)])?
+ .into_unoptimized_plan(),
+ // Usually, into_optimized_plan() should be used here, but due to
+ // https://github.com/apache/arrow-datafusion/issues/5771,
+ // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here.
+ )))?
+ .select(vec![col("a"), col("b")])?
+ .explain(false, false)?
+ .collect()
+ .await?;
+
+ //make sure sql plan same with df plan
+ assert_eq!(
+ pretty_format_batches(&sql_results)?.to_string(),
+ pretty_format_batches(&df_results)?.to_string()
+ );
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_window() -> Result<()> {
+ let ctx = create_join_context()?;
+
+ let sql_results = ctx
+ .sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1")
+ .await?
+ .explain(false, false)?
+ .collect()
+ .await?;
+ let df_results = ctx
+ .table("t1")
+ .await?
+ .select(vec![Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Count),
+ vec![Expr::Wildcard],
+ vec![],
+ vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
+ WindowFrame {
+ units: WindowFrameUnits::Range,
+ start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+ end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+ },
+ ))])?
+ .explain(false, false)?
+ .collect()
+ .await?;
+
+ //make sure sql plan same with df plan
+ assert_eq!(
+ pretty_format_batches(&df_results)?.to_string(),
+ pretty_format_batches(&sql_results)?.to_string()
+ );
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_aggregate() -> Result<()> {
+ let ctx = create_join_context()?;
+ register_alltypes_tiny_pages_parquet(&ctx).await?;
+
+ let sql_results = ctx
+ .sql("select count(*) from t1")
.await?
.select(vec![count(Expr::Wildcard)])?
.explain(false, false)?
@@ -58,7 +203,7 @@ async fn count_wildcard() -> Result<()> {
// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
let df_results = ctx
- .table("alltypes_tiny_pages")
+ .table("t1")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.select(vec![count(Expr::Wildcard)])?
@@ -72,24 +217,51 @@ async fn count_wildcard() -> Result<()> {
pretty_format_batches(&df_results)?.to_string()
);
- let results = ctx
- .table("alltypes_tiny_pages")
+ Ok(())
+}
+#[tokio::test]
+async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
+ let ctx = create_join_context()?;
+
+ let sql_results = ctx
+ .sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;")
.await?
- .aggregate(vec![], vec![count(Expr::Wildcard)])?
+ .explain(false, false)?
.collect()
.await?;
- let expected = vec![
- "+-----------------+",
- "| COUNT(UInt8(1)) |",
- "+-----------------+",
- "| 7300 |",
- "+-----------------+",
- ];
- assert_batches_sorted_eq!(expected, &results);
+ // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
+ // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
+ // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
+ let ctx = create_join_context()?;
+ let df_results = ctx
+ .table("t1")
+ .await?
+ .filter(
+ scalar_subquery(Arc::new(
+ ctx.table("t2")
+ .await?
+ .filter(col("t1.a").eq(col("t2.a")))?
+ .aggregate(vec![], vec![count(lit(COUNT_STAR_EXPANSION))])?
+ .select(vec![count(lit(COUNT_STAR_EXPANSION))])?
+ .into_unoptimized_plan(),
+ ))
+ .gt(lit(ScalarValue::UInt8(Some(0)))),
+ )?
+ .select(vec![col("t1.a"), col("t1.b")])?
+ .explain(false, false)?
+ .collect()
+ .await?;
+
+ //make sure sql plan same with df plan
+ assert_eq!(
+ pretty_format_batches(&sql_results)?.to_string(),
+ pretty_format_batches(&df_results)?.to_string()
+ );
Ok(())
}
+
#[tokio::test]
async fn describe() -> Result<()> {
let ctx = SessionContext::new();
@@ -229,7 +401,7 @@ async fn sort_on_unprojected_columns() -> Result<()> {
let results = df.collect().await.unwrap();
#[rustfmt::skip]
- let expected = vec![
+ let expected = vec![
"+-----+",
"| a |",
"+-----+",
@@ -275,7 +447,7 @@ async fn sort_on_distinct_columns() -> Result<()> {
let results = df.collect().await.unwrap();
#[rustfmt::skip]
- let expected = vec![
+ let expected = vec![
"+-----+",
"| a |",
"+-----+",
@@ -417,7 +589,7 @@ async fn filter_with_alias_overwrite() -> Result<()> {
let results = df.collect().await.unwrap();
#[rustfmt::skip]
- let expected = vec![
+ let expected = vec![
"+------+",
"| a |",
"+------+",
@@ -1047,3 +1219,14 @@ async fn table_with_nested_types(n: usize) -> Result<DataFrame> {
ctx.register_batch("shapes", batch)?;
ctx.table("shapes").await
}
+
+pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> {
+ let testdata = parquet_test_data();
+ ctx.register_parquet(
+ "alltypes_tiny_pages",
+ &format!("{testdata}/alltypes_tiny_pages.parquet"),
+ ParquetReadOptions::default(),
+ )
+ .await?;
+ Ok(())
+}
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index ecd00d7ac1..ed48da7fde 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -16,14 +16,22 @@
// under the License.
use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::{Transformed, TreeNode};
-use datafusion_common::Result;
+use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
+use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window};
+use datafusion_expr::Expr::{Exists, InSubquery, ScalarSubquery};
+use datafusion_expr::{
+ aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter,
+ LogicalPlan, Projection, Sort, Subquery, Window,
+};
+use std::string::ToString;
+use std::sync::Arc;
use crate::analyzer::AnalyzerRule;
+pub const COUNT_STAR: &str = "COUNT(*)";
+
/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473.
#[derive(Default)]
@@ -46,35 +54,116 @@ impl AnalyzerRule for CountWildcardRule {
}
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
+ let mut rewriter = CountWildcardRewriter {};
match plan {
LogicalPlan::Window(window) => {
- let window_expr = handle_wildcard(&window.window_expr);
+ let window_expr = window
+ .window_expr
+ .iter()
+ .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+ .collect::<Vec<Expr>>();
+
Ok(Transformed::Yes(LogicalPlan::Window(Window {
input: window.input.clone(),
window_expr,
- schema: window.schema,
+ schema: rewrite_schema(&window.schema),
})))
}
LogicalPlan::Aggregate(agg) => {
- let aggr_expr = handle_wildcard(&agg.aggr_expr);
+ let aggr_expr = agg
+ .aggr_expr
+ .iter()
+ .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+ .collect();
+
Ok(Transformed::Yes(LogicalPlan::Aggregate(
Aggregate::try_new_with_schema(
agg.input.clone(),
agg.group_expr.clone(),
aggr_expr,
- agg.schema,
+ rewrite_schema(&agg.schema),
)?,
)))
}
+ LogicalPlan::Sort(Sort { expr, input, fetch }) => {
+ let sort_expr = expr
+ .iter()
+ .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+ .collect();
+ Ok(Transformed::Yes(LogicalPlan::Sort(Sort {
+ expr: sort_expr,
+ input,
+ fetch,
+ })))
+ }
+ LogicalPlan::Projection(projection) => {
+ let projection_expr = projection
+ .expr
+ .iter()
+ .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+ .collect();
+ Ok(Transformed::Yes(LogicalPlan::Projection(
+ Projection::try_new_with_schema(
+ projection_expr,
+ projection.input,
+ // rewrite_schema(projection.schema.clone()),
+ rewrite_schema(&projection.schema),
+ )?,
+ )))
+ }
+ LogicalPlan::Filter(Filter {
+ predicate, input, ..
+ }) => {
+ let predicate = predicate.rewrite(&mut rewriter).unwrap();
+ Ok(Transformed::Yes(LogicalPlan::Filter(
+ Filter::try_new(predicate, input).unwrap(),
+ )))
+ }
+
_ => Ok(Transformed::No(plan)),
}
}
-// handle Count(Expr:Wildcard) with DataFrame API
-pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
- exprs
- .iter()
- .map(|expr| match expr {
+struct CountWildcardRewriter {}
+
+impl TreeNodeRewriter for CountWildcardRewriter {
+ type N = Expr;
+
+ fn mutate(&mut self, old_expr: Expr) -> Result<Expr> {
+ let new_expr = match old_expr.clone() {
+ Expr::Column(Column { name, relation }) if name.contains(COUNT_STAR) => {
+ Expr::Column(Column {
+ name: name.replace(
+ COUNT_STAR,
+ count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
+ ),
+ relation: relation.clone(),
+ })
+ }
+ Expr::WindowFunction(expr::WindowFunction {
+ fun:
+ window_function::WindowFunction::AggregateFunction(
+ aggregate_function::AggregateFunction::Count,
+ ),
+ args,
+ partition_by,
+ order_by,
+ window_frame,
+ }) if args.len() == 1 => match args[0] {
+ Expr::Wildcard => {
+ Expr::WindowFunction(datafusion_expr::expr::WindowFunction {
+ fun: window_function::WindowFunction::AggregateFunction(
+ aggregate_function::AggregateFunction::Count,
+ ),
+ args: vec![lit(COUNT_STAR_EXPANSION)],
+ partition_by,
+ order_by,
+ window_frame,
+ })
+ }
+
+ _ => old_expr,
+ },
Expr::AggregateFunction(AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
args,
@@ -84,12 +173,261 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
args: vec![lit(COUNT_STAR_EXPANSION)],
- distinct: *distinct,
- filter: filter.clone(),
+ distinct,
+ filter,
}),
- _ => expr.clone(),
+ _ => old_expr,
},
- _ => expr.clone(),
+
+ ScalarSubquery(Subquery {
+ subquery,
+ outer_ref_columns,
+ }) => {
+ let new_plan = subquery
+ .as_ref()
+ .clone()
+ .transform_down(&analyze_internal)
+ .unwrap();
+ ScalarSubquery(Subquery {
+ subquery: Arc::new(new_plan),
+ outer_ref_columns,
+ })
+ }
+ InSubquery {
+ expr,
+ subquery,
+ negated,
+ } => {
+ let new_plan = subquery
+ .subquery
+ .as_ref()
+ .clone()
+ .transform_down(&analyze_internal)
+ .unwrap();
+
+ InSubquery {
+ expr,
+ subquery: Subquery {
+ subquery: Arc::new(new_plan),
+ outer_ref_columns: subquery.outer_ref_columns,
+ },
+ negated,
+ }
+ }
+ Exists { subquery, negated } => {
+ let new_plan = subquery
+ .subquery
+ .as_ref()
+ .clone()
+ .transform_down(&analyze_internal)
+ .unwrap();
+
+ Exists {
+ subquery: Subquery {
+ subquery: Arc::new(new_plan),
+ outer_ref_columns: subquery.outer_ref_columns,
+ },
+ negated,
+ }
+ }
+ _ => old_expr,
+ };
+ Ok(new_expr)
+ }
+}
+fn rewrite_schema(schema: &DFSchema) -> DFSchemaRef {
+ let new_fields = schema
+ .fields()
+ .iter()
+ .map(|field| {
+ let mut name = field.field().name().clone();
+ if name.contains(COUNT_STAR) {
+ name = name.replace(
+ COUNT_STAR,
+ count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
+ );
+ }
+ DFField::new(
+ field.qualifier().cloned(),
+ &name,
+ field.data_type().clone(),
+ field.is_nullable(),
+ )
})
- .collect()
+ .collect::<Vec<DFField>>();
+ DFSchemaRef::new(
+ DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(),
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::*;
+ use datafusion_common::{Result, ScalarValue};
+ use datafusion_expr::expr::Sort;
+ use datafusion_expr::{
+ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder,
+ max, scalar_subquery, AggregateFunction, Expr, WindowFrame, WindowFrameBound,
+ WindowFrameUnits, WindowFunction,
+ };
+
+ fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ assert_analyzed_plan_eq_display_indent(
+ Arc::new(CountWildcardRule::new()),
+ plan,
+ expected,
+ )
+ }
+
+ #[test]
+ fn test_count_wildcard_on_sort() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![col("b")], vec![count(Expr::Wildcard)])?
+ .project(vec![count(Expr::Wildcard)])?
+ .sort(vec![count(Expr::Wildcard).sort(true, false)])?
+ .build()?;
+ let expected = "Sort: COUNT(UInt8(1)) ASC NULLS LAST [COUNT(UInt8(1)):Int64;N]\
+ \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+ \n Aggregate: groupBy=[[test.b]], aggr=[[COUNT(UInt8(1))]] [b:UInt32, COUNT(UInt8(1)):Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn test_count_wildcard_on_where_in() -> Result<()> {
+ let table_scan_t1 = test_table_scan_with_name("t1")?;
+ let table_scan_t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(table_scan_t1)
+ .filter(in_subquery(
+ col("a"),
+ Arc::new(
+ LogicalPlanBuilder::from(table_scan_t2)
+ .aggregate(Vec::<Expr>::new(), vec![count(Expr::Wildcard)])?
+ .project(vec![count(Expr::Wildcard)])?
+ .build()?,
+ ),
+ ))?
+ .build()?;
+
+ let expected = "Filter: t1.a IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
+ \n Subquery: [COUNT(UInt8(1)):Int64;N]\
+ \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn test_count_wildcard_on_where_exists() -> Result<()> {
+ let table_scan_t1 = test_table_scan_with_name("t1")?;
+ let table_scan_t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(table_scan_t1)
+ .filter(exists(Arc::new(
+ LogicalPlanBuilder::from(table_scan_t2)
+ .aggregate(Vec::<Expr>::new(), vec![count(Expr::Wildcard)])?
+ .project(vec![count(Expr::Wildcard)])?
+ .build()?,
+ )))?
+ .build()?;
+
+ let expected = "Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
+ \n Subquery: [COUNT(UInt8(1)):Int64;N]\
+ \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
+ let table_scan_t1 = test_table_scan_with_name("t1")?;
+ let table_scan_t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(table_scan_t1)
+ .filter(
+ scalar_subquery(Arc::new(
+ LogicalPlanBuilder::from(table_scan_t2)
+ .filter(col("t1.a").eq(col("t2.a")))?
+ .aggregate(
+ Vec::<Expr>::new(),
+ vec![count(lit(COUNT_STAR_EXPANSION))],
+ )?
+ .project(vec![count(lit(COUNT_STAR_EXPANSION))])?
+ .build()?,
+ ))
+ .gt(lit(ScalarValue::UInt8(Some(0)))),
+ )?
+ .project(vec![col("t1.a"), col("t1.b")])?
+ .build()?;
+
+ let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\
+ \n Filter: (<subquery>) > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\
+ \n Subquery: [COUNT(UInt8(1)):Int64;N]\
+ \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+ \n Filter: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
+ assert_plan_eq(&plan, expected)
+ }
+ #[test]
+ fn test_count_wildcard_on_window() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .window(vec![Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Count),
+ vec![Expr::Wildcard],
+ vec![],
+ vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
+ WindowFrame {
+ units: WindowFrameUnits::Range,
+ start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(
+ 6,
+ ))),
+ end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+ },
+ ))])?
+ .project(vec![count(Expr::Wildcard)])?
+ .build()?;
+
+ let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+ \n WindowAggr: windowExpr=[[COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn test_count_wildcard_on_aggregate() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(Vec::<Expr>::new(), vec![count(Expr::Wildcard)])?
+ .project(vec![count(Expr::Wildcard)])?
+ .build()?;
+
+ let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn test_count_wildcard_on_nesting() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(Vec::<Expr>::new(), vec![max(count(Expr::Wildcard))])?
+ .project(vec![count(Expr::Wildcard)])?
+ .build()?;
+
+ let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+ \n Aggregate: groupBy=[[]], aggr=[[MAX(COUNT(UInt8(1)))]] [MAX(COUNT(UInt8(1))):Int64;N]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+ assert_plan_eq(&plan, expected)
+ }
}
diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs
index 439f44151e..67d342b4cb 100644
--- a/datafusion/optimizer/src/test/mod.rs
+++ b/datafusion/optimizer/src/test/mod.rs
@@ -121,7 +121,19 @@ pub fn assert_analyzed_plan_eq(
Ok(())
}
+pub fn assert_analyzed_plan_eq_display_indent(
+ rule: Arc<dyn AnalyzerRule + Send + Sync>,
+ plan: &LogicalPlan,
+ expected: &str,
+) -> Result<()> {
+ let options = ConfigOptions::default();
+ let analyzed_plan =
+ Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options)?;
+ let formatted_plan = format!("{}", analyzed_plan.display_indent_schema());
+ assert_eq!(formatted_plan, expected);
+ Ok(())
+}
pub fn assert_optimized_plan_eq(
rule: Arc<dyn OptimizerRule + Send + Sync>,
plan: &LogicalPlan,