You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/01/16 11:21:28 UTC

spark git commit: [SPARK-22978][PYSPARK] Register Vectorized UDFs for SQL Statement

Repository: spark
Updated Branches:
  refs/heads/master 66217dac4 -> b85eb946a


[SPARK-22978][PYSPARK] Register Vectorized UDFs for SQL Statement

## What changes were proposed in this pull request?
Register Vectorized UDFs for SQL Statement. For example,

```Python
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> pandas_udf("integer", PandasUDFType.SCALAR)
... def add_one(x):
...     return x + 1
...
>>> _ = spark.udf.register("add_one", add_one)
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect()
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
```

## How was this patch tested?
Added test cases

Author: gatorsmile <ga...@gmail.com>

Closes #20171 from gatorsmile/supportVectorizedUDF.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b85eb946
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b85eb946
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b85eb946

Branch: refs/heads/master
Commit: b85eb946ac298e711dad25db0d04eee41d7fd236
Parents: 66217da
Author: gatorsmile <ga...@gmail.com>
Authored: Tue Jan 16 20:20:33 2018 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Tue Jan 16 20:20:33 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/catalog.py | 75 ++++++++++++++++++++++++++-----------
 python/pyspark/sql/context.py | 51 +++++++++++++++++--------
 python/pyspark/sql/tests.py   | 76 ++++++++++++++++++++++++++++++++------
 3 files changed, 155 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b85eb946/python/pyspark/sql/catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 1566031..35fbe9e 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -226,18 +226,23 @@ class Catalog(object):
 
     @ignore_unicode_prefix
     @since(2.0)
-    def registerFunction(self, name, f, returnType=StringType()):
+    def registerFunction(self, name, f, returnType=None):
         """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
-        as a UDF. The registered UDF can be used in SQL statement.
+        as a UDF. The registered UDF can be used in SQL statements.
 
-        In addition to a name and the function itself, the return type can be optionally specified.
-        When the return type is not given it default to a string and conversion will automatically
-        be done.  For any other return type, the produced object must match the specified type.
+        :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.
 
-        :param name: name of the UDF
-        :param f: a Python function, or a wrapped/native UserDefinedFunction
-        :param returnType: a :class:`pyspark.sql.types.DataType` object
-        :return: a wrapped :class:`UserDefinedFunction`
+        In addition to a name and the function itself, `returnType` can be optionally specified.
+        1) When f is a Python function, `returnType` defaults to a string. The produced object must
+        match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
+        type of the given UDF as the return type of the registered UDF. The input parameter
+        `returnType` is None by default. If given by users, the value must be None.
+
+        :param name: name of the UDF in SQL statements.
+        :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
+            row-at-a-time or vectorized.
+        :param returnType: the return type of the registered UDF.
+        :return: a wrapped/native :class:`UserDefinedFunction`
 
         >>> strlen = spark.catalog.registerFunction("stringLengthString", len)
         >>> spark.sql("SELECT stringLengthString('test')").collect()
@@ -256,27 +261,55 @@ class Catalog(object):
         >>> spark.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
 
+        >>> from pyspark.sql.types import IntegerType
+        >>> from pyspark.sql.functions import udf
+        >>> slen = udf(lambda s: len(s), IntegerType())
+        >>> _ = spark.udf.register("slen", slen)
+        >>> spark.sql("SELECT slen('test')").collect()
+        [Row(slen(test)=4)]
+
         >>> import random
         >>> from pyspark.sql.functions import udf
-        >>> from pyspark.sql.types import IntegerType, StringType
+        >>> from pyspark.sql.types import IntegerType
         >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
-        >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
+        >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
         >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
-        [Row(random_udf()=u'82')]
-        >>> spark.range(1).select(newRandom_udf()).collect()  # doctest: +SKIP
-        [Row(random_udf()=u'62')]
+        [Row(random_udf()=82)]
+        >>> spark.range(1).select(new_random_udf()).collect()  # doctest: +SKIP
+        [Row(<lambda>()=26)]
+
+        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+        >>> @pandas_udf("integer", PandasUDFType.SCALAR)  # doctest: +SKIP
+        ... def add_one(x):
+        ...     return x + 1
+        ...
+        >>> _ = spark.udf.register("add_one", add_one)  # doctest: +SKIP
+        >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
+        [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
         """
 
         # This is to check whether the input function is a wrapped/native UserDefinedFunction
         if hasattr(f, 'asNondeterministic'):
-            udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
-                                      evalType=PythonEvalType.SQL_BATCHED_UDF,
-                                      deterministic=f.deterministic)
+            if returnType is not None:
+                raise TypeError(
+                    "Invalid returnType: None is expected when f is a UserDefinedFunction, "
+                    "but got %s." % returnType)
+            if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
+                                  PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
+                raise ValueError(
+                    "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
+            register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
+                                               evalType=f.evalType,
+                                               deterministic=f.deterministic)
+            return_udf = f
         else:
