You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/17 05:00:05 UTC

[spark] branch master updated: [SPARK-44052][CONNECT][PS] Add util to get proper Column or DataFrame class for Spark Connect

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b97ce8b9a99 [SPARK-44052][CONNECT][PS] Add util to get proper Column or DataFrame class for Spark Connect
b97ce8b9a99 is described below

commit b97ce8b9a99c570fc57dec967e7e9db3d115c1db
Author: itholic <ha...@databricks.com>
AuthorDate: Sat Jun 17 12:59:48 2023 +0800

    [SPARK-44052][CONNECT][PS] Add util to get proper Column or DataFrame class for Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add util to get proper Column or DataFrame class for Spark Connect.
    
    ### Why are the changes needed?
    
    To eliminate code duplication and obtain the appropriate Spark Connect Column and Spark Connect DataFrame classes, utility functions `get_column_class` and `get_dataframe_class` have been added.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's code refactoring.
    
    ### How was this patch tested?
    
    The existing CI should pass.
    
    Closes #41570 from itholic/SPARK-43703.
    
    Authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/pandas/data_type_ops/base.py        | 22 ++----
 python/pyspark/pandas/data_type_ops/boolean_ops.py | 38 +++-------
 python/pyspark/pandas/data_type_ops/date_ops.py    | 40 +++-------
 python/pyspark/pandas/data_type_ops/num_ops.py     | 86 +++++-----------------
 python/pyspark/pandas/frame.py                     | 38 +++-------
 python/pyspark/pandas/indexes/base.py              |  2 +-
 python/pyspark/pandas/indexes/multi.py             | 17 ++++-
 python/pyspark/pandas/indexing.py                  | 65 +++-------------
 python/pyspark/pandas/internal.py                  | 43 +++--------
 python/pyspark/pandas/namespace.py                 |  9 +--
 python/pyspark/pandas/series.py                    | 17 ++---
 python/pyspark/pandas/spark/accessors.py           | 23 +-----
 .../tests/connect/indexes/test_parity_base.py      | 10 ++-
 python/pyspark/pandas/utils.py                     |  9 +--
 python/pyspark/sql/utils.py                        | 25 ++++++-
 15 files changed, 131 insertions(+), 313 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py
index 18e792e292f..5d497a55a5f 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -24,7 +24,7 @@ import numpy as np
 import pandas as pd
 from pandas.api.types import CategoricalDtype
 
-from pyspark.sql import functions as F, Column as PySparkColumn
+from pyspark.sql import functions as F
 from pyspark.sql.types import (
     ArrayType,
     BinaryType,
@@ -54,7 +54,7 @@ from pyspark.pandas.typedef.typehints import (
 )
 
 # For supporting Spark Connect
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class
 
 if extension_dtypes_available:
     from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
@@ -482,26 +482,16 @@ class DataTypeOps(object, metaclass=ABCMeta):
         else:
             from pyspark.pandas.base import column_op
 
-            if is_remote():
-                from pyspark.sql.connect.column import Column as ConnectColumn
-
-                Column = ConnectColumn
-            else:
-                Column = PySparkColumn  # type: ignore[assignment]
-            return column_op(Column.__eq__)(left, right)  # type: ignore[arg-type]
+            Column = get_column_class()
+            return column_op(Column.__eq__)(left, right)
 
     def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__ne__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__ne__)(left, right)
 
     def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
         raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
diff --git a/python/pyspark/pandas/data_type_ops/boolean_ops.py b/python/pyspark/pandas/data_type_ops/boolean_ops.py
index d800dbc0714..11f376d6e16 100644
--- a/python/pyspark/pandas/data_type_ops/boolean_ops.py
+++ b/python/pyspark/pandas/data_type_ops/boolean_ops.py
@@ -38,7 +38,7 @@ from pyspark.pandas.typedef.typehints import as_spark_type, extension_dtypes, pa
 from pyspark.sql import functions as F
 from pyspark.sql.column import Column as PySparkColumn
 from pyspark.sql.types import BooleanType, StringType
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class
 from pyspark.errors import PySparkValueError
 
 
@@ -331,43 +331,23 @@ class BooleanOps(DataTypeOps):
 
     def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__lt__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__lt__)(left, right)
 
     def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__le__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__le__)(left, right)
 
     def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__ge__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__ge__)(left, right)
 
     def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__gt__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__gt__)(left, right)
 
     def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
         return operand._with_new_scol(~operand.spark.column, field=operand._internal.data_fields[0])
