You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/03 11:53:50 UTC

[spark] branch branch-3.4 updated: [SPARK-42653][CONNECT] Artifact transfer from Scala/JVM client to Server

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 24a602dd173 [SPARK-42653][CONNECT] Artifact transfer from Scala/JVM client to Server
24a602dd173 is described below

commit 24a602dd173a914b02e8047731bd6f95e4810fd0
Author: vicennial <ve...@databricks.com>
AuthorDate: Fri Mar 3 07:52:57 2023 -0400

    [SPARK-42653][CONNECT] Artifact transfer from Scala/JVM client to Server
    
    ### What changes were proposed in this pull request?
    
    This PR introduces a mechanism to transfer artifacts (currently, local `.jar` + `.class` files) from a Spark Connect JVM/Scala client over to the server side of Spark Connect. The mechanism follows the protocol as defined in https://github.com/apache/spark/pull/40147 and supports batching (for multiple "small" artifacts) and chunking (for large artifacts).
    
    Note: Server-side artifact handling is not covered in this PR.
    
    ### Why are the changes needed?
    
    In the decoupled client-server architecture of Spark Connect, a remote client may use a local JAR or a new class in their UDF that may not be present on the server. To handle these cases of missing "artifacts", we implement a mechanism to transfer artifacts from the client side over to the server side as per the protocol defined in https://github.com/apache/spark/pull/40147.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users would be able to use the `addArtifact` and `addArtifacts` methods (via a `SparkSession` instance) to transfer local files (`.jar` and `.class` extensions).
    
    ### How was this patch tested?
    
    Unit tests - located in `ArtifactSuite`.
    
    Closes #40256 from vicennial/SPARK-42653.
    
    Authored-by: vicennial <ve...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
    (cherry picked from commit 8a0d6261c00d35cb174d2a68142f05aae364f59b)
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |  32 +++
 .../spark/sql/connect/client/ArtifactManager.scala | 305 +++++++++++++++++++++
 .../sql/connect/client/SparkConnectClient.scala    |  23 ++
 .../test/resources/artifact-tests/crc/README.md    |   5 +
 .../resources/artifact-tests/crc/junitLargeJar.txt |  12 +
 .../artifact-tests/crc/smallClassFile.txt          |   1 +
 .../artifact-tests/crc/smallClassFileDup.txt       |   1 +
 .../test/resources/artifact-tests/crc/smallJar.txt |   1 +
 .../resources/artifact-tests/junitLargeJar.jar     | Bin 0 -> 384581 bytes
 .../resources/artifact-tests/smallClassFile.class  | Bin 0 -> 424 bytes
 .../artifact-tests/smallClassFileDup.class         | Bin 0 -> 424 bytes
 .../src/test/resources/artifact-tests/smallJar.jar | Bin 0 -> 787 bytes
 .../apache/spark/sql/PlanGenerationTestSuite.scala |  22 +-
 .../spark/sql/connect/client/ArtifactSuite.scala   | 249 +++++++++++++++++
 .../connect/client/SparkConnectClientSuite.scala   |  23 +-
 .../sql/connect/client/util/ConnectFunSuite.scala  |  36 ++-
 .../sql/connect/service/SparkConnectService.scala  |  23 ++
 17 files changed, 710 insertions(+), 23 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index d463af68832..a8a88d63b1a 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql
 
 import java.io.Closeable
+import java.net.URI
 import java.util.concurrent.TimeUnit._
 import java.util.concurrent.atomic.AtomicLong
 
@@ -417,6 +418,37 @@ class SparkSession private[sql] (
     execute(command)
   }
 
