You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/27 00:31:25 UTC

[spark] branch branch-3.4 updated: [SPARK-42419][CONNECT][PYTHON] Migrate into error framework for Spark Connect Column API

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new a9220f77411 [SPARK-42419][CONNECT][PYTHON] Migrate into error framework for Spark Connect Column API
a9220f77411 is described below

commit a9220f77411a5a6edee4730d3e3fba04386d14c8
Author: itholic <ha...@databricks.com>
AuthorDate: Mon Feb 27 09:30:55 2023 +0900

    [SPARK-42419][CONNECT][PYTHON] Migrate into error framework for Spark Connect Column API
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to migrate `TypeError` into error framework for Spark Connect Column API.
    
    ### Why are the changes needed?
    
    To improve errors by leveraging the PySpark error framework for Spark Connect Column APIs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Fixed & added UTs.
    
    Closes #39991 from itholic/SPARK-42419.
    
    Authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 86d3db9fc1372a377625c67c2966187ebdf2848e)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/errors/error_classes.py             | 25 ++++++++
 python/pyspark/sql/column.py                       | 25 +++++---
 python/pyspark/sql/connect/column.py               | 72 +++++++++++++++-------
 .../sql/tests/connect/test_connect_column.py       | 70 +++++++++++++++++----
 .../sql/tests/connect/test_connect_function.py     | 11 ++--
 python/pyspark/sql/tests/test_column.py            | 21 +++++--
 python/pyspark/sql/tests/test_functions.py         | 14 ++++-
 7 files changed, 185 insertions(+), 53 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py
index 00d676b52b8..8c0f79f7d5a 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -99,11 +99,21 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` should be a DataFrame, got <arg_type>."
     ]
   },
+  "NOT_DATATYPE_OR_STR" : {
+    "message" : [
+      "Argument `<arg_name>` should be a DataType or str, got <arg_type>."
+    ]
+  },
   "NOT_DICT" : {
     "message" : [
       "Argument `<arg_name>` should be a dict, got <arg_type>."
     ]
   },
+  "NOT_EXPRESSION" : {
+    "message" : [
+      "Argument `<arg_name>` should be a Expression, got <arg_type>."
+    ]
+  },
   "NOT_FLOAT_OR_INT" : {
     "message" : [
       "Argument `<arg_name>` should be a float or int, got <arg_type>."
@@ -119,6 +129,11 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` should be an int, got <arg_type>."
     ]
   },