diff --git a/python/pyspark/pandas/data_type_ops/date_ops.py b/python/pyspark/pandas/data_type_ops/date_ops.py
index f3cfaa9c403..51d1018a304 100644
--- a/python/pyspark/pandas/data_type_ops/date_ops.py
+++ b/python/pyspark/pandas/data_type_ops/date_ops.py
@@ -23,9 +23,9 @@ import numpy as np
 import pandas as pd
 from pandas.api.types import CategoricalDtype
 
-from pyspark.sql import functions as F, Column as PySparkColumn
+from pyspark.sql import functions as F
 from pyspark.sql.types import BooleanType, DateType, StringType
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class
 
 from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
 from pyspark.pandas.base import column_op, IndexOpsMixin
@@ -85,49 +85,29 @@ class DateOps(DataTypeOps):
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__lt__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__lt__)(left, right)
 
     def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__le__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__le__)(left, right)
 
     def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__ge__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__ge__)(left, right)
 
     def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__gt__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__gt__)(left, right)
 
     def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
         dtype, spark_type = pandas_on_spark_type(dtype)
diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py
index 911228a5265..af5e387c0f2 100644
--- a/python/pyspark/pandas/data_type_ops/num_ops.py
+++ b/python/pyspark/pandas/data_type_ops/num_ops.py
@@ -54,7 +54,7 @@ from pyspark.sql.types import (
 from pyspark.errors import PySparkValueError
 
 # For Supporting Spark Connect
-from pyspark.sql.utils import is_remote, pyspark_column_op
+from pyspark.sql.utils import pyspark_column_op, get_column_class
 
 
 def _non_fractional_astype(
@@ -83,13 +83,8 @@ class NumericOps(DataTypeOps):
             raise TypeError("Addition can not be applied to given types.")
 
         right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__add__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__add__)(left, right)
 
     def sub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -97,13 +92,8 @@ class NumericOps(DataTypeOps):
             raise TypeError("Subtraction can not be applied to given types.")
 
         right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__sub__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__sub__)(left, right)
 
     def mod(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -121,12 +111,7 @@ class NumericOps(DataTypeOps):
         if not is_valid_operand_for_numeric_arithmetic(right):
             raise TypeError("Exponentiation can not be applied to given types.")
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
 
         def pow_func(left: Column, right: Any) -> Column:  # type: ignore[valid-type]
             return (
@@ -143,51 +128,31 @@ class NumericOps(DataTypeOps):
         if not isinstance(right, numbers.Number):
             raise TypeError("Addition can not be applied to given types.")
         right = transform_boolean_operand_to_numeric(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__radd__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__radd__)(left, right)
 
     def rsub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
         if not isinstance(right, numbers.Number):
             raise TypeError("Subtraction can not be applied to given types.")
         right = transform_boolean_operand_to_numeric(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__rsub__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__rsub__)(left, right)
 
     def rmul(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
         if not isinstance(right, numbers.Number):
             raise TypeError("Multiplication can not be applied to given types.")
         right = transform_boolean_operand_to_numeric(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__rmul__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__rmul__)(left, right)
 
     def rpow(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
         if not isinstance(right, numbers.Number):
             raise TypeError("Exponentiation can not be applied to given types.")
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
 
         def rpow_func(left: Column, right: Any) -> Column:  # type: ignore[valid-type]
             return F.when(F.lit(right == 1), right).otherwise(Column.__rpow__(left, right))
@@ -286,13 +251,8 @@ class IntegralOps(NumericOps):
             raise TypeError("Multiplication can not be applied to given types.")
 
         right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__mul__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__mul__)(left, right)
 
     def truediv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -376,13 +336,8 @@ class FractionalOps(NumericOps):
             raise TypeError("Multiplication can not be applied to given types.")
 
         right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return column_op(Column.__mul__)(left, right)  # type: ignore[arg-type]
+        Column = get_column_class()
+        return column_op(Column.__mul__)(left, right)
 
     def truediv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -540,12 +495,7 @@ class DecimalOps(FractionalOps):
         if not isinstance(right, numbers.Number):
             raise TypeError("Exponentiation can not be applied to given types.")
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
 
         def rpow_func(left: Column, right: Any) -> Column:  # type: ignore[valid-type]
             return (
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index faf5cd028bc..6f2c8389a4c 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -149,7 +149,7 @@ from pyspark.pandas.typedef.typehints import (
 from pyspark.pandas.plot import PandasOnSparkPlotAccessor
 
 # For supporting Spark Connect
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class, get_dataframe_class
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import OptionalPrimitiveType
@@ -527,12 +527,7 @@ class DataFrame(Frame, Generic[T]):
     def __init__(  # type: ignore[no-untyped-def]
         self, data=None, index=None, columns=None, dtype=None, copy=False
     ):
-        if is_remote():
-            from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
-
-            SparkDataFrame = ConnectDataFrame
-        else:
-            SparkDataFrame = PySparkDataFrame
+        SparkDataFrame = get_dataframe_class()
         index_assigned = False
         if isinstance(data, InternalFrame):
             assert columns is None
@@ -5504,12 +5499,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         from pyspark.pandas.indexes import MultiIndex
         from pyspark.pandas.series import IndexOpsMixin
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         for k, v in kwargs.items():
             is_invalid_assignee = (
                 not (isinstance(v, (IndexOpsMixin, Column)) or callable(v) or is_scalar(v))
@@ -5540,7 +5530,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
                     scol, field = pairs[label[: len(label) - i]]
 
                     name = self._internal.spark_column_name_for(label)
-                    scol = scol.alias(name)  # type: ignore[attr-defined]
+                    scol = scol.alias(name)
                     if field is not None:
                         field = field.copy(name=name)
                     break
@@ -5554,7 +5544,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         for label, (scol, field) in pairs.items():
             if label not in set(i[: len(label)] for i in self._internal.column_labels):
                 name = name_like_string(label)
-                scols.append(scol.alias(name))  # type: ignore[attr-defined]
+                scols.append(scol.alias(name))
                 if field is not None:
                     field = field.copy(name=name)
                 data_fields.append(field)
@@ -7495,12 +7485,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         if na_position not in ("first", "last"):
             raise ValueError("invalid na_position: '{}'".format(na_position))
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         # Mapper: Get a spark colum
         # n function for (ascending, na_position) combination
         mapper = {
@@ -7509,12 +7494,12 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             (False, "first"): Column.desc_nulls_first,
             (False, "last"): Column.desc_nulls_last,
         }
-        by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by, ascending)]  # type: ignore
+        by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by, ascending)]
 
         natural_order_scol = F.col(NATURAL_ORDER_COLUMN_NAME)
 
         if keep == "last":
-            natural_order_scol = Column.desc(natural_order_scol)  # type: ignore
+            natural_order_scol = Column.desc(natural_order_scol)
         elif keep == "all":
             raise NotImplementedError("`keep`=all is not implemented yet.")
         elif keep != "first":
@@ -13664,12 +13649,7 @@ def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) -> Any
     """
     Performs a reduction on a spark DataFrame, the functions being known SQL aggregate functions.
     """
-    if is_remote():
-        from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
-
-        SparkDataFrame = ConnectDataFrame
-    else:
-        SparkDataFrame = PySparkDataFrame  # type: ignore[assignment]
+    SparkDataFrame = get_dataframe_class()
     assert isinstance(sdf, SparkDataFrame)
     sdf0 = sdf.agg(*aggs)
     lst = sdf0.limit(2).toPandas()
diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py
index 35f52012944..a8fd07aa2a7 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -1822,7 +1822,7 @@ class Index(IndexOpsMixin):
         # when self._scol has name of '__index_level_0__'
         index_value_column_format = "__index_value_{}__"
 
-        sdf = self._internal._sdf  # type: ignore[has-type]
+        sdf = self._internal._sdf
         index_value_column_names = [
             verify_temp_column_name(sdf, index_value_column_format.format(i))
             for i in range(self._internal.index_level)
diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py
index 93a323cd5b9..dd93e31d023 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -21,8 +21,9 @@ from typing import Any, Callable, Iterator, List, Optional, Tuple, Union, cast,
 import pandas as pd
 from pandas.api.types import is_hashable, is_list_like  # type: ignore[attr-defined]
 
-from pyspark.sql import functions as F, Column, Window
+from pyspark.sql import functions as F, Column as PySparkColumn, Window
 from pyspark.sql.types import DataType
+from pyspark.sql.utils import get_column_class
 
 # For running doctests and reference resolution in PyCharm.
 from pyspark import pandas as ps
@@ -136,7 +137,7 @@ class MultiIndex(Index):
         raise TypeError("TypeError: cannot perform __abs__ with this index type: MultiIndex")
 
     def _with_new_scol(
-        self, scol: Column, *, field: Optional[InternalField] = None
+        self, scol: PySparkColumn, *, field: Optional[InternalField] = None
     ) -> "MultiIndex":
         raise NotImplementedError("Not supported for type MultiIndex")
 
@@ -497,7 +498,10 @@ class MultiIndex(Index):
     @staticmethod
     def _comparator_for_monotonic_increasing(
         data_type: DataType,
-    ) -> Callable[[Column, Column, Callable[[Column, Column], Column]], Column]:
+    ) -> Callable[
+        [PySparkColumn, PySparkColumn, Callable[[PySparkColumn, PySparkColumn], PySparkColumn]],
+        PySparkColumn,
+    ]:
         return compare_disallow_null
 
     def _is_monotonic(self, order: str) -> bool:
@@ -511,6 +515,7 @@ class MultiIndex(Index):
 
         cond = F.lit(True)
         has_not_null = F.lit(True)
+        Column = get_column_class()
         for scol in self._internal.index_spark_columns[::-1]:
             data_type = self._internal.spark_type_for(scol)
             prev = F.lag(scol, 1).over(window)
@@ -545,7 +550,10 @@ class MultiIndex(Index):
     @staticmethod
     def _comparator_for_monotonic_decreasing(
         data_type: DataType,
-    ) -> Callable[[Column, Column, Callable[[Column, Column], Column]], Column]:
+    ) -> Callable[
+        [PySparkColumn, PySparkColumn, Callable[[PySparkColumn, PySparkColumn], PySparkColumn]],
+        PySparkColumn,
+    ]:
         return compare_disallow_null
 
     def _is_monotonic_decreasing(self) -> Series:
