You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/02 00:42:52 UTC

[spark] branch branch-3.4 updated: [SPARK-42631][CONNECT] Support custom extensions in Scala client

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new a1d5e89555a [SPARK-42631][CONNECT] Support custom extensions in Scala client
a1d5e89555a is described below

commit a1d5e89555aa0c516807f3b3e0432728568cfb2a
Author: Tom van Bussel <to...@databricks.com>
AuthorDate: Wed Mar 1 20:42:30 2023 -0400

    [SPARK-42631][CONNECT] Support custom extensions in Scala client
    
    ### What changes were proposed in this pull request?
    
    This PR adds public interfaces for creating `Dataset` and `Column` instances, and for executing commands. These interfaces only allow creating `Plan`s and `Expression`s that contain an `extension` to limit what we need to expose.
    
    ### Why are the changes needed?
    
    Required to implement extensions to the Scala Spark Connect client.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, adds new public interfaces (see above).
    
    ### How was this patch tested?
    
    Added unit tests.
    
    Closes #40234 from tomvanbussel/SPARK-34827.
    
    Authored-by: Tom van Bussel <to...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
    (cherry picked from commit a9c5efa413335c02621d79242e83595c4b932bf0)
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Column.scala   |   6 ++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |   3 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |  20 ++++++++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  12 ++++++++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |  33 ++++++++++++++++++++-
 .../explain-results/expression_extension.explain   |   2 ++
 .../explain-results/relation_extension.explain     |   1 +
 .../query-tests/queries/drop_multiple_column.json  |   2 +-
 .../query-tests/queries/drop_single_column.json    |   2 +-
 .../query-tests/queries/expression_extension.json  |  26 ++++++++++++++++
 .../queries/expression_extension.proto.bin         | Bin 0 -> 127 bytes
 ..._single_column.json => relation_extension.json} |  10 ++-----
 .../queries/relation_extension.proto.bin           | Bin 0 -> 108 bytes
 .../sql/connect/ProtoToParsedPlanTestSuite.scala   |  13 +++++++-
 14 files changed, 117 insertions(+), 13 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
index c39d5c9757e..4212747f57a 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql
 
 import scala.collection.JavaConverters._
 
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering
 import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection
@@ -1312,6 +1313,11 @@ private[sql] object Column {
     new Column(builder.build())
   }
 
+  @DeveloperApi
+  def apply(extension: com.google.protobuf.Any): Column = {
+    apply(_.setExtension(extension))
+  }
+
   private[sql] def fn(name: String, inputs: Column*): Column = {
     fn(name, isDistinct = false, inputs: _*)
   }
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6ef20595630..1cd3c541950 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.control.NonFatal
 
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, StringEncoder, UnboundRowEncoder}
@@ -120,7 +121,7 @@ import org.apache.spark.util.Utils
  */
 class Dataset[T] private[sql] (
     val sparkSession: SparkSession,
-    private[sql] val plan: proto.Plan,
+    @DeveloperApi val plan: proto.Plan,
     val encoder: AgnosticEncoder[T])
     extends Serializable {
   // Make sure we don't forget to set plan id.
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 12d984f150d..48b86474b9e 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.arrow.memory.RootAllocator
 
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
@@ -261,6 +261,18 @@ class SparkSession private[sql] (
     new Dataset[T](this, plan, encoder)
   }
 
+  @DeveloperApi
+  def newDataFrame(extension: com.google.protobuf.Any): DataFrame = {
+    newDataset(extension, UnboundRowEncoder)
+  }
+
+  @DeveloperApi
+  def newDataset[T](
+      extension: com.google.protobuf.Any,
+      encoder: AgnosticEncoder[T]): Dataset[T] = {
+    newDataset(encoder)(_.setExtension(extension))
+  }
+
   private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
     val builder = proto.Command.newBuilder()
     f(builder)
@@ -287,6 +299,12 @@ class SparkSession private[sql] (
     client.execute(plan).asScala.foreach(_ => ())
   }
 
+  @DeveloperApi
+  def execute(extension: com.google.protobuf.Any): Unit = {
+    val command = proto.Command.newBuilder().setExtension(extension).build()
+    execute(command)
+  }
+
   /**
    * This resets the plan id generator so we can produce plans that are comparable.
    *
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 4a26a32353a..43b0cd2674c 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -134,4 +134,16 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
       df.groupBy().pivot(Column("c"), Seq(Column("col")))
     }
   }
+
+  test("command extension") {
+    val extension = proto.ExamplePluginCommand.newBuilder().setCustomField("abc").build()
+    val command = proto.Command
+      .newBuilder()
+      .setExtension(com.google.protobuf.Any.pack(extension))
+      .build()
+    val expectedPlan = proto.Plan.newBuilder().setCommand(command).build()
+    ss.execute(com.google.protobuf.Any.pack(extension))
+    val actualPlan = service.getAndClearLatestInputPlan()
+    assert(actualPlan.equals(expectedPlan))
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 0b198ab8f70..80ca5b43622 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
 import scala.util.{Failure, Success, Try}
 
 import com.google.protobuf.util.JsonFormat
+import com.google.protobuf.util.JsonFormat.TypeRegistry
 import io.grpc.inprocess.InProcessChannelBuilder
 import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
 
@@ -100,7 +101,14 @@ class PlanGenerationTestSuite
     "query-tests",
     "test-data")
 
-  private val printer = JsonFormat.printer()
+  private val registry = TypeRegistry
+    .newBuilder()
+    .add(proto.ExamplePluginRelation.getDescriptor)
+    .add(proto.ExamplePluginExpression.getDescriptor)
+    .add(proto.ExamplePluginCommand.getDescriptor)
+    .build()
+
+  private val printer = JsonFormat.printer().usingTypeRegistry(registry)
 
   private var session: SparkSession = _
 
@@ -2007,4 +2015,27 @@ class PlanGenerationTestSuite
       fn.min("id").over(Window.orderBy("a").rangeBetween(2L, 3L)),
       fn.count(Column("id")).over())
   }
+
+  /* Extensions */
+  test("relation extension") {
+    val input = proto.ExamplePluginRelation
+      .newBuilder()
+      .setInput(simple.plan.getRoot)
+      .build()
+    session.newDataFrame(com.google.protobuf.Any.pack(input))
+  }
+
+  test("expression extension") {
+    val extension = proto.ExamplePluginExpression
+      .newBuilder()
+      .setChild(
+        proto.Expression
+          .newBuilder()
+          .setUnresolvedAttribute(proto.Expression.UnresolvedAttribute
+            .newBuilder()
+            .setUnparsedIdentifier("id")))
+      .setCustomField("abc")
+      .build()
+    simple.select(Column(com.google.protobuf.Any.pack(extension)))
+  }
 }
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension.explain
new file mode 100644
index 00000000000..7426332004a
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension.explain
@@ -0,0 +1,2 @@
+Project [id#0L AS abc#0L]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension.explain
new file mode 100644
index 00000000000..df724a7dd18
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension.explain
@@ -0,0 +1 @@
+LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json
index 6a8546d9326..3ec19cf8c4c 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json
@@ -11,7 +11,7 @@
         "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
       }
     },