-            udf = UserDefinedFunction(f, returnType=returnType, name=name,
-                                      evalType=PythonEvalType.SQL_BATCHED_UDF)
-        self._jsparkSession.udf().registerPython(name, udf._judf)
-        return udf._wrapped()
+            if returnType is None:
+                returnType = StringType()
+            register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
+                                               evalType=PythonEvalType.SQL_BATCHED_UDF)
+            return_udf = register_udf._wrapped()
+        self._jsparkSession.udf().registerPython(name, register_udf._judf)
+        return return_udf
 
     @since(2.0)
     def isCached(self, tableName):

http://git-wip-us.apache.org/repos/asf/spark/blob/b85eb946/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index b8d86cc..8547909 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -174,18 +174,23 @@ class SQLContext(object):
 
     @ignore_unicode_prefix
     @since(1.2)
-    def registerFunction(self, name, f, returnType=StringType()):
+    def registerFunction(self, name, f, returnType=None):
         """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
-        as a UDF. The registered UDF can be used in SQL statement.
+        as a UDF. The registered UDF can be used in SQL statements.
 
-        In addition to a name and the function itself, the return type can be optionally specified.
-        When the return type is not given it default to a string and conversion will automatically
-        be done.  For any other return type, the produced object must match the specified type.
+        :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`.
 
-        :param name: name of the UDF
-        :param f: a Python function, or a wrapped/native UserDefinedFunction
-        :param returnType: a :class:`pyspark.sql.types.DataType` object
-        :return: a wrapped :class:`UserDefinedFunction`
+        In addition to a name and the function itself, `returnType` can be optionally specified.
+        1) When f is a Python function, `returnType` defaults to a string. The produced object must
+        match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
+        type of the given UDF as the return type of the registered UDF. The input parameter
+        `returnType` is None by default. If given by users, the value must be None.
+
+        :param name: name of the UDF in SQL statements.
+        :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
+            row-at-a-time or vectorized.
+        :param returnType: the return type of the registered UDF.
+        :return: a wrapped/native :class:`UserDefinedFunction`
 
         >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
         >>> sqlContext.sql("SELECT stringLengthString('test')").collect()
@@ -204,15 +209,31 @@ class SQLContext(object):
         >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
 
+        >>> from pyspark.sql.types import IntegerType
+        >>> from pyspark.sql.functions import udf
+        >>> slen = udf(lambda s: len(s), IntegerType())
+        >>> _ = sqlContext.udf.register("slen", slen)
+        >>> sqlContext.sql("SELECT slen('test')").collect()
+        [Row(slen(test)=4)]
+
         >>> import random
         >>> from pyspark.sql.functions import udf
-        >>> from pyspark.sql.types import IntegerType, StringType
+        >>> from pyspark.sql.types import IntegerType
         >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
-        >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
+        >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf)
         >>> sqlContext.sql("SELECT random_udf()").collect()  # doctest: +SKIP
-        [Row(random_udf()=u'82')]
-        >>> sqlContext.range(1).select(newRandom_udf()).collect()  # doctest: +SKIP
-        [Row(random_udf()=u'62')]
+        [Row(random_udf()=82)]
+        >>> sqlContext.range(1).select(new_random_udf()).collect()  # doctest: +SKIP
+        [Row(<lambda>()=26)]
+
+        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+        >>> @pandas_udf("integer", PandasUDFType.SCALAR)  # doctest: +SKIP
+        ... def add_one(x):
+        ...     return x + 1
+        ...
+        >>> _ = sqlContext.udf.register("add_one", add_one)  # doctest: +SKIP
+        >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
+        [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
         """
         return self.sparkSession.catalog.registerFunction(name, f, returnType)
 
@@ -575,7 +596,7 @@ class UDFRegistration(object):
     def __init__(self, sqlContext):
         self.sqlContext = sqlContext
 
-    def register(self, name, f, returnType=StringType()):
+    def register(self, name, f, returnType=None):
         return self.sqlContext.registerFunction(name, f, returnType)
 
     def registerJavaFunction(self, name, javaClassName, returnType=None):

http://git-wip-us.apache.org/repos/asf/spark/blob/b85eb946/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 80a94a9..8906618 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -380,12 +380,25 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(4, res[0])
 
     def test_udf3(self):
