You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/10/27 00:10:16 UTC

git commit: [SPARK-3537][SPARK-3914][SQL] Refines in-memory columnar table statistics

Repository: spark
Updated Branches:
  refs/heads/master 879a16585 -> 2838bf8aa


[SPARK-3537][SPARK-3914][SQL] Refines in-memory columnar table statistics

This PR refines in-memory columnar table statistics:

1. adds 2 more statistics for in-memory table columns: `count` and `sizeInBytes`
1. adds filter pushdown support for `IS NULL` and `IS NOT NULL`.
1. caches and propagates statistics in `InMemoryRelation` once the underlying cached RDD is materialized.

   Statistics are collected to driver side with an accumulator.

This PR also fixes SPARK-3914 by properly propagating in-memory statistics.

Author: Cheng Lian <li...@databricks.com>

Closes #2860 from liancheng/propagates-in-mem-stats and squashes the following commits:

0cc5271 [Cheng Lian] Restricts visibility of o.a.s.s.c.p.l.Statistics
c5ff904 [Cheng Lian] Fixes test table name conflict
a8c818d [Cheng Lian] Refines tests
1d01074 [Cheng Lian] Bug fix: shouldn't call STRING.actualSize on null string value
7dc6a34 [Cheng Lian] Adds more in-memory table statistics and propagates them properly


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2838bf8a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2838bf8a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2838bf8a

Branch: refs/heads/master
Commit: 2838bf8aadd5228829c1a869863bc4da7877fdfb
Parents: 879a165
Author: Cheng Lian <li...@databricks.com>
Authored: Sun Oct 26 16:10:09 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Sun Oct 26 16:10:09 2014 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/AttributeMap.scala |  10 +-
 .../catalyst/plans/logical/LogicalPlan.scala    |  31 +++--
 .../apache/spark/sql/columnar/ColumnStats.scala | 122 ++++++++++---------
 .../columnar/InMemoryColumnarTableScan.scala    | 101 +++++++++------
 .../spark/sql/execution/ExistingRDD.scala       |  11 +-
 .../spark/sql/parquet/ParquetRelation.scala     |   3 +-
 .../org/apache/spark/sql/CachedTableSuite.scala |  11 +-
 .../scala/org/apache/spark/sql/TestData.scala   |  16 +--
 .../spark/sql/columnar/ColumnStatsSuite.scala   |   6 +
 .../columnar/PartitionBatchPruningSuite.scala   |  76 +++++++-----
 .../spark/sql/execution/PlannerSuite.scala      |  20 +++
 11 files changed, 240 insertions(+), 167 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 8364379..82e760b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,8 +23,7 @@ package org.apache.spark.sql.catalyst.expressions
  * of the name, or the expected nullability).
  */
 object AttributeMap {
-  def apply[A](kvs: Seq[(Attribute, A)]) =
-    new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap)
+  def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
 }
 
 class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
@@ -32,10 +31,9 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
 
   override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
 
-  override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
-    (baseMap.map(_._2) + kv).toMap
+  override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv
 
-  override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator
+  override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
 