+  /**
+   * Add a single artifact to the client session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   *
+   * @since 3.4.0
+   */
+  @Experimental
+  def addArtifact(path: String): Unit = client.addArtifact(path)
+
+  /**
+   * Add a single artifact to the client session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   *
+   * @since 3.4.0
+   */
+  @Experimental
+  def addArtifact(uri: URI): Unit = client.addArtifact(uri)
+
+  /**
+   * Add one or more artifacts to the session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   *
+   * @since 3.4.0
+   */
+  @Experimental
+  @scala.annotation.varargs
+  def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
+
   /**
    * This resets the plan id generator so we can produce plans that are comparable.
    *
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
new file mode 100644
index 00000000000..ead500a53e6
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import java.io.InputStream
+import java.net.URI
+import java.nio.file.{Files, Path, Paths}
+import java.util.zip.{CheckedInputStream, CRC32}
+
+import scala.collection.mutable
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
+import scala.util.control.NonFatal
+
+import Artifact._
+import com.google.protobuf.ByteString
+import io.grpc.ManagedChannel
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.AddArtifactsResponse
+import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * The Artifact Manager is responsible for handling and transferring artifacts from the local
+ * client to the server (local/remote).
+ * @param userContext
+ * @param channel
+ */
+class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
+  // Using the midpoint recommendation of 32KiB for chunk size as specified in
+  // https://github.com/grpc/grpc.github.io/issues/371.
+  private val CHUNK_SIZE: Int = 32 * 1024
+
+  private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
+
+  /**
+   * Add a single artifact to the session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   */
+  def addArtifact(path: String): Unit = {
+    addArtifact(Utils.resolveURI(path))
+  }
+
+  private def parseArtifacts(uri: URI): Seq[Artifact] = {
+    // Currently only local files with extensions .jar and .class are supported.
+    uri.getScheme match {
+      case "file" =>
+        val path = Paths.get(uri)
+        val artifact = path.getFileName.toString match {
+          case jar if jar.endsWith(".jar") =>
+            newJarArtifact(path.getFileName, new LocalFile(path))
+          case cf if cf.endsWith(".class") =>
+            newClassArtifact(path.getFileName, new LocalFile(path))
+          case other =>
+            throw new UnsupportedOperationException(s"Unsuppoted file format: $other")
+        }
+        Seq[Artifact](artifact)
+
+      case other =>
+        throw new UnsupportedOperationException(s"Unsupported scheme: $other")
+    }
+  }
+
+  /**
+   * Add a single artifact to the session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   */
+  def addArtifact(uri: URI): Unit = addArtifacts(parseArtifacts(uri))
+
+  /**
+   * Add multiple artifacts to the session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   */
+  def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))
+
+  /**
+   * Add a number of artifacts to the session.
+   */
+  private def addArtifacts(artifacts: Iterable[Artifact]): Unit = {
+    val promise = Promise[Seq[ArtifactSummary]]
+    val responseHandler = new StreamObserver[proto.AddArtifactsResponse] {
+      private val summaries = mutable.Buffer.empty[ArtifactSummary]
+      override def onNext(v: AddArtifactsResponse): Unit = {
+        v.getArtifactsList.forEach { summary =>
+          summaries += summary
+        }
+      }
+      override def onError(throwable: Throwable): Unit = {
+        promise.failure(throwable)
+      }
+      override def onCompleted(): Unit = {
+        promise.success(summaries.toSeq)
+      }
+    }
+    val stream = stub.addArtifacts(responseHandler)
+    val currentBatch = mutable.Buffer.empty[Artifact]
+    var currentBatchSize = 0L
+
+    def addToBatch(dep: Artifact, size: Long): Unit = {
+      currentBatch += dep
+      currentBatchSize += size
+    }
+
+    def writeBatch(): Unit = {
+      addBatchedArtifacts(currentBatch.toSeq, stream)
+      currentBatch.clear()
+      currentBatchSize = 0
+    }
+
+    artifacts.iterator.foreach { artifact =>
+      val data = artifact.storage
+      val size = data.size
+      if (size > CHUNK_SIZE) {
+        // Payload can either be a batch OR a single chunked artifact. Write batch if non-empty
+        // before chunking current artifact.
+        if (currentBatch.nonEmpty) {
+          writeBatch()
+        }
+        addChunkedArtifact(artifact, stream)
+      } else {
+        if (currentBatchSize + size > CHUNK_SIZE) {
+          writeBatch()
+        }
+        addToBatch(artifact, size)
+      }
+    }
+    if (currentBatch.nonEmpty) {
+      writeBatch()
+    }
+    stream.onCompleted()
+    ThreadUtils.awaitResult(promise.future, Duration.Inf)
+    // TODO(SPARK-42658): Handle responses containing CRC failures.
+  }
+
+  /**
+   * Add a batch of artifacts to the stream. All the artifacts in this call are packaged into a
+   * single [[proto.AddArtifactsRequest]].
+   */
+  private def addBatchedArtifacts(
+      artifacts: Seq[Artifact],
+      stream: StreamObserver[proto.AddArtifactsRequest]): Unit = {
+    val builder = proto.AddArtifactsRequest
+      .newBuilder()
+      .setUserContext(userContext)
+    artifacts.foreach { artifact =>
+      val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
+      try {
+        val data = proto.AddArtifactsRequest.ArtifactChunk
+          .newBuilder()
+          .setData(ByteString.readFrom(in))
+          .setCrc(in.getChecksum.getValue)
+
+        builder.getBatchBuilder
+          .addArtifactsBuilder()
+          .setName(artifact.path.toString)
+          .setData(data)
+          .build()
+      } catch {
+        case NonFatal(e) =>
+          stream.onError(e)
+          throw e
+      } finally {
+        in.close()
+      }
+    }
+    stream.onNext(builder.build())
+  }
+
+  /**
+   * Read data from an [[InputStream]] in pieces of `chunkSize` bytes and convert to
+   * protobuf-compatible [[ByteString]].
+   * @param in
+   * @return
+   */
+  private def readNextChunk(in: InputStream): ByteString = {
+    val buf = new Array[Byte](CHUNK_SIZE)
+    var bytesRead = 0
+    var count = 0
+    while (count != -1 && bytesRead < CHUNK_SIZE) {
+      count = in.read(buf, bytesRead, CHUNK_SIZE - bytesRead)
+      if (count != -1) {
+        bytesRead += count
+      }
+    }
+    if (bytesRead == 0) ByteString.empty()
+    else ByteString.copyFrom(buf, 0, bytesRead)
+  }
+
+  /**
+   * Add a artifact in chunks to the stream. The artifact's data is spread out over multiple
+   * [[proto.AddArtifactsRequest requests]].
+   */
+  private def addChunkedArtifact(
+      artifact: Artifact,
+      stream: StreamObserver[proto.AddArtifactsRequest]): Unit = {
+    val builder = proto.AddArtifactsRequest
+      .newBuilder()
+      .setUserContext(userContext)
+
+    val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
+    try {
+      // First RPC contains the `BeginChunkedArtifact` payload (`begin_chunk`).
+      // Subsequent RPCs contains the `ArtifactChunk` payload (`chunk`).
+      val artifactChunkBuilder = proto.AddArtifactsRequest.ArtifactChunk.newBuilder()
+      var dataChunk = readNextChunk(in)
+      // Integer division that rounds up to the nearest whole number.
+      def getNumChunks(size: Long): Long = (size + (CHUNK_SIZE - 1)) / CHUNK_SIZE
+
+      builder.getBeginChunkBuilder
+        .setName(artifact.path.toString)
+        .setTotalBytes(artifact.size)
+        .setNumChunks(getNumChunks(artifact.size))
+        .setInitialChunk(
+          artifactChunkBuilder
+            .setData(dataChunk)
+            .setCrc(in.getChecksum.getValue))
+      stream.onNext(builder.build())
+      in.getChecksum.reset()
+      builder.clearBeginChunk()
+
+      dataChunk = readNextChunk(in)
+      // Consume stream in chunks until there is no data left to read.
+      while (!dataChunk.isEmpty) {
+        artifactChunkBuilder.setData(dataChunk).setCrc(in.getChecksum.getValue)
+        builder.setChunk(artifactChunkBuilder.build())
+        stream.onNext(builder.build())
+        in.getChecksum.reset()
+        builder.clearChunk()
+        dataChunk = readNextChunk(in)
+      }
+    } catch {
+      case NonFatal(e) =>
+        stream.onError(e)
+        throw e
+    } finally {
+      in.close()
+    }
+  }
+}
+
+class Artifact private (val path: Path, val storage: LocalData) {
+  require(!path.isAbsolute, s"Bad path: $path")
+
+  lazy val size: Long = storage match {
+    case localData: LocalData => localData.size
+  }
+}
+
+object Artifact {
+  val CLASS_PREFIX: Path = Paths.get("classes")
+  val JAR_PREFIX: Path = Paths.get("jars")
+
+  def newJarArtifact(fileName: Path, storage: LocalData): Artifact = {
+    newArtifact(JAR_PREFIX, ".jar", fileName, storage)
+  }
+
+  def newClassArtifact(fileName: Path, storage: LocalData): Artifact = {
+    newArtifact(CLASS_PREFIX, ".class", fileName, storage)
+  }
+
+  private def newArtifact(
+      prefix: Path,
+      requiredSuffix: String,
+      fileName: Path,
+      storage: LocalData): Artifact = {
+    require(!fileName.isAbsolute)
+    require(fileName.toString.endsWith(requiredSuffix))
+    new Artifact(prefix.resolve(fileName), storage)
+  }
+
+  /**
+   * Payload stored on this machine.
+   */
+  sealed trait LocalData {
+    def stream: InputStream
+    def size: Long
+  }
+
+  /**
+   * Payload stored in a local file.
+   */
+  class LocalFile(val path: Path) extends LocalData {
+    override def size: Long = Files.size(path)
+    override def stream: InputStream = Files.newInputStream(path)
+  }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index cdc0b381a44..599aab441de 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -36,6 +36,8 @@ private[sql] class SparkConnectClient(
 
   private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
 
+  private[client] val artifactManager: ArtifactManager = new ArtifactManager(userContext, channel)
+
   /**
    * Placeholder method.
    * @return
@@ -147,6 +149,27 @@ private[sql] class SparkConnectClient(
     analyze(request)
   }
 
+  /**
+   * Add a single artifact to the client session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   */
+  def addArtifact(path: String): Unit = artifactManager.addArtifact(path)
+
+  /**
+   * Add a single artifact to the client session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   */
+  def addArtifact(uri: URI): Unit = artifactManager.addArtifact(uri)
+
+  /**
+   * Add multiple artifacts to the session.
+   *
+   * Currently only local files with extensions .jar and .class are supported.
+   */
+  def addArtifacts(uri: Seq[URI]): Unit = artifactManager.addArtifacts(uri)
+
   /**
    * Shutdown the client's connection to the server.
    */
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md
new file mode 100644
index 00000000000..df9af410644
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md
@@ -0,0 +1,5 @@
+The CRCs for a specific file are stored in a text file with the same name (excluding the original extension).
+
+The CRCs are calculated for data chunks of `32768 bytes` (individual CRCs) and are newline delimited.
+
+The CRCs were calculated using https://simplycalc.com/crc32-file.php
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt
new file mode 100644
index 00000000000..3e89631dea5
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt
@@ -0,0 +1,12 @@
+902183889
+2415704507
+1084811487
+1951510
+1158852476
+2003120166
+3026803842
+3850244775
+3409267044
+652109216
+104029242
+3019434266
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt
new file mode 100644
index 00000000000..531f98ce9a2
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt
@@ -0,0 +1 @@
+1935693963
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt
new file mode 100644
index 00000000000..531f98ce9a2
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt
@@ -0,0 +1 @@
+1935693963
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt
new file mode 100644
index 00000000000..df32adcce7a
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt
@@ -0,0 +1 @@
+1631702900
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar b/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar
new file mode 100755
index 00000000000..6da55d8b852
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar differ
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class
new file mode 100755
index 00000000000..e796030e471
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class differ
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class
new file mode 100755
index 00000000000..e796030e471
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class differ
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar
new file mode 100755
index 00000000000..3c4930e8e95
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar differ
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
old mode 100644
new mode 100755
index 67dc92a7472..6e9583ae725
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -68,27 +68,7 @@ class PlanGenerationTestSuite
   // Borrowed from SparkFunSuite
   private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
 
-  // Borrowed from SparkFunSuite
-  private def getWorkspaceFilePath(first: String, more: String*): Path = {
-    if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) {
-      fail("spark.test.home or SPARK_HOME is not set.")
-    }
-    val sparkHome = sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
-    java.nio.file.Paths.get(sparkHome, first +: more: _*)
-  }
-
-  protected val baseResourcePath: Path = {
-    getWorkspaceFilePath(
-      "connector",
-      "connect",
-      "common",
-      "src",
-      "test",
-      "resources",
-      "query-tests").toAbsolutePath
-  }
-
-  protected val queryFilePath: Path = baseResourcePath.resolve("queries")
+  protected val queryFilePath: Path = commonResourcePath.resolve("queries")
 
   // A relative path to /connector/connect/server, used by `ProtoToParsedPlanTestSuite` to run
   // with the datasource.
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
new file mode 100644
index 00000000000..adb2b3f1908
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import java.io.InputStream
+import java.nio.file.{Files, Path, Paths}
+import java.util.concurrent.TimeUnit
+
+import collection.JavaConverters._
+import com.google.protobuf.ByteString
+import io.grpc.{ManagedChannel, Server}
+import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.AddArtifactsRequest
+import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+
+class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
+
+  private var client: SparkConnectClient = _
+  private var service: DummySparkConnectService = _
+  private var server: Server = _
+  private var artifactManager: ArtifactManager = _
+  private var channel: ManagedChannel = _
+
+  private def startDummyServer(): Unit = {
+    service = new DummySparkConnectService()
+    server = InProcessServerBuilder
+      .forName(getClass.getName)
+      .addService(service)
+      .build()
+    server.start()
+  }
+
+  private def createArtifactManager(): Unit = {
+    channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
+    artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), channel)
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    startDummyServer()
+    createArtifactManager()
+    client = null
+  }
+
+  override def afterEach(): Unit = {
+    if (server != null) {
+      server.shutdownNow()
+      assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to shutdown")
+    }
+
+    if (channel != null) {
+      channel.shutdownNow()
+    }
+
+    if (client != null) {
+      client.shutdown()
+    }
+  }
+
+  private val CHUNK_SIZE: Int = 32 * 1024
+  protected def artifactFilePath: Path = baseResourcePath.resolve("artifact-tests")
+  protected def artifactCrcPath: Path = artifactFilePath.resolve("crc")
+
+  private def getCrcValues(filePath: Path): Seq[Long] = {
+    val fileName = filePath.getFileName.toString
+    val crcFileName = fileName.split('.').head + ".txt"
+    Files
+      .readAllLines(artifactCrcPath.resolve(crcFileName))
+      .asScala
+      .map(_.toLong)
+  }
+
+  /**
+   * Check if the data sent to the server (stored in `artifactChunk`) is equivalent to the local
+   * data at `localPath`.
+   * @param artifactChunk
+   * @param localPath
+   */
+  private def assertFileDataEquality(
+      artifactChunk: AddArtifactsRequest.ArtifactChunk,
+      localPath: Path): Unit = {
+    val localData = ByteString.readFrom(Files.newInputStream(localPath))
+    val expectedCrc = getCrcValues(localPath).head
+    assert(artifactChunk.getData == localData)
+    assert(artifactChunk.getCrc == expectedCrc)
+  }
+
+  private def singleChunkArtifactTest(path: String): Unit = {
+    test(s"Single Chunk Artifact - $path") {
+      val artifactPath = artifactFilePath.resolve(path)
+      artifactManager.addArtifact(artifactPath.toString)
+
+      val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+      // Single `AddArtifactRequest`
+      assert(receivedRequests.size == 1)
+
+      val request = receivedRequests.head
+      assert(request.hasBatch)
+
+      val batch = request.getBatch
+      // Single artifact in batch
+      assert(batch.getArtifactsList.size() == 1)
+
+      val singleChunkArtifact = batch.getArtifacts(0)
+      val namePrefix = artifactPath.getFileName.toString match {
+        case jar if jar.endsWith(".jar") => "jars"
+        case cf if cf.endsWith(".class") => "classes"
+      }
+      assert(singleChunkArtifact.getName.equals(namePrefix + "/" + path))
+      assertFileDataEquality(singleChunkArtifact.getData, artifactPath)
+    }
+  }
+
+  singleChunkArtifactTest("smallClassFile.class")
+
+  singleChunkArtifactTest("smallJar.jar")
+
+  private def readNextChunk(in: InputStream): ByteString = {
+    val buf = new Array[Byte](CHUNK_SIZE)
+    var bytesRead = 0
+    var count = 0
+    while (count != -1 && bytesRead < CHUNK_SIZE) {
+      count = in.read(buf, bytesRead, CHUNK_SIZE - bytesRead)
+      if (count != -1) {
+        bytesRead += count
+      }
+    }
+    if (bytesRead == 0) ByteString.empty()
+    else ByteString.copyFrom(buf, 0, bytesRead)
+  }
+
+  /**
+   * Reads data in a chunk of `CHUNK_SIZE` bytes from `in` and verify equality with server-side
+   * data stored in `chunk`.
+   * @param in
+   * @param chunk
+   * @return
+   */
+  private def checkChunksDataAndCrc(
+      filePath: Path,
+      chunks: Seq[AddArtifactsRequest.ArtifactChunk]): Unit = {
+    val in = Files.newInputStream(filePath)
+    val crcs = getCrcValues(filePath)
+    chunks.zip(crcs).foreach { case (chunk, expectedCrc) =>
+      val expectedData = readNextChunk(in)
+      chunk.getData == expectedData && chunk.getCrc == expectedCrc
+    }
+  }
+
+  test("Chunked Artifact - junitLargeJar.jar") {
+    val artifactPath = artifactFilePath.resolve("junitLargeJar.jar")
+    artifactManager.addArtifact(artifactPath.toString)
+    // Expected chunks = roundUp( file_size / chunk_size) = 12
+    // File size of `junitLargeJar.jar` is 384581 bytes.
+    val expectedChunks = (384581 + (CHUNK_SIZE - 1)) / CHUNK_SIZE
+    val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+    assert(384581 == Files.size(artifactPath))
+    assert(receivedRequests.size == expectedChunks)
+    assert(receivedRequests.head.hasBeginChunk)
+    val beginChunkRequest = receivedRequests.head.getBeginChunk
+    assert(beginChunkRequest.getName == "jars/junitLargeJar.jar")
+    assert(beginChunkRequest.getTotalBytes == 384581)
+    assert(beginChunkRequest.getNumChunks == expectedChunks)
+    val dataChunks = Seq(beginChunkRequest.getInitialChunk) ++
+      receivedRequests.drop(1).map(_.getChunk)
+    checkChunksDataAndCrc(artifactPath, dataChunks)
+  }
+
+  test("Batched SingleChunkArtifacts") {
+    val file1 = artifactFilePath.resolve("smallClassFile.class").toUri
+    val file2 = artifactFilePath.resolve("smallJar.jar").toUri
+    artifactManager.addArtifacts(Seq(file1, file2))
+    val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+    // Single request containing 2 artifacts.
+    assert(receivedRequests.size == 1)
+
+    val request = receivedRequests.head
+    assert(request.hasBatch)
+
+    val batch = request.getBatch
+    assert(batch.getArtifactsList.size() == 2)
+
+    val artifacts = batch.getArtifactsList
+    assert(artifacts.get(0).getName == "classes/smallClassFile.class")
+    assert(artifacts.get(1).getName == "jars/smallJar.jar")
+
+    assertFileDataEquality(artifacts.get(0).getData, Paths.get(file1))
+    assertFileDataEquality(artifacts.get(1).getData, Paths.get(file2))
+  }
+
+  test("Mix of SingleChunkArtifact and chunked artifact") {
+    val file1 = artifactFilePath.resolve("smallClassFile.class").toUri
+    val file2 = artifactFilePath.resolve("junitLargeJar.jar").toUri
+    val file3 = artifactFilePath.resolve("smallClassFileDup.class").toUri
+    val file4 = artifactFilePath.resolve("smallJar.jar").toUri
+    artifactManager.addArtifacts(Seq(file1, file2, file3, file4))
+    val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+    // There are a total of 14 requests.
+    // The 1st request contains a single artifact - smallClassFile.class (There are no
+    // other artifacts batched with it since the next one is large multi-chunk artifact)
+    // Requests 2-13 (1-indexed) belong to the transfer of junitLargeJar.jar. This includes
+    // the first "beginning chunk" and the subsequent data chunks.
+    // The last request (14) contains both smallClassFileDup.class and smallJar.jar batched
+    // together.
+    assert(receivedRequests.size == 1 + 12 + 1)
+
+    val firstReqBatch = receivedRequests.head.getBatch.getArtifactsList
+    assert(firstReqBatch.size() == 1)
+    assert(firstReqBatch.get(0).getName == "classes/smallClassFile.class")
+    assertFileDataEquality(firstReqBatch.get(0).getData, Paths.get(file1))
+
+    val secondReq = receivedRequests(1)
+    assert(secondReq.hasBeginChunk)
+    val beginChunkRequest = secondReq.getBeginChunk
+    assert(beginChunkRequest.getName == "jars/junitLargeJar.jar")
+    assert(beginChunkRequest.getTotalBytes == 384581)
+    assert(beginChunkRequest.getNumChunks == 12)
+    // Large artifact data chunks are requests number 3 to 13.
+    val dataChunks = Seq(beginChunkRequest.getInitialChunk) ++
+      receivedRequests.drop(2).dropRight(1).map(_.getChunk)
+    checkChunksDataAndCrc(Paths.get(file2), dataChunks)
+
+    val lastBatch = receivedRequests.last.getBatch
+    assert(lastBatch.getArtifactsCount == 2)
+    val remainingArtifacts = lastBatch.getArtifactsList
+    assert(remainingArtifacts.get(0).getName == "classes/smallClassFileDup.class")
+    assert(remainingArtifacts.get(1).getName == "jars/smallJar.jar")
+
+    assertFileDataEquality(remainingArtifacts.get(0).getData, Paths.get(file3))
+    assertFileDataEquality(remainingArtifacts.get(1).getData, Paths.get(file4))
+  }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
old mode 100644
new mode 100755
index 8cead49de0c..dcb13589206
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -22,9 +22,10 @@ import io.grpc.{Server, StatusRuntimeException}
 import io.grpc.netty.NettyServerBuilder
 import io.grpc.stub.StreamObserver
 import org.scalatest.BeforeAndAfterEach
