You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/24 02:44:26 UTC

[spark] branch master updated: [SPARK-41222][CONNECT][PYTHON] Unify the typing definitions

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 381dd7943e5 [SPARK-41222][CONNECT][PYTHON] Unify the typing definitions
381dd7943e5 is described below

commit 381dd7943e52483b1a10cb6d15c980e375631052
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Thu Nov 24 10:44:08 2022 +0800

    [SPARK-41222][CONNECT][PYTHON] Unify the typing definitions
    
    ### What changes were proposed in this pull request?
    1, remove `__init__.py`
    2, rename `ColumnOrString ` as `ColumnOrName` to be the same as pyspark
    
    ### Why are the changes needed?
    1, there are two typing files now: `_typing.py` and `__init__.py`, they are used in different files, which is very confusing;
    2, the definitions of `LiteralType` are different, the old one in `_typing.py` was never used
    3, both `ColumnOrString ` and `ColumnOrName` are used now;
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing UTs
    
    Closes #38757 from zhengruifeng/connect_typing.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/connect/_typing.py          | 41 +++++++++++++++++++++++---
 python/pyspark/sql/connect/client.py           | 25 +++++++---------
 python/pyspark/sql/connect/column.py           |  5 +++-
 python/pyspark/sql/connect/dataframe.py        | 12 ++++----
 python/pyspark/sql/connect/function_builder.py | 10 +++++--
 python/pyspark/sql/connect/plan.py             |  4 +--
 python/pyspark/sql/connect/readwriter.py       |  5 +---
 python/pyspark/sql/connect/typing/__init__.pyi | 35 ----------------------
 8 files changed, 66 insertions(+), 71 deletions(-)

diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py
index 4e69b2e4aa5..262d71fcea1 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -14,8 +14,41 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from typing import Union
-from datetime import date, time, datetime
 
-PrimitiveType = Union[str, int, bool, float]
-LiteralType = Union[PrimitiveType, Union[date, time, datetime]]
+import sys
+
+if sys.version_info >= (3, 8):
+    from typing import Protocol
+else:
+    from typing_extensions import Protocol
+
+from typing import Union, Optional
+import datetime
+import decimal
+
+from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column
+from pyspark.sql.connect.function_builder import UserDefinedFunction
+
+ExpressionOrString = Union[Expression, str]
+
+ColumnOrName = Union[Column, str]
+
+PrimitiveType = Union[bool, float, int, str]
+
+OptionalPrimitiveType = Optional[PrimitiveType]
+
+LiteralType = PrimitiveType
+
+DecimalLiteral = decimal.Decimal
+
+DateTimeLiteral = Union[datetime.datetime, datetime.date]
+
+
+class FunctionBuilderCallable(Protocol):
+    def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression:
+        ...
+
+
+class UserDefinedFunctionCallable(Protocol):
+    def __call__(self, *_: ColumnOrName) -> UserDefinedFunction:
+        ...
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index fdcf34b7a47..deb9ef6f3be 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -18,7 +18,6 @@
 
 import logging
 import os
-import typing
 import urllib.parse
 import uuid
 
@@ -35,9 +34,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.plan import SQL, Range
 from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType
 
-from typing import Optional, Any, Union
-
-NumericType = typing.Union[int, float]
+from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
 
 logging.basicConfig(level=logging.INFO)
 
@@ -74,7 +71,7 @@ class ChannelBuilder:
         # Python's built-in parser.
         tmp_url = "http" + url[2:]
         self.url = urllib.parse.urlparse(tmp_url)
-        self.params: typing.Dict[str, str] = {}
+        self.params: Dict[str, str] = {}
         if len(self.url.path) > 0 and self.url.path != "/":
             raise AttributeError(
                 f"Path component for connection URI must be empty: {self.url.path}"
@@ -102,7 +99,7 @@ class ChannelBuilder:
                 f"Target destination {self.url.netloc} does not match '<host>:<port>' pattern"
             )
 
