You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/09/04 04:09:54 UTC

git commit: [SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF

Repository: spark
Updated Branches:
  refs/heads/master 248067adb -> c5cbc4923


[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF

After this patch, broadcast can be used in Python UDF.

Author: Davies Liu <da...@gmail.com>

Closes #2243 from davies/udf_broadcast and squashes the following commits:

7b88861 [Davies Liu] support broadcast in UDF


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5cbc492
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5cbc492
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5cbc492

Branch: refs/heads/master
Commit: c5cbc49233193836b321cb6b77ce69dae798570b
Parents: 248067a
Author: Davies Liu <da...@gmail.com>
Authored: Wed Sep 3 19:08:39 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Sep 3 19:08:39 2014 -0700

----------------------------------------------------------------------
 python/pyspark/sql.py                           | 17 ++++++++-------
 python/pyspark/tests.py                         | 22 ++++++++++++++++++++
 .../org/apache/spark/sql/UdfRegistration.scala  |  3 +++
 .../apache/spark/sql/execution/pythonUdfs.scala |  3 ++-
 4 files changed, 36 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 4431692..aaa35da 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -942,9 +942,7 @@ class SQLContext:
         self._jsc = self._sc._jsc
         self._jvm = self._sc._jvm
         self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
-
-        if sqlContext:
-            self._scala_SQLContext = sqlContext
+        self._scala_SQLContext = sqlContext
 
     @property
     def _ssql_ctx(self):
@@ -953,7 +951,7 @@ class SQLContext:
         Subclasses can override this property to provide their own
         JVM Contexts.
         """
-        if not hasattr(self, '_scala_SQLContext'):
+        if self._scala_SQLContext is None:
             self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
         return self._scala_SQLContext
 
@@ -970,23 +968,26 @@ class SQLContext:
         >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
         >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
         [Row(c0=4)]
-        >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
-        >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
-        [Row(c0=5)]
         """
         func = lambda _, it: imap(lambda x: f(*x), it)
         command = (func,
                    BatchedSerializer(PickleSerializer(), 1024),
                    BatchedSerializer(PickleSerializer(), 1024))
+        pickled_command = CloudPickleSerializer().dumps(command)
+        broadcast_vars = ListConverter().convert(
+            [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
+            self._sc._gateway._gateway_client)
+        self._sc._pickled_broadcast_vars.clear()
         env = MapConverter().convert(self._sc.environment,
                                      self._sc._gateway._gateway_client)
         includes = ListConverter().convert(self._sc._python_includes,
                                            self._sc._gateway._gateway_client)
         self._ssql_ctx.registerPython(name,
-                                      bytearray(CloudPickleSerializer().dumps(command)),
+                                      bytearray(pickled_command),
                                       env,
                                       includes,
                                       self._sc.pythonExec,
+                                      broadcast_vars,
                                       self._sc._javaAccumulator,
                                       str(returnType))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f1a75cb..3e74799 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -43,6 +43,7 @@ from pyspark.context import SparkContext
 from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
+from pyspark.sql import SQLContext, IntegerType
 
 _have_scipy = False
 _have_numpy = False
@@ -525,6 +526,27 @@ class TestRDDFunctions(PySparkTestCase):
         self.assertRaises(TypeError, lambda: rdd.histogram(2))
 
 
+class TestSQL(PySparkTestCase):
+
+    def setUp(self):
+        PySparkTestCase.setUp(self)
+        self.sqlCtx = SQLContext(self.sc)
+
+    def test_udf(self):
+        self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+        [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], 5)
+
+    def test_broadcast_in_udf(self):
+        bar = {"a": "aa", "b": "bb", "c": "abc"}
+        foo = self.sc.broadcast(bar)
+        self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+        [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+        self.assertEqual("abc", res[0])
+        [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+        self.assertEqual("", res[0])
+
+
 class TestIO(PySparkTestCase):
 
     def test_stdout_redirection(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 0b48e9e..0ea1105 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.util.{List => JList, Map => JMap}
 
 import org.apache.spark.Accumulator
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf}
 import org.apache.spark.sql.execution.PythonUDF
@@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration {
       envVars: JMap[String, String],
       pythonIncludes: JList[String],
       pythonExec: String,
+      broadcastVars: JList[Broadcast[Array[Byte]]],
       accumulator: Accumulator[JList[Array[Byte]]],
       stringDataType: String): Unit = {
     log.debug(
@@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration {
         envVars,
         pythonIncludes,
         pythonExec,
+        broadcastVars,
         accumulator,
         dataType,
         e)

http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 3dc8be2..0977da3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -42,6 +42,7 @@ private[spark] case class PythonUDF(
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     pythonExec: String,
+    broadcastVars: JList[Broadcast[Array[Byte]]],
     accumulator: Accumulator[JList[Array[Byte]]],
     dataType: DataType,
     children: Seq[Expression]) extends Expression with SparkLogging {
@@ -145,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
       udf.pythonIncludes,
       false,
       udf.pythonExec,
-      Seq[Broadcast[Array[Byte]]](),
+      udf.broadcastVars,
       udf.accumulator
     ).mapPartitions { iter =>
       val pickle = new Unpickler


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org