You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "yjshen (via GitHub)" <gi...@apache.org> on 2023/04/04 17:22:59 UTC

[GitHub] [arrow-datafusion] yjshen opened a new pull request, #5868: feat: Support SQL filter clause for aggregate expressions

yjshen opened a new pull request, #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868

   # Which issue does this PR close?
   
   <!--
   We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123.
   -->
   
   Closes #.
   
   # Rationale for this change
   
   <!--
    Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes.  
   -->
   
   This pull request introduces support for the FILTER (WHERE) clause in aggregate expressions. This feature enables users to filter the rows that are considered for aggregation, similar to how it is done in popular SQL databases such as PostgreSQL, SQLite, Spark, and Hive. 
   
   # What changes are included in this PR?
   
   1. The `physical_plan/aggregate` module is where the majority of the work for this project was completed.
   2. Additionally, there were changes made to the mechanism that routes the optional filter through the optimizer and execution code path.
   
   <!--
   There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR.
   -->
   
   # Are these changes tested?
   
   New tests were added in `group_by.rs` to cover various scenarios using the FILTER (WHERE) clause with different situations.
   
   <!--
   We typically require tests for all PRs in order to:
   1. Prevent the code from being accidentally broken by subsequent changes
   4. Serve as another way to document the expected behavior of the code
   
   If tests are not included in your PR, please explain why (for example, are they covered by existing tests)?
   -->
   
   # Are there any user-facing changes?
   
   Yes, users can now use the FILTER (WHERE) clause in aggregate expressions in their SQL queries, providing more flexibility and control over the aggregation. This change is backward compatible, and existing queries without the FILTER (WHERE) clause should continue to work as expected.
   
   <!--
   If there are user-facing changes then we may require documentation to be updated before approving the PR.
   -->
   
   <!--
   If there are any breaking changes to public APIs, please add the `api change` label.
   -->


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1584857551

   https://github.com/apache/arrow-datafusion/pull/6616 -- PR to use upstreamed version of parse_sql_dialect


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1160159239


##########
datafusion/common/src/config.rs:
##########
@@ -187,6 +187,10 @@ config_namespace! {
         /// When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted)
         pub enable_ident_normalization: bool, default = true
 
+        /// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic,

Review Comment:
   πŸ‘ 



##########
datafusion/core/src/execution/context.rs:
##########
@@ -1833,6 +1860,29 @@ impl From<&SessionState> for TaskContext {
     }
 }
 
+fn create_dialect_from_str(dialect_name: &str) -> Box<dyn Dialect> {

Review Comment:
   What would you think about putting this in a PR  upstream in sqlparser-rs?  I can do so if you agree



##########
datafusion/core/src/execution/context.rs:
##########
@@ -1833,6 +1860,29 @@ impl From<&SessionState> for TaskContext {
     }
 }
 