-    "cols": [{
+    "columns": [{
       "unresolvedAttribute": {
         "unparsedIdentifier": "b"
       }
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
index 7f4cd227186..1fe8563e0fd 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
@@ -11,7 +11,7 @@
         "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
       }
     },
-    "cols": [{
+    "columns": [{
       "unresolvedAttribute": {
         "unparsedIdentifier": "b"
       }
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.json b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.json
new file mode 100644
index 00000000000..acfb3cc2333
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.json
@@ -0,0 +1,26 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "expressions": [{
+      "extension": {
+        "@type": "type.googleapis.com/spark.connect.ExamplePluginExpression",
+        "child": {
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "id"
+          }
+        },
+        "customField": "abc"
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.proto.bin
new file mode 100644
index 00000000000..24669eba642
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.proto.bin differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.json
similarity index 63%
copy from connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
copy to connector/connect/common/src/test/resources/query-tests/queries/relation_extension.json
index 7f4cd227186..47ceba13ca7 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.json
@@ -2,7 +2,8 @@
   "common": {
     "planId": "1"
   },
-  "drop": {
+  "extension": {
+    "@type": "type.googleapis.com/spark.connect.ExamplePluginRelation",
     "input": {
       "common": {
         "planId": "0"
@@ -10,11 +11,6 @@
       "localRelation": {
         "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
       }
-    },
-    "cols": [{
-      "unresolvedAttribute": {
-        "unparsedIdentifier": "b"
-      }
-    }]
+    }
   }
 }
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.proto.bin
new file mode 100644
index 00000000000..680bb550eca
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.proto.bin differ
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
index fdf6862c51a..26599b30b99 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
@@ -23,12 +23,13 @@ import java.util
 
 import scala.util.{Failure, Success, Try}
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker}
 import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry}
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog
 import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
+import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog}
 import org.apache.spark.sql.connector.expressions.Transform
@@ -56,6 +57,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
  */
 // scalastyle:on
 class ProtoToParsedPlanTestSuite extends SparkFunSuite with SharedSparkSession {
+  override def sparkConf: SparkConf = {
+    super.sparkConf
+      .set(
+        Connect.CONNECT_EXTENSIONS_RELATION_CLASSES.key,
+        "org.apache.spark.sql.connect.plugin.ExampleRelationPlugin")
+      .set(
+        Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES.key,
+        "org.apache.spark.sql.connect.plugin.ExampleExpressionPlugin")
+  }
+
   protected val baseResourcePath: Path = {
     getWorkspaceFilePath(
       "connector",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org