You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2021/10/12 20:37:12 UTC
[spark] branch master updated: [SPARK-36951][PYTHON] Inline type
hints for python/pyspark/sql/column.py
This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3ba57f5 [SPARK-36951][PYTHON] Inline type hints for python/pyspark/sql/column.py
3ba57f5 is described below
commit 3ba57f5edc5594ee676249cd309b8f0d8248462e
Author: Xinrong Meng <xi...@databricks.com>
AuthorDate: Tue Oct 12 13:36:22 2021 -0700
[SPARK-36951][PYTHON] Inline type hints for python/pyspark/sql/column.py
### What changes were proposed in this pull request?
Inline type hints for python/pyspark/sql/column.py
### Why are the changes needed?
Currently, Inline type hints for python/pyspark/sql/column.pyi doesn't support type checking within function bodies. So we inline type hints to support that.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test.
Closes #34226 from xinrong-databricks/inline_column.
Authored-by: Xinrong Meng <xi...@databricks.com>
Signed-off-by: Takuya UESHIN <ue...@databricks.com>
---
python/pyspark/sql/column.py | 236 ++++++++++++++++++++++++++++----------
python/pyspark/sql/column.pyi | 118 -------------------
python/pyspark/sql/dataframe.py | 12 +-
python/pyspark/sql/functions.py | 3 +-
python/pyspark/sql/observation.py | 5 +-
python/pyspark/sql/window.py | 4 +-
6 files changed, 190 insertions(+), 188 deletions(-)
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index c46b0eb..a3e3e9e 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -18,25 +18,43 @@
import sys
import json
import warnings
+from typing import (
+ cast,
+ overload,
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ TYPE_CHECKING,
+ Union
+)
+
+from py4j.java_gateway import JavaObject
from pyspark import copy_func
from pyspark.context import SparkContext
from pyspark.sql.types import DataType, StructField, StructType, IntegerType, StringType
+if TYPE_CHECKING:
+ from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral
+ from pyspark.sql.window import WindowSpec
+
__all__ = ["Column"]
-def _create_column_from_literal(literal):
- sc = SparkContext._active_spark_context
+def _create_column_from_literal(literal: Union["LiteralType", "DecimalLiteral"]) -> "Column":
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
return sc._jvm.functions.lit(literal)
-def _create_column_from_name(name):
- sc = SparkContext._active_spark_context
+def _create_column_from_name(name: str) -> "Column":
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
return sc._jvm.functions.col(name)
-def _to_java_column(col):
+def _to_java_column(col: "ColumnOrName") -> JavaObject:
if isinstance(col, Column):
jcol = col._jc
elif isinstance(col, str):
@@ -50,7 +68,11 @@ def _to_java_column(col):
return jcol
-def _to_seq(sc, cols, converter=None):
+def _to_seq(
+ sc: SparkContext,
+ cols: Iterable["ColumnOrName"],
+ converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None,
+) -> JavaObject:
"""
Convert a list of Column (or names) into a JVM Seq of Column.
@@ -59,10 +81,14 @@ def _to_seq(sc, cols, converter=None):
"""
if converter:
cols = [converter(c) for c in cols]
- return sc._jvm.PythonUtils.toSeq(cols)
+ return sc._jvm.PythonUtils.toSeq(cols) # type: ignore[attr-defined]
-def _to_list(sc, cols, converter=None):
+def _to_list(
+ sc: SparkContext,
+ cols: List["ColumnOrName"],
+ converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None,
+) -> JavaObject:
"""
Convert a list of Column (or names) into a JVM (Scala) List of Column.
@@ -71,30 +97,37 @@ def _to_list(sc, cols, converter=None):
"""
if converter:
cols = [converter(c) for c in cols]
- return sc._jvm.PythonUtils.toList(cols)
+ return sc._jvm.PythonUtils.toList(cols) # type: ignore[attr-defined]
-def _unary_op(name, doc="unary operator"):
+def _unary_op(
+ name: str,
+ doc: str = "unary operator",
+) -> Callable[["Column"], "Column"]:
""" Create a method for given unary operator """
- def _(self):
+ def _(self: "Column") -> "Column":
jc = getattr(self._jc, name)()
return Column(jc)
_.__doc__ = doc
return _
-def _func_op(name, doc=''):
- def _(self):
- sc = SparkContext._active_spark_context
+def _func_op(name: str, doc: str = '') -> Callable[["Column"], "Column"]:
+ def _(self: "Column") -> "Column":
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
jc = getattr(sc._jvm.functions, name)(self._jc)
return Column(jc)
_.__doc__ = doc
return _
-def _bin_func_op(name, reverse=False, doc="binary function"):
- def _(self, other):
- sc = SparkContext._active_spark_context
+def _bin_func_op(
+ name: str,
+ reverse: bool = False,
+ doc: str = "binary function",
+) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"]:
+ def _(self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral"]) -> "Column":
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
fn = getattr(sc._jvm.functions, name)
jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
@@ -103,10 +136,19 @@ def _bin_func_op(name, reverse=False, doc="binary function"):
return _
-def _bin_op(name, doc="binary operator"):
+def _bin_op(
+ name: str,
+ doc: str = "binary operator",
+) -> Callable[
+ ["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]],
+ "Column"
+]:
""" Create a method for given binary operator
"""
- def _(self, other):
+ def _(
+ self: "Column",
+ other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
+ ) -> "Column":
jc = other._jc if isinstance(other, Column) else other
njc = getattr(self._jc, name)(jc)
return Column(njc)
@@ -114,10 +156,13 @@ def _bin_op(name, doc="binary operator"):
return _
-def _reverse_op(name, doc="binary operator"):
+def _reverse_op(
+ name: str,
+ doc: str = "binary operator",
+) -> Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"]:
""" Create a method for binary operator (this object is on right side)
"""
- def _(self, other):
+ def _(self: "Column", other: Union["LiteralType", "DecimalLiteral"]) -> "Column":
jother = _create_column_from_literal(other)
jc = getattr(jother, name)(self._jc)
return Column(jc)
@@ -144,29 +189,81 @@ class Column(object):
.. versionadded:: 1.3.0
"""
- def __init__(self, jc):
+ def __init__(self, jc: JavaObject) -> None:
self._jc = jc
# arithmetic operators
__neg__ = _func_op("negate")
- __add__ = _bin_op("plus")
- __sub__ = _bin_op("minus")
- __mul__ = _bin_op("multiply")
- __div__ = _bin_op("divide")
- __truediv__ = _bin_op("divide")
- __mod__ = _bin_op("mod")
- __radd__ = _bin_op("plus")
- __rsub__ = _reverse_op("minus")
- __rmul__ = _bin_op("multiply")
- __rdiv__ = _reverse_op("divide")
- __rtruediv__ = _reverse_op("divide")
- __rmod__ = _reverse_op("mod")
+ __add__ = cast(
+ Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("plus")
+ )
+ __sub__ = cast(
+ Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("minus")
+ )
+ __mul__ = cast(
+ Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("multiply")
+ )
+ __div__ = cast(
+ Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("divide")
+ )
+ __truediv__ = cast(
+ Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("divide")
+ )
+ __mod__ = cast(
+ Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("mod")
+ )
+ __radd__ = cast(
+ Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("plus")
+ )
+ __rsub__ = cast(
+ Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+ _reverse_op("minus")
+ )
+ __rmul__ = cast(
+ Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_op("multiply")
+ )
+ __rdiv__ = cast(
+ Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+ _reverse_op("divide")
+ )
+ __rtruediv__ = cast(
+ Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+ _reverse_op("divide")
+ )
+ __rmod__ = cast(
+ Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+ _reverse_op("mod")
+ )
+
__pow__ = _bin_func_op("pow")
- __rpow__ = _bin_func_op("pow", reverse=True)
+ __rpow__ = cast(
+ Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"],
+ _bin_func_op("pow", reverse=True)
+ )
# logistic operators
- __eq__ = _bin_op("equalTo")
- __ne__ = _bin_op("notEqual")
+ def __eq__( # type: ignore[override]
+ self,
+ other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"],
+ ) -> "Column":
+ """binary function"""
+ return _bin_op("equalTo")(self, other)
+
+ def __ne__( # type: ignore[override]
+ self,
+ other: Any,
+ ) -> "Column":
+ """binary function"""
+ return _bin_op("notEqual")(self, other)
+
__lt__ = _bin_op("lt")
__le__ = _bin_op("leq")
__ge__ = _bin_op("geq")
@@ -243,7 +340,7 @@ class Column(object):
__ror__ = _bin_op("or")
# container operators
- def __contains__(self, item):
+ def __contains__(self, item: Any) -> None:
raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' "
"in a string column or 'array_contains' function for an array column.")
@@ -301,7 +398,7 @@ class Column(object):
bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc)
bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc)
- def getItem(self, key):
+ def getItem(self, key: Any) -> "Column":
"""
An expression that gets an item at position ``ordinal`` out of a list,
or gets an item by key out of a dict.
@@ -327,7 +424,7 @@ class Column(object):
)
return self[key]
- def getField(self, name):
+ def getField(self, name: Any) -> "Column":
"""
An expression that gets a field by name in a :class:`StructType`.
@@ -359,7 +456,7 @@ class Column(object):
)
return self[name]
- def withField(self, fieldName, col):
+ def withField(self, fieldName: str, col: "Column") -> "Column":
"""
An expression that adds/replaces a field in :class:`StructType` by name.
@@ -391,7 +488,7 @@ class Column(object):
return Column(self._jc.withField(fieldName, col._jc))
- def dropFields(self, *fieldNames):
+ def dropFields(self, *fieldNames: str) -> "Column":
"""
An expression that drops fields in :class:`StructType` by name.
This is a no-op if schema doesn't contain field name(s).
@@ -441,17 +538,17 @@ class Column(object):
+--------------+
"""
- sc = SparkContext._active_spark_context
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
jc = self._jc.dropFields(_to_seq(sc, fieldNames))
return Column(jc)
- def __getattr__(self, item):
+ def __getattr__(self, item: Any) -> "Column":
if item.startswith("__"):
raise AttributeError(item)
return self[item]
- def __getitem__(self, k):
+ def __getitem__(self, k: Any) -> "Column":
if isinstance(k, slice):
if k.step is not None:
raise ValueError("slice with step is not supported.")
@@ -459,7 +556,7 @@ class Column(object):
else:
return _bin_op("apply")(self, k)
- def __iter__(self):
+ def __iter__(self) -> None:
raise TypeError("Column is not iterable")
# string methods
@@ -565,7 +662,15 @@ class Column(object):
startswith = _bin_op("startsWith", _startswith_doc)
endswith = _bin_op("endsWith", _endswith_doc)
- def substr(self, startPos, length):
+ @overload
+ def substr(self, startPos: int, length: int) -> "Column":
+ ...
+
+ @overload
+ def substr(self, startPos: "Column", length: "Column") -> "Column":
+ ...
+
+ def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -> "Column":
"""
Return a :class:`Column` which is a substring of the column.
@@ -594,12 +699,12 @@ class Column(object):
if isinstance(startPos, int):
jc = self._jc.substr(startPos, length)
elif isinstance(startPos, Column):
- jc = self._jc.substr(startPos._jc, length._jc)
+ jc = self._jc.substr(cast("Column", startPos)._jc, cast("Column", length)._jc)
else:
raise TypeError("Unexpected type: %s" % type(startPos))
return Column(jc)
- def isin(self, *cols):
+ def isin(self, *cols: Any) -> "Column":
"""
A boolean expression that is evaluated to true if the value of this
expression is contained by the evaluated values of the arguments.
@@ -614,9 +719,12 @@ class Column(object):
[Row(age=2, name='Alice')]
"""
if len(cols) == 1 and isinstance(cols[0], (list, set)):
- cols = cols[0]
- cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
- sc = SparkContext._active_spark_context
+ cols = cast(Tuple, cols[0])
+ cols = cast(
+ Tuple,
+ [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
+ )
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
jc = getattr(self._jc, "isin")(_to_seq(sc, cols))
return Column(jc)
@@ -730,7 +838,7 @@ class Column(object):
isNull = _unary_op("isNull", _isNull_doc)
isNotNull = _unary_op("isNotNull", _isNotNull_doc)
- def alias(self, *alias, **kwargs):
+ def alias(self, *alias: str, **kwargs: Any) -> "Column":
"""
Returns this column aliased with a new name or names (in the case of expressions that
return more than one column, such as explode).
@@ -763,7 +871,7 @@ class Column(object):
metadata = kwargs.pop('metadata', None)
assert not kwargs, 'Unexpected kwargs where passed: %s' % kwargs
- sc = SparkContext._active_spark_context
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
if len(alias) == 1:
if metadata:
jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(
@@ -778,7 +886,7 @@ class Column(object):
name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.")
- def cast(self, dataType):
+ def cast(self, dataType: Union[DataType, str]) -> "Column":
"""
Casts the column into type ``dataType``.
@@ -804,7 +912,11 @@ class Column(object):
astype = copy_func(cast, sinceversion=1.4, doc=":func:`astype` is an alias for :func:`cast`.")
- def between(self, lowerBound, upperBound):
+ def between(
+ self,
+ lowerBound: Union["Column", "LiteralType", "DateTimeLiteral", "DecimalLiteral"],
+ upperBound: Union["Column", "LiteralType", "DateTimeLiteral", "DecimalLiteral"],
+ ) -> "Column":
"""
True if the current column is between the lower bound and upper bound, inclusive.
@@ -822,7 +934,7 @@ class Column(object):
"""
return (self >= lowerBound) & (self <= upperBound)
- def when(self, condition, value):
+ def when(self, condition: "Column", value: Any) -> "Column":
"""
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
@@ -857,7 +969,7 @@ class Column(object):
jc = self._jc.when(condition._jc, v)
return Column(jc)
- def otherwise(self, value):
+ def otherwise(self, value: Any) -> "Column":
"""
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
@@ -888,7 +1000,7 @@ class Column(object):
jc = self._jc.otherwise(v)
return Column(jc)
- def over(self, window):
+ def over(self, window: "WindowSpec") -> "Column":
"""
Define a windowing column.
@@ -924,16 +1036,16 @@ class Column(object):
jc = self._jc.over(window._jspec)
return Column(jc)
- def __nonzero__(self):
+ def __nonzero__(self) -> None:
raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
"'~' for 'not' when building DataFrame boolean expressions.")
__bool__ = __nonzero__
- def __repr__(self):
+ def __repr__(self) -> str:
return "Column<'%s'>" % self._jc.toString()
-def _test():
+def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.column
diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi
deleted file mode 100644
index 36c1bcc..0000000
--- a/python/pyspark/sql/column.pyi
+++ /dev/null
@@ -1,118 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import Any, Union
-
-from pyspark.sql._typing import LiteralType, DecimalLiteral, DateTimeLiteral
-from pyspark.sql.types import ( # noqa: F401
- DataType,
- StructField,
- StructType,
- IntegerType,
- StringType,
-)
-from pyspark.sql.window import WindowSpec
-
-from py4j.java_gateway import JavaObject # type: ignore[import]
-
-class Column:
- def __init__(self, jc: JavaObject) -> None: ...
- def __neg__(self) -> Column: ...
- def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
- def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
- def __mul__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
- def __div__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
- def __truediv__(
- self, other: Union[Column, LiteralType, DecimalLiteral]
- ) -> Column: ...
- def __mod__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
- def __radd__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
- def __rsub__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
- def __rmul__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
- def __rdiv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
- def __rtruediv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
- def __rmod__(self, other: Union[bool, int, float, DecimalLiteral]) -> Column: ...
- def __pow__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ...
- def __rpow__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ...
- def __eq__(self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]) -> Column: ... # type: ignore[override]
- def __ne__(self, other: Any) -> Column: ... # type: ignore[override]
- def __lt__(
- self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
- ) -> Column: ...
- def __le__(
- self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
- ) -> Column: ...
- def __ge__(
- self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
- ) -> Column: ...
- def __gt__(
- self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]
- ) -> Column: ...
- def eqNullSafe(
- self, other: Union[Column, LiteralType, DecimalLiteral]
- ) -> Column: ...
- def __and__(self, other: Column) -> Column: ...
- def __or__(self, other: Column) -> Column: ...
- def __invert__(self) -> Column: ...
- def __rand__(self, other: Column) -> Column: ...
- def __ror__(self, other: Column) -> Column: ...
- def __contains__(self, other: Any) -> Column: ...
- def __getitem__(self, other: Any) -> Column: ...
- def bitwiseOR(self, other: Union[Column, int]) -> Column: ...
- def bitwiseAND(self, other: Union[Column, int]) -> Column: ...
- def bitwiseXOR(self, other: Union[Column, int]) -> Column: ...
- def getItem(self, key: Any) -> Column: ...
- def getField(self, name: Any) -> Column: ...
- def withField(self, fieldName: str, col: Column) -> Column: ...
- def dropFields(self, *fieldNames: str) -> Column: ...
- def __getattr__(self, item: Any) -> Column: ...
- def __iter__(self) -> None: ...
- def rlike(self, item: str) -> Column: ...
- def like(self, item: str) -> Column: ...
- def startswith(self, item: Union[str, Column]) -> Column: ...
- def endswith(self, item: Union[str, Column]) -> Column: ...
- @overload
- def substr(self, startPos: int, length: int) -> Column: ...
- @overload
- def substr(self, startPos: Column, length: Column) -> Column: ...
- def __getslice__(self, startPos: int, length: int) -> Column: ...
- def isin(self, *cols: Any) -> Column: ...
- def asc(self) -> Column: ...
- def asc_nulls_first(self) -> Column: ...
- def asc_nulls_last(self) -> Column: ...
- def desc(self) -> Column: ...
- def desc_nulls_first(self) -> Column: ...
- def desc_nulls_last(self) -> Column: ...
- def isNull(self) -> Column: ...
- def isNotNull(self) -> Column: ...
- def alias(self, *alias: str, **kwargs: Any) -> Column: ...
- def name(self, *alias: str) -> Column: ...
- def cast(self, dataType: Union[DataType, str]) -> Column: ...
- def astype(self, dataType: Union[DataType, str]) -> Column: ...
- def between(
- self,
- lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
- upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral],
- ) -> Column: ...
- def when(self, condition: Column, value: Any) -> Column: ...
- def otherwise(self, value: Any) -> Column: ...
- def over(self, window: WindowSpec) -> Column: ...
- def __nonzero__(self) -> None: ...
- def __bool__(self) -> None: ...
- def contains(self, item: Any) -> Column: ...
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 339f8f8..223f041 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1279,7 +1279,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise ValueError("Weights must be positive. Found weight value: %s" % w)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd_array = self._jdf.randomSplit(
- _to_list(self.sql_ctx._sc, weights), int(seed) # type: ignore[attr-defined]
+ _to_list(
+ self.sql_ctx._sc, # type: ignore[attr-defined]
+ cast(List["ColumnOrName"], weights)
+ ),
+ int(seed)
)
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
@@ -1674,7 +1678,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise ValueError("should sort by at least one column")
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
- jcols = [_to_java_column(c) for c in cols]
+ jcols = [_to_java_column(cast("ColumnOrName", c)) for c in cols]
ascending = kwargs.get('ascending', True)
if isinstance(ascending, (bool, int)):
if not ascending:
@@ -2723,7 +2727,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
for c in col:
if not isinstance(c, str):
raise TypeError("columns should be strings, but got %r" % type(c))
- col = _to_list(self._sc, col)
+ col = _to_list(self._sc, cast(List["ColumnOrName"], col))
if not isinstance(probabilities, (list, tuple)):
raise TypeError("probabilities should be a list or tuple")
@@ -2732,7 +2736,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
for p in probabilities:
if not isinstance(p, (float, int)) or p < 0 or p > 1:
raise ValueError("probabilities should be numerical (float, int) in [0,1].")
- probabilities = _to_list(self._sc, probabilities)
+ probabilities = _to_list(self._sc, cast(List["ColumnOrName"], probabilities))
if not isinstance(relativeError, (float, int)):
raise TypeError("relativeError should be numerical (float, int)")
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 7e0d015..717eaec 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -24,6 +24,7 @@ import functools
import warnings
from typing import (
Any,
+ cast,
Callable,
Dict,
List,
@@ -1770,7 +1771,7 @@ def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = Non
"""
sc = SparkContext._active_spark_context # type: ignore[attr-defined]
if arg2 is None:
- jc = sc._jvm.functions.log(_to_java_column(arg1))
+ jc = sc._jvm.functions.log(_to_java_column(cast("ColumnOrName", arg1)))
else:
jc = sc._jvm.functions.log(arg1, _to_java_column(arg2))
return Column(jc)
diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py
index 48d8176..f60e580 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, TYPE_CHECKING
from pyspark.sql import column
from pyspark.sql.column import Column
@@ -22,6 +22,9 @@ from pyspark.sql.dataframe import DataFrame
__all__ = ["Observation"]
+if TYPE_CHECKING:
+ from pyspark import SparkContext # noqa: F401
+
class Observation:
"""Class to observe (named) metrics on a :class:`DataFrame`.
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 3054273..f1b03ab 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -16,7 +16,7 @@
#
import sys
-from typing import List, Tuple, TYPE_CHECKING, Union
+from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union
from pyspark import since, SparkContext
from pyspark.sql.column import _to_seq, _to_java_column # type: ignore[attr-defined]
@@ -35,7 +35,7 @@ def _to_java_cols(
sc = SparkContext._active_spark_context # type: ignore[attr-defined]
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
- return _to_seq(sc, cols, _to_java_column)
+ return _to_seq(sc, cast(Iterable["ColumnOrName"], cols), _to_java_column)
class Window(object):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org