You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/03/23 04:23:28 UTC

spark git commit: [SPARK-23614][SQL] Fix incorrect reuse exchange when caching is used

Repository: spark
Updated Branches:
  refs/heads/master a649fcf32 -> b2edc30db


[SPARK-23614][SQL] Fix incorrect reuse exchange when caching is used

## What changes were proposed in this pull request?

We should provide customized canonicalize plan for `InMemoryRelation` and `InMemoryTableScanExec`. Otherwise, we can wrongly treat two different cached plans as same result. It causes wrongly reused exchange then.

For a test query like this:
```scala
val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache()
val group1 = cached.groupBy("x").agg(min(col("y")) as "value")
val group2 = cached.groupBy("x").agg(min(col("z")) as "value")
group1.union(group2)
```

Canonicalized plans before:

First exchange:
```
Exchange hashpartitioning(none#0, 5)
+- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4])
   +- *(1) InMemoryTableScan [none#0, none#1]
         +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
               +- LocalTableScan [x#4253, y#4254, z#4255]
```

Second exchange:
```
Exchange hashpartitioning(none#0, 5)
+- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4])
   +- *(3) InMemoryTableScan [none#0, none#1]
         +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
               +- LocalTableScan [x#4253, y#4254, z#4255]
```

You can find that they have the canonicalized plans are the same, although we use different columns in two `InMemoryTableScan`s.

Canonicalized plan after:

First exchange:
```
Exchange hashpartitioning(none#0, 5)
+- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4])
   +- *(1) InMemoryTableScan [none#0, none#1]
         +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas)
               +- LocalTableScan [none#0, none#1, none#2]
```

Second exchange:
```
Exchange hashpartitioning(none#0, 5)
+- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4])
   +- *(3) InMemoryTableScan [none#0, none#2]
         +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas)
               +- LocalTableScan [none#0, none#1, none#2]
```

## How was this patch tested?

Added unit test.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #20831 from viirya/SPARK-23614.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b2edc30d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b2edc30d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b2edc30d

Branch: refs/heads/master
Commit: b2edc30db1dcc6102687d20c158a2700965fdf51
Parents: a649fcf
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Mar 22 21:23:25 2018 -0700
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Mar 22 21:23:25 2018 -0700

----------------------------------------------------------------------
 .../execution/columnar/InMemoryRelation.scala    | 10 ++++++++++
 .../columnar/InMemoryTableScanExec.scala         | 19 +++++++++++++------
 .../org/apache/spark/sql/DatasetSuite.scala      |  9 +++++++++
 .../spark/sql/execution/ExchangeSuite.scala      |  7 +++++++
 4 files changed, 39 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b2edc30d/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 22e1691..2579046 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
 import org.apache.spark.sql.execution.SparkPlan
@@ -68,6 +69,15 @@ case class InMemoryRelation(
 
   override protected def innerChildren: Seq[SparkPlan] = Seq(child)
 
+  override def doCanonicalize(): logical.LogicalPlan =
+    copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)),
+      storageLevel = StorageLevel.NONE,
+      child = child.canonicalized,
+      tableName = None)(
+      _cachedColumnBuffers,
+      sizeInBytesStats,
+      statsOfPlanToCache)
+
   override def producedAttributes: AttributeSet = outputSet
 
   @transient val partitionStatistics = new PartitionStatistics(output)

http://git-wip-us.apache.org/repos/asf/spark/blob/b2edc30d/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index a93e8a1..e73e137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
-import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.vectorized._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -38,6 +38,11 @@ case class InMemoryTableScanExec(
 
   override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren
 
+  override def doCanonicalize(): SparkPlan =
+    copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)),
+      predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)),
+      relation = relation.canonicalized.asInstanceOf[InMemoryRelation])
+
   override def vectorTypes: Option[Seq[String]] =
     Option(Seq.fill(attributes.length)(
       if (!conf.offHeapColumnVectorEnabled) {
@@ -169,11 +174,13 @@ case class InMemoryTableScanExec(
   override def outputOrdering: Seq[SortOrder] =
     relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
 
-  private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
+  // Keeps relation's partition statistics because we don't serialize relation.
+  private val stats = relation.partitionStatistics
+  private def statsFor(a: Attribute) = stats.forAttribute(a)
 
   // Returned filter predicate should return false iff it is impossible for the input expression
   // to evaluate to `true' based on statistics collected about this partition batch.
-  @transient val buildFilter: PartialFunction[Expression, Expression] = {
+  @transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
     case And(lhs: Expression, rhs: Expression)
       if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) =>
       (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _)
@@ -213,14 +220,14 @@ case class InMemoryTableScanExec(
         l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
   }
 
-  val partitionFilters: Seq[Expression] = {
+  lazy val partitionFilters: Seq[Expression] = {
     predicates.flatMap { p =>
       val filter = buildFilter.lift(p)
       val boundFilter =
         filter.map(
           BindReferences.bindReference(
             _,
-            relation.partitionStatistics.schema,
+            stats.schema,
             allowFailures = true))
 
       boundFilter.foreach(_ =>
@@ -243,7 +250,7 @@ case class InMemoryTableScanExec(
   private def filteredCachedBatches(): RDD[CachedBatch] = {
     // Using these variables here to avoid serialization of entire objects (if referenced directly)
     // within the map Partitions closure.
-    val schema = relation.partitionStatistics.schema
+    val schema = stats.schema
     val schemaIndex = schema.zipWithIndex
     val buffers = relation.cachedColumnBuffers
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b2edc30d/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 49c59cf..9b745be 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1446,8 +1446,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     val data = Seq(("a", null))
     checkDataset(data.toDS(), data: _*)
   }
+
+  test("SPARK-23614: Union produces incorrect results when caching is used") {
+    val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache()
+    val group1 = cached.groupBy("x").agg(min(col("y")) as "value")
+    val group2 = cached.groupBy("x").agg(min(col("z")) as "value")
+    checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil)
+  }
 }
 
+case class TestDataUnion(x: Int, y: Int, z: Int)
+
 case class SingleData(id: Int)
 case class DoubleData(id: Int, val1: String)
 case class TripleData(id: Int, val1: String, val2: Long)

http://git-wip-us.apache.org/repos/asf/spark/blob/b2edc30d/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 697d7e6..bde2de5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -125,4 +125,11 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
       assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong))
     }
   }
+
+  test("SPARK-23614: Fix incorrect reuse exchange when caching is used") {
+    val cached = spark.createDataset(Seq((1, 2, 3), (4, 5, 6))).cache()
+    val projection1 = cached.select("_1", "_2").queryExecution.executedPlan
+    val projection2 = cached.select("_1", "_3").queryExecution.executedPlan
+    assert(!projection1.sameResult(projection2))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org