You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/06/28 01:05:45 UTC

[spark] branch master updated: [SPARK-44161][CONNECT] Handle Row input for UDFs

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 05fc3497f00 [SPARK-44161][CONNECT] Handle Row input for UDFs
05fc3497f00 is described below

commit 05fc3497f00a0aad9240f14637ea21d271b2bbe4
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Tue Jun 27 21:05:33 2023 -0400

    [SPARK-44161][CONNECT] Handle Row input for UDFs
    
    ### What changes were proposed in this pull request?
    If the client passes Rows as inputs to UDFs, the Spark connect planner will fail to create the RowEncoder for the Row input.
    
    The Row encoder sent by the client contains no field or schema information. The real input schema should be obtained from the plan's output.
    
    This PR ensures if the server planner failed to create the encoder for the UDF input using reflection, then it will fall back to use RowEncoders created from the plan.output schema.
    
    This PR fixed [SPARK-43761](https://issues.apache.org/jira/browse/SPARK-43761) using the same logic.
    This PR resolved [SPARK-43796](https://issues.apache.org/jira/browse/SPARK-43796). The error is just caused by the case class defined in the test.
    
    ### Why are the changes needed?
    Fix the bug where the Row cannot be used as UDF inputs.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    E2E tests.
    
    Closes #41704 from zhenlineo/rowEncoder.
    
    Authored-by: Zhen Li <zh...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../sql/expressions/UserDefinedFunction.scala      |  5 +-
 .../spark/sql/streaming/DataStreamWriter.scala     | 11 +--
 .../sql/KeyValueGroupedDatasetE2ETestSuite.scala   | 45 +++++++++++++
 .../sql/UserDefinedFunctionE2ETestSuite.scala      | 36 +++++++++-
 .../spark/sql/streaming/StreamingQuerySuite.scala  | 55 ++++++++-------
 .../sql/connect/planner/SparkConnectPlanner.scala  | 78 ++++++++++++++--------
 .../spark/sql/catalyst/ScalaReflection.scala       | 38 ++++++++---
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  | 21 +++++-
 .../spark/sql/streaming/DataStreamWriter.scala     |  6 +-
 9 files changed, 210 insertions(+), 85 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index bfcd4572e03..14dfc0c6a86 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -142,7 +142,10 @@ object ScalarUserDefinedFunction {
 
     ScalarUserDefinedFunction(
       function = function,
-      inputEncoders = parameterTypes.map(tag => ScalaReflection.encoderFor(tag)),
+      // Input can be a row because the input data schema can be found from the plan.
+      inputEncoders =
+        parameterTypes.map(tag => ScalaReflection.encoderForWithRowEncoderSupport(tag)),
+      // Output cannot be a row as there is no good way to get the return data type.
       outputEncoder = ScalaReflection.encoderFor(returnType))
   }
 
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 263e1e372c8..ed3d2bb8558 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -30,7 +30,6 @@ import org.apache.spark.connect.proto.Command
 import org.apache.spark.connect.proto.WriteStreamOperationStart
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Dataset, ForeachWriter}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.connect.common.ForeachWriterPacket
 import org.apache.spark.sql.execution.streaming.AvailableNowTrigger
 import org.apache.spark.sql.execution.streaming.ContinuousTrigger
@@ -215,15 +214,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
    * @since 3.5.0
    */
   def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
-    // TODO [SPARK-43761] Update this once resolved UnboundRowEncoder serialization issue.
-    // ds.encoder equal to UnboundRowEncoder means type parameter T is Row,
-    // which is not able to be serialized. Server will detect this and use default encoder.
-    val rowEncoder = if (ds.encoder != UnboundRowEncoder) {
-      ds.encoder
-    } else {
-      null
-    }
-    val serialized = Utils.serialize(ForeachWriterPacket(writer, rowEncoder))
+    val serialized = Utils.serialize(ForeachWriterPacket(writer, ds.encoder))
     val scalaWriterBuilder = proto.ScalarScalaUDF
       .newBuilder()
       .setPayload(ByteString.copyFrom(serialized))
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 173867060bd..79bedabd559 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -580,6 +580,51 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
 
     checkDataset(values, ClickState("a", 5), ClickState("b", 3), ClickState("c", 1))
   }
