You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/01/27 00:46:56 UTC
spark git commit: [SPARK-23214][SQL] cached data should not carry
extra hint info
Repository: spark
Updated Branches:
refs/heads/master 073744985 -> 5b5447c68
[SPARK-23214][SQL] cached data should not carry extra hint info
## What changes were proposed in this pull request?
This is a regression introduced by https://github.com/apache/spark/pull/19864
When we lookup cache, we should not carry the hint info, as this cache entry might be added by a plan having hint info, while the input plan for this lookup may not have hint info, or have different hint info.
## How was this patch tested?
a new test.
Author: Wenchen Fan <we...@databricks.com>
Closes #20394 from cloud-fan/cache.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b5447c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b5447c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b5447c6
Branch: refs/heads/master
Commit: 5b5447c68ac79715e2256e487e1212861cdab1fc
Parents: 0737449
Author: Wenchen Fan <we...@databricks.com>
Authored: Fri Jan 26 16:46:51 2018 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Fri Jan 26 16:46:51 2018 -0800
----------------------------------------------------------------------
.../spark/sql/execution/CacheManager.scala | 17 +--
.../execution/columnar/InMemoryRelation.scala | 27 +++--
.../org/apache/spark/sql/CachedTableSuite.scala | 4 +-
.../columnar/InMemoryColumnarQuerySuite.scala | 2 +-
.../execution/joins/BroadcastJoinSuite.scala | 103 ++++++++++++-------
5 files changed, 94 insertions(+), 59 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 432eb59..d68aeb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -169,14 +169,17 @@ class CacheManager extends Logging {
/** Replaces segments of the given logical plan with cached versions where possible. */
def useCachedData(plan: LogicalPlan): LogicalPlan = {
val newPlan = plan transformDown {
+ // Do not lookup the cache by hint node. Hint node is special, we should ignore it when
+ // canonicalizing plans, so that plans which are same except hint can hit the same cache.
+ // However, we also want to keep the hint info after cache lookup. Here we skip the hint
+ // node, so that the returned caching plan won't replace the hint node and drop the hint info
+ // from the original plan.
+ case hint: ResolvedHint => hint
+
case currentFragment =>
- lookupCachedData(currentFragment).map { cached =>
- val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
- currentFragment match {
- case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints)
- case _ => cachedPlan
- }
- }.getOrElse(currentFragment)
+ lookupCachedData(currentFragment)
+ .map(_.cachedRepresentation.withOutput(currentFragment.output))
+ .getOrElse(currentFragment)
}
newPlan transformAllExpressions {
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 51928d9..22e1691 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator
@@ -62,8 +62,8 @@ case class InMemoryRelation(
@transient child: SparkPlan,
tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
- val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
- statsOfPlanToCache: Statistics = null)
+ val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
+ statsOfPlanToCache: Statistics)
extends logical.LeafNode with MultiInstanceRelation {
override protected def innerChildren: Seq[SparkPlan] = Seq(child)
@@ -73,11 +73,16 @@ case class InMemoryRelation(
@transient val partitionStatistics = new PartitionStatistics(output)
override def computeStats(): Statistics = {
- if (batchStats.value == 0L) {
- // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache
- statsOfPlanToCache
+ if (sizeInBytesStats.value == 0L) {
+ // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
+ // Note that we should drop the hint info here. We may cache a plan whose root node is a hint
+ // node. When we lookup the cache with a semantically same plan without hint info, the plan
+ // returned by cache lookup should not have hint info. If we lookup the cache with a
+ // semantically same plan with a different hint info, `CacheManager.useCachedData` will take
+ // care of it and retain the hint info in the lookup input plan.
+ statsOfPlanToCache.copy(hints = HintInfo())
} else {
- Statistics(sizeInBytes = batchStats.value.longValue)
+ Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
}
}
@@ -122,7 +127,7 @@ case class InMemoryRelation(
rowCount += 1
}
- batchStats.add(totalSize)
+ sizeInBytesStats.add(totalSize)
val stats = InternalRow.fromSeq(
columnBuilders.flatMap(_.columnStats.collectedStatistics))
@@ -144,7 +149,7 @@ case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
- _cachedColumnBuffers, batchStats, statsOfPlanToCache)
+ _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
}
override def newInstance(): this.type = {
@@ -156,12 +161,12 @@ case class InMemoryRelation(
child,
tableName)(
_cachedColumnBuffers,
- batchStats,
+ sizeInBytesStats,
statsOfPlanToCache).asInstanceOf[this.type]
}
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
override protected def otherCopyArgs: Seq[AnyRef] =
- Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache)
+ Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1e52445..72fe0f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val toBeCleanedAccIds = new HashSet[Long]
val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
- case i: InMemoryRelation => i.batchStats.id
+ case i: InMemoryRelation => i.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId1
val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
- case i: InMemoryRelation => i.batchStats.id
+ case i: InMemoryRelation => i.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId2
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 2280da9..dc1766f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(cached, expectedAnswer)
// Check that the right size was calculated.
- assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
+ assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
}
test("access primitive-type columns in CachedBatch without whole stage codegen") {
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5447c6/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 1704bc8..bcdee79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,7 +22,8 @@ import scala.reflect.ClassTag
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
private def testBroadcastJoin[T: ClassTag](
joinType: String,
forceBroadcast: Boolean = false): SparkPlan = {
- val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
- val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+ val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+ val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
@@ -109,30 +110,58 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
}
}
- test("broadcast hint is retained after using the cached data") {
+ test("SPARK-23192: broadcast hint should be retained after using the cached data") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
- val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
- val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
- df2.cache()
- val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
- val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
- case b: BroadcastHashJoinExec => b
- }.size
- assert(numBroadCastHashJoin === 1)
+ try {
+ val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+ val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
+ df2.cache()
+ val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
+ val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
+ case b: BroadcastHashJoinExec => b
+ }.size
+ assert(numBroadCastHashJoin === 1)
+ } finally {
+ spark.catalog.clearCache()
+ }
+ }
+ }
+
+ test("SPARK-23214: cached data should not carry extra hint info") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ try {
+ val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+ val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
+ broadcast(df2).cache()
+
+ val df3 = df1.join(df2, Seq("key"), "inner")
+ val numCachedPlan = df3.queryExecution.executedPlan.collect {
+ case i: InMemoryTableScanExec => i
+ }.size
+ // df2 should be cached.
+ assert(numCachedPlan === 1)
+
+ val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
+ case b: BroadcastHashJoinExec => b
+ }.size
+ // df2 should not be broadcasted.
+ assert(numBroadCastHashJoin === 0)
+ } finally {
+ spark.catalog.clearCache()
+ }
}
}
test("broadcast hint isn't propagated after a join") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
- val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
- val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+ val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
+ val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))
- val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value")
+ val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value")
val df5 = df4.join(df3, Seq("key"), "inner")
- val plan =
- EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
+ val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
@@ -140,30 +169,30 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
}
private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
- val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+ val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val joined = df1.join(df, Seq("key"), "inner")
- val plan =
- EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
+ val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
}
test("broadcast hint programming API") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
- val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value")
+ val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value")
val broadcasted = broadcast(df2)
- val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value")
-
- val cases = Seq(broadcasted.limit(2),
- broadcasted.filter("value < 10"),
- broadcasted.sample(true, 0.5),
- broadcasted.distinct(),
- broadcasted.groupBy("value").agg(min($"key").as("key")),
- // except and intersect are semi/anti-joins which won't return more data then
- // their left argument, so the broadcast hint should be propagated here
- broadcasted.except(df3),
- broadcasted.intersect(df3))
+ val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value")
+
+ val cases = Seq(
+ broadcasted.limit(2),
+ broadcasted.filter("value < 10"),
+ broadcasted.sample(true, 0.5),
+ broadcasted.distinct(),
+ broadcasted.groupBy("value").agg(min($"key").as("key")),
+ // except and intersect are semi/anti-joins which won't return more data then
+ // their left argument, so the broadcast hint should be propagated here
+ broadcasted.except(df3),
+ broadcasted.intersect(df3))
cases.foreach(assertBroadcastJoin)
}
@@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("Shouldn't change broadcast join buildSide if user clearly specified") {
withTempView("t1", "t2") {
- spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
- spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
- .createTempView("t2")
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
@@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("Shouldn't bias towards build right if user didn't specify") {
withTempView("t1", "t2") {
- spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
- spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
- .createTempView("t2")
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org