You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/04/04 00:27:17 UTC

spark git commit: [SPARK-19408][SQL] filter estimation on two columns of same table

Repository: spark
Updated Branches:
  refs/heads/master 58c9e6e77 -> e7877fd47


[SPARK-19408][SQL] filter estimation on two columns of same table

## What changes were proposed in this pull request?

In SQL queries, we also see predicate expressions involving two columns such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. Note that, if column-1 and column-2 belong to different tables, then it is a join operator's work, NOT a filter operator's work.

This PR estimates filter selectivity on two columns of same table.  For example, multiple tpc-h queries have this predicate "WHERE l_commitdate < l_receiptdate"

## How was this patch tested?

We added 6 new test cases to test various logical predicates involving two columns of same table.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Ron Hu <ro...@huawei.com>
Author: U-CHINA\r00754707 <r0...@R00754707-SC04.china.huawei.com>

Closes #17415 from ron8hu/filterTwoColumns.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e7877fd4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e7877fd4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e7877fd4

Branch: refs/heads/master
Commit: e7877fd4728ed41e440d7c4d8b6b02bd0d9e873e
Parents: 58c9e6e
Author: Ron Hu <ro...@huawei.com>
Authored: Mon Apr 3 17:27:12 2017 -0700
Committer: Xiao Li <ga...@gmail.com>
Committed: Mon Apr 3 17:27:12 2017 -0700

----------------------------------------------------------------------
 .../statsEstimation/FilterEstimation.scala      | 233 ++++++++++++++++++-
 .../statsEstimation/FilterEstimationSuite.scala | 140 ++++++++++-
 2 files changed, 363 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e7877fd4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
old mode 100644
new mode 100755
index b32374c..03c76cd
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -201,6 +201,21 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
       case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] =>
         evaluateNullCheck(ar, isNull = false, update)
 
+      case op @ Equality(attrLeft: Attribute, attrRight: Attribute) =>
+        evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+      case op @ LessThan(attrLeft: Attribute, attrRight: Attribute) =>
+        evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+      case op @ LessThanOrEqual(attrLeft: Attribute, attrRight: Attribute) =>
+        evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+      case op @ GreaterThan(attrLeft: Attribute, attrRight: Attribute) =>
+        evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+      case op @ GreaterThanOrEqual(attrLeft: Attribute, attrRight: Attribute) =>
+        evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
       case _ =>
         // TODO: it's difficult to support string operators without advanced statistics.
         // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _)
