You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/02/24 04:18:26 UTC

spark git commit: [SPARK-17075][SQL] implemented filter estimation

Repository: spark
Updated Branches:
  refs/heads/master 2f69e3f60 -> d7e43b613


[SPARK-17075][SQL] implemented filter estimation

## What changes were proposed in this pull request?

We traverse predicate and evaluate the logical expressions to compute the selectivity of a FILTER operator.

## How was this patch tested?

We add a new test suite to test various logical operators.

Author: Ron Hu <ro...@huawei.com>

Closes #16395 from ron8hu/filterSelectivity.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7e43b61
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7e43b61
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7e43b61

Branch: refs/heads/master
Commit: d7e43b613aeb6d85b8522d7c16c6e0807843c964
Parents: 2f69e3f
Author: Ron Hu <ro...@huawei.com>
Authored: Thu Feb 23 20:18:21 2017 -0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Feb 23 20:18:21 2017 -0800

----------------------------------------------------------------------
 .../plans/logical/basicLogicalOperators.scala   |  10 +-
 .../statsEstimation/FilterEstimation.scala      | 511 +++++++++++++++++++
 .../plans/logical/statsEstimation/Range.scala   |  16 +
 .../statsEstimation/FilterEstimationSuite.scala | 403 +++++++++++++++
 4 files changed, 939 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7e43b61/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ce1c55d..ccebae3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -129,6 +129,14 @@ case class Filter(condition: Expression, child: LogicalPlan)
       .filterNot(SubqueryExpression.hasCorrelatedSubquery)
     child.constraints.union(predicates.toSet)
   }
+
+  override def computeStats(conf: CatalystConf): Statistics = {
+    if (conf.cboEnabled) {
+      FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf))
+    } else {
+      super.computeStats(conf)
+    }
+  }
 }
 
 abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7e43b61/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
