You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/07/11 00:33:25 UTC

[spark] branch master updated: [SPARK-28234][CORE][PYTHON] Add python and JavaSparkContext support to get resources

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f84cca2  [SPARK-28234][CORE][PYTHON] Add python and JavaSparkContext support to get resources
f84cca2 is described below

commit f84cca2d84d67cee2877092d0354cf111c95eb8e
Author: Thomas Graves <tg...@nvidia.com>
AuthorDate: Thu Jul 11 09:32:58 2019 +0900

    [SPARK-28234][CORE][PYTHON] Add python and JavaSparkContext support to get resources
    
    ## What changes were proposed in this pull request?
    
    Add python api support and JavaSparkContext support for resources().  I needed the JavaSparkContext support for it to properly translate into python with the py4j stuff.
    
    ## How was this patch tested?
    
    Unit tests added and manually tested in local cluster mode and on yarn.
    
    Closes #25087 from tgravescs/SPARK-28234-python.
    
    Authored-by: Thomas Graves <tg...@nvidia.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 .../apache/spark/api/java/JavaSparkContext.scala   |  3 ++
 .../org/apache/spark/api/python/PythonRunner.scala | 10 +++++
 python/pyspark/__init__.py                         |  3 +-
 python/pyspark/context.py                          | 12 ++++++
 python/pyspark/resourceinformation.py              | 43 ++++++++++++++++++++++
 python/pyspark/taskcontext.py                      |  8 ++++
 python/pyspark/tests/test_context.py               | 35 +++++++++++++++++-
 python/pyspark/tests/test_taskcontext.py           | 38 +++++++++++++++++++
 python/pyspark/worker.py                           | 11 ++++++
 9 files changed, 161 insertions(+), 2 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index c5ef190..330c2f6 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.input.PortableDataStream
 import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD}
