You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/12/01 00:05:13 UTC

[GitHub] [spark] ueshin commented on a change in pull request #34439: [SPARK-37095][PYTHON] Inline type hints for files in python/pyspark/broadcast.py

ueshin commented on a change in pull request #34439:
URL: https://github.com/apache/spark/pull/34439#discussion_r759746730



##########
File path: python/pyspark/broadcast.py
##########
@@ -62,35 +81,44 @@ class Broadcast(object):
     >>> large_broadcast = sc.broadcast(range(10000))
     """
 
-    def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_file=None):
+    def __init__(
+        self,
+        sc: Optional["SparkContext"] = None,
+        value: Optional[T] = None,
+        pickle_registry: Optional["BroadcastPickleRegistry"] = None,
+        path: Optional[Any] = None,
+        sock_file: Optional[Any] = None,
+    ):
         """
         Should not be called directly by users -- use :meth:`SparkContext.broadcast`
         instead.
         """
         if sc is not None:
             # we're on the driver.  We want the pickled data to end up in a file (maybe encrypted)
-            f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
+            f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)  # type: ignore[attr-defined]
             self._path = f.name
-            self._sc = sc
-            self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
-            if sc._encryption_enabled:
+            self._sc: Optional["SparkContext"] = sc
+            self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)  # type: ignore[attr-defined]
+            if sc._encryption_enabled:  # type: ignore[attr-defined]
                 # with encryption, we ask the jvm to do the encryption for us, we send it data
                 # over a socket
                 port, auth_secret = self._python_broadcast.setupEncryptionServer()
                 (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
-                broadcast_out = ChunkedStream(encryption_sock_file, 8192)
+                broadcast_out: Union[ChunkedStream, IO[bytes]] = ChunkedStream(
+                    encryption_sock_file, 8192
+                )
             else:
                 # no encryption, we can just write pickled data directly to the file from python
                 broadcast_out = f
-            self.dump(value, broadcast_out)
-            if sc._encryption_enabled:
+            self.dump(cast(T, value), broadcast_out)

Review comment:
       I'm fine with adding the overloads.

##########
File path: python/pyspark/broadcast.py
##########
@@ -175,27 +205,27 @@ def destroy(self, blocking=False):
         self._jbroadcast.destroy(blocking)
         os.unlink(self._path)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple[Callable[[int], "Broadcast[T]"], Tuple[int]]:
         if self._jbroadcast is None:
             raise RuntimeError("Broadcast can only be serialized in driver")
-        self._pickle_registry.add(self)
+        cast(Any, self._pickle_registry).add(self)
         return _from_id, (self._jbroadcast.id(),)
 
 
 class BroadcastPickleRegistry(threading.local):
     """Thread-local registry for broadcast variables that have been pickled"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.__dict__.setdefault("_registry", set())
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Broadcast]:

Review comment:
       `Iterator[Broadcast[Any]]`?

##########
File path: python/pyspark/broadcast.py
##########
@@ -175,27 +205,27 @@ def destroy(self, blocking=False):
         self._jbroadcast.destroy(blocking)
         os.unlink(self._path)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple[Callable[[int], "Broadcast[T]"], Tuple[int]]:
         if self._jbroadcast is None:
             raise RuntimeError("Broadcast can only be serialized in driver")
-        self._pickle_registry.add(self)
+        cast(Any, self._pickle_registry).add(self)

Review comment:
       `BroadcastPickleRegistry` instead of `Any`?

##########
File path: python/pyspark/broadcast.py
##########
@@ -175,27 +205,27 @@ def destroy(self, blocking=False):
         self._jbroadcast.destroy(blocking)
         os.unlink(self._path)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple[Callable[[int], "Broadcast[T]"], Tuple[int]]:
         if self._jbroadcast is None:
             raise RuntimeError("Broadcast can only be serialized in driver")
-        self._pickle_registry.add(self)
+        cast(Any, self._pickle_registry).add(self)
         return _from_id, (self._jbroadcast.id(),)
 
 
 class BroadcastPickleRegistry(threading.local):
     """Thread-local registry for broadcast variables that have been pickled"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.__dict__.setdefault("_registry", set())
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Broadcast]:
         for bcast in self._registry:
             yield bcast
 
-    def add(self, bcast):
+    def add(self, bcast: Broadcast) -> None:

Review comment:
       `Broadcast[Any]`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org