@@ -257,7 +272,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
   /**
    * Returns a percentage of rows meeting a binary comparison expression.
    *
-   * @param op a binary comparison operator uch as =, <, <=, >, >=
+   * @param op a binary comparison operator such as =, <, <=, >, >=
    * @param attr an Attribute (or a column)
    * @param literal a literal value (or constant)
    * @param update a boolean flag to specify if we need to update ColumnStat of a given column
@@ -448,7 +463,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
    * Returns a percentage of rows meeting a binary comparison expression.
    * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns.
    *
-   * @param op a binary comparison operator uch as =, <, <=, >, >=
+   * @param op a binary comparison operator such as =, <, <=, >, >=
    * @param attr an Attribute (or a column)
    * @param literal a literal value (or constant)
    * @param update a boolean flag to specify if we need to update ColumnStat of a given column
@@ -550,6 +565,220 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
     Some(percent.toDouble)
   }
 
+  /**
+   * Returns a percentage of rows meeting a binary comparison expression containing two columns.
+   * In SQL queries, we also see predicate expressions involving two columns
+   * such as "column-1 (op) column-2" where column-1 and column-2 belong to same table.
+   * Note that, if column-1 and column-2 belong to different tables, then it is a join
+   * operator's work, NOT a filter operator's work.
+   *
+   * @param op a binary comparison operator, including =, <=>, <, <=, >, >=
+   * @param attrLeft the left Attribute (or a column)
+   * @param attrRight the right Attribute (or a column)
+   * @param update a boolean flag to specify if we need to update ColumnStat of the given columns
+   *               for subsequent conditions
+   * @return an optional double value to show the percentage of rows meeting a given condition
+   */
+  def evaluateBinaryForTwoColumns(
+      op: BinaryComparison,
+      attrLeft: Attribute,
+      attrRight: Attribute,
+      update: Boolean): Option[Double] = {
+
+    if (!colStatsMap.contains(attrLeft)) {
+      logDebug("[CBO] No statistics for " + attrLeft)
+      return None
+    }
+    if (!colStatsMap.contains(attrRight)) {
+      logDebug("[CBO] No statistics for " + attrRight)
+      return None
+    }
+
+    attrLeft.dataType match {
+      case StringType | BinaryType =>
+        // TODO: It is difficult to support other binary comparisons for String/Binary
+        // type without min/max and advanced statistics like histogram.
+        logDebug("[CBO] No range comparison statistics for String/Binary type " + attrLeft)
+        return None
+      case _ =>
+    }
+
+    val colStatLeft = colStatsMap(attrLeft)
+    val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
+      .asInstanceOf[NumericRange]
+    val maxLeft = BigDecimal(statsRangeLeft.max)
+    val minLeft = BigDecimal(statsRangeLeft.min)
+
+    val colStatRight = colStatsMap(attrRight)
+    val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType)
+      .asInstanceOf[NumericRange]
+    val maxRight = BigDecimal(statsRangeRight.max)
+    val minRight = BigDecimal(statsRangeRight.min)
+
+    // determine the overlapping degree between predicate range and column's range
+    val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
+    val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
+      // Left < Right or Left <= Right
+      // - no overlap:
+      //      minRight           maxRight     minLeft       maxLeft
+      // --------+------------------+------------+-------------+------->
+      // - complete overlap: (If null values exists, we set it to partial overlap.)
+      //      minLeft            maxLeft      minRight      maxRight
+      // --------+------------------+------------+-------------+------->
+      case _: LessThan =>
+        (minLeft >= maxRight, (maxLeft < minRight) && allNotNull)
+      case _: LessThanOrEqual =>
+        (minLeft > maxRight, (maxLeft <= minRight) && allNotNull)
+
+      // Left > Right or Left >= Right
+      // - no overlap:
+      //      minLeft            maxLeft      minRight      maxRight
+      // --------+------------------+------------+-------------+------->
+      // - complete overlap: (If null values exists, we set it to partial overlap.)
+      //      minRight           maxRight     minLeft       maxLeft
+      // --------+------------------+------------+-------------+------->
+      case _: GreaterThan =>
+        (maxLeft <= minRight, (minLeft > maxRight) && allNotNull)
+      case _: GreaterThanOrEqual =>
+        (maxLeft < minRight, (minLeft >= maxRight) && allNotNull)
+
+      // Left = Right or Left <=> Right
+      // - no overlap:
+      //      minLeft            maxLeft      minRight      maxRight
+      // --------+------------------+------------+-------------+------->
+      //      minRight           maxRight     minLeft       maxLeft
+      // --------+------------------+------------+-------------+------->
+      // - complete overlap:
+      //      minLeft            maxLeft
+      //      minRight           maxRight
+      // --------+------------------+------->
+      case _: EqualTo =>
+        ((maxLeft < minRight) || (maxRight < minLeft),
+          (minLeft == minRight) && (maxLeft == maxRight) && allNotNull
+          && (colStatLeft.distinctCount == colStatRight.distinctCount)
+        )
+      case _: EqualNullSafe =>
+        // For null-safe equality, we use a very restrictive condition to evaluate its overlap.
+        // If null values exists, we set it to partial overlap.
+        (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull,
+          (minLeft == minRight) && (maxLeft == maxRight) && allNotNull
+            && (colStatLeft.distinctCount == colStatRight.distinctCount)
+        )
+    }
+
+    var percent = BigDecimal(1.0)
+    if (noOverlap) {
+      percent = 0.0
+    } else if (completeOverlap) {
+      percent = 1.0
+    } else {
+      // For partial overlap, we use an empirical value 1/3 as suggested by the book
+      // "Database Systems, the complete book".
+      percent = 1.0 / 3.0
+
+      if (update) {
+        // Need to adjust new min/max after the filter condition is applied
+
+        val ndvLeft = BigDecimal(colStatLeft.distinctCount)
+        var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+        if (newNdvLeft < 1) newNdvLeft = 1
+        val ndvRight = BigDecimal(colStatRight.distinctCount)
+        var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+        if (newNdvRight < 1) newNdvRight = 1
+
+        var newMaxLeft = colStatLeft.max
+        var newMinLeft = colStatLeft.min
+        var newMaxRight = colStatRight.max
+        var newMinRight = colStatRight.min
+
+        op match {
+          case _: LessThan | _: LessThanOrEqual =>
+            // the left side should be less than the right side.
+            // If not, we need to adjust it to narrow the range.
+            // Left < Right or Left <= Right
+            //      minRight     <     minLeft
+            // --------+******************+------->
+            //              filtered      ^
+            //                            |
+            //                        newMinRight
+            //
+            //      maxRight     <     maxLeft
+            // --------+******************+------->
+            //         ^    filtered
+            //         |
+            //     newMaxLeft
+            if (minLeft > minRight) newMinRight = colStatLeft.min
+            if (maxLeft > maxRight) newMaxLeft = colStatRight.max
+
+          case _: GreaterThan | _: GreaterThanOrEqual =>
+            // the left side should be greater than the right side.
+            // If not, we need to adjust it to narrow the range.
+            // Left > Right or Left >= Right
+            //      minLeft     <      minRight
+            // --------+******************+------->
+            //              filtered      ^
+            //                            |
+            //                        newMinLeft
+            //
+            //      maxLeft     <      maxRight
+            // --------+******************+------->
+            //         ^    filtered
+            //         |
+            //     newMaxRight
+            if (minLeft < minRight) newMinLeft = colStatRight.min
+            if (maxLeft < maxRight) newMaxRight = colStatLeft.max
+
+          case _: EqualTo | _: EqualNullSafe =>
+            // need to set new min to the larger min value, and
+            // set the new max to the smaller max value.
+            // Left = Right or Left <=> Right
+            //      minLeft     <      minRight
+            // --------+******************+------->
+            //              filtered      ^
+            //                            |
+            //                        newMinLeft
+            //
+            //      minRight    <=     minLeft
+            // --------+******************+------->
+            //              filtered      ^
+            //                            |
+            //                        newMinRight
+            //
+            //      maxLeft     <      maxRight
+            // --------+******************+------->
+            //         ^    filtered
+            //         |
+            //     newMaxRight
+            //
+            //      maxRight    <=     maxLeft
+            // --------+******************+------->
+            //         ^    filtered
+            //         |
+            //     newMaxLeft
+          if (minLeft < minRight) {
+            newMinLeft = colStatRight.min
+          } else {
+            newMinRight = colStatLeft.min
+          }
+          if (maxLeft < maxRight) {
+            newMaxRight = colStatLeft.max
+          } else {
+            newMaxLeft = colStatRight.max
+          }
+        }
+
+        val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft,
+          max = newMaxLeft)
+        colStatsMap(attrLeft) = newStatsLeft
+        val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight,
+          max = newMaxRight)
+        colStatsMap(attrRight) = newStatsRight
+      }
+    }
+
+    Some(percent.toDouble)
+  }
+
 }
 
 class ColumnStatsMap {

http://git-wip-us.apache.org/repos/asf/spark/blob/e7877fd4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
old mode 100644
new mode 100755
index 1966c96..cffb0d8
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -33,49 +33,74 @@ import org.apache.spark.sql.types._
 class FilterEstimationSuite extends StatsEstimationTestBase {
 
   // Suppose our test table has 10 rows and 6 columns.
-  // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
+  // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
   // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
   val attrInt = AttributeReference("cint", IntegerType)()
   val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
     nullCount = 0, avgLen = 4, maxLen = 4)
 
-  // only 2 values
+  // column cbool has only 2 distinct values
   val attrBool = AttributeReference("cbool", BooleanType)()
   val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
     nullCount = 0, avgLen = 1, maxLen = 1)
 
