You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/01 01:24:54 UTC
[spark] branch master updated: [SPARK-43516][ML][FOLLOW-UP] Make `pyspark.mlv2` module supports python < 3.9
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 33a76811b23 [SPARK-43516][ML][FOLLOW-UP] Make `pyspark.mlv2` module supports python < 3.9
33a76811b23 is described below
commit 33a76811b23f2249cf9343fdc4ef654d12bd23b5
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Thu Jun 1 10:24:39 2023 +0900
[SPARK-43516][ML][FOLLOW-UP] Make `pyspark.mlv2` module supports python < 3.9
### What changes were proposed in this pull request?
Make `pyspark.mlv2` module supports python < 3.9
We need to change some type hints definition to make them compatible with python < 3.9
### Why are the changes needed?
pyspark master still need to support python 3.7 and python 3.8
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Manually run `pyspark.mlv2` tests against python 3.7 or python 3.8
Closes #41405 from WeichenXu123/fix-tpye-hints.
Authored-by: Weichen Xu <we...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/mlv2/base.py | 5 +++--
python/pyspark/mlv2/feature.py | 7 +++----
python/pyspark/mlv2/summarizer.py | 8 ++++----
python/pyspark/mlv2/util.py | 7 +++----
4 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/python/pyspark/mlv2/base.py b/python/pyspark/mlv2/base.py
index 4c0d4652928..dc503db71c0 100644
--- a/python/pyspark/mlv2/base.py
+++ b/python/pyspark/mlv2/base.py
@@ -16,7 +16,6 @@
#
from abc import ABCMeta, abstractmethod
-from collections.abc import Callable
import pandas as pd
@@ -28,6 +27,8 @@ from typing import (
TypeVar,
Union,
TYPE_CHECKING,
+ Tuple,
+ Callable,
)
from pyspark import since
@@ -123,7 +124,7 @@ class Transformer(Params, metaclass=ABCMeta):
"""
raise NotImplementedError()
- def _output_columns(self) -> list[tuple[str, str]]:
+ def _output_columns(self) -> List[Tuple[str, str]]:
"""
Return a list of output transformed columns, each elements in the list
is a tuple of (column_name, column_spark_type)
diff --git a/python/pyspark/mlv2/feature.py b/python/pyspark/mlv2/feature.py
index 6bbcdf7eaac..cecff362823 100644
--- a/python/pyspark/mlv2/feature.py
+++ b/python/pyspark/mlv2/feature.py
@@ -15,10 +15,9 @@
# limitations under the License.
#
-from collections.abc import Callable
import numpy as np
import pandas as pd
-from typing import Any, Union
+from typing import Any, Union, List, Tuple, Callable
from pyspark.sql import DataFrame
from pyspark.mlv2.base import Estimator, Model
@@ -60,7 +59,7 @@ class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol):
def _input_column_name(self) -> str:
return self.getInputCol()
- def _output_columns(self) -> list[tuple[str, str]]:
+ def _output_columns(self) -> List[Tuple[str, str]]:
return [(self.getOutputCol(), "array<double>")]
def _get_transform_fn(self) -> Callable[["pd.Series"], Any]:
@@ -108,7 +107,7 @@ class StandardScalerModel(Model, HasInputCol, HasOutputCol):
def _input_column_name(self) -> str:
return self.getInputCol()
- def _output_columns(self) -> list[tuple[str, str]]:
+ def _output_columns(self) -> List[Tuple[str, str]]:
return [(self.getOutputCol(), "array<double>")]
def _get_transform_fn(self) -> Callable[["pd.Series"], Any]:
diff --git a/python/pyspark/mlv2/summarizer.py b/python/pyspark/mlv2/summarizer.py
index 6bf03c26e19..18776b71eb2 100644
--- a/python/pyspark/mlv2/summarizer.py
+++ b/python/pyspark/mlv2/summarizer.py
@@ -17,7 +17,7 @@
import numpy as np
import pandas as pd
-from typing import Any, Union
+from typing import Any, Union, List, Dict
from pyspark.sql import DataFrame
from pyspark.mlv2.util import aggregate_dataframe
@@ -46,7 +46,7 @@ class SummarizerAggState:
self.max_values = np.maximum(self.max_values, state.max_values)
return self
- def to_result(self, metrics: list[str]) -> dict[str, Any]:
+ def to_result(self, metrics: List[str]) -> Dict[str, Any]:
result = {}
for metric in metrics:
@@ -75,8 +75,8 @@ class SummarizerAggState:
def summarize_dataframe(
- dataframe: Union["DataFrame", "pd.DataFrame"], column: str, metrics: list[str]
-) -> dict[str, Any]:
+ dataframe: Union["DataFrame", "pd.DataFrame"], column: str, metrics: List[str]
+) -> Dict[str, Any]:
"""
Summarize an array type column over a spark dataframe or a pandas dataframe
diff --git a/python/pyspark/mlv2/util.py b/python/pyspark/mlv2/util.py
index de2ffb3d7c1..9aebb3fa9a3 100644
--- a/python/pyspark/mlv2/util.py
+++ b/python/pyspark/mlv2/util.py
@@ -16,8 +16,7 @@
#
import pandas as pd
-from collections.abc import Callable, Iterable
-from typing import Any, Union
+from typing import Any, Union, List, Tuple, Callable, Iterable
from pyspark import cloudpickle
from pyspark.sql import DataFrame
@@ -26,7 +25,7 @@ from pyspark.sql.functions import col, pandas_udf
def aggregate_dataframe(
dataframe: Union["DataFrame", "pd.DataFrame"],
- input_col_names: list[str],
+ input_col_names: List[str],
local_agg_fn: Callable[["pd.DataFrame"], Any],
merge_agg_state: Callable[[Any, Any], Any],
agg_state_to_result: Callable[[Any], Any],
@@ -115,7 +114,7 @@ def transform_dataframe_column(
dataframe: Union["DataFrame", "pd.DataFrame"],
input_col_name: str,
transform_fn: Callable[["pd.Series"], Any],
- output_cols: list[tuple[str, str]],
+ output_cols: List[Tuple[str, str]],
) -> Union["DataFrame", "pd.DataFrame"]:
"""
Transform specified column of the input spark dataframe or pandas dataframe,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org