You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2023/06/07 13:01:04 UTC

[spark] branch master updated: [SPARK-43993][SQL][TESTS] Add tests for cache artifacts

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fead8a7962a [SPARK-43993][SQL][TESTS] Add tests for cache artifacts
fead8a7962a is described below

commit fead8a7962a717aae5cab9eef51eed2ac684f070
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Wed Jun 7 16:00:49 2023 +0300

    [SPARK-43993][SQL][TESTS] Add tests for cache artifacts
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to add a test to check two methods of the artifact manager:
    - `isCachedArtifact()`
    - `cacheArtifact()`
    
    ### Why are the changes needed?
    To improve test coverage of Artifacts API.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    By running new test:
    ```
    $ build/sbt "test:testOnly *.ArtifactSuite"
    ```
    
    Closes #41493 from MaxGekk/test-cache-artifact.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../spark/sql/connect/client/ArtifactManager.scala |  2 +-
 .../spark/sql/connect/client/ArtifactSuite.scala   | 14 ++++++++++++
 .../connect/client/SparkConnectClientSuite.scala   | 25 +++++++++++++++++++++-
 3 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
index acd9f279c6d..6d0d16df946 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -108,7 +108,7 @@ class ArtifactManager(
    */
   def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))
 
-  private def isCachedArtifact(hash: String): Boolean = {
+  private[client] def isCachedArtifact(hash: String): Boolean = {
     val artifactName = CACHE_PREFIX + "/" + hash
     val request = proto.ArtifactStatusesRequest
       .newBuilder()
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
index 506ad3625b0..39ab0eef412 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
 import com.google.protobuf.ByteString
 import io.grpc.{ManagedChannel, Server}
 import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
+import org.apache.commons.codec.digest.DigestUtils.sha256Hex
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.connect.proto
@@ -248,4 +249,17 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
     assertFileDataEquality(remainingArtifacts.get(0).getData, Paths.get(file3))
     assertFileDataEquality(remainingArtifacts.get(1).getData, Paths.get(file4))
   }
+
+  test("cache an artifact and check its presence") {
+    val s = "Hello, World!"
+    val blob = s.getBytes("UTF-8")
+    val expectedHash = sha256Hex(blob)
+    assert(artifactManager.isCachedArtifact(expectedHash) === false)
+    val actualHash = artifactManager.cacheArtifact(blob)
+    assert(actualHash === expectedHash)
+    assert(artifactManager.isCachedArtifact(expectedHash) === true)
+
+    val receivedRequests = service.getAndClearLatestAddArtifactRequests()
+    assert(receivedRequests.size == 1)
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 7a0ad1a9e2a..7e0b687054d 100755
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.client
 
 import java.util.concurrent.TimeUnit
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 import io.grpc.{Server, StatusRuntimeException}
@@ -26,7 +27,7 @@ import io.grpc.stub.StreamObserver
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
+import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
 import org.apache.spark.sql.connect.common.config.ConnectCommon
@@ -251,4 +252,26 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
       responseObserver.onCompleted()
     }
   }
+
+  override def artifactStatus(
+      request: ArtifactStatusesRequest,
+      responseObserver: StreamObserver[ArtifactStatusesResponse]): Unit = {
+    val builder = proto.ArtifactStatusesResponse.newBuilder()
+    request.getNamesList().iterator().asScala.foreach { name =>
+      val status = proto.ArtifactStatusesResponse.ArtifactStatus.newBuilder()
+      val exists = if (name.startsWith("cache/")) {
+        inputArtifactRequests.exists { artifactReq =>
+          if (artifactReq.hasBatch) {
+            val batch = artifactReq.getBatch
+            batch.getArtifactsList.asScala.exists { singleArtifact =>
+              singleArtifact.getName == name
+            }
+          } else false
+        }
+      } else false
+      builder.putStatuses(name, status.setExists(exists).build())
+    }
+    responseObserver.onNext(builder.build())
+    responseObserver.onCompleted()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org