You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/04/15 16:40:07 UTC

[arrow-datafusion] branch main updated: update count_wildcard_rule for more scenario (#6010)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ed655e4dc update count_wildcard_rule for more scenario (#6010)
6ed655e4dc is described below

commit 6ed655e4dc5ad89436ef94d1e441fdcf0ce8800e
Author: zhenxing jiang <ji...@gmail.com>
AuthorDate: Sun Apr 16 00:40:01 2023 +0800

    update count_wildcard_rule for more scenario (#6010)
---
 datafusion/core/tests/dataframe.rs                 | 235 +++++++++++--
 .../optimizer/src/analyzer/count_wildcard_rule.rs  | 372 ++++++++++++++++++++-
 datafusion/optimizer/src/test/mod.rs               |  12 +
 3 files changed, 576 insertions(+), 43 deletions(-)

diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs
index 4b2daa100b..82c5a35b7f 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -32,24 +32,169 @@ use datafusion::error::Result;
 use datafusion::execution::context::SessionContext;
 use datafusion::prelude::JoinType;
 use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
+use datafusion::test_util::parquet_test_data;
 use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
+use datafusion_common::ScalarValue;
 use datafusion_expr::expr::{GroupingSet, Sort};
-use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
+use datafusion_expr::Expr::Wildcard;
+use datafusion_expr::{
+    avg, col, count, exists, expr, in_subquery, lit, max, scalar_subquery, sum,
+    AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
+    WindowFrameUnits, WindowFunction,
+};
 
 #[tokio::test]
-async fn count_wildcard() -> Result<()> {
-    let ctx = SessionContext::new();
-    let testdata = datafusion::test_util::parquet_test_data();
+async fn test_count_wildcard_on_sort() -> Result<()> {
+    let ctx = create_join_context()?;
 
-    ctx.register_parquet(
-        "alltypes_tiny_pages",
-        &format!("{testdata}/alltypes_tiny_pages.parquet"),
-        ParquetReadOptions::default(),
-    )
-    .await?;
+    let sql_results = ctx
+        .sql("select b,count(*) from t1 group by b order by count(*)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .aggregate(vec![col("b")], vec![count(Wildcard)])?
+        .sort(vec![count(Wildcard).sort(true, false)])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+    //make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&sql_results)?.to_string(),
+        pretty_format_batches(&df_results)?.to_string()
+    );
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_where_in() -> Result<()> {
+    let ctx = create_join_context()?;
+    let sql_results = ctx
+        .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
 
+    // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
+    // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
+    // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
+    let ctx = create_join_context()?;
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .filter(in_subquery(
+            col("a"),
+            Arc::new(
+                ctx.table("t2")
+                    .await?
+                    .aggregate(vec![], vec![count(Expr::Wildcard)])?
+                    .select(vec![count(Expr::Wildcard)])?
+                    .into_unoptimized_plan(),
+                // Usually, into_optimized_plan() should be used here, but due to
+                // https://github.com/apache/arrow-datafusion/issues/5771,
+                // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here.
+            ),
+        ))?
+        .select(vec![col("a"), col("b")])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    // make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&sql_results)?.to_string(),
+        pretty_format_batches(&df_results)?.to_string()
+    );
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_where_exist() -> Result<()> {
+    let ctx = create_join_context()?;
     let sql_results = ctx
-        .sql("select count(*) from alltypes_tiny_pages")
+        .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .filter(exists(Arc::new(
+            ctx.table("t2")
+                .await?
+                .aggregate(vec![], vec![count(Expr::Wildcard)])?
+                .select(vec![count(Expr::Wildcard)])?
+                .into_unoptimized_plan(),
+            // Usually, into_optimized_plan() should be used here, but due to
+            // https://github.com/apache/arrow-datafusion/issues/5771,
+            // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here.
+        )))?
+        .select(vec![col("a"), col("b")])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    //make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&sql_results)?.to_string(),
+        pretty_format_batches(&df_results)?.to_string()
+    );
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_window() -> Result<()> {
+    let ctx = create_join_context()?;
+
+    let sql_results = ctx
+        .sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)  from t1")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .select(vec![Expr::WindowFunction(expr::WindowFunction::new(
+            WindowFunction::AggregateFunction(AggregateFunction::Count),
+            vec![Expr::Wildcard],
+            vec![],
+            vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
+            WindowFrame {
+                units: WindowFrameUnits::Range,
+                start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+                end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+            },
+        ))])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    //make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&df_results)?.to_string(),
+        pretty_format_batches(&sql_results)?.to_string()
+    );
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_aggregate() -> Result<()> {
+    let ctx = create_join_context()?;
+    register_alltypes_tiny_pages_parquet(&ctx).await?;
+
+    let sql_results = ctx
+        .sql("select count(*) from t1")
         .await?
         .select(vec![count(Expr::Wildcard)])?
         .explain(false, false)?
@@ -58,7 +203,7 @@ async fn count_wildcard() -> Result<()> {
 
     // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
     let df_results = ctx
-        .table("alltypes_tiny_pages")
+        .table("t1")
         .await?
         .aggregate(vec![], vec![count(Expr::Wildcard)])?
         .select(vec![count(Expr::Wildcard)])?
@@ -72,24 +217,51 @@ async fn count_wildcard() -> Result<()> {
         pretty_format_batches(&df_results)?.to_string()
     );
 
-    let results = ctx
-        .table("alltypes_tiny_pages")
+    Ok(())
+}
+#[tokio::test]
+async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
+    let ctx = create_join_context()?;
+
+    let sql_results = ctx
+        .sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;")
         .await?
-        .aggregate(vec![], vec![count(Expr::Wildcard)])?
+        .explain(false, false)?
         .collect()
         .await?;
 
-    let expected = vec![
-        "+-----------------+",
-        "| COUNT(UInt8(1)) |",
-        "+-----------------+",
-        "| 7300            |",
-        "+-----------------+",
-    ];
-    assert_batches_sorted_eq!(expected, &results);
+    // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
+    // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
+    // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
+    let ctx = create_join_context()?;
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .filter(
+            scalar_subquery(Arc::new(
+                ctx.table("t2")
+                    .await?
+                    .filter(col("t1.a").eq(col("t2.a")))?
+                    .aggregate(vec![], vec![count(lit(COUNT_STAR_EXPANSION))])?
+                    .select(vec![count(lit(COUNT_STAR_EXPANSION))])?
+                    .into_unoptimized_plan(),
+            ))
+            .gt(lit(ScalarValue::UInt8(Some(0)))),
+        )?
+        .select(vec![col("t1.a"), col("t1.b")])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    //make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&sql_results)?.to_string(),
+        pretty_format_batches(&df_results)?.to_string()
+    );
 
     Ok(())
 }
