You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2017/01/11 23:01:02 UTC

spark git commit: [SPARK-19132][SQL] Add test cases for row size estimation and aggregate estimation

Repository: spark
Updated Branches:
  refs/heads/master 66fe819ad -> 43fa21b3e


[SPARK-19132][SQL] Add test cases for row size estimation and aggregate estimation

## What changes were proposed in this pull request?

In this pr, we add more test cases for project and aggregate estimation.

## How was this patch tested?

Add test cases.

Author: wangzhenhua <wa...@huawei.com>

Closes #16551 from wzhfy/addTests.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/43fa21b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/43fa21b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/43fa21b3

Branch: refs/heads/master
Commit: 43fa21b3e62ee108bcecb74398f431f08c6b625c
Parents: 66fe819
Author: wangzhenhua <wa...@huawei.com>
Authored: Wed Jan 11 15:00:58 2017 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Jan 11 15:00:58 2017 -0800

----------------------------------------------------------------------
 .../statsEstimation/AggregateEstimation.scala   |  14 +-
 .../statsEstimation/EstimationUtils.scala       |  11 +-
 .../statsEstimation/ProjectEstimation.scala     |   2 +-
 .../statsEstimation/AggEstimationSuite.scala    | 135 -------------------
 .../AggregateEstimationSuite.scala              | 116 ++++++++++++++++
 .../ProjectEstimationSuite.scala                | 119 +++++++++++++---
 .../StatsEstimationTestBase.scala               |  11 +-
 7 files changed, 248 insertions(+), 160 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index af67343..21e94fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -41,13 +41,19 @@ object AggregateEstimation {
       var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
         (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount)
 
-      // Here we set another upper bound for the number of output rows: it must not be larger than
-      // child's number of rows.
-      outputRows = outputRows.min(childStats.rowCount.get)
+      outputRows = if (agg.groupingExpressions.isEmpty) {
+        // If there's no group-by columns, the output is a single row containing values of aggregate
+        // functions: aggregated results for non-empty input or initial values for empty input.
+        1
+      } else {
+        // Here we set another upper bound for the number of output rows: it must not be larger than
+        // child's number of rows.
+        outputRows.min(childStats.rowCount.get)
+      }
 
       val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
       Some(Statistics(
-        sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats),
+        sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows),
         rowCount = Some(outputRows),
         attributeStats = outputAttrStats,
         isBroadcastable = childStats.isBroadcastable))

http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index c7eb6f0..cf4452d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -35,10 +35,13 @@ object EstimationUtils {
     AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
   }
 
