You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/01/13 09:19:18 UTC

[spark] branch master updated: [SPARK-37095][PYTHON] Inline type hints for files in python/pyspark/broadcast.py

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 79d42a1  [SPARK-37095][PYTHON] Inline type hints for files in python/pyspark/broadcast.py
79d42a1 is described below

commit 79d42a1fa6cc354b34c790ea39573735f8363137
Author: dch nguyen <dg...@viettel.com.vn>
AuthorDate: Thu Jan 13 18:18:07 2022 +0900

    [SPARK-37095][PYTHON] Inline type hints for files in python/pyspark/broadcast.py
    
    Lead-authored-by: dchvn nguyen <dgd_contributorviettel.com.vn>
    Co-authored-by: zero323 <mszymkiewiczgmail.com>
    
    ### What changes were proposed in this pull request?
    Inline type hints for python/pyspark/broadcast.py
    ### Why are the changes needed?
    We can take advantage of static type checking within the functions by inlining the type hints.
    ### Does this PR introduce _any_ user-facing change?
    No
    ### How was this patch tested?
    Existing tests
    
    Closes #34439 from dchvn/SPARK-37095.
    
    Authored-by: dch nguyen <dg...@viettel.com.vn>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/broadcast.py  | 81 ++++++++++++++++++++++++++++++++++----------
 python/pyspark/broadcast.pyi | 48 --------------------------
 2 files changed, 64 insertions(+), 65 deletions(-)

diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 903e4ea..edd282d 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -21,20 +21,40 @@ import sys
 from tempfile import NamedTemporaryFile
 import threading
 import pickle
+from typing import (
+    overload,
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    IO,
+    Iterator,
+    Optional,
+    Tuple,
+    TypeVar,
+    TYPE_CHECKING,
+    Union,
+)
+from typing.io import BinaryIO  # type: ignore[import]
 
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import ChunkedStream, pickle_protocol
 from pyspark.util import print_exec
 
+if TYPE_CHECKING:
+    from pyspark import SparkContext
+
 
 __all__ = ["Broadcast"]
 
+T = TypeVar("T")
+
 
 # Holds broadcasted data received from Java, keyed by its id.
-_broadcastRegistry = {}
+_broadcastRegistry: Dict[int, "Broadcast[Any]"] = {}
 
 
-def _from_id(bid):
+def _from_id(bid: int) -> "Broadcast[Any]":
     from pyspark.broadcast import _broadcastRegistry
 
     if bid not in _broadcastRegistry:
@@ -42,7 +62,7 @@ def _from_id(bid):
     return _broadcastRegistry[bid]
 
 
-class Broadcast:
+class Broadcast(Generic[T]):
 
     """
     A broadcast variable created with :meth:`SparkContext.broadcast`.
@@ -62,7 +82,31 @@ class Broadcast:
     >>> large_broadcast = sc.broadcast(range(10000))
     """
 
