You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/01 18:12:51 UTC
[spark] branch master updated: [SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fb9d7067845 [SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs
fb9d7067845 is described below
commit fb9d706784557ef0fe9e17d59b7096374658954e
Author: vicennial <ve...@databricks.com>
AuthorDate: Wed Feb 1 14:12:38 2023 -0400
[SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs
### What changes were proposed in this pull request?
This PR adds support for "simple" scalar Scala UDFs for the Spark Connect Scala/JVM Client. “Simple” here refers to UDFs that utilize no client-specific class files (e.g REPL-generated) and JARs. Essentially, a “simple” UDF may only reference in-built libraries and classes defined within the scope of the UDF.
A user would then be able to do the following (example):
```
def myFunc(x: Int): Int = x + 5
val myUdf = udf(myFunc _)
df = df.select(myUdf(Column("id")))
```
#### Implementation Details:
A shared JVM object `UdfPacket` is introduced in the common package to encapsulate the Scala UDF and its encoders (via Agnostic Encoders) such that it could be serialized/deserialized on the client and server respectively.
Further, a new protobuf message `ScalarScalaUDF` is introduced to transmit Scala/JVM specific information to the server (such as the above serialized JVM object).
### Why are the changes needed?
UDFs are crucial for the completeness of the Spark Connect Scala/JVM client. We introduce this component incrementally.
### Does this PR introduce _any_ user-facing change?
Yes, users are now able to run "simple" scalar Scala UDFs through the Scala/JVM client.
### How was this patch tested?
Unit test + Integration test
Closes #39850 from vicennial/SPARK-42283.
Authored-by: vicennial <ve...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../sql/expressions/UserDefinedFunction.scala | 146 ++++++++++++
.../scala/org/apache/spark/sql/functions.scala | 255 +++++++++++++++++++++
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 14 ++
.../spark/sql/UserDefinedFunctionSuite.scala | 53 +++++
connector/connect/common/pom.xml | 12 +
.../main/protobuf/spark/connect/expressions.proto | 12 +
.../spark/sql/connect/common/UdfPacket.scala | 70 ++++++
.../sql/connect/planner/SparkConnectPlanner.scala | 27 +++
.../pyspark/sql/connect/proto/expressions_pb2.py | 22 +-
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 66 +++++-
10 files changed, 671 insertions(+), 6 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
new file mode 100644
index 00000000000..0fe47092e4e
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.expressions
+
+import scala.collection.JavaConverters._
+import scala.reflect.runtime.universe.TypeTag
+
+import com.google.protobuf.ByteString
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.common.UdfPacket
+import org.apache.spark.util.Utils
+
+/**
+ * A user-defined function. To create one, use the `udf` functions in `functions`.
+ *
+ * As an example:
+ * {{{
+ * // Define a UDF that returns true or false based on some numeric score.
+ * val predict = udf((score: Double) => score > 0.5)
+ *
+ * // Projects a column that adds a prediction column based on the score column.
+ * df.select( predict(df("score")) )
+ * }}}
+ *
+ * @since 3.4.0
+ */
+sealed abstract class UserDefinedFunction {
+
+ /**
+ * Returns true when the UDF can return a nullable value.
+ *
+ * @since 3.4.0
+ */
+ def nullable: Boolean
+
+ /**
+ * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the
+ * same input.
+ *
+ * @since 3.4.0
+ */
+ def deterministic: Boolean
+
+ /**
+ * Returns an expression that invokes the UDF, using the given arguments.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def apply(exprs: Column*): Column
+
+ /**
+ * Updates UserDefinedFunction with a given name.
+ *
+ * @since 3.4.0
+ */
+ def withName(name: String): UserDefinedFunction
+
+ /**
+ * Updates UserDefinedFunction to non-nullable.
+ *
+ * @since 3.4.0
+ */
+ def asNonNullable(): UserDefinedFunction
+
+ /**
+ * Updates UserDefinedFunction to nondeterministic.
+ *
+ * @since 3.4.0
+ */
+ def asNondeterministic(): UserDefinedFunction
+}
+
+/**
+ * Holder class for a scalar user-defined function and it's input/output encoder(s).
+ */
+case class ScalarUserDefinedFunction(
+ function: AnyRef,
+ inputEncoders: Seq[AgnosticEncoder[_]],
+ outputEncoder: AgnosticEncoder[_],
+ name: Option[String],
+ override val nullable: Boolean,
+ override val deterministic: Boolean)
+ extends UserDefinedFunction {
+
+ private[this] lazy val udf = {
+ val udfPacketBytes = Utils.serialize(UdfPacket(function, inputEncoders, outputEncoder))
+ val scalaUdfBuilder = proto.ScalarScalaUDF
+ .newBuilder()
+ .setPayload(ByteString.copyFrom(udfPacketBytes))
+ .setNullable(nullable)
+
+ scalaUdfBuilder.build()
+ }
+
+ @scala.annotation.varargs
+ override def apply(exprs: Column*): Column = Column { builder =>
+ val udfBuilder = builder.getCommonInlineUserDefinedFunctionBuilder
+ udfBuilder
+ .setDeterministic(deterministic)
+ .setScalarScalaUdf(udf)
+ .addAllArguments(exprs.map(_.expr).asJava)
+
+ name.foreach(udfBuilder.setFunctionName)
+ }
+
+ override def withName(name: String): ScalarUserDefinedFunction = copy(name = Option(name))
+
+ override def asNonNullable(): ScalarUserDefinedFunction = copy(nullable = false)
+
+ override def asNondeterministic(): ScalarUserDefinedFunction = copy(deterministic = false)
+}
+
+object ScalarUserDefinedFunction {
+ private[sql] def apply(
+ function: AnyRef,
+ returnType: TypeTag[_],
+ parameterTypes: TypeTag[_]*): ScalarUserDefinedFunction = {
+
+ ScalarUserDefinedFunction(
+ function = function,
+ inputEncoders = parameterTypes.map(tag => ScalaReflection.encoderFor(tag)),
+ outputEncoder = ScalaReflection.encoderFor(returnType),
+ name = None,
+ nullable = true,
+ deterministic = true)
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index bae394785be..61174f1921e 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -19,10 +19,13 @@ package org.apache.spark.sql
import java.math.{BigDecimal => JBigDecimal}
import java.time.LocalDate
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
import com.google.protobuf.ByteString
import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.client.unsupported
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction}
/**
* Commonly used functions available for DataFrame operations.
@@ -80,4 +83,256 @@ object functions {
case _ => unsupported(s"literal $literal not supported (yet).")
}
}
+
+ // scalastyle:off line.size.limit
+
+ /**
+ * Defines a Scala closure of 0 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[RT: TypeTag](f: () => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(f, typeTag[RT])
+ }
+
+ /**
+ * Defines a Scala closure of 1 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1])
+ }
+
+ /**
+ * Defines a Scala closure of 2 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2])
+ }
+
+ /**
+ * Defines a Scala closure of 3 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
+ f: (A1, A2, A3) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3])
+ }
+
+ /**
+ * Defines a Scala closure of 4 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
+ f: (A1, A2, A3, A4) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3], typeTag[A4])
+ }
+
+ /**
+ * Defines a Scala closure of 5 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
+ f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(
+ f,
+ typeTag[RT],
+ typeTag[A1],
+ typeTag[A2],
+ typeTag[A3],
+ typeTag[A4],
+ typeTag[A5])
+ }
+
+ /**
+ * Defines a Scala closure of 6 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag](f: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(
+ f,
+ typeTag[RT],
+ typeTag[A1],
+ typeTag[A2],
+ typeTag[A3],
+ typeTag[A4],
+ typeTag[A5],
+ typeTag[A6])
+ }
+
+ /**
+ * Defines a Scala closure of 7 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(
+ f,
+ typeTag[RT],
+ typeTag[A1],
+ typeTag[A2],
+ typeTag[A3],
+ typeTag[A4],
+ typeTag[A5],
+ typeTag[A6],
+ typeTag[A7])
+ }
+
+ /**
+ * Defines a Scala closure of 8 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(
+ f,
+ typeTag[RT],
+ typeTag[A1],
+ typeTag[A2],
+ typeTag[A3],
+ typeTag[A4],
+ typeTag[A5],
+ typeTag[A6],
+ typeTag[A7],
+ typeTag[A8])
+ }
+
+ /**
+ * Defines a Scala closure of 9 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(
+ f,
+ typeTag[RT],
+ typeTag[A1],
+ typeTag[A2],
+ typeTag[A3],
+ typeTag[A4],
+ typeTag[A5],
+ typeTag[A6],
+ typeTag[A7],
+ typeTag[A8],
+ typeTag[A9])
+ }
+
+ /**
+ * Defines a Scala closure of 10 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 3.4.0
+ */
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction = {
+ ScalarUserDefinedFunction(
+ f,
+ typeTag[RT],
+ typeTag[A1],
+ typeTag[A2],
+ typeTag[A3],
+ typeTag[A4],
+ typeTag[A5],
+ typeTag[A6],
+ typeTag[A7],
+ typeTag[A8],
+ typeTag[A9],
+ typeTag[A10])
+ }
+ // scalastyle:off line.size.limit
+
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index e31f121ca10..db2b8b26987 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructField, StructType}
class ClientE2ETestSuite extends RemoteSparkSession {
@@ -48,6 +49,19 @@ class ClientE2ETestSuite extends RemoteSparkSession {
assert(array(2).getLong(0) == 2)
}
+ test("simple udf test") {
+
+ def dummyUdf(x: Int): Int = x + 5
+ val myUdf = udf(dummyUdf _)
+ val df = spark.range(5).select(myUdf(Column("id")))
+
+ val result = df.collectResult()
+ assert(result.length == 5)
+ result.toArray.zipWithIndex.foreach { case (v, idx) =>
+ assert(v.getInt(0) == idx + 5)
+ }
+ }
+
// TODO test large result when we can create table or view
// test("test spark large result")
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
new file mode 100644
index 00000000000..b0d92a223c6
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.reflect.runtime.universe.typeTag
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.connect.common.UdfPacket
+import org.apache.spark.sql.functions.udf
+import org.apache.spark.util.Utils
+
+class UserDefinedFunctionSuite
+ extends AnyFunSuite // scalastyle:ignore funsuite
+ with BeforeAndAfterEach {
+
+ test("udf and encoder serialization") {
+ def func(x: Int): Int = x + 1
+
+ val myUdf = udf(func _)
+ val colWithUdf = myUdf(Column("dummy"))
+
+ val udfExpr = colWithUdf.expr.getCommonInlineUserDefinedFunction
+ assert(udfExpr.getDeterministic)
+ assert(udfExpr.getArgumentsCount == 1)
+ assert(udfExpr.getArguments(0) == Column("dummy").expr)
+ val udfObj = udfExpr.getScalarScalaUdf
+
+ assert(udfObj.getNullable)
+
+ val deSer = Utils.deserialize[UdfPacket](udfObj.getPayload.toByteArray)
+
+ assert(deSer.function.asInstanceOf[Int => Int](5) == func(5))
+ assert(deSer.outputEncoder == ScalaReflection.encoderFor(typeTag[Int]))
+ assert(deSer.inputEncoders == Seq(ScalaReflection.encoderFor(typeTag[Int])))
+ }
+}
diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml
index a37f87dda1e..eb1e4cae34d 100644
--- a/connector/connect/common/pom.xml
+++ b/connector/connect/common/pom.xml
@@ -38,6 +38,18 @@
<tomcat.annotations.api.version>6.0.53</tomcat.annotations.api.version>
</properties>
<dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 5b27d4593db..66361883321 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -307,6 +307,7 @@ message CommonInlineUserDefinedFunction {
// (Required) Indicate the function type of the user-defined function.
oneof function {
PythonUDF python_udf = 4;
+ ScalarScalaUDF scalar_scala_udf = 5;
}
}
@@ -319,3 +320,14 @@ message PythonUDF {
bytes command = 3;
}
+message ScalarScalaUDF {
+ // (Required) Serialized JVM object containing UDF definition, input encoders and output encoder
+ bytes payload = 1;
+ // (Optional) Input type(s) of the UDF
+ repeated DataType inputTypes = 2;
+ // (Required) Output type of the UDF
+ DataType outputType = 3;
+ // (Required) True if the UDF can return null value
+ bool nullable = 4;
+}
+
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala
new file mode 100644
index 00000000000..6829b8d1b21
--- /dev/null
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.common
+
+import com.google.protobuf.ByteString
+import java.io.{InputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
+
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+
+/**
+ * A wrapper class around the UDF and it's Input/Output [[AgnosticEncoder]](s).
+ *
+ * This class is shared between the client and the server to allow for serialization and
+ * deserialization of the JVM object.
+ *
+ * @param function
+ * The UDF
+ * @param inputEncoders
+ * A list of [[AgnosticEncoder]](s) for all input arguments of the UDF
+ * @param outputEncoder
+ * An [[AgnosticEncoder]] for the output of the UDF
+ */
+@SerialVersionUID(8866761834651399125L)
+case class UdfPacket(
+ function: AnyRef,
+ inputEncoders: Seq[AgnosticEncoder[_]],
+ outputEncoder: AgnosticEncoder[_])
+ extends Serializable {
+
+ def writeTo(out: OutputStream): Unit = {
+ val oos = new ObjectOutputStream(out)
+ oos.writeObject(this)
+ oos.flush()
+ }
+
+ def toByteString: ByteString = {
+ val out = ByteString.newOutput()
+ writeTo(out)
+ out.toByteString
+ }
+}
+
+object UdfPacket {
+ def apply(in: InputStream): UdfPacket = {
+ val ois = new ObjectInputStream(in)
+ ois.readObject().asInstanceOf[UdfPacket]
+ }
+
+ def apply(bytes: ByteString): UdfPacket = {
+ val in = bytes.newInput()
+ try UdfPacket(in)
+ finally {
+ in.close()
+ }
+ }
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 9b5c4b93f62..51d115ef1ca 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -29,6 +29,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
@@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
+import org.apache.spark.sql.connect.common.UdfPacket
import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue}
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -831,12 +833,37 @@ class SparkConnectPlanner(val session: SparkSession) {
fun.getFunctionCase match {
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
transformPythonUDF(fun)
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+ transformScalarScalaUDF(fun)
case _ =>
throw InvalidPlanInput(
s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
}
}
+ /**
+ * Translates a Scalar Scala user-defined function from proto to the Catalyst expression.
+ *
+ * @param fun
+ * Proto representation of the Scalar Scalar user-defined function.
+ * @return
+ * ScalaUDF.
+ */
+ private def transformScalarScalaUDF(fun: proto.CommonInlineUserDefinedFunction): ScalaUDF = {
+ val udf = fun.getScalarScalaUdf
+ val udfPacket =
+ Utils.deserialize[UdfPacket](udf.getPayload.toByteArray, Utils.getContextOrSparkClassLoader)
+ ScalaUDF(
+ function = udfPacket.function,
+ dataType = udfPacket.outputEncoder.dataType,
+ children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+ inputEncoders = udfPacket.inputEncoders.map(e => Option(ExpressionEncoder(e))),
+ outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
+ udfName = Option(fun.getFunctionName),
+ nullable = udf.getNullable,
+ udfDeterministic = fun.getDeterministic)
+ }
+
/**
* Translates a Python user-defined function from proto to the Catalyst expression.
*
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index f320eee54e0..3a06e80c21e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
)
@@ -65,6 +65,7 @@ _COMMONINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
"CommonInlineUserDefinedFunction"
]
_PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"]
+_SCALARSCALAUDF = DESCRIPTOR.message_types_by_name["ScalarScalaUDF"]
_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
"FrameType"
]
@@ -283,6 +284,17 @@ PythonUDF = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(PythonUDF)
+ScalarScalaUDF = _reflection.GeneratedProtocolMessageType(
+ "ScalarScalaUDF",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SCALARSCALAUDF,
+ "__module__": "spark.connect.expressions_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ScalarScalaUDF)
+ },
+)
+_sym_db.RegisterMessage(ScalarScalaUDF)
+
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
@@ -332,7 +344,9 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784
_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
_COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5098
- _PYTHONUDF._serialized_start = 5100
- _PYTHONUDF._serialized_end = 5199
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5173
+ _PYTHONUDF._serialized_start = 5175
+ _PYTHONUDF._serialized_end = 5274
+ _SCALARSCALAUDF._serialized_start = 5277
+ _SCALARSCALAUDF._serialized_end = 5461
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index d8b0485017c..604672a9ad7 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1100,6 +1100,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
DETERMINISTIC_FIELD_NUMBER: builtins.int
ARGUMENTS_FIELD_NUMBER: builtins.int
PYTHON_UDF_FIELD_NUMBER: builtins.int
+ SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
function_name: builtins.str
"""(Required) Name of the user-defined function."""
deterministic: builtins.bool
@@ -1111,6 +1112,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
"""(Optional) Function arguments. Empty arguments are allowed."""
@property
def python_udf(self) -> global___PythonUDF: ...
+ @property
+ def scalar_scala_udf(self) -> global___ScalarScalaUDF: ...
def __init__(
self,
*,
@@ -1118,10 +1121,18 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
deterministic: builtins.bool = ...,
arguments: collections.abc.Iterable[global___Expression] | None = ...,
python_udf: global___PythonUDF | None = ...,
+ scalar_scala_udf: global___ScalarScalaUDF | None = ...,
) -> None: ...
def HasField(
self,
- field_name: typing_extensions.Literal["function", b"function", "python_udf", b"python_udf"],
+ field_name: typing_extensions.Literal[
+ "function",
+ b"function",
+ "python_udf",
+ b"python_udf",
+ "scalar_scala_udf",
+ b"scalar_scala_udf",
+ ],
) -> builtins.bool: ...
def ClearField(
self,
@@ -1136,11 +1147,13 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
b"function_name",
"python_udf",
b"python_udf",
+ "scalar_scala_udf",
+ b"scalar_scala_udf",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["function", b"function"]
- ) -> typing_extensions.Literal["python_udf"] | None: ...
+ ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf"] | None: ...
global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction
@@ -1171,3 +1184,52 @@ class PythonUDF(google.protobuf.message.Message):
) -> None: ...
global___PythonUDF = PythonUDF
+
+class ScalarScalaUDF(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ PAYLOAD_FIELD_NUMBER: builtins.int
+ INPUTTYPES_FIELD_NUMBER: builtins.int
+ OUTPUTTYPE_FIELD_NUMBER: builtins.int
+ NULLABLE_FIELD_NUMBER: builtins.int
+ payload: builtins.bytes
+ """(Required) Serialized JVM object containing UDF definition, input encoders and output encoder"""
+ @property
+ def inputTypes(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.types_pb2.DataType
+ ]:
+ """(Optional) Input type(s) of the UDF"""
+ @property
+ def outputType(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+ """(Required) Output type of the UDF"""
+ nullable: builtins.bool
+ """(Required) True if the UDF can return null value"""
+ def __init__(
+ self,
+ *,
+ payload: builtins.bytes = ...,
+ inputTypes: collections.abc.Iterable[pyspark.sql.connect.proto.types_pb2.DataType]
+ | None = ...,
+ outputType: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+ nullable: builtins.bool = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["outputType", b"outputType"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "inputTypes",
+ b"inputTypes",
+ "nullable",
+ b"nullable",
+ "outputType",
+ b"outputType",
+ "payload",
+ b"payload",
+ ],
+ ) -> None: ...
+
+global___ScalarScalaUDF = ScalarScalaUDF
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org