-  def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = {
+  def getOutputSize(
+      attributes: Seq[Attribute],
+      attrStats: AttributeMap[ColumnStat],
+      outputRowCount: BigInt): BigInt = {
     // We assign a generic overhead for a Row object, the actual overhead is different for different
     // Row format.
-    8 + attributes.map { attr =>
+    val sizePerRow = 8 + attributes.map { attr =>
       if (attrStats.contains(attr)) {
         attr.dataType match {
           case StringType =>
@@ -51,5 +54,9 @@ object EstimationUtils {
         attr.dataType.defaultSize
       }
     }.sum
+
+    // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
+    // (simple computation of statistics returns product of children).
+    if (outputRowCount > 0) outputRowCount * sizePerRow else 1
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
index 69c546b..50b869a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
@@ -36,7 +36,7 @@ object ProjectEstimation {
       val outputAttrStats =
         getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
       Some(childStats.copy(
-        sizeInBytes = childStats.rowCount.get * getRowSize(project.output, outputAttrStats),
+        sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get),
         attributeStats = outputAttrStats))
     } else {
       None

http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
deleted file mode 100644
index ff79122..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.statsEstimation
-
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
-
-
-class AggEstimationSuite extends StatsEstimationTestBase {
-
-  /** Columns for testing */
-  private val columnInfo: Map[Attribute, ColumnStat] =
-    Map(
-      attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
-        avgLen = 4, maxLen = 4),
-      attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0,
-        avgLen = 4, maxLen = 4),
-      attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
-        avgLen = 4, maxLen = 4),
-      attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
-        avgLen = 4, maxLen = 4),
-      attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
-        avgLen = 4, maxLen = 4),
-      attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
-        avgLen = 4, maxLen = 4))
-
-  private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
-  private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
-    columnInfo.map(kv => kv._1.name -> kv)
-
-  test("empty group-by column") {
-    val colNames = Seq("key11", "key12")
-    // Suppose table1 has 2 records: (1, 10), (2, 10)
-    val table1 = StatsTestPlan(
-      outputList = colNames.map(nameToAttr),
-      stats = Statistics(
-        sizeInBytes = 2 * (4 + 4),
-        rowCount = Some(2),
-        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
-    checkAggStats(
-      child = table1,
-      colNames = Nil,
-      expectedRowCount = 1)
-  }
-
-  test("there's a primary key in group-by columns") {
-    val colNames = Seq("key11", "key12")
-    // Suppose table1 has 2 records: (1, 10), (2, 10)
-    val table1 = StatsTestPlan(
-      outputList = colNames.map(nameToAttr),
-      stats = Statistics(
-        sizeInBytes = 2 * (4 + 4),
-        rowCount = Some(2),
-        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
-    checkAggStats(
-      child = table1,
-      colNames = colNames,
-      // Column key11 a primary key, so row count = ndv of key11 = child's row count
-      expectedRowCount = table1.stats.rowCount.get)
-  }
-
-  test("the product of ndv's of group-by columns is too large") {
-    val colNames = Seq("key21", "key22")
-    // Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
-    val table2 = StatsTestPlan(
-      outputList = colNames.map(nameToAttr),
-      stats = Statistics(
-        sizeInBytes = 4 * (4 + 4),
-        rowCount = Some(4),
-        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
-    checkAggStats(
-      child = table2,
-      colNames = colNames,
-      // Use child's row count as an upper bound
-      expectedRowCount = table2.stats.rowCount.get)
-  }
-
-  test("data contains all combinations of distinct values of group-by columns.") {
-    val colNames = Seq("key31", "key32")
-    // Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
-    val table3 = StatsTestPlan(
-      outputList = colNames.map(nameToAttr),
-      stats = Statistics(
-        sizeInBytes = 6 * (4 + 4),
-        rowCount = Some(6),
-        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
-    checkAggStats(
-      child = table3,
-      colNames = colNames,
-      // Row count = product of ndv
-      expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2
-        .distinctCount)
-  }
-
-  private def checkAggStats(
-      child: LogicalPlan,
-      colNames: Seq[String],
-      expectedRowCount: BigInt): Unit = {
-
-    val columns = colNames.map(nameToAttr)
-    val testAgg = Aggregate(
-      groupingExpressions = columns,
-      aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(),
-      child = child)
-
-    val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo))
-    val expectedStats = Statistics(
-      sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats),
-      rowCount = Some(expectedRowCount),
-      attributeStats = expectedAttrStats)
-
-    assert(testAgg.stats(conf) == expectedStats)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
new file mode 100644
index 0000000..41a4bc3
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+
+
+class AggregateEstimationSuite extends StatsEstimationTestBase {
+
+  /** Columns for testing */
+  private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+    attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
+      avgLen = 4, maxLen = 4),
+    attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
+      avgLen = 4, maxLen = 4),
+    attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
+      avgLen = 4, maxLen = 4),
+    attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
+      avgLen = 4, maxLen = 4),
+    attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0,
+      avgLen = 4, maxLen = 4)
+  ))
+
+  private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
+  private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+    columnInfo.map(kv => kv._1.name -> kv)
+
+  test("set an upper bound if the product of ndv's of group-by columns is too large") {
+    // Suppose table1 (key11 int, key12 int) has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
+    checkAggStats(
+      tableColumns = Seq("key11", "key12"),
+      tableRowCount = 4,
+      groupByColumns = Seq("key11", "key12"),
+      // Use child's row count as an upper bound
+      expectedOutputRowCount = 4)
+  }
+
+  test("data contains all combinations of distinct values of group-by columns.") {
+    // Suppose table2 (key21 int, key22 int) has 6 records:
+    // (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
+    checkAggStats(
+      tableColumns = Seq("key21", "key22"),
+      tableRowCount = 6,
+      groupByColumns = Seq("key21", "key22"),
+      // Row count = product of ndv
+      expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2
+        .distinctCount)
+  }
+
+  test("empty group-by column") {
+    // Suppose table1 (key11 int, key12 int) has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
+    checkAggStats(
+      tableColumns = Seq("key11", "key12"),
+      tableRowCount = 4,
+      groupByColumns = Nil,
+      expectedOutputRowCount = 1)
+  }
+
+  test("aggregate on empty table - with or without group-by column") {
+    // Suppose table3 (key31 int) is an empty table
+    // Return a single row without group-by column
+    checkAggStats(
+      tableColumns = Seq("key31"),
+      tableRowCount = 0,
+      groupByColumns = Nil,
+      expectedOutputRowCount = 1)
+    // Return empty result with group-by column
+    checkAggStats(
+      tableColumns = Seq("key31"),
+      tableRowCount = 0,
+      groupByColumns = Seq("key31"),
+      expectedOutputRowCount = 0)
+  }
+
+  private def checkAggStats(
+      tableColumns: Seq[String],
+      tableRowCount: BigInt,
+      groupByColumns: Seq[String],
+      expectedOutputRowCount: BigInt): Unit = {
+    val attributes = groupByColumns.map(nameToAttr)
+    // Construct an Aggregate for testing
+    val testAgg = Aggregate(
+      groupingExpressions = attributes,
+      aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(),
+      child = StatsTestPlan(
+        outputList = tableColumns.map(nameToAttr),
+        rowCount = tableRowCount,
+        attributeStats = AttributeMap(tableColumns.map(nameToColInfo))))
+
+    val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo))
+    val expectedStats = Statistics(
+      sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount),
+      rowCount = Some(expectedOutputRowCount),
+      attributeStats = expectedAttrStats)
+
+    assert(testAgg.stats(conf) == expectedStats)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
index a613f0f..ae102a4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
@@ -17,35 +17,122 @@
 
 package org.apache.spark.sql.catalyst.statsEstimation
 
