You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/07 01:54:37 UTC

[spark] branch master updated: [SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6fd1c649c72 [SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts
6fd1c649c72 is described below

commit 6fd1c649c72d4b53ecf83c1643d38002d80c9288
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Wed Jun 7 10:54:24 2023 +0900

    [SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add the support of the regular files in `SparkSession.addArtifacts`.
    
    ### Why are the changes needed?
    
    So users can add the regular files in the worker nodes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds the support of arbitrary regular files in `SparkSession.addArtifacts`.
    
    ### How was this patch tested?
    
    Added a couple of unittests.
    
    Also manually tested in `local-cluster`:
    
    ```bash
    ./sbin/start-connect-server.sh --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar` --master "local-cluster[2,2,1024]"
    ./bin/pyspark --remote "sc://localhost:15002"
    ```
    
    ```python
    import os
    import tempfile
    from pyspark.sql.functions import udf
    from pyspark import SparkFiles
    
    with tempfile.TemporaryDirectory() as d:
        file_path = os.path.join(d, "my_file.txt")
        with open(file_path, "w") as f:
            f.write("Hello world!!")
        udf("string")
        def func(x):
            with open(
                os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r"
            ) as my_file:
                return my_file.read().strip()
        spark.addArtifacts(file_path, file=True)
        spark.range(1).select(func("id")).show()
    ```
    
    Closes #41415 from HyukjinKwon/addFile.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../artifact/SparkConnectArtifactManager.scala     |  6 +++--
 python/pyspark/sql/connect/client/artifact.py      | 21 +++++++++++----
 python/pyspark/sql/connect/client/core.py          |  4 +--
 python/pyspark/sql/connect/session.py              | 13 ++++++---
 .../sql/tests/connect/client/test_artifact.py      | 31 +++++++++++++++++++---
 5 files changed, 58 insertions(+), 17 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index 604108f68d2..47c48d8e083 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -97,6 +97,7 @@ class SparkConnectArtifactManager private[connect] {
    * @param session
    * @param remoteRelativePath
    * @param serverLocalStagingPath
+   * @param fragment
    */
   private[connect] def addArtifact(
       sessionHolder: SessionHolder,
@@ -135,8 +136,7 @@ class SparkConnectArtifactManager private[connect] {
       // previously added,
       if (Files.exists(target)) {
         throw new RuntimeException(
-          s"Duplicate Jar: $remoteRelativePath. " +
-            s"Jars cannot be overwritten.")
+          s"Duplicate file: $remoteRelativePath. Files cannot be overwritten.")
       }
       Files.move(serverLocalStagingPath, target)
       if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
@@ -154,6 +154,8 @@ class SparkConnectArtifactManager private[connect] {
         val canonicalUri =
           fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri)
         sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
+      } else if (remoteRelativePath.startsWith(s"files${File.separator}")) {
+        sessionHolder.session.sparkContext.addFile(target.toString)
       }
     }
   }
diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py
index 64f89119e4f..9a848bd96b8 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -39,6 +39,7 @@ import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
 JAR_PREFIX: str = "jars"
 PYFILE_PREFIX: str = "pyfiles"
 ARCHIVE_PREFIX: str = "archives"
+FILE_PREFIX: str = "files"
 
 
 class LocalData(metaclass=abc.ABCMeta):
@@ -107,6 +108,10 @@ def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact:
     return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage)
 
 
+def new_file_artifact(file_name: str, storage: LocalData) -> Artifact:
+    return _new_artifact(FILE_PREFIX, "", file_name, storage)
+
+
 def _new_artifact(
     prefix: str, required_suffix: str, file_name: str, storage: LocalData
 ) -> Artifact:
@@ -141,7 +146,9 @@ class ArtifactManager:
         self._stub = grpc_lib.SparkConnectServiceStub(channel)
         self._session_id = session_id
 
-    def _parse_artifacts(self, path_or_uri: str, pyfile: bool, archive: bool) -> List[Artifact]:
+    def _parse_artifacts(
+        self, path_or_uri: str, pyfile: bool, archive: bool, file: bool
+    ) -> List[Artifact]:
         # Currently only local files with .jar extension is supported.
         parsed = urlparse(path_or_uri)
         # Check if it is a file from the scheme
@@ -180,6 +187,8 @@ class ArtifactManager:
                     name = f"{name}#{parsed.fragment}"
 
                 artifact = new_archive_artifact(name, LocalFile(local_path))
+            elif file:
+                artifact = new_file_artifact(name, LocalFile(local_path))
             elif name.endswith(".jar"):
                 artifact = new_jar_artifact(name, LocalFile(local_path))
             else:
@@ -188,11 +197,13 @@ class ArtifactManager:
         raise RuntimeError(f"Unsupported scheme: {parsed.scheme}")
 
     def _create_requests(
-        self, *path: str, pyfile: bool, archive: bool
+        self, *path: str, pyfile: bool, archive: bool, file: bool
     ) -> Iterator[proto.AddArtifactsRequest]:
         """Separated for the testing purpose."""
         return self._add_artifacts(
-            chain(*(self._parse_artifacts(p, pyfile=pyfile, archive=archive) for p in path))
+            chain(
+                *(self._parse_artifacts(p, pyfile=pyfile, archive=archive, file=file) for p in path)
+            )
         )
 
     def _retrieve_responses(
@@ -201,13 +212,13 @@ class ArtifactManager:
         """Separated for the testing purpose."""
         return self._stub.AddArtifacts(requests)
 
-    def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
+    def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None:
         """
         Add a single artifact to the session.
         Currently only local files with .jar extension is supported.
         """
         requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(
-            *path, pyfile=pyfile, archive=archive
+            *path, pyfile=pyfile, archive=archive, file=file
         )
         response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
         summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []
diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py
index 8da649e7765..b2071641a26 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1251,8 +1251,8 @@ class SparkConnectClient(object):
         else:
             raise SparkConnectGrpcException(str(rpc_error)) from None
 
-    def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
-        self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive)
+    def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None:
+        self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file)
 
 
 class RetryState:
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 2d58ce1daf0..341db448955 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -613,7 +613,9 @@ class SparkSession:
         """
         return self._client
 
-    def addArtifacts(self, *path: str, pyfile: bool = False, archive: bool = False) -> None:
+    def addArtifacts(
+        self, *path: str, pyfile: bool = False, archive: bool = False, file: bool = False
+    ) -> None:
         """
         Add artifact(s) to the client session. Currently only local files are supported.
 
@@ -630,10 +632,13 @@ class SparkSession:
         archive : bool
             Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files.
             The archives are unpacked on the executor side automatically.
+        file : bool
+            Add a file to be downloaded with this Spark job on every node.
+            The ``path`` passed can only be a local file for now.
         """
-        if pyfile and archive:
-            raise ValueError("'pyfile' and 'archive' cannot be True together.")
-        self._client.add_artifacts(*path, pyfile=pyfile, archive=archive)
+        if sum([file, pyfile, archive]) > 1:
+            raise ValueError("'pyfile', 'archive' and/or 'file' cannot be True together.")
+        self._client.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file)
 
     addArtifact = addArtifacts
 
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py
index 2bff3fd5bc4..1725e0f6e4c 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -49,7 +49,9 @@ class ArtifactTests(ReusedConnectTestCase):
         file_name = "smallJar"
         small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
         response = self.artifact_manager._retrieve_responses(
-            self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False)
+            self.artifact_manager._create_requests(
+                small_jar_path, pyfile=False, archive=False, file=False
+            )
         )
         self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar"))
 
@@ -59,7 +61,9 @@ class ArtifactTests(ReusedConnectTestCase):
         small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
 
         requests = list(
-            self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False)
+            self.artifact_manager._create_requests(
+                small_jar_path, pyfile=False, archive=False, file=False
+            )
         )
         self.assertEqual(len(requests), 1)
 
@@ -83,7 +87,9 @@ class ArtifactTests(ReusedConnectTestCase):
         large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
 
         requests = list(
-            self.artifact_manager._create_requests(large_jar_path, pyfile=False, archive=False)
+            self.artifact_manager._create_requests(
+                large_jar_path, pyfile=False, archive=False, file=False
+            )
         )
         # Expected chunks = roundUp( file_size / chunk_size) = 12
         # File size of `junitLargeJar.jar` is 384581 bytes.
@@ -117,7 +123,7 @@ class ArtifactTests(ReusedConnectTestCase):
 
         requests = list(
             self.artifact_manager._create_requests(
-                small_jar_path, small_jar_path, pyfile=False, archive=False
+                small_jar_path, small_jar_path, pyfile=False, archive=False, file=False
             )
         )
         # Single request containing 2 artifacts.
@@ -160,6 +166,7 @@ class ArtifactTests(ReusedConnectTestCase):
                 small_jar_path,
                 pyfile=False,
                 archive=False,
+                file=False,
             )
         )
         # There are a total of 14 requests.
@@ -271,6 +278,22 @@ class ArtifactTests(ReusedConnectTestCase):
             self.spark.addArtifacts(f"{archive_path}.zip#my_files", archive=True)
             self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello world!")
 
+    def test_add_file(self):
+        with tempfile.TemporaryDirectory() as d:
+            file_path = os.path.join(d, "my_file.txt")
+            with open(file_path, "w") as f:
+                f.write("Hello world!!")
+
+            @udf("string")
+            def func(x):
+                with open(
+                    os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r"
+                ) as my_file:
+                    return my_file.read().strip()
+
+            self.spark.addArtifacts(file_path, file=True)
+            self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "Hello world!!")
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.client.test_artifact import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org