-  // Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
+  // column cdate has 10 values from 2017-01-01 through 2017-01-10.
   val dMin = Date.valueOf("2017-01-01")
   val dMax = Date.valueOf("2017-01-10")
   val attrDate = AttributeReference("cdate", DateType)()
   val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
     nullCount = 0, avgLen = 4, maxLen = 4)
 
-  // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
+  // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
   val decMin = new java.math.BigDecimal("0.200000000000000000")
   val decMax = new java.math.BigDecimal("0.800000000000000000")
   val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
   val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
     nullCount = 0, avgLen = 8, maxLen = 8)
 
-  // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
+  // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
   val attrDouble = AttributeReference("cdouble", DoubleType)()
   val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
     nullCount = 0, avgLen = 8, maxLen = 8)
 
-  // Sixth column cstring has 10 String values:
+  // column cstring has 10 String values:
   // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9"
   val attrString = AttributeReference("cstring", StringType)()
   val colStatString = ColumnStat(distinctCount = 10, min = None, max = None,
     nullCount = 0, avgLen = 2, maxLen = 2)
 
+  // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
+  // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4
+  // This column is created to test "cint < cint2
+  val attrInt2 = AttributeReference("cint2", IntegerType)()
+  val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16),
+    nullCount = 0, avgLen = 4, maxLen = 4)
+
+  // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39
+  // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4
+  // This column is created to test "cint = cint3 without overlap at all.
+  val attrInt3 = AttributeReference("cint3", IntegerType)()
+  val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39),
+    nullCount = 0, avgLen = 4, maxLen = 4)
+
+  // column cint4 has values in the range from 1 to 10
+  // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
+  // This column is created to test complete overlap
+  val attrInt4 = AttributeReference("cint4", IntegerType)()
+  val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+    nullCount = 0, avgLen = 4, maxLen = 4)
+
   val attributeMap = AttributeMap(Seq(
     attrInt -> colStatInt,
     attrBool -> colStatBool,
     attrDate -> colStatDate,
     attrDecimal -> colStatDecimal,
     attrDouble -> colStatDouble,
-    attrString -> colStatString))
+    attrString -> colStatString,
+    attrInt2 -> colStatInt2,
+    attrInt3 -> colStatInt3,
+    attrInt4 -> colStatInt4
+  ))
 
   test("true") {
     validateEstimatedStats(
@@ -450,6 +475,89 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
     }
   }
 
