You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/31 10:14:08 UTC

spark git commit: [SPARK-22310][SQL] Refactor join estimation to incorporate estimation logic for different kinds of statistics

Repository: spark
Updated Branches:
  refs/heads/master aa6db57e3 -> 59589bc65


[SPARK-22310][SQL] Refactor join estimation to incorporate estimation logic for different kinds of statistics

## What changes were proposed in this pull request?

The current join estimation logic is only based on basic column statistics (such as ndv, etc). If we want to add estimation for other kinds of statistics (such as histograms), it's not easy to incorporate into the current algorithm:
1. When we have multiple pairs of join keys, the current algorithm computes cardinality in a single formula. But if different join keys have different kinds of stats, the computation logic for each pair of join keys become different, so the previous formula does not apply.
2. Currently it computes cardinality and updates join keys' column stats separately. It's better to do these two steps together, since both computation and update logic are different for different kinds of stats.

## How was this patch tested?

Only refactor, covered by existing tests.

Author: Zhenhua Wang <wa...@huawei.com>

Closes #19531 from wzhfy/join_est_refactor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/59589bc6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/59589bc6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/59589bc6

Branch: refs/heads/master
Commit: 59589bc6545b6665432febfa9ee4891a96d119c4
Parents: aa6db57
Author: Zhenhua Wang <wa...@huawei.com>
Authored: Tue Oct 31 11:13:48 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Oct 31 11:13:48 2017 +0100

----------------------------------------------------------------------
 .../statsEstimation/BasicStatsPlanVisitor.scala |   4 +-
 .../statsEstimation/JoinEstimation.scala        | 172 +++++++++----------
 2 files changed, 85 insertions(+), 91 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/59589bc6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
index 4cff72d..ca0775a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
@@ -17,9 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 
-import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types.LongType
 
 /**
  * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer.
@@ -54,7 +52,7 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
   override def visitIntersect(p: Intersect): Statistics = fallback(p)
 
   override def visitJoin(p: Join): Statistics = {
-    JoinEstimation.estimate(p).getOrElse(fallback(p))
+    JoinEstimation(p).estimate.getOrElse(fallback(p))
   }
 
   override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p)

http://git-wip-us.apache.org/repos/asf/spark/blob/59589bc6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index dcbe36d..b073108 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -28,60 +28,58 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
 
 
-object JoinEstimation extends Logging {
+case class JoinEstimation(join: Join) extends Logging {
+
+  private val leftStats = join.left.stats
+  private val rightStats = join.right.stats
+
   /**
    * Estimate statistics after join. Return `None` if the join type is not supported, or we don't
    * have enough statistics for estimation.
    */
