You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/06/28 14:37:50 UTC

[spark] branch master updated: [SPARK-43757][CONNECT] Change client compatibility from allow list to deny list

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fcd537a37b [SPARK-43757][CONNECT] Change client compatibility from allow list to deny list
1fcd537a37b is described below

commit 1fcd537a37b2457092e20f8034f23917a8ae2ffa
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Wed Jun 28 10:37:38 2023 -0400

    [SPARK-43757][CONNECT] Change client compatibility from allow list to deny list
    
    ### What changes were proposed in this pull request?
    Expand the client compatibility check to include all sql APIs.
    
    ### Why are the changes needed?
    Enhance the API compatibility coverage
    
    ### Does this PR introduce _any_ user-facing change?
    No, except it fixes a few wrong types and hides a few helper methods internally.
    
    ### How was this patch tested?
    Existing tests.
    
    Closes #41284 from zhenlineo/compatibility-check-allowlist.
    
    Authored-by: Zhen Li <zh...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |   6 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |   2 +-
 .../sql/streaming/StreamingQueryException.scala    |   3 +-
 .../sql/streaming/StreamingQueryManager.scala      |   3 +-
 .../CheckConnectJvmClientCompatibility.scala       | 327 ++++++++++++++-------
 5 files changed, 225 insertions(+), 116 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 20c130b83cb..e67ef1c0fa7 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode
  *
  * @since 3.5.0
  */
-abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
 
   /**
    * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
@@ -462,7 +462,7 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable
       UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
   }
 
-  protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
+  protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
       outputMode: Option[OutputMode],
       timeoutConf: GroupStateTimeout,
       initialState: Option[KeyValueGroupedDataset[K, S]],
@@ -923,7 +923,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     agg(aggregator)
   }
 
-  override protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
+  override protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
       outputMode: Option[OutputMode],
       timeoutConf: GroupStateTimeout,
       initialState: Option[KeyValueGroupedDataset[K, S]],
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 45e7dca38d7..54e9102c55c 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -429,7 +429,7 @@ class SparkSession private[sql] (
    *
    * @since 3.4.0
    */