+fn create_dialect_from_str(dialect_name: &str) -> Box<dyn Dialect> {
+    match dialect_name.to_lowercase().as_str() {
+        "generic" => Box::new(GenericDialect),
+        "mysql" => Box::new(MySqlDialect {}),
+        "postgresql" | "postgres" => Box::new(PostgreSqlDialect {}),
+        "hive" => Box::new(HiveDialect {}),
+        "sqlite" => Box::new(SQLiteDialect {}),
+        "snowflake" => Box::new(SnowflakeDialect),
+        "redshift" => Box::new(RedshiftSqlDialect {}),
+        "mssql" => Box::new(MsSqlDialect {}),
+        "clickhouse" => Box::new(ClickHouseDialect {}),
+        "bigquery" => Box::new(BigQueryDialect),
+        "ansi" => Box::new(AnsiDialect {}),
+        _ => {

Review Comment:
   I think it might be better to return an error here (rather than doing `println` -- for one thing when running in a server or other distributed context, stdout may not be connected to anything



##########
datafusion/core/src/physical_plan/aggregates/no_grouping.rs:
##########
@@ -172,26 +177,34 @@ fn aggregate_batch(
     batch: &RecordBatch,
     accumulators: &mut [AccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+    filters: &[Option<Arc<dyn PhysicalExpr>>],
 ) -> Result<usize> {
     let mut allocated = 0usize;
 
     // 1.1 iterate accumulators and respective expressions together
-    // 1.2 evaluate expressions
-    // 1.3 update / merge accumulators with the expressions' values
+    // 1.2 filter the batch if necessary
+    // 1.3 evaluate expressions
+    // 1.4 update / merge accumulators with the expressions' values
 
     // 1.1
     accumulators
         .iter_mut()
         .zip(expressions)
-        .try_for_each(|(accum, expr)| {
+        .zip(filters)
+        .try_for_each(|((accum, expr), filter)| {
             // 1.2
+            let batch = match filter {
+                Some(filter) => batch_filter(batch, filter)?,
+                None => batch.clone(),

Review Comment:
   It would be really nice to figure out how to avoid this clone(). 
   
   Here is one way I found to do so:
   
   ```diff
   diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
   index 8b770f796..88bab512c 100644
   --- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
   +++ b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
   @@ -29,6 +29,7 @@ use arrow::record_batch::RecordBatch;
    use datafusion_common::Result;
    use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
    use futures::stream::BoxStream;
   +use std::borrow::Cow;
    use std::sync::Arc;
    use std::task::{Context, Poll};
    
   @@ -101,7 +102,7 @@ impl AggregateStream {
                            let timer = elapsed_compute.timer();
                            let result = aggregate_batch(
                                &this.mode,
   -                            &batch,
   +                            batch,
                                &mut this.accumulators,
                                &this.aggregate_expressions,
                                &this.filter_expressions,
   @@ -174,7 +175,7 @@ impl RecordBatchStream for AggregateStream {
    /// TODO: Make this a member function
    fn aggregate_batch(
        mode: &AggregateMode,
   -    batch: &RecordBatch,
   +    batch: RecordBatch,
        accumulators: &mut [AccumulatorItem],
        expressions: &[Vec<Arc<dyn PhysicalExpr>>],
        filters: &[Option<Arc<dyn PhysicalExpr>>],
   @@ -194,8 +195,8 @@ fn aggregate_batch(
            .try_for_each(|((accum, expr), filter)| {
                // 1.2
                let batch = match filter {
   -                Some(filter) => batch_filter(batch, filter)?,
   -                None => batch.clone(),
   +                Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
   +                None => Cow::Borrowed(&batch),
                };
                // 1.3
                let values = &expr
   ```



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn query_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(1),
+                Some(1),
+                Some(2),
+                Some(2),
+                Some(3),
+            ])),
+            Arc::new(Int32Array::from(vec![
+                Some(10),
+                Some(20),
+                Some(10),
+                Some(20),
+                Some(10),
+            ])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql =
+        "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test GROUP BY c1";
+
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | result |",
+        "+----+--------+",
+        "| 1  | 20     |",
+        "| 2  | 20     |",
+        "| 3  |        |",
+        "+----+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// test avg since it has two state columns
+#[tokio::test]
+async fn query_group_by_avg_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql =
+        "SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | avg_c2 |",
+        "+----+--------+",
+        "| 1  | 20.0   |",
+        "| 2  | 35.0   |",
+        "+----+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_with_multiple_filters() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+        Field::new("c3", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+            Arc::new(Int32Array::from(vec![50, 60, 70, 80])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql = "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) FILTER (WHERE c3 <= 70) AS avg_c3 FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+--------+",
+        "| c1 | sum_c2 | avg_c3 |",
+        "+----+--------+--------+",
+        "| 1  | 20     | 55.0   |",
+        "| 2  | 70     | 70.0   |",
+        "+----+--------+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_distinct_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+-------------------+",
+        "| c1 | distinct_c2_count |",
+        "+----+-------------------+",
+        "| 1  | 1                 |",
+        "| 2  | 3                 |",
+        "+----+-------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_without_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------+",
+        "| sum_c2 |",
+        "+--------+",
+        "| 110    |",
+        "+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// count is special cased by `aggregate_statistics`

Review Comment:
   Also, it would be good to test when the filter filters out all rows (aka the input to the aggregate is empty)
   
   ```sql
   SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test
   ```



##########
datafusion/physical-expr/src/aggregate/mod.rs:
##########
@@ -77,6 +77,9 @@ pub trait AggregateExpr: Send + Sync + Debug {
     /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many.
     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
 
+    /// FILTER (WHERE clause) expression for this aggregate

Review Comment:
   I think it makes more sense to track the filters on the `LogicalPlan::GroupBy`, and relevant ExecutionPlan, rather than force all AggregateExprs to carry the filter themselves because"
   
   1. The filtering is the same for all aggregates (it doesn't vary by aggregate type) so having the filter on the aggregate seems to be a mismatch
   2. If forces user defined aggregates to all do the same thing (carry along a filter) and if they make a mistake they could get wrong answers. 



##########
datafusion/core/src/execution/context.rs:
##########
@@ -1510,6 +1515,27 @@ impl SessionState {
         Ok(statement)
     }
 
+    /// Convert a SQL string into an AST Statement
+    pub fn sql_to_statement_with_dialect(

Review Comment:
   I think with this change, it will mean that `sql_to_statement` will effectively ignore `ConfigOptions::sql_parser::dialect` which seems confusing. 
   
   Is there a need for a separate  "sql_to_statement_with_dialect" -- rather than changing  `sql_to_statement` to use the configured dialect ?



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn query_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(1),
+                Some(1),
+                Some(2),
+                Some(2),
+                Some(3),
+            ])),
+            Arc::new(Int32Array::from(vec![
+                Some(10),
+                Some(20),
+                Some(10),
+                Some(20),
+                Some(10),
+            ])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql =

Review Comment:
   I think another case that is important is a query with an aggregate that both does / does not have a filter
   
   For example
   
   
   ```sql
   SELECT c1, 
     SUM(c2) FILTER (WHERE c2 >= 20) as result,
     SUM(c2)  as result_no_filter,
   FROM test GROUP BY c1



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+

Review Comment:
   What do you think about writing tests in  aggregate.slt instead of .rs?
   
   



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn query_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(1),
+                Some(1),
+                Some(2),
+                Some(2),
+                Some(3),
+            ])),
+            Arc::new(Int32Array::from(vec![
+                Some(10),
+                Some(20),
+                Some(10),
+                Some(20),
+                Some(10),
+            ])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql =
+        "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test GROUP BY c1";
+
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | result |",
+        "+----+--------+",
+        "| 1  | 20     |",
+        "| 2  | 20     |",
+        "| 3  |        |",
+        "+----+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// test avg since it has two state columns
+#[tokio::test]
+async fn query_group_by_avg_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql =
+        "SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | avg_c2 |",
+        "+----+--------+",
+        "| 1  | 20.0   |",
+        "| 2  | 35.0   |",
+        "+----+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_with_multiple_filters() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+        Field::new("c3", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+            Arc::new(Int32Array::from(vec![50, 60, 70, 80])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql = "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) FILTER (WHERE c3 <= 70) AS avg_c3 FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+--------+",
+        "| c1 | sum_c2 | avg_c3 |",
+        "+----+--------+--------+",
+        "| 1  | 20     | 55.0   |",
+        "| 2  | 70     | 70.0   |",
+        "+----+--------+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_distinct_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+-------------------+",
+        "| c1 | distinct_c2_count |",
+        "+----+-------------------+",
+        "| 1  | 1                 |",
+        "| 2  | 3                 |",
+        "+----+-------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_without_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------+",
+        "| sum_c2 |",
+        "+--------+",
+        "| 110    |",
+        "+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// count is special cased by `aggregate_statistics`

Review Comment:
   Another test case that I think is important is when the filter is on a different column than the aggregate. 
   
   ```
   postgres=# create table test as values (1, 10), (2, 20), (3, 30);
   SELECT 3
   postgres=# select * from test;
    column1 | column2
   ---------+---------
          1 |      10
          2 |      20
          3 |      30
   (3 rows)
   
   postgres=# select sum(column1) FILTER (WHERE column2 < 30) from test;
    sum
   -----
      3
   (1 row)
   ```
   
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1498111630

   I plan to review this PR tomorrow. Thank you @yjshen 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1503081536

   I ran this branch against main using https://github.com/alamb/datafusion-benchmarking and I see no performance difference πŸ‘ 
   
   ```
   ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
   ┃ Query        ┃ /home/alamb… ┃ /home/alamb… ┃    Change ┃
   ┑━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
   β”‚ QQuery 1     β”‚    1700.14ms β”‚    1674.54ms β”‚ no change β”‚
   β”‚ QQuery 2     β”‚     495.96ms β”‚     476.32ms β”‚ no change β”‚
   β”‚ QQuery 3     β”‚     566.94ms β”‚     574.48ms β”‚ no change β”‚
   β”‚ QQuery 4     β”‚     238.05ms β”‚     233.97ms β”‚ no change β”‚
   β”‚ QQuery 5     β”‚     776.22ms β”‚     747.68ms β”‚ no change β”‚
   β”‚ QQuery 6     β”‚     429.90ms β”‚     434.80ms β”‚ no change β”‚
   β”‚ QQuery 7     β”‚    1366.64ms β”‚    1346.81ms β”‚ no change β”‚
   β”‚ QQuery 8     β”‚     756.20ms β”‚     758.10ms β”‚ no change β”‚
   β”‚ QQuery 9     β”‚    1244.81ms β”‚    1243.00ms β”‚ no change β”‚
   β”‚ QQuery 10    β”‚     879.18ms β”‚     851.17ms β”‚ no change β”‚
   β”‚ QQuery 11    β”‚     412.81ms β”‚     421.18ms β”‚ no change β”‚
   β”‚ QQuery 12    β”‚     346.56ms β”‚     340.06ms β”‚ no change β”‚
   β”‚ QQuery 13    β”‚    1390.35ms β”‚    1373.72ms β”‚ no change β”‚
   β”‚ QQuery 14    β”‚     455.99ms β”‚     447.49ms β”‚ no change β”‚
   β”‚ QQuery 15    β”‚     462.05ms β”‚     446.96ms β”‚ no change β”‚
   β”‚ QQuery 16    β”‚     356.55ms β”‚     350.45ms β”‚ no change β”‚
   β”‚ QQuery 17    β”‚    6630.39ms β”‚    6682.32ms β”‚ no change β”‚
   β”‚ QQuery 18    β”‚    4040.30ms β”‚    4029.81ms β”‚ no change β”‚
   β”‚ QQuery 19    β”‚     753.73ms β”‚     780.20ms β”‚ no change β”‚
   β”‚ QQuery 20    β”‚    1948.00ms β”‚    1903.14ms β”‚ no change β”‚
   β”‚ QQuery 21    β”‚    1865.89ms β”‚    1935.56ms β”‚ no change β”‚
   β”‚ QQuery 22    β”‚     211.15ms β”‚     218.02ms β”‚ no change β”‚
   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] yjshen commented on a diff in pull request #5868: feat: Support SQL filter clause for aggregate expressions

Posted by "yjshen (via GitHub)" <gi...@apache.org>.
yjshen commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1157564945


##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -463,27 +482,20 @@ impl GroupedHashAggregateStream {
                                     accumulator.merge_batch(&values, &mut state_accessor)
                                 }
                             }
-                        })
-                        // 2.5
-                        .and(Ok(()))?;
+                        })?;
                     // normal accumulators
                     group_state
                         .accumulator_set
                         .iter_mut()
                         .zip(normal_values.iter())
-                        .map(|(accumulator, aggr_array)| {
-                            (
-                                accumulator,
-                                aggr_array
-                                    .iter()
-                                    .map(|array| {
-                                        // 2.3
-                                        array.slice(offsets[0], offsets[1] - offsets[0])
-                                    })
-                                    .collect::<Vec<ArrayRef>>(),
-                            )
-                        })
-                        .try_for_each(|(accumulator, values)| {
+                        .zip(normal_filter_values.iter())

Review Comment:
   Filter input batch before merging into normal accumulators.



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -430,22 +453,18 @@ impl GroupedHashAggregateStream {
                 .try_for_each(|(group_idx, offsets)| {
                     let group_state = &mut row_group_states[*group_idx];
                     // 2.2
+                    // Process row accumulators
                     self.row_accumulators
                         .iter_mut()
                         .zip(row_values.iter())
-                        .map(|(accumulator, aggr_array)| {
-                            (
-                                accumulator,
-                                aggr_array
-                                    .iter()
-                                    .map(|array| {
-                                        // 2.3
-                                        array.slice(offsets[0], offsets[1] - offsets[0])
-                                    })
-                                    .collect::<Vec<ArrayRef>>(),
-                            )
-                        })
-                        .try_for_each(|(accumulator, values)| {
+                        .zip(row_filter_values.iter())

Review Comment:
   Filter input batch before merging into row accumulators.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1161957354


##########
datafusion/core/src/physical_plan/aggregates/no_grouping.rs:
##########
@@ -169,29 +182,37 @@ impl RecordBatchStream for AggregateStream {
 /// TODO: Make this a member function
 fn aggregate_batch(
     mode: &AggregateMode,
-    batch: &RecordBatch,
+    batch: RecordBatch,
     accumulators: &mut [AccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+    filters: &[Option<Arc<dyn PhysicalExpr>>],
 ) -> Result<usize> {
     let mut allocated = 0usize;
 
     // 1.1 iterate accumulators and respective expressions together
-    // 1.2 evaluate expressions
-    // 1.3 update / merge accumulators with the expressions' values
+    // 1.2 filter the batch if necessary
+    // 1.3 evaluate expressions
+    // 1.4 update / merge accumulators with the expressions' values
 
     // 1.1
     accumulators
         .iter_mut()
         .zip(expressions)
-        .try_for_each(|(accum, expr)| {
+        .zip(filters)
+        .try_for_each(|((accum, expr), filter)| {
             // 1.2
+            let batch = match filter {

Review Comment:
   πŸ‘ 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] jdye64 commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "jdye64 (via GitHub)" <gi...@apache.org>.
jdye64 commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1499699035

   @alamb thank you for the heads up. Created an issue to make sure we accommodate for these changes. [316](https://github.com/apache/arrow-datafusion-python/issues/316)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] yjshen commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "yjshen (via GitHub)" <gi...@apache.org>.
yjshen commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1502634699

   I benched `aggregate_query_sql.rs` and confirmed that this has no noticeable effect on the current benchmarks.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] yjshen commented on a diff in pull request #5868: feat: Support SQL filter clause for aggregate expressions

Posted by "yjshen (via GitHub)" <gi...@apache.org>.
yjshen commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1157563571


##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -430,22 +453,18 @@ impl GroupedHashAggregateStream {
                 .try_for_each(|(group_idx, offsets)| {
                     let group_state = &mut row_group_states[*group_idx];
                     // 2.2
+                    // Process row accumulators
                     self.row_accumulators
                         .iter_mut()
                         .zip(row_values.iter())
-                        .map(|(accumulator, aggr_array)| {
-                            (
-                                accumulator,
-                                aggr_array
-                                    .iter()
-                                    .map(|array| {
-                                        // 2.3
-                                        array.slice(offsets[0], offsets[1] - offsets[0])
-                                    })
-                                    .collect::<Vec<ArrayRef>>(),
-                            )
-                        })
-                        .try_for_each(|(accumulator, values)| {
+                        .zip(row_filter_values.iter())

Review Comment:
   Filter input batch before merge into row accumulators.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1502144356

   cc @tustvold @mustafasrepo  @crepererum  and @Dandandan  who I think are all interested in grouping performance


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1162016185


##########
datafusion/core/src/execution/context.rs:
##########
@@ -1833,6 +1841,28 @@ impl From<&SessionState> for TaskContext {
     }
 }
 
+fn create_dialect_from_str(dialect_name: &str) -> Result<Box<dyn Dialect>> {

Review Comment:
   I have ported this code upstream here: https://github.com/sqlparser-rs/sqlparser-rs/pull/848
   
   ```suggestion
   // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/848 is released
   fn create_dialect_from_str(dialect_name: &str) -> Result<Box<dyn Dialect>> {
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] yjshen commented on a diff in pull request #5868: feat: Support SQL filter clause for aggregate expressions

Posted by "yjshen (via GitHub)" <gi...@apache.org>.
yjshen commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1157566345


##########
datafusion/core/src/physical_plan/aggregates/no_grouping.rs:
##########
@@ -172,26 +177,34 @@ fn aggregate_batch(
     batch: &RecordBatch,
     accumulators: &mut [AccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+    filters: &[Option<Arc<dyn PhysicalExpr>>],
 ) -> Result<usize> {
     let mut allocated = 0usize;
 
     // 1.1 iterate accumulators and respective expressions together
-    // 1.2 evaluate expressions
-    // 1.3 update / merge accumulators with the expressions' values
+    // 1.2 filter the batch if necessary

Review Comment:
   Filter input batch before merging into the single state.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] yjshen merged pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "yjshen (via GitHub)" <gi...@apache.org>.
yjshen merged PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] yjshen commented on a diff in pull request #5868: feat: Support SQL filter clause for aggregate expressions

Posted by "yjshen (via GitHub)" <gi...@apache.org>.
yjshen commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1157564945


##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -463,27 +482,20 @@ impl GroupedHashAggregateStream {
                                     accumulator.merge_batch(&values, &mut state_accessor)
                                 }
                             }
-                        })
-                        // 2.5
-                        .and(Ok(()))?;
+                        })?;
                     // normal accumulators
                     group_state
                         .accumulator_set
                         .iter_mut()
                         .zip(normal_values.iter())
-                        .map(|(accumulator, aggr_array)| {
-                            (
-                                accumulator,
-                                aggr_array
-                                    .iter()
-                                    .map(|array| {
-                                        // 2.3
-                                        array.slice(offsets[0], offsets[1] - offsets[0])
-                                    })
-                                    .collect::<Vec<ArrayRef>>(),
-                            )
-                        })
-                        .try_for_each(|(accumulator, values)| {
+                        .zip(normal_filter_values.iter())

Review Comment:
   ditto



##########
datafusion/core/src/physical_plan/aggregates/no_grouping.rs:
##########
@@ -172,26 +177,34 @@ fn aggregate_batch(
     batch: &RecordBatch,
     accumulators: &mut [AccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+    filters: &[Option<Arc<dyn PhysicalExpr>>],
 ) -> Result<usize> {
     let mut allocated = 0usize;
 
     // 1.1 iterate accumulators and respective expressions together
-    // 1.2 evaluate expressions
-    // 1.3 update / merge accumulators with the expressions' values
+    // 1.2 filter the batch if necessary

Review Comment:
   Here is another main change.



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -532,7 +544,7 @@ pub struct RowGroupState {
 }
 
 /// The state of all the groups
-pub struct RowAggregationState {
+pub struct AggregationState {

Review Comment:
   ditto



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -430,22 +453,18 @@ impl GroupedHashAggregateStream {
                 .try_for_each(|(group_idx, offsets)| {
                     let group_state = &mut row_group_states[*group_idx];
                     // 2.2
+                    // Process row accumulators
                     self.row_accumulators
                         .iter_mut()
                         .zip(row_values.iter())
-                        .map(|(accumulator, aggr_array)| {
-                            (
-                                accumulator,
-                                aggr_array
-                                    .iter()
-                                    .map(|array| {
-                                        // 2.3
-                                        array.slice(offsets[0], offsets[1] - offsets[0])
-                                    })
-                                    .collect::<Vec<ArrayRef>>(),
-                            )
-                        })
-                        .try_for_each(|(accumulator, values)| {
+                        .zip(row_filter_values.iter())

Review Comment:
   This is one of the main changes.



##########
datafusion/sql/src/parser.rs:
##########
@@ -134,7 +134,7 @@ impl<'a> DFParser<'a> {
     /// Parse a sql string into one or [`Statement`]s using the
     /// [`GenericDialect`].
     pub fn parse_sql(sql: &str) -> Result<VecDeque<Statement>, ParserError> {
-        let dialect = &GenericDialect {};
+        let dialect = &PostgreSqlDialect {};

Review Comment:
   I am unsure about the best approach to tackle this. One potential solution is to provide a configuration option that users can set at the beginning.



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -516,7 +528,7 @@ impl GroupedHashAggregateStream {
 
 /// The state that is built for each output group.
 #[derive(Debug)]
-pub struct RowGroupState {
+pub struct GroupState {

Review Comment:
   The State actually contains both row-wise state and non-row state.



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -200,19 +214,21 @@ impl GroupedHashAggregateStream {
 
         Ok(GroupedHashAggregateStream {
             schema: Arc::clone(&schema),
-            mode,
-            exec_state,
             input,
-            group_by,
+            mode,
             normal_aggr_expr,
+            normal_aggregate_expressions,
+            normal_filter_expressions,
+            row_aggregate_expressions,
+            row_filter_expressions,
             row_accumulators,
             row_converter,
             row_aggr_schema,
             row_aggr_layout,
+            group_by,
+            aggr_state,
+            exec_state,

Review Comment:
   Group the members based on their usage and arrange them accordingly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] yjshen commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "yjshen (via GitHub)" <gi...@apache.org>.
yjshen commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1500537165

   Thank you @alamb for the detailed review! I have made the following updates to the PR based on your feedback: 
   
   1. The optional filter has been moved from aggregate expressions into `AggregateExec`.
   2. All related tests have been moved into `.slt`, additional tests were added as per your review comments and tested against PostgreSQL results.
   3. Miscellaneous dialect configuration changes were made as per your comments. By the way, it would be great if `create_dialect_from_str` could be moved upstream.
   
   I believe that this PR is now ready for further review. Thank you again, @alamb!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on pull request #5868: feat: Support SQL filter clause for aggregate expressions, add SQL dialect support

Posted by "alamb (via GitHub)" <gi...@apache.org>.
alamb commented on PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#issuecomment-1499547549

   cc @andygrove @Dandandan and @jdye64 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org