You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/24 02:44:26 UTC
[spark] branch master updated: [SPARK-41222][CONNECT][PYTHON] Unify the typing definitions
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 381dd7943e5 [SPARK-41222][CONNECT][PYTHON] Unify the typing definitions
381dd7943e5 is described below
commit 381dd7943e52483b1a10cb6d15c980e375631052
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Thu Nov 24 10:44:08 2022 +0800
[SPARK-41222][CONNECT][PYTHON] Unify the typing definitions
### What changes were proposed in this pull request?
1, remove `__init__.py`
2, rename `ColumnOrString ` as `ColumnOrName` to be the same as pyspark
### Why are the changes needed?
1, there are two typing files now: `_typing.py` and `__init__.py`, they are used in different files, which is very confusing;
2, the definitions of `LiteralType` are different, the old one in `_typing.py` was never used
3, both `ColumnOrString ` and `ColumnOrName` are used now;
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing UTs
Closes #38757 from zhengruifeng/connect_typing.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/connect/_typing.py | 41 +++++++++++++++++++++++---
python/pyspark/sql/connect/client.py | 25 +++++++---------
python/pyspark/sql/connect/column.py | 5 +++-
python/pyspark/sql/connect/dataframe.py | 12 ++++----
python/pyspark/sql/connect/function_builder.py | 10 +++++--
python/pyspark/sql/connect/plan.py | 4 +--
python/pyspark/sql/connect/readwriter.py | 5 +---
python/pyspark/sql/connect/typing/__init__.pyi | 35 ----------------------
8 files changed, 66 insertions(+), 71 deletions(-)
diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py
index 4e69b2e4aa5..262d71fcea1 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -14,8 +14,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Union
-from datetime import date, time, datetime
-PrimitiveType = Union[str, int, bool, float]
-LiteralType = Union[PrimitiveType, Union[date, time, datetime]]
+import sys
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+from typing import Union, Optional
+import datetime
+import decimal
+
+from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column
+from pyspark.sql.connect.function_builder import UserDefinedFunction
+
+ExpressionOrString = Union[Expression, str]
+
+ColumnOrName = Union[Column, str]
+
+PrimitiveType = Union[bool, float, int, str]
+
+OptionalPrimitiveType = Optional[PrimitiveType]
+
+LiteralType = PrimitiveType
+
+DecimalLiteral = decimal.Decimal
+
+DateTimeLiteral = Union[datetime.datetime, datetime.date]
+
+
+class FunctionBuilderCallable(Protocol):
+ def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression:
+ ...
+
+
+class UserDefinedFunctionCallable(Protocol):
+ def __call__(self, *_: ColumnOrName) -> UserDefinedFunction:
+ ...
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index fdcf34b7a47..deb9ef6f3be 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -18,7 +18,6 @@
import logging
import os
-import typing
import urllib.parse
import uuid
@@ -35,9 +34,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.plan import SQL, Range
from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType
-from typing import Optional, Any, Union
-
-NumericType = typing.Union[int, float]
+from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
logging.basicConfig(level=logging.INFO)
@@ -74,7 +71,7 @@ class ChannelBuilder:
# Python's built-in parser.
tmp_url = "http" + url[2:]
self.url = urllib.parse.urlparse(tmp_url)
- self.params: typing.Dict[str, str] = {}
+ self.params: Dict[str, str] = {}
if len(self.url.path) > 0 and self.url.path != "/":
raise AttributeError(
f"Path component for connection URI must be empty: {self.url.path}"
@@ -102,7 +99,7 @@ class ChannelBuilder:
f"Target destination {self.url.netloc} does not match '<host>:<port>' pattern"
)
- def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]:
+ def metadata(self) -> Iterable[Tuple[str, str]]:
"""
Builds the GRPC specific metadata list to be injected into the request. All
parameters will be converted to metadata except ones that are explicitly used
@@ -198,7 +195,7 @@ class ChannelBuilder:
class MetricValue:
- def __init__(self, name: str, value: NumericType, type: str):
+ def __init__(self, name: str, value: Union[int, float], type: str):
self._name = name
self._type = type
self._value = value
@@ -211,7 +208,7 @@ class MetricValue:
return self._name
@property
- def value(self) -> NumericType:
+ def value(self) -> Union[int, float]:
return self._value
@property
@@ -220,7 +217,7 @@ class MetricValue:
class PlanMetrics:
- def __init__(self, name: str, id: int, parent: int, metrics: typing.List[MetricValue]):
+ def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]):
self._name = name
self._id = id
self._parent_id = parent
@@ -242,7 +239,7 @@ class PlanMetrics:
return self._parent_id
@property
- def metrics(self) -> typing.List[MetricValue]:
+ def metrics(self) -> List[MetricValue]:
return self._metrics
@@ -252,7 +249,7 @@ class AnalyzeResult:
self.explain_string = explain
@classmethod
- def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
+ def fromProto(cls, pb: Any) -> "AnalyzeResult":
return AnalyzeResult(pb.schema, pb.explain_string)
@@ -306,9 +303,7 @@ class RemoteSparkSession(object):
self._execute_and_fetch(req)
return name
- def _build_metrics(
- self, metrics: "pb2.ExecutePlanResponse.Metrics"
- ) -> typing.List[PlanMetrics]:
+ def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
return [
PlanMetrics(
x.name,
@@ -450,7 +445,7 @@ class RemoteSparkSession(object):
return rd.read_pandas()
return None
- def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]:
+ def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]:
import pandas as pd
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index c4ffc54c20b..36f38e0ded2 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -22,7 +22,6 @@ import decimal
import datetime
import pyspark.sql.connect.proto as proto
-from pyspark.sql.connect._typing import PrimitiveType
if TYPE_CHECKING:
from pyspark.sql.connect.client import RemoteSparkSession
@@ -33,6 +32,8 @@ def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["Column", Any], "Expression"]:
def _(self: "Column", other: Any) -> "Expression":
+ from pyspark.sql.connect._typing import PrimitiveType
+
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
@@ -70,6 +71,8 @@ class Expression(object):
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
+ from pyspark.sql.connect._typing import PrimitiveType
+
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("==", self, other)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 7b42bdf747b..ff14945db0f 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -44,11 +44,9 @@ from pyspark.sql.types import (
)
if TYPE_CHECKING:
- from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString, LiteralType
+ from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType
from pyspark.sql.connect.client import RemoteSparkSession
-ColumnOrName = Union[Column, str]
-
class GroupingFrame(object):
@@ -308,7 +306,7 @@ class DataFrame(object):
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)
- def drop(self, *cols: "ColumnOrString") -> "DataFrame":
+ def drop(self, *cols: "ColumnOrName") -> "DataFrame":
_cols = list(cols)
if any(not isinstance(c, (str, Column)) for c in _cols):
raise TypeError(
@@ -342,7 +340,7 @@ class DataFrame(object):
"""
return self.head()
- def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
+ def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame:
return GroupingFrame(self, *cols)
@overload
@@ -414,13 +412,13 @@ class DataFrame(object):
def offset(self, n: int) -> "DataFrame":
return DataFrame.withPlan(plan.Offset(child=self._plan, offset=n), session=self._session)
- def sort(self, *cols: "ColumnOrString") -> "DataFrame":
+ def sort(self, *cols: "ColumnOrName") -> "DataFrame":
"""Sort by a specific column"""
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session
)
- def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame":
+ def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame":
"""Sort within each partition by a specific column"""
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session
diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py
index e116e493954..4a2688d6a0d 100644
--- a/python/pyspark/sql/connect/function_builder.py
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -28,9 +28,13 @@ from pyspark.sql.connect.column import (
if TYPE_CHECKING:
- from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
+ from pyspark.sql.connect._typing import (
+ ColumnOrName,
+ ExpressionOrString,
+ FunctionBuilderCallable,
+ UserDefinedFunctionCallable,
+ )
from pyspark.sql.connect.client import RemoteSparkSession
- from pyspark.sql.connect.typing import FunctionBuilderCallable, UserDefinedFunctionCallable
def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression:
@@ -103,7 +107,7 @@ class UserDefinedFunction(Expression):
def _create_udf(
function: Any, return_type: Union[str, pyspark.sql.types.DataType]
) -> "UserDefinedFunctionCallable":
- def wrapper(*cols: "ColumnOrString") -> UserDefinedFunction:
+ def wrapper(*cols: "ColumnOrName") -> UserDefinedFunction:
return UserDefinedFunction(func=function, return_type=return_type, args=cols)
return wrapper
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index ffb0ce080b3..8aadc3dc4fa 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.column import (
if TYPE_CHECKING:
- from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
+ from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString
from pyspark.sql.connect.client import RemoteSparkSession
@@ -58,7 +58,7 @@ class LogicalPlan(object):
return exp
def to_attr_or_expression(
- self, col: "ColumnOrString", session: "RemoteSparkSession"
+ self, col: "ColumnOrName", session: "RemoteSparkSession"
) -> proto.Expression:
"""Returns either an instance of an unresolved attribute or the serialized
expression value of the column."""
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 66e48eeab76..27aa023ae47 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -18,17 +18,14 @@
from typing import Dict, Optional
-from pyspark.sql.connect.column import PrimitiveType
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import Read, DataSource
from pyspark.sql.utils import to_str
-
-OptionalPrimitiveType = Optional[PrimitiveType]
-
from typing import TYPE_CHECKING
if TYPE_CHECKING:
+ from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.client import RemoteSparkSession
diff --git a/python/pyspark/sql/connect/typing/__init__.pyi b/python/pyspark/sql/connect/typing/__init__.pyi
deleted file mode 100644
index 43cc28701da..00000000000
--- a/python/pyspark/sql/connect/typing/__init__.pyi
+++ /dev/null
@@ -1,35 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from typing_extensions import Protocol
-from typing import Union
-from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column
-from pyspark.sql.connect.function_builder import UserDefinedFunction
-
-ExpressionOrString = Union[str, Expression]
-
-ColumnOrString = Union[str, Column]
-
-PrimitiveType = Union[bool, float, int, str]
-
-LiteralType = PrimitiveType
-
-class FunctionBuilderCallable(Protocol):
- def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: ...
-
-class UserDefinedFunctionCallable(Protocol):
- def __call__(self, *_: ColumnOrString) -> UserDefinedFunction: ...
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org