-        twoargs = self.spark.catalog.registerFunction(
-            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
-        self.assertEqual(twoargs.deterministic, True)
+        two_args = self.spark.catalog.registerFunction(
+            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
+        self.assertEqual(two_args.deterministic, True)
+        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], u'5')
+
+    def test_udf_registration_return_type_none(self):
+        two_args = self.spark.catalog.registerFunction(
+            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
+        self.assertEqual(two_args.deterministic, True)
         [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
         self.assertEqual(row[0], 5)
 
+    def test_udf_registration_return_type_not_none(self):
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(TypeError, "Invalid returnType"):
+                self.spark.catalog.registerFunction(
+                    "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
+
     def test_nondeterministic_udf(self):
         # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
         from pyspark.sql.functions import udf
@@ -402,12 +415,12 @@ class SQLTests(ReusedSQLTestCase):
         from pyspark.sql.functions import udf
         random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
         self.assertEqual(random_udf.deterministic, False)
-        random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
+        random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
         self.assertEqual(random_udf1.deterministic, False)
         [row] = self.spark.sql("SELECT randInt()").collect()
-        self.assertEqual(row[0], "6")
+        self.assertEqual(row[0], 6)
         [row] = self.spark.range(1).select(random_udf1()).collect()
-        self.assertEqual(row[0], "6")
+        self.assertEqual(row[0], 6)
         [row] = self.spark.range(1).select(random_udf()).collect()
         self.assertEqual(row[0], 6)
         # render_doc() reproduces the help() exception without printing output
@@ -3691,7 +3704,7 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         ReusedSQLTestCase.tearDownClass()
 
     @property
-    def random_udf(self):
+    def nondeterministic_vectorized_udf(self):
         from pyspark.sql.functions import pandas_udf
 
         @pandas_udf('double')
@@ -3726,6 +3739,21 @@ class VectorizedUDFTests(ReusedSQLTestCase):
                         bool_f(col('bool')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_register_nondeterministic_vectorized_udf_basic(self):
+        from pyspark.sql.functions import pandas_udf
+        from pyspark.rdd import PythonEvalType
+        import random
+        random_pandas_udf = pandas_udf(
+            lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
+        self.assertEqual(random_pandas_udf.deterministic, False)
+        self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+        nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
+            "randomPandasUDF", random_pandas_udf)
+        self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
+        self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+        [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
+        self.assertEqual(row[0], 7)
+
     def test_vectorized_udf_null_boolean(self):
         from pyspark.sql.functions import pandas_udf, col
         data = [(True,), (True,), (None,), (False,)]
@@ -4085,14 +4113,14 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         finally:
             self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
-    def test_nondeterministic_udf(self):
+    def test_nondeterministic_vectorized_udf(self):
         # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
         from pyspark.sql.functions import udf, pandas_udf, col
 
         @pandas_udf('double')
         def plus_ten(v):
             return v + 10
-        random_udf = self.random_udf
+        random_udf = self.nondeterministic_vectorized_udf
 
         df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
         result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()
@@ -4100,11 +4128,11 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         self.assertEqual(random_udf.deterministic, False)
         self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))
 
-    def test_nondeterministic_udf_in_aggregate(self):
+    def test_nondeterministic_vectorized_udf_in_aggregate(self):
         from pyspark.sql.functions import pandas_udf, sum
 
         df = self.spark.range(10)
-        random_udf = self.random_udf
+        random_udf = self.nondeterministic_vectorized_udf
 
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
@@ -4112,6 +4140,23 @@ class VectorizedUDFTests(ReusedSQLTestCase):
             with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
                 df.agg(sum(random_udf(df.id))).collect()
 
+    def test_register_vectorized_udf_basic(self):
+        from pyspark.rdd import PythonEvalType
+        from pyspark.sql.functions import pandas_udf, col, expr
+        df = self.spark.range(10).select(
+            col('id').cast('int').alias('a'),
+            col('id').cast('int').alias('b'))
+        original_add = pandas_udf(lambda x, y: x + y, IntegerType())
+        self.assertEqual(original_add.deterministic, True)
+        self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+        new_add = self.spark.catalog.registerFunction("add1", original_add)
+        res1 = df.select(new_add(col('a'), col('b')))
+        res2 = self.spark.sql(
+            "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
+        expected = df.select(expr('a + b'))
+        self.assertEquals(expected.collect(), res1.collect())
+        self.assertEquals(expected.collect(), res2.collect())
+
 
 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
 class GroupbyApplyTests(ReusedSQLTestCase):
@@ -4147,6 +4192,15 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
         self.assertFramesEqual(expected, result)
 
+    def test_register_group_map_udf(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP)
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or '
+                                                     'SQL_PANDAS_SCALAR_UDF'):
+                self.spark.catalog.registerFunction("foo_udf", foo_udf)
+
     def test_decorator(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
         df = self.data


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org