You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2022/08/25 18:30:40 UTC

[spark] branch master updated: [SPARK-40130][PYTHON] Support NumPy scalars in built-in functions

This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fde833c5326 [SPARK-40130][PYTHON] Support NumPy scalars in built-in functions
fde833c5326 is described below

commit fde833c532630092204dc54299702676e1de8b74
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Thu Aug 25 11:30:18 2022 -0700

    [SPARK-40130][PYTHON] Support NumPy scalars in built-in functions
    
    ### What changes were proposed in this pull request?
    Support NumPy scalars in built-in functions by introducing Py4J input converter `NumpyScalarConverter`.
    
    Specifically,
    - `np.int8, np.int16, np.int32, np.int64` are mapped to Spark `int/bigint`.
    - `np.float32, np.float64` are mapped to Spark `double`.
    
    Note that 2147483648 is the boundary between Spark `int` and `bigint`:
    ```py
    >>> df.select(lit(np.int64(max_int + 1))).dtypes
    [('2147483648', 'bigint')]
    >>> df.select(lit(np.int64(max_int))).dtypes
    [('2147483647', 'int')]
    ```
    
    ### Why are the changes needed?
    As part of [SPARK-39405](https://issues.apache.org/jira/browse/SPARK-39405) for NumPy support in SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. NumPy scalars are supported in built-in functions when input parameter accepts scalars;
    Influenced functions include `lit`, `when`, `array_contains`, `array_position`, `element_at`, `array_remove`.
    
    Take `lit` for example,
    ```py
    >>> df.select(lit(np.int8(1))).dtypes
    [('1', 'int')]
    >>> df.select(lit(np.float32(1))).dtypes
    [('1.0', 'double')]
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #37560 from xinrong-meng/builtin_np.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Xinrong Meng <xi...@apache.org>
---
 python/pyspark/sql/tests/test_functions.py | 29 +++++++++++++++++++++++++++++
 python/pyspark/sql/types.py                | 13 +++++++++++++
 python/pyspark/sql/utils.py                |  9 +++++++++
 3 files changed, 51 insertions(+)

diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 71c6bc33dbb..102ebef8317 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -20,6 +20,7 @@ from inspect import getmembers, isfunction
 from itertools import chain
 import re
 import math
+import unittest
 
 from py4j.protocol import Py4JJavaError
 from pyspark.sql import Row, Window, types
@@ -55,6 +56,7 @@ from pyspark.sql.functions import (
 )
 from pyspark.sql import functions
 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.utils import have_numpy
 
 
 class FunctionsTests(ReusedSQLTestCase):
@@ -974,6 +976,33 @@ class FunctionsTests(ReusedSQLTestCase):
             )
         )
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_np_scalar_input(self):
+        import numpy as np
+        from pyspark.sql.functions import array_contains, array_position
+
+        df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"])
+        for dtype in [np.int8, np.int16, np.int32, np.int64]:
+            self.assertEqual(df.select(lit(dtype(1))).dtypes, [("1", "int")])
+            res = df.select(array_contains(df.data, dtype(1)).alias("b")).collect()
+            self.assertEqual([Row(b=True), Row(b=False)], res)
+            res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
+            self.assertEqual([Row(c=1), Row(c=0)], res)
+
+        # java.lang.Integer max: 2147483647
+        max_int = 2147483647
+        # Convert int to bigint automatically
+        self.assertEqual(df.select(lit(np.int32(max_int))).dtypes, [("2147483647", "int")])
+        self.assertEqual(df.select(lit(np.int64(max_int + 1))).dtypes, [("2147483648", "bigint")])
+
+        df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"])
+        for dtype in [np.float32, np.float64]:
+            self.assertEqual(df.select(lit(dtype(1))).dtypes, [("1.0", "double")])
+            res = df.select(array_contains(df.data, dtype(1)).alias("b")).collect()
+            self.assertEqual([Row(b=True), Row(b=False)], res)
+            res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
+            self.assertEqual([Row(c=1), Row(c=0)], res)
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8255aca8f52..c1e6a738bc6 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -49,6 +49,10 @@ from py4j.protocol import register_input_converter
 from py4j.java_gateway import GatewayClient, JavaClass, JavaObject
 
 from pyspark.serializers import CloudPickleSerializer
+from pyspark.sql.utils import has_numpy
+
+if has_numpy:
+    import numpy as np
 
 T = TypeVar("T")
 U = TypeVar("U")
@@ -2256,11 +2260,20 @@ class DayTimeIntervalTypeConverter:
         )
 
 
+class NumpyScalarConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.generic)
+
+    def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
+        return obj.item()
+
+
 # datetime is a subclass of date, we should register DatetimeConverter first
 register_input_converter(DatetimeNTZConverter())
 register_input_converter(DatetimeConverter())
 register_input_converter(DateConverter())
 register_input_converter(DayTimeIntervalTypeConverter())
+register_input_converter(NumpyScalarConverter())
 
 
 def _test() -> None:
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index e4a0299164e..2ff13cd2bba 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -29,6 +29,15 @@ from py4j.protocol import Py4JJavaError
 from pyspark import SparkContext
 from pyspark.find_spark_home import _find_spark_home
 
+has_numpy = False
+try:
+    import numpy as np  # noqa: F401
+
+    has_numpy = True
+except ImportError:
+    pass
+
+
 if TYPE_CHECKING:
     from pyspark.sql.session import SparkSession
     from pyspark.sql.dataframe import DataFrame


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org