You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/24 10:13:30 UTC
[spark] branch master updated: [SPARK-42911][PYTHON] Introduce more basic exceptions
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fa9b6c370c8 [SPARK-42911][PYTHON] Introduce more basic exceptions
fa9b6c370c8 is described below
commit fa9b6c370c85ca65e92171562c52379b80e9c796
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Fri Mar 24 19:13:14 2023 +0900
[SPARK-42911][PYTHON] Introduce more basic exceptions
### What changes were proposed in this pull request?
Introduces more basic exceptions.
- ArithmeticException
- ArrayIndexOutOfBoundsException
- DateTimeException
- NumberFormatException
- SparkRuntimeException
### Why are the changes needed?
There are more exceptions that Spark throws but PySpark doesn't capture.
We should introduce more basic exceptions; otherwise we still see `Py4JJavaError` or `SparkConnectGrpcException`.
```py
>>> spark.conf.set("spark.sql.ansi.enabled", True)
>>> spark.sql("select 1/0")
DataFrame[(1 / 0): double]
>>> spark.sql("select 1/0").show()
Traceback (most recent call last):
...
py4j.protocol.Py4JJavaError: An error occurred while calling o44.showString.
: org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 8) ==
select 1/0
^^^
at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:225)
... JVM's stacktrace
```
```py
>>> spark.sql("select 1/0").show()
Traceback (most recent call last):
...
pyspark.errors.exceptions.connect.SparkConnectGrpcException: (org.apache.spark.SparkArithmeticException) [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 8) ==
select 1/0
^^^
```
### Does this PR introduce _any_ user-facing change?
The error message is more readable.
```py
>>> spark.sql("select 1/0").show()
Traceback (most recent call last):
...
pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 8) ==
select 1/0
^^^
```
or
```py
>>> spark.sql("select 1/0").show()
Traceback (most recent call last):
...
pyspark.errors.exceptions.connect.ArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 8) ==
select 1/0
^^^
```
### How was this patch tested?
Added the related tests.
Closes #40538 from ueshin/issues/SPARK-42911/exceptions.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
dev/sparktestsupport/modules.py | 2 +
python/pyspark/errors/__init__.py | 10 ++++
python/pyspark/errors/exceptions/base.py | 36 ++++++++++-
python/pyspark/errors/exceptions/captured.py | 52 +++++++++++++++-
python/pyspark/errors/exceptions/connect.py | 66 +++++++++++++++++----
.../sql/tests/connect/test_parity_errors.py | 36 +++++++++++
python/pyspark/sql/tests/test_errors.py | 69 ++++++++++++++++++++++
7 files changed, 255 insertions(+), 16 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index c3c3b415a1f..11257841bce 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -476,6 +476,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.test_context",
"pyspark.sql.tests.test_dataframe",
"pyspark.sql.tests.test_datasources",
+ "pyspark.sql.tests.test_errors",
"pyspark.sql.tests.test_functions",
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
@@ -754,6 +755,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_column",
"pyspark.sql.tests.connect.test_parity_datasources",
+ "pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
"pyspark.sql.tests.connect.test_parity_conf",
"pyspark.sql.tests.connect.test_parity_serde",
diff --git a/python/pyspark/errors/__init__.py b/python/pyspark/errors/__init__.py
index 94117fc5160..e4143d4eb87 100644
--- a/python/pyspark/errors/__init__.py
+++ b/python/pyspark/errors/__init__.py
@@ -24,10 +24,15 @@ from pyspark.errors.exceptions.base import ( # noqa: F401
TempTableAlreadyExistsException,
ParseException,
IllegalArgumentException,
+ ArithmeticException,
+ ArrayIndexOutOfBoundsException,
+ DateTimeException,
+ NumberFormatException,
StreamingQueryException,
QueryExecutionException,
PythonException,
UnknownException,
+ SparkRuntimeException,
SparkUpgradeException,
PySparkTypeError,
PySparkValueError,
@@ -41,10 +46,15 @@ __all__ = [
"TempTableAlreadyExistsException",
"ParseException",
"IllegalArgumentException",
+ "ArithmeticException",
+ "ArrayIndexOutOfBoundsException",
+ "DateTimeException",
+ "NumberFormatException",
"StreamingQueryException",
"QueryExecutionException",
"PythonException",
"UnknownException",
+ "SparkRuntimeException",
"SparkUpgradeException",
"PySparkTypeError",
"PySparkValueError",
diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py
index fa66b80ac3a..31b69650972 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -108,7 +108,7 @@ class TempTableAlreadyExistsException(AnalysisException):
"""
-class ParseException(PySparkException):
+class ParseException(AnalysisException):
"""
Failed to parse a SQL command.
"""
@@ -120,6 +120,30 @@ class IllegalArgumentException(PySparkException):
"""
+class ArithmeticException(PySparkException):
+ """
+ Arithmetic exception thrown from Spark with an error class.
+ """
+
+
+class ArrayIndexOutOfBoundsException(PySparkException):
+ """
+ Array index out of bounds exception thrown from Spark with an error class.
+ """
+
+
+class DateTimeException(PySparkException):
+ """
+ Datetime exception thrown from Spark with an error class.
+ """
+
+
+class NumberFormatException(IllegalArgumentException):
+ """
+ Number format exception thrown from Spark with an error class.
+ """
+
+
class StreamingQueryException(PySparkException):
"""
Exception that stopped a :class:`StreamingQuery`.
@@ -138,9 +162,9 @@ class PythonException(PySparkException):
"""
-class UnknownException(PySparkException):
+class SparkRuntimeException(PySparkException):
"""
- None of the above exceptions.
+ Runtime exception thrown from Spark with an error class.
"""
@@ -150,6 +174,12 @@ class SparkUpgradeException(PySparkException):
"""
+class UnknownException(PySparkException):
+ """
+ None of the above exceptions.
+ """
+
+
class PySparkValueError(PySparkException, ValueError):
"""
Wrapper class for ValueError to support error classes.
diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py
index 6313665b3fe..d1b57997f99 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -25,10 +25,15 @@ from pyspark import SparkContext
from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
IllegalArgumentException as BaseIllegalArgumentException,
+ ArithmeticException as BaseArithmeticException,
+ ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException,
+ DateTimeException as BaseDateTimeException,
+ NumberFormatException as BaseNumberFormatException,
ParseException as BaseParseException,
PySparkException,
PythonException as BasePythonException,
QueryExecutionException as BaseQueryExecutionException,
+ SparkRuntimeException as BaseSparkRuntimeException,
SparkUpgradeException as BaseSparkUpgradeException,
StreamingQueryException as BaseStreamingQueryException,
UnknownException as BaseUnknownException,
@@ -129,8 +134,19 @@ def convert_exception(e: Py4JJavaError) -> CapturedException:
return StreamingQueryException(origin=e)
elif is_instance_of(gw, e, "org.apache.spark.sql.execution.QueryExecutionException"):
return QueryExecutionException(origin=e)
+ # Order matters. NumberFormatException inherits IllegalArgumentException.
+ elif is_instance_of(gw, e, "java.lang.NumberFormatException"):
+ return NumberFormatException(origin=e)
elif is_instance_of(gw, e, "java.lang.IllegalArgumentException"):
return IllegalArgumentException(origin=e)
+ elif is_instance_of(gw, e, "java.lang.ArithmeticException"):
+ return ArithmeticException(origin=e)
+ elif is_instance_of(gw, e, "java.lang.ArrayIndexOutOfBoundsException"):
+ return ArrayIndexOutOfBoundsException(origin=e)
+ elif is_instance_of(gw, e, "java.time.DateTimeException"):
+ return DateTimeException(origin=e)
+ elif is_instance_of(gw, e, "org.apache.spark.SparkRuntimeException"):
+ return SparkRuntimeException(origin=e)
elif is_instance_of(gw, e, "org.apache.spark.SparkUpgradeException"):
return SparkUpgradeException(origin=e)
@@ -194,7 +210,7 @@ class AnalysisException(CapturedException, BaseAnalysisException):
"""
-class ParseException(CapturedException, BaseParseException):
+class ParseException(AnalysisException, BaseParseException):
"""
Failed to parse a SQL command.
"""
@@ -224,9 +240,33 @@ class PythonException(CapturedException, BasePythonException):
"""
-class UnknownException(CapturedException, BaseUnknownException):
+class ArithmeticException(CapturedException, BaseArithmeticException):
"""
- None of the above exceptions.
+ Arithmetic exception.
+ """
+
+
+class ArrayIndexOutOfBoundsException(CapturedException, BaseArrayIndexOutOfBoundsException):
+ """
+ Array index out of bounds exception.
+ """
+
+
+class DateTimeException(CapturedException, BaseDateTimeException):
+ """
+ Datetime exception.
+ """
+
+
+class NumberFormatException(IllegalArgumentException, BaseNumberFormatException):
+ """
+ Number format exception.
+ """
+
+
+class SparkRuntimeException(CapturedException, BaseSparkRuntimeException):
+ """
+ Runtime exception.
"""
@@ -234,3 +274,9 @@ class SparkUpgradeException(CapturedException, BaseSparkUpgradeException):
"""
Exception thrown because of Spark upgrade.
"""
+
+
+class UnknownException(CapturedException, BaseUnknownException):
+ """
+ None of the above exceptions.
+ """
diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py
index f5f1d42ca5d..43fee1f0af9 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -21,11 +21,16 @@ from typing import Dict, Optional, TYPE_CHECKING
from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
IllegalArgumentException as BaseIllegalArgumentException,
+ ArithmeticException as BaseArithmeticException,
+ ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException,
+ DateTimeException as BaseDateTimeException,
+ NumberFormatException as BaseNumberFormatException,
ParseException as BaseParseException,
PySparkException,
PythonException as BasePythonException,
StreamingQueryException as BaseStreamingQueryException,
QueryExecutionException as BaseQueryExecutionException,
+ SparkRuntimeException as BaseSparkRuntimeException,
SparkUpgradeException as BaseSparkUpgradeException,
)
@@ -53,8 +58,19 @@ def convert_exception(info: "ErrorInfo", message: str) -> SparkConnectException:
return StreamingQueryException(message)
elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
return QueryExecutionException(message)
+ # Order matters. NumberFormatException inherits IllegalArgumentException.
+ elif "java.lang.NumberFormatException" in classes:
+ return NumberFormatException(message)
elif "java.lang.IllegalArgumentException" in classes:
return IllegalArgumentException(message)
+ elif "java.lang.ArithmeticException" in classes:
+ return ArithmeticException(message)
+ elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
+ return ArrayIndexOutOfBoundsException(message)
+ elif "java.time.DateTimeException" in classes:
+ return DateTimeException(message)
+ elif "org.apache.spark.SparkRuntimeException" in classes:
+ return SparkRuntimeException(message)
elif "org.apache.spark.SparkUpgradeException" in classes:
return SparkUpgradeException(message)
elif "org.apache.spark.api.python.PythonException" in classes:
@@ -91,41 +107,71 @@ class SparkConnectGrpcException(SparkConnectException):
class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
"""
- Failed to analyze a SQL query plan from Spark Connect server.
+ Failed to analyze a SQL query plan, thrown from Spark Connect.
"""
-class ParseException(SparkConnectGrpcException, BaseParseException):
+class ParseException(AnalysisException, BaseParseException):
"""
- Failed to parse a SQL command from Spark Connect server.
+ Failed to parse a SQL command, thrown from Spark Connect.
"""
class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException):
"""
- Passed an illegal or inappropriate argument from Spark Connect server.
+ Passed an illegal or inappropriate argument, thrown from Spark Connect.
"""
class StreamingQueryException(SparkConnectGrpcException, BaseStreamingQueryException):
"""
- Exception that stopped a :class:`StreamingQuery` from Spark Connect server.
+ Exception that stopped a :class:`StreamingQuery` thrown from Spark Connect.
"""
class QueryExecutionException(SparkConnectGrpcException, BaseQueryExecutionException):
"""
- Failed to execute a query from Spark Connect server.
+ Failed to execute a query, thrown from Spark Connect.
"""
-class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException):
+class PythonException(SparkConnectGrpcException, BasePythonException):
"""
- Exception thrown because of Spark upgrade from Spark Connect
+ Exceptions thrown from Spark Connect.
"""
-class PythonException(SparkConnectGrpcException, BasePythonException):
+class ArithmeticException(SparkConnectGrpcException, BaseArithmeticException):
+ """
+ Arithmetic exception thrown from Spark Connect.
+ """
+
+
+class ArrayIndexOutOfBoundsException(SparkConnectGrpcException, BaseArrayIndexOutOfBoundsException):
+ """
+ Array index out of bounds exception thrown from Spark Connect.
+ """
+
+
+class DateTimeException(SparkConnectGrpcException, BaseDateTimeException):
+ """
+ Datetime exception thrown from Spark Connect.
+ """
+
+
+class NumberFormatException(IllegalArgumentException, BaseNumberFormatException):
+ """
+ Number format exception thrown from Spark Connect.
+ """
+
+
+class SparkRuntimeException(SparkConnectGrpcException, BaseSparkRuntimeException):
+ """
+ Runtime exception thrown from Spark Connect.
+ """
+
+
+class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException):
"""
- Exceptions thrown from Spark Connect server.
+ Exception thrown because of Spark upgrade from Spark Connect.
"""
diff --git a/python/pyspark/sql/tests/connect/test_parity_errors.py b/python/pyspark/sql/tests/connect/test_parity_errors.py
new file mode 100644
index 00000000000..37f5b904b3a
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_errors.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.tests.test_errors import ErrorsTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ErrorsParityTests(ErrorsTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_errors import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_errors.py b/python/pyspark/sql/tests/test_errors.py
new file mode 100644
index 00000000000..2ae6ef564c5
--- /dev/null
+++ b/python/pyspark/sql/tests/test_errors.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.errors import (
+ ArithmeticException,
+ ArrayIndexOutOfBoundsException,
+ DateTimeException,
+ NumberFormatException,
+ SparkRuntimeException,
+)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class ErrorsTestsMixin:
+ def test_arithmetic_exception(self):
+ with self.assertRaises(ArithmeticException):
+ with self.sql_conf({"spark.sql.ansi.enabled": True}):
+ self.spark.sql("select 1/0").show()
+
+ def test_array_index_out_of_bounds_exception(self):
+ with self.assertRaises(ArrayIndexOutOfBoundsException):
+ with self.sql_conf({"spark.sql.ansi.enabled": True}):
+ self.spark.sql("select array(1, 2)[2]").show()
+
+ def test_date_time_exception(self):
+ with self.assertRaises(DateTimeException):
+ with self.sql_conf({"spark.sql.ansi.enabled": True}):
+ self.spark.sql("select unix_timestamp('2023-01-01', 'dd-MM-yyyy')").show()
+
+ def test_number_format_exception(self):
+ with self.assertRaises(NumberFormatException):
+ with self.sql_conf({"spark.sql.ansi.enabled": True}):
+ self.spark.sql("select cast('abc' as double)").show()
+
+ def test_spark_runtime_exception(self):
+ with self.assertRaises(SparkRuntimeException):
+ with self.sql_conf({"spark.sql.ansi.enabled": True}):
+ self.spark.sql("select cast('abc' as boolean)").show()
+
+
+class ErrorsTests(ReusedSQLTestCase, ErrorsTestsMixin):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_errors import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org