+  "NOT_ITERABLE" : {
+    "message" : [
+      "<objectName> is not iterable."
+    ]
+  },
   "NOT_LIST_OR_STR_OR_TUPLE" : {
     "message" : [
       "Argument `<arg_name>` should be a list, str or tuple, got <arg_type>."
@@ -129,11 +144,21 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` should be a list or tuple, got <arg_type>."
     ]
   },
+  "NOT_SAME_TYPE" : {
+    "message" : [
+      "Argument `<arg_name1>` and `<arg_name2>` should be the same type, got <arg_type1> and <arg_type2>."
+    ]
+  },
   "NOT_STR" : {
     "message" : [
       "Argument `<arg_name>` should be a str, got <arg_type>."
     ]
   },
+  "NOT_WINDOWSPEC" : {
+    "message" : [
+      "Argument `<arg_name>` should be a WindowSpec, got <arg_type>."
+    ]
+  },
   "UNSUPPORTED_NUMPY_ARRAY_SCALAR" : {
     "message" : [
       "The type of array scalar '<dtype>' is not supported."
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index bcf6676d5ca..abd28136895 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -35,6 +35,7 @@ from py4j.java_gateway import JavaObject
 
 from pyspark import copy_func
 from pyspark.context import SparkContext
+from pyspark.errors import PySparkTypeError
 from pyspark.sql.types import DataType
 
 if TYPE_CHECKING:
@@ -555,10 +556,16 @@ class Column:
         +---+
         """
         if not isinstance(fieldName, str):
-            raise TypeError("fieldName should be a string")
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "fieldName", "arg_type": type(fieldName).__name__},
+            )
 
         if not isinstance(col, Column):
-            raise TypeError("col should be a Column")
+            raise PySparkTypeError(
+                error_class="NOT_COLUMN",
+                message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
+            )
 
         return Column(self._jc.withField(fieldName, col._jc))
 
@@ -844,12 +851,14 @@ class Column:
         [Row(col='Ali'), Row(col='Bob')]
         """
         if type(startPos) != type(length):
-            raise TypeError(
-                "startPos and length must be the same type. "
-                "Got {startPos_t} and {length_t}, respectively.".format(
-                    startPos_t=type(startPos),
-                    length_t=type(length),
-                )
+            raise PySparkTypeError(
+                error_class="NOT_SAME_TYPE",
+                message_parameters={
+                    "arg_name1": "startPos",
+                    "arg_name2": "length",
+                    "arg_type1": type(startPos).__name__,
+                    "arg_type2": type(length).__name__,
+                },
             )
         if isinstance(startPos, int):
             jc = self._jc.substr(startPos, length)
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index a172b884f69..bc8b60beb97 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -31,6 +31,7 @@ from typing import (
     Optional,
 )
 
+from pyspark.errors import PySparkTypeError
 from pyspark.sql.types import DataType
 from pyspark.sql.column import Column as PySparkColumn
 
@@ -98,8 +99,9 @@ def _unary_op(name: str, doc: Optional[str] = "unary function") -> Callable[["Co
 class Column:
     def __init__(self, expr: "Expression") -> None:
         if not isinstance(expr, Expression):
-            raise TypeError(
-                f"Cannot construct column expected Expression, got {expr} ({type(expr)})"
+            raise PySparkTypeError(
+                error_class="NOT_EXPRESSION",
+                message_parameters={"arg_name": "expr", "arg_type": type(expr).__name__},
             )
         self._expr = expr
 
@@ -163,7 +165,10 @@ class Column:
 
     def when(self, condition: "Column", value: Any) -> "Column":
         if not isinstance(condition, Column):
-            raise TypeError("condition should be a Column")
+            raise PySparkTypeError(
+                error_class="NOT_COLUMN",
+                message_parameters={"arg_name": "condition", "arg_type": type(condition).__name__},
+            )
 
         if not isinstance(self._expr, CaseWhen):
             raise TypeError(
@@ -186,12 +191,12 @@ class Column:
 
     def otherwise(self, value: Any) -> "Column":
         if not isinstance(self._expr, CaseWhen):
-            raise TypeError(
+            raise PySparkTypeError(
                 "otherwise() can only be applied on a Column previously generated by when()"
             )
 
         if self._expr._else_value is not None:
-            raise TypeError(
+            raise PySparkTypeError(
                 "otherwise() can only be applied once on a Column previously generated by when()"
             )
 
@@ -218,12 +223,14 @@ class Column:
 
     def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -> "Column":
         if type(startPos) != type(length):
-            raise TypeError(
-                "startPos and length must be the same type. "
-                "Got {startPos_t} and {length_t}, respectively.".format(
-                    startPos_t=type(startPos),
-                    length_t=type(length),
-                )
+            raise PySparkTypeError(
+                error_class="NOT_SAME_TYPE",
+                message_parameters={
+                    "arg_name1": "startPos",
+                    "arg_name2": "length",
+                    "arg_type1": type(startPos).__name__,
+                    "arg_type2": type(length).__name__,
+                },
             )
 
         if isinstance(length, Column):
@@ -231,14 +238,20 @@ class Column:
         elif isinstance(length, int):
             length_expr = LiteralExpression._from_value(length)
         else:
-            raise TypeError("Unsupported type for substr().")
+            raise PySparkTypeError(
+                error_class="NOT_COLUMN_OR_INT",
+                message_parameters={"arg_name": "length", "arg_type": type(length).__name__},
+            )
 
         if isinstance(startPos, Column):
             start_expr = startPos._expr
         elif isinstance(startPos, int):
             start_expr = LiteralExpression._from_value(startPos)
         else:
-            raise TypeError("Unsupported type for substr().")
+            raise PySparkTypeError(
+                error_class="NOT_COLUMN_OR_INT",
+                message_parameters={"arg_name": "startPos", "arg_type": type(startPos).__name__},
+            )
 
         return Column(UnresolvedFunction("substring", [self._expr, start_expr, length_expr]))
 
@@ -303,7 +316,10 @@ class Column:
         if isinstance(dataType, (DataType, str)):
             return Column(CastExpression(expr=self._expr, data_type=dataType))
         else:
-            raise TypeError("unexpected type: %s" % type(dataType))
+            raise PySparkTypeError(
+                error_class="NOT_DATATYPE_OR_STR",
+                message_parameters={"arg_name": "dataType", "arg_type": type(dataType).__name__},
+            )
 
     cast.__doc__ = PySparkColumn.cast.__doc__
 
@@ -316,8 +332,9 @@ class Column:
         from pyspark.sql.connect.window import WindowSpec
 
         if not isinstance(window, WindowSpec):
-            raise TypeError(
-                f"window should be WindowSpec, but got {type(window).__name__} {window}"
+            raise PySparkTypeError(
+                error_class="NOT_WINDOWSPEC",
+                message_parameters={"arg_name": "window", "arg_type": type(window).__name__},
             )
 
         return Column(WindowExpression(windowFunction=self._expr, windowSpec=window))
@@ -376,12 +393,16 @@ class Column:
 
     def withField(self, fieldName: str, col: "Column") -> "Column":
         if not isinstance(fieldName, str):
-            raise TypeError(
-                f"fieldName should be a string, but got {type(fieldName).__name__} {fieldName}"
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "fieldName", "arg_type": type(fieldName).__name__},
             )
 
         if not isinstance(col, Column):
-            raise TypeError(f"col should be a Column, but got {type(col).__name__} {col}")
+            raise PySparkTypeError(
+                error_class="NOT_COLUMN",
+                message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
+            )
 
         return Column(WithField(self._expr, fieldName, col._expr))
 
@@ -391,8 +412,12 @@ class Column:
         dropField: Optional[DropField] = None
         for fieldName in fieldNames:
             if not isinstance(fieldName, str):
-                raise TypeError(
-                    f"fieldName should be a string, but got {type(fieldName).__name__} {fieldName}"
+                raise PySparkTypeError(
+                    error_class="NOT_STR",
+                    message_parameters={
+                        "arg_name": "fieldName",
+                        "arg_type": type(fieldName).__name__,
+                    },
                 )
 
             if dropField is None:
@@ -421,7 +446,10 @@ class Column:
             return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k)))
 
     def __iter__(self) -> None:
-        raise TypeError("Column is not iterable")
+        raise PySparkTypeError(
+            error_class="NOT_ITERABLE",
+            message_parameters={"objectName": "Column"},
+        )
 
     def __nonzero__(self) -> None:
         raise ValueError(
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py
index a2c786db180..b5d8163f4f7 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -41,6 +41,7 @@ from pyspark.sql.types import (
     DecimalType,
     BooleanType,
 )
+from pyspark.errors import PySparkTypeError
 from pyspark.errors.exceptions.connect import SparkConnectException
 from pyspark.testing.connectutils import should_test_connect
 from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
@@ -133,6 +134,33 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
             df4.filter(df4.name.isNotNull()).toPandas(),
         )
 
+        # check error
+        with self.assertRaises(PySparkTypeError) as pe:
+            df.name.substr(df.id, 10)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_SAME_TYPE",
+            message_parameters={
+                "arg_name1": "startPos",
+                "arg_name2": "length",
+                "arg_type1": "Column",
+                "arg_type2": "int",
+            },
+        )
+
+        with self.assertRaises(PySparkTypeError) as pe:
+            df.name.substr(10.5, 10.5)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_COLUMN_OR_INT",
+            message_parameters={
+                "arg_name": "length",
+                "arg_type": "float",
+            },
+        )
+
     def test_column_with_null(self):
         # SPARK-41751: test isNull, isNotNull, eqNullSafe
 
@@ -532,6 +560,15 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
                 df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas()
             )
 
+        with self.assertRaises(PySparkTypeError) as pe:
+            df.id.cast(10)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_DATATYPE_OR_STR",
+            message_parameters={"arg_name": "dataType", "arg_type": "int"},
+        )
+
     def test_isin(self):
         # SPARK-41526: test Column.isin
         query = """
@@ -893,24 +930,33 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
         with self.assertRaises(SparkConnectException):
             cdf.select(cdf.x.dropFields("a", "b", "c", "d")).show()
 
-        with self.assertRaisesRegex(
-            TypeError,
-            "fieldName should be a string",
-        ):
+        with self.assertRaises(PySparkTypeError) as pe:
             cdf.select(cdf.x.withField(CF.col("a"), cdf.e)).show()
 
-        with self.assertRaisesRegex(
-            TypeError,
-            "col should be a Column",
-        ):
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_STR",
+            message_parameters={"arg_name": "fieldName", "arg_type": "Column"},
+        )
+
+        with self.assertRaises(PySparkTypeError) as pe:
             cdf.select(cdf.x.withField("a", 2)).show()
 
-        with self.assertRaisesRegex(
-            TypeError,
-            "fieldName should be a string",
-        ):
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_COLUMN",
+            message_parameters={"arg_name": "col", "arg_type": "int"},
+        )
+
+        with self.assertRaises(PySparkTypeError) as pe:
             cdf.select(cdf.x.dropFields("a", 1, 2)).show()
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_STR",
+            message_parameters={"arg_name": "fieldName", "arg_type": "int"},
+        )
+
         with self.assertRaisesRegex(
             ValueError,
             "dropFields requires at least 1 field",
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index 9e499815107..599e595af62 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -892,12 +892,15 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S
         ):
             cdf.select(CF.sum("a").over(CW.orderBy("b").rowsBetween(0, (1 << 33)))).show()
 
-        with self.assertRaisesRegex(
-            TypeError,
-            "window should be WindowSpec",
-        ):
+        with self.assertRaises(PySparkTypeError) as pe:
             cdf.select(CF.rank().over(cdf.a))
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_WINDOWSPEC",
+            message_parameters={"arg_name": "window", "arg_type": "Column"},
+        )
+
         # invalid window function
         with self.assertRaises(AnalysisException):
             cdf.select(cdf.b.over(CW.orderBy("b"))).show()
diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py
index be9cb07eccc..da2a4f1de33 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -18,7 +18,7 @@
 
 from pyspark.sql import Column, Row
 from pyspark.sql.types import StructType, StructField, LongType
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PySparkTypeError
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
@@ -160,11 +160,22 @@ class ColumnTestsMixin:
         result = df.withColumn("a", df["a"].withField("b", lit(3))).collect()[0].asDict()
         self.assertEqual(3, result["a"]["b"])
 
-        self.assertRaisesRegex(
-            TypeError, "col should be a Column", lambda: df["a"].withField("b", 3)
+        with self.assertRaises(PySparkTypeError) as pe:
+            df["a"].withField("b", 3)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_COLUMN",
+            message_parameters={"arg_name": "col", "arg_type": "int"},
         )
-        self.assertRaisesRegex(
-            TypeError, "fieldName should be a string", lambda: df["a"].withField(col("b"), lit(3))
+
+        with self.assertRaises(PySparkTypeError) as pe:
+            df["a"].withField(col("b"), lit(3))
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_STR",
+            message_parameters={"arg_name": "fieldName", "arg_type": "Column"},
         )
 
     def test_drop_fields(self):
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 8bc2b96cc51..3aec7cc42de 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -390,8 +390,18 @@ class FunctionsTestsMixin:
         ]
 
         df = self.spark.createDataFrame([["nick"]], schema=["name"])
-        self.assertRaisesRegex(
-            TypeError, "must be the same type", lambda: df.select(col("name").substr(0, lit(1)))
+        with self.assertRaises(PySparkTypeError) as pe:
+            df.select(col("name").substr(0, lit(1)))
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_SAME_TYPE",
+            message_parameters={
+                "arg_name1": "startPos",
+                "arg_name2": "length",
+                "arg_type1": "int",
+                "arg_type2": "Column",
+            },
         )
 
         for name in string_functions:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org