-  def estimate(join: Join): Option[Statistics] = {
+  def estimate: Option[Statistics] = {
     join.joinType match {
       case Inner | Cross | LeftOuter | RightOuter | FullOuter =>
-        InnerOuterEstimation(join).doEstimate()
+        estimateInnerOuterJoin()
       case LeftSemi | LeftAnti =>
-        LeftSemiAntiEstimation(join).doEstimate()
+        estimateLeftSemiAntiJoin()
       case _ =>
         logDebug(s"[CBO] Unsupported join type: ${join.joinType}")
         None
     }
   }
-}
-
-case class InnerOuterEstimation(join: Join) extends Logging {
-
-  private val leftStats = join.left.stats
-  private val rightStats = join.right.stats
 
   /**
    * Estimate output size and number of rows after a join operator, and update output column stats.
    */
-  def doEstimate(): Option[Statistics] = join match {
+  private def estimateInnerOuterJoin(): Option[Statistics] = join match {
     case _ if !rowCountsExist(join.left, join.right) =>
       None
 
     case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
       // 1. Compute join selectivity
       val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
-      val selectivity = joinSelectivity(joinKeyPairs)
+      val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs)
 
       // 2. Estimate the number of output rows
       val leftRows = leftStats.rowCount.get
       val rightRows = rightStats.rowCount.get
-      val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity)
 
       // Make sure outputRows won't be too small based on join type.
       val outputRows = joinType match {
         case LeftOuter =>
           // All rows from left side should be in the result.
-          leftRows.max(innerJoinedRows)
+          leftRows.max(numInnerJoinedRows)
         case RightOuter =>
           // All rows from right side should be in the result.
-          rightRows.max(innerJoinedRows)
+          rightRows.max(numInnerJoinedRows)
         case FullOuter =>
           // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B)
-          leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows
+          leftRows.max(numInnerJoinedRows) + rightRows.max(numInnerJoinedRows) - numInnerJoinedRows
         case _ =>
+          assert(joinType == Inner || joinType == Cross)
           // Don't change for inner or cross join
-          innerJoinedRows
+          numInnerJoinedRows
       }
 
       // 3. Update statistics based on the output of join
@@ -93,7 +91,7 @@ case class InnerOuterEstimation(join: Join) extends Logging {
       val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) {
         // The output is empty, we don't need to keep column stats.
         Nil
-      } else if (selectivity == 0) {
+      } else if (numInnerJoinedRows == 0) {
         joinType match {
           // For outer joins, if the join selectivity is 0, the number of output rows is the
           // same as that of the outer side. And column stats of join keys from the outer side
@@ -113,26 +111,28 @@ case class InnerOuterEstimation(join: Join) extends Logging {
               val oriColStat = inputAttrStats(a)
               (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows))
             }
-          case _ => Nil
+          case _ =>
+            assert(joinType == Inner || joinType == Cross)
+            Nil
         }
-      } else if (selectivity == 1) {
+      } else if (numInnerJoinedRows == leftRows * rightRows) {
         // Cartesian product, just propagate the original column stats
         inputAttrStats.toSeq
       } else {
-        val joinKeyStats = getIntersectedStats(joinKeyPairs)
         join.joinType match {
           // For outer joins, don't update column stats from the outer side.
           case LeftOuter =>
             fromLeft.map(a => (a, inputAttrStats(a))) ++
-              updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats)
+              updateOutputStats(outputRows, fromRight, inputAttrStats, keyStatsAfterJoin)
           case RightOuter =>
-            updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++
+            updateOutputStats(outputRows, fromLeft, inputAttrStats, keyStatsAfterJoin) ++
               fromRight.map(a => (a, inputAttrStats(a)))
           case FullOuter =>
             inputAttrStats.toSeq
           case _ =>
+            assert(joinType == Inner || joinType == Cross)
             // Update column stats from both sides for inner or cross join.
-            updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats)
+            updateOutputStats(outputRows, attributesWithStat, inputAttrStats, keyStatsAfterJoin)
         }
       }
 
@@ -157,64 +157,90 @@ case class InnerOuterEstimation(join: Join) extends Logging {
   // scalastyle:off
   /**
    * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula:
-   * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of
-   * that column. The underlying assumption for this formula is: each value of the smaller domain
-   * is included in the larger domain.
-   * Generally, inner join with multiple join keys can also be estimated based on the above
-   * formula:
+   * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)),
+   * where V is the number of distinct values (ndv) of that column. The underlying assumption for
+   * this formula is: each value of the smaller domain is included in the larger domain.
+   *
+   * Generally, inner join with multiple join keys can be estimated based on the above formula:
    * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn)))
    * However, the denominator can become very large and excessively reduce the result, so we use a
    * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator.
+   *
+   * That is, join estimation is based on the most selective join keys. We follow this strategy
+   * when different types of column statistics are available. E.g., if card1 is the cardinality
+   * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms
+   * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2).
+   *
+   * @param keyPairs pairs of join keys
+   *
+   * @return join cardinality, and column stats for join keys after the join
    */
   // scalastyle:on
-  def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = {
-    var ndvDenom: BigInt = -1
+  private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)])
+    : (BigInt, AttributeMap[ColumnStat]) = {
+    // If there's no column stats available for join keys, estimate as cartesian product.
+    var joinCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get
+    val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]()
     var i = 0