+
 #[tokio::test]
 async fn describe() -> Result<()> {
     let ctx = SessionContext::new();
@@ -229,7 +401,7 @@ async fn sort_on_unprojected_columns() -> Result<()> {
     let results = df.collect().await.unwrap();
 
     #[rustfmt::skip]
-    let expected = vec![
+        let expected = vec![
         "+-----+",
         "| a   |",
         "+-----+",
@@ -275,7 +447,7 @@ async fn sort_on_distinct_columns() -> Result<()> {
     let results = df.collect().await.unwrap();
 
     #[rustfmt::skip]
-    let expected = vec![
+        let expected = vec![
         "+-----+",
         "| a   |",
         "+-----+",
@@ -417,7 +589,7 @@ async fn filter_with_alias_overwrite() -> Result<()> {
     let results = df.collect().await.unwrap();
 
     #[rustfmt::skip]
-    let expected = vec![
+        let expected = vec![
         "+------+",
         "| a    |",
         "+------+",
@@ -1047,3 +1219,14 @@ async fn table_with_nested_types(n: usize) -> Result<DataFrame> {
     ctx.register_batch("shapes", batch)?;
     ctx.table("shapes").await
 }
+
+pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> {
+    let testdata = parquet_test_data();
+    ctx.register_parquet(
+        "alltypes_tiny_pages",
+        &format!("{testdata}/alltypes_tiny_pages.parquet"),
+        ParquetReadOptions::default(),
+    )
+    .await?;
+    Ok(())
+}
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index ecd00d7ac1..ed48da7fde 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -16,14 +16,22 @@
 // under the License.
 
 use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::{Transformed, TreeNode};
-use datafusion_common::Result;
+use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
+use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result};
 use datafusion_expr::expr::AggregateFunction;
 use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window};
+use datafusion_expr::Expr::{Exists, InSubquery, ScalarSubquery};
+use datafusion_expr::{
+    aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter,
+    LogicalPlan, Projection, Sort, Subquery, Window,
+};
+use std::string::ToString;
+use std::sync::Arc;
 
 use crate::analyzer::AnalyzerRule;
 
+pub const COUNT_STAR: &str = "COUNT(*)";
+
 /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
 /// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473.
 #[derive(Default)]
@@ -46,35 +54,116 @@ impl AnalyzerRule for CountWildcardRule {
 }
 
 fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
+    let mut rewriter = CountWildcardRewriter {};
     match plan {
         LogicalPlan::Window(window) => {
-            let window_expr = handle_wildcard(&window.window_expr);
+            let window_expr = window
+                .window_expr
+                .iter()
+                .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+                .collect::<Vec<Expr>>();
+
             Ok(Transformed::Yes(LogicalPlan::Window(Window {
                 input: window.input.clone(),
                 window_expr,
-                schema: window.schema,
+                schema: rewrite_schema(&window.schema),
             })))
         }
         LogicalPlan::Aggregate(agg) => {
-            let aggr_expr = handle_wildcard(&agg.aggr_expr);
+            let aggr_expr = agg
+                .aggr_expr
+                .iter()
+                .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+                .collect();
+
             Ok(Transformed::Yes(LogicalPlan::Aggregate(
                 Aggregate::try_new_with_schema(
                     agg.input.clone(),
                     agg.group_expr.clone(),
                     aggr_expr,
-                    agg.schema,
+                    rewrite_schema(&agg.schema),
                 )?,
             )))
         }
+        LogicalPlan::Sort(Sort { expr, input, fetch }) => {
+            let sort_expr = expr
+                .iter()
+                .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+                .collect();
+            Ok(Transformed::Yes(LogicalPlan::Sort(Sort {
+                expr: sort_expr,
+                input,
+                fetch,
+            })))
+        }
+        LogicalPlan::Projection(projection) => {
+            let projection_expr = projection
+                .expr
+                .iter()
+                .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
+                .collect();
+            Ok(Transformed::Yes(LogicalPlan::Projection(
+                Projection::try_new_with_schema(
+                    projection_expr,
+                    projection.input,
+                    // rewrite_schema(projection.schema.clone()),
+                    rewrite_schema(&projection.schema),
+                )?,
+            )))
+        }
+        LogicalPlan::Filter(Filter {
+            predicate, input, ..
+        }) => {
+            let predicate = predicate.rewrite(&mut rewriter).unwrap();
+            Ok(Transformed::Yes(LogicalPlan::Filter(
+                Filter::try_new(predicate, input).unwrap(),
+            )))
+        }
+
         _ => Ok(Transformed::No(plan)),
     }
 }
 
-// handle Count(Expr:Wildcard) with DataFrame API
-pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
-    exprs
-        .iter()
-        .map(|expr| match expr {
+struct CountWildcardRewriter {}
+
+impl TreeNodeRewriter for CountWildcardRewriter {
+    type N = Expr;
+
+    fn mutate(&mut self, old_expr: Expr) -> Result<Expr> {
+        let new_expr = match old_expr.clone() {
+            Expr::Column(Column { name, relation }) if name.contains(COUNT_STAR) => {
+                Expr::Column(Column {
+                    name: name.replace(
+                        COUNT_STAR,
+                        count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
+                    ),
+                    relation: relation.clone(),
+                })
+            }
+            Expr::WindowFunction(expr::WindowFunction {
+                fun:
+                    window_function::WindowFunction::AggregateFunction(
+                        aggregate_function::AggregateFunction::Count,
+                    ),
+                args,
+                partition_by,
+                order_by,
+                window_frame,
+            }) if args.len() == 1 => match args[0] {
+                Expr::Wildcard => {
+                    Expr::WindowFunction(datafusion_expr::expr::WindowFunction {
+                        fun: window_function::WindowFunction::AggregateFunction(
+                            aggregate_function::AggregateFunction::Count,
+                        ),
+                        args: vec![lit(COUNT_STAR_EXPANSION)],
+                        partition_by,
+                        order_by,
+                        window_frame,
+                    })
+                }
+
+                _ => old_expr,
+            },
             Expr::AggregateFunction(AggregateFunction {
                 fun: aggregate_function::AggregateFunction::Count,
                 args,
@@ -84,12 +173,261 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
                 Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
                     fun: aggregate_function::AggregateFunction::Count,
                     args: vec![lit(COUNT_STAR_EXPANSION)],
-                    distinct: *distinct,
-                    filter: filter.clone(),
+                    distinct,
+                    filter,
                 }),
-                _ => expr.clone(),
+                _ => old_expr,
             },
-            _ => expr.clone(),
+
+            ScalarSubquery(Subquery {
+                subquery,
+                outer_ref_columns,
+            }) => {
+                let new_plan = subquery
+                    .as_ref()
+                    .clone()
+                    .transform_down(&analyze_internal)
+                    .unwrap();
+                ScalarSubquery(Subquery {
+                    subquery: Arc::new(new_plan),
+                    outer_ref_columns,
+                })
+            }
+            InSubquery {
+                expr,
+                subquery,
+                negated,
+            } => {
+                let new_plan = subquery
+                    .subquery
+                    .as_ref()
+                    .clone()
+                    .transform_down(&analyze_internal)
+                    .unwrap();
+
+                InSubquery {
+                    expr,
+                    subquery: Subquery {
+                        subquery: Arc::new(new_plan),
+                        outer_ref_columns: subquery.outer_ref_columns,
+                    },
+                    negated,
+                }
+            }
+            Exists { subquery, negated } => {
+                let new_plan = subquery
+                    .subquery
+                    .as_ref()
+                    .clone()
+                    .transform_down(&analyze_internal)
+                    .unwrap();
+
+                Exists {
+                    subquery: Subquery {
+                        subquery: Arc::new(new_plan),
+                        outer_ref_columns: subquery.outer_ref_columns,
+                    },
+                    negated,
+                }
+            }
+            _ => old_expr,
+        };
+        Ok(new_expr)
+    }
+}
+fn rewrite_schema(schema: &DFSchema) -> DFSchemaRef {
+    let new_fields = schema
+        .fields()
+        .iter()
+        .map(|field| {
+            let mut name = field.field().name().clone();
+            if name.contains(COUNT_STAR) {
+                name = name.replace(
+                    COUNT_STAR,
+                    count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
+                );
+            }
+            DFField::new(
+                field.qualifier().cloned(),
+                &name,
+                field.data_type().clone(),
+                field.is_nullable(),
+            )
         })
-        .collect()
+        .collect::<Vec<DFField>>();
+    DFSchemaRef::new(
+        DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(),
+    )
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::test::*;
+    use datafusion_common::{Result, ScalarValue};
+    use datafusion_expr::expr::Sort;
+    use datafusion_expr::{
+        col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder,
+        max, scalar_subquery, AggregateFunction, Expr, WindowFrame, WindowFrameBound,
+        WindowFrameUnits, WindowFunction,
+    };
+
+    fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+        assert_analyzed_plan_eq_display_indent(
+            Arc::new(CountWildcardRule::new()),
+            plan,
+            expected,
+        )
+    }
+
+    #[test]
+    fn test_count_wildcard_on_sort() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(vec![col("b")], vec![count(Expr::Wildcard)])?
+            .project(vec![count(Expr::Wildcard)])?
+            .sort(vec![count(Expr::Wildcard).sort(true, false)])?
+            .build()?;
+        let expected = "Sort: COUNT(UInt8(1)) ASC NULLS LAST [COUNT(UInt8(1)):Int64;N]\
+          \n  Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+          \n    Aggregate: groupBy=[[test.b]], aggr=[[COUNT(UInt8(1))]] [b:UInt32, COUNT(UInt8(1)):Int64;N]\
+          \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn test_count_wildcard_on_where_in() -> Result<()> {
+        let table_scan_t1 = test_table_scan_with_name("t1")?;
+        let table_scan_t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(table_scan_t1)
+            .filter(in_subquery(
+                col("a"),
+                Arc::new(
+                    LogicalPlanBuilder::from(table_scan_t2)
+                        .aggregate(Vec::<Expr>::new(), vec![count(Expr::Wildcard)])?
+                        .project(vec![count(Expr::Wildcard)])?
+                        .build()?,
+                ),
+            ))?
+            .build()?;
+
+        let expected = "Filter: t1.a IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
+              \n  Subquery: [COUNT(UInt8(1)):Int64;N]\
+              \n    Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+              \n      Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+              \n        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+              \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn test_count_wildcard_on_where_exists() -> Result<()> {
+        let table_scan_t1 = test_table_scan_with_name("t1")?;
+        let table_scan_t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(table_scan_t1)
+            .filter(exists(Arc::new(
+                LogicalPlanBuilder::from(table_scan_t2)
+                    .aggregate(Vec::<Expr>::new(), vec![count(Expr::Wildcard)])?
+                    .project(vec![count(Expr::Wildcard)])?
+                    .build()?,
+            )))?
+            .build()?;
+
+        let expected = "Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
+          \n  Subquery: [COUNT(UInt8(1)):Int64;N]\
+          \n    Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+          \n      Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+          \n        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+          \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
+        let table_scan_t1 = test_table_scan_with_name("t1")?;
+        let table_scan_t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(table_scan_t1)
+            .filter(
+                scalar_subquery(Arc::new(
+                    LogicalPlanBuilder::from(table_scan_t2)
+                        .filter(col("t1.a").eq(col("t2.a")))?
+                        .aggregate(
+                            Vec::<Expr>::new(),
+                            vec![count(lit(COUNT_STAR_EXPANSION))],
+                        )?
+                        .project(vec![count(lit(COUNT_STAR_EXPANSION))])?
+                        .build()?,
+                ))
+                .gt(lit(ScalarValue::UInt8(Some(0)))),
+            )?
+            .project(vec![col("t1.a"), col("t1.b")])?
+            .build()?;
+
+        let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\
+              \n  Filter: (<subquery>) > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\
+              \n    Subquery: [COUNT(UInt8(1)):Int64;N]\
+              \n      Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+              \n        Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+              \n          Filter: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32]\
+              \n            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+              \n    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
+        assert_plan_eq(&plan, expected)
+    }
+    #[test]
+    fn test_count_wildcard_on_window() -> Result<()> {
+        let table_scan = test_table_scan()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .window(vec![Expr::WindowFunction(expr::WindowFunction::new(
+                WindowFunction::AggregateFunction(AggregateFunction::Count),
+                vec![Expr::Wildcard],
+                vec![],
+                vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
+                WindowFrame {
+                    units: WindowFrameUnits::Range,
+                    start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(
+                        6,
+                    ))),
+                    end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+                },
+            ))])?
+            .project(vec![count(Expr::Wildcard)])?
+            .build()?;
+
+        let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+              \n  WindowAggr: windowExpr=[[COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\
+              \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn test_count_wildcard_on_aggregate() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(Vec::<Expr>::new(), vec![count(Expr::Wildcard)])?
+            .project(vec![count(Expr::Wildcard)])?
+            .build()?;
+
+        let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+              \n  Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\
+              \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn test_count_wildcard_on_nesting() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(Vec::<Expr>::new(), vec![max(count(Expr::Wildcard))])?
+            .project(vec![count(Expr::Wildcard)])?
+            .build()?;
+
+        let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\
+          \n  Aggregate: groupBy=[[]], aggr=[[MAX(COUNT(UInt8(1)))]] [MAX(COUNT(UInt8(1))):Int64;N]\
+          \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+        assert_plan_eq(&plan, expected)
+    }
 }
diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs
index 439f44151e..67d342b4cb 100644
--- a/datafusion/optimizer/src/test/mod.rs
+++ b/datafusion/optimizer/src/test/mod.rs
@@ -121,7 +121,19 @@ pub fn assert_analyzed_plan_eq(
 
     Ok(())
 }
+pub fn assert_analyzed_plan_eq_display_indent(
+    rule: Arc<dyn AnalyzerRule + Send + Sync>,
+    plan: &LogicalPlan,
+    expected: &str,
+) -> Result<()> {
+    let options = ConfigOptions::default();
+    let analyzed_plan =
+        Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options)?;
+    let formatted_plan = format!("{}", analyzed_plan.display_indent_schema());
+    assert_eq!(formatted_plan, expected);
 
+    Ok(())
+}
 pub fn assert_optimized_plan_eq(
     rule: Arc<dyn OptimizerRule + Send + Sync>,
     plan: &LogicalPlan,