+
+  test("RowEncoder in udf") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("c1", "c2")
+
+    checkDatasetUnorderly(
+      ds.groupByKey(k => k.getAs[String](0)).agg(sum("c2").as[Long]),
+      ("a", 30L),
+      ("b", 3L),
+      ("c", 1L))
+  }
+
+  test("mapGroups with row encoder") {
+    val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("c1", "c2")
+
+    checkDataset(
+      df.groupByKey(r => r.getAs[String]("c1"))
+        .mapGroups((_, it) =>
+          it.map(r => {
+            r.getAs[Int]("c2")
+          }).sum),
+      30,
+      3,
+      1)
+  }
+
+  test("coGroup with row encoder") {
+    val df1 = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("c1", "c2")
+    val df2 = Seq(("x", 10), ("x", 20), ("y", 1), ("y", 2), ("a", 1)).toDF("c1", "c2")
+
+    val ds1: KeyValueGroupedDataset[String, Row] =
+      df1.groupByKey(r => r.getAs[String]("c1"))
+    val ds2: KeyValueGroupedDataset[String, Row] =
+      df2.groupByKey(r => r.getAs[String]("c1"))
+    checkDataset(
+      ds1.cogroup(ds2)((_, it, it2) => {
+        val sum1 = it.map(r => r.getAs[Int]("c2")).sum
+        val sum2 = it2.map(r => r.getAs[Int]("c2")).sum
+        Iterator(sum1 + sum2)
+      }),
+      31,
+      3,
+      1,
+      30,
+      3)
+  }
 }
 
 case class K1(a: Long)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index ca1bcf3fe67..a4f1a61cf39 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -26,13 +26,13 @@ import scala.collection.JavaConverters._
 import org.apache.spark.TaskContext
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder}
-import org.apache.spark.sql.connect.client.util.RemoteSparkSession
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.connect.client.util.QueryTest
+import org.apache.spark.sql.functions.{col, struct, udf}
 
 /**
  * All tests in this class requires client UDF defined in this test class synced with the server.
  */
-class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
+class UserDefinedFunctionE2ETestSuite extends QueryTest {
   test("Dataset typed filter") {
     val rows = spark.range(10).filter(n => n % 2 == 0).collectAsList()
     assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))
@@ -227,4 +227,34 @@ class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
           override def call(v1: Long, v2: Long): Long = v1 + v2
         }) == 55)
   }
+
+  test("udf with row input encoder") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val df = Seq((1, 2, 3)).toDF("a", "b", "c")
+    val f = udf((row: Row) => row.schema.fieldNames)
+    checkDataset(df.select(f(struct(df.columns map col: _*))), Row(Seq("a", "b", "c")))
+  }
+
+  test("Filter with row input encoder") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("c1", "c2")
+
+    checkDataset(df.filter(r => r.getInt(1) > 5), Row("a", 10), Row("a", 20))
+  }
+
+  test("mapPartitions with row input encoder") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("c1", "c2")
+
+    checkDataset(
+      df.mapPartitions(it => it.map(r => r.getAs[String]("c1"))),
+      "a",
+      "a",
+      "b",
+      "b",
+      "c")
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index f56fb97f484..6ddcedf19cb 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -215,34 +215,29 @@ class StreamingQuerySuite extends RemoteSparkSession with SQLHelper {
     query.stop()
   }
 
