You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/01 18:12:51 UTC

[spark] branch master updated: [SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fb9d7067845 [SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs
fb9d7067845 is described below

commit fb9d706784557ef0fe9e17d59b7096374658954e
Author: vicennial <ve...@databricks.com>
AuthorDate: Wed Feb 1 14:12:38 2023 -0400

    [SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for "simple" scalar Scala UDFs for the Spark Connect Scala/JVM Client. “Simple” here refers to UDFs that utilize no client-specific class files (e.g REPL-generated) and JARs. Essentially, a “simple” UDF may only reference in-built libraries and classes defined within the scope of the UDF.
    
    A user would then be able to do the following (example):
    ```
    def myFunc(x: Int): Int = x + 5
    
    val myUdf = udf(myFunc _)
    df = df.select(myUdf(Column("id")))
    ```
    #### Implementation Details:
    
    A shared JVM object `UdfPacket` is introduced in the common package to encapsulate the Scala UDF and its encoders (via Agnostic Encoders) such that it could be serialized/deserialized on the client and server respectively.
    
    Further, a new protobuf message `ScalarScalaUDF` is introduced to transmit Scala/JVM specific information to the server (such as the above serialized JVM object).
    
    ### Why are the changes needed?
    
    UDFs are crucial for the completeness of the Spark Connect Scala/JVM client. We introduce this component incrementally.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users are now able to run "simple" scalar Scala UDFs through the Scala/JVM client.
    
    ### How was this patch tested?
    Unit test + Integration test
    
    Closes #39850 from vicennial/SPARK-42283.
    
    Authored-by: vicennial <ve...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../sql/expressions/UserDefinedFunction.scala      | 146 ++++++++++++
 .../scala/org/apache/spark/sql/functions.scala     | 255 +++++++++++++++++++++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  14 ++
 .../spark/sql/UserDefinedFunctionSuite.scala       |  53 +++++
 connector/connect/common/pom.xml                   |  12 +
 .../main/protobuf/spark/connect/expressions.proto  |  12 +
 .../spark/sql/connect/common/UdfPacket.scala       |  70 ++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  27 +++
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  22 +-
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  66 +++++-
 10 files changed, 671 insertions(+), 6 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
new file mode 100644
index 00000000000..0fe47092e4e
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.expressions
+
+import scala.collection.JavaConverters._
+import scala.reflect.runtime.universe.TypeTag
+
+import com.google.protobuf.ByteString
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.common.UdfPacket
+import org.apache.spark.util.Utils
+
+/**
+ * A user-defined function. To create one, use the `udf` functions in `functions`.
+ *
+ * As an example:
+ * {{{
+ *   // Define a UDF that returns true or false based on some numeric score.
+ *   val predict = udf((score: Double) => score > 0.5)
+ *
+ *   // Projects a column that adds a prediction column based on the score column.
+ *   df.select( predict(df("score")) )
+ * }}}
+ *
+ * @since 3.4.0
+ */
+sealed abstract class UserDefinedFunction {
+
+  /**
+   * Returns true when the UDF can return a nullable value.
+   *
+   * @since 3.4.0
+   */
+  def nullable: Boolean
+
+  /**
+   * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the
+   * same input.
+   *
+   * @since 3.4.0
+   */
+  def deterministic: Boolean
+
+  /**
+   * Returns an expression that invokes the UDF, using the given arguments.
+   *
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def apply(exprs: Column*): Column
+
+  /**
+   * Updates UserDefinedFunction with a given name.
+   *
+   * @since 3.4.0
+   */
+  def withName(name: String): UserDefinedFunction
+
+  /**
+   * Updates UserDefinedFunction to non-nullable.
+   *
+   * @since 3.4.0
+   */
+  def asNonNullable(): UserDefinedFunction
+
+  /**
+   * Updates UserDefinedFunction to nondeterministic.
+   *
+   * @since 3.4.0
+   */
+  def asNondeterministic(): UserDefinedFunction
+}
+
+/**
+ * Holder class for a scalar user-defined function and it's input/output encoder(s).
+ */
+case class ScalarUserDefinedFunction(
+    function: AnyRef,
+    inputEncoders: Seq[AgnosticEncoder[_]],
+    outputEncoder: AgnosticEncoder[_],
+    name: Option[String],
+    override val nullable: Boolean,
+    override val deterministic: Boolean)
+    extends UserDefinedFunction {
+
+  private[this] lazy val udf = {
+    val udfPacketBytes = Utils.serialize(UdfPacket(function, inputEncoders, outputEncoder))
+    val scalaUdfBuilder = proto.ScalarScalaUDF
+      .newBuilder()
+      .setPayload(ByteString.copyFrom(udfPacketBytes))
+      .setNullable(nullable)
+
+    scalaUdfBuilder.build()
+  }
+
+  @scala.annotation.varargs
+  override def apply(exprs: Column*): Column = Column { builder =>
+    val udfBuilder = builder.getCommonInlineUserDefinedFunctionBuilder
+    udfBuilder
+      .setDeterministic(deterministic)
+      .setScalarScalaUdf(udf)
+      .addAllArguments(exprs.map(_.expr).asJava)
+
+    name.foreach(udfBuilder.setFunctionName)
+  }
+
+  override def withName(name: String): ScalarUserDefinedFunction = copy(name = Option(name))
+
+  override def asNonNullable(): ScalarUserDefinedFunction = copy(nullable = false)
+
+  override def asNondeterministic(): ScalarUserDefinedFunction = copy(deterministic = false)
+}
+
+object ScalarUserDefinedFunction {
+  private[sql] def apply(
+      function: AnyRef,
+      returnType: TypeTag[_],
+      parameterTypes: TypeTag[_]*): ScalarUserDefinedFunction = {
+
+    ScalarUserDefinedFunction(
+      function = function,
+      inputEncoders = parameterTypes.map(tag => ScalaReflection.encoderFor(tag)),
+      outputEncoder = ScalaReflection.encoderFor(returnType),
+      name = None,
+      nullable = true,
+      deterministic = true)
+  }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index bae394785be..61174f1921e 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -19,10 +19,13 @@ package org.apache.spark.sql
 import java.math.{BigDecimal => JBigDecimal}
 import java.time.LocalDate
 
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
 import com.google.protobuf.ByteString
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.client.unsupported
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction}
 
 /**
  * Commonly used functions available for DataFrame operations.
@@ -80,4 +83,256 @@ object functions {
       case _ => unsupported(s"literal $literal not supported (yet).")
     }
   }
+
+  // scalastyle:off line.size.limit
+
+  /**
+   * Defines a Scala closure of 0 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[RT: TypeTag](f: () => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(f, typeTag[RT])
+  }
+
+  /**
+   * Defines a Scala closure of 1 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1])
+  }
+
+  /**
+   * Defines a Scala closure of 2 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2])
+  }
+
+  /**
+   * Defines a Scala closure of 3 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
+      f: (A1, A2, A3) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3])
+  }
+
+  /**
+   * Defines a Scala closure of 4 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
+      f: (A1, A2, A3, A4) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3], typeTag[A4])
+  }
+
+  /**
+   * Defines a Scala closure of 5 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
+      f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(
+      f,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5])
+  }
+
+  /**
+   * Defines a Scala closure of 6 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag](f: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(
+      f,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6])
+  }
+
+  /**
+   * Defines a Scala closure of 7 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(
+      f,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7])
+  }
+
+  /**
+   * Defines a Scala closure of 8 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(
+      f,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8])
+  }
+
+  /**
+   * Defines a Scala closure of 9 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(
+      f,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9])
+  }
+
+  /**
+   * Defines a Scala closure of 10 arguments as user-defined function (UDF). The data types are
+   * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+   * deterministic. To change it to nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * @group udf_funcs
+   * @since 3.4.0
+   */
+  def udf[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction = {
+    ScalarUserDefinedFunction(
+      f,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10])
+  }
+  // scalastyle:off line.size.limit
+
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index e31f121ca10..db2b8b26987 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.functions.udf
 import org.apache.spark.sql.types.{StringType, StructField, StructType}
 
 class ClientE2ETestSuite extends RemoteSparkSession {
@@ -48,6 +49,19 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     assert(array(2).getLong(0) == 2)
   }
 
+  test("simple udf test") {
+
+    def dummyUdf(x: Int): Int = x + 5
+    val myUdf = udf(dummyUdf _)
+    val df = spark.range(5).select(myUdf(Column("id")))
+
+    val result = df.collectResult()
+    assert(result.length == 5)
+    result.toArray.zipWithIndex.foreach { case (v, idx) =>
+      assert(v.getInt(0) == idx + 5)
+    }
+  }
+
   // TODO test large result when we can create table or view
   // test("test spark large result")
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
new file mode 100644
index 00000000000..b0d92a223c6
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.reflect.runtime.universe.typeTag
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.connect.common.UdfPacket
+import org.apache.spark.sql.functions.udf
+import org.apache.spark.util.Utils
+
+class UserDefinedFunctionSuite
+    extends AnyFunSuite // scalastyle:ignore funsuite
+    with BeforeAndAfterEach {
+
+  test("udf and encoder serialization") {
+    def func(x: Int): Int = x + 1
+
+    val myUdf = udf(func _)
+    val colWithUdf = myUdf(Column("dummy"))
+
+    val udfExpr = colWithUdf.expr.getCommonInlineUserDefinedFunction
+    assert(udfExpr.getDeterministic)
+    assert(udfExpr.getArgumentsCount == 1)
+    assert(udfExpr.getArguments(0) == Column("dummy").expr)
+    val udfObj = udfExpr.getScalarScalaUdf
+
+    assert(udfObj.getNullable)
+
+    val deSer = Utils.deserialize[UdfPacket](udfObj.getPayload.toByteArray)
+
+    assert(deSer.function.asInstanceOf[Int => Int](5) == func(5))
+    assert(deSer.outputEncoder == ScalaReflection.encoderFor(typeTag[Int]))
+    assert(deSer.inputEncoders == Seq(ScalaReflection.encoderFor(typeTag[Int])))
+  }
+}
diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml
index a37f87dda1e..eb1e4cae34d 100644
--- a/connector/connect/common/pom.xml
+++ b/connector/connect/common/pom.xml
@@ -38,6 +38,18 @@
         <tomcat.annotations.api.version>6.0.53</tomcat.annotations.api.version>
     </properties>
     <dependencies>
+        <dependency>
+            <groupId>org.apache.spark</groupId>
+            <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+            <version>${project.version}</version>
+            <scope>provided</scope>
+            <exclusions>
+                <exclusion>
+                    <groupId>com.google.guava</groupId>
+                    <artifactId>guava</artifactId>
+                </exclusion>
+            </exclusions>
+        </dependency>
         <dependency>
             <groupId>org.scala-lang</groupId>
             <artifactId>scala-library</artifactId>
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 5b27d4593db..66361883321 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -307,6 +307,7 @@ message CommonInlineUserDefinedFunction {
   // (Required) Indicate the function type of the user-defined function.
   oneof function {
     PythonUDF python_udf = 4;
+    ScalarScalaUDF scalar_scala_udf = 5;
   }
 }
 
@@ -319,3 +320,14 @@ message PythonUDF {
   bytes command = 3;
 }
 
+message ScalarScalaUDF {
+  // (Required) Serialized JVM object containing UDF definition, input encoders and output encoder
+  bytes payload = 1;
+  // (Optional) Input type(s) of the UDF
+  repeated DataType inputTypes = 2;
+  // (Required) Output type of the UDF
+  DataType outputType = 3;
+  // (Required) True if the UDF can return null value
+  bool nullable = 4;
+}
+
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala
new file mode 100644
index 00000000000..6829b8d1b21
--- /dev/null
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.common
+
+import com.google.protobuf.ByteString
+import java.io.{InputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
+
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+
+/**
+ * A wrapper class around the UDF and it's Input/Output [[AgnosticEncoder]](s).
+ *
+ * This class is shared between the client and the server to allow for serialization and
+ * deserialization of the JVM object.
+ *
+ * @param function
+ *   The UDF
+ * @param inputEncoders
+ *   A list of [[AgnosticEncoder]](s) for all input arguments of the UDF
+ * @param outputEncoder
+ *   An [[AgnosticEncoder]] for the output of the UDF
+ */
+@SerialVersionUID(8866761834651399125L)
+case class UdfPacket(
+    function: AnyRef,
+    inputEncoders: Seq[AgnosticEncoder[_]],
+    outputEncoder: AgnosticEncoder[_])
+    extends Serializable {
+
+  def writeTo(out: OutputStream): Unit = {
+    val oos = new ObjectOutputStream(out)
+    oos.writeObject(this)
+    oos.flush()
+  }
+
+  def toByteString: ByteString = {
+    val out = ByteString.newOutput()
+    writeTo(out)
+    out.toByteString
+  }
+}
+
+object UdfPacket {
+  def apply(in: InputStream): UdfPacket = {
+    val ois = new ObjectInputStream(in)
+    ois.readObject().asInstanceOf[UdfPacket]
+  }
+
+  def apply(bytes: ByteString): UdfPacket = {
+    val in = bytes.newInput()
+    try UdfPacket(in)
+    finally {
+      in.close()
+    }
+  }
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 9b5c4b93f62..51d115ef1ca 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -29,6 +29,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
@@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
+import org.apache.spark.sql.connect.common.UdfPacket
 import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue}
 import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
 import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -831,12 +833,37 @@ class SparkConnectPlanner(val session: SparkSession) {
     fun.getFunctionCase match {
       case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
         transformPythonUDF(fun)
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+        transformScalarScalaUDF(fun)
       case _ =>
         throw InvalidPlanInput(
           s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
     }
   }
 
+  /**
+   * Translates a Scalar Scala user-defined function from proto to the Catalyst expression.
+   *
+   * @param fun
+   *   Proto representation of the Scalar Scalar user-defined function.
+   * @return
+   *   ScalaUDF.
+   */
+  private def transformScalarScalaUDF(fun: proto.CommonInlineUserDefinedFunction): ScalaUDF = {
+    val udf = fun.getScalarScalaUdf
+    val udfPacket =
+      Utils.deserialize[UdfPacket](udf.getPayload.toByteArray, Utils.getContextOrSparkClassLoader)
+    ScalaUDF(
+      function = udfPacket.function,
+      dataType = udfPacket.outputEncoder.dataType,
+      children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+      inputEncoders = udfPacket.inputEncoders.map(e => Option(ExpressionEncoder(e))),
+      outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
+      udfName = Option(fun.getFunctionName),
+      nullable = udf.getNullable,
+      udfDeterministic = fun.getDeterministic)
+  }
+
   /**
    * Translates a Python user-defined function from proto to the Catalyst expression.
    *
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index f320eee54e0..3a06e80c21e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
 )
 
 
@@ -65,6 +65,7 @@ _COMMONINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
     "CommonInlineUserDefinedFunction"
 ]
 _PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"]
+_SCALARSCALAUDF = DESCRIPTOR.message_types_by_name["ScalarScalaUDF"]
 _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
     "FrameType"
 ]
@@ -283,6 +284,17 @@ PythonUDF = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(PythonUDF)
 
+ScalarScalaUDF = _reflection.GeneratedProtocolMessageType(
+    "ScalarScalaUDF",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _SCALARSCALAUDF,
+        "__module__": "spark.connect.expressions_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.ScalarScalaUDF)
+    },
+)
+_sym_db.RegisterMessage(ScalarScalaUDF)
+
 if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
@@ -332,7 +344,9 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784
     _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5098
-    _PYTHONUDF._serialized_start = 5100
-    _PYTHONUDF._serialized_end = 5199
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5173
+    _PYTHONUDF._serialized_start = 5175
+    _PYTHONUDF._serialized_end = 5274
+    _SCALARSCALAUDF._serialized_start = 5277
+    _SCALARSCALAUDF._serialized_end = 5461
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index d8b0485017c..604672a9ad7 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1100,6 +1100,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
     DETERMINISTIC_FIELD_NUMBER: builtins.int
     ARGUMENTS_FIELD_NUMBER: builtins.int
     PYTHON_UDF_FIELD_NUMBER: builtins.int
+    SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
     function_name: builtins.str
     """(Required) Name of the user-defined function."""
     deterministic: builtins.bool
@@ -1111,6 +1112,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
         """(Optional) Function arguments. Empty arguments are allowed."""
     @property
     def python_udf(self) -> global___PythonUDF: ...
+    @property
+    def scalar_scala_udf(self) -> global___ScalarScalaUDF: ...
     def __init__(
         self,
         *,
@@ -1118,10 +1121,18 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
         deterministic: builtins.bool = ...,
         arguments: collections.abc.Iterable[global___Expression] | None = ...,
         python_udf: global___PythonUDF | None = ...,
+        scalar_scala_udf: global___ScalarScalaUDF | None = ...,
     ) -> None: ...
     def HasField(
         self,
-        field_name: typing_extensions.Literal["function", b"function", "python_udf", b"python_udf"],
+        field_name: typing_extensions.Literal[
+            "function",
+            b"function",
+            "python_udf",
+            b"python_udf",
+            "scalar_scala_udf",
+            b"scalar_scala_udf",
+        ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
@@ -1136,11 +1147,13 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
             b"function_name",
             "python_udf",
             b"python_udf",
+            "scalar_scala_udf",
+            b"scalar_scala_udf",
         ],
     ) -> None: ...
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["function", b"function"]
-    ) -> typing_extensions.Literal["python_udf"] | None: ...
+    ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf"] | None: ...
 
 global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction
 
@@ -1171,3 +1184,52 @@ class PythonUDF(google.protobuf.message.Message):
     ) -> None: ...
 
 global___PythonUDF = PythonUDF
+
+class ScalarScalaUDF(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    PAYLOAD_FIELD_NUMBER: builtins.int
+    INPUTTYPES_FIELD_NUMBER: builtins.int
+    OUTPUTTYPE_FIELD_NUMBER: builtins.int
+    NULLABLE_FIELD_NUMBER: builtins.int
+    payload: builtins.bytes
+    """(Required) Serialized JVM object containing UDF definition, input encoders and output encoder"""
+    @property
+    def inputTypes(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.types_pb2.DataType
+    ]:
+        """(Optional) Input type(s) of the UDF"""
+    @property
+    def outputType(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+        """(Required) Output type of the UDF"""
+    nullable: builtins.bool
+    """(Required) True if the UDF can return null value"""
+    def __init__(
+        self,
+        *,
+        payload: builtins.bytes = ...,
+        inputTypes: collections.abc.Iterable[pyspark.sql.connect.proto.types_pb2.DataType]
+        | None = ...,
+        outputType: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+        nullable: builtins.bool = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["outputType", b"outputType"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "inputTypes",
+            b"inputTypes",
+            "nullable",
+            b"nullable",
+            "outputType",
+            b"outputType",
+            "payload",
+            b"payload",
+        ],
+    ) -> None: ...
+
+global___ScalarScalaUDF = ScalarScalaUDF


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org