You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@spark.apache.org by 万昆 <wa...@163.com> on 2023/01/04 03:34:31 UTC

Could we reorder the second aggregate node and the expand node when rewriting multiple distinct

Hello, 
  Spark sql rule RewriteDistinctAggregates will rewrite multiple distinct expressions into two Aggregate nodes and a expand node.
The follow is the example in the class documentation, I wander if we can reorder the second Aggregate node and the expand node and make the expand generate fewer records?
Thanks


Second example: aggregate function without distinct and with filter clauses (in sql):
   SELECT
     COUNT(DISTINCT cat1)as cat1_cnt,
     COUNT(DISTINCT cat2)as cat2_cnt,
     SUM(value) FILTER (WHERE id >1)AS total
  FROM
    data
  GROUPBY
    key

This translates to the following (pseudo) logical plan:

 Aggregate(
    key = ['key]
    functions = [COUNT(DISTINCT 'cat1),
                 COUNT(DISTINCT 'cat2),
                 sum('value) with FILTER('id > 1)]
    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
   LocalTableScan [...]

This rule rewrites this logical plan to the following (pseudo) logical plan:

 Aggregate(
    key = ['key]
    functions = [count(if (('gid = 1)) 'cat1 else null),
                 count(if (('gid = 2)) 'cat2 else null),
                 first(if (('gid = 0)) 'total else null) ignore nulls]
    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
   Aggregate(
      key = ['key, 'cat1, 'cat2, 'gid]
      functions = [sum('value) with FILTER('id > 1)]
      output = ['key, 'cat1, 'cat2, 'gid, 'total])
     Expand(
        projections = [('key, null, null, 0, cast('value as bigint), 'id),
                       ('key, 'cat1, null, 1, null, null),
                       ('key, null, 'cat2, 2, null, null)]
        output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id])
       LocalTableScan [...]

Could we rewrite this logical plan to :

 Aggregate(
    key = ['key]
    functions = [count(if (('gid = 1)) 'cat1 else null),
                 count(if (('gid = 2)) 'cat2 else null),
                 first(if (('gid = 0)) 'total else null) ignore nulls]
    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
   Expand(
     projections = [('key, 'total, null, null, 0, cast('value as bigint)),
                    ('key, 'total, 'cat1, null, 1, null),
                    ('key, 'total, null, 'cat2, 2, null)]
     output = ['key, 'total, 'cat1, 'cat2, 'gid, 'value])
      Aggregate(
         key = ['key, 'cat1, 'cat2]
         functions = [sum('value) with FILTER('id > 1)]
         output = ['key, 'cat1, 'cat2, 'total])
       LocalTableScan [...]