You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "ueshin (via GitHub)" <gi...@apache.org> on 2023/07/31 18:37:32 UTC

[GitHub] [spark] ueshin commented on a diff in pull request #42191: [SPARK-44559][PYTHON] Improve error messages for Python UDTF arrow cast

ueshin commented on code in PR #42191:
URL: https://github.com/apache/spark/pull/42191#discussion_r1279727112


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -538,6 +538,73 @@ def _create_batch(self, series):
 
         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
 
+    def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
+        """
+        Override the `_create_array` method in the superclass to create an Arrow Array
+        from a given pandas.Series and an arrow type. The difference here is that we always
+        use arrow cast when creating the arrow array. Also, the error messages are specific
+        to arrow-optimized Python UDTFs.
+
+        Parameters
+        ----------
+        series : pandas.Series
+            A single series
+        arrow_type : pyarrow.DataType, optional
+            If None, pyarrow's inferred type will be used
+        spark_type : DataType, optional
+            If None, spark type converted from arrow_type will be used
+        arrow_cast: bool, optional
+            Whether to apply Arrow casting when the user-specified return type mismatches the
+            actual return values.
+
+        Returns
+        -------
+        pyarrow.Array
+        """
+        import pyarrow as pa
+        from pandas.api.types import is_categorical_dtype
+
+        if is_categorical_dtype(series.dtype):
+            series = series.astype(series.dtypes.categories.dtype)
+
+        if arrow_type is not None:
+            dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
+            # TODO(SPARK-43579): cache the converter for reuse
+            conv = _create_converter_from_pandas(
+                dt, timezone=self._timezone, error_on_duplicated_field_names=False
+            )
+            series = conv(series)
+
+        if hasattr(series.array, "__arrow_array__"):
+            mask = None
+        else:
+            mask = series.isnull()
+
+        try:
+            try:
+                return pa.Array.from_pandas(
+                    series, mask=mask, type=arrow_type, safe=self._safecheck
+                )
+            except pa.lib.ArrowException:
+                if arrow_cast:
+                    return pa.Array.from_pandas(series, mask=mask).cast(
+                        target_type=arrow_type, safe=self._safecheck
+                    )
+                else:
+                    raise
+        except pa.lib.ArrowException:

Review Comment:
   We might need to revisit the error message here and for UDF later.
   I feel weird if the error messages are completely different for the similar cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org