You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/01/27 00:46:56 UTC

spark git commit: [SPARK-23214][SQL] cached data should not carry extra hint info

Repository: spark
Updated Branches:
  refs/heads/master 073744985 -> 5b5447c68


[SPARK-23214][SQL] cached data should not carry extra hint info

## What changes were proposed in this pull request?

This is a regression introduced by https://github.com/apache/spark/pull/19864

When we lookup cache, we should not carry the hint info, as this cache entry might be added by a plan having hint info, while the input plan for this lookup may not have hint info, or have different hint info.

## How was this patch tested?

a new test.

Author: Wenchen Fan <we...@databricks.com>

Closes #20394 from cloud-fan/cache.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b5447c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b5447c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b5447c6

Branch: refs/heads/master
Commit: 5b5447c68ac79715e2256e487e1212861cdab1fc
Parents: 0737449
Author: Wenchen Fan <we...@databricks.com>
Authored: Fri Jan 26 16:46:51 2018 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Fri Jan 26 16:46:51 2018 -0800

----------------------------------------------------------------------
 .../spark/sql/execution/CacheManager.scala      |  17 +--
 .../execution/columnar/InMemoryRelation.scala   |  27 +++--
 .../org/apache/spark/sql/CachedTableSuite.scala |   4 +-
 .../columnar/InMemoryColumnarQuerySuite.scala   |   2 +-
 .../execution/joins/BroadcastJoinSuite.scala    | 103 ++++++++++++-------
 5 files changed, 94 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 432eb59..d68aeb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -169,14 +169,17 @@ class CacheManager extends Logging {
   /** Replaces segments of the given logical plan with cached versions where possible. */
   def useCachedData(plan: LogicalPlan): LogicalPlan = {
     val newPlan = plan transformDown {
+      // Do not lookup the cache by hint node. Hint node is special, we should ignore it when
+      // canonicalizing plans, so that plans which are same except hint can hit the same cache.
+      // However, we also want to keep the hint info after cache lookup. Here we skip the hint
+      // node, so that the returned caching plan won't replace the hint node and drop the hint info
+      // from the original plan.
+      case hint: ResolvedHint => hint
+
       case currentFragment =>
-        lookupCachedData(currentFragment).map { cached =>
-          val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
-          currentFragment match {
-            case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints)
-            case _ => cachedPlan
-          }
-        }.getOrElse(currentFragment)
+        lookupCachedData(currentFragment)
+          .map(_.cachedRepresentation.withOutput(currentFragment.output))
+          .getOrElse(currentFragment)
     }
 
     newPlan transformAllExpressions {

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 51928d9..22e1691 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.LongAccumulator
@@ -62,8 +62,8 @@ case class InMemoryRelation(
     @transient child: SparkPlan,
     tableName: Option[String])(
     @transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
-    val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
-    statsOfPlanToCache: Statistics = null)
+    val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
+    statsOfPlanToCache: Statistics)
   extends logical.LeafNode with MultiInstanceRelation {
 
   override protected def innerChildren: Seq[SparkPlan] = Seq(child)
@@ -73,11 +73,16 @@ case class InMemoryRelation(
   @transient val partitionStatistics = new PartitionStatistics(output)
 
   override def computeStats(): Statistics = {
-    if (batchStats.value == 0L) {
-      // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache
-      statsOfPlanToCache
+    if (sizeInBytesStats.value == 0L) {
+      // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
+      // Note that we should drop the hint info here. We may cache a plan whose root node is a hint
+      // node. When we lookup the cache with a semantically same plan without hint info, the plan
+      // returned by cache lookup should not have hint info. If we lookup the cache with a
+      // semantically same plan with a different hint info, `CacheManager.useCachedData` will take
+      // care of it and retain the hint info in the lookup input plan.
+      statsOfPlanToCache.copy(hints = HintInfo())
     } else {
-      Statistics(sizeInBytes = batchStats.value.longValue)
+      Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
     }
   }
 
@@ -122,7 +127,7 @@ case class InMemoryRelation(
             rowCount += 1
           }
 
-          batchStats.add(totalSize)
+          sizeInBytesStats.add(totalSize)
 
           val stats = InternalRow.fromSeq(
             columnBuilders.flatMap(_.columnStats.collectedStatistics))
@@ -144,7 +149,7 @@ case class InMemoryRelation(
   def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
     InMemoryRelation(
       newOutput, useCompression, batchSize, storageLevel, child, tableName)(
-        _cachedColumnBuffers, batchStats, statsOfPlanToCache)
+        _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
   }
 
   override def newInstance(): this.type = {
@@ -156,12 +161,12 @@ case class InMemoryRelation(
       child,
       tableName)(
         _cachedColumnBuffers,
-        batchStats,
+        sizeInBytesStats,
         statsOfPlanToCache).asInstanceOf[this.type]
   }
 
   def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
 
   override protected def otherCopyArgs: Seq[AnyRef] =
-    Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache)
+    Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1e52445..72fe0f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
     val toBeCleanedAccIds = new HashSet[Long]
 
     val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
-      case i: InMemoryRelation => i.batchStats.id
+      case i: InMemoryRelation => i.sizeInBytesStats.id
     }.head
     toBeCleanedAccIds += accId1
 
     val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
-      case i: InMemoryRelation => i.batchStats.id
+      case i: InMemoryRelation => i.sizeInBytesStats.id
     }.head
     toBeCleanedAccIds += accId2
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 2280da9..dc1766f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
     checkAnswer(cached, expectedAnswer)
 
     // Check that the right size was calculated.
-    assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
+    assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
   }
 
   test("access primitive-type columns in CachedBatch without whole stage codegen") {

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 1704bc8..bcdee79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,7 +22,8 @@ import scala.reflect.ClassTag
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
 import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
   private def testBroadcastJoin[T: ClassTag](
       joinType: String,
       forceBroadcast: Boolean = false): SparkPlan = {
-    val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
-    val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+    val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+    val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
 
     // Comparison at the end is for broadcast left semi join
     val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
@@ -109,30 +110,58 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
     }
   }
 
