You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "yjshen (via GitHub)" <gi...@apache.org> on 2023/03/10 18:57:36 UTC

[GitHub] [arrow-datafusion] yjshen commented on issue #5547: Improve the performance of COUNT DISTINCT queries for high cardinality groups

yjshen commented on issue #5547:
URL: https://github.com/apache/arrow-datafusion/issues/5547#issuecomment-1464255278

   I was suggesting using an optimizer rule to rewrite aggregate with distinct into double aggregation to eliminate distinct `AggregateExpr`s for execution. 
   
   The gist of the idea is to first move distinct columns as additional grouping columns to compute non-distinct aggregate results, and then use another round of aggregation to compute values for distinct expressions (since they have already been deduplicated in the first aggregation as grouping columns).
   
   I will paste JavaDoc for Spark's RewriteDistinctAggregates below because it contains helpful examples, [source here](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala)
   
   ```java
   /**
    * This rule rewrites an aggregate query with distinct aggregations into an expanded double
    * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
    * in a separate group. The results are then combined in a second aggregate.
    *
    * First example: query without filter clauses (in scala):
    * {{{
    *   val data = Seq(
    *     ("a", "ca1", "cb1", 10),
    *     ("a", "ca1", "cb2", 5),
    *     ("b", "ca1", "cb1", 13))
    *     .toDF("key", "cat1", "cat2", "value")
    *   data.createOrReplaceTempView("data")
    *
    *   val agg = data.groupBy($"key")
    *     .agg(
    *       count_distinct($"cat1").as("cat1_cnt"),
    *       count_distinct($"cat2").as("cat2_cnt"),
    *       sum($"value").as("total"))
    * }}}
    *
    * This translates to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [COUNT(DISTINCT 'cat1),
    *                 COUNT(DISTINCT 'cat2),
    *                 sum('value)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   LocalTableScan [...]
    * }}}
    *
    * This rule rewrites this logical plan to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [count('cat1) FILTER (WHERE 'gid = 1),
    *                 count('cat2) FILTER (WHERE 'gid = 2),
    *                 first('total) ignore nulls FILTER (WHERE 'gid = 0)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   Aggregate(
    *      key = ['key, 'cat1, 'cat2, 'gid]
    *      functions = [sum('value)]
    *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
    *     Expand(
    *        projections = [('key, null, null, 0, cast('value as bigint)),
    *                       ('key, 'cat1, null, 1, null),
    *                       ('key, null, 'cat2, 2, null)]
    *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
    *       LocalTableScan [...]
    * }}}
    *
    * Second example: aggregate function without distinct and with filter clauses (in sql):
    * {{{
    *   SELECT
    *     COUNT(DISTINCT cat1) as cat1_cnt,
    *     COUNT(DISTINCT cat2) as cat2_cnt,
    *     SUM(value) FILTER (WHERE id > 1) AS total
    *   FROM
    *     data
    *   GROUP BY
    *     key
    * }}}
    *
    * This translates to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [COUNT(DISTINCT 'cat1),
    *                 COUNT(DISTINCT 'cat2),
    *                 sum('value) FILTER (WHERE 'id > 1)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   LocalTableScan [...]
    * }}}
    *
    * This rule rewrites this logical plan to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [count('cat1) FILTER (WHERE 'gid = 1),
    *                 count('cat2) FILTER (WHERE 'gid = 2),
    *                 first('total) ignore nulls FILTER (WHERE 'gid = 0)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   Aggregate(
    *      key = ['key, 'cat1, 'cat2, 'gid]
    *      functions = [sum('value) FILTER (WHERE 'id > 1)]
    *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
    *     Expand(
    *        projections = [('key, null, null, 0, cast('value as bigint), 'id),
    *                       ('key, 'cat1, null, 1, null, null),
    *                       ('key, null, 'cat2, 2, null, null)]
    *        output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id])
    *       LocalTableScan [...]
    * }}}
    *
    * Third example: aggregate function with distinct and filter clauses (in sql):
    * {{{
    *   SELECT
    *     COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt,
    *     COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_cnt,
    *     SUM(value) FILTER (WHERE id > 3) AS total
    *   FROM
    *     data
    *   GROUP BY
    *     key
    * }}}
    *
    * This translates to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [COUNT(DISTINCT 'cat1) FILTER (WHERE 'id > 1),
    *                 COUNT(DISTINCT 'cat2) FILTER (WHERE 'id > 2),
    *                 sum('value) FILTER (WHERE 'id > 3)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   LocalTableScan [...]
    * }}}
    *
    * This rule rewrites this logical plan to the following (pseudo) logical plan:
    * {{{
    * Aggregate(
    *    key = ['key]
    *    functions = [count('cat1) FILTER (WHERE 'gid = 1 and 'max_cond1),
    *                 count('cat2) FILTER (WHERE 'gid = 2 and 'max_cond2),
    *                 first('total) ignore nulls FILTER (WHERE 'gid = 0)]
    *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
    *   Aggregate(
    *      key = ['key, 'cat1, 'cat2, 'gid]
    *      functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id > 3)]
    *      output = ['key, 'cat1, 'cat2, 'gid, 'max_cond1, 'max_cond2, 'total])
    *     Expand(
    *        projections = [('key, null, null, 0, null, null, cast('value as bigint), 'id),
    *                       ('key, 'cat1, null, 1, 'id > 1, null, null, null),
    *                       ('key, null, 'cat2, 2, null, 'id > 2, null, null)]
    *        output = ['key, 'cat1, 'cat2, 'gid, 'cond1, 'cond2, 'value, 'id])
    *       LocalTableScan [...]
    * }}}
   ```
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org