You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2021/10/12 20:37:12 UTC

[spark] branch master updated: [SPARK-36951][PYTHON] Inline type hints for python/pyspark/sql/column.py

This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ba57f5  [SPARK-36951][PYTHON] Inline type hints for python/pyspark/sql/column.py
3ba57f5 is described below

commit 3ba57f5edc5594ee676249cd309b8f0d8248462e
Author: Xinrong Meng <xi...@databricks.com>
AuthorDate: Tue Oct 12 13:36:22 2021 -0700

    [SPARK-36951][PYTHON] Inline type hints for python/pyspark/sql/column.py
    
    ### What changes were proposed in this pull request?
    Inline type hints for python/pyspark/sql/column.py
    
    ### Why are the changes needed?
    Currently, Inline type hints for python/pyspark/sql/column.pyi doesn't support type checking within function bodies. So we inline type hints to support that.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing test.
    
    Closes #34226 from xinrong-databricks/inline_column.
    
    Authored-by: Xinrong Meng <xi...@databricks.com>
    Signed-off-by: Takuya UESHIN <ue...@databricks.com>
---
 python/pyspark/sql/column.py      | 236 ++++++++++++++++++++++++++++----------
 python/pyspark/sql/column.pyi     | 118 -------------------
 python/pyspark/sql/dataframe.py   |  12 +-
 python/pyspark/sql/functions.py   |   3 +-
 python/pyspark/sql/observation.py |   5 +-
 python/pyspark/sql/window.py      |   4 +-
 6 files changed, 190 insertions(+), 188 deletions(-)

diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index c46b0eb..a3e3e9e 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -18,25 +18,43 @@
 import sys
 import json
 import warnings
+from typing import (
+    cast,
+    overload,
+    Any,
+    Callable,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    TYPE_CHECKING,
+    Union
+)
+
+from py4j.java_gateway import JavaObject
 
 from pyspark import copy_func
 from pyspark.context import SparkContext
 from pyspark.sql.types import DataType, StructField, StructType, IntegerType, StringType
 
+if TYPE_CHECKING:
+    from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral
+    from pyspark.sql.window import WindowSpec
+
 __all__ = ["Column"]
 
 
-def _create_column_from_literal(literal):
-    sc = SparkContext._active_spark_context
+def _create_column_from_literal(literal: Union["LiteralType", "DecimalLiteral"]) -> "Column":
+    sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
     return sc._jvm.functions.lit(literal)
 
 
-def _create_column_from_name(name):
-    sc = SparkContext._active_spark_context
+def _create_column_from_name(name: str) -> "Column":
+    sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
     return sc._jvm.functions.col(name)
 
 
-def _to_java_column(col):
+def _to_java_column(col: "ColumnOrName") -> JavaObject:
     if isinstance(col, Column):
         jcol = col._jc
     elif isinstance(col, str):
@@ -50,7 +68,11 @@ def _to_java_column(col):
     return jcol
 
 
-def _to_seq(sc, cols, converter=None):
+def _to_seq(
+    sc: SparkContext,
+    cols: Iterable["ColumnOrName"],
+    converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None,
+) -> JavaObject:
     """
     Convert a list of Column (or names) into a JVM Seq of Column.
 
@@ -59,10 +81,14 @@ def _to_seq(sc, cols, converter=None):
     """
     if converter:
         cols = [converter(c) for c in cols]
-    return sc._jvm.PythonUtils.toSeq(cols)
+    return sc._jvm.PythonUtils.toSeq(cols)  # type: ignore[attr-defined]
 
 
-def _to_list(sc, cols, converter=None):
+def _to_list(
+    sc: SparkContext,
+    cols: List["ColumnOrName"],
+    converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None,
+) -> JavaObject:
     """
     Convert a list of Column (or names) into a JVM (Scala) List of Column.
 
@@ -71,30 +97,37 @@ def _to_list(sc, cols, converter=None):
     """
     if converter:
         cols = [converter(c) for c in cols]
