You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/24 21:45:27 UTC

spark git commit: [SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code

Repository: spark
Updated Branches:
  refs/heads/master f92f53fae -> a60f91284


[SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code

## What changes were proposed in this pull request?

When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear.

## How was the this patch tested?

by existing unit tests.

Author: Wenchen Fan <we...@databricks.com>

Closes #11342 from cloud-fan/python-clean.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a60f9128
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a60f9128
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a60f9128

Branch: refs/heads/master
Commit: a60f91284ceee64de13f04559ec19c13a820a133
Parents: f92f53f
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Feb 24 12:44:54 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Feb 24 12:44:54 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 37 ++++++++++++--------
 python/pyspark/rdd.py                           | 23 +++++++-----
 python/pyspark/sql/context.py                   |  2 +-
 python/pyspark/sql/functions.py                 |  8 ++---
 .../org/apache/spark/sql/UDFRegistration.scala  |  8 ++---
 .../python/BatchPythonEvaluation.scala          |  8 +----
 .../spark/sql/execution/python/PythonUDF.scala  | 13 ++-----
 .../python/UserDefinedPythonFunction.scala      | 15 ++------
 8 files changed, 51 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f12e2df..05d1c31 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
 
 private[spark] class PythonRDD(
     parent: RDD[_],
-    command: Array[Byte],
-    envVars: JMap[String, String],
-    pythonIncludes: JList[String],
-    preservePartitoning: Boolean,
-    pythonExec: String,
-    pythonVer: String,
-    broadcastVars: JList[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[JList[Array[Byte]]])
+    func: PythonFunction,
+    preservePartitoning: Boolean)
   extends RDD[Array[Byte]](parent) {
 
   val bufferSize = conf.getInt("spark.buffer.size", 65536)
@@ -64,29 +58,37 @@ private[spark] class PythonRDD(
   val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 
   override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
-    val runner = new PythonRunner(
-      command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator,
-      bufferSize, reuse_worker)
+    val runner = new PythonRunner(func, bufferSize, reuse_worker)
     runner.compute(firstParent.iterator(split, context), split.index, context)
   }
 }
 
-
 /**
- * A helper class to run Python UDFs in Spark.
+ * A wrapper for a Python function, contains all necessary context to run the function in Python
+ * runner.
  */
-private[spark] class PythonRunner(
+private[spark] case class PythonFunction(
     command: Array[Byte],
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     pythonExec: String,
     pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[JList[Array[Byte]]],
+    accumulator: Accumulator[JList[Array[Byte]]])
+
+/**
+ * A helper class to run Python UDFs in Spark.
+ */
+private[spark] class PythonRunner(
+    func: PythonFunction,
     bufferSize: Int,
     reuse_worker: Boolean)
   extends Logging {
 
+  private val envVars = func.envVars
+  private val pythonExec = func.pythonExec
+  private val accumulator = func.accumulator
+
   def compute(
       inputIterator: Iterator[_],
       partitionIndex: Int,
@@ -225,6 +227,11 @@ private[spark] class PythonRunner(
 
     @volatile private var _exception: Exception = null
 
+    private val pythonVer = func.pythonVer
+    private val pythonIncludes = func.pythonIncludes
+    private val broadcastVars = func.broadcastVars
+    private val command = func.command
+
     setDaemon(true)
 
     /** Contains the exception thrown while writing the parent iterator to the Python process. */

http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4eaf589..37574ce 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2309,7 +2309,7 @@ class RDD(object):
                 yield row
 
 
-def _prepare_for_python_RDD(sc, command, obj=None):
+def _prepare_for_python_RDD(sc, command):
     # the serialized command will be compressed by broadcast
     ser = CloudPickleSerializer()
     pickled_command = ser.dumps(command)
@@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None):
     return pickled_command, broadcast_vars, env, includes
 
 
+def _wrap_function(sc, func, deserializer, serializer, profiler=None):
+    assert deserializer, "deserializer should not be empty"
+    assert serializer, "serializer should not be empty"
+    command = (func, profiler, deserializer, serializer)
+    pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
+    return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
+                                  sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+
+
 class PipelinedRDD(RDD):
 
     """
@@ -2390,14 +2399,10 @@ class PipelinedRDD(RDD):
         else:
             profiler = None
 
-        command = (self.func, profiler, self._prev_jrdd_deserializer,
-                   self._jrdd_deserializer)
-        pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
-        python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
-                                             bytearray(pickled_cmd),
-                                             env, includes, self.preservesPartitioning,
-                                             self.ctx.pythonExec, self.ctx.pythonVer,
-                                             bvars, self.ctx._javaAccumulator)
+        wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
+                                      self._jrdd_deserializer, profiler)
+        python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
+                                             self.preservesPartitioning)
         self._jrdd_val = python_rdd.asJavaRDD()
 
         if profiler:

http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 89bf144..87e32c0 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -29,7 +29,7 @@ else:
 from py4j.protocol import Py4JError
 
 from pyspark import since
-from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
+from pyspark.rdd import RDD, ignore_unicode_prefix
 from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
 from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
     _infer_schema, _has_nulltype, _merge_type, _create_converter

http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6894c27..b30cc67 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -25,7 +25,7 @@ if sys.version < "3":
     from itertools import imap as map
 
 from pyspark import since, SparkContext
-from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
+from pyspark.rdd import _wrap_function, ignore_unicode_prefix
 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
 from pyspark.sql.types import StringType
 from pyspark.sql.column import Column, _to_java_column, _to_seq
@@ -1645,16 +1645,14 @@ class UserDefinedFunction(object):
         f, returnType = self.func, self.returnType  # put them in closure `func`
         func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
         ser = AutoBatchedSerializer(PickleSerializer())
-        command = (func, None, ser, ser)
         sc = SparkContext.getOrCreate()
-        pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+        wrapped_func = _wrap_function(sc, func, ser, ser)
         ctx = SQLContext.getOrCreate(sc)
         jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
         if name is None:
             name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
         judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
-            name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer,
-            broadcast_vars, sc._javaAccumulator, jdt)
+            name, wrapped_func, jdt)
         return judf
 
     def __del__(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index ecfc170..de01cbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
       s"""
         | Registering new PythonUDF:
         | name: $name
-        | command: ${udf.command.toSeq}
-        | envVars: ${udf.envVars}
-        | pythonIncludes: ${udf.pythonIncludes}
-        | pythonExec: ${udf.pythonExec}
+        | command: ${udf.func.command.toSeq}
+        | envVars: ${udf.func.envVars}
+        | pythonIncludes: ${udf.func.pythonIncludes}
+        | pythonExec: ${udf.func.pythonExec}
         | dataType: ${udf.dataType}
       """.stripMargin)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index 00df019..c65a7bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
 
       // Output iterator for results from Python.
       val outputIterator = new PythonRunner(
-        udf.command,
-        udf.envVars,
-        udf.pythonIncludes,
-        udf.pythonExec,
-        udf.pythonVer,
-        udf.broadcastVars,
-        udf.accumulator,
+        udf.func,
         bufferSize,
         reuseWorker
       ).compute(inputIterator, context.partitionId(), context)

http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 9aff0be..0aa2785 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -17,9 +17,8 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.{Accumulator, Logging}
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.Logging
+import org.apache.spark.api.python.PythonFunction
 import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
 import org.apache.spark.sql.types.DataType
 
@@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType
  */
 case class PythonUDF(
     name: String,
-    command: Array[Byte],
-    envVars: java.util.Map[String, String],
-    pythonIncludes: java.util.List[String],
-    pythonExec: String,
-    pythonVer: String,
-    broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[java.util.List[Array[Byte]]],
+    func: PythonFunction,
     dataType: DataType,
     children: Seq[Expression])
   extends Expression with Unevaluable with NonSQLExpression with Logging {

http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 79ac1c8..d301874 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -17,9 +17,7 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.Accumulator
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.api.python.PythonFunction
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.types.DataType
@@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType
  */
 case class UserDefinedPythonFunction(
     name: String,
-    command: Array[Byte],
-    envVars: java.util.Map[String, String],
-    pythonIncludes: java.util.List[String],
-    pythonExec: String,
-    pythonVer: String,
-    broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[java.util.List[Array[Byte]]],
+    func: PythonFunction,
     dataType: DataType) {
 
   def builder(e: Seq[Expression]): PythonUDF = {
-    PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
-      accumulator, dataType, e)
+    PythonUDF(name, func, dataType, e)
   }
 
   /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org