+import org.apache.spark.resource.ResourceInformation
 
 /**
  * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
@@ -114,6 +115,8 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
 
   def appName: String = sc.appName
 
+  def resources: JMap[String, ResourceInformation] = sc.resources.asJava
+
   def jars: util.List[String] = sc.jars.asJava
 
   def startTime: java.lang.Long = sc.startTime
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 414d208..dc6c596 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -281,6 +281,16 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         dataOut.writeInt(context.partitionId())
         dataOut.writeInt(context.attemptNumber())
         dataOut.writeLong(context.taskAttemptId())
+        val resources = context.resources()
+        dataOut.writeInt(resources.size)
+        resources.foreach { case (k, v) =>
+          PythonRDD.writeUTF(k, dataOut)
+          PythonRDD.writeUTF(v.name, dataOut)
+          dataOut.writeInt(v.addresses.size)
+          v.addresses.foreach { case addr =>
+            PythonRDD.writeUTF(addr, dataOut)
+          }
+        }
         val localProps = context.getLocalProperties.asScala
         dataOut.writeInt(localProps.size)
         localProps.foreach { case (k, v) =>
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index ee153af..70c0b27 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -54,6 +54,7 @@ from pyspark.files import SparkFiles
 from pyspark.storagelevel import StorageLevel
 from pyspark.accumulators import Accumulator, AccumulatorParam
 from pyspark.broadcast import Broadcast
+from pyspark.resourceinformation import ResourceInformation
 from pyspark.serializers import MarshalSerializer, PickleSerializer
 from pyspark.status import *
 from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo
@@ -118,5 +119,5 @@ __all__ = [
     "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
     "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
     "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
-    "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo",
+    "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "ResourceInformation",
 ]
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 69020e6..8d28488 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -37,6 +37,7 @@ from pyspark.java_gateway import launch_gateway, local_connect_and_auth
 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
     PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
 from pyspark.storagelevel import StorageLevel
+from pyspark.resourceinformation import ResourceInformation
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.traceback_utils import CallSite, first_spark_call
 from pyspark.status import StatusTracker
@@ -1107,6 +1108,17 @@ class SparkContext(object):
         conf.setAll(self._conf.getAll())
         return conf
 
+    @property
+    def resources(self):
+        resources = {}
+        jresources = self._jsc.resources()
+        for x in jresources:
+            name = jresources[x].name()
+            jaddresses = jresources[x].addresses()
+            addrs = [addr for addr in jaddresses]
+            resources[name] = ResourceInformation(name, addrs)
+        return resources
+
 
 def _test():
     import atexit
diff --git a/python/pyspark/resourceinformation.py b/python/pyspark/resourceinformation.py
new file mode 100644
index 0000000..aaed213
--- /dev/null
+++ b/python/pyspark/resourceinformation.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+class ResourceInformation(object):
+
+    """
+    .. note:: Evolving
+
+    Class to hold information about a type of Resource. A resource could be a GPU, FPGA, etc.
+    The array of addresses are resource specific and its up to the user to interpret the address.
+
+    One example is GPUs, where the addresses would be the indices of the GPUs
+
+    @param name the name of the resource
+    @param addresses an array of strings describing the addresses of the resource
+    """
+
+    def __init__(self, name, addresses):
+        self._name = name
+        self._addresses = addresses
+
+    @property
+    def name(self):
+        return self._name
+
+    @property
+    def addresses(self):
+        return self._addresses
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index 6d28491..790de0b 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -38,6 +38,7 @@ class TaskContext(object):
     _stageId = None
     _taskAttemptId = None
     _localProperties = None
+    _resources = None
 
     def __new__(cls):
         """Even if users construct TaskContext instead of using get, give them the singleton."""
@@ -95,6 +96,13 @@ class TaskContext(object):
         """
         return self._localProperties.get(key, None)
 
+    def resources(self):
+        """
+        Resources allocated to the task. The key is the resource name and the value is information
+        about the resource.
+        """
+        return self._resources
+
 
 BARRIER_FUNCTION = 1
 
diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py
index 4048ac5..bcd5d06 100644
--- a/python/pyspark/tests/test_context.py
+++ b/python/pyspark/tests/test_context.py
@@ -16,13 +16,14 @@
 #
 import os
 import shutil
+import stat
 import tempfile
 import threading
 import time
 import unittest
 from collections import namedtuple
 
-from pyspark import SparkFiles, SparkContext
+from pyspark import SparkConf, SparkFiles, SparkContext
 from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME
 
 
@@ -256,6 +257,38 @@ class ContextTests(unittest.TestCase):
             SparkContext(gateway=mock_insecure_gateway)
         self.assertIn("insecure Py4j gateway", str(context.exception))
 
+    def test_resources(self):
+        """Test the resources are empty by default."""
+        with SparkContext() as sc:
+            resources = sc.resources
+            self.assertEqual(len(resources), 0)
+
+
+class ContextTestsWithResources(unittest.TestCase):
+
+    def setUp(self):
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
+        self.tempFile.close()
+        os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
+                 stat.S_IROTH | stat.S_IXOTH)
+        conf = SparkConf().set("spark.driver.resource.gpu.amount", "1")
+        conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
+        self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
+
+    def test_resources(self):
+        """Test the resources are available."""
+        resources = self.sc.resources
+        self.assertEqual(len(resources), 1)
+        self.assertTrue('gpu' in resources)
+        self.assertEqual(resources['gpu'].name, 'gpu')
+        self.assertEqual(resources['gpu'].addresses, ['0'])
+
+    def tearDown(self):
+        os.unlink(self.tempFile.name)
+        self.sc.stop()
+
 
 if __name__ == "__main__":
     from pyspark.tests.test_context import *
diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py
index d7d1d80..66357b6 100644
--- a/python/pyspark/tests/test_taskcontext.py
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -16,7 +16,9 @@
 #
 import os
 import random
+import stat
 import sys
+import tempfile
 import time
 import unittest
 
@@ -43,6 +45,15 @@ class TaskContextTests(PySparkTestCase):
         self.assertEqual(stage1 + 2, stage3)
         self.assertEqual(stage2 + 1, stage3)
 
+    def test_resources(self):
+        """Test the resources are empty by default."""
+        rdd = self.sc.parallelize(range(10))
+        resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
+        # Test using the constructor directly rather than the get()
+        resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0]
+        self.assertEqual(len(resources1), 0)
+        self.assertEqual(len(resources2), 0)
+
     def test_partition_id(self):
         """Test the partition id."""
         rdd1 = self.sc.parallelize(range(10), 1)
@@ -174,6 +185,33 @@ class TaskContextTestsWithWorkerReuse(unittest.TestCase):
         self.sc.stop()
 
 
+class TaskContextTestsWithResources(unittest.TestCase):
+
+    def setUp(self):
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
+        self.tempFile.close()
+        os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
+                 stat.S_IROTH | stat.S_IXOTH)
+        conf = SparkConf().set("spark.task.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.discoveryScript", self.tempFile.name)
+        self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
+
+    def test_resources(self):
+        """Test the resources are available."""
+        rdd = self.sc.parallelize(range(10))
+        resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
+        self.assertEqual(len(resources), 1)
+        self.assertTrue('gpu' in resources)
+        self.assertEqual(resources['gpu'].name, 'gpu')
+        self.assertEqual(resources['gpu'].addresses, ['0'])
+
+    def tearDown(self):
+        os.unlink(self.tempFile.name)
+        self.sc.stop()
+
 if __name__ == "__main__":
     import unittest
     from pyspark.tests.test_taskcontext import *
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b34abd0..7f38c27 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -35,6 +35,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import BarrierTaskContext, TaskContext
 from pyspark.files import SparkFiles
+from pyspark.resourceinformation import ResourceInformation
 from pyspark.rdd import PythonEvalType
 from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
@@ -435,6 +436,16 @@ def main(infile, outfile):
         taskContext._partitionId = read_int(infile)
         taskContext._attemptNumber = read_int(infile)
         taskContext._taskAttemptId = read_long(infile)
+        taskContext._resources = {}
+        for r in range(read_int(infile)):
+            key = utf8_deserializer.loads(infile)
+            name = utf8_deserializer.loads(infile)
+            addresses = []
+            taskContext._resources = {}
+            for a in range(read_int(infile)):
+                addresses.append(utf8_deserializer.loads(infile))
+            taskContext._resources[key] = ResourceInformation(name, addresses)
+
         taskContext._localProperties = dict()
         for i in range(read_int(infile)):
             k = utf8_deserializer.loads(infile)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org