-    while(i < joinKeyPairs.length && ndvDenom != 0) {
-      val (leftKey, rightKey) = joinKeyPairs(i)
+    while(i < keyPairs.length && joinCard != 0) {
+      val (leftKey, rightKey) = keyPairs(i)
       // Check if the two sides are disjoint
-      val leftKeyStats = leftStats.attributeStats(leftKey)
-      val rightKeyStats = rightStats.attributeStats(rightKey)
-      val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType)
-      val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType)
+      val leftKeyStat = leftStats.attributeStats(leftKey)
+      val rightKeyStat = rightStats.attributeStats(rightKey)
+      val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType)
+      val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType)
       if (ValueInterval.isIntersected(lInterval, rInterval)) {
-        // Get the largest ndv among pairs of join keys
-        val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount)
-        if (maxNdv > ndvDenom) ndvDenom = maxNdv
+        val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
+        val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax)
+        keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat)
+        // Return cardinality estimated from the most selective join keys.
+        if (card < joinCard) joinCard = card
       } else {
-        // Set ndvDenom to zero to indicate that this join should have no output
-        ndvDenom = 0
+        // One of the join key pairs is disjoint, thus the two sides of join is disjoint.
+        joinCard = 0
       }
       i += 1
     }
+    (joinCard, AttributeMap(keyStatsAfterJoin.toSeq))
+  }
 
-    if (ndvDenom < 0) {
-      // We can't find any join key pairs with column stats, estimate it as cartesian join.
-      1
-    } else if (ndvDenom == 0) {
-      // One of the join key pairs is disjoint, thus the two sides of join is disjoint.
-      0
-    } else {
-      1 / BigDecimal(ndvDenom)
-    }
+  /** Returns join cardinality and the column stat for this pair of join keys. */
+  private def computeByNdv(
+      leftKey: AttributeReference,
+      rightKey: AttributeReference,
+      newMin: Option[Any],
+      newMax: Option[Any]): (BigInt, ColumnStat) = {
+    val leftKeyStat = leftStats.attributeStats(leftKey)
+    val rightKeyStat = rightStats.attributeStats(rightKey)
+    val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount)
+    // Compute cardinality by the basic formula.
+    val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv)
+
+    // Get the intersected column stat.
+    val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount)
+    val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
+    val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
+    val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
+
+    (ceil(card), newStats)
   }
 
   /**
    * Propagate or update column stats for output attributes.
    */
-  private def updateAttrStats(
+  private def updateOutputStats(
       outputRows: BigInt,
-      attributes: Seq[Attribute],
+      output: Seq[Attribute],
       oldAttrStats: AttributeMap[ColumnStat],
-      joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = {
+      keyStatsAfterJoin: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = {
     val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
     val leftRows = leftStats.rowCount.get
     val rightRows = rightStats.rowCount.get
 
-    attributes.foreach { a =>
+    output.foreach { a =>
       // check if this attribute is a join key
-      if (joinKeyStats.contains(a)) {
-        outputAttrStats += a -> joinKeyStats(a)
+      if (keyStatsAfterJoin.contains(a)) {
+        outputAttrStats += a -> keyStatsAfterJoin(a)
       } else {
         val oldColStat = oldAttrStats(a)
         val oldNdv = oldColStat.distinctCount
@@ -231,34 +257,6 @@ case class InnerOuterEstimation(join: Join) extends Logging {
     outputAttrStats
   }
 
-  /** Get intersected column stats for join keys. */
-  private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)])
-    : AttributeMap[ColumnStat] = {
-
-    val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]()
-    joinKeyPairs.foreach { case (leftKey, rightKey) =>
-      val leftKeyStats = leftStats.attributeStats(leftKey)
-      val rightKeyStats = rightStats.attributeStats(rightKey)
-      val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType)
-      val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType)
-      // When we reach here, join selectivity is not zero, so each pair of join keys should be
-      // intersected.
-      assert(ValueInterval.isIntersected(lInterval, rInterval))
-
-      // Update intersected column stats
-      assert(leftKey.dataType.sameType(rightKey.dataType))
-      val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
-      val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
-      val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen)
-      val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2
-      val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
-
-      intersectedStats.put(leftKey, newStats)
-      intersectedStats.put(rightKey, newStats)
-    }
-    AttributeMap(intersectedStats.toSeq)
-  }
-
   private def extractJoinKeysWithColStats(
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = {
@@ -270,10 +268,8 @@ case class InnerOuterEstimation(join: Join) extends Logging {
         if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk)
     }
   }
-}
 
-case class LeftSemiAntiEstimation(join: Join) {
-  def doEstimate(): Option[Statistics] = {
+  private def estimateLeftSemiAntiJoin(): Option[Statistics] = {
     // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic
     // column stats. Now we just propagate the statistics from left side. We should do more
     // accurate estimation when advanced stats (e.g. histograms) are available.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org