+  test("cint = cint2") {
+    // partial overlap case
+    validateEstimatedStats(
+      Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
+      Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+        attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+          nullCount = 0, avgLen = 4, maxLen = 4)),
+      expectedRowCount = 4)
+  }
+
+  test("cint > cint2") {
+    // partial overlap case
+    validateEstimatedStats(
+      Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
+      Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+        attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+          nullCount = 0, avgLen = 4, maxLen = 4)),
+      expectedRowCount = 4)
+  }
+
+  test("cint < cint2") {
+    // partial overlap case
+    validateEstimatedStats(
+      Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
+      Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+        attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16),
+          nullCount = 0, avgLen = 4, maxLen = 4)),
+      expectedRowCount = 4)
+  }
+
+  test("cint = cint4") {
+    // complete overlap case
+    validateEstimatedStats(
+      Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)),
+      Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+        attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+          nullCount = 0, avgLen = 4, maxLen = 4)),
+      expectedRowCount = 10)
+  }
+
+  test("cint < cint4") {
+    // partial overlap case
+    validateEstimatedStats(
+      Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)),
+      Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+        attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+          nullCount = 0, avgLen = 4, maxLen = 4)),
+      expectedRowCount = 4)
+  }
+
+  test("cint = cint3") {
+    // no records qualify due to no overlap
+    val emptyColStats = Seq[(Attribute, ColumnStat)]()
+    validateEstimatedStats(
+      Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
+      Nil, // set to empty
+      expectedRowCount = 0)
+  }
+
+  test("cint < cint3") {
+    // all table records qualify.
+    validateEstimatedStats(
+      Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
+      Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+        nullCount = 0, avgLen = 4, maxLen = 4),
+        attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39),
+          nullCount = 0, avgLen = 4, maxLen = 4)),
+      expectedRowCount = 10)
+  }
+
+  test("cint > cint3") {
+    // no records qualify due to no overlap
+    validateEstimatedStats(
+      Filter(GreaterThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
+      Nil, // set to empty
+      expectedRowCount = 0)
+  }
+
   private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
     StatsTestPlan(
       outputList = outList,
@@ -491,7 +599,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
         sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap),
         rowCount = Some(expectedRowCount),
         attributeStats = expectedAttributeMap)
-      assert(filter.stats(conf) == expectedStats)
+
+      val filterStats = filter.stats(conf)
+      assert(filterStats.sizeInBytes == expectedStats.sizeInBytes)
+      assert(filterStats.rowCount == expectedStats.rowCount)
+      val rowCountValue = filterStats.rowCount.getOrElse(0)
+      // check the output column stats if the row count is > 0.
+      // When row count is 0, the output is set to empty.
+      if (rowCountValue != 0) {
+        // Need to check attributeStats one by one because we may have multiple output columns.
+        // Due to update operation, the output columns may be in different order.
+        assert(expectedColStats.size == filterStats.attributeStats.size)
+        expectedColStats.foreach { kv =>
+          val filterColumnStat = filterStats.attributeStats.get(kv._1).get
+          assert(filterColumnStat == kv._2)
+        }
+      }
     }
   }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org