-  // TODO[SPARK-43796]: Enable this test once ammonite client fixes the issue.
-  //  test("foreach Custom class") {
-  //    val session: SparkSession = spark
-  //    import session.implicits._
-  //
-  //    case class TestClass(value: Int) {
-  //      override def toString: String = value.toString
-  //    }
-  //
-  //    val writer = new TestForeachWriter[TestClass]
-  //    val df = spark.readStream
-  //      .format("rate")
-  //      .option("rowsPerSecond", "10")
-  //      .load()
-  //
-  //    val query = df
-  //      .selectExpr("CAST(value AS INT)")
-  //      .as[TestClass]
-  //      .writeStream
-  //      .foreach(writer)
-  //      .outputMode("update")
-  //      .start()
-  //
-  //    assert(query.isActive)
-  //    assert(query.exception.isEmpty)
-  //
-  //    query.stop()
-  //  }
+  test("foreach Custom class") {
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val writer = new TestForeachWriter[TestClass]
+    val df = spark.readStream
+      .format("rate")
+      .option("rowsPerSecond", "10")
+      .load()
+
+    val query = df
+      .selectExpr("CAST(value AS INT)")
+      .as[TestClass]
+      .writeStream
+      .foreach(writer)
+      .outputMode("update")
+      .start()
+
+    assert(query.isActive)
+    assert(query.exception.isEmpty)
+
+    query.stop()
+  }
 
   test("streaming query manager") {
     assert(spark.streams.active.isEmpty)
@@ -293,3 +288,7 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
     Utils.deleteRecursively(path)
   }
 }
