You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/12/20 23:51:24 UTC

spark git commit: [SPARK-18576][PYTHON] Add basic TaskContext information to PySpark

Repository: spark
Updated Branches:
  refs/heads/master caed89321 -> 047a9d92c


[SPARK-18576][PYTHON] Add basic TaskContext information to PySpark

## What changes were proposed in this pull request?

Adds basic TaskContext information to PySpark.

## How was this patch tested?

New unit tests to `tests.py` & existing unit tests.

Author: Holden Karau <ho...@us.ibm.com>

Closes #16211 from holdenk/SPARK-18576-pyspark-taskcontext.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/047a9d92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/047a9d92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/047a9d92

Branch: refs/heads/master
Commit: 047a9d92caa1f0af2305e4afeba8339abf32518b
Parents: caed893
Author: Holden Karau <ho...@us.ibm.com>
Authored: Tue Dec 20 15:51:21 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Dec 20 15:51:21 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala |  5 ++
 python/pyspark/__init__.py                      |  5 +-
 python/pyspark/taskcontext.py                   | 90 ++++++++++++++++++++
 python/pyspark/tests.py                         | 65 ++++++++++++++
 python/pyspark/worker.py                        |  6 ++
 5 files changed, 170 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/047a9d92/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0ca91b9..04ae97e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -275,6 +275,11 @@ private[spark] class PythonRunner(
         dataOut.writeInt(partitionIndex)
         // Python version of driver
         PythonRDD.writeUTF(pythonVer, dataOut)
+        // Write out the TaskContextInfo
+        dataOut.writeInt(context.stageId())
+        dataOut.writeInt(context.partitionId())
+        dataOut.writeInt(context.attemptNumber())
+        dataOut.writeLong(context.taskAttemptId())
         // sparkFilesDir
         PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
         // Python includes (*.zip and *.egg files)

http://git-wip-us.apache.org/repos/asf/spark/blob/047a9d92/python/pyspark/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 5f93586..9331e74 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -34,6 +34,8 @@ Public classes:
       Access files shipped with jobs.
   - :class:`StorageLevel`:
       Finer-grained cache persistence levels.
+  - :class:`TaskContext`:
+      Information about the current running task, avaialble on the workers and experimental.
 
 """
 
@@ -49,6 +51,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam
 from pyspark.broadcast import Broadcast
 from pyspark.serializers import MarshalSerializer, PickleSerializer
 from pyspark.status import *
+from pyspark.taskcontext import TaskContext
 from pyspark.profiler import Profiler, BasicProfiler
 from pyspark.version import __version__
 
@@ -106,5 +109,5 @@ from pyspark.sql import SQLContext, HiveContext, Row
 __all__ = [
     "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
     "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
-    "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler",
+    "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
 ]

http://git-wip-us.apache.org/repos/asf/spark/blob/047a9d92/python/pyspark/taskcontext.py
----------------------------------------------------------------------
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
new file mode 100644
index 0000000..e5218d9
--- /dev/null
+++ b/python/pyspark/taskcontext.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+
+class TaskContext(object):
+
+    """
+    .. note:: Experimental
+
+    Contextual information about a task which can be read or mutated during
+    execution. To access the TaskContext for a running task, use:
+    L{TaskContext.get()}.
+    """
+
+    _taskContext = None
+
+    _attemptNumber = None
+    _partitionId = None
+    _stageId = None
+    _taskAttemptId = None
+
+    def __new__(cls):
+        """Even if users construct TaskContext instead of using get, give them the singleton."""
+        taskContext = cls._taskContext
+        if taskContext is not None:
+            return taskContext
+        cls._taskContext = taskContext = object.__new__(cls)
+        return taskContext
+
+    def __init__(self):
+        """Construct a TaskContext, use get instead"""
+        pass
+
+    @classmethod
+    def _getOrCreate(cls):
+        """Internal function to get or create global TaskContext."""
+        if cls._taskContext is None:
+            cls._taskContext = TaskContext()
+        return cls._taskContext
+
+    @classmethod
+    def get(cls):
+        """
+        Return the currently active TaskContext. This can be called inside of
+        user functions to access contextual information about running tasks.
+
+        .. note:: Must be called on the worker, not the driver. Returns None if not initialized.
+        """
+        return cls._taskContext
+
+    def stageId(self):
+        """The ID of the stage that this task belong to."""
+        return self._stageId
+
+    def partitionId(self):
+        """
+        The ID of the RDD partition that is computed by this task.
+        """
+        return self._partitionId
+
+    def attemptNumber(self):
+        """"
+        How many times this task has been attempted.  The first task attempt will be assigned
+        attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
+        """
+        return self._attemptNumber
+
+    def taskAttemptId(self):
+        """
+        An ID that is unique to this task attempt (within the same SparkContext, no two task
+        attempts will share the same attempt ID).  This is roughly equivalent to Hadoop's
+        TaskAttemptID.
+        """
+        return self._taskAttemptId

http://git-wip-us.apache.org/repos/asf/spark/blob/047a9d92/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index fe314c5..c383d9a 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -69,6 +69,7 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer,
 from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
 from pyspark import shuffle
 from pyspark.profiler import BasicProfiler
+from pyspark.taskcontext import TaskContext
 
 _have_scipy = False
 _have_numpy = False
@@ -478,6 +479,70 @@ class AddFileTests(PySparkTestCase):
         self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
 
 
+class TaskContextTests(PySparkTestCase):
+
+    def setUp(self):
+        self._old_sys_path = list(sys.path)
+        class_name = self.__class__.__name__
+        # Allow retries even though they are normally disabled in local mode
+        self.sc = SparkContext('local[4, 2]', class_name)
+
+    def test_stage_id(self):
+        """Test the stage ids are available and incrementing as expected."""
+        rdd = self.sc.parallelize(range(10))
+        stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+        stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+        # Test using the constructor directly rather than the get()
+        stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
+        self.assertEqual(stage1 + 1, stage2)
+        self.assertEqual(stage1 + 2, stage3)
+        self.assertEqual(stage2 + 1, stage3)
+
+    def test_partition_id(self):
+        """Test the partition id."""
+        rdd1 = self.sc.parallelize(range(10), 1)
+        rdd2 = self.sc.parallelize(range(10), 2)
+        pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
+        pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
+        self.assertEqual(0, pids1[0])
+        self.assertEqual(0, pids1[9])
+        self.assertEqual(0, pids2[0])
+        self.assertEqual(1, pids2[9])
+
+    def test_attempt_number(self):
+        """Verify the attempt numbers are correctly reported."""
+        rdd = self.sc.parallelize(range(10))
+        # Verify a simple job with no failures
+        attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
+        map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
+
+        def fail_on_first(x):
+            """Fail on the first attempt so we get a positive attempt number"""
+            tc = TaskContext.get()
+            attempt_number = tc.attemptNumber()
+            partition_id = tc.partitionId()
+            attempt_id = tc.taskAttemptId()
+            if attempt_number == 0 and partition_id == 0:
+                raise Exception("Failing on first attempt")
+            else:
+                return [x, partition_id, attempt_number, attempt_id]
+        result = rdd.map(fail_on_first).collect()
+        # We should re-submit the first partition to it but other partitions should be attempt 0
+        self.assertEqual([0, 0, 1], result[0][0:3])
+        self.assertEqual([9, 3, 0], result[9][0:3])
+        first_partition = filter(lambda x: x[1] == 0, result)
+        map(lambda x: self.assertEqual(1, x[2]), first_partition)
+        other_partitions = filter(lambda x: x[1] != 0, result)
+        map(lambda x: self.assertEqual(0, x[2]), other_partitions)
+        # The task attempt id should be different
+        self.assertTrue(result[0][3] != result[9][3])
+
+    def test_tc_on_driver(self):
+        """Verify that getting the TaskContext on the driver returns None."""
+        tc = TaskContext.get()
+        self.assertTrue(tc is None)
+
+
 class RDDTests(ReusedPySparkTestCase):
 
     def test_range(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/047a9d92/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0918282..25ee475 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,6 +27,7 @@ import traceback
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
@@ -125,6 +126,11 @@ def main(infile, outfile):
                             ("%d.%d" % sys.version_info[:2], version))
 
         # initialize global state
+        taskContext = TaskContext._getOrCreate()
+        taskContext._stageId = read_int(infile)
+        taskContext._partitionId = read_int(infile)
+        taskContext._attemptNumber = read_int(infile)
+        taskContext._taskAttemptId = read_long(infile)
         shuffle.MemoryBytesSpilled = 0
         shuffle.DiskBytesSpilled = 0
         _accumulatorRegistry.clear()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org