You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2017/02/09 18:07:13 UTC

spark git commit: [SPARK-19514] Making range interruptible.

Repository: spark
Updated Branches:
  refs/heads/master 3fc8e8caf -> 4064574d0


[SPARK-19514] Making range interruptible.

## What changes were proposed in this pull request?

Previously range operator could not be interrupted. For example, using DAGScheduler.cancelStage(...) on a query with range might have been ineffective.

This change adds periodic checks of TaskContext.isInterrupted to codegen version, and InterruptibleOperator to non-codegen version.

I benchmarked the performance of codegen version on a sample query `spark.range(1000L * 1000 * 1000 * 10).count()` and there is no measurable difference.

## How was this patch tested?

Adds a unit test.

Author: Ala Luszczak <al...@databricks.com>

Closes #16872 from ala/SPARK-19514b.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4064574d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4064574d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4064574d

Branch: refs/heads/master
Commit: 4064574d031215fcfdf899a1ee9f3b6fecb1bfc9
Parents: 3fc8e8c
Author: Ala Luszczak <al...@databricks.com>
Authored: Thu Feb 9 19:07:06 2017 +0100
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Feb 9 19:07:06 2017 +0100

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |  8 +++--
 .../sql/execution/basicPhysicalOperators.scala  | 12 +++++--
 .../apache/spark/sql/DataFrameRangeSuite.scala  | 38 +++++++++++++++++++-
 3 files changed, 52 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4064574d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 04b812e..374d714 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -23,14 +23,14 @@ import java.util.{Map => JavaMap}
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.language.existentials
 import scala.util.control.NonFatal
 
 import com.google.common.cache.{CacheBuilder, CacheLoader}
 import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
 import org.codehaus.janino.util.ClassFile
-import scala.language.existentials
 
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.metrics.source.CodegenMetrics
 import org.apache.spark.sql.catalyst.InternalRow
@@ -933,7 +933,9 @@ object CodeGenerator extends Logging {
       classOf[UnsafeArrayData].getName,
       classOf[MapData].getName,
       classOf[UnsafeMapData].getName,
-      classOf[Expression].getName
+      classOf[Expression].getName,
+      classOf[TaskContext].getName,
+      classOf[TaskKilledException].getName
     ))
     evaluator.setExtendedClass(classOf[GeneratedClass])
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4064574d/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 792fb3e..649c21b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import scala.concurrent.{ExecutionContext, Future}
 import scala.concurrent.duration.Duration
 
-import org.apache.spark.SparkException
+import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext}
 import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -363,6 +363,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
     val ev = ExprCode("", "false", value)
     val BigInt = classOf[java.math.BigInteger].getName
 
+    val taskContext = ctx.freshName("taskContext")
+    ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
+
     // In order to periodically update the metrics without inflicting performance penalty, this
     // operator produces elements in batches. After a batch is complete, the metrics are updated
     // and a new batch is started.
@@ -443,6 +446,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
       |     if (shouldStop()) return;
       |   }
       |
+      |   if ($taskContext.isInterrupted()) {
+      |     throw new TaskKilledException();
+      |   }
+      |
       |   long $nextBatchTodo;
       |   if ($numElementsTodo > ${batchSize}L) {
       |     $nextBatchTodo = ${batchSize}L;
@@ -482,7 +489,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
         val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
         val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
 
-        new Iterator[InternalRow] {
+        val iter = new Iterator[InternalRow] {
           private[this] var number: Long = safePartitionStart
           private[this] var overflow: Boolean = false
 
@@ -511,6 +518,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
             unsafeRow
           }
         }
+        new InterruptibleIterator(TaskContext.get(), iter)
       }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4064574d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
index 6d2d776..3ebfd9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -17,14 +17,20 @@
 
 package org.apache.spark.sql
 
+import scala.concurrent.duration._
 import scala.math.abs
 import scala.util.Random
 
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.SparkException
+import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerTaskStart}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
-class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
+
+class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventually {
 
   test("SPARK-7150 range api") {
     // numSlice is greater than length
@@ -127,4 +133,34 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
       }
     }
   }
+
+  test("Cancelling stage in a query with Range.") {
+    val listener = new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+        Thread.sleep(100)
+        sparkContext.cancelStage(taskStart.stageId)
+      }
+    }
+
+    sparkContext.addSparkListener(listener)
+    for (codegen <- Seq(true, false)) {
+      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
+        val ex = intercept[SparkException] {
+          spark.range(100000L).crossJoin(spark.range(100000L))
+            .toDF("a", "b").agg(sum("a"), sum("b")).collect()
+        }
+        ex.getCause() match {
+          case null =>
+            assert(ex.getMessage().contains("cancelled"))
+          case cause: SparkException =>
+            assert(cause.getMessage().contains("cancelled"))
+          case cause: Throwable =>
+            fail("Expected the casue to be SparkException, got " + cause.toString() + " instead.")
+        }
+      }
+      eventually(timeout(20.seconds)) {
+        assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org