You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Kevin Liu (Jira)" <ji...@apache.org> on 2021/12/18 12:37:00 UTC

[jira] [Created] (SPARK-37682) Reduce memory pressure of RewriteDistinctAggregates

Kevin Liu created SPARK-37682:
---------------------------------

             Summary: Reduce memory pressure of RewriteDistinctAggregates
                 Key: SPARK-37682
                 URL: https://issues.apache.org/jira/browse/SPARK-37682
             Project: Spark
          Issue Type: Improvement
          Components: SQL
    Affects Versions: 3.2.0
            Reporter: Kevin Liu


In some cases, current RewriteDistinctAggregates duplicates unnecessary input data in distinct groups.
This will cause a lot of waste of memory and affects performance.
We could apply 'merged column' and 'bit vector' tricks to alleviate the problem. For example:
{code:sql}
SELECT
  COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_filter_cnt_dist,
  COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_filter_cnt_dist,
  COUNT(DISTINCT IF(value > 5, cat1, null)) as cat1_if_cnt_dist,
  COUNT(DISTINCT id) as id_cnt_dist,
  SUM(DISTINCT value) as id_sum_dist
FROM data
GROUP BY key
{code}

Current rule will rewrite the above sql plan to the following (pseudo) logical plan:
{noformat}
Aggregate(
   key = ['key]
   functions = [
       count('cat1) FILTER (WHERE (('gid = 1) AND 'max(id > 1))),
       count('(IF((value > 5), cat1, null))) FILTER (WHERE ('gid = 5)),
       count('cat2) FILTER (WHERE (('gid = 3) AND 'max(id > 2))),
       count('id) FILTER (WHERE ('gid = 2)),
       sum('value) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 'gid]
     functions = [max('id > 1), max('id > 2)]
     output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 'gid,
               'max(id > 1), 'max(id > 2)])
    Expand(
       projections = [
         ('key, 'cat1, null, null, null, null, 1, ('id > 1), null),
         ('key, null, null, null, null, 'id, 2, null, null),
         ('key, null, null, 'cat2, null, null, 3, null, ('id > 2)),
         ('key, null, 'value, null, null, null, 4, null, null),
         ('key, null, null, null, if (('value > 5)) 'cat1 else null, null, 5, null, null)
       ]
       output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id,
                 'gid, '(id > 1), '(id > 2)])
      LocalTableScan [...]
{noformat}

After applying 'merged column' and 'bit vector' tricks, the logical plan will become:
{noformat}
Aggregate(
   key = ['key]
   functions = [
       count(if (NOT (('bit_or(vector_1) & 1) = 0)) 'merged_string_1 else null)
         FILTER (WHERE ('gid = 1)),
       count(if (NOT (('bit_or(vector_1) & 2) = 0)) 'merged_string_1 else null)
         FILTER (WHERE ('gid = 1)),
       count(if (NOT (('bit_or(vector_1) & 1) = 0)) 'merged_string_1 else null)
         FILTER (WHERE ('gid = 2)),
       count('merged_integer_1) FILTER (WHERE ('gid = 3)),
       sum('merged_integer_1) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'merged_string_1, 'merged_integer_1, 'gid]
     functions = [bit_or('vector_1)]
     output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 'bit_or(vector_1)])
    Expand(
       projections = [
         ('key, 'cat1, null, 1, (if (('id > 1)) 1 else 0 | if (('value > 5)) 2 else 0)),
         ('key, 'cat2, null, 2, if (('id > 2)) 1 else 0),
         ('key, null, 'id, 3, null),
         ('key, null, 'value, 4, null)
       ]
       output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 'vector_1])
      LocalTableScan [...]
{noformat}

1. merged column: Children with same datatype from different aggregate functions can share same project column (e.g. cat1, cat2).
2. bit vector: If multiple aggregate function children have conditional expressions, these conditions will output one column when it is true, and output null when it is false. The detail logic is in RewriteDistinctAggregates.groupDistinctAggExpr of the following github link. Then these aggregate functions can share one row group, and store the results of their respective conditional expressions in the bit vector column, reducing the number of rows of data expansion (e.g. cat1_filter_cnt_dist, cat1_if_cnt_dist).
If there are many similar aggregate functions with or without filter in distinct, these tricks can save mass memory and improve performance.



--
This message was sent by Atlassian Jira
(v8.20.1#820001)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org