-  test("broadcast hint is retained after using the cached data") {
+  test("SPARK-23192: broadcast hint should be retained after using the cached data") {
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
-      val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
-      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
-      df2.cache()
-      val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
-      val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
-        case b: BroadcastHashJoinExec => b
-      }.size
-      assert(numBroadCastHashJoin === 1)
+      try {
+        val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+        val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
+        df2.cache()
+        val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
+        val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
+          case b: BroadcastHashJoinExec => b
+        }.size
+        assert(numBroadCastHashJoin === 1)
+      } finally {
+        spark.catalog.clearCache()
+      }
+    }
+  }
+
+  test("SPARK-23214: cached data should not carry extra hint info") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      try {
+        val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+        val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
+        broadcast(df2).cache()
+
+        val df3 = df1.join(df2, Seq("key"), "inner")
+        val numCachedPlan = df3.queryExecution.executedPlan.collect {
+          case i: InMemoryTableScanExec => i
+        }.size
+        // df2 should be cached.
+        assert(numCachedPlan === 1)
+
+        val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
+          case b: BroadcastHashJoinExec => b
+        }.size
+        // df2 should not be broadcasted.
+        assert(numBroadCastHashJoin === 0)
+      } finally {
+        spark.catalog.clearCache()
+      }
     }
   }
 
   test("broadcast hint isn't propagated after a join") {
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
-      val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
-      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+      val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+      val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
       val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))
 
-      val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value")
+      val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value")
       val df5 = df4.join(df3, Seq("key"), "inner")
 
-      val plan =
-        EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
+      val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
 
       assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
       assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
@@ -140,30 +169,30 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
   }
 
   private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
-    val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+    val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
     val joined = df1.join(df, Seq("key"), "inner")
 
-    val plan =
-      EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
+    val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
 
     assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
   }
 
   test("broadcast hint programming API") {
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
-      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value")
+      val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value")
       val broadcasted = broadcast(df2)
-      val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value")
-
-      val cases = Seq(broadcasted.limit(2),
-                      broadcasted.filter("value < 10"),
-                      broadcasted.sample(true, 0.5),
-                      broadcasted.distinct(),
-                      broadcasted.groupBy("value").agg(min($"key").as("key")),
-                      // except and intersect are semi/anti-joins which won't return more data then
-                      // their left argument, so the broadcast hint should be propagated here
-                      broadcasted.except(df3),
-                      broadcasted.intersect(df3))
+      val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value")
+
+      val cases = Seq(
+        broadcasted.limit(2),
+        broadcasted.filter("value < 10"),
+        broadcasted.sample(true, 0.5),
+        broadcasted.distinct(),
+        broadcasted.groupBy("value").agg(min($"key").as("key")),
+        // except and intersect are semi/anti-joins which won't return more data then
+        // their left argument, so the broadcast hint should be propagated here
+        broadcasted.except(df3),
+        broadcasted.intersect(df3))
 
       cases.foreach(assertBroadcastJoin)
     }
@@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
   test("Shouldn't change broadcast join buildSide if user clearly specified") {
 
     withTempView("t1", "t2") {
-      spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
-      spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
-        .createTempView("t2")
+      Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+      Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
 
       val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
       val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
@@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
   test("Shouldn't bias towards build right if user didn't specify") {
 
     withTempView("t1", "t2") {
-      spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
-      spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
-        .createTempView("t2")
+      Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+      Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
 
       val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
       val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org