@@ -553,6 +561,7 @@ class MultiIndex(Index):
 
         cond = F.lit(True)
         has_not_null = F.lit(True)
+        Column = get_column_class()
         for scol in self._internal.index_spark_columns[::-1]:
             data_type = self._internal.spark_type_for(scol)
             prev = F.lag(scol, 1).over(window)
diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py
index 0664bffd9ce..cf8a2c0a363 100644
--- a/python/pyspark/pandas/indexing.py
+++ b/python/pyspark/pandas/indexing.py
@@ -52,7 +52,7 @@ from pyspark.pandas.utils import (
 )
 
 # For Supporting Spark Connect
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class
 
 if TYPE_CHECKING:
     from pyspark.pandas.frame import DataFrame
@@ -261,18 +261,13 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
         """
         from pyspark.pandas.series import Series
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if rows_sel is None:
             return None, None, None
         elif isinstance(rows_sel, Series):
             return self._select_rows_by_series(rows_sel)
         elif isinstance(rows_sel, Column):
-            return self._select_rows_by_spark_column(rows_sel)  # type: ignore[arg-type]
+            return self._select_rows_by_spark_column(rows_sel)
         elif isinstance(rows_sel, slice):
             if rows_sel == slice(None):
                 # If slice is None - select everything, so nothing to do
@@ -313,12 +308,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
         """
         from pyspark.pandas.series import Series
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if cols_sel is None:
             column_labels = self._internal.column_labels
             data_spark_columns = self._internal.data_spark_columns
