You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by br...@apache.org on 2020/03/19 01:08:43 UTC

[spark] branch master updated: [SPARK-31178][SQL] Prevent V2 exec nodes from executing multiple times

This is an automated email from the ASF dual-hosted git repository.

brkyvz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4237251  [SPARK-31178][SQL] Prevent V2 exec nodes from executing multiple times
4237251 is described below

commit 4237251861c79f3176de7cf5232f0388ec5d946e
Author: Burak Yavuz <br...@gmail.com>
AuthorDate: Wed Mar 18 18:07:24 2020 -0700

    [SPARK-31178][SQL] Prevent V2 exec nodes from executing multiple times
    
    ### What changes were proposed in this pull request?
    
    This PR prevents the execution of V2 DataSource exec nodes multiple times when `collect()` is called on them. For V1 DataSources, commands would be executed as a RunnableCommand, which would cache the result as part of the `ExecutedCommandExec` node. We extend `V2CommandExec` for all the data writing commands so that they only get executed once as well.
    
    ### Why are the changes needed?
    
    Calling `collect()` on a SQL command that inserts data or creates a table gets executed multiple times otherwise.
    
    ### Does this PR introduce any user-facing change?
    
    Fixes a bug
    
    ### How was this patch tested?
    
    Unit tests
    
    Closes #27941 from brkyvz/doubleInsert.
    
    Authored-by: Burak Yavuz <br...@gmail.com>
    Signed-off-by: Burak Yavuz <br...@gmail.com>
---
 .../datasources/v2/ShowNamespacesExec.scala        |  4 +-
 .../execution/datasources/v2/ShowTablesExec.scala  |  4 +-
 .../datasources/v2/V1FallbackWriters.scala         | 10 ++--
 .../execution/datasources/v2/V2CommandExec.scala   |  6 +-
 .../datasources/v2/WriteToDataSourceV2Exec.scala   | 28 ++++-----
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 68 ++++++++++++++++++++++
 .../spark/sql/connector/InsertIntoTests.scala      | 17 ++++++
 7 files changed, 112 insertions(+), 25 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
index fe3ab80..6f96848 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRowWithSchem
 import org.apache.spark.sql.catalyst.util.StringUtils
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
 import org.apache.spark.sql.connector.catalog.SupportsNamespaces
+import org.apache.spark.sql.execution.LeafExecNode
 
 /**
  * Physical plan node for showing namespaces.
@@ -33,8 +34,7 @@ case class ShowNamespacesExec(
     output: Seq[Attribute],
     catalog: SupportsNamespaces,
     namespace: Seq[String],
-    pattern: Option[String])
-    extends V2CommandExec {
+    pattern: Option[String]) extends V2CommandExec with LeafExecNode {
 
   override protected def run(): Seq[InternalRow] = {
     val namespaces = if (namespace.nonEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala
index 995b008..c740e0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRowWithSchem
 import org.apache.spark.sql.catalyst.util.StringUtils
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
 import org.apache.spark.sql.connector.catalog.TableCatalog
+import org.apache.spark.sql.execution.LeafExecNode
 
 /**
  * Physical plan node for showing tables.
@@ -33,8 +34,7 @@ case class ShowTablesExec(
     output: Seq[Attribute],
     catalog: TableCatalog,
     namespace: Seq[String],
-    pattern: Option[String])
-    extends V2CommandExec {
+    pattern: Option[String]) extends V2CommandExec with LeafExecNode {
   override protected def run(): Seq[InternalRow] = {
     val rows = new ArrayBuffer[InternalRow]()
     val encoder = RowEncoder(schema).resolveAndBind()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
index f973000..7502a87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
@@ -41,7 +41,7 @@ case class AppendDataExecV1(
     writeOptions: CaseInsensitiveStringMap,
     plan: LogicalPlan) extends V1FallbackWriters {
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     writeWithV1(newWriteBuilder().buildForV1Write())
   }
 }
@@ -67,7 +67,7 @@ case class OverwriteByExpressionExecV1(
     filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
   }
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     newWriteBuilder() match {
       case builder: SupportsTruncate if isTruncate(deleteWhere) =>
         writeWithV1(builder.truncate().asV1Builder.buildForV1Write())
@@ -82,7 +82,7 @@ case class OverwriteByExpressionExecV1(
 }
 
 /** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
-sealed trait V1FallbackWriters extends SupportsV1Write {
+sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write {
   override def output: Seq[Attribute] = Nil
   override final def children: Seq[SparkPlan] = Nil
 
@@ -115,8 +115,8 @@ trait SupportsV1Write extends SparkPlan {
   // TODO: We should be able to work on SparkPlans at this point.
   def plan: LogicalPlan
 
-  protected def writeWithV1(relation: InsertableRelation): RDD[InternalRow] = {
+  protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = {
     relation.insert(Dataset.ofRows(sqlContext.sparkSession, plan), overwrite = false)
-    sparkContext.emptyRDD
+    Nil
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala
index a1f685d..4be4a6b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.LeafExecNode
+import org.apache.spark.sql.execution.SparkPlan
 
 /**
  * A physical operator that executes run() and saves the result to prevent multiple executions.
  * Any V2 commands that do not require triggering a spark job should extend this class.
  */
