You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/06/26 18:36:03 UTC

spark git commit: [SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId

Repository: spark
Updated Branches:
  refs/heads/master 9e50a1d37 -> c22810004


[SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId

## What changes were proposed in this pull request?

in https://github.com/apache/spark/pull/18064, to work around the nested sql execution id issue, we introduced several internal methods in `Dataset`, like `collectInternal`, `countInternal`, `showInternal`, etc., to avoid nested execution id.

However, this approach has poor expansibility. When we hit other nested execution id cases, we may need to add more internal methods in `Dataset`.

Our goal is to ignore the nested execution id in some cases, and we can have a better approach to achieve this goal, by introducing `SQLExecution.ignoreNestedExecutionId`. Whenever we find a place which needs to ignore the nested execution, we can just wrap the action with `SQLExecution.ignoreNestedExecutionId`, and this is more expansible than the previous approach.

The idea comes from https://github.com/apache/spark/pull/17540/files#diff-ab49028253e599e6e74cc4f4dcb2e3a8R57 by rdblue

## How was this patch tested?

existing tests.

Author: Wenchen Fan <we...@databricks.com>

Closes #18419 from cloud-fan/follow.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2281000
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2281000
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2281000

Branch: refs/heads/master
Commit: c22810004fb2db249be6477c9801d09b807af851
Parents: 9e50a1d
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Jun 27 02:35:51 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jun 27 02:35:51 2017 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 39 ++------------------
 .../spark/sql/execution/SQLExecution.scala      | 39 ++++++++++++++++++--
 .../execution/command/AnalyzeTableCommand.scala |  5 ++-
 .../spark/sql/execution/command/cache.scala     | 19 +++++-----
 .../datasources/csv/CSVDataSource.scala         |  6 ++-
 .../datasources/jdbc/JDBCRelation.scala         | 14 +++----
 .../spark/sql/execution/streaming/console.scala | 13 +++++--
 .../spark/sql/execution/streaming/memory.scala  | 33 +++++++++--------
 8 files changed, 89 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6e66e92..268a37f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -246,13 +246,8 @@ class Dataset[T] private[sql](
       _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
     val numRows = _numRows.max(0)
     val takeResult = toDF().take(numRows + 1)
-    showString(takeResult, numRows, truncate, vertical)
-  }
-
-  private def showString(
-      dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = {
-    val hasMoreData = dataWithOneMoreRow.length > numRows
-    val data = dataWithOneMoreRow.take(numRows)
+    val hasMoreData = takeResult.length > numRows
+    val data = takeResult.take(numRows)
 
     lazy val timeZone =
       DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
@@ -688,19 +683,6 @@ class Dataset[T] private[sql](
     println(showString(numRows, truncate = 0))
   }
 
-  // An internal version of `show`, which won't set execution id and trigger listeners.
-  private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = {
-    val numRows = _numRows.max(0)
-    val takeResult = toDF().takeInternal(numRows + 1)
-
-    if (truncate) {
-      println(showString(takeResult, numRows, truncate = 20, vertical = false))
-    } else {
-      println(showString(takeResult, numRows, truncate = 0, vertical = false))
-    }
-  }
-  // scalastyle:on println
-
   /**
    * Displays the Dataset in a tabular form. For example:
    * {{{
@@ -2467,11 +2449,6 @@ class Dataset[T] private[sql](
    */
   def take(n: Int): Array[T] = head(n)
 
-  // An internal version of `take`, which won't set execution id and trigger listeners.
-  private[sql] def takeInternal(n: Int): Array[T] = {
-    collectFromPlan(limit(n).queryExecution.executedPlan)
-  }
-
   /**
    * Returns the first `n` rows in the Dataset as a list.
    *
@@ -2496,11 +2473,6 @@ class Dataset[T] private[sql](
    */
   def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)
 
-  // An internal version of `collect`, which won't set execution id and trigger listeners.
-  private[sql] def collectInternal(): Array[T] = {
-    collectFromPlan(queryExecution.executedPlan)
-  }
-
   /**
    * Returns a Java list that contains all rows in this Dataset.
    *
@@ -2542,11 +2514,6 @@ class Dataset[T] private[sql](
     plan.executeCollect().head.getLong(0)
   }
 
-  // An internal version of `count`, which won't set execution id and trigger listeners.
-  private[sql] def countInternal(): Long = {
-    groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0)
-  }
-
   /**
    * Returns a new Dataset that has exactly `numPartitions` partitions.
    *
@@ -2792,7 +2759,7 @@ class Dataset[T] private[sql](
     createTempViewCommand(viewName, replace = true, global = true)
   }
 
-  private[spark] def createTempViewCommand(
+  private def createTempViewCommand(
       viewName: String,
       replace: Boolean,
       global: Boolean): CreateViewCommand = {

http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index bb206e8..ca8bed5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -29,6 +29,8 @@ object SQLExecution {
 
   val EXECUTION_ID_KEY = "spark.sql.execution.id"
 
+  private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
+
   private val _nextExecutionId = new AtomicLong(0)
 
   private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
@@ -42,8 +44,11 @@ object SQLExecution {
   private val testing = sys.props.contains("spark.testing")
 
   private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
+    val sc = sparkSession.sparkContext
+    val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
+    val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
     // only throw an exception during tests. a missing execution ID should not fail a job.
-    if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) {
+    if (testing && !isNestedExecution && !hasExecutionId) {
       // Attention testers: when a test fails with this exception, it means that the action that
       // started execution of a query didn't call withNewExecutionId. The execution ID should be
       // set by calling withNewExecutionId in the action that begins execution, like
@@ -65,7 +70,7 @@ object SQLExecution {
       val executionId = SQLExecution.nextExecutionId
       sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
       executionIdToQueryExecution.put(executionId, queryExecution)
-      val r = try {
+      try {
         // sparkContext.getCallSite() would first try to pick up any call site that was previously
         // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
         // streaming queries would give us call site like "run at <unknown>:0"
@@ -84,7 +89,15 @@ object SQLExecution {
         executionIdToQueryExecution.remove(executionId)
         sc.setLocalProperty(EXECUTION_ID_KEY, null)
       }
-      r
+    } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
+      // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
+      // `body`, so that Spark jobs issued in the `body` won't be tracked.
+      try {
+        sc.setLocalProperty(EXECUTION_ID_KEY, null)
+        body
+      } finally {
+        sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
+      }
     } else {
       // Don't support nested `withNewExecutionId`. This is an example of the nested
       // `withNewExecutionId`:
@@ -100,7 +113,9 @@ object SQLExecution {
       // all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
       //
       // A real case is the `DataFrame.count` method.
-      throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
+      throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
+        "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
+        "jobs issued by the nested execution.")
     }
   }
 
@@ -118,4 +133,20 @@ object SQLExecution {
       sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
     }
   }
+
+  /**
+   * Wrap an action which may have nested execution id. This method can be used to run an execution
+   * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
+   * all Spark jobs issued in the body won't be tracked in UI.
+   */
+  def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
+    val sc = sparkSession.sparkContext
+    val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
+    try {
+      sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
+      body
+    } finally {
+      sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
index 06e588f..13b8faf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.internal.SessionState
 
 
@@ -58,7 +59,9 @@ case class AnalyzeTableCommand(
     // 2. when total size is changed, `oldRowCount` becomes invalid.
     // This is to make sure that we only record the right statistics.
     if (!noscan) {
-      val newRowCount = sparkSession.table(tableIdentWithDB).countInternal()
+      val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
+        sparkSession.table(tableIdentWithDB).count()
+      }
       if (newRowCount >= 0 && newRowCount != oldRowCount) {
         newStats = if (newStats.isDefined) {
           newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))

http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index 184d038..d36eb75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SQLExecution
 
 case class CacheTableCommand(
     tableIdent: TableIdentifier,
@@ -33,16 +34,16 @@ case class CacheTableCommand(
   override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
-    plan.foreach { logicalPlan =>
-      Dataset.ofRows(sparkSession, logicalPlan)
-        .createTempViewCommand(tableIdent.quotedString, replace = false, global = false)
-        .run(sparkSession)
-    }
-    sparkSession.catalog.cacheTable(tableIdent.quotedString)
+    SQLExecution.ignoreNestedExecutionId(sparkSession) {
+      plan.foreach { logicalPlan =>
+        Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
+      }
+      sparkSession.catalog.cacheTable(tableIdent.quotedString)
 
-    if (!isLazy) {
-      // Performs eager caching
-      sparkSession.table(tableIdent).countInternal()
+      if (!isLazy) {
+        // Performs eager caching
+        sparkSession.table(tableIdent).count()
+      }
     }
 
     Seq.empty[Row]

http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index eadc6c9..99133bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
 import org.apache.spark.rdd.{BinaryFileRDD, RDD}
 import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.text.TextFileFormat
 import org.apache.spark.sql.types.StructType
@@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource {
       inputPaths: Seq[FileStatus],
       parsedOptions: CSVOptions): StructType = {
     val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
-    val maybeFirstLine =
-      CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption
+    val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
+      CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
+    }
     inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index a06f1ce..b11da70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.Partition
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
+import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.jdbc.JdbcDialects
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.StructType
@@ -129,14 +130,11 @@ private[sql] case class JDBCRelation(
   }
 
   override def insert(data: DataFrame, overwrite: Boolean): Unit = {
-    import scala.collection.JavaConverters._
-
-    val options = jdbcOptions.asProperties.asScala +
-      ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table)
-    val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
-
-    new JdbcRelationProvider().createRelation(
-      data.sparkSession.sqlContext, mode, options.toMap, data)
+    SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
+      data.write
+        .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
+        .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
+    }
   }
 
   override def toString: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index 9e889ff..6fa7c11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.types.StructType
 
 class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
@@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
     println(batchIdStr)
     println("-------------------------------------------")
     // scalastyle:off println
-    data.sparkSession.createDataFrame(
-      data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema)
-      .showInternal(numRowsToShow, isTruncated)
+    SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
+      data.sparkSession.createDataFrame(
+        data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
+        .show(numRowsToShow, isTruncated)
+    }
   }
 }
 
@@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider
 
     // Truncate the displayed data if it is too long, by default it is true
     val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
-    data.showInternal(numRowsToShow, isTruncated)
+    SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
+      data.show(numRowsToShow, isTruncated)
+    }
 
     ConsoleRelation(sqlContext, data)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index a5dac46..198a342 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
@@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
     }
     if (notCommitted) {
       logDebug(s"Committing batch $batchId to $this")
-      outputMode match {
-        case Append | Update =>
-          val rows = AddedData(batchId, data.collectInternal())
-          synchronized { batches += rows }
-
-        case Complete =>
-          val rows = AddedData(batchId, data.collectInternal())
-          synchronized {
-            batches.clear()
-            batches += rows
-          }
-
-        case _ =>
-          throw new IllegalArgumentException(
-            s"Output mode $outputMode is not supported by MemorySink")
+      SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
+        outputMode match {
+          case Append | Update =>
+            val rows = AddedData(batchId, data.collect())
+            synchronized { batches += rows }
+
+          case Complete =>
+            val rows = AddedData(batchId, data.collect())
+            synchronized {
+              batches.clear()
+              batches += rows
+            }
+
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Output mode $outputMode is not supported by MemorySink")
+        }
       }
     } else {
       logDebug(s"Skipping already committed batch: $batchId")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org