You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/01/22 06:55:12 UTC

[spark] branch master updated: [SPARK-25811][PYSPARK] Raise a proper error when unsafe cast is detected by PyArrow

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f92d276  [SPARK-25811][PYSPARK] Raise a proper error when unsafe cast is detected by PyArrow
f92d276 is described below

commit f92d2766535d882b17f6d3b061d1df57bc84a90e
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue Jan 22 14:54:41 2019 +0800

    [SPARK-25811][PYSPARK] Raise a proper error when unsafe cast is detected by PyArrow
    
    ## What changes were proposed in this pull request?
    
    Since 0.11.0, PyArrow supports to raise an error for unsafe cast ([PR](https://github.com/apache/arrow/pull/2504)). We should use it to raise a proper error for pandas udf users when such cast is detected.
    
    Added a SQL config `spark.sql.execution.pandas.arrowSafeTypeConversion` to disable Arrow safe type check.
    
    ## How was this patch tested?
    
    Added test and manually test.
    
    Closes #22807 from viirya/SPARK-25811.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 docs/sql-migration-guide-upgrade.md                | 48 ++++++++++++++++++
 python/pyspark/serializers.py                      | 19 +++++--
 python/pyspark/sql/session.py                      |  3 +-
 python/pyspark/sql/tests/test_pandas_udf.py        | 58 ++++++++++++++++++++++
 python/pyspark/worker.py                           |  4 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    | 12 +++++
 .../spark/sql/execution/arrow/ArrowUtils.scala     |  4 +-
 7 files changed, 141 insertions(+), 7 deletions(-)

diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md
index 5d3d4c6..3d1b804 100644
--- a/docs/sql-migration-guide-upgrade.md
+++ b/docs/sql-migration-guide-upgrade.md
@@ -41,6 +41,54 @@ displayTitle: Spark SQL Upgrading Guide
 
   - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.
 
+  - In PySpark, when Arrow optimization is enabled, if Arrow version is higher than 0.11.0, Arrow can perform safe type conversion when converting Pandas.Series to Arrow array during serialization. Arrow will raise errors when detecting unsafe type conversion like overflow. Setting `spark.sql.execution.pandas.arrowSafeTypeConversion` to true can enable it. The default setting is false. PySpark's behavior for Arrow versions is illustrated in the table below:
+  <table class="table">
+        <tr>
+          <th>
+            <b>PyArrow version</b>
+          </th>
+          <th>
+            <b>Integer Overflow</b>
+          </th>
+          <th>
+            <b>Floating Point Truncation</b>
+          </th>
+        </tr>
+        <tr>
+          <th>
+            <b>version < 0.11.0</b>
+          </th>
+          <th>
+            <b>Raise error</b>
+          </th>
+          <th>
+            <b>Silently allows</b>
+          </th>
+        </tr>
+        <tr>
+          <th>
+            <b>version > 0.11.0, arrowSafeTypeConversion=false</b>
+          </th>
+          <th>
+            <b>Silent overflow</b>
+          </th>
+          <th>
+            <b>Silently allows</b>
+          </th>
+        </tr>
+        <tr>
+          <th>
+            <b>version > 0.11.0, arrowSafeTypeConversion=true</b>
+          </th>
+          <th>
+            <b>Raise error</b>
+          </th>
+          <th>
+            <b>Raise error</b>
+          </th>
+        </tr>
+  </table>
+
   - In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is int [...]
 
 ## Upgrading From Spark SQL 2.3 to 2.4
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fd46952..741dfb2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -245,7 +245,7 @@ class ArrowStreamSerializer(Serializer):
         return "ArrowStreamSerializer"
 
 
-def _create_batch(series, timezone):
+def _create_batch(series, timezone, safecheck):
     """
     Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
 
@@ -284,7 +284,17 @@ def _create_batch(series, timezone):
         elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
             # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
             return pa.Array.from_pandas(s, mask=mask, type=t)
-        return pa.Array.from_pandas(s, mask=mask, type=t, safe=False)
+
+        try:
+            array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck)
+        except pa.ArrowException as e:
+            error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
+                        "Array (%s). It can be caused by overflows or other unsafe " + \
+                        "conversions warned by Arrow. Arrow safe type check can be " + \
+                        "disabled by using SQL config " + \
+                        "`spark.sql.execution.pandas.arrowSafeTypeConversion`."
+            raise RuntimeError(error_msg % (s.dtype, t), e)
+        return array
 
     arrs = [create_array(s, t) for s, t in series]
     return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
@@ -295,9 +305,10 @@ class ArrowStreamPandasSerializer(Serializer):
     Serializes Pandas.Series as Arrow data with Arrow streaming format.
     """
 
-    def __init__(self, timezone):
+    def __init__(self, timezone, safecheck):
         super(ArrowStreamPandasSerializer, self).__init__()
         self._timezone = timezone
+        self._safecheck = safecheck
 
     def arrow_to_pandas(self, arrow_column):
         from pyspark.sql.types import from_arrow_type, \
@@ -317,7 +328,7 @@ class ArrowStreamPandasSerializer(Serializer):
         writer = None
         try:
             for series in iterator:
-                batch = _create_batch(series, self._timezone)
+                batch = _create_batch(series, self._timezone, self._safecheck)
                 if writer is None:
                     write_int(SpecialLengths.START_ARROW_STREAM, stream)
                     writer = pa.RecordBatchStreamWriter(stream, batch.schema)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 6f4b327..bdf1701 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -556,8 +556,9 @@ class SparkSession(object):
         pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
 
         # Create Arrow record batches
+        safecheck = self._wrapped._conf.arrowSafeTypeConversion()
         batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
-                                 timezone)
+                                 timezone, safecheck)
                    for pdf_slice in pdf_slices]
 
         # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py