-  override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap
+  override def -(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 882e9c6..ed578e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -26,25 +26,24 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.catalyst.types.StructType
 import org.apache.spark.sql.catalyst.trees
 
+/**
+ * Estimates of various statistics.  The default estimation logic simply lazily multiplies the
+ * corresponding statistic produced by the children.  To override this behavior, override
+ * `statistics` and assign it an overriden version of `Statistics`.
+ *
+ * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the
+ * performance of the implementations.  The reason is that estimations might get triggered in
+ * performance-critical processes, such as query plan planning.
+ *
+ * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
+ *                    defaults to the product of children's `sizeInBytes`.
+ */
+private[sql] case class Statistics(sizeInBytes: BigInt)
+
 abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
   self: Product =>
 
-  /**
-   * Estimates of various statistics.  The default estimation logic simply lazily multiplies the
-   * corresponding statistic produced by the children.  To override this behavior, override
-   * `statistics` and assign it an overriden version of `Statistics`.
-   *
-   * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the
-   * performance of the implementations.  The reason is that estimations might get triggered in
-   * performance-critical processes, such as query plan planning.
-   *
-   * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
-   *                    defaults to the product of children's `sizeInBytes`.
-   */
-  case class Statistics(
-    sizeInBytes: BigInt
-  )
-  lazy val statistics: Statistics = {
+  def statistics: Statistics = {
     if (children.size == 0) {
       throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index b34ab25..b9f9f82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, Attri
 import org.apache.spark.sql.catalyst.types._
 
 private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
-  val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)()
-  val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)()
-  val nullCount =  AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
+  val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)()
+  val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)()
+  val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
+  val count = AttributeReference(a.name + ".count", IntegerType, nullable = false)()
+  val sizeInBytes = AttributeReference(a.name + ".sizeInBytes", LongType, nullable = false)()
 
-  val schema = Seq(lowerBound, upperBound, nullCount)
+  val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes)
 }
 
 private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
