You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/27 00:31:25 UTC
[spark] branch branch-3.4 updated: [SPARK-42419][CONNECT][PYTHON] Migrate into error framework for Spark Connect Column API
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new a9220f77411 [SPARK-42419][CONNECT][PYTHON] Migrate into error framework for Spark Connect Column API
a9220f77411 is described below
commit a9220f77411a5a6edee4730d3e3fba04386d14c8
Author: itholic <ha...@databricks.com>
AuthorDate: Mon Feb 27 09:30:55 2023 +0900
[SPARK-42419][CONNECT][PYTHON] Migrate into error framework for Spark Connect Column API
### What changes were proposed in this pull request?
This PR proposes to migrate `TypeError` into error framework for Spark Connect Column API.
### Why are the changes needed?
To improve errors by leveraging the PySpark error framework for Spark Connect Column APIs.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Fixed & added UTs.
Closes #39991 from itholic/SPARK-42419.
Authored-by: itholic <ha...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 86d3db9fc1372a377625c67c2966187ebdf2848e)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/errors/error_classes.py | 25 ++++++++
python/pyspark/sql/column.py | 25 +++++---
python/pyspark/sql/connect/column.py | 72 +++++++++++++++-------
.../sql/tests/connect/test_connect_column.py | 70 +++++++++++++++++----
.../sql/tests/connect/test_connect_function.py | 11 ++--
python/pyspark/sql/tests/test_column.py | 21 +++++--
python/pyspark/sql/tests/test_functions.py | 14 ++++-
7 files changed, 185 insertions(+), 53 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py
index 00d676b52b8..8c0f79f7d5a 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -99,11 +99,21 @@ ERROR_CLASSES_JSON = """
"Argument `<arg_name>` should be a DataFrame, got <arg_type>."
]
},
+ "NOT_DATATYPE_OR_STR" : {
+ "message" : [
+ "Argument `<arg_name>` should be a DataType or str, got <arg_type>."
+ ]
+ },
"NOT_DICT" : {
"message" : [
"Argument `<arg_name>` should be a dict, got <arg_type>."
]
},
+ "NOT_EXPRESSION" : {
+ "message" : [
+ "Argument `<arg_name>` should be a Expression, got <arg_type>."
+ ]
+ },
"NOT_FLOAT_OR_INT" : {
"message" : [
"Argument `<arg_name>` should be a float or int, got <arg_type>."
@@ -119,6 +129,11 @@ ERROR_CLASSES_JSON = """
"Argument `<arg_name>` should be an int, got <arg_type>."
]
},
+ "NOT_ITERABLE" : {
+ "message" : [
+ "<objectName> is not iterable."
+ ]
+ },
"NOT_LIST_OR_STR_OR_TUPLE" : {
"message" : [
"Argument `<arg_name>` should be a list, str or tuple, got <arg_type>."
@@ -129,11 +144,21 @@ ERROR_CLASSES_JSON = """
"Argument `<arg_name>` should be a list or tuple, got <arg_type>."
]
},
+ "NOT_SAME_TYPE" : {
+ "message" : [
+ "Argument `<arg_name1>` and `<arg_name2>` should be the same type, got <arg_type1> and <arg_type2>."
+ ]
+ },
"NOT_STR" : {
"message" : [
"Argument `<arg_name>` should be a str, got <arg_type>."
]
},
+ "NOT_WINDOWSPEC" : {
+ "message" : [
+ "Argument `<arg_name>` should be a WindowSpec, got <arg_type>."
+ ]
+ },
"UNSUPPORTED_NUMPY_ARRAY_SCALAR" : {
"message" : [
"The type of array scalar '<dtype>' is not supported."
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index bcf6676d5ca..abd28136895 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -35,6 +35,7 @@ from py4j.java_gateway import JavaObject
from pyspark import copy_func
from pyspark.context import SparkContext
+from pyspark.errors import PySparkTypeError
from pyspark.sql.types import DataType
if TYPE_CHECKING:
@@ -555,10 +556,16 @@ class Column:
+---+
"""
if not isinstance(fieldName, str):
- raise TypeError("fieldName should be a string")
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "fieldName", "arg_type": type(fieldName).__name__},
+ )
if not isinstance(col, Column):
- raise TypeError("col should be a Column")
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN",
+ message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
+ )
return Column(self._jc.withField(fieldName, col._jc))
@@ -844,12 +851,14 @@ class Column:
[Row(col='Ali'), Row(col='Bob')]
"""
if type(startPos) != type(length):
- raise TypeError(
- "startPos and length must be the same type. "
- "Got {startPos_t} and {length_t}, respectively.".format(
- startPos_t=type(startPos),
- length_t=type(length),
- )
+ raise PySparkTypeError(
+ error_class="NOT_SAME_TYPE",
+ message_parameters={
+ "arg_name1": "startPos",
+ "arg_name2": "length",
+ "arg_type1": type(startPos).__name__,
+ "arg_type2": type(length).__name__,
+ },
)
if isinstance(startPos, int):
jc = self._jc.substr(startPos, length)
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index a172b884f69..bc8b60beb97 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -31,6 +31,7 @@ from typing import (
Optional,
)
+from pyspark.errors import PySparkTypeError
from pyspark.sql.types import DataType
from pyspark.sql.column import Column as PySparkColumn
@@ -98,8 +99,9 @@ def _unary_op(name: str, doc: Optional[str] = "unary function") -> Callable[["Co
class Column:
def __init__(self, expr: "Expression") -> None:
if not isinstance(expr, Expression):
- raise TypeError(
- f"Cannot construct column expected Expression, got {expr} ({type(expr)})"
+ raise PySparkTypeError(
+ error_class="NOT_EXPRESSION",
+ message_parameters={"arg_name": "expr", "arg_type": type(expr).__name__},
)
self._expr = expr
@@ -163,7 +165,10 @@ class Column:
def when(self, condition: "Column", value: Any) -> "Column":
if not isinstance(condition, Column):
- raise TypeError("condition should be a Column")
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN",
+ message_parameters={"arg_name": "condition", "arg_type": type(condition).__name__},
+ )
if not isinstance(self._expr, CaseWhen):
raise TypeError(
@@ -186,12 +191,12 @@ class Column:
def otherwise(self, value: Any) -> "Column":
if not isinstance(self._expr, CaseWhen):
- raise TypeError(
+ raise PySparkTypeError(
"otherwise() can only be applied on a Column previously generated by when()"
)
if self._expr._else_value is not None:
- raise TypeError(
+ raise PySparkTypeError(
"otherwise() can only be applied once on a Column previously generated by when()"
)
@@ -218,12 +223,14 @@ class Column:
def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -> "Column":
if type(startPos) != type(length):
- raise TypeError(
- "startPos and length must be the same type. "
- "Got {startPos_t} and {length_t}, respectively.".format(
- startPos_t=type(startPos),
- length_t=type(length),
- )
+ raise PySparkTypeError(
+ error_class="NOT_SAME_TYPE",
+ message_parameters={
+ "arg_name1": "startPos",
+ "arg_name2": "length",
+ "arg_type1": type(startPos).__name__,
+ "arg_type2": type(length).__name__,
+ },
)
if isinstance(length, Column):
@@ -231,14 +238,20 @@ class Column:
elif isinstance(length, int):
length_expr = LiteralExpression._from_value(length)
else:
- raise TypeError("Unsupported type for substr().")
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_INT",
+ message_parameters={"arg_name": "length", "arg_type": type(length).__name__},
+ )
if isinstance(startPos, Column):
start_expr = startPos._expr
elif isinstance(startPos, int):
start_expr = LiteralExpression._from_value(startPos)
else:
- raise TypeError("Unsupported type for substr().")
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_INT",
+ message_parameters={"arg_name": "startPos", "arg_type": type(startPos).__name__},
+ )
return Column(UnresolvedFunction("substring", [self._expr, start_expr, length_expr]))
@@ -303,7 +316,10 @@ class Column:
if isinstance(dataType, (DataType, str)):
return Column(CastExpression(expr=self._expr, data_type=dataType))
else:
- raise TypeError("unexpected type: %s" % type(dataType))
+ raise PySparkTypeError(
+ error_class="NOT_DATATYPE_OR_STR",
+ message_parameters={"arg_name": "dataType", "arg_type": type(dataType).__name__},
+ )
cast.__doc__ = PySparkColumn.cast.__doc__
@@ -316,8 +332,9 @@ class Column:
from pyspark.sql.connect.window import WindowSpec
if not isinstance(window, WindowSpec):
- raise TypeError(
- f"window should be WindowSpec, but got {type(window).__name__} {window}"
+ raise PySparkTypeError(
+ error_class="NOT_WINDOWSPEC",
+ message_parameters={"arg_name": "window", "arg_type": type(window).__name__},
)
return Column(WindowExpression(windowFunction=self._expr, windowSpec=window))
@@ -376,12 +393,16 @@ class Column:
def withField(self, fieldName: str, col: "Column") -> "Column":
if not isinstance(fieldName, str):
- raise TypeError(
- f"fieldName should be a string, but got {type(fieldName).__name__} {fieldName}"
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "fieldName", "arg_type": type(fieldName).__name__},
)
if not isinstance(col, Column):
- raise TypeError(f"col should be a Column, but got {type(col).__name__} {col}")
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN",
+ message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
+ )
return Column(WithField(self._expr, fieldName, col._expr))
@@ -391,8 +412,12 @@ class Column:
dropField: Optional[DropField] = None
for fieldName in fieldNames:
if not isinstance(fieldName, str):
- raise TypeError(
- f"fieldName should be a string, but got {type(fieldName).__name__} {fieldName}"
+ raise PySparkTypeError(
+ error_class="NOT_STR",
+ message_parameters={
+ "arg_name": "fieldName",
+ "arg_type": type(fieldName).__name__,
+ },
)
if dropField is None:
@@ -421,7 +446,10 @@ class Column:
return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k)))
def __iter__(self) -> None:
- raise TypeError("Column is not iterable")
+ raise PySparkTypeError(
+ error_class="NOT_ITERABLE",
+ message_parameters={"objectName": "Column"},
+ )
def __nonzero__(self) -> None:
raise ValueError(
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py
index a2c786db180..b5d8163f4f7 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -41,6 +41,7 @@ from pyspark.sql.types import (
DecimalType,
BooleanType,
)
+from pyspark.errors import PySparkTypeError
from pyspark.errors.exceptions.connect import SparkConnectException
from pyspark.testing.connectutils import should_test_connect
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
@@ -133,6 +134,33 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
df4.filter(df4.name.isNotNull()).toPandas(),
)
+ # check error
+ with self.assertRaises(PySparkTypeError) as pe:
+ df.name.substr(df.id, 10)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_SAME_TYPE",
+ message_parameters={
+ "arg_name1": "startPos",
+ "arg_name2": "length",
+ "arg_type1": "Column",
+ "arg_type2": "int",
+ },
+ )
+
+ with self.assertRaises(PySparkTypeError) as pe:
+ df.name.substr(10.5, 10.5)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_COLUMN_OR_INT",
+ message_parameters={
+ "arg_name": "length",
+ "arg_type": "float",
+ },
+ )
+
def test_column_with_null(self):
# SPARK-41751: test isNull, isNotNull, eqNullSafe
@@ -532,6 +560,15 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas()
)
+ with self.assertRaises(PySparkTypeError) as pe:
+ df.id.cast(10)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_DATATYPE_OR_STR",
+ message_parameters={"arg_name": "dataType", "arg_type": "int"},
+ )
+
def test_isin(self):
# SPARK-41526: test Column.isin
query = """
@@ -893,24 +930,33 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
with self.assertRaises(SparkConnectException):
cdf.select(cdf.x.dropFields("a", "b", "c", "d")).show()
- with self.assertRaisesRegex(
- TypeError,
- "fieldName should be a string",
- ):
+ with self.assertRaises(PySparkTypeError) as pe:
cdf.select(cdf.x.withField(CF.col("a"), cdf.e)).show()
- with self.assertRaisesRegex(
- TypeError,
- "col should be a Column",
- ):
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "fieldName", "arg_type": "Column"},
+ )
+
+ with self.assertRaises(PySparkTypeError) as pe:
cdf.select(cdf.x.withField("a", 2)).show()
- with self.assertRaisesRegex(
- TypeError,
- "fieldName should be a string",
- ):
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_COLUMN",
+ message_parameters={"arg_name": "col", "arg_type": "int"},
+ )
+
+ with self.assertRaises(PySparkTypeError) as pe:
cdf.select(cdf.x.dropFields("a", 1, 2)).show()
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "fieldName", "arg_type": "int"},
+ )
+
with self.assertRaisesRegex(
ValueError,
"dropFields requires at least 1 field",
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index 9e499815107..599e595af62 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -892,12 +892,15 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S
):
cdf.select(CF.sum("a").over(CW.orderBy("b").rowsBetween(0, (1 << 33)))).show()
- with self.assertRaisesRegex(
- TypeError,
- "window should be WindowSpec",
- ):
+ with self.assertRaises(PySparkTypeError) as pe:
cdf.select(CF.rank().over(cdf.a))
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_WINDOWSPEC",
+ message_parameters={"arg_name": "window", "arg_type": "Column"},
+ )
+
# invalid window function
with self.assertRaises(AnalysisException):
cdf.select(cdf.b.over(CW.orderBy("b"))).show()
diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py
index be9cb07eccc..da2a4f1de33 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -18,7 +18,7 @@
from pyspark.sql import Column, Row
from pyspark.sql.types import StructType, StructField, LongType
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -160,11 +160,22 @@ class ColumnTestsMixin:
result = df.withColumn("a", df["a"].withField("b", lit(3))).collect()[0].asDict()
self.assertEqual(3, result["a"]["b"])
- self.assertRaisesRegex(
- TypeError, "col should be a Column", lambda: df["a"].withField("b", 3)
+ with self.assertRaises(PySparkTypeError) as pe:
+ df["a"].withField("b", 3)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_COLUMN",
+ message_parameters={"arg_name": "col", "arg_type": "int"},
)
- self.assertRaisesRegex(
- TypeError, "fieldName should be a string", lambda: df["a"].withField(col("b"), lit(3))
+
+ with self.assertRaises(PySparkTypeError) as pe:
+ df["a"].withField(col("b"), lit(3))
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_STR",
+ message_parameters={"arg_name": "fieldName", "arg_type": "Column"},
)
def test_drop_fields(self):
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 8bc2b96cc51..3aec7cc42de 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -390,8 +390,18 @@ class FunctionsTestsMixin:
]
df = self.spark.createDataFrame([["nick"]], schema=["name"])
- self.assertRaisesRegex(
- TypeError, "must be the same type", lambda: df.select(col("name").substr(0, lit(1)))
+ with self.assertRaises(PySparkTypeError) as pe:
+ df.select(col("name").substr(0, lit(1)))
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_SAME_TYPE",
+ message_parameters={
+ "arg_name1": "startPos",
+ "arg_name2": "length",
+ "arg_type1": "int",
+ "arg_type2": "Column",
+ },
)
for name in string_functions:
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org