You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/06/26 18:36:03 UTC
spark git commit: [SPARK-20213][SQL][FOLLOW-UP] introduce
SQLExecution.ignoreNestedExecutionId
Repository: spark
Updated Branches:
refs/heads/master 9e50a1d37 -> c22810004
[SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId
## What changes were proposed in this pull request?
in https://github.com/apache/spark/pull/18064, to work around the nested sql execution id issue, we introduced several internal methods in `Dataset`, like `collectInternal`, `countInternal`, `showInternal`, etc., to avoid nested execution id.
However, this approach has poor expansibility. When we hit other nested execution id cases, we may need to add more internal methods in `Dataset`.
Our goal is to ignore the nested execution id in some cases, and we can have a better approach to achieve this goal, by introducing `SQLExecution.ignoreNestedExecutionId`. Whenever we find a place which needs to ignore the nested execution, we can just wrap the action with `SQLExecution.ignoreNestedExecutionId`, and this is more expansible than the previous approach.
The idea comes from https://github.com/apache/spark/pull/17540/files#diff-ab49028253e599e6e74cc4f4dcb2e3a8R57 by rdblue
## How was this patch tested?
existing tests.
Author: Wenchen Fan <we...@databricks.com>
Closes #18419 from cloud-fan/follow.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2281000
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2281000
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2281000
Branch: refs/heads/master
Commit: c22810004fb2db249be6477c9801d09b807af851
Parents: 9e50a1d
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Jun 27 02:35:51 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jun 27 02:35:51 2017 +0800
----------------------------------------------------------------------
.../scala/org/apache/spark/sql/Dataset.scala | 39 ++------------------
.../spark/sql/execution/SQLExecution.scala | 39 ++++++++++++++++++--
.../execution/command/AnalyzeTableCommand.scala | 5 ++-
.../spark/sql/execution/command/cache.scala | 19 +++++-----
.../datasources/csv/CSVDataSource.scala | 6 ++-
.../datasources/jdbc/JDBCRelation.scala | 14 +++----
.../spark/sql/execution/streaming/console.scala | 13 +++++--
.../spark/sql/execution/streaming/memory.scala | 33 +++++++++--------
8 files changed, 89 insertions(+), 79 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6e66e92..268a37f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -246,13 +246,8 @@ class Dataset[T] private[sql](
_numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
val numRows = _numRows.max(0)
val takeResult = toDF().take(numRows + 1)
- showString(takeResult, numRows, truncate, vertical)
- }
-
- private def showString(
- dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = {
- val hasMoreData = dataWithOneMoreRow.length > numRows
- val data = dataWithOneMoreRow.take(numRows)
+ val hasMoreData = takeResult.length > numRows
+ val data = takeResult.take(numRows)
lazy val timeZone =
DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
@@ -688,19 +683,6 @@ class Dataset[T] private[sql](
println(showString(numRows, truncate = 0))
}
- // An internal version of `show`, which won't set execution id and trigger listeners.
- private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = {
- val numRows = _numRows.max(0)
- val takeResult = toDF().takeInternal(numRows + 1)
-
- if (truncate) {
- println(showString(takeResult, numRows, truncate = 20, vertical = false))
- } else {
- println(showString(takeResult, numRows, truncate = 0, vertical = false))
- }
- }
- // scalastyle:on println
-
/**
* Displays the Dataset in a tabular form. For example:
* {{{
@@ -2467,11 +2449,6 @@ class Dataset[T] private[sql](
*/
def take(n: Int): Array[T] = head(n)
- // An internal version of `take`, which won't set execution id and trigger listeners.
- private[sql] def takeInternal(n: Int): Array[T] = {
- collectFromPlan(limit(n).queryExecution.executedPlan)
- }
-
/**
* Returns the first `n` rows in the Dataset as a list.
*
@@ -2496,11 +2473,6 @@ class Dataset[T] private[sql](
*/
def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)
- // An internal version of `collect`, which won't set execution id and trigger listeners.
- private[sql] def collectInternal(): Array[T] = {
- collectFromPlan(queryExecution.executedPlan)
- }
-
/**
* Returns a Java list that contains all rows in this Dataset.
*
@@ -2542,11 +2514,6 @@ class Dataset[T] private[sql](
plan.executeCollect().head.getLong(0)
}
- // An internal version of `count`, which won't set execution id and trigger listeners.
- private[sql] def countInternal(): Long = {
- groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0)
- }
-
/**
* Returns a new Dataset that has exactly `numPartitions` partitions.
*
@@ -2792,7 +2759,7 @@ class Dataset[T] private[sql](
createTempViewCommand(viewName, replace = true, global = true)
}
- private[spark] def createTempViewCommand(
+ private def createTempViewCommand(
viewName: String,
replace: Boolean,
global: Boolean): CreateViewCommand = {
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index bb206e8..ca8bed5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -29,6 +29,8 @@ object SQLExecution {
val EXECUTION_ID_KEY = "spark.sql.execution.id"
+ private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
+
private val _nextExecutionId = new AtomicLong(0)
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
@@ -42,8 +44,11 @@ object SQLExecution {
private val testing = sys.props.contains("spark.testing")
private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
+ val sc = sparkSession.sparkContext
+ val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
+ val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
// only throw an exception during tests. a missing execution ID should not fail a job.
- if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) {
+ if (testing && !isNestedExecution && !hasExecutionId) {
// Attention testers: when a test fails with this exception, it means that the action that
// started execution of a query didn't call withNewExecutionId. The execution ID should be
// set by calling withNewExecutionId in the action that begins execution, like
@@ -65,7 +70,7 @@ object SQLExecution {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
- val r = try {
+ try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
@@ -84,7 +89,15 @@ object SQLExecution {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, null)
}
- r
+ } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
+ // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
+ // `body`, so that Spark jobs issued in the `body` won't be tracked.
+ try {
+ sc.setLocalProperty(EXECUTION_ID_KEY, null)
+ body
+ } finally {
+ sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
+ }
} else {
// Don't support nested `withNewExecutionId`. This is an example of the nested
// `withNewExecutionId`:
@@ -100,7 +113,9 @@ object SQLExecution {
// all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
//
// A real case is the `DataFrame.count` method.
- throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
+ throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
+ "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
+ "jobs issued by the nested execution.")
}
}
@@ -118,4 +133,20 @@ object SQLExecution {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
}
}
+
+ /**
+ * Wrap an action which may have nested execution id. This method can be used to run an execution
+ * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
+ * all Spark jobs issued in the body won't be tracked in UI.
+ */
+ def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
+ val sc = sparkSession.sparkContext
+ val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
+ try {
+ sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
+ body
+ } finally {
+ sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
index 06e588f..13b8faf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.internal.SessionState
@@ -58,7 +59,9 @@ case class AnalyzeTableCommand(
// 2. when total size is changed, `oldRowCount` becomes invalid.
// This is to make sure that we only record the right statistics.
if (!noscan) {
- val newRowCount = sparkSession.table(tableIdentWithDB).countInternal()
+ val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
+ sparkSession.table(tableIdentWithDB).count()
+ }
if (newRowCount >= 0 && newRowCount != oldRowCount) {
newStats = if (newStats.isDefined) {
newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index 184d038..d36eb75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SQLExecution
case class CacheTableCommand(
tableIdent: TableIdentifier,
@@ -33,16 +34,16 @@ case class CacheTableCommand(
override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
override def run(sparkSession: SparkSession): Seq[Row] = {
- plan.foreach { logicalPlan =>
- Dataset.ofRows(sparkSession, logicalPlan)
- .createTempViewCommand(tableIdent.quotedString, replace = false, global = false)
- .run(sparkSession)
- }
- sparkSession.catalog.cacheTable(tableIdent.quotedString)
+ SQLExecution.ignoreNestedExecutionId(sparkSession) {
+ plan.foreach { logicalPlan =>
+ Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
+ }
+ sparkSession.catalog.cacheTable(tableIdent.quotedString)
- if (!isLazy) {
- // Performs eager caching
- sparkSession.table(tableIdent).countInternal()
+ if (!isLazy) {
+ // Performs eager caching
+ sparkSession.table(tableIdent).count()
+ }
}
Seq.empty[Row]
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index eadc6c9..99133bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
@@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource {
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): StructType = {
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
- val maybeFirstLine =
- CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption
+ val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
+ CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
+ }
inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index a06f1ce..b11da70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
@@ -129,14 +130,11 @@ private[sql] case class JDBCRelation(
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
- import scala.collection.JavaConverters._
-
- val options = jdbcOptions.asProperties.asScala +
- ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table)
- val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
-
- new JdbcRelationProvider().createRelation(
- data.sparkSession.sqlContext, mode, options.toMap, data)
+ SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
+ data.write
+ .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
+ .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
+ }
}
override def toString: String = {
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index 9e889ff..6fa7c11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.types.StructType
class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
@@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
println(batchIdStr)
println("-------------------------------------------")
// scalastyle:off println
- data.sparkSession.createDataFrame(
- data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema)
- .showInternal(numRowsToShow, isTruncated)
+ SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
+ data.sparkSession.createDataFrame(
+ data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
+ .show(numRowsToShow, isTruncated)
+ }
}
}
@@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider
// Truncate the displayed data if it is too long, by default it is true
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
- data.showInternal(numRowsToShow, isTruncated)
+ SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
+ data.show(numRowsToShow, isTruncated)
+ }
ConsoleRelation(sqlContext, data)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c2281000/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index a5dac46..198a342 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
- outputMode match {
- case Append | Update =>
- val rows = AddedData(batchId, data.collectInternal())
- synchronized { batches += rows }
-
- case Complete =>
- val rows = AddedData(batchId, data.collectInternal())
- synchronized {
- batches.clear()
- batches += rows
- }
-
- case _ =>
- throw new IllegalArgumentException(
- s"Output mode $outputMode is not supported by MemorySink")
+ SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
+ outputMode match {
+ case Append | Update =>
+ val rows = AddedData(batchId, data.collect())
+ synchronized { batches += rows }
+
+ case Complete =>
+ val rows = AddedData(batchId, data.collect())
+ synchronized {
+ batches.clear()
+ batches += rows
+ }
+
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Output mode $outputMode is not supported by MemorySink")
+ }
}
} else {
logDebug(s"Skipping already committed batch: $batchId")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org