index d4d9679..fd6d4e1 100644
--- a/python/pyspark/sql/tests/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/test_pandas_udf.py
@@ -197,6 +197,64 @@ class PandasUDFTests(ReusedSQLTestCase):
             ).collect
         )
 
+    def test_pandas_udf_detect_unsafe_type_conversion(self):
+        from distutils.version import LooseVersion
+        import pandas as pd
+        import numpy as np
+        import pyarrow as pa
+
+        values = [1.0] * 3
+        pdf = pd.DataFrame({'A': values})
+        df = self.spark.createDataFrame(pdf).repartition(1)
+
+        @pandas_udf(returnType="int")
+        def udf(column):
+            return pd.Series(np.linspace(0, 1, 3))
+
+        # Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
+        if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
+            with self.sql_conf({
+                    "spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
+                with self.assertRaisesRegexp(Exception,
+                                             "Exception thrown when converting pandas.Series"):
+                    df.select(['A']).withColumn('udf', udf('A')).collect()
+
+        # Disabling Arrow safe type check.
+        with self.sql_conf({
+                "spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
+            df.select(['A']).withColumn('udf', udf('A')).collect()
+
+    def test_pandas_udf_arrow_overflow(self):
+        from distutils.version import LooseVersion
+        import pandas as pd
+        import pyarrow as pa
+
+        df = self.spark.range(0, 1)
+
+        @pandas_udf(returnType="byte")
+        def udf(column):
+            return pd.Series([128])
+
+        # Arrow 0.11.0+ allows enabling or disabling safe type check.
+        if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
+            # When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
+            with self.sql_conf({
+                    "spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
+                with self.assertRaisesRegexp(Exception,
+                                             "Exception thrown when converting pandas.Series"):
+                    df.withColumn('udf', udf('id')).collect()
+
+            # Disabling safe type check, let Arrow do the cast anyway.
+            with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
+                df.withColumn('udf', udf('id')).collect()
+        else:
+            # SQL config `arrowSafeTypeConversion` no matters for older Arrow.
+            # Overflow cast causes an error.
+            with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
+                with self.assertRaisesRegexp(Exception,
+                                             "Integer value out of bounds"):
+                    df.withColumn('udf', udf('id')).collect()
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.test_pandas_udf import *
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 1e7424a..01934a0 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -252,7 +252,9 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
         timezone = runner_conf.get("spark.sql.session.timeZone", None)
-        ser = ArrowStreamPandasSerializer(timezone)
+        safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
+                                    "false").lower() == 'true'
+        ser = ArrowStreamPandasSerializer(timezone, safecheck)
     else:
         ser = BatchedSerializer(PickleSerializer(), 100)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ebc8c37..6b301c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1324,6 +1324,16 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val PANDAS_ARROW_SAFE_TYPE_CONVERSION =
+    buildConf("spark.sql.execution.pandas.arrowSafeTypeConversion")
+      .internal()
+      .doc("When true, Arrow will perform safe type conversion when converting " +
+        "Pandas.Series to Arrow array during serialization. Arrow will raise errors " +
+        "when detecting unsafe type conversion like overflow. When false, disabling Arrow's type " +
+        "check and do type conversions anyway. This config only works for Arrow 0.11.0+.")
+      .booleanConf
+      .createWithDefault(false)
+
   val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter")
     .internal()
     .doc("When true, the apply function of the rule verifies whether the right node of the" +
@@ -1998,6 +2008,8 @@ class SQLConf extends Serializable with Logging {
   def pandasGroupedMapAssignColumnsByName: Boolean =
     getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)
 
+  def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION)
+
   def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
 
   def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
index b1e8fb3..7de6256 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
@@ -133,6 +133,8 @@ object ArrowUtils {
     }
     val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
       conf.pandasGroupedMapAssignColumnsByName.toString)
-    Map(timeZoneConf ++ pandasColsByName: _*)
+    val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
+      conf.arrowSafeTypeConversion.toString)
+    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org