@@ -327,9 +317,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
         elif isinstance(cols_sel, Series):
             return self._select_cols_by_series(cols_sel, missing_keys)
         elif isinstance(cols_sel, Column):
-            return self._select_cols_by_spark_column(
-                cols_sel, missing_keys  # type: ignore[arg-type]
-            )
+            return self._select_cols_by_spark_column(cols_sel, missing_keys)
         elif isinstance(cols_sel, slice):
             if cols_sel == slice(None):
                 # If slice is None - select everything, so nothing to do
@@ -593,12 +581,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
         from pyspark.pandas.frame import DataFrame
         from pyspark.pandas.series import Series, first_series
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if self._is_series:
             if (
                 isinstance(key, Series)
@@ -1139,16 +1122,9 @@ class LocIndexer(LocIndexerLike):
                     )
                 )[::-1]:
                     compare = MultiIndex._comparator_for_monotonic_increasing(dt)
-                    if is_remote():
-                        from pyspark.sql.connect.column import Column as ConnectColumn
-
-                        Column = ConnectColumn
-                    else:
-                        Column = PySparkColumn  # type: ignore[assignment]
+                    Column = get_column_class()
                     cond = F.when(scol.eqNullSafe(F.lit(value).cast(dt)), cond).otherwise(
-                        compare(
-                            scol, F.lit(value).cast(dt), Column.__gt__  # type: ignore[arg-type]
-                        )
+                        compare(scol, F.lit(value).cast(dt), Column.__gt__)
                     )
                 conds.append(cond)
             if stop is not None:
