You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/02 16:58:18 UTC

[spark] branch branch-3.4 updated: [SPARK-42172][CONNECT] Scala Client Mima Compatibility Tests

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 7d32904cef6 [SPARK-42172][CONNECT] Scala Client Mima Compatibility Tests
7d32904cef6 is described below

commit 7d32904cef68e363ddf1286c3c133989bb9d0904
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Thu Feb 2 12:57:46 2023 -0400

    [SPARK-42172][CONNECT] Scala Client Mima Compatibility Tests
    
    ### What changes were proposed in this pull request?
    
    The Spark Connect Scala Client should provide the same API as the existing SQL API. This PR adds the tests to ensure the generated binaries of two modules are compatible using MiMa.
    The covered APIs are:
    * `Dataset`,
    * `SparkSession` with all implemented methods,
    * `Column` with all implemented methods,
    * `DataFrame`
    
    ### Why are the changes needed?
    Ensures the binary compatibility of the two APIs.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Integration tests.
    
    Note: This PR need to be merged into 3.4 too.
    
    Closes #39712 from zhenlineo/cp-test.
    
    Authored-by: Zhen Li <zh...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
    (cherry picked from commit 15971a0e6f8fb21d5c6effa55ed0b313cd03ec07)
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 connector/connect/client/jvm/pom.xml               |   8 ++
 .../main/scala/org/apache/spark/sql/Column.scala   |   5 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  |   9 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |  31 +++--
 .../main/scala/org/apache/spark/sql/package.scala  |  22 +++
 .../sql/connect/client/CompatibilitySuite.scala    | 153 +++++++++++++++++++++
 .../connect/client/util/IntegrationTestUtils.scala |  76 ++++++++++
 .../connect/client/util/RemoteSparkSession.scala   |  55 +-------
 8 files changed, 293 insertions(+), 66 deletions(-)

diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml
index 94e49033858..341bc43a072 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -33,6 +33,7 @@
   <properties>
     <sbt.project.name>connect-client-jvm</sbt.project.name>
     <guava.version>31.0.1-jre</guava.version>
+    <mima.version>1.1.0</mima.version>
   </properties>
 
   <dependencies>
@@ -92,6 +93,13 @@
       <artifactId>mockito-core</artifactId>
       <scope>test</scope>
     </dependency>
+    <!-- Use mima to perform the compatibility check -->
+    <dependency>
+      <groupId>com.typesafe</groupId>
+      <artifactId>mima-core_${scala.binary.version}</artifactId>
+      <version>${mima.version}</version>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
   <build>
     <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
index f25d579d5c3..35ea76e5d98 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 import scala.collection.JavaConverters._
 
 import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Column.fn
 import org.apache.spark.sql.connect.client.unsupported
 import org.apache.spark.sql.functions.lit
@@ -44,7 +45,7 @@ import org.apache.spark.sql.functions.lit
  *
  * @since 3.4.0
  */