+import scala.collection.mutable
 
 import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
+import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
 import org.apache.spark.sql.connect.common.config.ConnectCommon
@@ -181,6 +182,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
 class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {
 
   private var inputPlan: proto.Plan = _
+  private val inputArtifactRequests: mutable.ListBuffer[AddArtifactsRequest] =
+    mutable.ListBuffer.empty
 
   private[sql] def getAndClearLatestInputPlan(): proto.Plan = {
     val plan = inputPlan
@@ -188,6 +191,12 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
     plan
   }
 
+  private[sql] def getAndClearLatestAddArtifactRequests(): Seq[AddArtifactsRequest] = {
+    val requests = inputArtifactRequests.toSeq
+    inputArtifactRequests.clear()
+    requests
+  }
+
   override def executePlan(
       request: ExecutePlanRequest,
       responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
@@ -229,4 +238,16 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
     responseObserver.onNext(response)
     responseObserver.onCompleted()
   }
+
+  override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse])
+      : StreamObserver[AddArtifactsRequest] = new StreamObserver[AddArtifactsRequest] {
+    override def onNext(v: AddArtifactsRequest): Unit = inputArtifactRequests.append(v)
+
+    override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable)
+
+    override def onCompleted(): Unit = {
+      responseObserver.onNext(proto.AddArtifactsResponse.newBuilder().build())
+      responseObserver.onCompleted()
+    }
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
old mode 100644
new mode 100755
index 5100fa7d229..1ece0838b1b
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
@@ -16,9 +16,43 @@
  */
 package org.apache.spark.sql.connect.client.util
 
