You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/02 00:42:52 UTC
[spark] branch branch-3.4 updated: [SPARK-42631][CONNECT] Support custom extensions in Scala client
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new a1d5e89555a [SPARK-42631][CONNECT] Support custom extensions in Scala client
a1d5e89555a is described below
commit a1d5e89555aa0c516807f3b3e0432728568cfb2a
Author: Tom van Bussel <to...@databricks.com>
AuthorDate: Wed Mar 1 20:42:30 2023 -0400
[SPARK-42631][CONNECT] Support custom extensions in Scala client
### What changes were proposed in this pull request?
This PR adds public interfaces for creating `Dataset` and `Column` instances, and for executing commands. These interfaces only allow creating `Plan`s and `Expression`s that contain an `extension` to limit what we need to expose.
### Why are the changes needed?
Required to implement extensions to the Scala Spark Connect client.
### Does this PR introduce _any_ user-facing change?
Yes, adds new public interfaces (see above).
### How was this patch tested?
Added unit tests.
Closes #40234 from tomvanbussel/SPARK-34827.
Authored-by: Tom van Bussel <to...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
(cherry picked from commit a9c5efa413335c02621d79242e83595c4b932bf0)
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../main/scala/org/apache/spark/sql/Column.scala | 6 ++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 3 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 20 ++++++++++++-
.../scala/org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++
.../apache/spark/sql/PlanGenerationTestSuite.scala | 33 ++++++++++++++++++++-
.../explain-results/expression_extension.explain | 2 ++
.../explain-results/relation_extension.explain | 1 +
.../query-tests/queries/drop_multiple_column.json | 2 +-
.../query-tests/queries/drop_single_column.json | 2 +-
.../query-tests/queries/expression_extension.json | 26 ++++++++++++++++
.../queries/expression_extension.proto.bin | Bin 0 -> 127 bytes
..._single_column.json => relation_extension.json} | 10 ++-----
.../queries/relation_extension.proto.bin | Bin 0 -> 108 bytes
.../sql/connect/ProtoToParsedPlanTestSuite.scala | 13 +++++++-
14 files changed, 117 insertions(+), 13 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
index c39d5c9757e..4212747f57a 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering
import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection
@@ -1312,6 +1313,11 @@ private[sql] object Column {
new Column(builder.build())
}
+ @DeveloperApi
+ def apply(extension: com.google.protobuf.Any): Column = {
+ apply(_.setExtension(extension))
+ }
+
private[sql] def fn(name: String, inputs: Column*): Column = {
fn(name, isDistinct = false, inputs: _*)
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6ef20595630..1cd3c541950 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, StringEncoder, UnboundRowEncoder}
@@ -120,7 +121,7 @@ import org.apache.spark.util.Utils
*/
class Dataset[T] private[sql] (
val sparkSession: SparkSession,
- private[sql] val plan: proto.Plan,
+ @DeveloperApi val plan: proto.Plan,
val encoder: AgnosticEncoder[T])
extends Serializable {
// Make sure we don't forget to set plan id.
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 12d984f150d..48b86474b9e 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.arrow.memory.RootAllocator
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
@@ -261,6 +261,18 @@ class SparkSession private[sql] (
new Dataset[T](this, plan, encoder)
}
+ @DeveloperApi
+ def newDataFrame(extension: com.google.protobuf.Any): DataFrame = {
+ newDataset(extension, UnboundRowEncoder)
+ }
+
+ @DeveloperApi
+ def newDataset[T](
+ extension: com.google.protobuf.Any,
+ encoder: AgnosticEncoder[T]): Dataset[T] = {
+ newDataset(encoder)(_.setExtension(extension))
+ }
+
private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
val builder = proto.Command.newBuilder()
f(builder)
@@ -287,6 +299,12 @@ class SparkSession private[sql] (
client.execute(plan).asScala.foreach(_ => ())
}
+ @DeveloperApi
+ def execute(extension: com.google.protobuf.Any): Unit = {
+ val command = proto.Command.newBuilder().setExtension(extension).build()
+ execute(command)
+ }
+
/**
* This resets the plan id generator so we can produce plans that are comparable.
*
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 4a26a32353a..43b0cd2674c 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -134,4 +134,16 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
df.groupBy().pivot(Column("c"), Seq(Column("col")))
}
}
+
+ test("command extension") {
+ val extension = proto.ExamplePluginCommand.newBuilder().setCustomField("abc").build()
+ val command = proto.Command
+ .newBuilder()
+ .setExtension(com.google.protobuf.Any.pack(extension))
+ .build()
+ val expectedPlan = proto.Plan.newBuilder().setCommand(command).build()
+ ss.execute(com.google.protobuf.Any.pack(extension))
+ val actualPlan = service.getAndClearLatestInputPlan()
+ assert(actualPlan.equals(expectedPlan))
+ }
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 0b198ab8f70..80ca5b43622 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.{Failure, Success, Try}
import com.google.protobuf.util.JsonFormat
+import com.google.protobuf.util.JsonFormat.TypeRegistry
import io.grpc.inprocess.InProcessChannelBuilder
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
@@ -100,7 +101,14 @@ class PlanGenerationTestSuite
"query-tests",
"test-data")
- private val printer = JsonFormat.printer()
+ private val registry = TypeRegistry
+ .newBuilder()
+ .add(proto.ExamplePluginRelation.getDescriptor)
+ .add(proto.ExamplePluginExpression.getDescriptor)
+ .add(proto.ExamplePluginCommand.getDescriptor)
+ .build()
+
+ private val printer = JsonFormat.printer().usingTypeRegistry(registry)
private var session: SparkSession = _
@@ -2007,4 +2015,27 @@ class PlanGenerationTestSuite
fn.min("id").over(Window.orderBy("a").rangeBetween(2L, 3L)),
fn.count(Column("id")).over())
}
+
+ /* Extensions */
+ test("relation extension") {
+ val input = proto.ExamplePluginRelation
+ .newBuilder()
+ .setInput(simple.plan.getRoot)
+ .build()
+ session.newDataFrame(com.google.protobuf.Any.pack(input))
+ }
+
+ test("expression extension") {
+ val extension = proto.ExamplePluginExpression
+ .newBuilder()
+ .setChild(
+ proto.Expression
+ .newBuilder()
+ .setUnresolvedAttribute(proto.Expression.UnresolvedAttribute
+ .newBuilder()
+ .setUnparsedIdentifier("id")))
+ .setCustomField("abc")
+ .build()
+ simple.select(Column(com.google.protobuf.Any.pack(extension)))
+ }
}
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension.explain
new file mode 100644
index 00000000000..7426332004a
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension.explain
@@ -0,0 +1,2 @@
+Project [id#0L AS abc#0L]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension.explain
new file mode 100644
index 00000000000..df724a7dd18
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension.explain
@@ -0,0 +1 @@
+LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json
index 6a8546d9326..3ec19cf8c4c 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/drop_multiple_column.json
@@ -11,7 +11,7 @@
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
- "cols": [{
+ "columns": [{
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
index 7f4cd227186..1fe8563e0fd 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
@@ -11,7 +11,7 @@
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
- "cols": [{
+ "columns": [{
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.json b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.json
new file mode 100644
index 00000000000..acfb3cc2333
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.json
@@ -0,0 +1,26 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "project": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "expressions": [{
+ "extension": {
+ "@type": "type.googleapis.com/spark.connect.ExamplePluginExpression",
+ "child": {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "id"
+ }
+ },
+ "customField": "abc"
+ }
+ }]
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.proto.bin
new file mode 100644
index 00000000000..24669eba642
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension.proto.bin differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.json
similarity index 63%
copy from connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
copy to connector/connect/common/src/test/resources/query-tests/queries/relation_extension.json
index 7f4cd227186..47ceba13ca7 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/drop_single_column.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.json
@@ -2,7 +2,8 @@
"common": {
"planId": "1"
},
- "drop": {
+ "extension": {
+ "@type": "type.googleapis.com/spark.connect.ExamplePluginRelation",
"input": {
"common": {
"planId": "0"
@@ -10,11 +11,6 @@
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
- },
- "cols": [{
- "unresolvedAttribute": {
- "unparsedIdentifier": "b"
- }
- }]
+ }
}
}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.proto.bin
new file mode 100644
index 00000000000..680bb550eca
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension.proto.bin differ
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
index fdf6862c51a..26599b30b99 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
@@ -23,12 +23,13 @@ import java.util
import scala.util.{Failure, Success, Try}
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
+import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog}
import org.apache.spark.sql.connector.expressions.Transform
@@ -56,6 +57,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
*/
// scalastyle:on
class ProtoToParsedPlanTestSuite extends SparkFunSuite with SharedSparkSession {
+ override def sparkConf: SparkConf = {
+ super.sparkConf
+ .set(
+ Connect.CONNECT_EXTENSIONS_RELATION_CLASSES.key,
+ "org.apache.spark.sql.connect.plugin.ExampleRelationPlugin")
+ .set(
+ Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES.key,
+ "org.apache.spark.sql.connect.plugin.ExampleExpressionPlugin")
+ }
+
protected val baseResourcePath: Path = {
getWorkspaceFilePath(
"connector",
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org