You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/01/18 05:51:12 UTC

spark git commit: [SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark

Repository: spark
Updated Branches:
  refs/heads/master 021947020 -> 39d244d92


[SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark

## What changes were proposed in this pull request?

This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0.

These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`.

Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring.

This PR also handles minor doc corrections. It also includes https://github.com/apache/spark/pull/20158

## How was this patch tested?

Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly.

Author: hyukjinkwon <gu...@gmail.com>

Closes #20288 from HyukjinKwon/deprecate-udf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/39d244d9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/39d244d9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/39d244d9

Branch: refs/heads/master
Commit: 39d244d921d8d2d3ed741e8e8f1175515a74bdbd
Parents: 0219470
Author: hyukjinkwon <gu...@gmail.com>
Authored: Thu Jan 18 14:51:05 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu Jan 18 14:51:05 2018 +0900

----------------------------------------------------------------------
 dev/sparktestsupport/modules.py |   1 +
 python/pyspark/sql/catalog.py   |  91 ++----------------
 python/pyspark/sql/context.py   | 137 ++++----------------------
 python/pyspark/sql/functions.py |   4 +-
 python/pyspark/sql/group.py     |   3 +-
 python/pyspark/sql/session.py   |   6 +-
 python/pyspark/sql/tests.py     |  20 ++++
 python/pyspark/sql/udf.py       | 182 ++++++++++++++++++++++++++++++++++-
 8 files changed, 234 insertions(+), 210 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 7164180..b900f0b 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -400,6 +400,7 @@ pyspark_sql = Module(
         "pyspark.sql.functions",
         "pyspark.sql.readwriter",
         "pyspark.sql.streaming",
+        "pyspark.sql.udf",
         "pyspark.sql.window",
         "pyspark.sql.tests",
     ]

http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/python/pyspark/sql/catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 35fbe9e..6aef0f2 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -224,92 +224,17 @@ class Catalog(object):
         """
         self._jcatalog.dropGlobalTempView(viewName)
 
-    @ignore_unicode_prefix
     @since(2.0)
     def registerFunction(self, name, f, returnType=None):
-        """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
-        as a UDF. The registered UDF can be used in SQL statements.
-
-        :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.
-
-        In addition to a name and the function itself, `returnType` can be optionally specified.
-        1) When f is a Python function, `returnType` defaults to a string. The produced object must
-        match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
-        type of the given UDF as the return type of the registered UDF. The input parameter
-        `returnType` is None by default. If given by users, the value must be None.
-
-        :param name: name of the UDF in SQL statements.
-        :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
-            row-at-a-time or vectorized.
-        :param returnType: the return type of the registered UDF.
-        :return: a wrapped/native :class:`UserDefinedFunction`
-
-        >>> strlen = spark.catalog.registerFunction("stringLengthString", len)
-        >>> spark.sql("SELECT stringLengthString('test')").collect()
-        [Row(stringLengthString(test)=u'4')]
-
-        >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
-        [Row(stringLengthString(text)=u'3')]
-
-        >>> from pyspark.sql.types import IntegerType
-        >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
-        >>> spark.sql("SELECT stringLengthInt('test')").collect()
-        [Row(stringLengthInt(test)=4)]
-
-        >>> from pyspark.sql.types import IntegerType
-        >>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
-        >>> spark.sql("SELECT stringLengthInt('test')").collect()
-        [Row(stringLengthInt(test)=4)]
-
-        >>> from pyspark.sql.types import IntegerType
-        >>> from pyspark.sql.functions import udf
-        >>> slen = udf(lambda s: len(s), IntegerType())
-        >>> _ = spark.udf.register("slen", slen)
-        >>> spark.sql("SELECT slen('test')").collect()
-        [Row(slen(test)=4)]
-
-        >>> import random
-        >>> from pyspark.sql.functions import udf
-        >>> from pyspark.sql.types import IntegerType
-        >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
-        >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
-        >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
-        [Row(random_udf()=82)]
-        >>> spark.range(1).select(new_random_udf()).collect()  # doctest: +SKIP
-        [Row(<lambda>()=26)]
-
-        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
-        >>> @pandas_udf("integer", PandasUDFType.SCALAR)  # doctest: +SKIP
-        ... def add_one(x):
-        ...     return x + 1
-        ...
-        >>> _ = spark.udf.register("add_one", add_one)  # doctest: +SKIP
-        >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
-        [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
-        """
+        """An alias for :func:`spark.udf.register`.
+        See :meth:`pyspark.sql.UDFRegistration.register`.
 
-        # This is to check whether the input function is a wrapped/native UserDefinedFunction
-        if hasattr(f, 'asNondeterministic'):
-            if returnType is not None:
-                raise TypeError(
-                    "Invalid returnType: None is expected when f is a UserDefinedFunction, "
-                    "but got %s." % returnType)
-            if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
-                                  PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
-                raise ValueError(
-                    "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
-            register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
-                                               evalType=f.evalType,
-                                               deterministic=f.deterministic)
-            return_udf = f
-        else:
-            if returnType is None:
-                returnType = StringType()
-            register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
-                                               evalType=PythonEvalType.SQL_BATCHED_UDF)
-            return_udf = register_udf._wrapped()
-        self._jsparkSession.udf().registerPython(name, register_udf._judf)
-        return return_udf
+        .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
+        """
+        warnings.warn(
+            "Deprecated in 2.3.0. Use spark.udf.register instead.",
+            DeprecationWarning)
+        return self._sparkSession.udf.register(name, f, returnType)
 
     @since(2.0)
     def isCached(self, tableName):

http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 8547909..cc1cd1a 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -29,9 +29,10 @@ from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.readwriter import DataFrameReader
 from pyspark.sql.streaming import DataStreamReader
 from pyspark.sql.types import IntegerType, Row, StringType
+from pyspark.sql.udf import UDFRegistration
 from pyspark.sql.utils import install_exception_handler
 
-__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
+__all__ = ["SQLContext", "HiveContext"]
 
 
 class SQLContext(object):
@@ -147,7 +148,7 @@ class SQLContext(object):
 
         :return: :class:`UDFRegistration`
         """
-        return UDFRegistration(self)
+        return self.sparkSession.udf
 
     @since(1.4)
     def range(self, start, end=None, step=1, numPartitions=None):
@@ -172,113 +173,29 @@ class SQLContext(object):
         """
         return self.sparkSession.range(start, end, step, numPartitions)
 
-    @ignore_unicode_prefix
     @since(1.2)
     def registerFunction(self, name, f, returnType=None):
-        """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
-        as a UDF. The registered UDF can be used in SQL statements.
-
-        :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`.
-
-        In addition to a name and the function itself, `returnType` can be optionally specified.
-        1) When f is a Python function, `returnType` defaults to a string. The produced object must
-        match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
-        type of the given UDF as the return type of the registered UDF. The input parameter
-        `returnType` is None by default. If given by users, the value must be None.
-
-        :param name: name of the UDF in SQL statements.
-        :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
-            row-at-a-time or vectorized.
-        :param returnType: the return type of the registered UDF.
-        :return: a wrapped/native :class:`UserDefinedFunction`
-
-        >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
-        >>> sqlContext.sql("SELECT stringLengthString('test')").collect()
-        [Row(stringLengthString(test)=u'4')]
-
-        >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
-        [Row(stringLengthString(text)=u'3')]
-
-        >>> from pyspark.sql.types import IntegerType
-        >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
-        >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
-        [Row(stringLengthInt(test)=4)]
-
-        >>> from pyspark.sql.types import IntegerType
-        >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
-        >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
-        [Row(stringLengthInt(test)=4)]
-
-        >>> from pyspark.sql.types import IntegerType
-        >>> from pyspark.sql.functions import udf
-        >>> slen = udf(lambda s: len(s), IntegerType())
-        >>> _ = sqlContext.udf.register("slen", slen)
-        >>> sqlContext.sql("SELECT slen('test')").collect()
-        [Row(slen(test)=4)]
-
-        >>> import random
-        >>> from pyspark.sql.functions import udf
-        >>> from pyspark.sql.types import IntegerType
-        >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
-        >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf)
-        >>> sqlContext.sql("SELECT random_udf()").collect()  # doctest: +SKIP
-        [Row(random_udf()=82)]
-        >>> sqlContext.range(1).select(new_random_udf()).collect()  # doctest: +SKIP
-        [Row(<lambda>()=26)]
-
-        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
-        >>> @pandas_udf("integer", PandasUDFType.SCALAR)  # doctest: +SKIP
-        ... def add_one(x):
-        ...     return x + 1
-        ...
-        >>> _ = sqlContext.udf.register("add_one", add_one)  # doctest: +SKIP
-        >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
-        [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
+        """An alias for :func:`spark.udf.register`.
+        See :meth:`pyspark.sql.UDFRegistration.register`.
+
+        .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
         """
-        return self.sparkSession.catalog.registerFunction(name, f, returnType)
+        warnings.warn(
+            "Deprecated in 2.3.0. Use spark.udf.register instead.",
+            DeprecationWarning)
+        return self.sparkSession.udf.register(name, f, returnType)
 
-    @ignore_unicode_prefix
     @since(2.1)
     def registerJavaFunction(self, name, javaClassName, returnType=None):
-        """Register a java UDF so it can be used in SQL statements.
-
-        In addition to a name and the function itself, the return type can be optionally specified.
-        When the return type is not specified we would infer it via reflection.
-        :param name:  name of the UDF
-        :param javaClassName: fully qualified name of java class
-        :param returnType: a :class:`pyspark.sql.types.DataType` object
-
-        >>> sqlContext.registerJavaFunction("javaStringLength",
-        ...   "test.org.apache.spark.sql.JavaStringLength", IntegerType())
-        >>> sqlContext.sql("SELECT javaStringLength('test')").collect()
-        [Row(UDF:javaStringLength(test)=4)]
-        >>> sqlContext.registerJavaFunction("javaStringLength2",
-        ...   "test.org.apache.spark.sql.JavaStringLength")
-        >>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
-        [Row(UDF:javaStringLength2(test)=4)]
+        """An alias for :func:`spark.udf.registerJavaFunction`.
+        See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`.
 
+        .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead.
         """
-        jdt = None
-        if returnType is not None:
-            jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
-        self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
-
-    @ignore_unicode_prefix
-    @since(2.3)
-    def registerJavaUDAF(self, name, javaClassName):
-        """Register a java UDAF so it can be used in SQL statements.
-
-        :param name:  name of the UDAF
-        :param javaClassName: fully qualified name of java class
-
-        >>> sqlContext.registerJavaUDAF("javaUDAF",
-        ...   "test.org.apache.spark.sql.MyDoubleAvg")
-        >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
-        >>> df.registerTempTable("df")
-        >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
-        [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
-        """
-        self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
+        warnings.warn(
+            "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.",
+            DeprecationWarning)
+        return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType)
 
     # TODO(andrew): delete this once we refactor things to take in SparkSession
     def _inferSchema(self, rdd, samplingRatio=None):
@@ -590,24 +507,6 @@ class HiveContext(SQLContext):
         self._ssql_ctx.refreshTable(tableName)
 
 
-class UDFRegistration(object):
-    """Wrapper for user-defined function registration."""
-
-    def __init__(self, sqlContext):
-        self.sqlContext = sqlContext
-
-    def register(self, name, f, returnType=None):
-        return self.sqlContext.registerFunction(name, f, returnType)
-
-    def registerJavaFunction(self, name, javaClassName, returnType=None):
-        self.sqlContext.registerJavaFunction(name, javaClassName, returnType)
-
-    def registerJavaUDAF(self, name, javaClassName):
-        self.sqlContext.registerJavaUDAF(name, javaClassName)
-
-    register.__doc__ = SQLContext.registerFunction.__doc__
-
-
 def _test():
     import os
     import doctest

http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f7b3f29..988c1d2 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()):
     >>> import random
     >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
 
-    .. note:: The user-defined functions do not support conditional expressions or short curcuiting
+    .. note:: The user-defined functions do not support conditional expressions or short circuiting
         in boolean expressions and it ends up with being executed all internally. If the functions
         can fail on special rows, the workaround is to incorporate the condition into the functions.
 
@@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
     ...     return pd.Series(np.random.randn(len(v))
     >>> random = random.asNondeterministic()  # doctest: +SKIP
 
-    .. note:: The user-defined functions do not support conditional expressions or short curcuiting
+    .. note:: The user-defined functions do not support conditional expressions or short circuiting
         in boolean expressions and it ends up with being executed all internally. If the functions
         can fail on special rows, the workaround is to incorporate the condition into the functions.
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/python/pyspark/sql/group.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 09fae46..22061b8 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -212,7 +212,8 @@ class GroupedData(object):
         This function does not support partial aggregation, and requires shuffling all the data in
         the :class:`DataFrame`.
 
-        :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
+        :param udf: a group map user-defined function returned by
+            :meth:`pyspark.sql.functions.pandas_udf`.
 
         >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
         >>> df = spark.createDataFrame(

http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 604021c..6c84023 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -29,7 +29,6 @@ else:
 
 from pyspark import since
 from pyspark.rdd import RDD, ignore_unicode_prefix
-from pyspark.sql.catalog import Catalog
 from pyspark.sql.conf import RuntimeConfig
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.readwriter import DataFrameReader
@@ -280,6 +279,7 @@ class SparkSession(object):
 
         :return: :class:`Catalog`
         """
+        from pyspark.sql.catalog import Catalog
         if not hasattr(self, "_catalog"):
             self._catalog = Catalog(self)
         return self._catalog
@@ -291,8 +291,8 @@ class SparkSession(object):
 
         :return: :class:`UDFRegistration`
         """
-        from pyspark.sql.context import UDFRegistration
-        return UDFRegistration(self._wrapped)
+        from pyspark.sql.udf import UDFRegistration
+        return UDFRegistration(self)
 
     @since(2.0)
     def range(self, start, end=None, step=1, numPartitions=None):

http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 8906618..f84aa3d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -372,6 +372,12 @@ class SQLTests(ReusedSQLTestCase):
         [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
         self.assertEqual(row[0], 5)
 
+        # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
+        sqlContext = self.spark._wrapped
+        sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
+        [row] = sqlContext.sql("SELECT oneArg('test')").collect()
+        self.assertEqual(row[0], 4)
+
     def test_udf2(self):
         self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
         self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
@@ -577,11 +583,25 @@ class SQLTests(ReusedSQLTestCase):
             df.select(add_three("id").alias("plus_three")).collect()
         )
 
+        # This is to check if a 'SQLContext.udf' can call its alias.
+        sqlContext = self.spark._wrapped
+        add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
+
+        self.assertListEqual(
+            df.selectExpr("add_four(id) AS plus_four").collect(),
+            df.select(add_four("id").alias("plus_four")).collect()
+        )
+
     def test_non_existed_udf(self):
         spark = self.spark
         self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
                                 lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
 
+        # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
+        sqlContext = spark._wrapped
+        self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
+                                lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
+
     def test_non_existed_udaf(self):
         spark = self.spark
         self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",

http://git-wip-us.apache.org/repos/asf/spark/blob/39d244d9/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 5e80ab9..1943bb7 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -19,11 +19,13 @@ User-defined function related classes and functions
 """
 import functools
 
-from pyspark import SparkContext
-from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
+from pyspark import SparkContext, since
+from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
 
+__all__ = ["UDFRegistration"]
+
 
 def _wrap_function(sc, func, returnType):
     command = (func, returnType)
@@ -181,3 +183,179 @@ class UserDefinedFunction(object):
         """
         self.deterministic = False
         return self
+
+
+class UDFRegistration(object):
+    """
+    Wrapper for user-defined function registration. This instance can be accessed by
+    :attr:`spark.udf` or :attr:`sqlContext.udf`.
+
+    .. versionadded:: 1.3.1
+    """
+
+    def __init__(self, sparkSession):
+        self.sparkSession = sparkSession
+
+    @ignore_unicode_prefix
+    @since("1.3.1")
+    def register(self, name, f, returnType=None):
+        """Registers a Python function (including lambda function) or a user-defined function
+        in SQL statements.
+
+        :param name: name of the user-defined function in SQL statements.
+        :param f: a Python function, or a user-defined function. The user-defined function can
+            be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
+            :meth:`pyspark.sql.functions.pandas_udf`.
+        :param returnType: the return type of the registered user-defined function.
+        :return: a user-defined function.
+
+        `returnType` can be optionally specified when `f` is a Python function but not
+        when `f` is a user-defined function. Please see below.
+
+        1. When `f` is a Python function:
+
+            `returnType` defaults to string type and can be optionally specified. The produced
+            object must match the specified type. In this case, this API works as if
+            `register(name, f, returnType=StringType())`.
+
+            >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
+            >>> spark.sql("SELECT stringLengthString('test')").collect()
+            [Row(stringLengthString(test)=u'4')]
+
+            >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
+            [Row(stringLengthString(text)=u'3')]
+
+            >>> from pyspark.sql.types import IntegerType
+            >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
+            >>> spark.sql("SELECT stringLengthInt('test')").collect()
+            [Row(stringLengthInt(test)=4)]
+
+            >>> from pyspark.sql.types import IntegerType
+            >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
+            >>> spark.sql("SELECT stringLengthInt('test')").collect()
+            [Row(stringLengthInt(test)=4)]
+
+        2. When `f` is a user-defined function:
+
+            Spark uses the return type of the given user-defined function as the return type of
+            the registered user-defined function. `returnType` should not be specified.
+            In this case, this API works as if `register(name, f)`.
+
+            >>> from pyspark.sql.types import IntegerType
+            >>> from pyspark.sql.functions import udf
+            >>> slen = udf(lambda s: len(s), IntegerType())
+            >>> _ = spark.udf.register("slen", slen)
+            >>> spark.sql("SELECT slen('test')").collect()
+            [Row(slen(test)=4)]
+
+            >>> import random
+            >>> from pyspark.sql.functions import udf
+            >>> from pyspark.sql.types import IntegerType
+            >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
+            >>> new_random_udf = spark.udf.register("random_udf", random_udf)
+            >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
+            [Row(random_udf()=82)]
+
+            >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+            >>> @pandas_udf("integer", PandasUDFType.SCALAR)  # doctest: +SKIP
+            ... def add_one(x):
+            ...     return x + 1
+            ...
+            >>> _ = spark.udf.register("add_one", add_one)  # doctest: +SKIP
+            >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
+            [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
+
+            .. note:: Registration for a user-defined function (case 2.) was added from
+                Spark 2.3.0.
+        """
+
+        # This is to check whether the input function is from a user-defined function or
+        # Python function.
+        if hasattr(f, 'asNondeterministic'):
+            if returnType is not None:
+                raise TypeError(
+                    "Invalid returnType: data type can not be specified when f is"
+                    "a user-defined function, but got %s." % returnType)
+            if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
+                                  PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
+                raise ValueError(
+                    "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
+            register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
+                                               evalType=f.evalType,
+                                               deterministic=f.deterministic)
+            return_udf = f
+        else:
+            if returnType is None:
+                returnType = StringType()
+            register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
+                                               evalType=PythonEvalType.SQL_BATCHED_UDF)
+            return_udf = register_udf._wrapped()
+        self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
+        return return_udf
+
+    @ignore_unicode_prefix
+    @since(2.3)
+    def registerJavaFunction(self, name, javaClassName, returnType=None):
+        """Register a Java user-defined function so it can be used in SQL statements.
+
+        In addition to a name and the function itself, the return type can be optionally specified.
+        When the return type is not specified we would infer it via reflection.
+
+        :param name: name of the user-defined function
+        :param javaClassName: fully qualified name of java class
+        :param returnType: a :class:`pyspark.sql.types.DataType` object
+
+        >>> from pyspark.sql.types import IntegerType
+        >>> spark.udf.registerJavaFunction(
+        ...     "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
+        >>> spark.sql("SELECT javaStringLength('test')").collect()
+        [Row(UDF:javaStringLength(test)=4)]
+        >>> spark.udf.registerJavaFunction(
+        ...     "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
+        >>> spark.sql("SELECT javaStringLength2('test')").collect()
+        [Row(UDF:javaStringLength2(test)=4)]
+        """
+
+        jdt = None
+        if returnType is not None:
+            jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
+        self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
+
+    @ignore_unicode_prefix
+    @since(2.3)
+    def registerJavaUDAF(self, name, javaClassName):
+        """Register a Java user-defined aggregate function so it can be used in SQL statements.
+
+        :param name: name of the user-defined aggregate function
+        :param javaClassName: fully qualified name of java class
+
+        >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
+        >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
+        >>> df.registerTempTable("df")
+        >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
+        [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
+        """
+
+        self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
+
+
+def _test():
+    import doctest
+    from pyspark.sql import SparkSession
+    import pyspark.sql.udf
+    globs = pyspark.sql.udf.__dict__.copy()
+    spark = SparkSession.builder\
+        .master("local[4]")\
+        .appName("sql.udf tests")\
+        .getOrCreate()
+    globs['spark'] = spark
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.udf, globs=globs,
+        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+    spark.stop()
+    if failure_count:
+        exit(-1)
+
+
+if __name__ == "__main__":
+    _test()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org