You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/08/24 00:31:01 UTC

[GitHub] [spark] xinrong-meng opened a new pull request, #37635: [WIP] Support NumPy arrays in built-in functions

xinrong-meng opened a new pull request, #37635:
URL: https://github.com/apache/spark/pull/37635

   ### What changes were proposed in this pull request?
   
   
   
   ### Why are the changes needed?
   
   
   ### Does this PR introduce _any_ user-facing change?
   
   
   
   ### How was this patch tested?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on PR #37635:
URL: https://github.com/apache/spark/pull/37635#issuecomment-1244036109

   Thank you all! Merged to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r956438632


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        # np.array([]).dtype is dtype('float64') so set float for empty plist
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,

Review Comment:
   Let me know if there is a better approach :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r959088849


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        # np.array([]).dtype is dtype('float64') so set float for empty plist
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,

Review Comment:
   and I believe we already have the type mapping defined in pandas API on Spark somewhere IIRC



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [WIP] Support NumPy arrays in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r953211654


##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(

Review Comment:
   Cannot 
   ```
   self.assertEqual(Row(b=[0.10000000149011612, 0.20000000298023224]), self.spark.range(1).select(lit(float_arrs[0]).alias("b")).collect())
   ```
   due to 
   ```
   AssertionError: Row(b=[0.10000000149011612, 0.20000000298023224]) != [Row(b=[0.10000000149011612, 0.20000000298023224])]
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r957664207


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)

Review Comment:
   Good point! Added `ndim` check.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r958558484


##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(
+            [("b", "array<double>")],
+            self.spark.range(1).select(lit(float_arrs[0]).alias("b")).dtypes,
+        )
+        self.assertEqual(

Review Comment:
   Thank you!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins commented on pull request #37635: [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
AmplabJenkins commented on PR #37635:
URL: https://github.com/apache/spark/pull/37635#issuecomment-1225785726

   Can one of the admins verify this patch?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r957688493


##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(
+            [("b", "array<double>")],
+            self.spark.range(1).select(lit(float_arrs[0]).alias("b")).dtypes,
+        )
+        self.assertEqual(

Review Comment:
   My understanding is array functions like `concat` / `array_intersect` would not accept numpy array input.
   Do they accept Java arrays in Scala side?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r962407136


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,48 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+        plist = obj.tolist()
+        tpe_np_to_java = {

Review Comment:
   nit, what about moving this dict outside of `convert`, so it can be reused



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #37635:
URL: https://github.com/apache/spark/pull/37635#issuecomment-1237564774

   Merged to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r964005049


##########
python/pyspark/sql/types.py:
##########
@@ -1447,6 +1447,26 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]:
     return None
 
 
+def _from_numpy_type_to_java_type(nt: "np.dtype", gateway: JavaGateway) -> Optional[JavaClass]:

Review Comment:
   Did you mean an instance method?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r955524424


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        # np.array([]).dtype is dtype('float64') so set float for empty plist
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,

Review Comment:
   Shouldn't we map this type from NumPy dtype?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r959088633


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        # np.array([]).dtype is dtype('float64') so set float for empty plist
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,

Review Comment:
   Hm, unlike `obj.item` in which we have to pass Python primitive type; thus, resulting that JVM side type precision cannot be specified, here we can have more correct size in the JVM array.
   
   I think it's better to have the correct type in the element ... Ideally we should make `obj.item` respect the numpy dtype too..



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r963179064


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2288,38 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+        plist = obj.tolist()
+
+        if len(obj) > 0 and isinstance(plist[0], str):
+            jtpe = gateway.jvm.String
+        else:
+            jtpe = _from_numpy_type_to_java_type(obj.dtype, gateway)
+            if jtpe is None:
+                raise TypeError("The type of array scalar is not supported")

Review Comment:
   oops, yeah. let's add one negative test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r955311510


##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(
+            [("b", "array<double>")],
+            self.spark.range(1).select(lit(float_arrs[0]).alias("b")).dtypes,
+        )
+        self.assertEqual(

Review Comment:
   We cannot compare the Row equality by calling `collect` with the error showed below:
   due to 
   ```
   AssertionError: Row(b=[0.10000000149011612, 0.20000000298023224]) != [Row(b=[0.10000000149011612, 0.20000000298023224])]
   ```
   
   We compare `dtypes` and the actual data in the Row separately.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng closed pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng closed pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions
URL: https://github.com/apache/spark/pull/37635


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r954114300


##########
python/pyspark/sql/types.py:
##########
@@ -2256,11 +2260,47 @@ def convert(self, obj: datetime.timedelta, gateway_client: GatewayClient) -> Jav
         )
 
 
+class NumpyScalarConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.generic)
+
+    def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
+        return obj.item()
+
+
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,
+            float: gateway.jvm.double,
+            bool: gateway.jvm.boolean,
+            str: gateway.jvm.String,
+        }
+        jarr = gateway.new_array(tpe_dict[ptpe], len(plist))

Review Comment:
   The Java type of the array is required in order to create a Java array. So `tpe_dict` is created to map Python types to Java types.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r956438477


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        # np.array([]).dtype is dtype('float64') so set float for empty plist
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,

Review Comment:
   Since `plist = obj.tolist()`, `plist` is a list of Python scalars, see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tolist.html.
   
   So `tpe_dict` maps Python types to Java type.
   
   That's consistent with `NumpyScalarConverter.convert` which calls `obj.item()`, see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.item.html.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r956438477


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        # np.array([]).dtype is dtype('float64') so set float for empty plist
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,

Review Comment:
   Since `plist = obj.tolist()`, `plist` is a list of Python scalars, see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tolist.html.
   
   So `tpe_dict` maps Python types to Java type.
   
   That's consistent with `NumpyArrayConverter.convert` which calls `obj.item()`, see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.item.html.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r963882160


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2288,38 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+        plist = obj.tolist()
+
+        if len(obj) > 0 and isinstance(plist[0], str):
+            jtpe = gateway.jvm.String
+        else:
+            jtpe = _from_numpy_type_to_java_type(obj.dtype, gateway)
+            if jtpe is None:
+                raise TypeError("The type of array scalar is not supported")

Review Comment:
   Sounds good!
   Optimized the TypeError message as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on PR #37635:
URL: https://github.com/apache/spark/pull/37635#issuecomment-1227802459

   May I get a review? Thanks! @HyukjinKwon @ueshin @zhengruifeng 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r955553096


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)

Review Comment:
   do we need to check the shape here?
   
   the `obj` may be a tensor like :
   
   ```
   In [11]: obj = np.zeros([2,3,4,5])
   
   In [12]: obj.shape
   Out[12]: (2, 3, 4, 5)
   ```



##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(
+            [("b", "array<double>")],
+            self.spark.range(1).select(lit(float_arrs[0]).alias("b")).dtypes,
+        )
+        self.assertEqual(

Review Comment:
   What about adding a few tests which use array function?
   
   for example: `concat` / `array_intersect`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r954103989


##########
python/pyspark/sql/types.py:
##########
@@ -2256,11 +2260,47 @@ def convert(self, obj: datetime.timedelta, gateway_client: GatewayClient) -> Jav
         )
 
 
+class NumpyScalarConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.generic)
+
+    def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
+        return obj.item()
+
+
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,
+            float: gateway.jvm.double,
+            bool: gateway.jvm.boolean,
+            str: gateway.jvm.String,
+        }
+        jarr = gateway.new_array(tpe_dict[ptpe], len(plist))
+        for i in range(len(plist)):
+            jarr[i] = plist[i]
+        return jarr
+
+
 # datetime is a subclass of date, we should register DatetimeConverter first
 register_input_converter(DatetimeNTZConverter())
 register_input_converter(DatetimeConverter())
 register_input_converter(DateConverter())
 register_input_converter(DayTimeIntervalTypeConverter())
+register_input_converter(NumpyScalarConverter())
+# NumPy array satisfies py4j.java_collections.ListConverter,
+# so prepend NumpyArrayConverter
+register_input_converter(NumpyArrayConverter(), prepend=True)

Review Comment:
   ```py
   >>> from py4j.java_collections import ListConverter
   >>> ndarr = np.array([1, 2])
   >>> ListConverter().can_convert(ndarr)
   True
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r957956802