-  object implicits extends SQLImplicits(this)
+  object implicits extends SQLImplicits(this) with Serializable
   // scalastyle:on
 
   def newSession(): SparkSession = {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
index d5e9982dfbf..512c94f5c70 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
@@ -36,7 +36,8 @@ class StreamingQueryException private[sql] (
     message: String,
     errorClass: String,
     stackTrace: String)
-    extends SparkThrowable {
+    extends Exception(message)
+    with SparkThrowable {
 
   override def getErrorClass: String = errorClass
 
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 775921ff579..13bbf470639 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -25,6 +25,7 @@ import org.apache.spark.annotation.Evolving
 import org.apache.spark.connect.proto.Command
 import org.apache.spark.connect.proto.StreamingQueryManagerCommand
 import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 
 /**
@@ -33,7 +34,7 @@ import org.apache.spark.sql.SparkSession
  * @since 3.5.0
  */
 @Evolving
-class StreamingQueryManager private[sql] (sparkSession: SparkSession) {
+class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging {
 
   /**
    * Returns a list of active queries associated with this SQLContext
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index acc469672b4..f22baddc01e 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -17,18 +17,14 @@
 package org.apache.spark.sql.connect.client
 
 import java.io.{File, Writer}
-import java.net.URLClassLoader
 import java.nio.charset.StandardCharsets
 import java.nio.file.{Files, Paths}
 import java.util.regex.Pattern
 
-import scala.reflect.runtime.universe.runtimeMirror
-
-import com.typesafe.tools.mima.core.{Problem, ProblemFilter, ProblemFilters}
+import com.typesafe.tools.mima.core._
 import com.typesafe.tools.mima.lib.MiMaLib
 
 import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._
-import org.apache.spark.util.ChildFirstURLClassLoader
 
 /**
  * A tool for checking the binary compatibility of the connect client API against the spark SQL
@@ -70,6 +66,16 @@ object CheckConnectJvmClientCompatibility {
         sqlJar,
         "Sql")
 
+      val problemsWithClientModule =
+        checkMiMaCompatibilityWithReversedSqlModule(clientJar, sqlJar)
+      appendMimaCheckErrorMessageIfNeeded(
+        resultWriter,
+        problemsWithClientModule,
+        clientJar,
+        sqlJar,
+        "ReversedSql",
+        "Sql")
+
       val avroJar: File = findJar("connector/avro", "spark-avro", "spark-avro")
       val problemsWithAvroModule = checkMiMaCompatibilityWithAvroModule(clientJar, avroJar)
       appendMimaCheckErrorMessageIfNeeded(
@@ -89,9 +95,6 @@ object CheckConnectJvmClientCompatibility {
         clientJar,
         protobufJar,
         "Protobuf")
-
-      val incompatibleApis = checkDatasetApiCompatibility(clientJar, sqlJar)
-      appendIncompatibleDatasetApisErrorMessageIfNeeded(resultWriter, incompatibleApis)
     } catch {
       case e: Throwable =>
         println(e.getMessage)
@@ -122,65 +125,62 @@ object CheckConnectJvmClientCompatibility {
   private def checkMiMaCompatibilityWithSqlModule(
       clientJar: File,
       sqlJar: File): List[Problem] = {
-    val includedRules = Seq(
-      IncludeByName("org.apache.spark.sql.catalog.Catalog.*"),
-      IncludeByName("org.apache.spark.sql.catalog.CatalogMetadata.*"),
-      IncludeByName("org.apache.spark.sql.catalog.Column.*"),
-      IncludeByName("org.apache.spark.sql.catalog.Database.*"),
-      IncludeByName("org.apache.spark.sql.catalog.Function.*"),
-      IncludeByName("org.apache.spark.sql.catalog.Table.*"),
-      IncludeByName("org.apache.spark.sql.Column.*"),
-      IncludeByName("org.apache.spark.sql.ColumnName.*"),
-      IncludeByName("org.apache.spark.sql.DataFrame.*"),
-      IncludeByName("org.apache.spark.sql.DataFrameReader.*"),
-      IncludeByName("org.apache.spark.sql.DataFrameNaFunctions.*"),
-      IncludeByName("org.apache.spark.sql.DataFrameStatFunctions.*"),
-      IncludeByName("org.apache.spark.sql.DataFrameWriter.*"),
-      IncludeByName("org.apache.spark.sql.DataFrameWriterV2.*"),
-      IncludeByName("org.apache.spark.sql.Dataset.*"),
-      IncludeByName("org.apache.spark.sql.functions.*"),
-      IncludeByName("org.apache.spark.sql.KeyValueGroupedDataset.*"),
-      IncludeByName("org.apache.spark.sql.RelationalGroupedDataset.*"),
-      IncludeByName("org.apache.spark.sql.SparkSession.*"),
-      IncludeByName("org.apache.spark.sql.RuntimeConfig.*"),
-      IncludeByName("org.apache.spark.sql.TypedColumn.*"),
-      IncludeByName("org.apache.spark.sql.SQLImplicits.*"),
-      IncludeByName("org.apache.spark.sql.DatasetHolder.*"),
-      IncludeByName("org.apache.spark.sql.streaming.DataStreamReader.*"),
-      IncludeByName("org.apache.spark.sql.streaming.DataStreamWriter.*"),
-      IncludeByName("org.apache.spark.sql.streaming.StreamingQuery.*"),
-      IncludeByName("org.apache.spark.sql.streaming.StreamingQueryManager.active"),
-      IncludeByName("org.apache.spark.sql.streaming.StreamingQueryManager.get"),
-      IncludeByName("org.apache.spark.sql.streaming.StreamingQueryManager.awaitAnyTermination"),
-      IncludeByName("org.apache.spark.sql.streaming.StreamingQueryManager.resetTerminated"),
-      IncludeByName("org.apache.spark.sql.streaming.StreamingQueryStatus.*"),
-      IncludeByName("org.apache.spark.sql.streaming.StreamingQueryProgress.*"))
+    val includedRules = Seq(IncludeByName("org.apache.spark.sql.*"))
     val excludeRules = Seq(
       // Filter unsupported rules:
       // Note when muting errors for a method, checks on all overloading methods are also muted.
 
-      // Skip all shaded dependencies and proto files in the client.
-      ProblemFilters.exclude[Problem]("org.sparkproject.*"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.connect.proto.*"),
+      // Skip unsupported packages
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.api.*"), // Java, Python, R
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.expressions.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.sources.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.streaming.ui.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.test.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.util.*"),
+
+      // Skip private[sql] constructors
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.*.this"),
+
+      // Skip unsupported classes
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ExperimentalMethods"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$*"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSessionExtensions"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.SparkSessionExtensionsProvider"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDTFRegistration"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration$"),
 
       // DataFrame Reader & Writer
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.json"), // deprecated
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.json"), // rdd
 
       // DataFrameNaFunctions
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.this"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"),
 
       // DataFrameStatFunctions
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.this"),
 
       // Dataset
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.Dataset$" // private[sql]
+      ),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.ofRows"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_TAG"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.COL_POS_KEY"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_KEY"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.curId"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.observe"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
@@ -191,7 +191,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.this"),
 
       // functions
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"),
@@ -203,11 +202,12 @@ object CheckConnectJvmClientCompatibility {
       // KeyValueGroupedDataset
       ProblemFilters.exclude[Problem](
         "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.this"),
 
       // RelationalGroupedDataset
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.RelationalGroupedDataset$*" // private[sql]
+      ),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.this"),
 
       // SparkSession
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"),
@@ -226,16 +226,32 @@ object CheckConnectJvmClientCompatibility {
         "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"),
       // TODO(SPARK-44068): Support positional parameters in Scala connect client
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"),
+
+      // SparkSession#implicits
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#implicits._sqlContext"),
+
+      // SparkSession#Builder
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.appName"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.config"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.master"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.enableHiveSupport"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.withExtensions"),
 
       // RuntimeConfig
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig$"),
 
-      // TypedColumn
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.TypedColumn.this"),
       // DataStreamWriter
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.DataStreamWriter$"),
       ProblemFilters.exclude[Problem](
         "org.apache.spark.sql.streaming.DataStreamWriter.foreachBatch" // TODO(SPARK-42944)
       ),
@@ -243,27 +259,161 @@ object CheckConnectJvmClientCompatibility {
         "org.apache.spark.sql.streaming.DataStreamWriter.SOURCE*" // These are constant vals.
       ),
 
+      // StreamingQueryException
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryException.message"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryException.cause"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryException.startOffset"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryException.endOffset"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryException.time"),
+
+      // StreamingQueryManager
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryManager.addListener"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryManager.removeListener"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryManager.listListeners"),
+
+      // Classes missing from streaming API
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.TestGroupState"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.TestGroupState$"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.PythonStreamingQueryListener"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.PythonStreamingQueryListenerWrapper"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryListener"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.StreamingQueryListener$*"),
+
       // SQLImplicits
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.this"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext"))
     checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules)
   }
 
   /**
-   * MiMa takes an old jar (sql jar) and a new jar (client jar) as inputs and then reports all
-   * incompatibilities found in the new jar. The incompatibility result is then filtered using
-   * include and exclude rules. Include rules are first applied to find all client classes that
-   * need to be checked. Then exclude rules are applied to filter out all unsupported methods in
-   * the client classes.
+   * This check ensures client jar dose not expose any unwanted APIs by mistake.
    */
-  private def checkMiMaCompatibility(
+  private def checkMiMaCompatibilityWithReversedSqlModule(
       clientJar: File,
-      targetJar: File,
+      sqlJar: File): List[Problem] = {
+    val includedRules = Seq(IncludeByName("org.apache.spark.sql.*"))
+    val excludeRules = Seq(
+      // Skipped packages
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.avro.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.connect.*"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.protobuf.*"),
+
+      // private[sql]
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.*.this"),
+      ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameStatFunctions$"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.KeyValueGroupedDatasetImpl"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.KeyValueGroupedDatasetImpl$"),
+      ProblemFilters.exclude[ReversedMissingMethodProblem](
+        "org.apache.spark.sql.SQLImplicits._sqlContext" // protected
+      ),
+
+      // New public APIs added in the client
+      // ScalarUserDefinedFunction
+      ProblemFilters
+        .exclude[MissingClassProblem](
+          "org.apache.spark.sql.expressions.ScalarUserDefinedFunction"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.expressions.ScalarUserDefinedFunction$"),
+
+      // Dataset
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.Dataset.plan"
+      ), // developer API
+      ProblemFilters.exclude[IncompatibleResultTypeProblem](
+        "org.apache.spark.sql.Dataset.encoder"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.Dataset.collectResult"),
+
+      // RuntimeConfig
+      ProblemFilters.exclude[MissingTypesProblem](
+        "org.apache.spark.sql.RuntimeConfig" // Client version extends Logging
+      ),
+      ProblemFilters.exclude[Problem](
+        "org.apache.spark.sql.RuntimeConfig.*" // Mute missing Logging methods
+      ),
+      // ConnectRepl
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.application.ConnectRepl" // developer API
+      ),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.application.ConnectRepl$" // developer API
+      ),
+
+      // SparkSession
+      // developer API
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession.newDataFrame"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession.newDataset"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession.execute"),
+      // Experimental
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession.addArtifact"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession.addArtifacts"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession.registerClassFinder"),
+      // public
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession.interruptAll"),
+      // SparkSession#Builder
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.remote"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.client"),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.build" // deprecated
+      ),
+      ProblemFilters.exclude[DirectMissingMethodProblem](
+        "org.apache.spark.sql.SparkSession#Builder.create"),
+
+      // Steaming API
+      ProblemFilters.exclude[MissingTypesProblem](
+        "org.apache.spark.sql.streaming.DataStreamWriter" // Client version extends Logging
+      ),
+      ProblemFilters.exclude[Problem](
+        "org.apache.spark.sql.streaming.DataStreamWriter.*" // Mute missing Logging methods
+      ),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.RemoteStreamingQuery"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.streaming.RemoteStreamingQuery$"))
+
+    checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules)
+  }
+
+  /**
+   * MiMa takes a new jar and an old jar as inputs and then reports all incompatibilities found in
+   * the new jar. The incompatibility result is then filtered using include and exclude rules.
+   * Include rules are first applied to find all client classes that need to be checked. Then
+   * exclude rules are applied to filter out all unsupported methods in the client classes.
+   */
+  private def checkMiMaCompatibility(
+      newJar: File,
+      oldJar: File,
       includedRules: Seq[IncludeByName],
       excludeRules: Seq[ProblemFilter]): List[Problem] = {
-    val mima = new MiMaLib(Seq(clientJar, targetJar))
-    val allProblems = mima.collectProblems(targetJar, clientJar, List.empty)
+    val mima = new MiMaLib(Seq(newJar, oldJar))
+    val allProblems = mima.collectProblems(oldJar, newJar, List.empty)
     val problems = allProblems
       .filter { p =>
         includedRules.exists(rule => rule(p))
@@ -274,46 +424,18 @@ object CheckConnectJvmClientCompatibility {
     problems
   }
 
-  private def checkDatasetApiCompatibility(clientJar: File, sqlJar: File): Seq[String] = {
-
-    def methods(jar: File, className: String): Seq[String] = {
-      val classLoader: URLClassLoader =
-        new ChildFirstURLClassLoader(Seq(jar.toURI.toURL).toArray, this.getClass.getClassLoader)
-      val mirror = runtimeMirror(classLoader)
-      // scalastyle:off classforname
-      val classSymbol =
-        mirror.classSymbol(Class.forName(className, false, classLoader))
-      // scalastyle:on classforname
-      classSymbol.typeSignature.members
-        .filter(_.isMethod)
-        .map(_.asMethod)
-        .filter(m => m.isPublic)
-        .map(_.fullName)
-        .toSeq
-    }
-
-    val className = "org.apache.spark.sql.Dataset"
-    val clientMethods = methods(clientJar, className)
-    val sqlMethods = methods(sqlJar, className)
-    // Exclude some public methods that must be added through `exceptionMethods`
-    val exceptionMethods =
-      Seq("org.apache.spark.sql.Dataset.collectResult", "org.apache.spark.sql.Dataset.plan")
-
-    // Find new public functions that are not in sql module `Dataset`.
-    clientMethods.diff(sqlMethods).diff(exceptionMethods)
-  }
-
   private def appendMimaCheckErrorMessageIfNeeded(
       resultWriter: Writer,
       problems: List[Problem],
       clientModule: File,
       targetModule: File,
-      targetName: String): Unit = {
+      targetName: String,
+      description: String = "client"): Unit = {
     if (problems.nonEmpty) {
       resultWriter.write(
         s"ERROR: Comparing Client jar: $clientModule and $targetName jar: $targetModule \n")
       resultWriter.write(s"problems with $targetName module: \n")
-      resultWriter.write(s"${problems.map(p => p.description("client")).mkString("\n")}")
+      resultWriter.write(s"${problems.map(p => p.description(description)).mkString("\n")}")
       resultWriter.write("\n")
       resultWriter.write(
         "Exceptions to binary compatibility can be added in " +
@@ -321,21 +443,6 @@ object CheckConnectJvmClientCompatibility {
     }
   }
 
-  private def appendIncompatibleDatasetApisErrorMessageIfNeeded(
-      resultWriter: Writer,
-      incompatibleApis: Seq[String]): Unit = {
-    if (incompatibleApis.nonEmpty) {
-      resultWriter.write(
-        "ERROR: The Dataset apis only exist in the connect client " +
-          "module and not belong to the sql module include: \n")
-      resultWriter.write(incompatibleApis.mkString("\n"))
-      resultWriter.write("\n")
-      resultWriter.write(
-        "Exceptions can be added to exceptionMethods in " +
-          "'CheckConnectJvmClientCompatibility#checkDatasetApiCompatibility'\n")
-    }
-  }
-
   private case class IncludeByName(name: String) extends ProblemFilter {
     private[this] val pattern =
       Pattern.compile(name.split("\\*", -1).map(Pattern.quote).mkString(".*"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org