You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/20 21:14:51 UTC

spark git commit: [SPARK-9114] [SQL] [PySpark] convert returned object from UDF into internal type

Repository: spark
Updated Branches:
  refs/heads/master 02181fb6d -> 9f913c4fd


[SPARK-9114] [SQL] [PySpark] convert returned object from UDF into internal type

This PR also remove the duplicated code between registerFunction and UserDefinedFunction.

cc JoshRosen

Author: Davies Liu <da...@databricks.com>

Closes #7450 from davies/fix_return_type and squashes the following commits:

e80bf9f [Davies Liu] remove debugging code
f94b1f6 [Davies Liu] fix mima
8f9c58b [Davies Liu] convert returned object from UDF into internal type


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9f913c4f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9f913c4f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9f913c4f

Branch: refs/heads/master
Commit: 9f913c4fd6f0f223fd378e453d5b9a87beda1ac4
Parents: 02181fb
Author: Davies Liu <da...@databricks.com>
Authored: Mon Jul 20 12:14:47 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Jul 20 12:14:47 2015 -0700

----------------------------------------------------------------------
 project/MimaExcludes.scala                      |  4 +-
 python/pyspark/sql/context.py                   | 16 ++-----
 python/pyspark/sql/functions.py                 | 15 +++----
 python/pyspark/sql/tests.py                     |  4 +-
 .../org/apache/spark/sql/UDFRegistration.scala  | 44 ++++----------------
 .../apache/spark/sql/UserDefinedFunction.scala  | 10 +++--
 6 files changed, 32 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9f913c4f/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index dd85254..a2595ff 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -69,7 +69,9 @@ object MimaExcludes {
             ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"),
             // local function inside a method
             ProblemFilters.exclude[MissingMethodProblem](
-              "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1")
+              "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24")
           ) ++ Seq(
             // SPARK-8479 Add numNonzeros and numActives to Matrix.
             ProblemFilters.exclude[MissingMethodProblem](

http://git-wip-us.apache.org/repos/asf/spark/blob/9f913c4f/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index c93a15b..abb6522 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -34,6 +34,7 @@ from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.readwriter import DataFrameReader
 from pyspark.sql.utils import install_exception_handler
+from pyspark.sql.functions import UserDefinedFunction
 
 try:
     import pandas
@@ -191,19 +192,8 @@ class SQLContext(object):
         >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
         [Row(_c0=4)]
         """
-        func = lambda _, it: map(lambda x: f(*x), it)
-        ser = AutoBatchedSerializer(PickleSerializer())
-        command = (func, None, ser, ser)
-        pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
-        self._ssql_ctx.udf().registerPython(name,
-                                            bytearray(pickled_cmd),
-                                            env,
-                                            includes,
-                                            self._sc.pythonExec,
-                                            self._sc.pythonVer,
-                                            bvars,
-                                            self._sc._javaAccumulator,
-                                            returnType.json())
+        udf = UserDefinedFunction(f, returnType, name)
+        self._ssql_ctx.udf().registerPython(name, udf._judf)
 
     def _inferSchemaFromList(self, data):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/9f913c4f/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index fd5a3ba..031745a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -801,23 +801,24 @@ class UserDefinedFunction(object):
 
     .. versionadded:: 1.3
     """
-    def __init__(self, func, returnType):
+    def __init__(self, func, returnType, name=None):
         self.func = func
         self.returnType = returnType
         self._broadcast = None
-        self._judf = self._create_judf()
+        self._judf = self._create_judf(name)
 
-    def _create_judf(self):
-        f = self.func  # put it in closure `func`
-        func = lambda _, it: map(lambda x: f(*x), it)
+    def _create_judf(self, name):
+        f, returnType = self.func, self.returnType  # put them in closure `func`
+        func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
         ser = AutoBatchedSerializer(PickleSerializer())
         command = (func, None, ser, ser)
         sc = SparkContext._active_spark_context
         pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
         ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
         jdt = ssql_ctx.parseDataType(self.returnType.json())
-        fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
-        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
+        if name is None:
+            name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
+        judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
                                                  sc.pythonExec, sc.pythonVer, broadcast_vars,
                                                  sc._javaAccumulator, jdt)
         return judf

http://git-wip-us.apache.org/repos/asf/spark/blob/9f913c4f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7a55d80..ea821f4 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -417,12 +417,14 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEquals(point, ExamplePoint(1.0, 2.0))
 
     def test_udf_with_udt(self):
-        from pyspark.sql.tests import ExamplePoint
+        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
         df = self.sc.parallelize([row]).toDF()
         self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
         udf = UserDefinedFunction(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
+        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
 
     def test_parquet_with_udt(self):
         from pyspark.sql.tests import ExamplePoint

http://git-wip-us.apache.org/repos/asf/spark/blob/9f913c4f/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index d35d37d..7cd7421 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -22,13 +22,10 @@ import java.util.{List => JList, Map => JMap}
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.Try
 
-import org.apache.spark.{Accumulator, Logging}
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.Logging
 import org.apache.spark.sql.api.java._
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
-import org.apache.spark.sql.execution.PythonUDF
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -40,44 +37,19 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
 
   private val functionRegistry = sqlContext.functionRegistry
 
-  protected[sql] def registerPython(
-      name: String,
-      command: Array[Byte],
-      envVars: JMap[String, String],
-      pythonIncludes: JList[String],
-      pythonExec: String,
-      pythonVer: String,
-      broadcastVars: JList[Broadcast[PythonBroadcast]],
-      accumulator: Accumulator[JList[Array[Byte]]],
-      stringDataType: String): Unit = {
+  protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
     log.debug(
       s"""
         | Registering new PythonUDF:
         | name: $name
-        | command: ${command.toSeq}
-        | envVars: $envVars
-        | pythonIncludes: $pythonIncludes
-        | pythonExec: $pythonExec
-        | dataType: $stringDataType
+        | command: ${udf.command.toSeq}
+        | envVars: ${udf.envVars}
+        | pythonIncludes: ${udf.pythonIncludes}
+        | pythonExec: ${udf.pythonExec}
+        | dataType: ${udf.dataType}
       """.stripMargin)
 
-
-    val dataType = sqlContext.parseDataType(stringDataType)
-
-    def builder(e: Seq[Expression]): PythonUDF =
-      PythonUDF(
-        name,
-        command,
-        envVars,
-        pythonIncludes,
-        pythonExec,
-        pythonVer,
-        broadcastVars,
-        accumulator,
-        dataType,
-        e)
-
-    functionRegistry.registerFunction(name, builder)
+    functionRegistry.registerFunction(name, udf.builder)
   }
 
   // scalastyle:off

http://git-wip-us.apache.org/repos/asf/spark/blob/9f913c4f/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index b14e00a..0f8cd28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Accumulator
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.python.PythonBroadcast
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.expressions.ScalaUDF
+import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
 import org.apache.spark.sql.execution.PythonUDF
 import org.apache.spark.sql.types.DataType
 
@@ -66,10 +66,14 @@ private[sql] case class UserDefinedPythonFunction(
     accumulator: Accumulator[JList[Array[Byte]]],
     dataType: DataType) {
 
+  def builder(e: Seq[Expression]): PythonUDF = {
+    PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
+      accumulator, dataType, e)
+  }
+
   /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
   def apply(exprs: Column*): Column = {
-    val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
-      broadcastVars, accumulator, dataType, exprs.map(_.expr))
+    val udf = builder(exprs.map(_.expr))
     Column(udf)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org