@@ -45,10 +47,21 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri
  * brings significant performance penalty.
  */
 private[sql] sealed trait ColumnStats extends Serializable {
+  protected var count = 0
+  protected var nullCount = 0
+  protected var sizeInBytes = 0L
+
   /**
    * Gathers statistics information from `row(ordinal)`.
    */
-  def gatherStats(row: Row, ordinal: Int): Unit
+  def gatherStats(row: Row, ordinal: Int): Unit = {
+    if (row.isNullAt(ordinal)) {
+      nullCount += 1
+      // 4 bytes for null position
+      sizeInBytes += 4
+    }
+    count += 1
+  }
 
   /**
    * Column statistics represented as a single row, currently including closed lower bound, closed
@@ -65,163 +78,154 @@ private[sql] class NoopColumnStats extends ColumnStats {
 }
 
 private[sql] class ByteColumnStats extends ColumnStats {
-  var upper = Byte.MinValue
-  var lower = Byte.MaxValue
-  var nullCount = 0
+  protected var upper = Byte.MinValue
+  protected var lower = Byte.MaxValue
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getByte(ordinal)
       if (value > upper) upper = value
       if (value < lower) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += BYTE.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class ShortColumnStats extends ColumnStats {
-  var upper = Short.MinValue
-  var lower = Short.MaxValue
-  var nullCount = 0
+  protected var upper = Short.MinValue
+  protected var lower = Short.MaxValue
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getShort(ordinal)
       if (value > upper) upper = value
       if (value < lower) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += SHORT.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class LongColumnStats extends ColumnStats {
-  var upper = Long.MinValue
-  var lower = Long.MaxValue
-  var nullCount = 0
+  protected var upper = Long.MinValue
+  protected var lower = Long.MaxValue
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getLong(ordinal)
       if (value > upper) upper = value
       if (value < lower) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += LONG.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class DoubleColumnStats extends ColumnStats {
-  var upper = Double.MinValue
-  var lower = Double.MaxValue
-  var nullCount = 0
+  protected var upper = Double.MinValue
+  protected var lower = Double.MaxValue
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getDouble(ordinal)
       if (value > upper) upper = value
       if (value < lower) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += DOUBLE.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class FloatColumnStats extends ColumnStats {
-  var upper = Float.MinValue
-  var lower = Float.MaxValue
-  var nullCount = 0
+  protected var upper = Float.MinValue
+  protected var lower = Float.MaxValue
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getFloat(ordinal)
       if (value > upper) upper = value
       if (value < lower) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += FLOAT.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class IntColumnStats extends ColumnStats {
-  var upper = Int.MinValue
-  var lower = Int.MaxValue
-  var nullCount = 0
+  protected var upper = Int.MinValue
+  protected var lower = Int.MaxValue
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getInt(ordinal)
       if (value > upper) upper = value
       if (value < lower) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += INT.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class StringColumnStats extends ColumnStats {
-  var upper: String = null
-  var lower: String = null
-  var nullCount = 0
+  protected var upper: String = null
+  protected var lower: String = null
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getString(ordinal)
       if (upper == null || value.compareTo(upper) > 0) upper = value
       if (lower == null || value.compareTo(lower) < 0) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += STRING.actualSize(row, ordinal)
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class DateColumnStats extends ColumnStats {
-  var upper: Date = null
-  var lower: Date = null
-  var nullCount = 0
+  protected var upper: Date = null
+  protected var lower: Date = null
 
   override def gatherStats(row: Row, ordinal: Int) {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row(ordinal).asInstanceOf[Date]
       if (upper == null || value.compareTo(upper) > 0) upper = value
       if (lower == null || value.compareTo(lower) < 0) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += DATE.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
 private[sql] class TimestampColumnStats extends ColumnStats {
-  var upper: Timestamp = null
-  var lower: Timestamp = null
-  var nullCount = 0
+  protected var upper: Timestamp = null
+  protected var lower: Timestamp = null
 
   override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row(ordinal).asInstanceOf[Timestamp]
       if (upper == null || value.compareTo(upper) > 0) upper = value
       if (lower == null || value.compareTo(lower) < 0) lower = value
-    } else {
-      nullCount += 1
+      sizeInBytes += TIMESTAMP.defaultSize
     }
   }
 
-  def collectedStatistics = Row(lower, upper, nullCount)
+  def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 22ab0e2..ee63134 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar
 
 import java.nio.ByteBuffer
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
 import org.apache.spark.storage.StorageLevel
 
@@ -45,15 +47,51 @@ private[sql] case class InMemoryRelation(
     useCompression: Boolean,
     batchSize: Int,
     storageLevel: StorageLevel,
-    child: SparkPlan)
-    (private var _cachedColumnBuffers: RDD[CachedBatch] = null)
+    child: SparkPlan)(
+    private var _cachedColumnBuffers: RDD[CachedBatch] = null,
+    private var _statistics: Statistics = null)
   extends LogicalPlan with MultiInstanceRelation {
 
-  override lazy val statistics =
-    Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
+  private val batchStats =
+    child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
 
   val partitionStatistics = new PartitionStatistics(output)
 
+  private def computeSizeInBytes = {
+    val sizeOfRow: Expression =
+      BindReferences.bindReference(
+        output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add),
+        partitionStatistics.schema)
+
+    batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
+  }
+
+  // Statistics propagation contracts:
+  // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data
+  // 2. Only propagate statistics when `_statistics` is non-null
+  private def statisticsToBePropagated = if (_statistics == null) {
+    val updatedStats = statistics
+    if (_statistics == null) null else updatedStats
+  } else {
+    _statistics
+  }
+
+  override def statistics = if (_statistics == null) {
+    if (batchStats.value.isEmpty) {
+      // Underlying columnar RDD hasn't been materialized, no useful statistics information
+      // available, return the default statistics.
+      Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
+    } else {
+      // Underlying columnar RDD has been materialized, required information has also been collected
+      // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`.
+      _statistics = Statistics(sizeInBytes = computeSizeInBytes)
+      _statistics
+    }
+  } else {
+    // Pre-computed statistics
+    _statistics
+  }
+
   // If the cached column buffers were not passed in, we calculate them in the constructor.
   // As in Spark, the actual work of caching is lazy.
   if (_cachedColumnBuffers == null) {
@@ -91,6 +129,7 @@ private[sql] case class InMemoryRelation(
           val stats = Row.fromSeq(
             columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
 
+          batchStats += stats
           CachedBatch(columnBuilders.map(_.build().array()), stats)
         }
 
@@ -104,7 +143,8 @@ private[sql] case class InMemoryRelation(
 
   def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
     InMemoryRelation(
-      newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers)
+      newOutput, useCompression, batchSize, storageLevel, child)(
+      _cachedColumnBuffers, statisticsToBePropagated)
   }
 
   override def children = Seq.empty
@@ -116,7 +156,8 @@ private[sql] case class InMemoryRelation(
       batchSize,
       storageLevel,
       child)(
-      _cachedColumnBuffers).asInstanceOf[this.type]
+      _cachedColumnBuffers,
+      statisticsToBePropagated).asInstanceOf[this.type]
   }
 
   def cachedColumnBuffers = _cachedColumnBuffers
@@ -132,6 +173,8 @@ private[sql] case class InMemoryColumnarTableScan(
 
   override def output: Seq[Attribute] = attributes
 
+  private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
+
   // Returned filter predicate should return false iff it is impossible for the input expression
   // to evaluate to `true' based on statistics collected about this partition batch.
   val buildFilter: PartialFunction[Expression, Expression] = {
@@ -144,44 +187,24 @@ private[sql] case class InMemoryColumnarTableScan(
       buildFilter(lhs) || buildFilter(rhs)
 
     case EqualTo(a: AttributeReference, l: Literal) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      aStats.lowerBound <= l && l <= aStats.upperBound
-
+      statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
     case EqualTo(l: Literal, a: AttributeReference) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      aStats.lowerBound <= l && l <= aStats.upperBound
-
-    case LessThan(a: AttributeReference, l: Literal) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      aStats.lowerBound < l
-
-    case LessThan(l: Literal, a: AttributeReference) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      l < aStats.upperBound
-
-    case LessThanOrEqual(a: AttributeReference, l: Literal) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      aStats.lowerBound <= l
+      statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
 
-    case LessThanOrEqual(l: Literal, a: AttributeReference) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      l <= aStats.upperBound
+    case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
+    case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
 
-    case GreaterThan(a: AttributeReference, l: Literal) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      l < aStats.upperBound
+    case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
+    case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
 
-    case GreaterThan(l: Literal, a: AttributeReference) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      aStats.lowerBound < l
+    case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
+    case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
 
-    case GreaterThanOrEqual(a: AttributeReference, l: Literal) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      l <= aStats.upperBound
+    case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
+    case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
 
-    case GreaterThanOrEqual(l: Literal, a: AttributeReference) =>
-      val aStats = relation.partitionStatistics.forAttribute(a)
-      aStats.lowerBound <= l
+    case IsNull(a: Attribute)    => statsFor(a).nullCount > 0
+    case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
   }
 
   val partitionFilters = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 2ddf513..04c51a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -17,16 +17,13 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-
-import scala.reflect.runtime.universe.TypeTag
-
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SQLContext, Row}
 import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.{Row, SQLContext}
 
 /**
  * :: DeveloperApi ::
@@ -100,7 +97,7 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ
   override final def newInstance(): this.type = {
     SparkLogicalPlan(
       alreadyPlanned match {
-        case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
+        case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd)
         case _ => sys.error("Multiple instance of the same relation detected.")
       })(sqlContext).asInstanceOf[this.type]
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 5ae7682..82130b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -22,7 +22,6 @@ import java.io.IOException
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.fs.permission.FsAction
-
 import parquet.hadoop.ParquetOutputFormat
 import parquet.hadoop.metadata.CompressionCodecName
 import parquet.schema.MessageType
@@ -30,7 +29,7 @@ import parquet.schema.MessageType
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
 import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
 
 /**
  * Relation that consists of data stored in a Parquet columnar format.

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index da5a358..1a5d87d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
+import org.apache.spark.sql.columnar._
 import org.apache.spark.sql.test.TestSQLContext._
 import org.apache.spark.storage.{StorageLevel, RDDBlockId}
 
@@ -234,4 +234,13 @@ class CachedTableSuite extends QueryTest {
     uncacheTable("testData")
     assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
   }
+
+  test("InMemoryRelation statistics") {
+    sql("CACHE TABLE testData")
+    table("testData").queryExecution.withCachedData.collect {
+      case cached: InMemoryRelation =>
+        val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum
+        assert(cached.statistics.sizeInBytes === actualSizeInBytes)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 10b7979..1c21afc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -28,40 +28,40 @@ import org.apache.spark.sql.test.TestSQLContext._
 case class TestData(key: Int, value: String)
 
 object TestData {
-  val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize(
-    (1 to 100).map(i => TestData(i, i.toString)))
+  val testData = TestSQLContext.sparkContext.parallelize(
+    (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
   testData.registerTempTable("testData")
 
   case class LargeAndSmallInts(a: Int, b: Int)
-  val largeAndSmallInts: SchemaRDD =
+  val largeAndSmallInts =
     TestSQLContext.sparkContext.parallelize(
       LargeAndSmallInts(2147483644, 1) ::
       LargeAndSmallInts(1, 2) ::
       LargeAndSmallInts(2147483645, 1) ::
       LargeAndSmallInts(2, 2) ::
       LargeAndSmallInts(2147483646, 1) ::
-      LargeAndSmallInts(3, 2) :: Nil)
+      LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD
   largeAndSmallInts.registerTempTable("largeAndSmallInts")
 
   case class TestData2(a: Int, b: Int)
-  val testData2: SchemaRDD =
+  val testData2 =
     TestSQLContext.sparkContext.parallelize(
       TestData2(1, 1) ::
       TestData2(1, 2) ::
       TestData2(2, 1) ::
       TestData2(2, 2) ::
       TestData2(3, 1) ::
-      TestData2(3, 2) :: Nil)
+      TestData2(3, 2) :: Nil).toSchemaRDD
   testData2.registerTempTable("testData2")
 
   case class BinaryData(a: Array[Byte], b: Int)
-  val binaryData: SchemaRDD =
+  val binaryData =
     TestSQLContext.sparkContext.parallelize(
       BinaryData("12".getBytes(), 1) ::
       BinaryData("22".getBytes(), 5) ::
       BinaryData("122".getBytes(), 3) ::
       BinaryData("121".getBytes(), 2) ::
-      BinaryData("123".getBytes(), 4) :: Nil)
+      BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD
   binaryData.registerTempTable("binaryData")
 
   // TODO: There is no way to express null primitives as case classes currently...

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 6bdf741..a9f0851 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -61,6 +61,12 @@ class ColumnStatsSuite extends FunSuite {
       assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
       assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
       assertResult(10, "Wrong null count")(stats(2))
+      assertResult(20, "Wrong row count")(stats(3))
+      assertResult(stats(4), "Wrong size in bytes") {
+        rows.map { row =>
+          if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
+        }.sum
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index f53acc8..9ba3c21 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -22,8 +22,6 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
 import org.apache.spark.sql._
 import org.apache.spark.sql.test.TestSQLContext._
 
-case class IntegerData(i: Int)
-
 class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
   val originalColumnBatchSize = columnBatchSize
   val originalInMemoryPartitionPruning = inMemoryPartitionPruning
@@ -31,8 +29,12 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
   override protected def beforeAll(): Unit = {
     // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
     setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
-    val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
-    rawData.registerTempTable("intData")
+
+    val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
+      val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
+      TestData(key, string)
+    }, 5)
+    pruningData.registerTempTable("pruningData")
 
     // Enable in-memory partition pruning
     setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
@@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
   }
 
   before {
-    cacheTable("intData")
+    cacheTable("pruningData")
   }
 
   after {
-    uncacheTable("intData")
+    uncacheTable("pruningData")
   }
 
   // Comparisons
-  checkBatchPruning("i = 1", Seq(1), 1, 1)
-  checkBatchPruning("1 = i", Seq(1), 1, 1)
-  checkBatchPruning("i < 12", 1 to 11, 1, 2)
-  checkBatchPruning("i <= 11", 1 to 11, 1, 2)
-  checkBatchPruning("i > 88", 89 to 100, 1, 2)
-  checkBatchPruning("i >= 89", 89 to 100, 1, 2)
-  checkBatchPruning("12 > i", 1 to 11, 1, 2)
-  checkBatchPruning("11 >= i", 1 to 11, 1, 2)
-  checkBatchPruning("88 < i", 89 to 100, 1, 2)
-  checkBatchPruning("89 <= i", 89 to 100, 1, 2)
+  checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1))
+  checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1))
+  checkBatchPruning("SELECT key FROM pruningData WHERE key < 12", 1, 2)(1 to 11)
+  checkBatchPruning("SELECT key FROM pruningData WHERE key <= 11", 1, 2)(1 to 11)
+  checkBatchPruning("SELECT key FROM pruningData WHERE key > 88", 1, 2)(89 to 100)
+  checkBatchPruning("SELECT key FROM pruningData WHERE key >= 89", 1, 2)(89 to 100)
+  checkBatchPruning("SELECT key FROM pruningData WHERE 12 > key", 1, 2)(1 to 11)
+  checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11)
+  checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100)
+  checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100)
+
+  // IS NULL
+  checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) {
+    (1 to 10) ++ (21 to 30) ++ (41 to 50) ++ (61 to 70) ++ (81 to 90)
+  }
+
+  // IS NOT NULL
+  checkBatchPruning("SELECT key FROM pruningData WHERE value IS NOT NULL", 5, 5) {
+    (11 to 20) ++ (31 to 40) ++ (51 to 60) ++ (71 to 80) ++ (91 to 100)
+  }
 
   // Conjunction and disjunction
-  checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
-  checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
-  checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
-  checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2)
+  checkBatchPruning("SELECT key FROM pruningData WHERE key > 8 AND key <= 21", 2, 3)(9 to 21)
+  checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR key > 99", 2, 2)(Seq(1, 100))
+  checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR (key > 78 AND key < 92)", 3, 4) {
+    Seq(1) ++ (79 to 91)
+  }
 
   // With unsupported predicate
-  checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
-  checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10)
+  checkBatchPruning("SELECT key FROM pruningData WHERE NOT (key < 88)", 1, 2)(88 to 100)
+  checkBatchPruning("SELECT key FROM pruningData WHERE key < 12 AND key IS NOT NULL", 1, 2)(1 to 11)
+
+  {
+    val seq = (1 to 30).mkString(", ")
+    checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100)
+  }
 
   def checkBatchPruning(
-      filter: String,
-      expectedQueryResult: Seq[Int],
+      query: String,
       expectedReadPartitions: Int,
-      expectedReadBatches: Int): Unit = {
+      expectedReadBatches: Int)(
+      expectedQueryResult: => Seq[Int]): Unit = {
 
-    test(filter) {
-      val query = sql(s"SELECT * FROM intData WHERE $filter")
+    test(query) {
+      val schemaRdd = sql(query)
       assertResult(expectedQueryResult.toArray, "Wrong query result") {
-        query.collect().map(_.head).toArray
+        schemaRdd.collect().map(_.head).toArray
       }
 
-      val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect {
+      val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
         case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
       }.head
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2838bf8a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index f14ffca..a5af71a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -76,4 +76,24 @@ class PlannerSuite extends FunSuite {
 
     setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
   }
+
+  test("InMemoryRelation statistics propagation") {
+    val origThreshold = autoBroadcastJoinThreshold
+    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString)
+
+    testData.limit(3).registerTempTable("tiny")
+    sql("CACHE TABLE tiny")
+
+    val a = testData.as('a)
+    val b = table("tiny").as('b)
+    val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
+
+    val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
+    val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+
+    assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
+    assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+
+    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org