@@ -1161,16 +1137,9 @@ class LocIndexer(LocIndexerLike):
                     )
                 )[::-1]:
                     compare = MultiIndex._comparator_for_monotonic_increasing(dt)
-                    if is_remote():
-                        from pyspark.sql.connect.column import Column as ConnectColumn
-
-                        Column = ConnectColumn
-                    else:
-                        Column = PySparkColumn  # type: ignore[assignment]
+                    Column = get_column_class()
                     cond = F.when(scol.eqNullSafe(F.lit(value).cast(dt)), cond).otherwise(
-                        compare(
-                            scol, F.lit(value).cast(dt), Column.__lt__  # type: ignore[arg-type]
-                        )
+                        compare(scol, F.lit(value).cast(dt), Column.__lt__)
                     )
                 conds.append(cond)
 
@@ -1328,12 +1297,7 @@ class LocIndexer(LocIndexerLike):
     ]:
         from pyspark.pandas.series import Series
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if all(isinstance(key, Series) for key in cols_sel):
             column_labels = [key._column_label for key in cols_sel]
             data_spark_columns = [key.spark.column for key in cols_sel]
@@ -1837,12 +1801,7 @@ class iLocIndexer(LocIndexerLike):
             )
 
     def __setitem__(self, key: Any, value: Any) -> None:
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if not isinstance(value, Column) and is_list_like(value):
             iloc_item = self[key]
             if not is_list_like(key) or not is_list_like(iloc_item):
diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py
index e4d8b4dbe5a..e025d91e7b7 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -42,7 +42,7 @@ from pyspark.sql.types import (  # noqa: F401
 from pyspark.sql.utils import is_timestamp_ntz_preferred
 
 # For supporting Spark Connect
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import is_remote, get_column_class, get_dataframe_class
 
 # For running doctests and reference resolution in PyCharm.
 from pyspark import pandas as ps
@@ -624,12 +624,7 @@ class InternalFrame:
         >>> internal.column_label_names
         [('column_labels_a',), ('column_labels_b',)]
         """
-        if is_remote():
-            from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
-
-            SparkDataFrame = ConnectDataFrame
-        else:
-            SparkDataFrame = PySparkDataFrame  # type: ignore[assignment]
+        SparkDataFrame = get_dataframe_class()
         assert isinstance(spark_frame, SparkDataFrame)
         assert not spark_frame.isStreaming, "pandas-on-Spark does not support Structured Streaming."
 
@@ -682,12 +677,7 @@ class InternalFrame:
         self._sdf = spark_frame
 
         # index_spark_columns
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn
+        Column = get_column_class()
         assert all(
             isinstance(index_scol, Column) for index_scol in index_spark_columns
         ), index_spark_columns
@@ -761,7 +751,7 @@ class InternalFrame:
             ]
 
         assert all(
-            isinstance(ops.dtype, Dtype.__args__)
+            isinstance(ops.dtype, Dtype.__args__)  # type: ignore[attr-defined]
             and (
                 ops.dtype == np.dtype("object")
                 or as_spark_type(ops.dtype, raise_error=False) is not None
@@ -795,7 +785,7 @@ class InternalFrame:
         self._index_fields: List[InternalField] = index_fields
 
         assert all(
-            isinstance(ops.dtype, Dtype.__args__)
+            isinstance(ops.dtype, Dtype.__args__)  # type: ignore[attr-defined]
             and (
                 ops.dtype == np.dtype("object")
                 or as_spark_type(ops.dtype, raise_error=False) is not None
@@ -1000,12 +990,7 @@ class InternalFrame:
 
     def spark_column_name_for(self, label_or_scol: Union[Label, PySparkColumn]) -> str:
         """Return the actual Spark column name for the given column label."""
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if isinstance(label_or_scol, Column):
             return self.spark_frame.select(label_or_scol).columns[0]
         else:
@@ -1013,12 +998,7 @@ class InternalFrame:
 
     def spark_type_for(self, label_or_scol: Union[Label, PySparkColumn]) -> DataType:
         """Return DataType for the given column label."""
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if isinstance(label_or_scol, Column):
             return self.spark_frame.select(label_or_scol).schema[0].dataType
         else:
@@ -1026,12 +1006,7 @@ class InternalFrame:
 
     def spark_column_nullable_for(self, label_or_scol: Union[Label, PySparkColumn]) -> bool:
         """Return nullability for the given column label."""
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if isinstance(label_or_scol, Column):
             return self.spark_frame.select(label_or_scol).schema[0].nullable
         else:
@@ -1048,7 +1023,7 @@ class InternalFrame:
     @property
     def spark_frame(self) -> PySparkDataFrame:
         """Return the managed Spark DataFrame."""
-        return self._sdf  # type: ignore[has-type]
+        return self._sdf
 
     @lazy_property
     def data_spark_column_names(self) -> List[str]:
diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py
index e90a4fafae9..3563a6d81b4 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -96,7 +96,7 @@ from pyspark.pandas.indexes import Index, DatetimeIndex, TimedeltaIndex
 from pyspark.pandas.indexes.multi import MultiIndex
 
 # For Supporting Spark Connect
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class
 
 __all__ = [
     "from_pandas",
@@ -3434,12 +3434,7 @@ def merge_asof(
     else:
         on = None
 
-    if is_remote():
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
-        Column = ConnectColumn
-    else:
-        Column = PySparkColumn  # type: ignore[assignment]
+    Column = get_column_class()
     if tolerance is not None and not isinstance(tolerance, Column):
         tolerance = F.lit(tolerance)
 
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index ca9d39bb695..0f1e814946a 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -70,7 +70,7 @@ from pyspark.sql.types import (
     TimestampType,
 )
 from pyspark.sql.window import Window
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class
 
 from pyspark import pandas as ps  # For running doctests and reference resolution in PyCharm.
 from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T
@@ -4222,12 +4222,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         if self._internal.index_level > 1:
             raise NotImplementedError("rank do not support MultiIndex now")
 
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if ascending:
             asc_func = Column.asc
         else:
@@ -4236,8 +4231,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         if method == "first":
             window = (
                 Window.orderBy(
-                    asc_func(self.spark.column),  # type: ignore[arg-type]
-                    asc_func(F.col(NATURAL_ORDER_COLUMN_NAME)),  # type: ignore[arg-type]
+                    asc_func(self.spark.column),
+                    asc_func(F.col(NATURAL_ORDER_COLUMN_NAME)),
                 )
                 .partitionBy(*part_cols)
                 .rowsBetween(Window.unboundedPreceding, Window.currentRow)
@@ -4245,7 +4240,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             scol = F.row_number().over(window)
         elif method == "dense":
             window = (
-                Window.orderBy(asc_func(self.spark.column))  # type: ignore[arg-type]
+                Window.orderBy(asc_func(self.spark.column))
                 .partitionBy(*part_cols)
                 .rowsBetween(Window.unboundedPreceding, Window.currentRow)
             )
@@ -4258,7 +4253,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             elif method == "max":
                 stat_func = F.max
             window1 = (
-                Window.orderBy(asc_func(self.spark.column))  # type: ignore[arg-type]
+                Window.orderBy(asc_func(self.spark.column))
                 .partitionBy(*part_cols)
                 .rowsBetween(Window.unboundedPreceding, Window.currentRow)
             )
diff --git a/python/pyspark/pandas/spark/accessors.py b/python/pyspark/pandas/spark/accessors.py
index e3098bb47a2..f55f70e0092 100644
--- a/python/pyspark/pandas/spark/accessors.py
+++ b/python/pyspark/pandas/spark/accessors.py
@@ -30,7 +30,7 @@ from pyspark.pandas._typing import IndexOpsLike
 from pyspark.pandas.internal import InternalField
 
 # For Supporting Spark Connect
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import get_column_class, get_dataframe_class
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import OptionalPrimitiveType
@@ -119,12 +119,7 @@ class SparkIndexOpsMethods(Generic[IndexOpsLike], metaclass=ABCMeta):
         if isinstance(self._data, MultiIndex):
             raise NotImplementedError("MultiIndex does not support spark.transform yet.")
         output = func(self._data.spark.column)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if not isinstance(output, Column):
             raise ValueError(
                 "The output of the function [%s] should be of a "
@@ -200,12 +195,7 @@ class SparkSeriesMethods(SparkIndexOpsMethods["ps.Series"]):
         from pyspark.pandas.internal import HIDDEN_COLUMNS
 
         output = func(self._data.spark.column)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
+        Column = get_column_class()
         if not isinstance(output, Column):
             raise ValueError(
                 "The output of the function [%s] should be of a "
@@ -952,12 +942,7 @@ class SparkFrameMethods:
         2  3      1
         """
         output = func(self.frame(index_col))
-        if is_remote():
-            from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
-
-            SparkDataFrame = ConnectDataFrame
-        else:
-            SparkDataFrame = PySparkDataFrame  # type: ignore[assignment]
+        SparkDataFrame = get_dataframe_class()
         if not isinstance(output, SparkDataFrame):
             raise ValueError(
                 "The output of the function [%s] should be of a "
diff --git a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
index d18a20d7290..b1e185389f3 100644
--- a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
+++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
@@ -35,14 +35,16 @@ class IndexesParityTests(
     def test_append(self):
         super().test_append()
 
+    @unittest.skip(
+        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
+    )
+    def test_monotonic(self):
+        super().test_monotonic()
+
     @unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
     def test_factorize(self):
         super().test_factorize()
 
-    @unittest.skip("TODO(SPARK-43703): Enable IndexesParityTests.test_monotonic.")
-    def test_monotonic(self):
-        super().test_monotonic()
-
     @unittest.skip("TODO(SPARK-43704): Enable IndexesParityTests.test_to_series.")
     def test_to_series(self):
         super().test_to_series()
diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py
index f9805a91db6..c66b3359e77 100644
--- a/python/pyspark/pandas/utils.py
+++ b/python/pyspark/pandas/utils.py
@@ -39,7 +39,7 @@ import warnings
 
 from pyspark.sql import functions as F, Column, DataFrame as PySparkDataFrame, SparkSession
 from pyspark.sql.types import DoubleType
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import is_remote, get_dataframe_class
 from pyspark.errors import PySparkTypeError
 import pandas as pd
 from pandas.api.types import is_list_like  # type: ignore[attr-defined]
@@ -918,12 +918,7 @@ def verify_temp_column_name(
         )
         column_name = column_name_or_label
 
-    if is_remote():
-        from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
-
-        SparkDataFrame = ConnectDataFrame
-    else:
-        SparkDataFrame = PySparkDataFrame  # type: ignore[assignment]
+    SparkDataFrame = get_dataframe_class()
     assert isinstance(df, SparkDataFrame), type(df)
     assert (
         column_name not in df.columns
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index f5a5c88b8d3..7ecfa65dcd1 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -16,7 +16,7 @@
 #
 import functools
 import os
-from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union
+from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type
 
 from py4j.java_collections import JavaArray
 from py4j.java_gateway import (
@@ -45,6 +45,7 @@ from pyspark.find_spark_home import _find_spark_home
 if TYPE_CHECKING:
     from pyspark.sql.session import SparkSession
     from pyspark.sql.dataframe import DataFrame
+    from pyspark.sql.column import Column
     from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
 
 has_numpy = False
@@ -259,3 +260,25 @@ def pyspark_column_op(
         fillna = None
     # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
     return result.fillna(fillna) if fillna is not None else result
+
+
+def get_column_class() -> Type["Column"]:
+    from pyspark.sql.column import Column as PySparkColumn
+
+    if is_remote():
+        from pyspark.sql.connect.column import Column as ConnectColumn
+
+        return ConnectColumn  # type: ignore[return-value]
+    else:
+        return PySparkColumn
+
+
+def get_dataframe_class() -> Type["DataFrame"]:
+    from pyspark.sql.dataframe import DataFrame as PySparkDataFrame
+
+    if is_remote():
+        from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+
+        return ConnectDataFrame  # type: ignore[return-value]
+    else:
+        return PySparkDataFrame


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org