+import java.sql.{Date, Timestamp}
+
 import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types._
 
 
 class ProjectEstimationSuite extends StatsEstimationTestBase {
 
-  test("estimate project with alias") {
-    val ar1 = AttributeReference("key1", IntegerType)()
-    val ar2 = AttributeReference("key2", IntegerType)()
-    val colStat1 = ColumnStat(2, Some(1), Some(2), 0, 4, 4)
-    val colStat2 = ColumnStat(1, Some(10), Some(10), 0, 4, 4)
+  test("project with alias") {
+    val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1),
+      max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4))
+    val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10),
+      max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4))
 
     val child = StatsTestPlan(
       outputList = Seq(ar1, ar2),
-      stats = Statistics(
-        sizeInBytes = 2 * (4 + 4),
-        rowCount = Some(2),
-        attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2))))
+      rowCount = 2,
+      attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2)))
 
-    val project = Project(Seq(ar1, Alias(ar2, "abc")()), child)
+    val proj = Project(Seq(ar1, Alias(ar2, "abc")()), child)
     val expectedColStats = Seq("key1" -> colStat1, "abc" -> colStat2)
-    val expectedAttrStats = toAttributeMap(expectedColStats, project)
-    // The number of rows won't change for project.
+    val expectedAttrStats = toAttributeMap(expectedColStats, proj)
     val expectedStats = Statistics(
-      sizeInBytes = 2 * getRowSize(project.output, expectedAttrStats),
+      sizeInBytes = 2 * (8 + 4 + 4),
       rowCount = Some(2),
       attributeStats = expectedAttrStats)