-    def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_file=None):
+    @overload  # On driver
+    def __init__(
+        self: "Broadcast[T]",
+        sc: "SparkContext",
+        value: T,
+        pickle_registry: "BroadcastPickleRegistry",
+    ):
+        ...
+
+    @overload  # On worker without decryption server
+    def __init__(self: "Broadcast[Any]", *, path: str):
+        ...
+
+    @overload  # On worker with decryption server
+    def __init__(self: "Broadcast[Any]", *, sock_file: str):
+        ...
+
+    def __init__(
+        self,
+        sc: Optional["SparkContext"] = None,
+        value: Optional[T] = None,
+        pickle_registry: Optional["BroadcastPickleRegistry"] = None,
+        path: Optional[str] = None,
+        sock_file: Optional[BinaryIO] = None,
+    ):
         """
         Should not be called directly by users -- use :meth:`SparkContext.broadcast`
         instead.
@@ -71,8 +115,10 @@ class Broadcast:
             # we're on the driver.  We want the pickled data to end up in a file (maybe encrypted)
             f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
             self._path = f.name
-            self._sc = sc
+            self._sc: Optional["SparkContext"] = sc
+            assert sc._jvm is not None
             self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+            broadcast_out: Union[ChunkedStream, IO[bytes]]
             if sc._encryption_enabled:
                 # with encryption, we ask the jvm to do the encryption for us, we send it data
                 # over a socket
@@ -82,7 +128,7 @@ class Broadcast:
             else:
                 # no encryption, we can just write pickled data directly to the file from python
                 broadcast_out = f
-            self.dump(value, broadcast_out)
+            self.dump(value, broadcast_out)  # type: ignore[arg-type]
             if sc._encryption_enabled:
                 self._python_broadcast.waitTillDataReceived()
             self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
@@ -102,7 +148,7 @@ class Broadcast:
                 assert path is not None
                 self._path = path
 
-    def dump(self, value, f):
+    def dump(self, value: T, f: BinaryIO) -> None:
         try:
             pickle.dump(value, f, pickle_protocol)
         except pickle.PickleError:
@@ -113,11 +159,11 @@ class Broadcast:
             raise pickle.PicklingError(msg)
         f.close()
 
-    def load_from_path(self, path):
+    def load_from_path(self, path: str) -> T:
         with open(path, "rb", 1 << 20) as f:
             return self.load(f)
 
-    def load(self, file):
+    def load(self, file: BinaryIO) -> T:
         # "file" could also be a socket
         gc.disable()
         try:
@@ -126,7 +172,7 @@ class Broadcast:
             gc.enable()
 
     @property
-    def value(self):
+    def value(self) -> T:
         """Return the broadcasted value"""
         if not hasattr(self, "_value") and self._path is not None:
             # we only need to decrypt it here when encryption is enabled and
@@ -140,7 +186,7 @@ class Broadcast:
                 self._value = self.load_from_path(self._path)
         return self._value
 
-    def unpersist(self, blocking=False):
+    def unpersist(self, blocking: bool = False) -> None:
         """
         Delete cached copies of this broadcast on the executors. If the
         broadcast is used after this is called, it will need to be
@@ -155,7 +201,7 @@ class Broadcast:
             raise RuntimeError("Broadcast can only be unpersisted in driver")
         self._jbroadcast.unpersist(blocking)
 
-    def destroy(self, blocking=False):
+    def destroy(self, blocking: bool = False) -> None:
         """
         Destroy all data and metadata related to this broadcast variable.
         Use this with caution; once a broadcast variable has been destroyed,
@@ -175,9 +221,10 @@ class Broadcast:
         self._jbroadcast.destroy(blocking)
         os.unlink(self._path)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple[Callable[[int], "Broadcast[T]"], Tuple[int]]:
         if self._jbroadcast is None:
             raise RuntimeError("Broadcast can only be serialized in driver")
+        assert self._pickle_registry is not None
         self._pickle_registry.add(self)
         return _from_id, (self._jbroadcast.id(),)
 
@@ -185,17 +232,17 @@ class Broadcast:
 class BroadcastPickleRegistry(threading.local):
     """Thread-local registry for broadcast variables that have been pickled"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.__dict__.setdefault("_registry", set())
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Broadcast[Any]]:
         for bcast in self._registry:
             yield bcast
 
-    def add(self, bcast):
+    def add(self, bcast: Broadcast[Any]) -> None:
         self._registry.add(bcast)
 
-    def clear(self):
+    def clear(self) -> None:
         self._registry.clear()
 
 
diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi
deleted file mode 100644
index 944cb06..0000000
--- a/python/pyspark/broadcast.pyi
+++ /dev/null
@@ -1,48 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import threading
-from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
-
-T = TypeVar("T")
-
-_broadcastRegistry: Dict[int, Broadcast]
-
-class Broadcast(Generic[T]):
-    def __init__(
-        self,
-        sc: Optional[Any] = ...,
-        value: Optional[T] = ...,
-        pickle_registry: Optional[Any] = ...,
-        path: Optional[Any] = ...,
-        sock_file: Optional[Any] = ...,
-    ) -> None: ...
-    def dump(self, value: T, f: Any) -> None: ...
-    def load_from_path(self, path: Any) -> T: ...
-    def load(self, file: Any) -> T: ...
-    @property
-    def value(self) -> T: ...
-    def unpersist(self, blocking: bool = ...) -> None: ...
-    def destroy(self, blocking: bool = ...) -> None: ...
-    def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ...
-
-class BroadcastPickleRegistry(threading.local):
-    def __init__(self) -> None: ...
-    def __iter__(self) -> None: ...
-    def add(self, bcast: Any) -> None: ...
-    def clear(self) -> None: ...

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org