You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "alamb (via GitHub)" <gi...@apache.org> on 2023/04/14 18:27:40 UTC

[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #6010: update count_wildcard_rule for more scenario

alamb commented on code in PR #6010:
URL: https://github.com/apache/arrow-datafusion/pull/6010#discussion_r1167161858


##########
datafusion/core/tests/dataframe.rs:
##########
@@ -32,24 +32,179 @@ use datafusion::error::Result;
 use datafusion::execution::context::SessionContext;
 use datafusion::prelude::JoinType;
 use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
+use datafusion::test_util::parquet_test_data;
 use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
+use datafusion_common::ScalarValue;
 use datafusion_expr::expr::{GroupingSet, Sort};
-use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
+use datafusion_expr::Expr::{ScalarSubquery, Wildcard};
+use datafusion_expr::{
+    avg, col, count, expr, lit, max, sum, AggregateFunction, Expr, ExprSchemable,
+    Subquery, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
+};
 
 #[tokio::test]
-async fn count_wildcard() -> Result<()> {
-    let ctx = SessionContext::new();
-    let testdata = datafusion::test_util::parquet_test_data();
+async fn test_count_wildcard_on_sort() -> Result<()> {
+    let ctx = create_join_context()?;
 
-    ctx.register_parquet(
-        "alltypes_tiny_pages",
-        &format!("{testdata}/alltypes_tiny_pages.parquet"),
-        ParquetReadOptions::default(),
-    )
-    .await?;
+    let sql_results = ctx
+        .sql("select b,count(*) from t1 group by b order by count(*)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .aggregate(vec![col("b")], vec![count(Wildcard)])?
+        .sort(vec![count(Wildcard).sort(true, false)])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+    //make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&sql_results)?.to_string(),
+        pretty_format_batches(&df_results)?.to_string()
+    );
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_where_in() -> Result<()> {
+    let ctx = create_join_context()?;
+    let sql_results = ctx
+        .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
+    // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
+    // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
+    let ctx = create_join_context()?;
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .filter(Expr::InSubquery {
+            expr: Box::new(col("a")),
+            subquery: Subquery {

Review Comment:
   I think you could use `in_subquery` to simplify this test https://docs.rs/datafusion/latest/datafusion/prelude/fn.in_subquery.html 



##########
datafusion/core/tests/dataframe.rs:
##########
@@ -32,24 +32,179 @@ use datafusion::error::Result;
 use datafusion::execution::context::SessionContext;
 use datafusion::prelude::JoinType;
 use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
+use datafusion::test_util::parquet_test_data;
 use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
+use datafusion_common::ScalarValue;
 use datafusion_expr::expr::{GroupingSet, Sort};
-use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
+use datafusion_expr::Expr::{ScalarSubquery, Wildcard};
+use datafusion_expr::{
+    avg, col, count, expr, lit, max, sum, AggregateFunction, Expr, ExprSchemable,
+    Subquery, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
+};
 
 #[tokio::test]
-async fn count_wildcard() -> Result<()> {
-    let ctx = SessionContext::new();
-    let testdata = datafusion::test_util::parquet_test_data();
+async fn test_count_wildcard_on_sort() -> Result<()> {
+    let ctx = create_join_context()?;
 
-    ctx.register_parquet(
-        "alltypes_tiny_pages",
-        &format!("{testdata}/alltypes_tiny_pages.parquet"),
-        ParquetReadOptions::default(),
-    )
-    .await?;
+    let sql_results = ctx
+        .sql("select b,count(*) from t1 group by b order by count(*)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .aggregate(vec![col("b")], vec![count(Wildcard)])?
+        .sort(vec![count(Wildcard).sort(true, false)])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+    //make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&sql_results)?.to_string(),
+        pretty_format_batches(&df_results)?.to_string()
+    );
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_where_in() -> Result<()> {
+    let ctx = create_join_context()?;
+    let sql_results = ctx
+        .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
+    // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
+    // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
+    let ctx = create_join_context()?;
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .filter(Expr::InSubquery {
+            expr: Box::new(col("a")),
+            subquery: Subquery {
+                subquery: Arc::new(
+                    ctx.table("t2")
+                        .await?
+                        .aggregate(vec![], vec![count(Expr::Wildcard)])?
+                        .select(vec![count(Expr::Wildcard)])?
+                        .into_unoptimized_plan(),
+                    // Usually, into_optimized_plan() should be used here, but due to
+                    // https://github.com/apache/arrow-datafusion/issues/5771,
+                    // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here.
+                ),
+                outer_ref_columns: vec![],
+            },
+            negated: false,
+        })?
+        .select(vec![col("a"), col("b")])?
+        .explain(false, false)?
+        .collect()
+        .await?;
+
+    // make sure sql plan same with df plan
+    assert_eq!(
+        pretty_format_batches(&sql_results)?.to_string(),
+        pretty_format_batches(&df_results)?.to_string()
+    );
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_count_wildcard_on_where_exist() -> Result<()> {
+    let ctx = create_join_context()?;
+    let sql_results = ctx
+        .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
+        .await?
+        .explain(false, false)?
+        .collect()
+        .await?;
+    let df_results = ctx
+        .table("t1")
+        .await?
+        .filter(Expr::Exists {

Review Comment:
   Likewise there  is an `exists` function to help https://docs.rs/datafusion/latest/datafusion/prelude/fn.exists.html
   
   And there are several other places in this PR that this or similar functions can be used when creating subqueries



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org