+
+case class TestClass(value: Int) {
+  override def toString: String = value.toString
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ff158990560..cecf14a7045 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.planner
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.util.Try
 
 import com.google.common.collect.{Lists, Maps}
 import com.google.protobuf.{Any => ProtoAny, ByteString}
@@ -40,11 +41,12 @@ import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.Streami
 import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{functions => MLFunctions}
-import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, RelationalGroupedDataset, Row, SparkSession}
+import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, RelationalGroupedDataset, SparkSession}
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
@@ -543,7 +545,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = TypedScalaUdf(fun)
+    val udf = TypedScalaUdf(fun, Some(child.output))
     val deserialized = DeserializeToObject(udf.inputDeserializer(), udf.inputObjAttr, child)
     val mapped = MapPartitions(
       udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
@@ -776,7 +778,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
         groupingExprs: java.util.List[proto.Expression],
         sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
       assert(groupingExprs.size() >= 1)
-      val dummyFunc = TypedScalaUdf(groupingExprs.get(0))
+      val dummyFunc = TypedScalaUdf(groupingExprs.get(0), None)
       val groupExprs = groupingExprs.asScala.toSeq.drop(1).map(expr => transformExpression(expr))
 
       val (qe, aliasedGroupings) =
@@ -796,7 +798,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
         groupingExprs: java.util.List[proto.Expression],
         sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
       assert(groupingExprs.size() == 1)
-      val groupFunc = TypedScalaUdf(groupingExprs.get(0))
+      val groupFunc = TypedScalaUdf(groupingExprs.get(0), Some(logicalPlan.output))
       val vEnc = groupFunc.inEnc
       val kEnc = groupFunc.outEnc
 
@@ -819,35 +821,58 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
    */
   private case class TypedScalaUdf(
       function: AnyRef,
-      outEnc: ExpressionEncoder[_],
-      outputObjAttr: Attribute,
-      inEnc: ExpressionEncoder[_],
-      inputObjAttr: Attribute) {
-    val outputNamedExpression: Seq[NamedExpression] = outEnc.namedExpressions
-    def inputDeserializer(inputAttributes: Seq[Attribute] = Nil): Expression = {
+      funcOutEnc: AgnosticEncoder[_],
+      funcInEnc: AgnosticEncoder[_],
+      inputAttrs: Option[Seq[Attribute]]) {
+    import TypedScalaUdf.encoderFor
+    def outputNamedExpression: Seq[NamedExpression] = outEnc.namedExpressions
+    def inputDeserializer(inputAttributes: Seq[Attribute] = Nil): Expression =
       UnresolvedDeserializer(inEnc.deserializer, inputAttributes)
-    }
+
+    def outEnc: ExpressionEncoder[_] = encoderFor(funcOutEnc, "output")
+    def outputObjAttr: Attribute = generateObjAttr(outEnc)
+    def inEnc: ExpressionEncoder[_] = encoderFor(funcInEnc, "input", inputAttrs)
+    def inputObjAttr: Attribute = generateObjAttr(inEnc)
   }
+
   private object TypedScalaUdf {
-    def apply(expr: proto.Expression): TypedScalaUdf = {
+    def apply(expr: proto.Expression, inputAttrs: Option[Seq[Attribute]]): TypedScalaUdf = {
       if (expr.hasCommonInlineUserDefinedFunction
         && expr.getCommonInlineUserDefinedFunction.hasScalarScalaUdf) {
-        apply(expr.getCommonInlineUserDefinedFunction)
+        apply(expr.getCommonInlineUserDefinedFunction, inputAttrs)
       } else {
         throw InvalidPlanInput(s"Expecting a Scala UDF, but get ${expr.getExprTypeCase}")
       }
     }
 
-    def apply(commonUdf: proto.CommonInlineUserDefinedFunction): TypedScalaUdf = {
+    def apply(
+        commonUdf: proto.CommonInlineUserDefinedFunction,
+        inputAttrs: Option[Seq[Attribute]] = None): TypedScalaUdf = {
       val udf = unpackUdf(commonUdf)
-      val outEnc = ExpressionEncoder(udf.outputEncoder)
       // There might be more than one inputs, but we only interested in the first one.
       // Most typed API takes one UDF input.
       // For the few that takes more than one inputs, e.g. grouping function mapping UDFs,
       // the first input which is the key of the grouping function.
       assert(udf.inputEncoders.nonEmpty)
-      val inEnc = ExpressionEncoder(udf.inputEncoders.head) // single input encoder or key encoder
-      TypedScalaUdf(udf.function, outEnc, generateObjAttr(outEnc), inEnc, generateObjAttr(inEnc))
+      val inEnc = udf.inputEncoders.head // single input encoder or key encoder
+      TypedScalaUdf(udf.function, udf.outputEncoder, inEnc, inputAttrs)
+    }
+
+    def encoderFor(
+        encoder: AgnosticEncoder[_],
+        errorType: String,
+        inputAttrs: Option[Seq[Attribute]] = None): ExpressionEncoder[_] = {
+      // TODO: handle nested unbound row encoders
+      if (encoder == UnboundRowEncoder) {
+        inputAttrs
+          .map(attrs =>
+            ExpressionEncoder(RowEncoder.encoderFor(StructType(attrs.map(a =>
+              StructField(a.name, a.dataType, a.nullable))))))
+          .getOrElse(
+            throw InvalidPlanInput(s"Row is not a supported $errorType type for this UDF."))
+      } else {
+        ExpressionEncoder(encoder)
+      }
     }
   }
 
@@ -1265,7 +1290,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
   private def transformTypedFilter(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): TypedFilter = {
-    val udf = TypedScalaUdf(fun)
+    val udf = TypedScalaUdf(fun, Some(child.output))
     TypedFilter(udf.function, child)(udf.inEnc)
   }
 
@@ -1455,7 +1480,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
       function = udfPacket.function,
       dataType = transformDataType(udf.getOutputType),
       children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
-      inputEncoders = udfPacket.inputEncoders.map(e => Option(ExpressionEncoder(e))),
+      inputEncoders = udfPacket.inputEncoders.map(e => Try(ExpressionEncoder(e)).toOption),
       outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
       udfName = Option(fun.getFunctionName),
       nullable = udf.getNullable,
@@ -2620,15 +2645,10 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
       } else {
         val foreachWriterPkt = unpackForeachWriter(writeOp.getForeachWriter.getScalaWriter)
         val clientWriter = foreachWriterPkt.foreachWriter
-        if (foreachWriterPkt.datasetEncoder == null) {
-          // datasetEncoder is null means the client-side writer has type parameter Row,
-          // Since server-side dataset is always dataframe, here just use foreach directly.
-          writer.foreach(clientWriter.asInstanceOf[ForeachWriter[Row]])
-        } else {
-          val encoder = ExpressionEncoder(
-            foreachWriterPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]])
-          writer.foreachImplementation(clientWriter.asInstanceOf[ForeachWriter[Any]], encoder)
-        }
+        val encoder: Option[ExpressionEncoder[Any]] = Try(
+          ExpressionEncoder(
+            foreachWriterPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]])).toOption
+        writer.foreachImplementation(clientWriter.asInstanceOf[ForeachWriter[Any]], encoder)
       }
     }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index ebc740be083..d67ba455438 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -29,6 +29,7 @@ import org.apache.commons.lang3.reflect.ConstructorUtils
 
 import org.apache.spark.SPARK_DOC_ROOT
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{expressions => exprs}
 import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
 import org.apache.spark.sql.catalyst.SerializerBuildHelper._
