You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/03 17:06:01 UTC

[spark] branch branch-3.1 updated: [SPARK-39952][SQL] SaveIntoDataSourceCommand should recache result relation

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 9605dde3c04 [SPARK-39952][SQL] SaveIntoDataSourceCommand should recache result relation
9605dde3c04 is described below

commit 9605dde3c041b8b3986dc1dfd61874e3fa8eae68
Author: ulysses-you <ul...@gmail.com>
AuthorDate: Thu Aug 4 01:03:45 2022 +0800

    [SPARK-39952][SQL] SaveIntoDataSourceCommand should recache result relation
    
    ### What changes were proposed in this pull request?
    
    recacheByPlan the result relation inside `SaveIntoDataSourceCommand`
    
    ### Why are the changes needed?
    
    The behavior of `SaveIntoDataSourceCommand` is similar with `InsertIntoDataSourceCommand` which supports append or overwirte data. In order to keep data consistent,  we should always do recacheByPlan the relation on post hoc.
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, bug fix
    
    ### How was this patch tested?
    
    add test
    
    Closes #37380 from ulysses-you/refresh.
    
    Authored-by: ulysses-you <ul...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 5fe0b245f7891a05bc4e1e641fd0aa9130118ea4)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../datasources/SaveIntoDataSourceCommand.scala    | 12 ++++-
 .../SaveIntoDataSourceCommandSuite.scala           | 61 +++++++++++++++++++++-
 2 files changed, 70 insertions(+), 3 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
index 49e77f618f2..e97cf74549a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.datasources
 
+import scala.util.control.NonFatal
+
 import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -42,9 +44,17 @@ case class SaveIntoDataSourceCommand(
   override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
-    dataSource.createRelation(
+    val relation = dataSource.createRelation(
       sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))
 
+    try {
+      val logicalRelation = LogicalRelation(relation, relation.schema.toAttributes, None, false)
+      sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
+    } catch {
+      case NonFatal(_) =>
+        // some data source can not support return a valid relation, e.g. `KafkaSourceProvider`
+    }
+
     Seq.empty[Row]
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala
index e843d1d3284..e68d6561fb8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.sql.execution.datasources
 
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
+import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, TableScan}
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
 
-class SaveIntoDataSourceCommandSuite extends SharedSparkSession {
+class SaveIntoDataSourceCommandSuite extends QueryTest with SharedSparkSession {
 
   test("simpleString is redacted") {
     val URL = "connection.url"
@@ -41,4 +44,58 @@ class SaveIntoDataSourceCommandSuite extends SharedSparkSession {
     assert(!logicalPlanString.contains(PASS))
     assert(logicalPlanString.contains(DRIVER))
   }
+
+  test("SPARK-39952: SaveIntoDataSourceCommand should recache result relation") {
+    val provider = classOf[FakeV1DataSource].getName
+
+    def saveIntoDataSource(data: Int): Unit = {
+      spark.range(data)
+        .write
+        .mode("append")
+        .format(provider)
+        .save()
+    }
+
+    def loadData: DataFrame = {
+      spark.read
+        .format(provider)
+        .load()
+    }
+
+    saveIntoDataSource(1)
+    val cached = loadData.cache()
+    checkAnswer(cached, Row(0))
+
+    saveIntoDataSource(2)
+    checkAnswer(loadData, Row(0) :: Row(1) :: Nil)
+
+    FakeV1DataSource.data = null
+  }
+}
+
+object FakeV1DataSource {
+  var data: RDD[Row] = _
+}
+
+class FakeV1DataSource extends RelationProvider with CreatableRelationProvider {
+  override def createRelation(
+     sqlContext: SQLContext,
+     parameters: Map[String, String]): BaseRelation = {
+    FakeRelation()
+  }
+
+  override def createRelation(
+     sqlContext: SQLContext,
+     mode: SaveMode,
+     parameters: Map[String, String],
+     data: DataFrame): BaseRelation = {
+    FakeV1DataSource.data = data.rdd
+    FakeRelation()
+  }
+}
+
+case class FakeRelation() extends BaseRelation with TableScan {
+  override def sqlContext: SQLContext = SparkSession.getActiveSession.get.sqlContext
+  override def schema: StructType = StructType(Seq(StructField("id", LongType)))
+  override def buildScan(): RDD[Row] = FakeV1DataSource.data
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org