-    def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]:
+    def metadata(self) -> Iterable[Tuple[str, str]]:
         """
         Builds the GRPC specific metadata list to be injected into the request. All
         parameters will be converted to metadata except ones that are explicitly used
@@ -198,7 +195,7 @@ class ChannelBuilder:
 
 
 class MetricValue:
-    def __init__(self, name: str, value: NumericType, type: str):
+    def __init__(self, name: str, value: Union[int, float], type: str):
         self._name = name
         self._type = type
         self._value = value
@@ -211,7 +208,7 @@ class MetricValue:
         return self._name
 
     @property
-    def value(self) -> NumericType:
+    def value(self) -> Union[int, float]:
         return self._value
 
     @property
@@ -220,7 +217,7 @@ class MetricValue:
 
 
 class PlanMetrics:
-    def __init__(self, name: str, id: int, parent: int, metrics: typing.List[MetricValue]):
+    def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]):
         self._name = name
         self._id = id
         self._parent_id = parent
@@ -242,7 +239,7 @@ class PlanMetrics:
         return self._parent_id
 
     @property
-    def metrics(self) -> typing.List[MetricValue]:
+    def metrics(self) -> List[MetricValue]:
         return self._metrics
 
 
@@ -252,7 +249,7 @@ class AnalyzeResult:
         self.explain_string = explain
 
     @classmethod
-    def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
+    def fromProto(cls, pb: Any) -> "AnalyzeResult":
         return AnalyzeResult(pb.schema, pb.explain_string)
 
 
@@ -306,9 +303,7 @@ class RemoteSparkSession(object):
         self._execute_and_fetch(req)
         return name
 
-    def _build_metrics(
-        self, metrics: "pb2.ExecutePlanResponse.Metrics"
-    ) -> typing.List[PlanMetrics]:
+    def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
         return [
             PlanMetrics(
                 x.name,
@@ -450,7 +445,7 @@ class RemoteSparkSession(object):
                 return rd.read_pandas()
         return None
 
-    def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]:
+    def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]:
         import pandas as pd
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index c4ffc54c20b..36f38e0ded2 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -22,7 +22,6 @@ import decimal
 import datetime
 
 import pyspark.sql.connect.proto as proto
-from pyspark.sql.connect._typing import PrimitiveType
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.client import RemoteSparkSession
@@ -33,6 +32,8 @@ def _bin_op(
     name: str, doc: str = "binary function", reverse: bool = False
 ) -> Callable[["Column", Any], "Expression"]:
     def _(self: "Column", other: Any) -> "Expression":
+        from pyspark.sql.connect._typing import PrimitiveType
+
         if isinstance(other, get_args(PrimitiveType)):
             other = LiteralExpression(other)
         if not reverse:
@@ -70,6 +71,8 @@ class Expression(object):
         """Returns a binary expression with the current column as the left
         side and the other expression as the right side.
         """
+        from pyspark.sql.connect._typing import PrimitiveType
+
         if isinstance(other, get_args(PrimitiveType)):
             other = LiteralExpression(other)
         return ScalarFunctionExpression("==", self, other)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 7b42bdf747b..ff14945db0f 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -44,11 +44,9 @@ from pyspark.sql.types import (
 )
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString, LiteralType
+    from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType
     from pyspark.sql.connect.client import RemoteSparkSession
 
-ColumnOrName = Union[Column, str]
-
 
 class GroupingFrame(object):
 
@@ -308,7 +306,7 @@ class DataFrame(object):
             plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
         )
 
-    def drop(self, *cols: "ColumnOrString") -> "DataFrame":
+    def drop(self, *cols: "ColumnOrName") -> "DataFrame":
         _cols = list(cols)
         if any(not isinstance(c, (str, Column)) for c in _cols):
             raise TypeError(
@@ -342,7 +340,7 @@ class DataFrame(object):
         """
         return self.head()
 