-    assert(project.stats(conf) == expectedStats)
+    assert(proj.stats(conf) == expectedStats)
+  }
+
+  test("project on empty table") {
+    val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None,
+      nullCount = 0, avgLen = 4, maxLen = 4))
+    val child = StatsTestPlan(
+      outputList = Seq(ar1),
+      rowCount = 0,
+      attributeStats = AttributeMap(Seq(ar1 -> colStat1)))
+    checkProjectStats(
+      child = child,
+      projectAttrMap = child.attributeStats,
+      expectedSize = 1,
+      expectedRowCount = 0)
+  }
+
+  test("test row size estimation") {
+    val dec1 = new java.math.BigDecimal("1.000000000000000000")
+    val dec2 = new java.math.BigDecimal("8.000000000000000000")
+    val d1 = Date.valueOf("2016-05-08")
+    val d2 = Date.valueOf("2016-05-09")
+    val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
+    val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
+
+    val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+      AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
+        min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
+      AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
+        min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1),
+      AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
+        min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2),
+      AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
+        min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4),
+      AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
+        min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
+      AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
+        min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
+      AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
+        min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4),
+      AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
+        min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
+      AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
+        min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
+      AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2,
+        min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
+      AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2,
+        min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4),
+      AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2,
+        min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8)
+    ))
+    val columnSizes = columnInfo.map { case (attr, colStat) =>
+      (attr, attr.dataType match {
+        case StringType => colStat.avgLen + 8 + 4
+        case _ => colStat.avgLen
+      })
+    }
+    val child = StatsTestPlan(
+      outputList = columnInfo.keys.toSeq,
+      rowCount = 2,
+      attributeStats = columnInfo)
+
+    // Row with single column
+    columnInfo.keys.foreach { attr =>
+      checkProjectStats(
+        child = child,
+        projectAttrMap = AttributeMap(attr -> columnInfo(attr) :: Nil),
+        expectedSize = 2 * (8 + columnSizes(attr)),
+        expectedRowCount = 2)
+    }
+
+    // Row with multiple columns
+    checkProjectStats(
+      child = child,
+      projectAttrMap = columnInfo,
+      expectedSize = 2 * (8 + columnSizes.values.sum),
+      expectedRowCount = 2)
+  }
+
+  private def checkProjectStats(
+      child: LogicalPlan,
+      projectAttrMap: AttributeMap[ColumnStat],
+      expectedSize: BigInt,
+      expectedRowCount: BigInt): Unit = {
+    val proj = Project(projectAttrMap.keys.toSeq, child)
+    val expectedStats = Statistics(
+      sizeInBytes = expectedSize,
+      rowCount = Some(expectedRowCount),
+      attributeStats = projectAttrMap)
+    assert(proj.stats(conf) == expectedStats)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index 0635309..e6adb67 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -42,7 +42,14 @@ class StatsEstimationTestBase extends SparkFunSuite {
 /**
  * This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
  */
-protected case class StatsTestPlan(outputList: Seq[Attribute], stats: Statistics) extends LeafNode {
+protected case class StatsTestPlan(
+    outputList: Seq[Attribute],
+    rowCount: BigInt,
+    attributeStats: AttributeMap[ColumnStat]) extends LeafNode {
   override def output: Seq[Attribute] = outputList
-  override def computeStats(conf: CatalystConf): Statistics = stats
+  override def computeStats(conf: CatalystConf): Statistics = Statistics(
+    // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value
+    sizeInBytes = Int.MaxValue,
+    rowCount = Some(rowCount),
+    attributeStats = attributeStats)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org