You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2017/01/11 23:01:02 UTC
spark git commit: [SPARK-19132][SQL] Add test cases for row size
estimation and aggregate estimation
Repository: spark
Updated Branches:
refs/heads/master 66fe819ad -> 43fa21b3e
[SPARK-19132][SQL] Add test cases for row size estimation and aggregate estimation
## What changes were proposed in this pull request?
In this pr, we add more test cases for project and aggregate estimation.
## How was this patch tested?
Add test cases.
Author: wangzhenhua <wa...@huawei.com>
Closes #16551 from wzhfy/addTests.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/43fa21b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/43fa21b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/43fa21b3
Branch: refs/heads/master
Commit: 43fa21b3e62ee108bcecb74398f431f08c6b625c
Parents: 66fe819
Author: wangzhenhua <wa...@huawei.com>
Authored: Wed Jan 11 15:00:58 2017 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Jan 11 15:00:58 2017 -0800
----------------------------------------------------------------------
.../statsEstimation/AggregateEstimation.scala | 14 +-
.../statsEstimation/EstimationUtils.scala | 11 +-
.../statsEstimation/ProjectEstimation.scala | 2 +-
.../statsEstimation/AggEstimationSuite.scala | 135 -------------------
.../AggregateEstimationSuite.scala | 116 ++++++++++++++++
.../ProjectEstimationSuite.scala | 119 +++++++++++++---
.../StatsEstimationTestBase.scala | 11 +-
7 files changed, 248 insertions(+), 160 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index af67343..21e94fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -41,13 +41,19 @@ object AggregateEstimation {
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
(res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount)
- // Here we set another upper bound for the number of output rows: it must not be larger than
- // child's number of rows.
- outputRows = outputRows.min(childStats.rowCount.get)
+ outputRows = if (agg.groupingExpressions.isEmpty) {
+ // If there's no group-by columns, the output is a single row containing values of aggregate
+ // functions: aggregated results for non-empty input or initial values for empty input.
+ 1
+ } else {
+ // Here we set another upper bound for the number of output rows: it must not be larger than
+ // child's number of rows.
+ outputRows.min(childStats.rowCount.get)
+ }
val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
Some(Statistics(
- sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats),
+ sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows),
rowCount = Some(outputRows),
attributeStats = outputAttrStats,
isBroadcastable = childStats.isBroadcastable))
http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index c7eb6f0..cf4452d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -35,10 +35,13 @@ object EstimationUtils {
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
}
- def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = {
+ def getOutputSize(
+ attributes: Seq[Attribute],
+ attrStats: AttributeMap[ColumnStat],
+ outputRowCount: BigInt): BigInt = {
// We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format.
- 8 + attributes.map { attr =>
+ val sizePerRow = 8 + attributes.map { attr =>
if (attrStats.contains(attr)) {
attr.dataType match {
case StringType =>
@@ -51,5 +54,9 @@ object EstimationUtils {
attr.dataType.defaultSize
}
}.sum
+
+ // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
+ // (simple computation of statistics returns product of children).
+ if (outputRowCount > 0) outputRowCount * sizePerRow else 1
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
index 69c546b..50b869a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
@@ -36,7 +36,7 @@ object ProjectEstimation {
val outputAttrStats =
getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
Some(childStats.copy(
- sizeInBytes = childStats.rowCount.get * getRowSize(project.output, outputAttrStats),
+ sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get),
attributeStats = outputAttrStats))
} else {
None
http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
deleted file mode 100644
index ff79122..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.statsEstimation
-
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
-
-
-class AggEstimationSuite extends StatsEstimationTestBase {
-
- /** Columns for testing */
- private val columnInfo: Map[Attribute, ColumnStat] =
- Map(
- attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
- avgLen = 4, maxLen = 4))
-
- private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
- private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
- columnInfo.map(kv => kv._1.name -> kv)
-
- test("empty group-by column") {
- val colNames = Seq("key11", "key12")
- // Suppose table1 has 2 records: (1, 10), (2, 10)
- val table1 = StatsTestPlan(
- outputList = colNames.map(nameToAttr),
- stats = Statistics(
- sizeInBytes = 2 * (4 + 4),
- rowCount = Some(2),
- attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
- checkAggStats(
- child = table1,
- colNames = Nil,
- expectedRowCount = 1)
- }
-
- test("there's a primary key in group-by columns") {
- val colNames = Seq("key11", "key12")
- // Suppose table1 has 2 records: (1, 10), (2, 10)
- val table1 = StatsTestPlan(
- outputList = colNames.map(nameToAttr),
- stats = Statistics(
- sizeInBytes = 2 * (4 + 4),
- rowCount = Some(2),
- attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
- checkAggStats(
- child = table1,
- colNames = colNames,
- // Column key11 a primary key, so row count = ndv of key11 = child's row count
- expectedRowCount = table1.stats.rowCount.get)
- }
-
- test("the product of ndv's of group-by columns is too large") {
- val colNames = Seq("key21", "key22")
- // Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
- val table2 = StatsTestPlan(
- outputList = colNames.map(nameToAttr),
- stats = Statistics(
- sizeInBytes = 4 * (4 + 4),
- rowCount = Some(4),
- attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
- checkAggStats(
- child = table2,
- colNames = colNames,
- // Use child's row count as an upper bound
- expectedRowCount = table2.stats.rowCount.get)
- }
-
- test("data contains all combinations of distinct values of group-by columns.") {
- val colNames = Seq("key31", "key32")
- // Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
- val table3 = StatsTestPlan(
- outputList = colNames.map(nameToAttr),
- stats = Statistics(
- sizeInBytes = 6 * (4 + 4),
- rowCount = Some(6),
- attributeStats = AttributeMap(colNames.map(nameToColInfo))))
-
- checkAggStats(
- child = table3,
- colNames = colNames,
- // Row count = product of ndv
- expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2
- .distinctCount)
- }
-
- private def checkAggStats(
- child: LogicalPlan,
- colNames: Seq[String],
- expectedRowCount: BigInt): Unit = {
-
- val columns = colNames.map(nameToAttr)
- val testAgg = Aggregate(
- groupingExpressions = columns,
- aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(),
- child = child)
-
- val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo))
- val expectedStats = Statistics(
- sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats),
- rowCount = Some(expectedRowCount),
- attributeStats = expectedAttrStats)
-
- assert(testAgg.stats(conf) == expectedStats)
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
new file mode 100644
index 0000000..41a4bc3
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+
+
+class AggregateEstimationSuite extends StatsEstimationTestBase {
+
+ /** Columns for testing */
+ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+ attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0,
+ avgLen = 4, maxLen = 4)
+ ))
+
+ private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
+ private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+ columnInfo.map(kv => kv._1.name -> kv)
+
+ test("set an upper bound if the product of ndv's of group-by columns is too large") {
+ // Suppose table1 (key11 int, key12 int) has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
+ checkAggStats(
+ tableColumns = Seq("key11", "key12"),
+ tableRowCount = 4,
+ groupByColumns = Seq("key11", "key12"),
+ // Use child's row count as an upper bound
+ expectedOutputRowCount = 4)
+ }
+
+ test("data contains all combinations of distinct values of group-by columns.") {
+ // Suppose table2 (key21 int, key22 int) has 6 records:
+ // (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
+ checkAggStats(
+ tableColumns = Seq("key21", "key22"),
+ tableRowCount = 6,
+ groupByColumns = Seq("key21", "key22"),
+ // Row count = product of ndv
+ expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2
+ .distinctCount)
+ }
+
+ test("empty group-by column") {
+ // Suppose table1 (key11 int, key12 int) has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
+ checkAggStats(
+ tableColumns = Seq("key11", "key12"),
+ tableRowCount = 4,
+ groupByColumns = Nil,
+ expectedOutputRowCount = 1)
+ }
+
+ test("aggregate on empty table - with or without group-by column") {
+ // Suppose table3 (key31 int) is an empty table
+ // Return a single row without group-by column
+ checkAggStats(
+ tableColumns = Seq("key31"),
+ tableRowCount = 0,
+ groupByColumns = Nil,
+ expectedOutputRowCount = 1)
+ // Return empty result with group-by column
+ checkAggStats(
+ tableColumns = Seq("key31"),
+ tableRowCount = 0,
+ groupByColumns = Seq("key31"),
+ expectedOutputRowCount = 0)
+ }
+
+ private def checkAggStats(
+ tableColumns: Seq[String],
+ tableRowCount: BigInt,
+ groupByColumns: Seq[String],
+ expectedOutputRowCount: BigInt): Unit = {
+ val attributes = groupByColumns.map(nameToAttr)
+ // Construct an Aggregate for testing
+ val testAgg = Aggregate(
+ groupingExpressions = attributes,
+ aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(),
+ child = StatsTestPlan(
+ outputList = tableColumns.map(nameToAttr),
+ rowCount = tableRowCount,
+ attributeStats = AttributeMap(tableColumns.map(nameToColInfo))))
+
+ val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo))
+ val expectedStats = Statistics(
+ sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount),
+ rowCount = Some(expectedOutputRowCount),
+ attributeStats = expectedAttrStats)
+
+ assert(testAgg.stats(conf) == expectedStats)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
index a613f0f..ae102a4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
@@ -17,35 +17,122 @@
package org.apache.spark.sql.catalyst.statsEstimation
+import java.sql.{Date, Timestamp}
+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types._
class ProjectEstimationSuite extends StatsEstimationTestBase {
- test("estimate project with alias") {
- val ar1 = AttributeReference("key1", IntegerType)()
- val ar2 = AttributeReference("key2", IntegerType)()
- val colStat1 = ColumnStat(2, Some(1), Some(2), 0, 4, 4)
- val colStat2 = ColumnStat(1, Some(10), Some(10), 0, 4, 4)
+ test("project with alias") {
+ val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1),
+ max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4))
+ val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10),
+ max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4))
val child = StatsTestPlan(
outputList = Seq(ar1, ar2),
- stats = Statistics(
- sizeInBytes = 2 * (4 + 4),
- rowCount = Some(2),
- attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2))))
+ rowCount = 2,
+ attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2)))
- val project = Project(Seq(ar1, Alias(ar2, "abc")()), child)
+ val proj = Project(Seq(ar1, Alias(ar2, "abc")()), child)
val expectedColStats = Seq("key1" -> colStat1, "abc" -> colStat2)
- val expectedAttrStats = toAttributeMap(expectedColStats, project)
- // The number of rows won't change for project.
+ val expectedAttrStats = toAttributeMap(expectedColStats, proj)
val expectedStats = Statistics(
- sizeInBytes = 2 * getRowSize(project.output, expectedAttrStats),
+ sizeInBytes = 2 * (8 + 4 + 4),
rowCount = Some(2),
attributeStats = expectedAttrStats)
- assert(project.stats(conf) == expectedStats)
+ assert(proj.stats(conf) == expectedStats)
+ }
+
+ test("project on empty table") {
+ val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None,
+ nullCount = 0, avgLen = 4, maxLen = 4))
+ val child = StatsTestPlan(
+ outputList = Seq(ar1),
+ rowCount = 0,
+ attributeStats = AttributeMap(Seq(ar1 -> colStat1)))
+ checkProjectStats(
+ child = child,
+ projectAttrMap = child.attributeStats,
+ expectedSize = 1,
+ expectedRowCount = 0)
+ }
+
+ test("test row size estimation") {
+ val dec1 = new java.math.BigDecimal("1.000000000000000000")
+ val dec2 = new java.math.BigDecimal("8.000000000000000000")
+ val d1 = Date.valueOf("2016-05-08")
+ val d2 = Date.valueOf("2016-05-09")
+ val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
+ val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
+
+ val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+ AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
+ min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
+ AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
+ min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1),
+ AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
+ min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2),
+ AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
+ min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4),
+ AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
+ min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
+ AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
+ min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
+ AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
+ min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4),
+ AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
+ min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
+ AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
+ min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
+ AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2,
+ min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
+ AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2,
+ min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4),
+ AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2,
+ min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8)
+ ))
+ val columnSizes = columnInfo.map { case (attr, colStat) =>
+ (attr, attr.dataType match {
+ case StringType => colStat.avgLen + 8 + 4
+ case _ => colStat.avgLen
+ })
+ }
+ val child = StatsTestPlan(
+ outputList = columnInfo.keys.toSeq,
+ rowCount = 2,
+ attributeStats = columnInfo)
+
+ // Row with single column
+ columnInfo.keys.foreach { attr =>
+ checkProjectStats(
+ child = child,
+ projectAttrMap = AttributeMap(attr -> columnInfo(attr) :: Nil),
+ expectedSize = 2 * (8 + columnSizes(attr)),
+ expectedRowCount = 2)
+ }
+
+ // Row with multiple columns
+ checkProjectStats(
+ child = child,
+ projectAttrMap = columnInfo,
+ expectedSize = 2 * (8 + columnSizes.values.sum),
+ expectedRowCount = 2)
+ }
+
+ private def checkProjectStats(
+ child: LogicalPlan,
+ projectAttrMap: AttributeMap[ColumnStat],
+ expectedSize: BigInt,
+ expectedRowCount: BigInt): Unit = {
+ val proj = Project(projectAttrMap.keys.toSeq, child)
+ val expectedStats = Statistics(
+ sizeInBytes = expectedSize,
+ rowCount = Some(expectedRowCount),
+ attributeStats = projectAttrMap)
+ assert(proj.stats(conf) == expectedStats)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/43fa21b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index 0635309..e6adb67 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -42,7 +42,14 @@ class StatsEstimationTestBase extends SparkFunSuite {
/**
* This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
*/
-protected case class StatsTestPlan(outputList: Seq[Attribute], stats: Statistics) extends LeafNode {
+protected case class StatsTestPlan(
+ outputList: Seq[Attribute],
+ rowCount: BigInt,
+ attributeStats: AttributeMap[ColumnStat]) extends LeafNode {
override def output: Seq[Attribute] = outputList
- override def computeStats(conf: CatalystConf): Statistics = stats
+ override def computeStats(conf: CatalystConf): Statistics = Statistics(
+ // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value
+ sizeInBytes = Int.MaxValue,
+ rowCount = Some(rowCount),
+ attributeStats = attributeStats)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org