You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "amaliujia (via GitHub)" <gi...@apache.org> on 2023/04/21 22:26:22 UTC

[GitHub] [spark] amaliujia commented on a diff in pull request #40796: [SPARK-43223][Connect] Typed agg, reduce functions

amaliujia commented on code in PR #40796:
URL: https://github.com/apache/spark/pull/40796#discussion_r1174198988


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -545,34 +540,94 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket =
-      Utils.deserialize[UdfPacket](
-        udf.getPayload.toByteArray,
-        SparkConnectArtifactManager.classLoaderWithArtifacts)
-    assert(udfPacket.inputEncoders.size == 1)
-    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
-    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
+    val udf = unpackUdf(fun)
+    assert(udf.inputEncoders.size == 1)
+    val iEnc = ExpressionEncoder(udf.inputEncoders.head)
+    val rEnc = ExpressionEncoder(udf.outputEncoder)
 
     val deserializer = UnresolvedDeserializer(iEnc.deserializer)
     val deserialized = DeserializeToObject(deserializer, generateObjAttr(iEnc), child)
     val mapped = MapPartitions(
-      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
       generateObjAttr(rEnc),
       deserialized)
     SerializeFromObject(rEnc.namedExpressions, mapped)
   }
 
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
-    val cols =
-      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+        transformTypedGroupMap(rel, commonUdf)
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(cols: _*)
-      .flatMapGroupsInPandas(pythonUdf)
-      .logicalPlan
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val cols =
+          rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+
+        Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(cols: _*)
+          .flatMapGroupsInPandas(pythonUdf)
+          .logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported")
+    }
+  }
+
+  private def transformTypedGroupMap(
+      rel: GroupMap,
+      commonUdf: CommonInlineUserDefinedFunction): LogicalPlan = {
+    // Compute grouping key
+    val logicalPlan = transformRelation(rel.getInput)
+    val udf = unpackUdf(commonUdf)
+    assert(rel.getGroupingExpressionsCount == 1)
+    val groupFunc = rel.getGroupingExpressionsList.asScala.toSeq
+      .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction))
+      .head
+
+    assert(groupFunc.inputEncoders.size == 1)
+    val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head)
+    val kEnc = ExpressionEncoder(groupFunc.outputEncoder)
+    val uEnc = ExpressionEncoder(udf.outputEncoder)
+    assert(udf.inputEncoders.nonEmpty)
+    // ukEnc != kEnc if user has called kvDS.keyAs
+    val ukEnc = ExpressionEncoder(udf.inputEncoders.head)
+
+    val withGroupingKey = new AppendColumns(
+      groupFunc.function.asInstanceOf[Any => Any],
+      vEnc.clsTag.runtimeClass,
+      vEnc.schema,
+      UnresolvedDeserializer(vEnc.deserializer),
+      kEnc.namedExpressions,
+      logicalPlan)
+
+    // Compute sort order
+    val sortExprs =
+      rel.getSortingExpressionsList.asScala.toSeq.map(expr => transformExpression(expr))
+    val sortOrder: Seq[SortOrder] = sortExprs.map {
+      case expr: SortOrder => expr
+      case expr: Expression => SortOrder(expr, Ascending)

Review Comment:
   Who will be in charge of checking non-supported expr? 



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -545,34 +540,94 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket =
-      Utils.deserialize[UdfPacket](
-        udf.getPayload.toByteArray,
-        SparkConnectArtifactManager.classLoaderWithArtifacts)
-    assert(udfPacket.inputEncoders.size == 1)
-    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
-    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
+    val udf = unpackUdf(fun)
+    assert(udf.inputEncoders.size == 1)
+    val iEnc = ExpressionEncoder(udf.inputEncoders.head)
+    val rEnc = ExpressionEncoder(udf.outputEncoder)
 
     val deserializer = UnresolvedDeserializer(iEnc.deserializer)
     val deserialized = DeserializeToObject(deserializer, generateObjAttr(iEnc), child)
     val mapped = MapPartitions(
-      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
       generateObjAttr(rEnc),
       deserialized)
     SerializeFromObject(rEnc.namedExpressions, mapped)
   }
 
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
-    val cols =
-      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+        transformTypedGroupMap(rel, commonUdf)
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(cols: _*)
-      .flatMapGroupsInPandas(pythonUdf)
-      .logicalPlan
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val cols =
+          rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+
+        Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(cols: _*)
+          .flatMapGroupsInPandas(pythonUdf)
+          .logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported")
+    }
+  }
+
+  private def transformTypedGroupMap(
+      rel: GroupMap,
+      commonUdf: CommonInlineUserDefinedFunction): LogicalPlan = {
+    // Compute grouping key
+    val logicalPlan = transformRelation(rel.getInput)
+    val udf = unpackUdf(commonUdf)
+    assert(rel.getGroupingExpressionsCount == 1)
+    val groupFunc = rel.getGroupingExpressionsList.asScala.toSeq
+      .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction))
+      .head
+
+    assert(groupFunc.inputEncoders.size == 1)
+    val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head)
+    val kEnc = ExpressionEncoder(groupFunc.outputEncoder)
+    val uEnc = ExpressionEncoder(udf.outputEncoder)
+    assert(udf.inputEncoders.nonEmpty)
+    // ukEnc != kEnc if user has called kvDS.keyAs
+    val ukEnc = ExpressionEncoder(udf.inputEncoders.head)
+
+    val withGroupingKey = new AppendColumns(
+      groupFunc.function.asInstanceOf[Any => Any],
+      vEnc.clsTag.runtimeClass,
+      vEnc.schema,
+      UnresolvedDeserializer(vEnc.deserializer),
+      kEnc.namedExpressions,
+      logicalPlan)
+
+    // Compute sort order
+    val sortExprs =
+      rel.getSortingExpressionsList.asScala.toSeq.map(expr => transformExpression(expr))
+    val sortOrder: Seq[SortOrder] = sortExprs.map {
+      case expr: SortOrder => expr
+      case expr: Expression => SortOrder(expr, Ascending)

Review Comment:
   Who will in charge of checking non-supported expr? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org