You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/01/13 09:19:18 UTC
[spark] branch master updated: [SPARK-37095][PYTHON] Inline type hints for files in python/pyspark/broadcast.py
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 79d42a1 [SPARK-37095][PYTHON] Inline type hints for files in python/pyspark/broadcast.py
79d42a1 is described below
commit 79d42a1fa6cc354b34c790ea39573735f8363137
Author: dch nguyen <dg...@viettel.com.vn>
AuthorDate: Thu Jan 13 18:18:07 2022 +0900
[SPARK-37095][PYTHON] Inline type hints for files in python/pyspark/broadcast.py
Lead-authored-by: dchvn nguyen <dgd_contributorviettel.com.vn>
Co-authored-by: zero323 <mszymkiewiczgmail.com>
### What changes were proposed in this pull request?
Inline type hints for python/pyspark/broadcast.py
### Why are the changes needed?
We can take advantage of static type checking within the functions by inlining the type hints.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests
Closes #34439 from dchvn/SPARK-37095.
Authored-by: dch nguyen <dg...@viettel.com.vn>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/broadcast.py | 81 ++++++++++++++++++++++++++++++++++----------
python/pyspark/broadcast.pyi | 48 --------------------------
2 files changed, 64 insertions(+), 65 deletions(-)
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 903e4ea..edd282d 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -21,20 +21,40 @@ import sys
from tempfile import NamedTemporaryFile
import threading
import pickle
+from typing import (
+ overload,
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ IO,
+ Iterator,
+ Optional,
+ Tuple,
+ TypeVar,
+ TYPE_CHECKING,
+ Union,
+)
+from typing.io import BinaryIO # type: ignore[import]
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import ChunkedStream, pickle_protocol
from pyspark.util import print_exec
+if TYPE_CHECKING:
+ from pyspark import SparkContext
+
__all__ = ["Broadcast"]
+T = TypeVar("T")
+
# Holds broadcasted data received from Java, keyed by its id.
-_broadcastRegistry = {}
+_broadcastRegistry: Dict[int, "Broadcast[Any]"] = {}
-def _from_id(bid):
+def _from_id(bid: int) -> "Broadcast[Any]":
from pyspark.broadcast import _broadcastRegistry
if bid not in _broadcastRegistry:
@@ -42,7 +62,7 @@ def _from_id(bid):
return _broadcastRegistry[bid]
-class Broadcast:
+class Broadcast(Generic[T]):
"""
A broadcast variable created with :meth:`SparkContext.broadcast`.
@@ -62,7 +82,31 @@ class Broadcast:
>>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_file=None):
+ @overload # On driver
+ def __init__(
+ self: "Broadcast[T]",
+ sc: "SparkContext",
+ value: T,
+ pickle_registry: "BroadcastPickleRegistry",
+ ):
+ ...
+
+ @overload # On worker without decryption server
+ def __init__(self: "Broadcast[Any]", *, path: str):
+ ...
+
+ @overload # On worker with decryption server
+ def __init__(self: "Broadcast[Any]", *, sock_file: str):
+ ...
+
+ def __init__(
+ self,
+ sc: Optional["SparkContext"] = None,
+ value: Optional[T] = None,
+ pickle_registry: Optional["BroadcastPickleRegistry"] = None,
+ path: Optional[str] = None,
+ sock_file: Optional[BinaryIO] = None,
+ ):
"""
Should not be called directly by users -- use :meth:`SparkContext.broadcast`
instead.
@@ -71,8 +115,10 @@ class Broadcast:
# we're on the driver. We want the pickled data to end up in a file (maybe encrypted)
f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
self._path = f.name
- self._sc = sc
+ self._sc: Optional["SparkContext"] = sc
+ assert sc._jvm is not None
self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+ broadcast_out: Union[ChunkedStream, IO[bytes]]
if sc._encryption_enabled:
# with encryption, we ask the jvm to do the encryption for us, we send it data
# over a socket
@@ -82,7 +128,7 @@ class Broadcast:
else:
# no encryption, we can just write pickled data directly to the file from python
broadcast_out = f
- self.dump(value, broadcast_out)
+ self.dump(value, broadcast_out) # type: ignore[arg-type]
if sc._encryption_enabled:
self._python_broadcast.waitTillDataReceived()
self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
@@ -102,7 +148,7 @@ class Broadcast:
assert path is not None
self._path = path
- def dump(self, value, f):
+ def dump(self, value: T, f: BinaryIO) -> None:
try:
pickle.dump(value, f, pickle_protocol)
except pickle.PickleError:
@@ -113,11 +159,11 @@ class Broadcast:
raise pickle.PicklingError(msg)
f.close()
- def load_from_path(self, path):
+ def load_from_path(self, path: str) -> T:
with open(path, "rb", 1 << 20) as f:
return self.load(f)
- def load(self, file):
+ def load(self, file: BinaryIO) -> T:
# "file" could also be a socket
gc.disable()
try:
@@ -126,7 +172,7 @@ class Broadcast:
gc.enable()
@property
- def value(self):
+ def value(self) -> T:
"""Return the broadcasted value"""
if not hasattr(self, "_value") and self._path is not None:
# we only need to decrypt it here when encryption is enabled and
@@ -140,7 +186,7 @@ class Broadcast:
self._value = self.load_from_path(self._path)
return self._value
- def unpersist(self, blocking=False):
+ def unpersist(self, blocking: bool = False) -> None:
"""
Delete cached copies of this broadcast on the executors. If the
broadcast is used after this is called, it will need to be
@@ -155,7 +201,7 @@ class Broadcast:
raise RuntimeError("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
- def destroy(self, blocking=False):
+ def destroy(self, blocking: bool = False) -> None:
"""
Destroy all data and metadata related to this broadcast variable.
Use this with caution; once a broadcast variable has been destroyed,
@@ -175,9 +221,10 @@ class Broadcast:
self._jbroadcast.destroy(blocking)
os.unlink(self._path)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Callable[[int], "Broadcast[T]"], Tuple[int]]:
if self._jbroadcast is None:
raise RuntimeError("Broadcast can only be serialized in driver")
+ assert self._pickle_registry is not None
self._pickle_registry.add(self)
return _from_id, (self._jbroadcast.id(),)
@@ -185,17 +232,17 @@ class Broadcast:
class BroadcastPickleRegistry(threading.local):
"""Thread-local registry for broadcast variables that have been pickled"""
- def __init__(self):
+ def __init__(self) -> None:
self.__dict__.setdefault("_registry", set())
- def __iter__(self):
+ def __iter__(self) -> Iterator[Broadcast[Any]]:
for bcast in self._registry:
yield bcast
- def add(self, bcast):
+ def add(self, bcast: Broadcast[Any]) -> None:
self._registry.add(bcast)
- def clear(self):
+ def clear(self) -> None:
self._registry.clear()
diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi
deleted file mode 100644
index 944cb06..0000000
--- a/python/pyspark/broadcast.pyi
+++ /dev/null
@@ -1,48 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import threading
-from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
-
-T = TypeVar("T")
-
-_broadcastRegistry: Dict[int, Broadcast]
-
-class Broadcast(Generic[T]):
- def __init__(
- self,
- sc: Optional[Any] = ...,
- value: Optional[T] = ...,
- pickle_registry: Optional[Any] = ...,
- path: Optional[Any] = ...,
- sock_file: Optional[Any] = ...,
- ) -> None: ...
- def dump(self, value: T, f: Any) -> None: ...
- def load_from_path(self, path: Any) -> T: ...
- def load(self, file: Any) -> T: ...
- @property
- def value(self) -> T: ...
- def unpersist(self, blocking: bool = ...) -> None: ...
- def destroy(self, blocking: bool = ...) -> None: ...
- def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ...
-
-class BroadcastPickleRegistry(threading.local):
- def __init__(self) -> None: ...
- def __iter__(self) -> None: ...
- def add(self, bcast: Any) -> None: ...
- def clear(self) -> None: ...
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org