+import java.nio.file.Path
+
 import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
 
 /**
  * The basic testsuite the client tests should extend from.
  */
-trait ConnectFunSuite extends AnyFunSuite {} // scalastyle:ignore funsuite
+trait ConnectFunSuite extends AnyFunSuite { // scalastyle:ignore funsuite
+
+  // Borrowed from SparkFunSuite
+  protected def getWorkspaceFilePath(first: String, more: String*): Path = {
+    if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) {
+      fail("spark.test.home or SPARK_HOME is not set.")
+    }
+    val sparkHome = sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
+    java.nio.file.Paths.get(sparkHome, first +: more: _*)
+  }
+
+  protected val baseResourcePath: Path = {
+    getWorkspaceFilePath(
+      "connector",
+      "connect",
+      "client",
+      "jvm",
+      "src",
+      "test",
+      "resources").toAbsolutePath
+  }
+
+  protected val commonResourcePath: Path = {
+    getWorkspaceFilePath(
+      "connector",
+      "connect",
+      "common",
+      "src",
+      "test",
+      "resources",
+      "query-tests").toAbsolutePath
+  }
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
old mode 100644
new mode 100755
index d6446eae4b7..cd353b6ff60
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{compact, render}
 import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
 import org.apache.spark.api.python.PythonException
 import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
@@ -179,6 +180,28 @@ class SparkConnectService(debug: Boolean)
       new SparkConnectConfigHandler(responseObserver).handle(request)
     } catch handleError("config", observer = responseObserver)
   }
+
+  /**
+   * This is the main entry method for all calls to add/transfer artifacts.
+   *
+   * @param responseObserver
+   * @return
+   */
+  override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse])
+      : StreamObserver[AddArtifactsRequest] = {
+    // TODO: Handle artifact files
+    // No-Op StreamObserver
+    new StreamObserver[AddArtifactsRequest] {
+      override def onNext(v: AddArtifactsRequest): Unit = {}
+
+      override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable)
+
+      override def onCompleted(): Unit = {
+        responseObserver.onNext(proto.AddArtifactsResponse.newBuilder().build())
+        responseObserver.onCompleted()
+      }
+    }
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org