@@ -806,25 +807,38 @@ object ScalaReflection extends ScalaReflection {
     encoderFor(typeTag[E].in(mirror).tpe).asInstanceOf[AgnosticEncoder[E]]
   }
 
+  /**
+   * Same as [[encoderFor]] but with extended support to return [[UnboundRowEncoder]] for [[Row]]
+   * type.
+   */
+  def encoderForWithRowEncoderSupport[E: TypeTag]: AgnosticEncoder[E] = {
+    encoderFor(typeTag[E].in(mirror).tpe, isRowEncoderSupported = true)
+      .asInstanceOf[AgnosticEncoder[E]]
+  }
+
   /**
    * Create an [[AgnosticEncoder]] for a [[Type]].
    */
-  def encoderFor(tpe: `Type`): AgnosticEncoder[_] = cleanUpReflectionObjects {
+  def encoderFor(
+     tpe: `Type`,
+     isRowEncoderSupported: Boolean = false): AgnosticEncoder[_] = cleanUpReflectionObjects {
     val clsName = getClassNameFromType(tpe)
     val walkedTypePath = WalkedTypePath().recordRoot(clsName)
-    encoderFor(tpe, Set.empty, walkedTypePath)
+    encoderFor(tpe, Set.empty, walkedTypePath, isRowEncoderSupported)
   }
 
   private def encoderFor(
       tpe: `Type`,
       seenTypeSet: Set[`Type`],
-      path: WalkedTypePath): AgnosticEncoder[_] = {
+      path: WalkedTypePath,
+      isRowEncoderSupported: Boolean): AgnosticEncoder[_] = {
     def createIterableEncoder(t: `Type`, fallbackClass: Class[_]): AgnosticEncoder[_] = {
       val TypeRef(_, _, Seq(elementType)) = t
       val encoder = encoderFor(
         elementType,
         seenTypeSet,
-        path.recordArray(getClassNameFromType(elementType)))
+        path.recordArray(getClassNameFromType(elementType)),
+        isRowEncoderSupported)
       val companion = t.dealias.typeSymbol.companion.typeSignature
       val targetClass = companion.member(TermName("newBuilder")) match {
         case NoSymbol => fallbackClass
@@ -888,6 +902,7 @@ object ScalaReflection extends ScalaReflection {
       case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => STRICT_TIMESTAMP_ENCODER
       case t if isSubtype(t, localTypeOf[java.time.Instant]) => STRICT_INSTANT_ENCODER
       case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder
+      case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder
 
       // UDT encoders
       case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
@@ -907,7 +922,8 @@ object ScalaReflection extends ScalaReflection {
         val encoder = encoderFor(
           optType,
           seenTypeSet,
-          path.recordOption(getClassNameFromType(optType)))
+          path.recordOption(getClassNameFromType(optType)),
+          isRowEncoderSupported)
         OptionEncoder(encoder)
 
       case t if isSubtype(t, localTypeOf[Array[_]]) =>
@@ -915,7 +931,8 @@ object ScalaReflection extends ScalaReflection {
         val encoder = encoderFor(
           elementType,
           seenTypeSet,
-          path.recordArray(getClassNameFromType(elementType)))
+          path.recordArray(getClassNameFromType(elementType)),
+          isRowEncoderSupported)
         ArrayEncoder(encoder, encoder.nullable)
 
       case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) =>
@@ -929,11 +946,13 @@ object ScalaReflection extends ScalaReflection {
         val keyEncoder = encoderFor(
           keyType,
           seenTypeSet,
-          path.recordKeyForMap(getClassNameFromType(keyType)))
+          path.recordKeyForMap(getClassNameFromType(keyType)),
+          isRowEncoderSupported)
         val valueEncoder = encoderFor(
           valueType,
           seenTypeSet,
-          path.recordValueForMap(getClassNameFromType(valueType)))
+          path.recordValueForMap(getClassNameFromType(valueType)),
+          isRowEncoderSupported)
         MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder, valueEncoder.nullable)
 
       case t if definedByConstructorParams(t) =>
@@ -951,7 +970,8 @@ object ScalaReflection extends ScalaReflection {
             val encoder = encoderFor(
               fieldType,
               seenTypeSet + t,
-              path.recordField(getClassNameFromType(fieldType), fieldName))
+              path.recordField(getClassNameFromType(fieldType), fieldName),
+              isRowEncoderSupported)
             EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
         }
         ProductEncoder(ClassTag(getClassFromType(t)), params)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index f8ebdfe7676..7df9f45ed83 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.catalyst
 
 import java.sql.{Date, Timestamp}
 
