You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/10/13 00:36:28 UTC
[spark] branch master updated: [SPARK-36961][PYTHON] Use PEP526
style variable type hints
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 973f04e [SPARK-36961][PYTHON] Use PEP526 style variable type hints
973f04e is described below
commit 973f04eea7140dc61457cc12e74d5e7e333013db
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Wed Oct 13 09:35:45 2021 +0900
[SPARK-36961][PYTHON] Use PEP526 style variable type hints
### What changes were proposed in this pull request?
Uses PEP526 style variable type hints.
### Why are the changes needed?
Now that we have started using newer Python syntax in the code base.
We should use PEP526 style variable type hints.
- https://www.python.org/dev/peps/pep-0526/
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #34227 from ueshin/issues/SPARK-36961/pep526.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/pandas/accessors.py | 8 ++--
python/pyspark/pandas/categorical.py | 6 ++-
python/pyspark/pandas/config.py | 6 +--
python/pyspark/pandas/frame.py | 73 +++++++++++++++---------------
python/pyspark/pandas/generic.py | 3 +-
python/pyspark/pandas/groupby.py | 16 +++----
python/pyspark/pandas/indexes/base.py | 28 ++++++------
python/pyspark/pandas/indexes/multi.py | 10 ++--
python/pyspark/pandas/indexing.py | 12 ++---
python/pyspark/pandas/internal.py | 40 ++++++++--------
python/pyspark/pandas/mlflow.py | 4 +-
python/pyspark/pandas/namespace.py | 55 +++++++++++++---------
python/pyspark/pandas/series.py | 20 ++++----
python/pyspark/pandas/sql_processor.py | 6 +--
python/pyspark/pandas/typedef/typehints.py | 16 +++----
python/pyspark/pandas/utils.py | 23 ++++++----
python/pyspark/pandas/window.py | 12 +++--
python/pyspark/sql/pandas/conversion.py | 2 +-
python/pyspark/sql/pandas/types.py | 3 +-
19 files changed, 186 insertions(+), 157 deletions(-)
diff --git a/python/pyspark/pandas/accessors.py b/python/pyspark/pandas/accessors.py
index e69a86e..c54f21d 100644
--- a/python/pyspark/pandas/accessors.py
+++ b/python/pyspark/pandas/accessors.py
@@ -343,7 +343,7 @@ class PandasOnSparkFrameMethods(object):
original_func = func
func = lambda o: original_func(o, *args, **kwds)
- self_applied = DataFrame(self._psdf._internal.resolved_copy) # type: DataFrame
+ self_applied: DataFrame = DataFrame(self._psdf._internal.resolved_copy)
if should_infer_schema:
# Here we execute with the first 1000 to get the return type.
@@ -356,7 +356,7 @@ class PandasOnSparkFrameMethods(object):
"The given function should return a frame; however, "
"the return type was %s." % type(applied)
)
- psdf = ps.DataFrame(applied) # type: DataFrame
+ psdf: DataFrame = DataFrame(applied)
if len(pdf) <= limit:
return psdf
@@ -632,7 +632,7 @@ class PandasOnSparkFrameMethods(object):
[field.struct_field for field in index_fields + data_fields]
)
- self_applied = DataFrame(self._psdf._internal.resolved_copy) # type: DataFrame
+ self_applied: DataFrame = DataFrame(self._psdf._internal.resolved_copy)
output_func = GroupBy._make_pandas_df_builder_func(
self_applied, func, return_schema, retain_index=True
@@ -893,7 +893,7 @@ class PandasOnSparkSeriesMethods(object):
limit = ps.get_option("compute.shortcut_limit")
pser = self._psser.head(limit + 1)._to_internal_pandas()
transformed = pser.transform(func)
- psser = Series(transformed) # type: Series
+ psser: Series = Series(transformed)
field = psser._internal.data_fields[0].normalize_spark_type()
else:
diff --git a/python/pyspark/pandas/categorical.py b/python/pyspark/pandas/categorical.py
index fa11228..d580253 100644
--- a/python/pyspark/pandas/categorical.py
+++ b/python/pyspark/pandas/categorical.py
@@ -239,8 +239,9 @@ class CategoricalAccessor(object):
FutureWarning,
)
+ categories: List[Any]
if is_list_like(new_categories):
- categories = list(new_categories) # type: List
+ categories = list(new_categories)
else:
categories = [new_categories]
@@ -433,8 +434,9 @@ class CategoricalAccessor(object):
FutureWarning,
)
+ categories: List[Any]
if is_list_like(removals):
- categories = [cat for cat in removals if cat is not None] # type: List
+ categories = [cat for cat in removals if cat is not None]
elif removals is None:
categories = []
else:
diff --git a/python/pyspark/pandas/config.py b/python/pyspark/pandas/config.py
index 9d52350..b9f94a8 100644
--- a/python/pyspark/pandas/config.py
+++ b/python/pyspark/pandas/config.py
@@ -116,7 +116,7 @@ class Option:
# See the examples below:
# >>> from pyspark.pandas.config import show_options
# >>> show_options()
-_options = [
+_options: List[Option] = [
Option(
key="display.max_rows",
doc=(
@@ -246,9 +246,9 @@ _options = [
default="plotly",
types=str,
),
-] # type: List[Option]
+]
-_options_dict = dict(zip((option.key for option in _options), _options)) # type: Dict[str, Option]
+_options_dict: Dict[str, Option] = dict(zip((option.key for option in _options), _options))
_key_format = "pandas_on_Spark.{}".format
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 7a4817d..97522cb 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -437,8 +437,9 @@ class DataFrame(Frame, Generic[T]):
4 2 5 4 3 9
"""
- @no_type_check
- def __init__(self, data=None, index=None, columns=None, dtype=None, copy=False):
+ def __init__( # type: ignore[no-untyped-def]
+ self, data=None, index=None, columns=None, dtype=None, copy=False
+ ):
if isinstance(data, InternalFrame):
assert index is None
assert columns is None
@@ -535,7 +536,7 @@ class DataFrame(Frame, Generic[T]):
not_same_anchor = requires_same_anchor and not same_anchor(internal, psser)
if renamed or not_same_anchor:
- psdf = DataFrame(self._internal.select_column(old_label)) # type: DataFrame
+ psdf: DataFrame = DataFrame(self._internal.select_column(old_label))
psser._update_anchor(psdf)
psser = None
else:
@@ -1261,7 +1262,7 @@ class DataFrame(Frame, Generic[T]):
)
with option_context("compute.default_index_type", "distributed"):
- psdf = DataFrame(GroupBy._spark_groupby(self, func)) # type: DataFrame
+ psdf: DataFrame = DataFrame(GroupBy._spark_groupby(self, func))
# The codes below basically converts:
#
@@ -2474,9 +2475,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
else:
return pdf_or_pser
- self_applied = DataFrame(self._internal.resolved_copy) # type: "DataFrame"
+ self_applied: DataFrame = DataFrame(self._internal.resolved_copy)
- column_labels = None # type: Optional[List[Label]]
+ column_labels: Optional[List[Label]] = None
if should_infer_schema:
# Here we execute with the first 1000 to get the return type.
# If the records were less than 1000, it uses pandas API directly for a shortcut.
@@ -2588,7 +2589,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
column_labels=column_labels,
)
- result = DataFrame(internal) # type: "DataFrame"
+ result: DataFrame = DataFrame(internal)
if should_return_series:
return first_series(result)
else:
@@ -2723,7 +2724,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
limit = get_option("compute.shortcut_limit")
pdf = self.head(limit + 1)._to_internal_pandas()
transformed = pdf.transform(func, axis, *args, **kwargs)
- psdf = DataFrame(transformed) # type: "DataFrame"
+ psdf: DataFrame = DataFrame(transformed)
if len(pdf) <= limit:
return psdf
@@ -2936,7 +2937,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
internal = self._internal.with_filter(reduce(lambda x, y: x & y, rows))
if len(key) == self._internal.index_level:
- psdf = DataFrame(internal) # type: DataFrame
+ psdf: DataFrame = DataFrame(internal)
pdf = psdf.head(2)._to_internal_pandas()
if len(pdf) == 0:
raise KeyError(key)
@@ -3555,8 +3556,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
2014 10 31
"""
inplace = validate_bool_kwarg(inplace, "inplace")
+ key_list: List[Label]
if is_name_like_tuple(keys):
- key_list = [cast(Label, keys)] # type: List[Label]
+ key_list = [cast(Label, keys)]
elif is_name_like_value(keys):
key_list = [(keys,)]
else:
@@ -5218,9 +5220,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
elif how not in ("any", "all"):
raise ValueError("invalid how option: {h}".format(h=how))
+ labels: Optional[List[Label]]
if subset is not None:
if isinstance(subset, str):
- labels = [(subset,)] # type: Optional[List[Label]]
+ labels = [(subset,)]
elif isinstance(subset, tuple):
labels = [subset]
else:
@@ -5284,7 +5287,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
internal = internal.with_filter(cond)
- psdf = DataFrame(internal)
+ psdf: DataFrame = DataFrame(internal)
null_counts = []
for label in internal.column_labels:
@@ -5996,6 +5999,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if fill_value is not None and isinstance(fill_value, (int, float)):
sdf = sdf.fillna(fill_value)
+ psdf: DataFrame
if index is not None:
index_columns = [self._internal.spark_column_name_for(label) for label in index]
index_fields = [self._internal.field_for(label) for label in index]
@@ -6034,7 +6038,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
data_spark_columns=[scol_for(sdf, col) for col in data_columns],
column_label_names=column_label_names,
)
- psdf = DataFrame(internal) # type: "DataFrame"
+ psdf = DataFrame(internal)
else:
column_labels = [tuple(list(values[0]) + [column]) for column in data_columns]
column_label_names = ([cast(Optional[Name], None)] * len(values[0])) + [columns]
@@ -6062,7 +6066,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
index_values = values[-1]
else:
index_values = values
- index_map = OrderedDict() # type: Dict[str, Optional[Label]]
+ index_map: Dict[str, Optional[Label]] = OrderedDict()
for i, index_value in enumerate(index_values):
colname = SPARK_INDEX_NAME_FORMAT(i)
sdf = sdf.withColumn(colname, SF.lit(index_value))
@@ -6257,10 +6261,11 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
)
)
+ column_label_names: Optional[List]
if isinstance(columns, pd.Index):
column_label_names = [
name if is_name_like_tuple(name) else (name,) for name in columns.names
- ] # type: Optional[List]
+ ]
else:
column_label_names = None
@@ -9008,7 +9013,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
"shape (1,{}) doesn't match the shape (1,{})".format(len(col), level)
)
fill_value = np.nan if fill_value is None else fill_value
- scols_or_pssers = [] # type: List[Union[Series, Column]]
+ scols_or_pssers: List[Union[Series, Column]] = []
labels = []
for label in label_columns:
if label in self._internal.column_labels:
@@ -9437,7 +9442,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
).with_filter(SF.lit(False))
)
- column_labels = defaultdict(dict) # type: Union[defaultdict, OrderedDict]
+ column_labels: Union[defaultdict, OrderedDict] = defaultdict(dict)
index_values = set()
should_returns_series = False
for label in self._internal.column_labels:
@@ -9498,7 +9503,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
data_spark_columns=[scol_for(sdf, col) for col in data_columns],
column_label_names=column_label_names,
)
- psdf = DataFrame(internal) # type: "DataFrame"
+ psdf: DataFrame = DataFrame(internal)
if should_returns_series:
return first_series(psdf)
@@ -10181,11 +10186,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
) -> Tuple[Callable[[Any], Any], Dtype, DataType]:
if isinstance(mapper, dict):
mapper_dict = cast(dict, mapper)
- if len(mapper_dict) == 0:
- if errors == "raise":
- raise KeyError("Index include label which is not in the `mapper`.")
- else:
- return DataFrame(self._internal)
type_set = set(map(lambda x: type(x), mapper_dict.values()))
if len(type_set) > 1:
@@ -10439,15 +10439,16 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
v: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]],
curnames: List[Name],
) -> List[Label]:
+ newnames: List[Name]
if is_scalar(v):
- newnames = [cast(Any, v)] # type: List[Name]
+ newnames = [cast(Name, v)]
elif is_list_like(v) and not is_dict_like(v):
- newnames = list(cast(Sequence[Any], v))
+ newnames = list(cast(Sequence[Name], v))
elif is_dict_like(v):
- v_dict = cast(Dict[Name, Any], v)
+ v_dict = cast(Dict[Name, Name], v)
newnames = [v_dict[name] if name in v_dict else name for name in curnames]
elif callable(v):
- v_callable = cast(Callable[[Name], Any], v)
+ v_callable = cast(Callable[[Name], Name], v)
newnames = [v_callable(name) for name in curnames]
else:
raise ValueError(
@@ -10647,7 +10648,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
)
cond = reduce(lambda x, y: x | y, conds)
- psdf = DataFrame(self._internal.with_filter(cond)) # type: "DataFrame"
+ psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax()))
@@ -10719,7 +10720,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
)
cond = reduce(lambda x, y: x | y, conds)
- psdf = DataFrame(self._internal.with_filter(cond)) # type: "DataFrame"
+ psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin()))
@@ -10912,7 +10913,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
"accuracy must be an integer; however, got [%s]" % type(accuracy).__name__
)
- qq = list(q) if isinstance(q, Iterable) else q # type: Union[float, List[float]]
+ qq: Union[float, List[float]] = list(q) if isinstance(q, Iterable) else q
for v in qq if isinstance(qq, list) else [qq]:
if not isinstance(v, float):
@@ -10944,9 +10945,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
# |[[0.25, 2, 6], [0.5, 3, 7], [0.75, 4, 8]]|
# +-----------------------------------------+
- percentile_cols = []
- percentile_col_names = []
- column_labels = []
+ percentile_cols: List[Column] = []
+ percentile_col_names: List[str] = []
+ column_labels: List[Label] = []
for label, column in zip(
self._internal.column_labels, self._internal.data_spark_column_names
):
@@ -10974,7 +10975,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
# |[2, 3, 4]|[6, 7, 8]|
# +---------+---------+
- cols_dict = OrderedDict() # type: OrderedDict
+ cols_dict: Dict[str, List[Column]] = OrderedDict()
for column in percentile_col_names:
cols_dict[column] = list()
for i in range(len(qq)):
@@ -11357,7 +11358,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if not is_name_like_value(column):
raise TypeError("column must be a scalar")
- psdf = DataFrame(self._internal.resolved_copy) # type: "DataFrame"
+ psdf: DataFrame = DataFrame(self._internal.resolved_copy)
psser = psdf[column]
if not isinstance(psser, Series):
raise ValueError(
@@ -11422,7 +11423,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return scol
- new_column_labels = [] # type: List[Label]
+ new_column_labels: List[Label] = []
for label in self._internal.column_labels:
# Filtering out only columns of numeric and boolean type column.
dtype = self._psser_for(label).spark.data_type
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 9c5c03d..75fa88d 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -2254,10 +2254,11 @@ class Frame(object, metaclass=ABCMeta):
2.0 2 5
NaN 1 4
"""
+ new_by: List[Union[Label, ps.Series]]
if isinstance(by, ps.DataFrame):
raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by).__name__))
elif isinstance(by, ps.Series):
- new_by = [by] # type: List[Union[Label, ps.Series]]
+ new_by = [by]
elif is_name_like_tuple(by):
if isinstance(self, ps.Series):
raise KeyError(by)
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 57c2281..471f90d 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -293,9 +293,9 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
agg_cols = [col.name for col in self._agg_columns]
func_or_funcs = OrderedDict([(col, func_or_funcs) for col in agg_cols])
- psdf = DataFrame(
+ psdf: DataFrame = DataFrame(
GroupBy._spark_groupby(self._psdf, func_or_funcs, self._groupkeys)
- ) # type: DataFrame
+ )
if self._dropna:
psdf = DataFrame(
@@ -1458,10 +1458,10 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
def _prepare_group_map_apply(
psdf: DataFrame, groupkeys: List[Series], agg_columns: List[Series]
) -> Tuple[DataFrame, List[Label], List[str]]:
- groupkey_labels = [
+ groupkey_labels: List[Label] = [
verify_temp_column_name(psdf, "__groupkey_{}__".format(i))
for i in range(len(groupkeys))
- ] # type: List[Label]
+ ]
psdf = psdf[[s.rename(label) for s, label in zip(groupkeys, groupkey_labels)] + agg_columns]
groupkey_names = [label if len(label) > 1 else label[0] for label in groupkey_labels]
return DataFrame(psdf._internal.resolved_copy), groupkey_labels, groupkey_names
@@ -2270,7 +2270,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
limit = get_option("compute.shortcut_limit")
pdf = psdf.head(limit + 1)._to_internal_pandas()
pdf = pdf.groupby(groupkey_names).transform(func, *args, **kwargs)
- psdf_from_pandas = DataFrame(pdf) # type: DataFrame
+ psdf_from_pandas: DataFrame = DataFrame(pdf)
return_schema = force_decimal_precision_scale(
as_nullable_spark_type(
psdf_from_pandas._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema
@@ -2614,7 +2614,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
data_fields=[psser._internal.data_fields[0] for psser in agg_columns],
column_label_names=self._psdf._internal.column_label_names,
)
- psdf = DataFrame(internal) # type: DataFrame
+ psdf: DataFrame = DataFrame(internal)
if len(psdf._internal.column_labels) > 0:
stat_exprs = []
@@ -3351,8 +3351,8 @@ def normalize_keyword_aggregation(
kwargs = OrderedDict(sorted(kwargs.items()))
# TODO(Py35): When we drop python 3.5, change this to defaultdict(list)
- aggspec = OrderedDict() # type: Dict[Union[Any, Tuple], List[str]]
- order = [] # type: List[Tuple]
+ aggspec: Dict[Union[Any, Tuple], List[str]] = OrderedDict()
+ order: List[Tuple] = []
columns, pairs = zip(*kwargs.items())
for column, aggfunc in pairs:
diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py
index 369934a..4564290 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -909,9 +909,7 @@ class Index(IndexOpsMixin):
field = field.copy(name=name_like_string(name))
elif self._internal.index_level == 1:
name = self.name
- column_labels = [
- name if is_name_like_tuple(name) else (name,)
- ] # type: List[Optional[Label]]
+ column_labels: List[Optional[Label]] = [name if is_name_like_tuple(name) else (name,)]
internal = self._internal.copy(
column_labels=column_labels,
data_spark_columns=[scol],
@@ -2181,7 +2179,7 @@ class Index(IndexOpsMixin):
elif repeats < 0:
raise ValueError("negative dimensions are not allowed")
- psdf = DataFrame(self._internal.resolved_copy) # type: DataFrame
+ psdf: DataFrame = DataFrame(self._internal.resolved_copy)
if repeats == 0:
return DataFrame(psdf._internal.with_filter(SF.lit(False))).index
else:
@@ -2315,9 +2313,10 @@ class Index(IndexOpsMixin):
sort = True if sort is None else sort
sort = validate_bool_kwarg(sort, "sort")
+ other_idx: Index
if isinstance(self, MultiIndex):
if isinstance(other, MultiIndex):
- other_idx = other # type: Index
+ other_idx = other
elif isinstance(other, list) and all(isinstance(item, tuple) for item in other):
other_idx = MultiIndex.from_tuples(other)
else:
@@ -2406,27 +2405,30 @@ class Index(IndexOpsMixin):
"""
from pyspark.pandas.indexes.multi import MultiIndex
+ other_idx: Index
if isinstance(other, DataFrame):
raise ValueError("Index data must be 1-dimensional")
elif isinstance(other, MultiIndex):
# Always returns a no-named empty Index if `other` is MultiIndex.
return self._psdf.head(0).index.rename(None)
elif isinstance(other, Index):
- spark_frame_other = other.to_frame().to_spark()
- keep_name = self.name == other.name
+ other_idx = other
+ spark_frame_other = other_idx.to_frame().to_spark()
+ keep_name = self.name == other_idx.name
elif isinstance(other, Series):
- spark_frame_other = other.to_frame().to_spark()
+ other_idx = Index(other)
+ spark_frame_other = other_idx.to_frame().to_spark()
keep_name = True
elif is_list_like(other):
- other = Index(other)
- if isinstance(other, MultiIndex):
- return other.to_frame().head(0).index
- spark_frame_other = other.to_frame().to_spark()
+ other_idx = Index(other)
+ if isinstance(other_idx, MultiIndex):
+ return other_idx.to_frame().head(0).index
+ spark_frame_other = other_idx.to_frame().to_spark()
keep_name = True
else:
raise TypeError("Input must be Index or array-like")
- index_fields = self._index_fields_for_union_like(other, func_name="intersection")
+ index_fields = self._index_fields_for_union_like(other_idx, func_name="intersection")
spark_frame_self = self.to_frame(name=SPARK_DEFAULT_INDEX_NAME).to_spark()
spark_frame_intersected = spark_frame_self.intersect(spark_frame_other)
diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py
index 896ea2a..a3875d54e 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -1069,9 +1069,7 @@ class MultiIndex(Index):
"index {} is out of bounds for axis 0 with size {}".format(loc, length)
)
- index_name = [
- (name,) for name in self._internal.index_spark_column_names
- ] # type: List[Label]
+ index_name: List[Label] = [(name,) for name in self._internal.index_spark_column_names]
sdf_before = self.to_frame(name=index_name)[:loc].to_spark()
sdf_middle = Index([item]).to_frame(name=index_name).to_spark()
sdf_after = self.to_frame(name=index_name)[loc:].to_spark()
@@ -1150,7 +1148,7 @@ class MultiIndex(Index):
index_fields = self._index_fields_for_union_like(other, func_name="intersection")
- default_name = [SPARK_INDEX_NAME_FORMAT(i) for i in range(self.nlevels)] # type: List
+ default_name: List[Name] = [SPARK_INDEX_NAME_FORMAT(i) for i in range(self.nlevels)]
spark_frame_self = self.to_frame(name=default_name).to_spark()
spark_frame_intersected = spark_frame_self.intersect(spark_frame_other)
if keep_name:
@@ -1160,7 +1158,9 @@ class MultiIndex(Index):
internal = InternalFrame(
spark_frame=spark_frame_intersected,
- index_spark_columns=[scol_for(spark_frame_intersected, col) for col in default_name],
+ index_spark_columns=[
+ scol_for(spark_frame_intersected, cast(str, col)) for col in default_name
+ ],
index_names=index_names,
index_fields=index_fields,
)
diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py
index 8a3c335..cf6d535 100644
--- a/python/pyspark/pandas/indexing.py
+++ b/python/pyspark/pandas/indexing.py
@@ -551,6 +551,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
)
psdf = DataFrame(internal)
+ psdf_or_psser: Union[DataFrame, Series]
if returns_series:
psdf_or_psser = first_series(psdf)
if series_name is not None and series_name != psdf_or_psser.name:
@@ -709,7 +710,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
return
cond, limit, remaining_index = self._select_rows(rows_sel)
- missing_keys = [] # type: Optional[List[Name]]
+ missing_keys: List[Name] = []
_, data_spark_columns, _, _, _ = self._select_cols(cols_sel, missing_keys=missing_keys)
if cond is None:
@@ -1033,7 +1034,7 @@ class LocIndexer(LocIndexerLike):
stop = [row[1] for row in start_and_stop if row[0] == stop]
stop = stop[-1] if len(stop) > 0 else None
- conds = [] # type: List[Column]
+ conds: List[Column] = []
if start is not None:
conds.append(F.col(NATURAL_ORDER_COLUMN_NAME) >= SF.lit(start).cast(LongType()))
if stop is not None:
@@ -1200,6 +1201,7 @@ class LocIndexer(LocIndexerLike):
return self._get_from_multiindex_column((str(key),), missing_keys, labels, recursed + 1)
else:
returns_series = all(lbl is None or len(lbl) == 0 for _, lbl in labels)
+ series_name: Optional[Name]
if returns_series:
label_set = set(label for label, _ in labels)
assert len(label_set) == 1
@@ -1208,7 +1210,7 @@ class LocIndexer(LocIndexerLike):
data_spark_columns = [self._internal.spark_column_for(label)]
data_fields = [self._internal.field_for(label)]
if label is None:
- series_name = None # type: Name
+ series_name = None
else:
if recursed > 0:
label = label[:-recursed]
@@ -1246,9 +1248,7 @@ class LocIndexer(LocIndexerLike):
bool,
Optional[Name],
]:
- column_labels = [
- (self._internal.spark_frame.select(cols_sel).columns[0],)
- ] # type: List[Label]
+ column_labels: List[Label] = [(self._internal.spark_frame.select(cols_sel).columns[0],)]
data_spark_columns = [cols_sel]
return column_labels, data_spark_columns, None, True, None
diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py
index fbcb60b..2e808ce 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -664,14 +664,14 @@ class InternalFrame(object):
NATURAL_ORDER_COLUMN_NAME, F.monotonically_increasing_id()
)
- self._sdf = spark_frame # type: SparkDataFrame
+ self._sdf: SparkDataFrame = spark_frame
# index_spark_columns
assert all(
isinstance(index_scol, Column) for index_scol in index_spark_columns
), index_spark_columns
- self._index_spark_columns = index_spark_columns # type: List[Column]
+ self._index_spark_columns: List[Column] = index_spark_columns
# data_spark_columns
if data_spark_columns is None:
@@ -684,10 +684,10 @@ class InternalFrame(object):
)
and col not in HIDDEN_COLUMNS
]
- self._data_spark_columns = data_spark_columns # type: List[Column]
else:
assert all(isinstance(scol, Column) for scol in data_spark_columns)
- self._data_spark_columns = data_spark_columns
+
+ self._data_spark_columns: List[Column] = data_spark_columns
# fields
if index_fields is None:
@@ -755,7 +755,7 @@ class InternalFrame(object):
for index_field, struct_field in zip(index_fields, struct_fields)
), (index_fields, struct_fields)
- self._index_fields = index_fields # type: List[InternalField]
+ self._index_fields: List[InternalField] = index_fields
assert all(
isinstance(ops.dtype, Dtype.__args__) # type: ignore[attr-defined]
@@ -773,7 +773,7 @@ class InternalFrame(object):
for data_field, struct_field in zip(data_fields, struct_fields)
), (data_fields, struct_fields)
- self._data_fields = data_fields # type: List[InternalField]
+ self._data_fields: List[InternalField] = data_fields
# index_names
if not index_names:
@@ -787,13 +787,11 @@ class InternalFrame(object):
is_name_like_tuple(index_name, check_type=True) for index_name in index_names
), index_names
- self._index_names = index_names # type: List[Optional[Label]]
+ self._index_names: List[Optional[Label]] = index_names
# column_labels
if column_labels is None:
- self._column_labels = [
- (col,) for col in spark_frame.select(self._data_spark_columns).columns
- ] # type: List[Label]
+ column_labels = [(col,) for col in spark_frame.select(self._data_spark_columns).columns]
else:
assert len(column_labels) == len(self._data_spark_columns), (
len(column_labels),
@@ -808,13 +806,12 @@ class InternalFrame(object):
for column_label in column_labels
), column_labels
assert len(set(len(label) for label in column_labels)) <= 1, column_labels
- self._column_labels = column_labels
+
+ self._column_labels: List[Label] = column_labels
# column_label_names
if column_label_names is None:
- self._column_label_names = [None] * column_labels_level(
- self._column_labels
- ) # type: List[Optional[Label]]
+ column_label_names = [None] * column_labels_level(self._column_labels)
else:
if len(self._column_labels) > 0:
assert len(column_label_names) == column_labels_level(self._column_labels), (
@@ -827,7 +824,8 @@ class InternalFrame(object):
is_name_like_tuple(column_label_name, check_type=True)
for column_label_name in column_label_names
), column_label_names
- self._column_label_names = column_label_names
+
+ self._column_label_names: List[Optional[Label]] = column_label_names
@staticmethod
def attach_default_index(
@@ -1439,18 +1437,20 @@ class InternalFrame(object):
:return: the created immutable DataFrame
"""
- index_names = [
+ index_names: List[Optional[Label]] = [
name if name is None or isinstance(name, tuple) else (name,) for name in pdf.index.names
- ] # type: List[Optional[Label]]
+ ]
columns = pdf.columns
+ column_labels: List[Label]
if isinstance(columns, pd.MultiIndex):
- column_labels = columns.tolist() # type: List[Label]
+ column_labels = columns.tolist()
else:
column_labels = [(col,) for col in columns]
- column_label_names = [
+
+ column_label_names: List[Optional[Label]] = [
name if name is None or isinstance(name, tuple) else (name,) for name in columns.names
- ] # type: List[Optional[Label]]
+ ]
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
diff --git a/python/pyspark/pandas/mlflow.py b/python/pyspark/pandas/mlflow.py
index 719db40..590d589 100644
--- a/python/pyspark/pandas/mlflow.py
+++ b/python/pyspark/pandas/mlflow.py
@@ -98,9 +98,9 @@ class PythonModelWrapper(object):
# However, this is only possible with spark >= 3.0
# s = F.struct(*data.columns)
# return_col = self._model_udf(s)
- column_labels = [
+ column_labels: List[Label] = [
(col,) for col in data._internal.spark_frame.select(return_col).columns
- ] # type: List[Label]
+ ]
internal = data._internal.copy(
column_labels=column_labels, data_spark_columns=[return_col], data_fields=None
)
diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py
index 2d62dea..645375f 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -24,6 +24,7 @@ from typing import ( # noqa: F401 (SPARK-34943)
Dict,
List,
Optional,
+ Set,
Sized,
Tuple,
Type,
@@ -327,9 +328,10 @@ def read_csv(
reader.options(**options)
+ column_labels: Dict[Any, str]
if isinstance(names, str):
sdf = reader.schema(names).csv(path)
- column_labels = OrderedDict((col, col) for col in sdf.columns) # type: Dict[Any, str]
+ column_labels = OrderedDict((col, col) for col in sdf.columns)
else:
sdf = reader.csv(path)
if is_list_like(names):
@@ -349,11 +351,12 @@ def read_csv(
column_labels = OrderedDict((col, col) for col in sdf.columns)
if usecols is not None:
+ missing: List[Union[int, str]]
if callable(usecols):
column_labels = OrderedDict(
(label, col) for label, col in column_labels.items() if usecols(label)
)
- missing = [] # type: List[Union[int, str]]
+ missing = []
elif all(isinstance(col, int) for col in usecols):
usecols_ints = cast(List[int], usecols)
new_column_labels = OrderedDict(
@@ -397,6 +400,8 @@ def read_csv(
if nrows is not None:
sdf = sdf.limit(nrows)
+ index_spark_column_names: List[str]
+ index_names: List[Label]
if index_col is not None:
if isinstance(index_col, (str, int)):
index_col = [index_col]
@@ -404,7 +409,7 @@ def read_csv(
if col not in column_labels:
raise KeyError(col)
index_spark_column_names = [column_labels[col] for col in index_col]
- index_names = [(col,) for col in index_col] # type: List[Label]
+ index_names = [(col,) for col in index_col]
column_labels = OrderedDict(
(label, col) for label, col in column_labels.items() if label not in index_col
)
@@ -412,7 +417,7 @@ def read_csv(
index_spark_column_names = []
index_names = []
- psdf = DataFrame(
+ psdf: DataFrame = DataFrame(
InternalFrame(
spark_frame=sdf,
index_spark_columns=[scol_for(sdf, col) for col in index_spark_column_names],
@@ -422,7 +427,7 @@ def read_csv(
],
data_spark_columns=[scol_for(sdf, col) for col in column_labels.values()],
)
- ) # type: DataFrame
+ )
if dtype is not None:
if isinstance(dtype, dict):
@@ -1378,11 +1383,11 @@ def read_sql_table(
reader.options(**options)
sdf = reader.format("jdbc").load()
index_spark_columns, index_names = _get_index_map(sdf, index_col)
- psdf = DataFrame(
+ psdf: DataFrame = DataFrame(
InternalFrame(
spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names
)
- ) # type: DataFrame
+ )
if columns is not None:
if isinstance(columns, str):
columns = [columns]
@@ -2246,10 +2251,13 @@ def concat(
raise ValueError("Only can inner (intersect) or outer (union) join the other axis.")
axis = validate_axis(axis)
+ psdf: DataFrame
if axis == 1:
- psdfs = [obj.to_frame() if isinstance(obj, Series) else obj for obj in objs]
+ psdfs: List[DataFrame] = [
+ obj.to_frame() if isinstance(obj, Series) else obj for obj in objs
+ ]
- level = min(psdf._internal.column_labels_level for psdf in psdfs)
+ level: int = min(psdf._internal.column_labels_level for psdf in psdfs)
psdfs = [
DataFrame._index_normalized_frame(level, psdf)
if psdf._internal.column_labels_level > level
@@ -2258,7 +2266,7 @@ def concat(
]
concat_psdf = psdfs[0]
- column_labels = concat_psdf._internal.column_labels.copy()
+ column_labels: List[Label] = concat_psdf._internal.column_labels.copy()
psdfs_not_same_anchor = []
for psdf in psdfs[1:]:
@@ -2323,17 +2331,19 @@ def concat(
# DataFrame, Series ... & Series, Series ...
# In this case, we should return DataFrame.
- new_objs = []
+ new_objs: List[DataFrame] = []
num_series = 0
series_names = set()
for obj in objs:
if isinstance(obj, Series):
num_series += 1
series_names.add(obj.name)
- obj = obj.to_frame(DEFAULT_SERIES_NAME)
- new_objs.append(obj)
+ new_objs.append(obj.to_frame(DEFAULT_SERIES_NAME))
+ else:
+ assert isinstance(obj, DataFrame)
+ new_objs.append(obj)
- column_labels_levels = set(obj._internal.column_labels_level for obj in new_objs)
+ column_labels_levels: Set[int] = set(obj._internal.column_labels_level for obj in new_objs)
if len(column_labels_levels) != 1:
raise ValueError("MultiIndex columns should have the same levels")
@@ -2354,8 +2364,9 @@ def concat(
)
column_labels_of_psdfs = [psdf._internal.column_labels for psdf in new_objs]
+ index_names_of_psdfs: List[List[Optional[Label]]]
if ignore_index:
- index_names_of_psdfs = [[] for _ in new_objs] # type: List
+ index_names_of_psdfs = [[] for _ in new_objs]
else:
index_names_of_psdfs = [psdf._internal.index_names for psdf in new_objs]
@@ -2446,7 +2457,7 @@ def concat(
index_names = psdfs[0]._internal.index_names
index_fields = psdfs[0]._internal.index_fields
- result_psdf = DataFrame(
+ result_psdf: DataFrame = DataFrame(
psdfs[0]._internal.copy(
spark_frame=concatenated,
index_spark_columns=[scol_for(concatenated, col) for col in index_spark_column_names],
@@ -2457,7 +2468,7 @@ def concat(
],
data_fields=None, # TODO: dtypes?
)
- ) # type: DataFrame
+ )
if should_return_series:
# If all input were Series, we should return Series.
@@ -3053,7 +3064,7 @@ def merge_asof(
if os is None:
return []
elif is_name_like_tuple(os):
- return [os] # type: ignore
+ return [cast(Label, os)]
elif is_name_like_value(os):
return [(os,)]
else:
@@ -3470,6 +3481,8 @@ def read_orc(
def _get_index_map(
sdf: SparkDataFrame, index_col: Optional[Union[str, List[str]]] = None
) -> Tuple[Optional[List[Column]], Optional[List[Label]]]:
+ index_spark_columns: Optional[List[Column]]
+ index_names: Optional[List[Label]]
if index_col is not None:
if isinstance(index_col, str):
index_col = [index_col]
@@ -3477,10 +3490,8 @@ def _get_index_map(
for col in index_col:
if col not in sdf_columns:
raise KeyError(col)
- index_spark_columns = [
- scol_for(sdf, col) for col in index_col
- ] # type: Optional[List[Column]]
- index_names = [(col,) for col in index_col] # type: Optional[List[Label]]
+ index_spark_columns = [scol_for(sdf, col) for col in index_col]
+ index_names = [(col,) for col in index_col]
else:
index_spark_columns = None
index_names = None
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 07777ca..9e20525 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -387,18 +387,21 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
Copy input data
"""
- @no_type_check
- def __init__(self, data=None, index=None, dtype=None, name=None, copy=False, fastpath=False):
+ def __init__( # type: ignore[no-untyped-def]
+ self, data=None, index=None, dtype=None, name=None, copy=False, fastpath=False
+ ):
assert data is not None
+ self._anchor: DataFrame
+ self._col_label: Label
if isinstance(data, DataFrame):
assert dtype is None
assert name is None
assert not copy
assert not fastpath
- self._anchor = data # type: DataFrame
- self._col_label = index # type: Label
+ self._anchor = data
+ self._col_label = index
else:
if isinstance(data, pd.Series):
assert index is None
@@ -1145,7 +1148,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
data_fields=[field],
column_label_names=None,
)
- psdf = DataFrame(internal) # type: DataFrame
+ psdf: DataFrame = DataFrame(internal)
if kwargs.get("inplace", False):
self._col_label = index
@@ -5019,7 +5022,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
if not self.index.sort_values().equals(other.index.sort_values()):
raise ValueError("matrices are not aligned")
- other_copy = other.copy() # type: DataFrame
+ other_copy: DataFrame = other.copy()
column_labels = other_copy._internal.column_labels
self_column_label = verify_temp_column_name(other_copy, "__self_column__")
@@ -5215,7 +5218,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
# The data is expected to be small so it's fine to transpose/use default index.
with ps.option_context("compute.default_index_type", "distributed", "compute.max_rows", 1):
- psdf = ps.DataFrame(sdf) # type: DataFrame
+ psdf: DataFrame = DataFrame(sdf)
psdf.columns = pd.Index(where)
return first_series(psdf.transpose()).rename(self.name)
@@ -5808,6 +5811,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
>>> reset_option("compute.ops_on_diff_frames")
"""
+ combined: DataFrame
if same_anchor(self, other):
self_column_label = verify_temp_column_name(other.to_frame(), "__self_column__")
other_column_label = verify_temp_column_name(self.to_frame(), "__other_column__")
@@ -5815,7 +5819,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
self._internal.with_new_columns(
[self.rename(self_column_label), other.rename(other_column_label)]
)
- ) # type: DataFrame
+ )
else:
if not self.index.equals(other.index):
raise ValueError("Can only compare identically-labeled Series objects")
diff --git a/python/pyspark/pandas/sql_processor.py b/python/pyspark/pandas/sql_processor.py
index 98b2fd3..9b51ef1 100644
--- a/python/pyspark/pandas/sql_processor.py
+++ b/python/pyspark/pandas/sql_processor.py
@@ -257,14 +257,14 @@ class SQLProcessor(object):
# All the temporary views created when executing this statement
# The key is the name of the variable in {}
# The value is the cached Spark Dataframe.
- self._temp_views = {} # type: Dict[str, SDataFrame]
+ self._temp_views: Dict[str, SDataFrame] = {}
# All the other variables, converted to a normalized form.
# The normalized form is typically a string
- self._cached_vars = {} # type: Dict[str, Any]
+ self._cached_vars: Dict[str, Any] = {}
# The SQL statement after:
# - all the dataframes have been registered as temporary views
# - all the values have been converted normalized to equivalent SQL representations
- self._normalized_statement = None # type: Optional[str]
+ self._normalized_statement: Optional[str] = None
self._session = session
def execute(self, index_col: Optional[Union[str, List[str]]]) -> DataFrame:
diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py
index 288273a..9f61995 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -41,11 +41,12 @@ import pandas as pd
from pandas.api.types import CategoricalDtype, pandas_dtype
from pandas.api.extensions import ExtensionDtype
+extension_dtypes: Tuple[type, ...]
try:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
extension_dtypes_available = True
- extension_dtypes = (Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype) # type: Tuple
+ extension_dtypes = (Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype)
try:
from pandas import BooleanDtype, StringDtype
@@ -665,10 +666,9 @@ def create_type_for_series_type(param: Any) -> Type[SeriesType]:
"""
from pyspark.pandas.typedef import NameTypeHolder
+ new_class: Type[NameTypeHolder]
if isinstance(param, ExtensionDtype):
- new_class = type(
- NameTypeHolder.short_name, (NameTypeHolder,), {}
- ) # type: Type[NameTypeHolder]
+ new_class = type(NameTypeHolder.short_name, (NameTypeHolder,), {})
new_class.tpe = param
else:
new_class = param.type if isinstance(param, np.dtype) else param
@@ -815,9 +815,9 @@ def _new_type_holders(
# DataFrame["id": int, "A": int]
new_params = []
for param in params:
- new_param = type(
+ new_param: Type[Union[NameTypeHolder, IndexNameTypeHolder]] = type(
holder_clazz.short_name, (holder_clazz,), {}
- ) # type: Type[Union[NameTypeHolder, IndexNameTypeHolder]]
+ )
new_param.name = param.start
if isinstance(param.stop, ExtensionDtype):
new_param.tpe = param.stop
@@ -830,9 +830,9 @@ def _new_type_holders(
# DataFrame[float, float]
new_types = []
for param in params:
- new_type = type(
+ new_type: Type[Union[NameTypeHolder, IndexNameTypeHolder]] = type(
holder_clazz.short_name, (holder_clazz,), {}
- ) # type: Type[Union[NameTypeHolder, IndexNameTypeHolder]]
+ )
if isinstance(param, ExtensionDtype):
new_type.tpe = param
else:
diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py
index 87c28b8..bca327b 100644
--- a/python/pyspark/pandas/utils.py
+++ b/python/pyspark/pandas/utils.py
@@ -385,11 +385,11 @@ def align_diff_frames(
# 2. Apply the given function to transform the columns in a batch and keep the new columns.
combined_column_labels = combined._internal.column_labels
- that_columns_to_apply = [] # type: List[Label]
- this_columns_to_apply = [] # type: List[Label]
- additional_that_columns = [] # type: List[Label]
- columns_to_keep = [] # type: List[Union[Series, Column]]
- column_labels_to_keep = [] # type: List[Label]
+ that_columns_to_apply: List[Label] = []
+ this_columns_to_apply: List[Label] = []
+ additional_that_columns: List[Label] = []
+ columns_to_keep: List[Union[Series, Column]] = []
+ column_labels_to_keep: List[Label] = []
for combined_label in combined_column_labels:
for common_label in common_column_labels:
@@ -419,22 +419,24 @@ def align_diff_frames(
# Should extract columns to apply and do it in a batch in case
# it adds new columns for example.
+ columns_applied: List[Union[Series, Column]]
+ column_labels_applied: List[Label]
if len(this_columns_to_apply) > 0 or len(that_columns_to_apply) > 0:
psser_set, column_labels_set = zip(
*resolve_func(combined, this_columns_to_apply, that_columns_to_apply)
)
- columns_applied = list(psser_set) # type: List[Union[Series, Column]]
- column_labels_applied = list(column_labels_set) # type: List[Label]
+ columns_applied = list(psser_set)
+ column_labels_applied = list(column_labels_set)
else:
columns_applied = []
column_labels_applied = []
- applied = DataFrame(
+ applied: DataFrame = DataFrame(
combined._internal.with_new_columns(
columns_applied + columns_to_keep,
column_labels=column_labels_applied + column_labels_to_keep,
)
- ) # type: DataFrame
+ )
# 3. Restore the names back and deduplicate columns.
this_labels = OrderedDict()
@@ -620,8 +622,9 @@ def name_like_string(name: Optional[Name]) -> str:
>>> name_like_string(name)
'(a, b, c)'
"""
+ label: Label
if name is None:
- label = ("__none__",) # type: Label
+ label = ("__none__",)
elif is_list_like(name):
label = tuple([str(n) for n in name])
else:
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 0cc6c67..675fad0d 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -22,6 +22,7 @@ from typing import ( # noqa: F401 (SPARK-34943)
Generic,
List,
Optional,
+ cast,
)
from pyspark.sql import Window
@@ -178,9 +179,12 @@ class Rolling(RollingLike[FrameLike]):
raise AttributeError(item)
def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> FrameLike:
- return self._psdf_or_psser._apply_series_op(
- lambda psser: psser._with_new_scol(func(psser.spark.column)), # TODO: dtype?
- should_resolve=True,
+ return cast(
+ FrameLike,
+ self._psdf_or_psser._apply_series_op(
+ lambda psser: psser._with_new_scol(func(psser.spark.column)), # TODO: dtype?
+ should_resolve=True,
+ ),
)
def count(self) -> FrameLike:
@@ -681,7 +685,7 @@ class RollingGroupby(RollingLike[FrameLike]):
# Here we need to include grouped key as an index, and shift previous index.
# [index_column0, index_column1] -> [grouped key, index_column0, index_column1]
- new_index_scols = [] # type: List[Column]
+ new_index_scols: List[Column] = []
new_index_spark_column_names = []
new_index_names = []
new_index_fields = []
diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py
index 354d3a9..a9700df 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -169,7 +169,7 @@ class PandasConversionMixin(object):
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
column_counter = Counter(self.columns)
- dtype = [None] * len(self.schema) # type: List[Optional[Type]]
+ dtype: List[Optional[Type]] = [None] * len(self.schema)
for fieldIdx, field in enumerate(self.schema):
# For duplicate column name, we use `iloc` to access it.
if column_counter[field.name] > 1:
diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py
index 44253bf..c52cbc7 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -102,8 +102,9 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
from distutils.version import LooseVersion
import pyarrow as pa
import pyarrow.types as types
+ spark_type: DataType
if types.is_boolean(at):
- spark_type = BooleanType() # type: DataType
+ spark_type = BooleanType()
elif types.is_int8(at):
spark_type = ByteType()
elif types.is_int16(at):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org