You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/02/17 05:09:18 UTC

spark git commit: [SPARK-18120][SPARK-19557][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods

Repository: spark
Updated Branches:
  refs/heads/master 21fde57f1 -> 54d23599d


[SPARK-18120][SPARK-19557][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods

## What changes were proposed in this pull request?

We only notify `QueryExecutionListener` for several `Dataset` operations, e.g. collect, take, etc. We should also do the notification for `DataFrameWriter` operations.

## How was this patch tested?

new regression test

close https://github.com/apache/spark/pull/16664

Author: Wenchen Fan <we...@databricks.com>

Closes #16962 from cloud-fan/insert.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/54d23599
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/54d23599
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/54d23599

Branch: refs/heads/master
Commit: 54d23599df7c28a7685416ced6ad8fcde047e534
Parents: 21fde57
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Feb 16 21:09:14 2017 -0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Feb 16 21:09:14 2017 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameWriter.scala  | 49 ++++++++++++-----
 .../datasources/SaveIntoDataSourceCommand.scala | 52 ++++++++++++++++++
 .../spark/sql/util/DataFrameCallbackSuite.scala | 57 +++++++++++++++++++-
 3 files changed, 142 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/54d23599/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index cdae8ea..3939251 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -25,9 +25,9 @@ import org.apache.spark.annotation.InterfaceStability
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType}
-import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
 import org.apache.spark.sql.execution.command.DDLUtils
-import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand}
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types.StructType
 
@@ -211,13 +211,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
     }
 
     assertNotBucketed("save")
-    val dataSource = DataSource(
-      df.sparkSession,
-      className = source,
-      partitionColumns = partitioningColumns.getOrElse(Nil),
-      options = extraOptions.toMap)
 
-    dataSource.write(mode, df)
+    runCommand(df.sparkSession, "save") {
+      SaveIntoDataSourceCommand(
+        query = df.logicalPlan,
+        provider = source,
+        partitionColumns = partitioningColumns.getOrElse(Nil),
+        options = extraOptions.toMap,
+        mode = mode)
+    }
   }
 
   /**
@@ -260,13 +262,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
       )
     }
 
-    df.sparkSession.sessionState.executePlan(
+    runCommand(df.sparkSession, "insertInto") {
       InsertIntoTable(
         table = UnresolvedRelation(tableIdent),
         partition = Map.empty[String, Option[String]],
         query = df.logicalPlan,
         overwrite = mode == SaveMode.Overwrite,
-        ifNotExists = false)).toRdd
+        ifNotExists = false)
+    }
   }
 
   private def getBucketSpec: Option[BucketSpec] = {
@@ -389,10 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
       schema = new StructType,
       provider = Some(source),
       partitionColumnNames = partitioningColumns.getOrElse(Nil),
-      bucketSpec = getBucketSpec
-    )
-    df.sparkSession.sessionState.executePlan(
-      CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd
+      bucketSpec = getBucketSpec)
+
+    runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
   }
 
   /**
@@ -573,6 +575,25 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
     format("csv").save(path)
   }
 
+  /**
+   * Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the
+   * user-registered callback functions.
+   */
+  private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
+    val qe = session.sessionState.executePlan(command)
+    try {
+      val start = System.nanoTime()
+      // call `QueryExecution.toRDD` to trigger the execution of commands.
+      qe.toRdd
+      val end = System.nanoTime()
+      session.listenerManager.onSuccess(name, qe, end - start)
+    } catch {
+      case e: Exception =>
+        session.listenerManager.onFailure(name, qe, e)
+        throw e
+    }
+  }
+
   ///////////////////////////////////////////////////////////////////////////////////////
   // Builder pattern config options
   ///////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/54d23599/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
new file mode 100644
index 0000000..6f19ea1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.command.RunnableCommand
+
+/**
+ * Saves the results of `query` in to a data source.
+ *
+ * Note that this command is different from [[InsertIntoDataSourceCommand]]. This command will call
+ * `CreatableRelationProvider.createRelation` to write out the data, while
+ * [[InsertIntoDataSourceCommand]] calls `InsertableRelation.insert`. Ideally these 2 data source
+ * interfaces should do the same thing, but as we've already published these 2 interfaces and the
+ * implementations may have different logic, we have to keep these 2 different commands.
+ */
+case class SaveIntoDataSourceCommand(
+    query: LogicalPlan,
+    provider: String,
+    partitionColumns: Seq[String],
+    options: Map[String, String],
+    mode: SaveMode) extends RunnableCommand {
+
+  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)
+
+  override def run(sparkSession: SparkSession): Seq[Row] = {
+    DataSource(
+      sparkSession,
+      className = provider,
+      partitionColumns = partitionColumns,
+      options = options).write(mode, Dataset.ofRows(sparkSession, query))
+
+    Seq.empty[Row]
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/54d23599/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 3ae5ce6..9f27d06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql.util
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
-import org.apache.spark.sql.{functions, QueryTest}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
+import org.apache.spark.sql.{functions, AnalysisException, QueryTest}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project}
 import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand}
 import org.apache.spark.sql.test.SharedSQLContext
 
 class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
@@ -159,4 +161,55 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
 
     spark.listenerManager.unregister(listener)
   }
+
+  test("execute callback functions for DataFrameWriter") {
+    val commands = ArrayBuffer.empty[(String, LogicalPlan)]
+    val exceptions = ArrayBuffer.empty[(String, Exception)]
+    val listener = new QueryExecutionListener {
+      override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+        exceptions += funcName -> exception
+      }
+
+      override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+        commands += funcName -> qe.logical
+      }
+    }
+    spark.listenerManager.register(listener)
+
+    withTempPath { path =>
+      spark.range(10).write.format("json").save(path.getCanonicalPath)
+      assert(commands.length == 1)
+      assert(commands.head._1 == "save")
+      assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand])
+      assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json")
+    }
+
+    withTable("tab") {
+      sql("CREATE TABLE tab(i long) using parquet")
+      spark.range(10).write.insertInto("tab")
+      assert(commands.length == 2)
+      assert(commands(1)._1 == "insertInto")
+      assert(commands(1)._2.isInstanceOf[InsertIntoTable])
+      assert(commands(1)._2.asInstanceOf[InsertIntoTable].table
+        .asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab")
+    }
+
+    withTable("tab") {
+      spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
+      assert(commands.length == 3)
+      assert(commands(2)._1 == "saveAsTable")
+      assert(commands(2)._2.isInstanceOf[CreateTable])
+      assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p"))
+    }
+
+    withTable("tab") {
+      sql("CREATE TABLE tab(i long) using parquet")
+      val e = intercept[AnalysisException] {
+        spark.range(10).select($"id", $"id").write.insertInto("tab")
+      }
+      assert(exceptions.length == 1)
+      assert(exceptions.head._1 == "insertInto")
+      assert(exceptions.head._2 == e)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org