You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "xinrong-meng (via GitHub)" <gi...@apache.org> on 2023/07/28 22:15:10 UTC

[GitHub] [spark] xinrong-meng commented on a diff in pull request #42191: [SPARK-44559][PYTHON] Improve error messages for Python UDTF arrow cast

xinrong-meng commented on code in PR #42191:
URL: https://github.com/apache/spark/pull/42191#discussion_r1278119601


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -538,6 +538,73 @@ def _create_batch(self, series):
 
         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
 
+    def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
+        """
+        Override the `_create_array` method in the superclass to create an Arrow Array
+        from a given pandas.Series and an arrow type. The difference here is that we always
+        use arrow cast when creating the arrow array. Also, the error messages are specific
+        to arrow-optimized Python UDTFs.
+
+        Parameters
+        ----------
+        series : pandas.Series
+            A single series
+        arrow_type : pyarrow.DataType, optional
+            If None, pyarrow's inferred type will be used
+        spark_type : DataType, optional
+            If None, spark type converted from arrow_type will be used
+        arrow_cast: bool, optional
+            Whether to apply Arrow casting when the user-specified return type mismatches the
+            actual return values.
+
+        Returns
+        -------
+        pyarrow.Array
+        """
+        import pyarrow as pa
+        from pandas.api.types import is_categorical_dtype
+
+        if is_categorical_dtype(series.dtype):
+            series = series.astype(series.dtypes.categories.dtype)
+
+        if arrow_type is not None:
+            dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
+            # TODO(SPARK-43579): cache the converter for reuse
+            conv = _create_converter_from_pandas(
+                dt, timezone=self._timezone, error_on_duplicated_field_names=False
+            )
+            series = conv(series)
+
+        if hasattr(series.array, "__arrow_array__"):
+            mask = None
+        else:
+            mask = series.isnull()
+
+        try:
+            try:
+                return pa.Array.from_pandas(
+                    series, mask=mask, type=arrow_type, safe=self._safecheck
+                )
+            except pa.lib.ArrowException:
+                if arrow_cast:
+                    return pa.Array.from_pandas(series, mask=mask).cast(
+                        target_type=arrow_type, safe=self._safecheck
+                    )
+                else:
+                    raise
+        except pa.lib.ArrowException:

Review Comment:
   Do we intentionally ignore potential error messages of _safecheck? For example,
   
    ```
                       " It can be caused by overflows or other "
                       "unsafe conversions warned by Arrow. Arrow safe type check "
                       "can be disabled by using SQL config "
                       "`spark.sql.execution.pandas.convertToArrowArraySafely`."
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org