##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(
+            [("b", "array<double>")],
+            self.spark.range(1).select(lit(float_arrs[0]).alias("b")).dtypes,
+        )
+        self.assertEqual(

Review Comment:
   yes, they should be in a `Literal`. Nvm



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r963004995


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,48 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+        plist = obj.tolist()
+        tpe_np_to_java = {

Review Comment:
   We cannot import `SparkContext` from the module level. And we may want to do a nullability check for "SparkContext._gateway". So `_from_numpy_type_to_java_type` is introduced instead fro code reuse. Let me know if you have a better idea :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r963004995


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,48 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+        plist = obj.tolist()
+        tpe_np_to_java = {

Review Comment:
   We cannot import `SparkContext` from the module level. And we may want to do a nullability check for "SparkContext._gateway". So `_from_numpy_type_to_java_type` is introduced instead for code reuse. Let me know if you have a better idea :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r954113915


##########
python/pyspark/sql/types.py:
##########
@@ -2256,11 +2260,47 @@ def convert(self, obj: datetime.timedelta, gateway_client: GatewayClient) -> Jav
         )
 
 
+class NumpyScalarConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.generic)
+
+    def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
+        return obj.item()
+
+
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,
+            float: gateway.jvm.double,
+            bool: gateway.jvm.boolean,
+            str: gateway.jvm.String,
+        }
+        jarr = gateway.new_array(tpe_dict[ptpe], len(plist))
+        for i in range(len(plist)):
+            jarr[i] = plist[i]
+        return jarr
+
+
 # datetime is a subclass of date, we should register DatetimeConverter first
 register_input_converter(DatetimeNTZConverter())
 register_input_converter(DatetimeConverter())
 register_input_converter(DateConverter())
 register_input_converter(DayTimeIntervalTypeConverter())
