You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/27 00:27:15 UTC

[spark] branch branch-3.4 updated: [SPARK-42911][PYTHON][3.4] Introduce more basic exceptions

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 1b95b4d6cfc [SPARK-42911][PYTHON][3.4] Introduce more basic exceptions
1b95b4d6cfc is described below

commit 1b95b4d6cfc13db031c9f31729e7b551207a0cc3
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Mon Mar 27 09:26:47 2023 +0900

    [SPARK-42911][PYTHON][3.4] Introduce more basic exceptions
    
    ### What changes were proposed in this pull request?
    
    Introduces more basic exceptions.
    
    - ArithmeticException
    - ArrayIndexOutOfBoundsException
    - DateTimeException
    - NumberFormatException
    - SparkRuntimeException
    
    ### Why are the changes needed?
    
    There are more exceptions that Spark throws but PySpark doesn't capture.
    
    We should introduce more basic exceptions; otherwise we still see `Py4JJavaError` or `SparkConnectGrpcException`.
    
    ```py
    >>> spark.conf.set("spark.sql.ansi.enabled", True)
    >>> spark.sql("select 1/0")
    DataFrame[(1 / 0): double]
    >>> spark.sql("select 1/0").show()
    Traceback (most recent call last):
    ...
    py4j.protocol.Py4JJavaError: An error occurred while calling o44.showString.
    : org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
    == SQL(line 1, position 8) ==
    select 1/0
           ^^^
    
            at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:225)
    ... JVM's stacktrace
    ```
    
    ```py
    >>> spark.sql("select 1/0").show()
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.connect.SparkConnectGrpcException: (org.apache.spark.SparkArithmeticException) [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
    == SQL(line 1, position 8) ==
    select 1/0
           ^^^
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    The error message is more readable.
    
    ```py
    >>> spark.sql("select 1/0").show()
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
    == SQL(line 1, position 8) ==
    select 1/0
           ^^^
    ```
    
    or
    
    ```py
    >>> spark.sql("select 1/0").show()
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.connect.ArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
    == SQL(line 1, position 8) ==
    select 1/0
           ^^^
    ```
    
    ### How was this patch tested?
    
    Added the related tests.
    
    Closes #40547 from ueshin/issues/SPARK-42911/3.4/exceptions.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 dev/sparktestsupport/modules.py                    |  2 +
 python/pyspark/errors/__init__.py                  | 10 ++++
 python/pyspark/errors/exceptions/base.py           | 36 ++++++++++-
 python/pyspark/errors/exceptions/captured.py       | 52 +++++++++++++++-
 python/pyspark/errors/exceptions/connect.py        | 66 +++++++++++++++++----
 .../sql/tests/connect/test_parity_errors.py        | 36 +++++++++++
 python/pyspark/sql/tests/test_errors.py            | 69 ++++++++++++++++++++++
 7 files changed, 255 insertions(+), 16 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2b202bc333e..29bc39e14bf 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -476,6 +476,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.test_context",
         "pyspark.sql.tests.test_dataframe",
         "pyspark.sql.tests.test_datasources",
+        "pyspark.sql.tests.test_errors",
         "pyspark.sql.tests.test_functions",
         "pyspark.sql.tests.test_group",
         "pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
@@ -524,6 +525,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_connect_function",
         "pyspark.sql.tests.connect.test_connect_column",
         "pyspark.sql.tests.connect.test_parity_datasources",
+        "pyspark.sql.tests.connect.test_parity_errors",
         "pyspark.sql.tests.connect.test_parity_catalog",
         "pyspark.sql.tests.connect.test_parity_conf",
         "pyspark.sql.tests.connect.test_parity_serde",
diff --git a/python/pyspark/errors/__init__.py b/python/pyspark/errors/__init__.py
index 94117fc5160..e4143d4eb87 100644
--- a/python/pyspark/errors/__init__.py
+++ b/python/pyspark/errors/__init__.py
@@ -24,10 +24,15 @@ from pyspark.errors.exceptions.base import (  # noqa: F401
     TempTableAlreadyExistsException,
     ParseException,
     IllegalArgumentException,
+    ArithmeticException,
+    ArrayIndexOutOfBoundsException,
+    DateTimeException,
+    NumberFormatException,
     StreamingQueryException,
     QueryExecutionException,
     PythonException,
     UnknownException,
+    SparkRuntimeException,
     SparkUpgradeException,
     PySparkTypeError,
     PySparkValueError,
@@ -41,10 +46,15 @@ __all__ = [
     "TempTableAlreadyExistsException",
     "ParseException",
     "IllegalArgumentException",
+    "ArithmeticException",
+    "ArrayIndexOutOfBoundsException",
+    "DateTimeException",
+    "NumberFormatException",
     "StreamingQueryException",
     "QueryExecutionException",
     "PythonException",
     "UnknownException",
+    "SparkRuntimeException",
     "SparkUpgradeException",
     "PySparkTypeError",
     "PySparkValueError",
diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py
index fa66b80ac3a..31b69650972 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -108,7 +108,7 @@ class TempTableAlreadyExistsException(AnalysisException):
     """
 
 
-class ParseException(PySparkException):
+class ParseException(AnalysisException):
     """
     Failed to parse a SQL command.
     """
@@ -120,6 +120,30 @@ class IllegalArgumentException(PySparkException):
     """
 
 
+class ArithmeticException(PySparkException):
+    """
+    Arithmetic exception thrown from Spark with an error class.
+    """
+
+
+class ArrayIndexOutOfBoundsException(PySparkException):
+    """
+    Array index out of bounds exception thrown from Spark with an error class.
+    """
+
+
+class DateTimeException(PySparkException):
+    """
+    Datetime exception thrown from Spark with an error class.
+    """
+
+
+class NumberFormatException(IllegalArgumentException):
+    """
+    Number format exception thrown from Spark with an error class.
+    """
+
+
 class StreamingQueryException(PySparkException):
     """
     Exception that stopped a :class:`StreamingQuery`.
@@ -138,9 +162,9 @@ class PythonException(PySparkException):
     """
 
 
-class UnknownException(PySparkException):
+class SparkRuntimeException(PySparkException):
     """
-    None of the above exceptions.
+    Runtime exception thrown from Spark with an error class.
     """
 
 
@@ -150,6 +174,12 @@ class SparkUpgradeException(PySparkException):
     """
 
 
+class UnknownException(PySparkException):
+    """
+    None of the above exceptions.
+    """
+
+
 class PySparkValueError(PySparkException, ValueError):
     """
     Wrapper class for ValueError to support error classes.
diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py
index 1764ed7d02c..8415f42e383 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -25,10 +25,15 @@ from pyspark import SparkContext
 from pyspark.errors.exceptions.base import (
     AnalysisException as BaseAnalysisException,
     IllegalArgumentException as BaseIllegalArgumentException,
+    ArithmeticException as BaseArithmeticException,
+    ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException,
+    DateTimeException as BaseDateTimeException,
+    NumberFormatException as BaseNumberFormatException,
     ParseException as BaseParseException,
     PySparkException,
     PythonException as BasePythonException,
     QueryExecutionException as BaseQueryExecutionException,
+    SparkRuntimeException as BaseSparkRuntimeException,
     SparkUpgradeException as BaseSparkUpgradeException,
     StreamingQueryException as BaseStreamingQueryException,
     UnknownException as BaseUnknownException,
@@ -122,8 +127,19 @@ def convert_exception(e: Py4JJavaError) -> CapturedException:
         return StreamingQueryException(origin=e)
     elif is_instance_of(gw, e, "org.apache.spark.sql.execution.QueryExecutionException"):
         return QueryExecutionException(origin=e)
+    # Order matters. NumberFormatException inherits IllegalArgumentException.
+    elif is_instance_of(gw, e, "java.lang.NumberFormatException"):
+        return NumberFormatException(origin=e)
     elif is_instance_of(gw, e, "java.lang.IllegalArgumentException"):
         return IllegalArgumentException(origin=e)
+    elif is_instance_of(gw, e, "java.lang.ArithmeticException"):
+        return ArithmeticException(origin=e)
+    elif is_instance_of(gw, e, "java.lang.ArrayIndexOutOfBoundsException"):
+        return ArrayIndexOutOfBoundsException(origin=e)
+    elif is_instance_of(gw, e, "java.time.DateTimeException"):
+        return DateTimeException(origin=e)
+    elif is_instance_of(gw, e, "org.apache.spark.SparkRuntimeException"):
+        return SparkRuntimeException(origin=e)
     elif is_instance_of(gw, e, "org.apache.spark.SparkUpgradeException"):
         return SparkUpgradeException(origin=e)
 
@@ -187,7 +203,7 @@ class AnalysisException(CapturedException, BaseAnalysisException):
     """
 
 
-class ParseException(CapturedException, BaseParseException):
+class ParseException(AnalysisException, BaseParseException):
     """
     Failed to parse a SQL command.
     """
@@ -217,9 +233,33 @@ class PythonException(CapturedException, BasePythonException):
     """
 
 
-class UnknownException(CapturedException, BaseUnknownException):
+class ArithmeticException(CapturedException, BaseArithmeticException):
     """
-    None of the above exceptions.
+    Arithmetic exception.
+    """
+
+
+class ArrayIndexOutOfBoundsException(CapturedException, BaseArrayIndexOutOfBoundsException):
+    """
+    Array index out of bounds exception.
+    """
+
+
+class DateTimeException(CapturedException, BaseDateTimeException):
+    """
+    Datetime exception.
+    """
+
+
+class NumberFormatException(IllegalArgumentException, BaseNumberFormatException):
+    """
+    Number format exception.
+    """
+
+
+class SparkRuntimeException(CapturedException, BaseSparkRuntimeException):
+    """
+    Runtime exception.
     """
 
 
@@ -227,3 +267,9 @@ class SparkUpgradeException(CapturedException, BaseSparkUpgradeException):
     """
     Exception thrown because of Spark upgrade.
     """
+
+
+class UnknownException(CapturedException, BaseUnknownException):
+    """
+    None of the above exceptions.
+    """
diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py
index f5f1d42ca5d..43fee1f0af9 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -21,11 +21,16 @@ from typing import Dict, Optional, TYPE_CHECKING
 from pyspark.errors.exceptions.base import (
     AnalysisException as BaseAnalysisException,
     IllegalArgumentException as BaseIllegalArgumentException,
+    ArithmeticException as BaseArithmeticException,
+    ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException,
+    DateTimeException as BaseDateTimeException,
+    NumberFormatException as BaseNumberFormatException,
     ParseException as BaseParseException,
     PySparkException,
     PythonException as BasePythonException,
     StreamingQueryException as BaseStreamingQueryException,
     QueryExecutionException as BaseQueryExecutionException,
+    SparkRuntimeException as BaseSparkRuntimeException,
     SparkUpgradeException as BaseSparkUpgradeException,
 )
 
@@ -53,8 +58,19 @@ def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException:
         return StreamingQueryException(message)
     elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
         return QueryExecutionException(message)
+    # Order matters. NumberFormatException inherits IllegalArgumentException.
+    elif "java.lang.NumberFormatException" in classes:
+        return NumberFormatException(message)
     elif "java.lang.IllegalArgumentException" in classes:
         return IllegalArgumentException(message)
+    elif "java.lang.ArithmeticException" in classes:
+        return ArithmeticException(message)
+    elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
+        return ArrayIndexOutOfBoundsException(message)
+    elif "java.time.DateTimeException" in classes:
+        return DateTimeException(message)
+    elif "org.apache.spark.SparkRuntimeException" in classes:
+        return SparkRuntimeException(message)
     elif "org.apache.spark.SparkUpgradeException" in classes:
         return SparkUpgradeException(message)
     elif "org.apache.spark.api.python.PythonException" in classes:
@@ -91,41 +107,71 @@ class SparkConnectGrpcException(SparkConnectException):
 
 class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
     """
-    Failed to analyze a SQL query plan from Spark Connect server.
+    Failed to analyze a SQL query plan, thrown from Spark Connect.
     """
 
 
-class ParseException(SparkConnectGrpcException, BaseParseException):
+class ParseException(AnalysisException, BaseParseException):
     """
-    Failed to parse a SQL command from Spark Connect server.
+    Failed to parse a SQL command, thrown from Spark Connect.
     """
 
 
 class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException):
     """
-    Passed an illegal or inappropriate argument from Spark Connect server.
+    Passed an illegal or inappropriate argument, thrown from Spark Connect.
     """
 
 
 class StreamingQueryException(SparkConnectGrpcException, BaseStreamingQueryException):
     """
-    Exception that stopped a :class:`StreamingQuery` from Spark Connect server.
+    Exception that stopped a :class:`StreamingQuery` thrown from Spark Connect.
     """
 
 
 class QueryExecutionException(SparkConnectGrpcException, BaseQueryExecutionException):
     """
-    Failed to execute a query from Spark Connect server.
+    Failed to execute a query, thrown from Spark Connect.
     """
 
 
-class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException):
+class PythonException(SparkConnectGrpcException, BasePythonException):
     """
-    Exception thrown because of Spark upgrade from Spark Connect
+    Exceptions thrown from Spark Connect.
     """
 
 
-class PythonException(SparkConnectGrpcException, BasePythonException):
+class ArithmeticException(SparkConnectGrpcException, BaseArithmeticException):
+    """
+    Arithmetic exception thrown from Spark Connect.
+    """
+
+
+class ArrayIndexOutOfBoundsException(SparkConnectGrpcException, BaseArrayIndexOutOfBoundsException):
+    """
+    Array index out of bounds exception thrown from Spark Connect.
+    """
+
+
+class DateTimeException(SparkConnectGrpcException, BaseDateTimeException):
+    """
+    Datetime exception thrown from Spark Connect.
+    """
+
+
+class NumberFormatException(IllegalArgumentException, BaseNumberFormatException):
+    """
+    Number format exception thrown from Spark Connect.
+    """
+
+
+class SparkRuntimeException(SparkConnectGrpcException, BaseSparkRuntimeException):
+    """
+    Runtime exception thrown from Spark Connect.
+    """
+
+
+class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException):
     """
-    Exceptions thrown from Spark Connect server.
+    Exception thrown because of Spark upgrade from Spark Connect.
     """
diff --git a/python/pyspark/sql/tests/connect/test_parity_errors.py b/python/pyspark/sql/tests/connect/test_parity_errors.py
new file mode 100644
index 00000000000..37f5b904b3a
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_errors.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.tests.test_errors import ErrorsTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ErrorsParityTests(ErrorsTestsMixin, ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_errors import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_errors.py b/python/pyspark/sql/tests/test_errors.py
new file mode 100644
index 00000000000..2ae6ef564c5
--- /dev/null
+++ b/python/pyspark/sql/tests/test_errors.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.errors import (
+    ArithmeticException,
+    ArrayIndexOutOfBoundsException,
+    DateTimeException,
+    NumberFormatException,
+    SparkRuntimeException,
+)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class ErrorsTestsMixin:
+    def test_arithmetic_exception(self):
+        with self.assertRaises(ArithmeticException):
+            with self.sql_conf({"spark.sql.ansi.enabled": True}):
+                self.spark.sql("select 1/0").show()
+
+    def test_array_index_out_of_bounds_exception(self):
+        with self.assertRaises(ArrayIndexOutOfBoundsException):
+            with self.sql_conf({"spark.sql.ansi.enabled": True}):
+                self.spark.sql("select array(1, 2)[2]").show()
+
+    def test_date_time_exception(self):
+        with self.assertRaises(DateTimeException):
+            with self.sql_conf({"spark.sql.ansi.enabled": True}):
+                self.spark.sql("select unix_timestamp('2023-01-01', 'dd-MM-yyyy')").show()
+
+    def test_number_format_exception(self):
+        with self.assertRaises(NumberFormatException):
+            with self.sql_conf({"spark.sql.ansi.enabled": True}):
+                self.spark.sql("select cast('abc' as double)").show()
+
+    def test_spark_runtime_exception(self):
+        with self.assertRaises(SparkRuntimeException):
+            with self.sql_conf({"spark.sql.ansi.enabled": True}):
+                self.spark.sql("select cast('abc' as boolean)").show()
+
+
+class ErrorsTests(ReusedSQLTestCase, ErrorsTestsMixin):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.test_errors import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org