You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2022/08/25 18:30:40 UTC
[spark] branch master updated: [SPARK-40130][PYTHON] Support NumPy scalars in built-in functions
This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fde833c5326 [SPARK-40130][PYTHON] Support NumPy scalars in built-in functions
fde833c5326 is described below
commit fde833c532630092204dc54299702676e1de8b74
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Thu Aug 25 11:30:18 2022 -0700
[SPARK-40130][PYTHON] Support NumPy scalars in built-in functions
### What changes were proposed in this pull request?
Support NumPy scalars in built-in functions by introducing Py4J input converter `NumpyScalarConverter`.
Specifically,
- `np.int8, np.int16, np.int32, np.int64` are mapped to Spark `int/bigint`.
- `np.float32, np.float64` are mapped to Spark `double`.
Note that 2147483648 is the boundary between Spark `int` and `bigint`:
```py
>>> df.select(lit(np.int64(max_int + 1))).dtypes
[('2147483648', 'bigint')]
>>> df.select(lit(np.int64(max_int))).dtypes
[('2147483647', 'int')]
```
### Why are the changes needed?
As part of [SPARK-39405](https://issues.apache.org/jira/browse/SPARK-39405) for NumPy support in SQL.
### Does this PR introduce _any_ user-facing change?
Yes. NumPy scalars are supported in built-in functions when input parameter accepts scalars;
Influenced functions include `lit`, `when`, `array_contains`, `array_position`, `element_at`, `array_remove`.
Take `lit` for example,
```py
>>> df.select(lit(np.int8(1))).dtypes
[('1', 'int')]
>>> df.select(lit(np.float32(1))).dtypes
[('1.0', 'double')]
```
### How was this patch tested?
Unit tests.
Closes #37560 from xinrong-meng/builtin_np.
Authored-by: Xinrong Meng <xi...@apache.org>
Signed-off-by: Xinrong Meng <xi...@apache.org>
---
python/pyspark/sql/tests/test_functions.py | 29 +++++++++++++++++++++++++++++
python/pyspark/sql/types.py | 13 +++++++++++++
python/pyspark/sql/utils.py | 9 +++++++++
3 files changed, 51 insertions(+)
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 71c6bc33dbb..102ebef8317 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -20,6 +20,7 @@ from inspect import getmembers, isfunction
from itertools import chain
import re
import math
+import unittest
from py4j.protocol import Py4JJavaError
from pyspark.sql import Row, Window, types
@@ -55,6 +56,7 @@ from pyspark.sql.functions import (
)
from pyspark.sql import functions
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.utils import have_numpy
class FunctionsTests(ReusedSQLTestCase):
@@ -974,6 +976,33 @@ class FunctionsTests(ReusedSQLTestCase):
)
)
+ @unittest.skipIf(not have_numpy, "NumPy not installed")
+ def test_np_scalar_input(self):
+ import numpy as np
+ from pyspark.sql.functions import array_contains, array_position
+
+ df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"])
+ for dtype in [np.int8, np.int16, np.int32, np.int64]:
+ self.assertEqual(df.select(lit(dtype(1))).dtypes, [("1", "int")])
+ res = df.select(array_contains(df.data, dtype(1)).alias("b")).collect()
+ self.assertEqual([Row(b=True), Row(b=False)], res)
+ res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
+ self.assertEqual([Row(c=1), Row(c=0)], res)
+
+ # java.lang.Integer max: 2147483647
+ max_int = 2147483647
+ # Convert int to bigint automatically
+ self.assertEqual(df.select(lit(np.int32(max_int))).dtypes, [("2147483647", "int")])
+ self.assertEqual(df.select(lit(np.int64(max_int + 1))).dtypes, [("2147483648", "bigint")])
+
+ df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"])
+ for dtype in [np.float32, np.float64]:
+ self.assertEqual(df.select(lit(dtype(1))).dtypes, [("1.0", "double")])
+ res = df.select(array_contains(df.data, dtype(1)).alias("b")).collect()
+ self.assertEqual([Row(b=True), Row(b=False)], res)
+ res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
+ self.assertEqual([Row(c=1), Row(c=0)], res)
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8255aca8f52..c1e6a738bc6 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -49,6 +49,10 @@ from py4j.protocol import register_input_converter
from py4j.java_gateway import GatewayClient, JavaClass, JavaObject
from pyspark.serializers import CloudPickleSerializer
+from pyspark.sql.utils import has_numpy
+
+if has_numpy:
+ import numpy as np
T = TypeVar("T")
U = TypeVar("U")
@@ -2256,11 +2260,20 @@ class DayTimeIntervalTypeConverter:
)
+class NumpyScalarConverter:
+ def can_convert(self, obj: Any) -> bool:
+ return has_numpy and isinstance(obj, np.generic)
+
+ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
+ return obj.item()
+
+
# datetime is a subclass of date, we should register DatetimeConverter first
register_input_converter(DatetimeNTZConverter())
register_input_converter(DatetimeConverter())
register_input_converter(DateConverter())
register_input_converter(DayTimeIntervalTypeConverter())
+register_input_converter(NumpyScalarConverter())
def _test() -> None:
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index e4a0299164e..2ff13cd2bba 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -29,6 +29,15 @@ from py4j.protocol import Py4JJavaError
from pyspark import SparkContext
from pyspark.find_spark_home import _find_spark_home
+has_numpy = False
+try:
+ import numpy as np # noqa: F401
+
+ has_numpy = True
+except ImportError:
+ pass
+
+
if TYPE_CHECKING:
from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org