-class Column private[sql] (private[sql] val expr: proto.Expression) {
+class Column private[sql] (private[sql] val expr: proto.Expression) extends Logging {
 
   /**
    * Sum of this expression and another expression.
@@ -80,7 +81,7 @@ class Column private[sql] (private[sql] val expr: proto.Expression) {
   }
 }
 
-object Column {
+private[sql] object Column {
 
   def apply(name: String): Column = Column { builder =>
     name match {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6891b2f5bed..51b734d1daa 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,7 +21,8 @@ import scala.collection.JavaConverters._
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.client.SparkResult
 
-class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
+class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan: proto.Plan)
+    extends Serializable {
 
   /**
    * Selects a set of column based expressions.
@@ -33,7 +34,7 @@ class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def select(cols: Column*): Dataset = session.newDataset { builder =>
+  def select(cols: Column*): DataFrame = session.newDataset { builder =>
     builder.getProjectBuilder
       .setInput(plan.getRoot)
       .addAllExpressions(cols.map(_.expr).asJava)
@@ -50,7 +51,7 @@ class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
    * @group typedrel
    * @since 3.4.0
    */
-  def filter(condition: Column): Dataset = session.newDataset { builder =>
+  def filter(condition: Column): Dataset[T] = session.newDataset { builder =>
     builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
   }
 
@@ -62,7 +63,7 @@ class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
    * @group typedrel
    * @since 3.4.0
    */
-  def limit(n: Int): Dataset = session.newDataset { builder =>
+  def limit(n: Int): Dataset[T] = session.newDataset { builder =>
     builder.getLimitBuilder
       .setInput(plan.getRoot)
       .setLimit(n)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 0c4f702ca34..eca5658e33d 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -16,9 +16,12 @@
  */
 package org.apache.spark.sql
 
+import java.io.Closeable
+
 import org.apache.arrow.memory.RootAllocator
 
 import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
 import org.apache.spark.sql.connect.client.util.Cleaner
 
@@ -43,7 +46,9 @@ import org.apache.spark.sql.connect.client.util.Cleaner
  * }}}
  */
 class SparkSession(private val client: SparkConnectClient, private val cleaner: Cleaner)
-    extends AutoCloseable {
+    extends Serializable
+    with Closeable
+    with Logging {
 
   private[this] val allocator = new RootAllocator()
 
@@ -53,7 +58,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
    *
    * @since 3.4.0
    */
-  def sql(query: String): Dataset = newDataset { builder =>
+  def sql(query: String): DataFrame = newDataset { builder =>
     builder.setSql(proto.SQL.newBuilder().setQuery(query))
   }
 
@@ -63,7 +68,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
    *
    * @since 3.4.0
    */
-  def range(end: Long): Dataset = range(0, end)
+  def range(end: Long): Dataset[java.lang.Long] = range(0, end)
 
   /**
    * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
@@ -71,7 +76,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
    *
    * @since 3.4.0
    */
-  def range(start: Long, end: Long): Dataset = {
+  def range(start: Long, end: Long): Dataset[java.lang.Long] = {
     range(start, end, step = 1)
   }
 
@@ -81,7 +86,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
    *
    * @since 3.4.0
    */
-  def range(start: Long, end: Long, step: Long): Dataset = {
+  def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
     range(start, end, step, None)
   }
 
@@ -91,11 +96,15 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
    *
    * @since 3.4.0
    */
-  def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset = {
+  def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
     range(start, end, step, Option(numPartitions))
   }
 
-  private def range(start: Long, end: Long, step: Long, numPartitions: Option[Int]): Dataset = {
+  private def range(
+      start: Long,
+      end: Long,
+      step: Long,
+      numPartitions: Option[Int]): Dataset[java.lang.Long] = {
     newDataset { builder =>
       val rangeBuilder = builder.getRangeBuilder
         .setStart(start)
@@ -105,11 +114,11 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
     }
   }
 
-  private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = {
+  private[sql] def newDataset[T](f: proto.Relation.Builder => Unit): Dataset[T] = {
     val builder = proto.Relation.newBuilder()
     f(builder)
     val plan = proto.Plan.newBuilder().setRoot(builder).build()
-    new Dataset(this, plan)
+    new Dataset[T](this, plan)
   }
 
   private[sql] def analyze(plan: proto.Plan): proto.AnalyzePlanResponse =
@@ -130,7 +139,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
 
 // The minimal builder needed to create a spark session.
 // TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
-object SparkSession {
+object SparkSession extends Logging {
   def builder(): Builder = new Builder()
 
   private lazy val cleaner = {
@@ -139,7 +148,7 @@ object SparkSession {
     cleaner
   }
 
-  class Builder() {
+  class Builder() extends Logging {
     private var _client = SparkConnectClient.builder().build()
 
     def client(client: SparkConnectClient): Builder = {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala
new file mode 100644
index 00000000000..ada94b76fcb
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+package object sql {
+  type DataFrame = Dataset[Row]
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
new file mode 100644
index 00000000000..21eed56ee78
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import java.io.File
+import java.net.URLClassLoader
+import java.util.regex.Pattern
+
+import com.typesafe.tools.mima.core._
+import com.typesafe.tools.mima.lib.MiMaLib
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._
+
+/**
+ * This test checks the binary compatibility of the connect client API against the spark SQL API
+ * using MiMa. We did not write this check using a SBT build rule as the rule cannot provide the
+ * same level of freedom as a test. With a test we can:
+ *   1. Specify any two jars to run the compatibility check.
+ *   1. Easily make the test automatically pick up all new methods added while the client is being
+ *      built.
+ *
+ * The test requires the following artifacts built before running:
+ * {{{
+ *     spark-sql
+ *     spark-connect-client-jvm
+ * }}}
+ * To build the above artifact, use e.g. `sbt package` or `mvn clean install -DskipTests`.
+ *
+ * When debugging this test, if any changes to the client API, the client jar need to be built
+ * before running the test. An example workflow with SBT for this test:
+ *   1. Compatibility test has reported an unexpected client API change.
+ *   1. Fix the wrong client API.
+ *   1. Build the client jar: `sbt package`
+ *   1. Run the test again: `sbt "testOnly
+ *      org.apache.spark.sql.connect.client.CompatibilitySuite"`
+ */
+class CompatibilitySuite extends AnyFunSuite { // scalastyle:ignore funsuite
+
+  private lazy val clientJar: File =
+    findJar(
+      "connector/connect/client/jvm",
+      "spark-connect-client-jvm-assembly",
+      "spark-connect-client-jvm")
+
+  private lazy val sqlJar: File = findJar("sql/core", "spark-sql", "spark-sql")
+
+  /**
+   * MiMa takes an old jar (sql jar) and a new jar (client jar) as inputs and then reports all
+   * incompatibilities found in the new jar. The incompatibility result is then filtered using
+   * include and exclude rules. Include rules are first applied to find all client classes that
+   * need to be checked. Then exclude rules are applied to filter out all unsupported methods in
+   * the client classes.
+   */
+  test("compatibility MiMa tests") {
+    val mima = new MiMaLib(Seq(clientJar, sqlJar))
+    val allProblems = mima.collectProblems(sqlJar, clientJar, List.empty)
+    val includedRules = Seq(
+      IncludeByName("org.apache.spark.sql.Column"),
+      IncludeByName("org.apache.spark.sql.Column$"),
+      IncludeByName("org.apache.spark.sql.Dataset"),
+      // TODO(SPARK-42175) Add the Dataset object definition
+      // IncludeByName("org.apache.spark.sql.Dataset$"),
+      IncludeByName("org.apache.spark.sql.DataFrame"),
+      IncludeByName("org.apache.spark.sql.SparkSession"),
+      IncludeByName("org.apache.spark.sql.SparkSession$")) ++ includeImplementedMethods(clientJar)
+    val excludeRules = Seq(
+      // Filter unsupported rules:
+      // Two sql overloading methods are marked experimental in the API and skipped in the client.
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"),
+      // Skip all shaded dependencies in the client.
+      ProblemFilters.exclude[Problem]("org.sparkproject.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.connect.proto.*"))
+    val problems = allProblems
+      .filter { p =>
+        includedRules.exists(rule => rule(p))
+      }
+      .filter { p =>
+        excludeRules.forall(rule => rule(p))
+      }
+
+    if (problems.nonEmpty) {
+      fail(
+        s"\nComparing client jar: $clientJar\nand sql jar: $sqlJar\n" +
+          problems.map(p => p.description("client")).mkString("\n"))
+    }
+  }
+
+  test("compatibility API tests: Dataset") {
+    val clientClassLoader: URLClassLoader = new URLClassLoader(Seq(clientJar.toURI.toURL).toArray)
+    val sqlClassLoader: URLClassLoader = new URLClassLoader(Seq(sqlJar.toURI.toURL).toArray)
+
+    val clientClass = clientClassLoader.loadClass("org.apache.spark.sql.Dataset")
+    val sqlClass = sqlClassLoader.loadClass("org.apache.spark.sql.Dataset")
+
+    val newMethods = clientClass.getMethods
+    val oldMethods = sqlClass.getMethods
+
+    // For now we simply check the new methods is a subset of the old methods.
+    newMethods
+      .map(m => m.toString)
+      .foreach(method => {
+        assert(oldMethods.map(m => m.toString).contains(method))
+      })
+  }
+
+  /**
+   * Find all methods that are implemented in the client jar. Once all major methods are
+   * implemented we can switch to include all methods under the class using ".*" e.g.
+   * "org.apache.spark.sql.Dataset.*"
+   */
+  private def includeImplementedMethods(clientJar: File): Seq[IncludeByName] = {
+    val clsNames = Seq(
+      "org.apache.spark.sql.Column",
+      // TODO(SPARK-42175) Add all overloading methods. Temporarily mute compatibility check for \
+      //  the Dataset methods, as too many overload methods are missing.
+      // "org.apache.spark.sql.Dataset",
+      "org.apache.spark.sql.SparkSession")
+
+    val clientClassLoader: URLClassLoader = new URLClassLoader(Seq(clientJar.toURI.toURL).toArray)
+    clsNames
+      .flatMap { clsName =>
+        val cls = clientClassLoader.loadClass(clsName)
+        // all distinct method names
+        cls.getMethods.map(m => s"$clsName.${m.getName}").toSet
+      }
+      .map { fullName =>
+        IncludeByName(fullName)
+      }
+  }
+
+  private case class IncludeByName(name: String) extends ProblemFilter {
+    private[this] val pattern =
+      Pattern.compile(name.split("\\*", -1).map(Pattern.quote).mkString(".*"))
+
+    override def apply(problem: Problem): Boolean = {
+      pattern.matcher(problem.matchName.getOrElse("")).matches
+    }
+  }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
new file mode 100644
index 00000000000..f0ae4cad679
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.util
+
+import java.io.File
+
+import org.scalatest.Assertions.fail
+
+object IntegrationTestUtils {
+
+  // System properties used for testing and debugging
+  private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
+
+  private[connect] lazy val sparkHome: String = {
+    if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) {
+      fail("spark.test.home or SPARK_HOME is not set.")
+    }
+    sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
+  }
+  private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean
+
+  // Log server start stop debug info into console
+  // scalastyle:off println
+  private[connect] def debug(msg: String): Unit = if (isDebug) println(msg)
+  // scalastyle:on println
+  private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace()
+
+  /**
+   * Find a jar in the Spark project artifacts. It requires a build first (e.g. sbt package, mvn
+   * clean install -DskipTests) so that this method can find the jar in the target folders.
+   *
+   * @return
+   *   the jar
+   */
+  private[sql] def findJar(path: String, sbtName: String, mvnName: String): File = {
+    val targetDir = new File(new File(sparkHome, path), "target")
+    assert(
+      targetDir.exists(),
+      s"Fail to locate the target folder: '${targetDir.getCanonicalPath}'. " +
+        s"SPARK_HOME='${new File(sparkHome).getCanonicalPath}'. " +
+        "Make sure the spark project jars has been built (e.g. using sbt package)" +
+        "and the env variable `SPARK_HOME` is set correctly.")
+    val jars = recursiveListFiles(targetDir).filter { f =>
+      // SBT jar
+      (f.getParentFile.getName.startsWith("scala-") &&
+        f.getName.startsWith(sbtName) && f.getName.endsWith(".jar")) ||
+      // Maven Jar
+      (f.getParent.endsWith("target") &&
+        f.getName.startsWith(mvnName) &&
+        f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}.jar"))
+    }
+    // It is possible we found more than one: one built by maven, and another by SBT
+    assert(jars.nonEmpty, s"Failed to find the jar inside folder: ${targetDir.getCanonicalPath}")
+    debug("Using jar: " + jars(0).getCanonicalPath)
+    jars(0) // return the first jar found
+  }
+
+  private def recursiveListFiles(f: File): Array[File] = {
+    val these = f.listFiles
+    these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
+  }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 552799d5229..2d9c218b2fb 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -21,16 +21,17 @@ import java.util.concurrent.TimeUnit
 
 import scala.io.Source
 
-import org.scalatest.Assertions.fail
 import org.scalatest.BeforeAndAfterAll
 import sys.process._
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connect.client.SparkConnectClient
+import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._
 import org.apache.spark.sql.connect.common.config.ConnectCommon
 
 /**
  * An util class to start a local spark connect server in a different process for local E2E tests.
+ * Pre-running the tests, the spark connect artifact needs to be built using e.g. `sbt package`.
  * It is designed to start the server once but shared by all tests. It is equivalent to use the
  * following command to start the connect server via command line:
  *
@@ -45,22 +46,6 @@ import org.apache.spark.sql.connect.common.config.ConnectCommon
  * print the server process output in the console to debug server start stop problems.
  */
 object SparkConnectServerUtils {
-  // System properties used for testing and debugging
-  private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
-
-  protected lazy val sparkHome: String = {
-    if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) {
-      fail("spark.test.home or SPARK_HOME is not set.")
-    }
-    sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
-  }
-  private val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean
-
-  // Log server start stop debug info into console
-  // scalastyle:off println
-  private[connect] def debug(msg: String): Unit = if (isDebug) println(msg)
-  // scalastyle:on println
-  private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace()
 
   // Server port
   private[connect] val port = ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000)
@@ -72,7 +57,10 @@ object SparkConnectServerUtils {
 
   private lazy val sparkConnect: Process = {
     debug("Starting the Spark Connect Server...")
-    val jar = findSparkConnectJar
+    val jar = findJar(
+      "connector/connect/server",
+      "spark-connect-assembly",
+      "spark-connect").getCanonicalPath
     val builder = Process(
       Seq(
         "bin/spark-submit",
@@ -118,37 +106,6 @@ object SparkConnectServerUtils {
     debug(s"Spark Connect Server is stopped with exit code: $code")
     code
   }
-
-  private def findSparkConnectJar: String = {
-    val target = "connector/connect/server/target"
-    val parentDir = new File(sparkHome, target)
-    assert(
-      parentDir.exists(),
-      s"Fail to locate the spark connect server target folder: '${parentDir.getCanonicalPath}'. " +
-        s"SPARK_HOME='${new File(sparkHome).getCanonicalPath}'. " +
-        "Make sure the spark connect server jar has been built " +
-        "and the env variable `SPARK_HOME` is set correctly.")
-    val jars = recursiveListFiles(parentDir).filter { f =>
-      // SBT jar
-      (f.getParentFile.getName.startsWith("scala-") &&
-        f.getName.startsWith("spark-connect-assembly") && f.getName.endsWith(".jar")) ||
-      // Maven Jar
-      (f.getParent.endsWith("target") &&
-        f.getName.startsWith("spark-connect") &&
-        f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}.jar"))
-    }
-    // It is possible we found more than one: one built by maven, and another by SBT
-    assert(
-      jars.nonEmpty,
-      s"Failed to find the `spark-connect` jar inside folder: ${parentDir.getCanonicalPath}")
-    debug("Using jar: " + jars(0).getCanonicalPath)
-    jars(0).getCanonicalPath // return the first one
-  }
-
-  def recursiveListFiles(f: File): Array[File] = {
-    val these = f.listFiles
-    these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
-  }
 }
 
 trait RemoteSparkSession


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org