-import scala.reflect.runtime.universe.TypeTag
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.FooEnum.FooEnum
 import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast}
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, MapObjects, NewInstance}
 import org.apache.spark.sql.types._
@@ -596,4 +599,20 @@ class ScalaReflectionSuite extends SparkFunSuite {
         ),
         nullable = true))
   }
+
+  test("encoder for row") {
+    assert(encoderForWithRowEncoderSupport[Row] === UnboundRowEncoder)
+    assert(encoderForWithRowEncoderSupport[Option[Row]] === OptionEncoder(UnboundRowEncoder))
+    assert(encoderForWithRowEncoderSupport[Array[Row]] === ArrayEncoder(UnboundRowEncoder, true))
+    assert(encoderForWithRowEncoderSupport[Map[Row, Row]] ===
+      MapEncoder(
+        ClassTag(getClassFromType(typeTag[Map[Row, Row]].tpe)),
+        UnboundRowEncoder, UnboundRowEncoder, true))
+    assert(encoderForWithRowEncoderSupport[MyClass] ===
+      ProductEncoder(
+        ClassTag(getClassFromType(typeTag[MyClass].tpe)),
+        Seq(EncoderField("row", UnboundRowEncoder, true, Metadata.empty))))
+  }
+
+  case class MyClass(row: Row)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 12977987f08..ac7915921d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -451,16 +451,14 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
   }
 
   private[sql] def foreachImplementation(writer: ForeachWriter[Any],
-      encoder: ExpressionEncoder[Any] = null): DataStreamWriter[T] = {
+      encoder: Option[ExpressionEncoder[Any]] = None): DataStreamWriter[T] = {
     this.source = SOURCE_NAME_FOREACH
     this.foreachWriter = if (writer != null) {
       ds.sparkSession.sparkContext.clean(writer)
     } else {
       throw new IllegalArgumentException("foreach writer cannot be null")
     }
-    if (encoder != null) {
-      this.foreachWriterEncoder = encoder
-    }
+    encoder.foreach(e => this.foreachWriterEncoder = e)
     this
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org