-abstract class V2CommandExec extends LeafExecNode {
+abstract class V2CommandExec extends SparkPlan {
 
   /**
    * Abstract method that each concrete command needs to implement to compute the result.
@@ -53,4 +53,6 @@ abstract class V2CommandExec extends LeafExecNode {
   protected override def doExecute(): RDD[InternalRow] = {
     sqlContext.sparkContext.parallelize(result, 1)
   }
+
+  override def children: Seq[SparkPlan] = Nil
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index e360a9e..616e18e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -70,10 +70,10 @@ case class CreateTableAsSelectExec(
 
   import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     if (catalog.tableExists(ident)) {
       if (ifNotExists) {
-        return sparkContext.parallelize(Seq.empty, 1)
+        return Nil
       }
 
       throw new TableAlreadyExistsException(ident)
@@ -125,10 +125,10 @@ case class AtomicCreateTableAsSelectExec(
     writeOptions: CaseInsensitiveStringMap,
     ifNotExists: Boolean) extends AtomicTableWriteExec {
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     if (catalog.tableExists(ident)) {
       if (ifNotExists) {
-        return sparkContext.parallelize(Seq.empty, 1)
+        return Nil
       }
 
       throw new TableAlreadyExistsException(ident)
@@ -161,7 +161,7 @@ case class ReplaceTableAsSelectExec(
 
   import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     // Note that this operation is potentially unsafe, but these are the strict semantics of
     // RTAS if the catalog does not support atomic operations.
     //
@@ -225,7 +225,7 @@ case class AtomicReplaceTableAsSelectExec(
     writeOptions: CaseInsensitiveStringMap,
     orCreate: Boolean) extends AtomicTableWriteExec {
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     val schema = query.schema.asNullable
     val staged = if (orCreate) {
       catalog.stageCreateOrReplace(
@@ -255,7 +255,7 @@ case class AppendDataExec(
     writeOptions: CaseInsensitiveStringMap,
     query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     writeWithV2(newWriteBuilder().buildForBatch())
   }
 }
@@ -280,7 +280,7 @@ case class OverwriteByExpressionExec(
     filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
   }
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     newWriteBuilder() match {
       case builder: SupportsTruncate if isTruncate(deleteWhere) =>
         writeWithV2(builder.truncate().buildForBatch())
@@ -308,7 +308,7 @@ case class OverwritePartitionsDynamicExec(
     writeOptions: CaseInsensitiveStringMap,
     query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     newWriteBuilder() match {
       case builder: SupportsDynamicOverwrite =>
         writeWithV2(builder.overwriteDynamicPartitions().buildForBatch())
@@ -325,7 +325,7 @@ case class WriteToDataSourceV2Exec(
 
   def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()
 
-  override protected def doExecute(): RDD[InternalRow] = {
+  override protected def run(): Seq[InternalRow] = {
     writeWithV2(batchWrite)
   }
 }
@@ -350,7 +350,7 @@ trait BatchWriteHelper {
 /**
  * The base physical plan for writing data into data source v2.
  */
-trait V2TableWriteExec extends UnaryExecNode {
+trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
   def query: SparkPlan
 
   var commitProgress: Option[StreamWriterCommitProgress] = None
@@ -358,7 +358,7 @@ trait V2TableWriteExec extends UnaryExecNode {
   override def child: SparkPlan = query
   override def output: Seq[Attribute] = Nil
 
-  protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = {
+  protected def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = {
     val rdd: RDD[InternalRow] = {
       val tempRdd = query.execute()
       // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
@@ -415,7 +415,7 @@ trait V2TableWriteExec extends UnaryExecNode {
         }
     }
 
-    sparkContext.emptyRDD
+    Nil
   }
 }
 
@@ -485,7 +485,7 @@ private[v2] trait AtomicTableWriteExec extends V2TableWriteExec with SupportsV1W
   protected def writeToStagedTable(
       stagedTable: StagedTable,
       writeOptions: CaseInsensitiveStringMap,
-      ident: Identifier): RDD[InternalRow] = {
+      ident: Identifier): Seq[InternalRow] = {
     Utils.tryWithSafeFinallyAndFailureCallbacks({
       stagedTable match {
         case table: SupportsWrite =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index ba4200d..07e0959 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -319,6 +319,37 @@ class DataSourceV2SQLSuite
     }
   }
 
+  test("CreateTableAsSelect: do not double execute on collect(), take() and other queries") {
+    val basicCatalog = catalog("testcat").asTableCatalog
+    val atomicCatalog = catalog("testcat_atomic").asTableCatalog
+    val basicIdentifier = "testcat.table_name"
+    val atomicIdentifier = "testcat_atomic.table_name"
+
+    Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach {
+      case (catalog, identifier) =>
+        val df = spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT id, data FROM source")
+
+        df.collect()
+        df.take(5)
+        df.tail(5)
+        df.where("true").collect()
+        df.where("true").take(5)
+        df.where("true").tail(5)
+
+        val table = catalog.loadTable(Identifier.of(Array(), "table_name"))
+
+        assert(table.name == identifier)
+        assert(table.partitioning.isEmpty)
+        assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
+        assert(table.schema == new StructType()
+          .add("id", LongType)
+          .add("data", StringType))
+
+        val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+        checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
+    }
+  }
+
   test("ReplaceTableAsSelect: basic v2 implementation.") {
     val basicCatalog = catalog("testcat").asTableCatalog
     val atomicCatalog = catalog("testcat_atomic").asTableCatalog
@@ -346,6 +377,43 @@ class DataSourceV2SQLSuite
     }
   }
 
+  Seq("REPLACE", "CREATE OR REPLACE").foreach { cmd =>
+    test(s"ReplaceTableAsSelect: do not double execute $cmd on collect()") {
+      val basicCatalog = catalog("testcat").asTableCatalog
+      val atomicCatalog = catalog("testcat_atomic").asTableCatalog
+      val basicIdentifier = "testcat.table_name"
+      val atomicIdentifier = "testcat_atomic.table_name"
+
+      Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach {
+        case (catalog, identifier) =>
+          spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT id, data FROM source")
+          val originalTable = catalog.loadTable(Identifier.of(Array(), "table_name"))
+
+          val df = spark.sql(s"$cmd TABLE $identifier USING foo AS SELECT id FROM source")
+
+          df.collect()
+          df.take(5)
+          df.tail(5)
+          df.where("true").collect()
+          df.where("true").take(5)
+          df.where("true").tail(5)
+
+          val replacedTable = catalog.loadTable(Identifier.of(Array(), "table_name"))
+
+          assert(replacedTable != originalTable, "Table should have been replaced.")
+          assert(replacedTable.name == identifier)
+          assert(replacedTable.partitioning.isEmpty)
+          assert(replacedTable.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
+          assert(replacedTable.schema == new StructType().add("id", LongType))
+
+          val rdd = spark.sparkContext.parallelize(replacedTable.asInstanceOf[InMemoryTable].rows)
+          checkAnswer(
+            spark.internalCreateDataFrame(rdd, replacedTable.schema),
+            spark.table("source").select("id"))
+      }
+    }
+  }
+
   test("ReplaceTableAsSelect: Non-atomic catalog drops the table if the write fails.") {
     spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source")
     val testCatalog = catalog("testcat").asTableCatalog
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
index 0fd6cf1..b88ad52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
@@ -463,5 +463,22 @@ trait InsertIntoSQLOnlyTests
         }
       }
     }
+
+    test("do not double insert on INSERT INTO collect()") {
+      val t1 = s"${catalogAndNamespace}tbl"
+      withTableAndData(t1) { view =>
+        sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
+        val df = sql(s"INSERT INTO TABLE $t1 SELECT * FROM $view")
+
+        df.collect()
+        df.take(5)
+        df.tail(5)
+        df.where("true").collect()
+        df.where("true").take(5)
+        df.where("true").tail(5)
+
+        verifyTable(t1, spark.table(view))
+      }
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org