You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/07 01:54:37 UTC
[spark] branch master updated: [SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6fd1c649c72 [SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts
6fd1c649c72 is described below
commit 6fd1c649c72d4b53ecf83c1643d38002d80c9288
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Wed Jun 7 10:54:24 2023 +0900
[SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts
### What changes were proposed in this pull request?
This PR proposes to add the support of the regular files in `SparkSession.addArtifacts`.
### Why are the changes needed?
So users can add the regular files in the worker nodes.
### Does this PR introduce _any_ user-facing change?
Yes, it adds the support of arbitrary regular files in `SparkSession.addArtifacts`.
### How was this patch tested?
Added a couple of unittests.
Also manually tested in `local-cluster`:
```bash
./sbin/start-connect-server.sh --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar` --master "local-cluster[2,2,1024]"
./bin/pyspark --remote "sc://localhost:15002"
```
```python
import os
import tempfile
from pyspark.sql.functions import udf
from pyspark import SparkFiles
with tempfile.TemporaryDirectory() as d:
file_path = os.path.join(d, "my_file.txt")
with open(file_path, "w") as f:
f.write("Hello world!!")
udf("string")
def func(x):
with open(
os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r"
) as my_file:
return my_file.read().strip()
spark.addArtifacts(file_path, file=True)
spark.range(1).select(func("id")).show()
```
Closes #41415 from HyukjinKwon/addFile.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../artifact/SparkConnectArtifactManager.scala | 6 +++--
python/pyspark/sql/connect/client/artifact.py | 21 +++++++++++----
python/pyspark/sql/connect/client/core.py | 4 +--
python/pyspark/sql/connect/session.py | 13 ++++++---
.../sql/tests/connect/client/test_artifact.py | 31 +++++++++++++++++++---
5 files changed, 58 insertions(+), 17 deletions(-)
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index 604108f68d2..47c48d8e083 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -97,6 +97,7 @@ class SparkConnectArtifactManager private[connect] {
* @param session
* @param remoteRelativePath
* @param serverLocalStagingPath
+ * @param fragment
*/
private[connect] def addArtifact(
sessionHolder: SessionHolder,
@@ -135,8 +136,7 @@ class SparkConnectArtifactManager private[connect] {
// previously added,
if (Files.exists(target)) {
throw new RuntimeException(
- s"Duplicate Jar: $remoteRelativePath. " +
- s"Jars cannot be overwritten.")
+ s"Duplicate file: $remoteRelativePath. Files cannot be overwritten.")
}
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
@@ -154,6 +154,8 @@ class SparkConnectArtifactManager private[connect] {
val canonicalUri =
fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri)
sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
+ } else if (remoteRelativePath.startsWith(s"files${File.separator}")) {
+ sessionHolder.session.sparkContext.addFile(target.toString)
}
}
}
diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py
index 64f89119e4f..9a848bd96b8 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -39,6 +39,7 @@ import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
JAR_PREFIX: str = "jars"
PYFILE_PREFIX: str = "pyfiles"
ARCHIVE_PREFIX: str = "archives"
+FILE_PREFIX: str = "files"
class LocalData(metaclass=abc.ABCMeta):
@@ -107,6 +108,10 @@ def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact:
return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage)
+def new_file_artifact(file_name: str, storage: LocalData) -> Artifact:
+ return _new_artifact(FILE_PREFIX, "", file_name, storage)
+
+
def _new_artifact(
prefix: str, required_suffix: str, file_name: str, storage: LocalData
) -> Artifact:
@@ -141,7 +146,9 @@ class ArtifactManager:
self._stub = grpc_lib.SparkConnectServiceStub(channel)
self._session_id = session_id
- def _parse_artifacts(self, path_or_uri: str, pyfile: bool, archive: bool) -> List[Artifact]:
+ def _parse_artifacts(
+ self, path_or_uri: str, pyfile: bool, archive: bool, file: bool
+ ) -> List[Artifact]:
# Currently only local files with .jar extension is supported.
parsed = urlparse(path_or_uri)
# Check if it is a file from the scheme
@@ -180,6 +187,8 @@ class ArtifactManager:
name = f"{name}#{parsed.fragment}"
artifact = new_archive_artifact(name, LocalFile(local_path))
+ elif file:
+ artifact = new_file_artifact(name, LocalFile(local_path))
elif name.endswith(".jar"):
artifact = new_jar_artifact(name, LocalFile(local_path))
else:
@@ -188,11 +197,13 @@ class ArtifactManager:
raise RuntimeError(f"Unsupported scheme: {parsed.scheme}")
def _create_requests(
- self, *path: str, pyfile: bool, archive: bool
+ self, *path: str, pyfile: bool, archive: bool, file: bool
) -> Iterator[proto.AddArtifactsRequest]:
"""Separated for the testing purpose."""
return self._add_artifacts(
- chain(*(self._parse_artifacts(p, pyfile=pyfile, archive=archive) for p in path))
+ chain(
+ *(self._parse_artifacts(p, pyfile=pyfile, archive=archive, file=file) for p in path)
+ )
)
def _retrieve_responses(
@@ -201,13 +212,13 @@ class ArtifactManager:
"""Separated for the testing purpose."""
return self._stub.AddArtifacts(requests)
- def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
+ def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None:
"""
Add a single artifact to the session.
Currently only local files with .jar extension is supported.
"""
requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(
- *path, pyfile=pyfile, archive=archive
+ *path, pyfile=pyfile, archive=archive, file=file
)
response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []
diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py
index 8da649e7765..b2071641a26 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1251,8 +1251,8 @@ class SparkConnectClient(object):
else:
raise SparkConnectGrpcException(str(rpc_error)) from None
- def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
- self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive)
+ def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None:
+ self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file)
class RetryState:
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 2d58ce1daf0..341db448955 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -613,7 +613,9 @@ class SparkSession:
"""
return self._client
- def addArtifacts(self, *path: str, pyfile: bool = False, archive: bool = False) -> None:
+ def addArtifacts(
+ self, *path: str, pyfile: bool = False, archive: bool = False, file: bool = False
+ ) -> None:
"""
Add artifact(s) to the client session. Currently only local files are supported.
@@ -630,10 +632,13 @@ class SparkSession:
archive : bool
Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files.
The archives are unpacked on the executor side automatically.
+ file : bool
+ Add a file to be downloaded with this Spark job on every node.
+ The ``path`` passed can only be a local file for now.
"""
- if pyfile and archive:
- raise ValueError("'pyfile' and 'archive' cannot be True together.")
- self._client.add_artifacts(*path, pyfile=pyfile, archive=archive)
+ if sum([file, pyfile, archive]) > 1:
+ raise ValueError("'pyfile', 'archive' and/or 'file' cannot be True together.")
+ self._client.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file)
addArtifact = addArtifacts
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py
index 2bff3fd5bc4..1725e0f6e4c 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -49,7 +49,9 @@ class ArtifactTests(ReusedConnectTestCase):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
response = self.artifact_manager._retrieve_responses(
- self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False)
+ self.artifact_manager._create_requests(
+ small_jar_path, pyfile=False, archive=False, file=False
+ )
)
self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar"))
@@ -59,7 +61,9 @@ class ArtifactTests(ReusedConnectTestCase):
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
requests = list(
- self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False)
+ self.artifact_manager._create_requests(
+ small_jar_path, pyfile=False, archive=False, file=False
+ )
)
self.assertEqual(len(requests), 1)
@@ -83,7 +87,9 @@ class ArtifactTests(ReusedConnectTestCase):
large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
requests = list(
- self.artifact_manager._create_requests(large_jar_path, pyfile=False, archive=False)
+ self.artifact_manager._create_requests(
+ large_jar_path, pyfile=False, archive=False, file=False
+ )
)
# Expected chunks = roundUp( file_size / chunk_size) = 12
# File size of `junitLargeJar.jar` is 384581 bytes.
@@ -117,7 +123,7 @@ class ArtifactTests(ReusedConnectTestCase):
requests = list(
self.artifact_manager._create_requests(
- small_jar_path, small_jar_path, pyfile=False, archive=False
+ small_jar_path, small_jar_path, pyfile=False, archive=False, file=False
)
)
# Single request containing 2 artifacts.
@@ -160,6 +166,7 @@ class ArtifactTests(ReusedConnectTestCase):
small_jar_path,
pyfile=False,
archive=False,
+ file=False,
)
)
# There are a total of 14 requests.
@@ -271,6 +278,22 @@ class ArtifactTests(ReusedConnectTestCase):
self.spark.addArtifacts(f"{archive_path}.zip#my_files", archive=True)
self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello world!")
+ def test_add_file(self):
+ with tempfile.TemporaryDirectory() as d:
+ file_path = os.path.join(d, "my_file.txt")
+ with open(file_path, "w") as f:
+ f.write("Hello world!!")
+
+ @udf("string")
+ def func(x):
+ with open(
+ os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r"
+ ) as my_file:
+ return my_file.read().strip()
+
+ self.spark.addArtifacts(file_path, file=True)
+ self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "Hello world!!")
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org