+register_input_converter(NumpyScalarConverter())
+# NumPy array satisfies py4j.java_collections.ListConverter,

Review Comment:
   ```py
   >>> from py4j.java_collections import ListConverter
   >>> ndarr = np.array([1, 2])
   >>> ListConverter().can_convert(ndarr)
   True
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r953211654


##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(

Review Comment:
   Cannot compare Row equality
   ```
   self.assertEqual(Row(b=[0.10000000149011612, 0.20000000298023224]), self.spark.range(1).select(lit(float_arrs[0]).alias("b")).collect())
   ```
   due to 
   ```
   AssertionError: Row(b=[0.10000000149011612, 0.20000000298023224]) != [Row(b=[0.10000000149011612, 0.20000000298023224])]
   ```
   .
   Instead, we compare `dtypes` and the actual data in Row.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [WIP] Support NumPy arrays in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r953211654


##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1003,6 +1003,30 @@ def test_np_scalar_input(self):
             res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        int_arrs = [np.array([1, 2]).astype(t) for t in ["int8", "int16", "int32", "int64"]]
+        for arr in int_arrs:
+            self.assertEqual(
+                [Row(b=[1, 2])], self.spark.range(1).select(lit(arr).alias("b")).collect()
+            )
+
+        float_arrs = [np.array([0.1, 0.2]).astype(t) for t in ["float32", "float64"]]
+
+        self.assertEqual(

Review Comment:
   Cannot compare Row equality
   ```
   self.assertEqual(Row(b=[0.10000000149011612, 0.20000000298023224]), self.spark.range(1).select(lit(float_arrs[0]).alias("b")).collect())
   ```
   due to 
   ```
   AssertionError: Row(b=[0.10000000149011612, 0.20000000298023224]) != [Row(b=[0.10000000149011612, 0.20000000298023224])]
   ```
   .
   Instead, we compare `dtypes` and the actual data in Row.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r963179305


##########
python/pyspark/sql/types.py:
##########
@@ -1447,6 +1447,26 @@ def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]:
     return None
 
 
+def _from_numpy_type_to_java_type(nt: "np.dtype", gateway: JavaGateway) -> Optional[JavaClass]:

Review Comment:
   You can actually add this as a `NumpyArrayConverter`'s class attribute



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r962400610


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2268,40 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray)
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+
+        plist = obj.tolist()
+        # np.array([]).dtype is dtype('float64') so set float for empty plist
+        ptpe = type(plist[0]) if len(plist) > 0 else float
+        tpe_dict = {
+            int: gateway.jvm.int,

Review Comment:
   Makes sense!
   
   One limitation is `np.dtype("int8")` cann't be mapped to `gateway.jvm.byte`, create `jarr` accordingly and then do the per-element assignment.
   
   `TypeError: 'bytes' object does not support item assignment` is caused in `jarr[i] = plist[i]`.
   
   So both `int8` and `int16` are mapped to `gateway.jvm.short`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
xinrong-meng commented on PR #37635:
URL: https://github.com/apache/spark/pull/37635#issuecomment-1236457435

   Rebased to resolve conflicts. Only [bc90498](https://github.com/apache/spark/pull/37635/commits/bc90498b13d600ea2e146106a51e872b68710b8d) is new after the review.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] itholic commented on a diff in pull request #37635: [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions

Posted by GitBox <gi...@apache.org>.
itholic commented on code in PR #37635:
URL: https://github.com/apache/spark/pull/37635#discussion_r963154413


##########
python/pyspark/sql/types.py:
##########
@@ -2268,12 +2288,38 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+        plist = obj.tolist()
+
+        if len(obj) > 0 and isinstance(plist[0], str):
+            jtpe = gateway.jvm.String
+        else:
+            jtpe = _from_numpy_type_to_java_type(obj.dtype, gateway)
+            if jtpe is None:
+                raise TypeError("The type of array scalar is not supported")

Review Comment:
   Can we have a test for this ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org