-    return sc._jvm.PythonUtils.toList(cols)
+    return sc._jvm.PythonUtils.toList(cols)  # type: ignore[attr-defined]
 
 
-def _unary_op(name, doc="unary operator"):
+def _unary_op(
+    name: str,
+    doc: str = "unary operator",
+) -> Callable[["Column"], "Column"]:
     """ Create a method for given unary operator """
-    def _(self):
+    def _(self: "Column") -> "Column":
         jc = getattr(self._jc, name)()
         return Column(jc)
     _.__doc__ = doc
     return _
 
 
-def _func_op(name, doc=''):
-    def _(self):
-        sc = SparkContext._active_spark_context
+def _func_op(name: str, doc: str = '') -> Callable[["Column"], "Column"]:
+    def _(self: "Column") -> "Column":
+        sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
         jc = getattr(sc._jvm.functions, name)(self._jc)
         return Column(jc)
     _.__doc__ = doc
     return _
 
 
-def _bin_func_op(name, reverse=False, doc="binary function"):
-    def _(self, other):
-        sc = SparkContext._active_spark_context
+def _bin_func_op(
+    name: str,
+    reverse: bool = False,
+    doc: str = "binary function",
+) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"]:
+    def _(self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral"]) -> "Column":
+        sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
         fn = getattr(sc._jvm.functions, name)
         jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
         njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
@@ -103,10 +136,19 @@ def _bin_func_op(name, reverse=False, doc="binary function"):
     return _
 
 
-def _bin_op(name, doc="binary operator"):
+def _bin_op(
+    name: str,
+    doc: str = "binary operator",
+) -> Callable[
+    ["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]],
+    "Column"
+]:
     """ Create a method for given binary operator
     """
-    def _(self, other):
+    def _(
+        self: "Column",
+        other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
+    ) -> "Column":
         jc = other._jc if isinstance(other, Column) else other
         njc = getattr(self._jc, name)(jc)
         return Column(njc)
@@ -114,10 +156,13 @@ def _bin_op(name, doc="binary operator"):
     return _
 
 
-def _reverse_op(name, doc="binary operator"):
+def _reverse_op(
+    name: str,
+    doc: str = "binary operator",
+) -> Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"]:
     """ Create a method for binary operator (this object is on right side)
     """
-    def _(self, other):
+    def _(self: "Column", other: Union["LiteralType", "DecimalLiteral"]) -> "Column":
         jother = _create_column_from_literal(other)
         jc = getattr(jother, name)(self._jc)
         return Column(jc)
@@ -144,29 +189,81 @@ class Column(object):
     .. versionadded:: 1.3.0
     """
 
-    def __init__(self, jc):
+    def __init__(self, jc: JavaObject) -> None:
         self._jc = jc
 
     # arithmetic operators
     __neg__ = _func_op("negate")
-    __add__ = _bin_op("plus")
-    __sub__ = _bin_op("minus")
-    __mul__ = _bin_op("multiply")
-    __div__ = _bin_op("divide")
-    __truediv__ = _bin_op("divide")
-    __mod__ = _bin_op("mod")
-    __radd__ = _bin_op("plus")
-    __rsub__ = _reverse_op("minus")
-    __rmul__ = _bin_op("multiply")
-    __rdiv__ = _reverse_op("divide")
-    __rtruediv__ = _reverse_op("divide")
-    __rmod__ = _reverse_op("mod")
+    __add__ = cast(
+        Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("plus")
+    )
+    __sub__ = cast(
+        Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("minus")
+    )
+    __mul__ = cast(
+        Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("multiply")
+    )
+    __div__ = cast(
+        Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("divide")
+    )
+    __truediv__ = cast(
+        Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("divide")
+    )
+    __mod__ = cast(
+        Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("mod")
+    )
+    __radd__ = cast(
+        Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("plus")
+    )
+    __rsub__ = cast(
+        Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+        _reverse_op("minus")
+    )
+    __rmul__ = cast(
+        Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_op("multiply")
+    )
+    __rdiv__ = cast(
+        Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+        _reverse_op("divide")
+    )
+    __rtruediv__ = cast(
+        Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+        _reverse_op("divide")
+    )
+    __rmod__ = cast(
+        Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+        _reverse_op("mod")
+    )
+
     __pow__ = _bin_func_op("pow")
-    __rpow__ = _bin_func_op("pow", reverse=True)
+    __rpow__ = cast(
+        Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+        _bin_func_op("pow", reverse=True)
+    )
 
     # logistic operators
-    __eq__ = _bin_op("equalTo")
-    __ne__ = _bin_op("notEqual")
+    def __eq__(   # type: ignore[override]
+        self,
+        other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
+    ) -> "Column":
+        """binary function"""
+        return _bin_op("equalTo")(self, other)
+
+    def __ne__(   # type: ignore[override]
+        self,
+        other: Any,
+    ) -> "Column":
+        """binary function"""
+        return _bin_op("notEqual")(self, other)
+
     __lt__ = _bin_op("lt")
     __le__ = _bin_op("leq")
     __ge__ = _bin_op("geq")
@@ -243,7 +340,7 @@ class Column(object):
     __ror__ = _bin_op("or")
 
     # container operators
-    def __contains__(self, item):
+    def __contains__(self, item: Any) -> None:
         raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' "
                          "in a string column or 'array_contains' function for an array column.")
 
@@ -301,7 +398,7 @@ class Column(object):
     bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc)
     bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc)
 
-    def getItem(self, key):
+    def getItem(self, key: Any) -> "Column":
         """
         An expression that gets an item at position ``ordinal`` out of a list,
         or gets an item by key out of a dict.
@@ -327,7 +424,7 @@ class Column(object):
             )
         return self[key]
 
-    def getField(self, name):
+    def getField(self, name: Any) -> "Column":
         """
         An expression that gets a field by name in a :class:`StructType`.
 
@@ -359,7 +456,7 @@ class Column(object):
             )
         return self[name]
 
-    def withField(self, fieldName, col):
+    def withField(self, fieldName: str, col: "Column") -> "Column":
         """
         An expression that adds/replaces a field in :class:`StructType` by name.
 
@@ -391,7 +488,7 @@ class Column(object):
 
         return Column(self._jc.withField(fieldName, col._jc))
 
-    def dropFields(self, *fieldNames):
+    def dropFields(self, *fieldNames: str) -> "Column":
         """
         An expression that drops fields in :class:`StructType` by name.
         This is a no-op if schema doesn't contain field name(s).
@@ -441,17 +538,17 @@ class Column(object):
         +--------------+
 
         """
-        sc = SparkContext._active_spark_context
+        sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
 
         jc = self._jc.dropFields(_to_seq(sc, fieldNames))
         return Column(jc)
 
-    def __getattr__(self, item):
+    def __getattr__(self, item: Any) -> "Column":
         if item.startswith("__"):
             raise AttributeError(item)
         return self[item]
 
-    def __getitem__(self, k):
+    def __getitem__(self, k: Any) -> "Column":
         if isinstance(k, slice):
             if k.step is not None:
                 raise ValueError("slice with step is not supported.")
@@ -459,7 +556,7 @@ class Column(object):
         else:
             return _bin_op("apply")(self, k)
 
-    def __iter__(self):
+    def __iter__(self) -> None:
         raise TypeError("Column is not iterable")
 
     # string methods
@@ -565,7 +662,15 @@ class Column(object):
     startswith = _bin_op("startsWith", _startswith_doc)
     endswith = _bin_op("endsWith", _endswith_doc)
 
-    def substr(self, startPos, length):
+    @overload
+    def substr(self, startPos: int, length: int) -> "Column":
+        ...
+
+    @overload
+    def substr(self, startPos: "Column", length: "Column") -> "Column":
+        ...
+
+    def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -> "Column":
         """
         Return a :class:`Column` which is a substring of the column.
 
@@ -594,12 +699,12 @@ class Column(object):
         if isinstance(startPos, int):
             jc = self._jc.substr(startPos, length)
         elif isinstance(startPos, Column):
-            jc = self._jc.substr(startPos._jc, length._jc)
+            jc = self._jc.substr(cast("Column", startPos)._jc, cast("Column", length)._jc)
         else:
             raise TypeError("Unexpected type: %s" % type(startPos))
         return Column(jc)
 
-    def isin(self, *cols):
+    def isin(self, *cols: Any) -> "Column":
         """
         A boolean expression that is evaluated to true if the value of this
         expression is contained by the evaluated values of the arguments.
@@ -614,9 +719,12 @@ class Column(object):
         [Row(age=2, name='Alice')]
         """
         if len(cols) == 1 and isinstance(cols[0], (list, set)):
-            cols = cols[0]
-        cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
-        sc = SparkContext._active_spark_context
+            cols = cast(Tuple, cols[0])
+        cols = cast(
+            Tuple,
+            [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
+        )
+        sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
         jc = getattr(self._jc, "isin")(_to_seq(sc, cols))
         return Column(jc)
 
@@ -730,7 +838,7 @@ class Column(object):
     isNull = _unary_op("isNull", _isNull_doc)
     isNotNull = _unary_op("isNotNull", _isNotNull_doc)
 
-    def alias(self, *alias, **kwargs):
+    def alias(self, *alias: str, **kwargs: Any) -> "Column":
         """
         Returns this column aliased with a new name or names (in the case of expressions that
         return more than one column, such as explode).
@@ -763,7 +871,7 @@ class Column(object):
         metadata = kwargs.pop('metadata', None)
         assert not kwargs, 'Unexpected kwargs where passed: %s' % kwargs
 
-        sc = SparkContext._active_spark_context
+        sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
         if len(alias) == 1:
             if metadata:
                 jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(
@@ -778,7 +886,7 @@ class Column(object):
 
     name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.")
 
-    def cast(self, dataType):
+    def cast(self, dataType: Union[DataType, str]) -> "Column":
         """
         Casts the column into type ``dataType``.
 
@@ -804,7 +912,11 @@ class Column(object):
 
     astype = copy_func(cast, sinceversion=1.4, doc=":func:`astype` is an alias for :func:`cast`.")
 
-    def between(self, lowerBound, upperBound):
+    def between(
+        self,
+        lowerBound: Union["Column", "LiteralType", "DateTimeLiteral", "DecimalLiteral"],
+        upperBound: Union["Column", "LiteralType", "DateTimeLiteral", "DecimalLiteral"],
+    ) -> "Column":
         """
         True if the current column is between the lower bound and upper bound, inclusive.
 
@@ -822,7 +934,7 @@ class Column(object):
         """
         return (self >= lowerBound) & (self <= upperBound)
 
-    def when(self, condition, value):
+    def when(self, condition: "Column", value: Any) -> "Column":
         """
         Evaluates a list of conditions and returns one of multiple possible result expressions.
         If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
@@ -857,7 +969,7 @@ class Column(object):
         jc = self._jc.when(condition._jc, v)
         return Column(jc)
 
-    def otherwise(self, value):
+    def otherwise(self, value: Any) -> "Column":
         """
         Evaluates a list of conditions and returns one of multiple possible result expressions.
         If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
@@ -888,7 +1000,7 @@ class Column(object):
         jc = self._jc.otherwise(v)
         return Column(jc)
 
-    def over(self, window):
+    def over(self, window: "WindowSpec") -> "Column":
         """
         Define a windowing column.
 
@@ -924,16 +1036,16 @@ class Column(object):
         jc = self._jc.over(window._jspec)
         return Column(jc)
 
-    def __nonzero__(self):
+    def __nonzero__(self) -> None:
         raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
                          "'~' for 'not' when building DataFrame boolean expressions.")
     __bool__ = __nonzero__
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "Column<'%s'>" % self._jc.toString()
 
 
-def _test():
+def _test() -> None:
     import doctest
     from pyspark.sql import SparkSession
     import pyspark.sql.column
diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi
deleted file mode 100644
index 36c1bcc..0000000
--- a/python/pyspark/sql/column.pyi
+++ /dev/null
@@ -1,118 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import Any, Union
-
-from pyspark.sql._typing import LiteralType, DecimalLiteral, DateTimeLiteral
-from pyspark.sql.types import (  # noqa: F401
-    DataType,
-    StructField,
-    StructType,
-    IntegerType,
-    StringType,
-)
-from pyspark.sql.window import WindowSpec
-
-from py4j.java_gateway import JavaObject  # type: ignore[import]
-
-class Column:
-    def __init__(self, jc: JavaObject) -> None: ...
-    def __neg__(self) -> Column: ...
-    def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
-    def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
-    def __mul__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
-    def __div__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
-    def __truediv__(
-        self, other: Union[Column, LiteralType, DecimalLiteral]
-    ) -> Column: ...
-    def __mod__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
-    def __radd__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
-    def __rsub__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
-    def __rmul__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
-    def __rdiv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
-    def __rtruediv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
-    def __rmod__(self, other: Union[bool, int, float, DecimalLiteral]) -> Column: ...
-    def __pow__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
-    def __rpow__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
-    def __eq__(self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]) -> Column: ...  # type: ignore[override]
-    def __ne__(self, other: Any) -> Column: ...  # type: ignore[override]
-    def __lt__(
-        self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
-    ) -> Column: ...
-    def __le__(
-        self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
-    ) -> Column: ...
-    def __ge__(
-        self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
-    ) -> Column: ...
-    def __gt__(
-        self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
-    ) -> Column: ...
-    def eqNullSafe(
-        self, other: Union[Column, LiteralType, DecimalLiteral]
-    ) -> Column: ...
-    def __and__(self, other: Column) -> Column: ...
-    def __or__(self, other: Column) -> Column: ...
-    def __invert__(self) -> Column: ...
-    def __rand__(self, other: Column) -> Column: ...
-    def __ror__(self, other: Column) -> Column: ...
-    def __contains__(self, other: Any) -> Column: ...
-    def __getitem__(self, other: Any) -> Column: ...
-    def bitwiseOR(self, other: Union[Column, int]) -> Column: ...
-    def bitwiseAND(self, other: Union[Column, int]) -> Column: ...
-    def bitwiseXOR(self, other: Union[Column, int]) -> Column: ...
-    def getItem(self, key: Any) -> Column: ...
-    def getField(self, name: Any) -> Column: ...
-    def withField(self, fieldName: str, col: Column) -> Column: ...
-    def dropFields(self, *fieldNames: str) -> Column: ...
-    def __getattr__(self, item: Any) -> Column: ...
-    def __iter__(self) -> None: ...
-    def rlike(self, item: str) -> Column: ...
-    def like(self, item: str) -> Column: ...
-    def startswith(self, item: Union[str, Column]) -> Column: ...
-    def endswith(self, item: Union[str, Column]) -> Column: ...
-    @overload
-    def substr(self, startPos: int, length: int) -> Column: ...
-    @overload
-    def substr(self, startPos: Column, length: Column) -> Column: ...
-    def __getslice__(self, startPos: int, length: int) -> Column: ...
-    def isin(self, *cols: Any) -> Column: ...
-    def asc(self) -> Column: ...
-    def asc_nulls_first(self) -> Column: ...
-    def asc_nulls_last(self) -> Column: ...
-    def desc(self) -> Column: ...
-    def desc_nulls_first(self) -> Column: ...
-    def desc_nulls_last(self) -> Column: ...
-    def isNull(self) -> Column: ...
-    def isNotNull(self) -> Column: ...
-    def alias(self, *alias: str, **kwargs: Any) -> Column: ...
-    def name(self, *alias: str) -> Column: ...
-    def cast(self, dataType: Union[DataType, str]) -> Column: ...
-    def astype(self, dataType: Union[DataType, str]) -> Column: ...
-    def between(
-        self,
-        lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
-        upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
-    ) -> Column: ...
-    def when(self, condition: Column, value: Any) -> Column: ...
-    def otherwise(self, value: Any) -> Column: ...
-    def over(self, window: WindowSpec) -> Column: ...
-    def __nonzero__(self) -> None: ...
-    def __bool__(self) -> None: ...
-    def contains(self, item: Any) -> Column: ...
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 339f8f8..223f041 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1279,7 +1279,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                 raise ValueError("Weights must be positive. Found weight value: %s" % w)
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
         rdd_array = self._jdf.randomSplit(
-            _to_list(self.sql_ctx._sc, weights), int(seed)  # type: ignore[attr-defined]
+            _to_list(
+                self.sql_ctx._sc,  # type: ignore[attr-defined]
+                cast(List["ColumnOrName"], weights)
+            ),
+            int(seed)
         )
         return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
 
