You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2019/03/21 02:22:18 UTC

[spark] branch master updated: [SPARK-26894][SQL] Handle Alias as well in AggregateEstimation to propagate child stats

This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b1857a4  [SPARK-26894][SQL] Handle Alias as well in AggregateEstimation to propagate child stats
b1857a4 is described below

commit b1857a4d7dfe17663f8adccd7825d890ae70d2a1
Author: Venkata krishnan Sowrirajan <vs...@qubole.com>
AuthorDate: Thu Mar 21 11:21:56 2019 +0900

    [SPARK-26894][SQL] Handle Alias as well in AggregateEstimation to propagate child stats
    
    ## What changes were proposed in this pull request?
    
    Currently aliases are not handled in the Aggregate Estimation due to which stats are not getting propagated. This causes CBO join-reordering to not give optimal join plans. ProjectEstimation is already taking care of aliases, we need same logic for AggregateEstimation as well to properly propagate stats when CBO is enabled.
    
    ## How was this patch tested?
    
    This patch is manually tested using the query Q83 of TPCDS benchmark (scale 1000)
    
    Closes #23803 from venkata91/aggstats.
    
    Authored-by: Venkata krishnan Sowrirajan <vs...@qubole.com>
    Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
 .../statsEstimation/AggregateEstimation.scala      |  7 +++++--
 .../logical/statsEstimation/EstimationUtils.scala  | 14 ++++++++++++-
 .../statsEstimation/ProjectEstimation.scala        | 10 +++------
 .../statsEstimation/AggregateEstimationSuite.scala | 24 ++++++++++++++++++++++
 4 files changed, 45 insertions(+), 10 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index eb56ab4..0606d0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
 
 
@@ -52,7 +52,10 @@ object AggregateEstimation {
         outputRows.min(childStats.rowCount.get)
       }
 
-      val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
+      val aliasStats = EstimationUtils.getAliasStats(agg.expressions, childStats.attributeStats)
+
+      val outputAttrStats = getOutputMap(
+        AttributeMap(childStats.attributeStats.toSeq ++ aliasStats), agg.output)
       Some(Statistics(
         sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
         rowCount = Some(outputRows),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 211a2a0..11d2f02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 import scala.collection.mutable.ArrayBuffer
 import scala.math.BigDecimal.RoundingMode
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.{DecimalType, _}
 
@@ -71,6 +71,18 @@ object EstimationUtils {
     AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
   }
 
+  /**
+   * Returns the stats for aliases of child's attributes
+   */
+  def getAliasStats(
+      expressions: Seq[Expression],
+      attributeStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = {
+    expressions.collect {
+      case alias @ Alias(attr: Attribute, _) if attributeStats.contains(attr) =>
+        alias.toAttribute -> attributeStats(attr)
+    }
+  }
+
   def getSizePerRow(
       attributes: Seq[Attribute],
       attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
index 489eb90..6925423 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
@@ -26,14 +26,10 @@ object ProjectEstimation {
   def estimate(project: Project): Option[Statistics] = {
     if (rowCountsExist(project.child)) {
       val childStats = project.child.stats
-      val inputAttrStats = childStats.attributeStats
-      // Match alias with its child's column stat
-      val aliasStats = project.expressions.collect {
-        case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) =>
-          alias.toAttribute -> inputAttrStats(attr)
-      }
+      val aliasStats = EstimationUtils.getAliasStats(project.expressions, childStats.attributeStats)
+
       val outputAttrStats =
-        getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
+        getOutputMap(AttributeMap(childStats.attributeStats.toSeq ++ aliasStats), project.output)
       Some(childStats.copy(
         sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats),
         attributeStats = outputAttrStats))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
index 8213d56..dfa6e46 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
@@ -45,6 +45,30 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest {
   private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
     columnInfo.map(kv => kv._1.name -> kv)
 
+  test("SPARK-26894: propagate child stats for aliases in Aggregate") {
+    val tableColumns = Seq("key11", "key12")
+    val groupByColumns = Seq("key11")
+    val attributes = groupByColumns.map(nameToAttr)
+
+    val rowCount = 2
+    val child = StatsTestPlan(
+      outputList = tableColumns.map(nameToAttr),
+      rowCount,
+      // rowCount * (overhead + column size)
+      size = Some(4 * (8 + 4)),
+      attributeStats = AttributeMap(tableColumns.map(nameToColInfo)))
+
+    val testAgg = Aggregate(
+      groupingExpressions = attributes,
+      aggregateExpressions = Seq(Alias(nameToAttr("key12"), "abc")()),
+      child)
+
+    val expectedColStats = Seq("abc" -> nameToColInfo("key12")._2)
+    val expectedAttrStats = toAttributeMap(expectedColStats, testAgg)
+
+    assert(testAgg.stats.attributeStats == expectedAttrStats)
+  }
+
   test("set an upper bound if the product of ndv's of group-by columns is too large") {
     // Suppose table1 (key11 int, key12 int) has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
     checkAggStats(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org