new file mode 100644
index 0000000..fcc607a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -0,0 +1,511 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
+
+import java.sql.{Date, Timestamp}
+
+import scala.collection.immutable.{HashSet, Map}
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging {
+
+  /**
+   * We use a mutable colStats because we need to update the corresponding ColumnStat
+   * for a column after we apply a predicate condition.  For example, column c has
+   * [min, max] value as [0, 100].  In a range condition such as (c > 40 AND c <= 50),
+   * we need to set the column's [min, max] value to [40, 100] after we evaluate the
+   * first condition c > 40.  We need to set the column's [min, max] value to [40, 50]
+   * after we evaluate the second condition c <= 50.
+   */
+  private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty
+
+  /**
+   * Returns an option of Statistics for a Filter logical plan node.
+   * For a given compound expression condition, this method computes filter selectivity
+   * (or the percentage of rows meeting the filter condition), which
+   * is used to compute row count, size in bytes, and the updated statistics after a given
+   * predicated is applied.
+   *
+   * @return Option[Statistics] When there is no statistics collected, it returns None.
+   */
+  def estimate: Option[Statistics] = {
+    // We first copy child node's statistics and then modify it based on filter selectivity.
+    val stats: Statistics = plan.child.stats(catalystConf)
+    if (stats.rowCount.isEmpty) return None
+
+    // save a mutable copy of colStats so that we can later change it recursively
+    mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*)
+
+    // estimate selectivity of this filter predicate
+    val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match {
+      case Some(percent) => percent
+      // for not-supported condition, set filter selectivity to a conservative estimate 100%
+      case None => 1.0
+    }
+
+    // attributeStats has mapping Attribute-to-ColumnStat.
+    // mutableColStats has mapping ExprId-to-ColumnStat.
+    // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat
+    val expridToAttrMap: Map[ExprId, Attribute] =
+      stats.attributeStats.map(kv => (kv._1.exprId, kv._1))
+    // copy mutableColStats contents to an immutable AttributeMap.
+    val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] =
+      mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2)
+    val newColStats = AttributeMap(mutableAttributeStats.toSeq)
+
+    val filteredRowCount: BigInt =
+      EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity)
+    val filteredSizeInBytes =
+      EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats)
+
+    Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
+      attributeStats = newColStats))
+  }
+
+  /**
+   * Returns a percentage of rows meeting a compound condition in Filter node.
+   * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT.
+   * For logical AND conditions, we need to update stats after a condition estimation
+   * so that the stats will be more accurate for subsequent estimation.  This is needed for
+   * range condition such as (c > 40 AND c <= 50)
+   * For logical OR conditions, we do not update stats after a condition estimation.
+   *
+   * @param condition the compound logical expression
+   * @param update a boolean flag to specify if we need to update ColumnStat of a column
+   *               for subsequent conditions
+   * @return a double value to show the percentage of rows meeting a given condition.
+   *         It returns None if the condition is not supported.
+   */
+  def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
+
+    condition match {
+      case And(cond1, cond2) =>
+        (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update))
+        match {
+          case (Some(p1), Some(p2)) => Some(p1 * p2)
+          case (Some(p1), None) => Some(p1)
+          case (None, Some(p2)) => Some(p2)
+          case (None, None) => None
+        }
+
+      case Or(cond1, cond2) =>
+        // For ease of debugging, we compute percent1 and percent2 in 2 statements.
+        val percent1 = calculateFilterSelectivity(cond1, update = false)
+        val percent2 = calculateFilterSelectivity(cond2, update = false)
+        (percent1, percent2) match {
+          case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2)))
+          case (Some(p1), None) => Some(1.0)
+          case (None, Some(p2)) => Some(1.0)
+          case (None, None) => None
+        }
+
+      case Not(cond) => calculateFilterSelectivity(cond, update = false) match {
+        case Some(percent) => Some(1.0 - percent)
+        // for not-supported condition, set filter selectivity to a conservative estimate 100%
+        case None => None
+      }
+
+      case _ =>
+        calculateSingleCondition(condition, update)
+    }
+  }
+
+  /**
+   * Returns a percentage of rows meeting a single condition in Filter node.
+   * Currently we only support binary predicates where one side is a column,
+   * and the other is a literal.
+   *
+   * @param condition a single logical expression
+   * @param update a boolean flag to specify if we need to update ColumnStat of a column
+   *               for subsequent conditions
+   * @return Option[Double] value to show the percentage of rows meeting a given condition.
+   *         It returns None if the condition is not supported.
+   */
+  def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = {
+    condition match {
+      // For evaluateBinary method, we assume the literal on the right side of an operator.
+      // So we will change the order if not.
+
+      // EqualTo does not care about the order
+      case op @ EqualTo(ar: AttributeReference, l: Literal) =>
+        evaluateBinary(op, ar, l, update)
+      case op @ EqualTo(l: Literal, ar: AttributeReference) =>
+        evaluateBinary(op, ar, l, update)
+
+      case op @ LessThan(ar: AttributeReference, l: Literal) =>
+        evaluateBinary(op, ar, l, update)
+      case op @ LessThan(l: Literal, ar: AttributeReference) =>
+        evaluateBinary(GreaterThan(ar, l), ar, l, update)
+
+      case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
+        evaluateBinary(op, ar, l, update)
+      case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) =>
+        evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update)
+
+      case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
+        evaluateBinary(op, ar, l, update)
+      case op @ GreaterThan(l: Literal, ar: AttributeReference) =>
+        evaluateBinary(LessThan(ar, l), ar, l, update)
+
+      case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
+        evaluateBinary(op, ar, l, update)
+      case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) =>
+        evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)
+
+      case In(ar: AttributeReference, expList)
+        if expList.forall(e => e.isInstanceOf[Literal]) =>
+        // Expression [In (value, seq[Literal])] will be replaced with optimized version
+        // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
+        // Here we convert In into InSet anyway, because they share the same processing logic.
+        val hSet = expList.map(e => e.eval())
+        evaluateInSet(ar, HashSet() ++ hSet, update)
+
+      case InSet(ar: AttributeReference, set) =>
+        evaluateInSet(ar, set, update)
+
+      case IsNull(ar: AttributeReference) =>
+        evaluateIsNull(ar, isNull = true, update)
+
+      case IsNotNull(ar: AttributeReference) =>
+        evaluateIsNull(ar, isNull = false, update)
+
+      case _ =>
+        // TODO: it's difficult to support string operators without advanced statistics.
+        // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _)
+        // | EndsWith(_, _) are not supported yet
+        logDebug("[CBO] Unsupported filter condition: " + condition)
+        None
+    }
+  }
+
+  /**
+   * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition.
+   *
+   * @param attrRef an AttributeReference (or a column)
+   * @param isNull set to true for "IS NULL" condition.  set to false for "IS NOT NULL" condition
+   * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+   *               for subsequent conditions
+   * @return an optional double value to show the percentage of rows meeting a given condition
+   *         It returns None if no statistics collected for a given column.
+   */
+  def evaluateIsNull(
+      attrRef: AttributeReference,
+      isNull: Boolean,
+      update: Boolean)
+    : Option[Double] = {
+    if (!mutableColStats.contains(attrRef.exprId)) {
+      logDebug("[CBO] No statistics for " + attrRef)
+      return None
+    }
+    val aColStat = mutableColStats(attrRef.exprId)
+    val rowCountValue = plan.child.stats(catalystConf).rowCount.get
+    val nullPercent: BigDecimal =
+      if (rowCountValue == 0) 0.0
+      else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue)
+
+    if (update) {
+      val newStats =
+        if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None)
+        else aColStat.copy(nullCount = 0)
+
+      mutableColStats += (attrRef.exprId -> newStats)
+    }
+
+    val percent =
+      if (isNull) {
+        nullPercent.toDouble
+      }
+      else {
+        /** ISNOTNULL(column) */
+        1.0 - nullPercent.toDouble
+      }
+
+    Some(percent)
+  }
+
+  /**
+   * Returns a percentage of rows meeting a binary comparison expression.
+   *
+   * @param op a binary comparison operator uch as =, <, <=, >, >=
+   * @param attrRef an AttributeReference (or a column)
+   * @param literal a literal value (or constant)
+   * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+   *               for subsequent conditions
+   * @return an optional double value to show the percentage of rows meeting a given condition
+    *         It returns None if no statistics exists for a given column or wrong value.
+   */
+  def evaluateBinary(
+      op: BinaryComparison,
+      attrRef: AttributeReference,
+      literal: Literal,
+      update: Boolean)
+    : Option[Double] = {
+    if (!mutableColStats.contains(attrRef.exprId)) {
+      logDebug("[CBO] No statistics for " + attrRef)
+      return None
+    }
+
+    op match {
+      case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update)
+      case _ =>
+        attrRef.dataType match {
+          case _: NumericType | DateType | TimestampType =>
+            evaluateBinaryForNumeric(op, attrRef, literal, update)
+          case StringType | BinaryType =>
+            // TODO: It is difficult to support other binary comparisons for String/Binary
+            // type without min/max and advanced statistics like histogram.
+            logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef)
+            None
+        }
+    }
+  }
+
+  /**
+   * For a SQL data type, its internal data type may be different from its external type.
+   * For DateType, its internal type is Int, and its external data type is Java Date type.
+   * The min/max values in ColumnStat are saved in their corresponding external type.
+   *
+   * @param attrDataType the column data type
+   * @param litValue the literal value
+   * @return a BigDecimal value
+   */
+  def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = {
+    attrDataType match {
+      case DateType =>
+        Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
+      case TimestampType =>
+        Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
+      case StringType | BinaryType =>
+        None
+      case _ =>
+        Some(litValue)
+    }
+  }
+
+  /**
+   * Returns a percentage of rows meeting an equality (=) expression.
+   * This method evaluates the equality predicate for all data types.
+   *
+   * @param attrRef an AttributeReference (or a column)
+   * @param literal a literal value (or constant)
+   * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+   *               for subsequent conditions
+   * @return an optional double value to show the percentage of rows meeting a given condition
+   */
+  def evaluateEqualTo(
+      attrRef: AttributeReference,
+      literal: Literal,
+      update: Boolean)
+    : Option[Double] = {
+
+    val aColStat = mutableColStats(attrRef.exprId)
+    val ndv = aColStat.distinctCount
+
+    // decide if the value is in [min, max] of the column.
+    // We currently don't store min/max for binary/string type.
+    // Hence, we assume it is in boundary for binary/string type.
+    val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType)
+    val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal)
+
+    if (inBoundary) {
+
+      if (update) {
+        // We update ColumnStat structure after apply this equality predicate.
+        // Set distinctCount to 1.  Set nullCount to 0.
+        // Need to save new min/max using the external type value of the literal
+        val newValue = convertBoundValue(attrRef.dataType, literal.value)
+        val newStats = aColStat.copy(distinctCount = 1, min = newValue,
+          max = newValue, nullCount = 0)
+        mutableColStats += (attrRef.exprId -> newStats)
+      }
+
+      Some(1.0 / ndv.toDouble)
+    } else {
+      Some(0.0)
+    }
+
+  }
+
+  /**
+   * Returns a percentage of rows meeting "IN" operator expression.
+   * This method evaluates the equality predicate for all data types.
+   *
+   * @param attrRef an AttributeReference (or a column)
+   * @param hSet a set of literal values
+   * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+   *               for subsequent conditions
+   * @return an optional double value to show the percentage of rows meeting a given condition
+   *         It returns None if no statistics exists for a given column.
+   */
+
+  def evaluateInSet(
+      attrRef: AttributeReference,
+      hSet: Set[Any],
+      update: Boolean)
+    : Option[Double] = {
+    if (!mutableColStats.contains(attrRef.exprId)) {
+      logDebug("[CBO] No statistics for " + attrRef)
+      return None
+    }
+
+    val aColStat = mutableColStats(attrRef.exprId)
+    val ndv = aColStat.distinctCount
+    val aType = attrRef.dataType
+    var newNdv: Long = 0
+
+    // use [min, max] to filter the original hSet
+    aType match {
+      case _: NumericType | DateType | TimestampType =>
+        val statsRange =
+          Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange]
+
+        // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal.
+        // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec.
+        val hSetBigdec = hSet.map(e => BigDecimal(e.toString))
+        val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max)
+        // We use hSetBigdecToAnyMap to help us find the original hSet value.
+        val hSetBigdecToAnyMap: Map[BigDecimal, Any] =
+          hSet.map(e => BigDecimal(e.toString) -> e).toMap
+
+        if (validQuerySet.isEmpty) {
+          return Some(0.0)
+        }
+
+        // Need to save new min/max using the external type value of the literal
+        val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max))
+        val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min))
+
+        // newNdv should not be greater than the old ndv.  For example, column has only 2 values
+        // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
+        newNdv = math.min(validQuerySet.size.toLong, ndv.longValue())
+        if (update) {
+          val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
+                max = newMax, nullCount = 0)
+          mutableColStats += (attrRef.exprId -> newStats)
+        }
+
+      // We assume the whole set since there is no min/max information for String/Binary type
+      case StringType | BinaryType =>
+        newNdv = math.min(hSet.size.toLong, ndv.longValue())
+        if (update) {
+          val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0)
+          mutableColStats += (attrRef.exprId -> newStats)
+        }
+    }
+
+    // return the filter selectivity.  Without advanced statistics such as histograms,
+    // we have to assume uniform distribution.
+    Some(math.min(1.0, newNdv.toDouble / ndv.toDouble))
+  }
+
+  /**
+   * Returns a percentage of rows meeting a binary comparison expression.
+   * This method evaluate expression for Numeric columns only.
+   *
+   * @param op a binary comparison operator uch as =, <, <=, >, >=
+   * @param attrRef an AttributeReference (or a column)
+   * @param literal a literal value (or constant)
+   * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+   *               for subsequent conditions
+   * @return an optional double value to show the percentage of rows meeting a given condition
+   */
+  def evaluateBinaryForNumeric(
+      op: BinaryComparison,
+      attrRef: AttributeReference,
+      literal: Literal,
+      update: Boolean)
+    : Option[Double] = {
+
+    var percent = 1.0
+    val aColStat = mutableColStats(attrRef.exprId)
+    val ndv = aColStat.distinctCount
+    val statsRange =
+      Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange]
+
+    // determine the overlapping degree between predicate range and column's range
+    val literalValueBD = BigDecimal(literal.value.toString)
+    val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
+      case _: LessThan =>
+        (literalValueBD <= statsRange.min, literalValueBD > statsRange.max)
+      case _: LessThanOrEqual =>
+        (literalValueBD < statsRange.min, literalValueBD >= statsRange.max)
+      case _: GreaterThan =>
+        (literalValueBD >= statsRange.max, literalValueBD < statsRange.min)
+      case _: GreaterThanOrEqual =>
+        (literalValueBD > statsRange.max, literalValueBD <= statsRange.min)
+    }
+
+    if (noOverlap) {
+      percent = 0.0
+    } else if (completeOverlap) {
+      percent = 1.0
+    } else {
+      // this is partial overlap case
+      var newMax = aColStat.max
+      var newMin = aColStat.min
+      var newNdv = ndv
+      val literalToDouble = literalValueBD.toDouble
+      val maxToDouble = BigDecimal(statsRange.max).toDouble
+      val minToDouble = BigDecimal(statsRange.min).toDouble
+
+      // Without advanced statistics like histogram, we assume uniform data distribution.
+      // We just prorate the adjusted range over the initial range to compute filter selectivity.
+      // For ease of computation, we convert all relevant numeric values to Double.
+      percent = op match {
+        case _: LessThan =>
+          (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
+        case _: LessThanOrEqual =>
+          if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble
+          else (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
+        case _: GreaterThan =>
+          (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
+        case _: GreaterThanOrEqual =>
+          if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble
+          else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
+      }
+
+      // Need to save new min/max using the external type value of the literal
+      val newValue = convertBoundValue(attrRef.dataType, literal.value)
+
+      if (update) {
+        op match {
+          case _: GreaterThan => newMin = newValue
+          case _: GreaterThanOrEqual => newMin = newValue
+          case _: LessThan => newMax = newValue
+          case _: LessThanOrEqual => newMax = newValue
+        }
+
+        newNdv = math.max(math.round(ndv.toDouble * percent), 1)
+        val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
+          max = newMax, nullCount = 0)
+
+        mutableColStats += (attrRef.exprId -> newStats)
+      }
+    }
+
+    Some(percent)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7e43b61/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
index 5aa6b93..4557114 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 import java.math.{BigDecimal => JDecimal}
 import java.sql.{Date, Timestamp}
 
+import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
 
@@ -57,6 +58,20 @@ object Range {
       n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0
   }
 
+  def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match {
+    case _: DefaultRange => true
+    case _: NullRange => false
+    case n: NumericRange =>
+      val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) {
+        if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
+      } else {
+        assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] ||
+          lit.dataType.isInstanceOf[TimestampType])
+        new JDecimal(lit.value.toString)
+      }
+      n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0
+  }
+
   /**
    * Intersected results of two ranges. This is only for two overlapped ranges.
    * The outputs are the intersected min/max values.
@@ -113,4 +128,5 @@ object Range {
           DateTimeUtils.toJavaTimestamp(n.max.longValue()))
     }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7e43b61/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
new file mode 100644
index 0000000..f139c9e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+/**
+ * In this test suite, we test predicates containing the following operators:
+ * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN
+ */
+class FilterEstimationSuite extends StatsEstimationTestBase {
+
+  // Suppose our test table has 10 rows and 6 columns.
+  // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
+  // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
+  val arInt = AttributeReference("cint", IntegerType)()
+  val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+    nullCount = 0, avgLen = 4, maxLen = 4)
+
+  // Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
+  val dMin = Date.valueOf("2017-01-01")
+  val dMax = Date.valueOf("2017-01-10")
+  val arDate = AttributeReference("cdate", DateType)()
+  val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
+    nullCount = 0, avgLen = 4, maxLen = 4)
+
+  // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through
+  // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours).
+  val tsMin = Timestamp.valueOf("2017-01-01 01:00:00")
+  val tsMax = Timestamp.valueOf("2017-01-01 10:00:00")
+  val arTimestamp = AttributeReference("ctimestamp", TimestampType)()
+  val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax),
+    nullCount = 0, avgLen = 8, maxLen = 8)
+
+  // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
+  val decMin = new java.math.BigDecimal("0.200000000000000000")
+  val decMax = new java.math.BigDecimal("0.800000000000000000")
+  val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
+  val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
+    nullCount = 0, avgLen = 8, maxLen = 8)
+
+  // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
+  val arDouble = AttributeReference("cdouble", DoubleType)()
+  val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
+    nullCount = 0, avgLen = 8, maxLen = 8)
+
+  // Sixth column cstring has 10 String values:
+  // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9"
+  val arString = AttributeReference("cstring", StringType)()
+  val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None,
+    nullCount = 0, avgLen = 2, maxLen = 2)
+
+  test("cint = 2") {
+    validateEstimatedStats(
+      arInt,
+      Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(1L)
+    )
+  }
+
+  test("cint = 0") {
+    // This is an out-of-range case since 0 is outside the range [min, max]
+    validateEstimatedStats(
+      arInt,
+      Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(0L)
+    )
+  }
+
+  test("cint < 3") {
+    validateEstimatedStats(
+      arInt,
+      Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(3L)
+    )
+  }
+
+  test("cint < 0") {
+    // This is a corner case since literal 0 is smaller than min.
+    validateEstimatedStats(
+      arInt,
+      Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(0L)
+    )
+  }
+
+  test("cint <= 3") {
+    validateEstimatedStats(
+      arInt,
+      Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(3L)
+    )
+  }
+
+  test("cint > 6") {
+    validateEstimatedStats(
+      arInt,
+      Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(5L)
+    )
+  }
+
+  test("cint > 10") {
+    // This is a corner case since max value is 10.
+    validateEstimatedStats(
+      arInt,
+      Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(0L)
+    )
+  }
+
+  test("cint >= 6") {
+    validateEstimatedStats(
+      arInt,
+      Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(5L)
+    )
+  }
+
+  test("cint IS NULL") {
+    validateEstimatedStats(
+      arInt,
+      Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 0, min = None, max = None,
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(0L)
+    )
+  }
+
+  test("cint IS NOT NULL") {
+    validateEstimatedStats(
+      arInt,
+      Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(10L)
+    )
+  }
+
+  test("cint > 3 AND cint <= 6") {
+    val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6)))
+    validateEstimatedStats(
+      arInt,
+      Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(4L)
+    )
+  }
+
+  test("cint = 3 OR cint = 6") {
+    val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6)))
+    validateEstimatedStats(
+      arInt,
+      Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(2L)
+    )
+  }
+
+  test("cint IN (3, 4, 5)") {
+    validateEstimatedStats(
+      arInt,
+      Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(3L)
+    )
+  }
+
+  test("cint NOT IN (3, 4, 5)") {
+    validateEstimatedStats(
+      arInt,
+      Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)),
+      ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(7L)
+    )
+  }
+
+  test("cdate = cast('2017-01-02' AS DATE)") {
+    val d20170102 = Date.valueOf("2017-01-02")
+    validateEstimatedStats(
+      arDate,
+      Filter(EqualTo(arDate, Literal(d20170102)),
+        childStatsTestPlan(Seq(arDate), 10L)),
+      ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(1L)
+    )
+  }
+
+  test("cdate < cast('2017-01-03' AS DATE)") {
+    val d20170103 = Date.valueOf("2017-01-03")
+    validateEstimatedStats(
+      arDate,
+      Filter(LessThan(arDate, Literal(d20170103)),
+        childStatsTestPlan(Seq(arDate), 10L)),
+      ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(3L)
+    )
+  }
+
+  test("""cdate IN ( cast('2017-01-03' AS DATE),
+      cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") {
+    val d20170103 = Date.valueOf("2017-01-03")
+    val d20170104 = Date.valueOf("2017-01-04")
+    val d20170105 = Date.valueOf("2017-01-05")
+    validateEstimatedStats(
+      arDate,
+      Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
+        childStatsTestPlan(Seq(arDate), 10L)),
+      ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(3L)
+    )
+  }
+
+  test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") {
+    val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00")
+    validateEstimatedStats(
+      arTimestamp,
+      Filter(EqualTo(arTimestamp, Literal(ts2017010102)),
+        childStatsTestPlan(Seq(arTimestamp), 10L)),
+      ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102),
+        nullCount = 0, avgLen = 8, maxLen = 8),
+      Some(1L)
+    )
+  }
+
+  test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") {
+    val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00")
+    validateEstimatedStats(
+      arTimestamp,
+      Filter(LessThan(arTimestamp, Literal(ts2017010103)),
+        childStatsTestPlan(Seq(arTimestamp), 10L)),
+      ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103),
+        nullCount = 0, avgLen = 8, maxLen = 8),
+      Some(3L)
+    )
+  }
+
+  test("cdecimal = 0.400000000000000000") {
+    val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
+    validateEstimatedStats(
+      arDecimal,
+      Filter(EqualTo(arDecimal, Literal(dec_0_40)),
+        childStatsTestPlan(Seq(arDecimal), 4L)),
+      ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
+        nullCount = 0, avgLen = 8, maxLen = 8),
+      Some(1L)
+    )
+  }
+
+  test("cdecimal < 0.60 ") {
+    val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
+    validateEstimatedStats(
+      arDecimal,
+      Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))),
+        childStatsTestPlan(Seq(arDecimal), 4L)),
+      ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
+        nullCount = 0, avgLen = 8, maxLen = 8),
+      Some(3L)
+    )
+  }
+
+  test("cdouble < 3.0") {
+    validateEstimatedStats(
+      arDouble,
+      Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)),
+      ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
+        nullCount = 0, avgLen = 8, maxLen = 8),
+      Some(3L)
+    )
+  }
+
+  test("cstring = 'A2'") {
+    validateEstimatedStats(
+      arString,
+      Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
+      ColumnStat(distinctCount = 1, min = None, max = None,
+        nullCount = 0, avgLen = 2, maxLen = 2),
+      Some(1L)
+    )
+  }
+
+  // There is no min/max statistics for String type.  We estimate 10 rows returned.
+  test("cstring < 'A2'") {
+    validateEstimatedStats(
+      arString,
+      Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
+      ColumnStat(distinctCount = 10, min = None, max = None,
+        nullCount = 0, avgLen = 2, maxLen = 2),
+      Some(10L)
+    )
+  }
+
+  // This is a corner test case.  We want to test if we can handle the case when the number of
+  // valid values in IN clause is greater than the number of distinct values for a given column.
+  // For example, column has only 2 distinct values 1 and 6.
+  // The predicate is: column IN (1, 2, 3, 4, 5).
+  test("cint IN (1, 2, 3, 4, 5)") {
+    val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6),
+      nullCount = 0, avgLen = 4, maxLen = 4)
+    val cornerChildStatsTestplan = StatsTestPlan(
+      outputList = Seq(arInt),
+      rowCount = 2L,
+      attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt))
+    )
+    validateEstimatedStats(
+      arInt,
+      Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
+      ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+      Some(2L)
+    )
+  }
+
+  private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
+    StatsTestPlan(
+      outputList = outList,
+      rowCount = tableRowCount,
+      attributeStats = AttributeMap(Seq(
+        arInt -> childColStatInt,
+        arDate -> childColStatDate,
+        arTimestamp -> childColStatTimestamp,
+        arDecimal -> childColStatDecimal,
+        arDouble -> childColStatDouble,
+        arString -> childColStatString
+      ))
+    )
+  }
+
+  private def validateEstimatedStats(
+      ar: AttributeReference,
+      filterNode: Filter,
+      expectedColStats: ColumnStat,
+      rowCount: Option[BigInt] = None)
+    : Unit = {
+
+    val expectedRowCount: BigInt = rowCount.getOrElse(0L)
+    val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode)
+    val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats)
+
+    val filteredStats = filterNode.stats(conf)
+    assert(filteredStats.sizeInBytes == expectedSizeInBytes)
+    assert(filteredStats.rowCount == rowCount)
+    ar.dataType match {
+      case DecimalType() =>
+        // Due to the internal transformation for DecimalType within engine, the new min/max
+        // in ColumnStat may have a different structure even it contains the right values.
+        // We convert them to Java BigDecimal values so that we can compare the entire object.
+        val generatedColumnStats = filteredStats.attributeStats(ar)
+        val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString)
+        val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString)
+        val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax))
+        assert(outputColStats == expectedColStats)
+      case _ =>
+        // For all other SQL types, we compare the entire object directly.
+        assert(filteredStats.attributeStats(ar) == expectedColStats)
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org