You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/03 11:53:50 UTC
[spark] branch branch-3.4 updated: [SPARK-42653][CONNECT] Artifact transfer from Scala/JVM client to Server
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 24a602dd173 [SPARK-42653][CONNECT] Artifact transfer from Scala/JVM client to Server
24a602dd173 is described below
commit 24a602dd173a914b02e8047731bd6f95e4810fd0
Author: vicennial <ve...@databricks.com>
AuthorDate: Fri Mar 3 07:52:57 2023 -0400
[SPARK-42653][CONNECT] Artifact transfer from Scala/JVM client to Server
### What changes were proposed in this pull request?
This PR introduces a mechanism to transfer artifacts (currently, local `.jar` + `.class` files) from a Spark Connect JVM/Scala client over to the server side of Spark Connect. The mechanism follows the protocol as defined in https://github.com/apache/spark/pull/40147 and supports batching (for multiple "small" artifacts) and chunking (for large artifacts).
Note: Server-side artifact handling is not covered in this PR.
### Why are the changes needed?
In the decoupled client-server architecture of Spark Connect, a remote client may use a local JAR or a new class in their UDF that may not be present on the server. To handle these cases of missing "artifacts", we implement a mechanism to transfer artifacts from the client side over to the server side as per the protocol defined in https://github.com/apache/spark/pull/40147.
### Does this PR introduce _any_ user-facing change?
Yes, users would be able to use the `addArtifact` and `addArtifacts` methods (via a `SparkSession` instance) to transfer local files (`.jar` and `.class` extensions).
### How was this patch tested?
Unit tests - located in `ArtifactSuite`.
Closes #40256 from vicennial/SPARK-42653.
Authored-by: vicennial <ve...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
(cherry picked from commit 8a0d6261c00d35cb174d2a68142f05aae364f59b)
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 32 +++
.../spark/sql/connect/client/ArtifactManager.scala | 305 +++++++++++++++++++++
.../sql/connect/client/SparkConnectClient.scala | 23 ++
.../test/resources/artifact-tests/crc/README.md | 5 +
.../resources/artifact-tests/crc/junitLargeJar.txt | 12 +
.../artifact-tests/crc/smallClassFile.txt | 1 +
.../artifact-tests/crc/smallClassFileDup.txt | 1 +
.../test/resources/artifact-tests/crc/smallJar.txt | 1 +
.../resources/artifact-tests/junitLargeJar.jar | Bin 0 -> 384581 bytes
.../resources/artifact-tests/smallClassFile.class | Bin 0 -> 424 bytes
.../artifact-tests/smallClassFileDup.class | Bin 0 -> 424 bytes
.../src/test/resources/artifact-tests/smallJar.jar | Bin 0 -> 787 bytes
.../apache/spark/sql/PlanGenerationTestSuite.scala | 22 +-
.../spark/sql/connect/client/ArtifactSuite.scala | 249 +++++++++++++++++
.../connect/client/SparkConnectClientSuite.scala | 23 +-
.../sql/connect/client/util/ConnectFunSuite.scala | 36 ++-
.../sql/connect/service/SparkConnectService.scala | 23 ++
17 files changed, 710 insertions(+), 23 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index d463af68832..a8a88d63b1a 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
import java.io.Closeable
+import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.AtomicLong
@@ -417,6 +418,37 @@ class SparkSession private[sql] (
execute(command)
}
+ /**
+ * Add a single artifact to the client session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ *
+ * @since 3.4.0
+ */
+ @Experimental
+ def addArtifact(path: String): Unit = client.addArtifact(path)
+
+ /**
+ * Add a single artifact to the client session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ *
+ * @since 3.4.0
+ */
+ @Experimental
+ def addArtifact(uri: URI): Unit = client.addArtifact(uri)
+
+ /**
+ * Add one or more artifacts to the session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ *
+ * @since 3.4.0
+ */
+ @Experimental
+ @scala.annotation.varargs
+ def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
+
/**
* This resets the plan id generator so we can produce plans that are comparable.
*
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
new file mode 100644
index 00000000000..ead500a53e6
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import java.io.InputStream
+import java.net.URI
+import java.nio.file.{Files, Path, Paths}
+import java.util.zip.{CheckedInputStream, CRC32}
+
+import scala.collection.mutable
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
+import scala.util.control.NonFatal
+
+import Artifact._
+import com.google.protobuf.ByteString
+import io.grpc.ManagedChannel
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.AddArtifactsResponse
+import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * The Artifact Manager is responsible for handling and transferring artifacts from the local
+ * client to the server (local/remote).
+ * @param userContext
+ * @param channel
+ */
+class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
+ // Using the midpoint recommendation of 32KiB for chunk size as specified in
+ // https://github.com/grpc/grpc.github.io/issues/371.
+ private val CHUNK_SIZE: Int = 32 * 1024
+
+ private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
+
+ /**
+ * Add a single artifact to the session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ */
+ def addArtifact(path: String): Unit = {
+ addArtifact(Utils.resolveURI(path))
+ }
+
+ private def parseArtifacts(uri: URI): Seq[Artifact] = {
+ // Currently only local files with extensions .jar and .class are supported.
+ uri.getScheme match {
+ case "file" =>
+ val path = Paths.get(uri)
+ val artifact = path.getFileName.toString match {
+ case jar if jar.endsWith(".jar") =>
+ newJarArtifact(path.getFileName, new LocalFile(path))
+ case cf if cf.endsWith(".class") =>
+ newClassArtifact(path.getFileName, new LocalFile(path))
+ case other =>
+ throw new UnsupportedOperationException(s"Unsuppoted file format: $other")
+ }
+ Seq[Artifact](artifact)
+
+ case other =>
+ throw new UnsupportedOperationException(s"Unsupported scheme: $other")
+ }
+ }
+
+ /**
+ * Add a single artifact to the session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ */
+ def addArtifact(uri: URI): Unit = addArtifacts(parseArtifacts(uri))
+
+ /**
+ * Add multiple artifacts to the session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ */
+ def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))
+
+ /**
+ * Add a number of artifacts to the session.
+ */
+ private def addArtifacts(artifacts: Iterable[Artifact]): Unit = {
+ val promise = Promise[Seq[ArtifactSummary]]
+ val responseHandler = new StreamObserver[proto.AddArtifactsResponse] {
+ private val summaries = mutable.Buffer.empty[ArtifactSummary]
+ override def onNext(v: AddArtifactsResponse): Unit = {
+ v.getArtifactsList.forEach { summary =>
+ summaries += summary
+ }
+ }
+ override def onError(throwable: Throwable): Unit = {
+ promise.failure(throwable)
+ }
+ override def onCompleted(): Unit = {
+ promise.success(summaries.toSeq)
+ }
+ }
+ val stream = stub.addArtifacts(responseHandler)
+ val currentBatch = mutable.Buffer.empty[Artifact]
+ var currentBatchSize = 0L
+
+ def addToBatch(dep: Artifact, size: Long): Unit = {
+ currentBatch += dep
+ currentBatchSize += size
+ }
+
+ def writeBatch(): Unit = {
+ addBatchedArtifacts(currentBatch.toSeq, stream)
+ currentBatch.clear()
+ currentBatchSize = 0
+ }
+
+ artifacts.iterator.foreach { artifact =>
+ val data = artifact.storage
+ val size = data.size
+ if (size > CHUNK_SIZE) {
+ // Payload can either be a batch OR a single chunked artifact. Write batch if non-empty
+ // before chunking current artifact.
+ if (currentBatch.nonEmpty) {
+ writeBatch()
+ }
+ addChunkedArtifact(artifact, stream)
+ } else {
+ if (currentBatchSize + size > CHUNK_SIZE) {
+ writeBatch()
+ }
+ addToBatch(artifact, size)
+ }
+ }
+ if (currentBatch.nonEmpty) {
+ writeBatch()
+ }
+ stream.onCompleted()
+ ThreadUtils.awaitResult(promise.future, Duration.Inf)
+ // TODO(SPARK-42658): Handle responses containing CRC failures.
+ }
+
+ /**
+ * Add a batch of artifacts to the stream. All the artifacts in this call are packaged into a
+ * single [[proto.AddArtifactsRequest]].
+ */
+ private def addBatchedArtifacts(
+ artifacts: Seq[Artifact],
+ stream: StreamObserver[proto.AddArtifactsRequest]): Unit = {
+ val builder = proto.AddArtifactsRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ artifacts.foreach { artifact =>
+ val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
+ try {
+ val data = proto.AddArtifactsRequest.ArtifactChunk
+ .newBuilder()
+ .setData(ByteString.readFrom(in))
+ .setCrc(in.getChecksum.getValue)
+
+ builder.getBatchBuilder
+ .addArtifactsBuilder()
+ .setName(artifact.path.toString)
+ .setData(data)
+ .build()
+ } catch {
+ case NonFatal(e) =>
+ stream.onError(e)
+ throw e
+ } finally {
+ in.close()
+ }
+ }
+ stream.onNext(builder.build())
+ }
+
+ /**
+ * Read data from an [[InputStream]] in pieces of `chunkSize` bytes and convert to
+ * protobuf-compatible [[ByteString]].
+ * @param in
+ * @return
+ */
+ private def readNextChunk(in: InputStream): ByteString = {
+ val buf = new Array[Byte](CHUNK_SIZE)
+ var bytesRead = 0
+ var count = 0
+ while (count != -1 && bytesRead < CHUNK_SIZE) {
+ count = in.read(buf, bytesRead, CHUNK_SIZE - bytesRead)
+ if (count != -1) {
+ bytesRead += count
+ }
+ }
+ if (bytesRead == 0) ByteString.empty()
+ else ByteString.copyFrom(buf, 0, bytesRead)
+ }
+
+ /**
+ * Add a artifact in chunks to the stream. The artifact's data is spread out over multiple
+ * [[proto.AddArtifactsRequest requests]].
+ */
+ private def addChunkedArtifact(
+ artifact: Artifact,
+ stream: StreamObserver[proto.AddArtifactsRequest]): Unit = {
+ val builder = proto.AddArtifactsRequest
+ .newBuilder()
+ .setUserContext(userContext)
+
+ val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
+ try {
+ // First RPC contains the `BeginChunkedArtifact` payload (`begin_chunk`).
+ // Subsequent RPCs contains the `ArtifactChunk` payload (`chunk`).
+ val artifactChunkBuilder = proto.AddArtifactsRequest.ArtifactChunk.newBuilder()
+ var dataChunk = readNextChunk(in)
+ // Integer division that rounds up to the nearest whole number.
+ def getNumChunks(size: Long): Long = (size + (CHUNK_SIZE - 1)) / CHUNK_SIZE
+
+ builder.getBeginChunkBuilder
+ .setName(artifact.path.toString)
+ .setTotalBytes(artifact.size)
+ .setNumChunks(getNumChunks(artifact.size))
+ .setInitialChunk(
+ artifactChunkBuilder
+ .setData(dataChunk)
+ .setCrc(in.getChecksum.getValue))
+ stream.onNext(builder.build())
+ in.getChecksum.reset()
+ builder.clearBeginChunk()
+
+ dataChunk = readNextChunk(in)
+ // Consume stream in chunks until there is no data left to read.
+ while (!dataChunk.isEmpty) {
+ artifactChunkBuilder.setData(dataChunk).setCrc(in.getChecksum.getValue)
+ builder.setChunk(artifactChunkBuilder.build())
+ stream.onNext(builder.build())
+ in.getChecksum.reset()
+ builder.clearChunk()
+ dataChunk = readNextChunk(in)
+ }
+ } catch {
+ case NonFatal(e) =>
+ stream.onError(e)
+ throw e
+ } finally {
+ in.close()
+ }
+ }
+}
+
+class Artifact private (val path: Path, val storage: LocalData) {
+ require(!path.isAbsolute, s"Bad path: $path")
+
+ lazy val size: Long = storage match {
+ case localData: LocalData => localData.size
+ }
+}
+
+object Artifact {
+ val CLASS_PREFIX: Path = Paths.get("classes")
+ val JAR_PREFIX: Path = Paths.get("jars")
+
+ def newJarArtifact(fileName: Path, storage: LocalData): Artifact = {
+ newArtifact(JAR_PREFIX, ".jar", fileName, storage)
+ }
+
+ def newClassArtifact(fileName: Path, storage: LocalData): Artifact = {
+ newArtifact(CLASS_PREFIX, ".class", fileName, storage)
+ }
+
+ private def newArtifact(
+ prefix: Path,
+ requiredSuffix: String,
+ fileName: Path,
+ storage: LocalData): Artifact = {
+ require(!fileName.isAbsolute)
+ require(fileName.toString.endsWith(requiredSuffix))
+ new Artifact(prefix.resolve(fileName), storage)
+ }
+
+ /**
+ * Payload stored on this machine.
+ */
+ sealed trait LocalData {
+ def stream: InputStream
+ def size: Long
+ }
+
+ /**
+ * Payload stored in a local file.
+ */
+ class LocalFile(val path: Path) extends LocalData {
+ override def size: Long = Files.size(path)
+ override def stream: InputStream = Files.newInputStream(path)
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index cdc0b381a44..599aab441de 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -36,6 +36,8 @@ private[sql] class SparkConnectClient(
private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
+ private[client] val artifactManager: ArtifactManager = new ArtifactManager(userContext, channel)
+
/**
* Placeholder method.
* @return
@@ -147,6 +149,27 @@ private[sql] class SparkConnectClient(
analyze(request)
}
+ /**
+ * Add a single artifact to the client session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ */
+ def addArtifact(path: String): Unit = artifactManager.addArtifact(path)
+
+ /**
+ * Add a single artifact to the client session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ */
+ def addArtifact(uri: URI): Unit = artifactManager.addArtifact(uri)
+
+ /**
+ * Add multiple artifacts to the session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ */
+ def addArtifacts(uri: Seq[URI]): Unit = artifactManager.addArtifacts(uri)
+
/**
* Shutdown the client's connection to the server.
*/
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md
new file mode 100644
index 00000000000..df9af410644
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md
@@ -0,0 +1,5 @@
+The CRCs for a specific file are stored in a text file with the same name (excluding the original extension).
+
+The CRCs are calculated for data chunks of `32768 bytes` (individual CRCs) and are newline delimited.
+
+The CRCs were calculated using https://simplycalc.com/crc32-file.php
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt
new file mode 100644
index 00000000000..3e89631dea5
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt
@@ -0,0 +1,12 @@
+902183889
+2415704507
+1084811487
+1951510
+1158852476
+2003120166
+3026803842
+3850244775
+3409267044
+652109216
+104029242
+3019434266
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt
new file mode 100644
index 00000000000..531f98ce9a2
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt
@@ -0,0 +1 @@
+1935693963
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt
new file mode 100644
index 00000000000..531f98ce9a2
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt
@@ -0,0 +1 @@
+1935693963
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt
new file mode 100644
index 00000000000..df32adcce7a
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt
@@ -0,0 +1 @@
+1631702900
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar b/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar
new file mode 100755
index 00000000000..6da55d8b852
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar differ
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class
new file mode 100755
index 00000000000..e796030e471
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class differ
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class
new file mode 100755
index 00000000000..e796030e471
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class differ
diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar
new file mode 100755
index 00000000000..3c4930e8e95
Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar differ
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
old mode 100644
new mode 100755
index 67dc92a7472..6e9583ae725
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -68,27 +68,7 @@ class PlanGenerationTestSuite
// Borrowed from SparkFunSuite
private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
- // Borrowed from SparkFunSuite
- private def getWorkspaceFilePath(first: String, more: String*): Path = {
- if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) {
- fail("spark.test.home or SPARK_HOME is not set.")
- }
- val sparkHome = sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
- java.nio.file.Paths.get(sparkHome, first +: more: _*)
- }
-
- protected val baseResourcePath: Path = {
- getWorkspaceFilePath(
- "connector",
- "connect",
- "common",
- "src",
- "test",
- "resources",
- "query-tests").toAbsolutePath
- }
-
- protected val queryFilePath: Path = baseResourcePath.resolve("queries")
+ protected val queryFilePath: Path = commonResourcePath.resolve("queries")
// A relative path to /connector/connect/server, used by `ProtoToParsedPlanTestSuite` to run
// with the datasource.
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
new file mode 100644
index 00000000000..adb2b3f1908
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import java.io.InputStream
+import java.nio.file.{Files, Path, Paths}
+import java.util.concurrent.TimeUnit
+
+import collection.JavaConverters._
+import com.google.protobuf.ByteString
+import io.grpc.{ManagedChannel, Server}
+import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.AddArtifactsRequest
+import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+
+class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
+
+ private var client: SparkConnectClient = _
+ private var service: DummySparkConnectService = _
+ private var server: Server = _
+ private var artifactManager: ArtifactManager = _
+ private var channel: ManagedChannel = _
+
+ private def startDummyServer(): Unit = {
+ service = new DummySparkConnectService()
+ server = InProcessServerBuilder
+ .forName(getClass.getName)
+ .addService(service)
+ .build()
+ server.start()
+ }
+
+ private def createArtifactManager(): Unit = {
+ channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
+ artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), channel)
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ startDummyServer()
+ createArtifactManager()
+ client = null
+ }
+
+ override def afterEach(): Unit = {
+ if (server != null) {
+ server.shutdownNow()
+ assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to shutdown")
+ }
+
+ if (channel != null) {
+ channel.shutdownNow()
+ }
+
+ if (client != null) {
+ client.shutdown()
+ }
+ }
+
+ private val CHUNK_SIZE: Int = 32 * 1024
+ protected def artifactFilePath: Path = baseResourcePath.resolve("artifact-tests")
+ protected def artifactCrcPath: Path = artifactFilePath.resolve("crc")
+
+ private def getCrcValues(filePath: Path): Seq[Long] = {
+ val fileName = filePath.getFileName.toString
+ val crcFileName = fileName.split('.').head + ".txt"
+ Files
+ .readAllLines(artifactCrcPath.resolve(crcFileName))
+ .asScala
+ .map(_.toLong)
+ }
+
+ /**
+ * Check if the data sent to the server (stored in `artifactChunk`) is equivalent to the local
+ * data at `localPath`.
+ * @param artifactChunk
+ * @param localPath
+ */
+ private def assertFileDataEquality(
+ artifactChunk: AddArtifactsRequest.ArtifactChunk,
+ localPath: Path): Unit = {
+ val localData = ByteString.readFrom(Files.newInputStream(localPath))
+ val expectedCrc = getCrcValues(localPath).head
+ assert(artifactChunk.getData == localData)
+ assert(artifactChunk.getCrc == expectedCrc)
+ }
+
+ private def singleChunkArtifactTest(path: String): Unit = {
+ test(s"Single Chunk Artifact - $path") {
+ val artifactPath = artifactFilePath.resolve(path)
+ artifactManager.addArtifact(artifactPath.toString)
+
+ val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+ // Single `AddArtifactRequest`
+ assert(receivedRequests.size == 1)
+
+ val request = receivedRequests.head
+ assert(request.hasBatch)
+
+ val batch = request.getBatch
+ // Single artifact in batch
+ assert(batch.getArtifactsList.size() == 1)
+
+ val singleChunkArtifact = batch.getArtifacts(0)
+ val namePrefix = artifactPath.getFileName.toString match {
+ case jar if jar.endsWith(".jar") => "jars"
+ case cf if cf.endsWith(".class") => "classes"
+ }
+ assert(singleChunkArtifact.getName.equals(namePrefix + "/" + path))
+ assertFileDataEquality(singleChunkArtifact.getData, artifactPath)
+ }
+ }
+
+ singleChunkArtifactTest("smallClassFile.class")
+
+ singleChunkArtifactTest("smallJar.jar")
+
+ private def readNextChunk(in: InputStream): ByteString = {
+ val buf = new Array[Byte](CHUNK_SIZE)
+ var bytesRead = 0
+ var count = 0
+ while (count != -1 && bytesRead < CHUNK_SIZE) {
+ count = in.read(buf, bytesRead, CHUNK_SIZE - bytesRead)
+ if (count != -1) {
+ bytesRead += count
+ }
+ }
+ if (bytesRead == 0) ByteString.empty()
+ else ByteString.copyFrom(buf, 0, bytesRead)
+ }
+
+ /**
+ * Reads data in a chunk of `CHUNK_SIZE` bytes from `in` and verify equality with server-side
+ * data stored in `chunk`.
+ * @param in
+ * @param chunk
+ * @return
+ */
+ private def checkChunksDataAndCrc(
+ filePath: Path,
+ chunks: Seq[AddArtifactsRequest.ArtifactChunk]): Unit = {
+ val in = Files.newInputStream(filePath)
+ val crcs = getCrcValues(filePath)
+ chunks.zip(crcs).foreach { case (chunk, expectedCrc) =>
+ val expectedData = readNextChunk(in)
+ chunk.getData == expectedData && chunk.getCrc == expectedCrc
+ }
+ }
+
+ test("Chunked Artifact - junitLargeJar.jar") {
+ val artifactPath = artifactFilePath.resolve("junitLargeJar.jar")
+ artifactManager.addArtifact(artifactPath.toString)
+ // Expected chunks = roundUp( file_size / chunk_size) = 12
+ // File size of `junitLargeJar.jar` is 384581 bytes.
+ val expectedChunks = (384581 + (CHUNK_SIZE - 1)) / CHUNK_SIZE
+ val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+ assert(384581 == Files.size(artifactPath))
+ assert(receivedRequests.size == expectedChunks)
+ assert(receivedRequests.head.hasBeginChunk)
+ val beginChunkRequest = receivedRequests.head.getBeginChunk
+ assert(beginChunkRequest.getName == "jars/junitLargeJar.jar")
+ assert(beginChunkRequest.getTotalBytes == 384581)
+ assert(beginChunkRequest.getNumChunks == expectedChunks)
+ val dataChunks = Seq(beginChunkRequest.getInitialChunk) ++
+ receivedRequests.drop(1).map(_.getChunk)
+ checkChunksDataAndCrc(artifactPath, dataChunks)
+ }
+
+ test("Batched SingleChunkArtifacts") {
+ val file1 = artifactFilePath.resolve("smallClassFile.class").toUri
+ val file2 = artifactFilePath.resolve("smallJar.jar").toUri
+ artifactManager.addArtifacts(Seq(file1, file2))
+ val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+ // Single request containing 2 artifacts.
+ assert(receivedRequests.size == 1)
+
+ val request = receivedRequests.head
+ assert(request.hasBatch)
+
+ val batch = request.getBatch
+ assert(batch.getArtifactsList.size() == 2)
+
+ val artifacts = batch.getArtifactsList
+ assert(artifacts.get(0).getName == "classes/smallClassFile.class")
+ assert(artifacts.get(1).getName == "jars/smallJar.jar")
+
+ assertFileDataEquality(artifacts.get(0).getData, Paths.get(file1))
+ assertFileDataEquality(artifacts.get(1).getData, Paths.get(file2))
+ }
+
+ test("Mix of SingleChunkArtifact and chunked artifact") {
+ val file1 = artifactFilePath.resolve("smallClassFile.class").toUri
+ val file2 = artifactFilePath.resolve("junitLargeJar.jar").toUri
+ val file3 = artifactFilePath.resolve("smallClassFileDup.class").toUri
+ val file4 = artifactFilePath.resolve("smallJar.jar").toUri
+ artifactManager.addArtifacts(Seq(file1, file2, file3, file4))
+ val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+ // There are a total of 14 requests.
+ // The 1st request contains a single artifact - smallClassFile.class (There are no
+ // other artifacts batched with it since the next one is large multi-chunk artifact)
+ // Requests 2-13 (1-indexed) belong to the transfer of junitLargeJar.jar. This includes
+ // the first "beginning chunk" and the subsequent data chunks.
+ // The last request (14) contains both smallClassFileDup.class and smallJar.jar batched
+ // together.
+ assert(receivedRequests.size == 1 + 12 + 1)
+
+ val firstReqBatch = receivedRequests.head.getBatch.getArtifactsList
+ assert(firstReqBatch.size() == 1)
+ assert(firstReqBatch.get(0).getName == "classes/smallClassFile.class")
+ assertFileDataEquality(firstReqBatch.get(0).getData, Paths.get(file1))
+
+ val secondReq = receivedRequests(1)
+ assert(secondReq.hasBeginChunk)
+ val beginChunkRequest = secondReq.getBeginChunk
+ assert(beginChunkRequest.getName == "jars/junitLargeJar.jar")
+ assert(beginChunkRequest.getTotalBytes == 384581)
+ assert(beginChunkRequest.getNumChunks == 12)
+ // Large artifact data chunks are requests number 3 to 13.
+ val dataChunks = Seq(beginChunkRequest.getInitialChunk) ++
+ receivedRequests.drop(2).dropRight(1).map(_.getChunk)
+ checkChunksDataAndCrc(Paths.get(file2), dataChunks)
+
+ val lastBatch = receivedRequests.last.getBatch
+ assert(lastBatch.getArtifactsCount == 2)
+ val remainingArtifacts = lastBatch.getArtifactsList
+ assert(remainingArtifacts.get(0).getName == "classes/smallClassFileDup.class")
+ assert(remainingArtifacts.get(1).getName == "jars/smallJar.jar")
+
+ assertFileDataEquality(remainingArtifacts.get(0).getData, Paths.get(file3))
+ assertFileDataEquality(remainingArtifacts.get(1).getData, Paths.get(file4))
+ }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
old mode 100644
new mode 100755
index 8cead49de0c..dcb13589206
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -22,9 +22,10 @@ import io.grpc.{Server, StatusRuntimeException}
import io.grpc.netty.NettyServerBuilder
import io.grpc.stub.StreamObserver
import org.scalatest.BeforeAndAfterEach
+import scala.collection.mutable
import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
+import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.util.ConnectFunSuite
import org.apache.spark.sql.connect.common.config.ConnectCommon
@@ -181,6 +182,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {
private var inputPlan: proto.Plan = _
+ private val inputArtifactRequests: mutable.ListBuffer[AddArtifactsRequest] =
+ mutable.ListBuffer.empty
private[sql] def getAndClearLatestInputPlan(): proto.Plan = {
val plan = inputPlan
@@ -188,6 +191,12 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
plan
}
+ private[sql] def getAndClearLatestAddArtifactRequests(): Seq[AddArtifactsRequest] = {
+ val requests = inputArtifactRequests.toSeq
+ inputArtifactRequests.clear()
+ requests
+ }
+
override def executePlan(
request: ExecutePlanRequest,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
@@ -229,4 +238,16 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
responseObserver.onNext(response)
responseObserver.onCompleted()
}
+
+ override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse])
+ : StreamObserver[AddArtifactsRequest] = new StreamObserver[AddArtifactsRequest] {
+ override def onNext(v: AddArtifactsRequest): Unit = inputArtifactRequests.append(v)
+
+ override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable)
+
+ override def onCompleted(): Unit = {
+ responseObserver.onNext(proto.AddArtifactsResponse.newBuilder().build())
+ responseObserver.onCompleted()
+ }
+ }
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
old mode 100644
new mode 100755
index 5100fa7d229..1ece0838b1b
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
@@ -16,9 +16,43 @@
*/
package org.apache.spark.sql.connect.client.util
+import java.nio.file.Path
+
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
/**
* The basic testsuite the client tests should extend from.
*/
-trait ConnectFunSuite extends AnyFunSuite {} // scalastyle:ignore funsuite
+trait ConnectFunSuite extends AnyFunSuite { // scalastyle:ignore funsuite
+
+ // Borrowed from SparkFunSuite
+ protected def getWorkspaceFilePath(first: String, more: String*): Path = {
+ if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) {
+ fail("spark.test.home or SPARK_HOME is not set.")
+ }
+ val sparkHome = sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
+ java.nio.file.Paths.get(sparkHome, first +: more: _*)
+ }
+
+ protected val baseResourcePath: Path = {
+ getWorkspaceFilePath(
+ "connector",
+ "connect",
+ "client",
+ "jvm",
+ "src",
+ "test",
+ "resources").toAbsolutePath
+ }
+
+ protected val commonResourcePath: Path = {
+ getWorkspaceFilePath(
+ "connector",
+ "connect",
+ "common",
+ "src",
+ "test",
+ "resources",
+ "query-tests").toAbsolutePath
+ }
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
old mode 100644
new mode 100755
index d6446eae4b7..cd353b6ff60
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
@@ -179,6 +180,28 @@ class SparkConnectService(debug: Boolean)
new SparkConnectConfigHandler(responseObserver).handle(request)
} catch handleError("config", observer = responseObserver)
}
+
+ /**
+ * This is the main entry method for all calls to add/transfer artifacts.
+ *
+ * @param responseObserver
+ * @return
+ */
+ override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse])
+ : StreamObserver[AddArtifactsRequest] = {
+ // TODO: Handle artifact files
+ // No-Op StreamObserver
+ new StreamObserver[AddArtifactsRequest] {
+ override def onNext(v: AddArtifactsRequest): Unit = {}
+
+ override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable)
+
+ override def onCompleted(): Unit = {
+ responseObserver.onNext(proto.AddArtifactsResponse.newBuilder().build())
+ responseObserver.onCompleted()
+ }
+ }
+ }
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org