@@ -1674,7 +1678,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             raise ValueError("should sort by at least one column")
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]
-        jcols = [_to_java_column(c) for c in cols]
+        jcols = [_to_java_column(cast("ColumnOrName", c)) for c in cols]
         ascending = kwargs.get('ascending', True)
         if isinstance(ascending, (bool, int)):
             if not ascending:
@@ -2723,7 +2727,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         for c in col:
             if not isinstance(c, str):
                 raise TypeError("columns should be strings, but got %r" % type(c))
-        col = _to_list(self._sc, col)
+        col = _to_list(self._sc, cast(List["ColumnOrName"], col))
 
         if not isinstance(probabilities, (list, tuple)):
             raise TypeError("probabilities should be a list or tuple")
@@ -2732,7 +2736,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         for p in probabilities:
             if not isinstance(p, (float, int)) or p < 0 or p > 1:
                 raise ValueError("probabilities should be numerical (float, int) in [0,1].")
-        probabilities = _to_list(self._sc, probabilities)
+        probabilities = _to_list(self._sc, cast(List["ColumnOrName"], probabilities))
 
         if not isinstance(relativeError, (float, int)):
             raise TypeError("relativeError should be numerical (float, int)")
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 7e0d015..717eaec 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -24,6 +24,7 @@ import functools
 import warnings
 from typing import (
     Any,
+    cast,
     Callable,
     Dict,
     List,
@@ -1770,7 +1771,7 @@ def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = Non
     """
     sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
     if arg2 is None:
-        jc = sc._jvm.functions.log(_to_java_column(arg1))
+        jc = sc._jvm.functions.log(_to_java_column(cast("ColumnOrName", arg1)))
     else:
         jc = sc._jvm.functions.log(arg1, _to_java_column(arg2))
     return Column(jc)
diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py
index 48d8176..f60e580 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, TYPE_CHECKING
 
 from pyspark.sql import column
 from pyspark.sql.column import Column
@@ -22,6 +22,9 @@ from pyspark.sql.dataframe import DataFrame
 
 __all__ = ["Observation"]
 
+if TYPE_CHECKING:
+    from pyspark import SparkContext  # noqa: F401
+
 
 class Observation:
     """Class to observe (named) metrics on a :class:`DataFrame`.
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 3054273..f1b03ab 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -16,7 +16,7 @@
 #
 
 import sys
-from typing import List, Tuple, TYPE_CHECKING, Union
+from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union
 
 from pyspark import since, SparkContext
 from pyspark.sql.column import _to_seq, _to_java_column  # type: ignore[attr-defined]
@@ -35,7 +35,7 @@ def _to_java_cols(
     sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
     if len(cols) == 1 and isinstance(cols[0], list):
         cols = cols[0]  # type: ignore[assignment]
-    return _to_seq(sc, cols, _to_java_column)
+    return _to_seq(sc, cast(Iterable["ColumnOrName"], cols), _to_java_column)
 
 
 class Window(object):

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org