-    def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
+    def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame:
         return GroupingFrame(self, *cols)
 
     @overload
@@ -414,13 +412,13 @@ class DataFrame(object):
     def offset(self, n: int) -> "DataFrame":
         return DataFrame.withPlan(plan.Offset(child=self._plan, offset=n), session=self._session)
 
-    def sort(self, *cols: "ColumnOrString") -> "DataFrame":
+    def sort(self, *cols: "ColumnOrName") -> "DataFrame":
         """Sort by a specific column"""
         return DataFrame.withPlan(
             plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session
         )
 
-    def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame":
+    def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame":
         """Sort within each partition by a specific column"""
         return DataFrame.withPlan(
             plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session
diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py
index e116e493954..4a2688d6a0d 100644
--- a/python/pyspark/sql/connect/function_builder.py
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -28,9 +28,13 @@ from pyspark.sql.connect.column import (
 
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
+    from pyspark.sql.connect._typing import (
+        ColumnOrName,
+        ExpressionOrString,
+        FunctionBuilderCallable,
+        UserDefinedFunctionCallable,
+    )
     from pyspark.sql.connect.client import RemoteSparkSession
-    from pyspark.sql.connect.typing import FunctionBuilderCallable, UserDefinedFunctionCallable
 
 
 def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression:
@@ -103,7 +107,7 @@ class UserDefinedFunction(Expression):
 def _create_udf(
     function: Any, return_type: Union[str, pyspark.sql.types.DataType]
 ) -> "UserDefinedFunctionCallable":
-    def wrapper(*cols: "ColumnOrString") -> UserDefinedFunction:
+    def wrapper(*cols: "ColumnOrName") -> UserDefinedFunction:
         return UserDefinedFunction(func=function, return_type=return_type, args=cols)
 
     return wrapper
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index ffb0ce080b3..8aadc3dc4fa 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.column import (
 
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
+    from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString
     from pyspark.sql.connect.client import RemoteSparkSession
 
 
@@ -58,7 +58,7 @@ class LogicalPlan(object):
         return exp
 
     def to_attr_or_expression(
-        self, col: "ColumnOrString", session: "RemoteSparkSession"
+        self, col: "ColumnOrName", session: "RemoteSparkSession"
     ) -> proto.Expression:
         """Returns either an instance of an unresolved attribute or the serialized
         expression value of the column."""
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 66e48eeab76..27aa023ae47 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -18,17 +18,14 @@
 
 from typing import Dict, Optional
 
-from pyspark.sql.connect.column import PrimitiveType
 from pyspark.sql.connect.dataframe import DataFrame
 from pyspark.sql.connect.plan import Read, DataSource
 from pyspark.sql.utils import to_str
 
-
-OptionalPrimitiveType = Optional[PrimitiveType]
-
 from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
+    from pyspark.sql.connect._typing import OptionalPrimitiveType
     from pyspark.sql.connect.client import RemoteSparkSession
 
 
diff --git a/python/pyspark/sql/connect/typing/__init__.pyi b/python/pyspark/sql/connect/typing/__init__.pyi
deleted file mode 100644
index 43cc28701da..00000000000
--- a/python/pyspark/sql/connect/typing/__init__.pyi
+++ /dev/null
@@ -1,35 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from typing_extensions import Protocol
-from typing import Union
-from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column
-from pyspark.sql.connect.function_builder import UserDefinedFunction
-
-ExpressionOrString = Union[str, Expression]
-
-ColumnOrString = Union[str, Column]
-
-PrimitiveType = Union[bool, float, int, str]
-
-LiteralType = PrimitiveType
-
-class FunctionBuilderCallable(Protocol):
-    def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: ...
-
-class UserDefinedFunctionCallable(Protocol):
-    def __call__(self, *_: ColumnOrString) -> UserDefinedFunction: ...


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org