You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "zhenlineo (via GitHub)" <gi...@apache.org> on 2023/04/10 22:40:28 UTC

[GitHub] [spark] zhenlineo opened a new pull request, #40729: [WIP][CONNECT] Adding groupByKey + mapGroup functions

zhenlineo opened a new pull request, #40729:
URL: https://github.com/apache/spark/pull/40729

   <!--
   Thanks for sending a pull request!  Here are some tips for you:
     1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
     2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
     3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
     4. Be sure to keep the PR description updated to reflect all changes.
     5. Please write your PR title to summarize what this PR proposes.
     6. If possible, provide a concise example to reproduce the issue for a faster review.
     7. If you want to add a new configuration, please read the guideline first for naming configurations in
        'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
     8. If you want to add or modify an error type or message, please read the guideline first in
        'core/src/main/resources/error/README.md'.
   -->
   
   ### What changes were proposed in this pull request?
   <!--
   Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. 
   If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
     1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
     2. If you fix some SQL features, you can provide some references of other DBMSes.
     3. If there is design documentation, please add the link.
     4. If there is a discussion in the mailing list, please add the link.
   -->
   
   
   ### Why are the changes needed?
   <!--
   Please clarify why the changes are needed. For instance,
     1. If you propose a new API, clarify the use case for a new API.
     2. If you fix a bug, you can clarify why it is a bug.
   -->
   
   
   ### Does this PR introduce _any_ user-facing change?
   <!--
   Note that it means *any* user-facing change including all aspects such as the documentation fix.
   If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
   If possible, please also clarify if this is a user-facing change compared to the released Spark versions or within the unreleased branches such as master.
   If no, write 'No'.
   -->
   
   
   ### How was this patch tested?
   <!--
   If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
   If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
   If tests were not added, please describe why they were not added and/or why it was difficult to add.
   If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
   -->
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhenlineo commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "zhenlineo (via GitHub)" <gi...@apache.org>.
zhenlineo commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1169373531


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -838,6 +841,12 @@ message CoGroupMap {
 
   // (Required) Input user-defined function.
   CommonInlineUserDefinedFunction func = 5;
+
+  // (Optional) Expressions for sorting. Only used by Scala Sorted CoGroup Map API.
+  repeated Expression input_sorting_expressions = 6;

Review Comment:
   Currently we follow the names used in the proto. The sorting expressions are one for the input and one for the other input.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [WIP][CONNECT] Adding groupByKey + mapGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1165744289


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########


Review Comment:
   Add this to compatibility tests?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1169422196


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, SingleInputUserDefinedFunction}
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {

Review Comment:
   Is this class mostly here to guarantee binary compatibility? That works for me.
   
   qq, can we make this class abstract?
   
   I am not 100% convinced that we need the `SingleInputUserDefinedFunction`. I think we can capture all of this in the in the implementation as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1177045727


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -520,54 +515,205 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket =
-      Utils.deserialize[UdfPacket](
-        udf.getPayload.toByteArray,
-        SparkConnectArtifactManager.classLoaderWithArtifacts)
-    assert(udfPacket.inputEncoders.size == 1)
-    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
-    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
-
-    val deserializer = UnresolvedDeserializer(iEnc.deserializer)
-    val deserialized = DeserializeToObject(deserializer, generateObjAttr(iEnc), child)
+    val udf = ScalaUdf(fun)
+    val deserialized = DeserializeToObject(udf.inputDeserializer(), udf.inputObjAttr, child)
     val mapped = MapPartitions(
-      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
-      generateObjAttr(rEnc),
+      udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      udf.outputObjAttr,
       deserialized)
-    SerializeFromObject(rEnc.namedExpressions, mapped)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
   }
 
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
-    val cols =
-      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+        transformTypedGroupMap(rel, commonUdf)
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(cols: _*)
-      .flatMapGroupsInPandas(pythonUdf)
-      .logicalPlan
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val cols =
+          rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+
+        Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(cols: _*)
+          .flatMapGroupsInPandas(pythonUdf)
+          .logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported")
+    }
+  }
+
+  private def transformTypedGroupMap(
+      rel: proto.GroupMap,
+      commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+    val udf = ScalaUdf(commonUdf)
+    val ds = UntypedKeyValueGroupedDataset(
+      rel.getInput,
+      rel.getGroupingExpressionsList,
+      rel.getSortingExpressionsList)
+
+    val mapped = new MapGroups(
+      udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
+      udf.inputDeserializer(ds.groupingAttributes),
+      ds.valueDeserializer,
+      ds.groupingAttributes,
+      ds.dataAttributes,
+      ds.sortOrder,
+      udf.outputObjAttr,
+      ds.analyzed)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
   }
 
   private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+        transformTypedCoGroupMap(rel, commonUdf)
 
-    val inputCols =
-      rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
-        Column(transformExpression(expr)))
-    val otherCols =
-      rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
-        Column(transformExpression(expr)))
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
 
-    val input = Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(inputCols: _*)
-    val other = Dataset
-      .ofRows(session, transformRelation(rel.getOther))
-      .groupBy(otherCols: _*)
+        val inputCols =
+          rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+        val otherCols =
+          rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
 
-    input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+        val input = Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(inputCols: _*)
+        val other = Dataset
+          .ofRows(session, transformRelation(rel.getOther))
+          .groupBy(otherCols: _*)
+
+        input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported")
+    }
+  }
+
+  private def transformTypedCoGroupMap(
+      rel: proto.CoGroupMap,
+      commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+    val udf = ScalaUdf(commonUdf)
+    val left = UntypedKeyValueGroupedDataset(
+      rel.getInput,
+      rel.getInputGroupingExpressionsList,
+      rel.getInputSortingExpressionsList)
+    val right = UntypedKeyValueGroupedDataset(
+      rel.getOther,
+      rel.getOtherGroupingExpressionsList,
+      rel.getOtherSortingExpressionsList)
+
+    val mapped = CoGroup(
+      udf.function.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
+      // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
+      // resolve the `keyDeserializer` based on either of them, here we pick the left one.
+      udf.inputDeserializer(left.groupingAttributes),
+      left.valueDeserializer,
+      right.valueDeserializer,
+      left.groupingAttributes,
+      right.groupingAttributes,
+      left.dataAttributes,
+      right.dataAttributes,
+      left.sortOrder,
+      right.sortOrder,
+      udf.outputObjAttr,
+      left.analyzed,
+      right.analyzed)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
+  }
+
+  /**
+   * This is the untyped version of [[KeyValueGroupedDataset]].
+   */
+  private case class UntypedKeyValueGroupedDataset(
+      kEncoder: ExpressionEncoder[_],
+      vEncoder: ExpressionEncoder[_],
+      valueDeserializer: Expression,
+      analyzed: LogicalPlan,
+      dataAttributes: Seq[Attribute],
+      groupingAttributes: Seq[Attribute],
+      sortOrder: Seq[SortOrder])
+  private object UntypedKeyValueGroupedDataset {
+    def apply(
+        input: proto.Relation,
+        groupingExprs: java.util.List[proto.Expression],
+        sortingExprs: java.util.List[proto.Expression]): UntypedKeyValueGroupedDataset = {
+      val logicalPlan = transformRelation(input)
+      assert(groupingExprs.size() == 1)
+      val groupFunc = groupingExprs.asScala.toSeq
+        .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction))
+        .head
+
+      assert(groupFunc.inputEncoders.size == 1)
+      val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head)
+      val kEnc = ExpressionEncoder(groupFunc.outputEncoder)
+
+      val withGroupingKey = new AppendColumns(
+        groupFunc.function.asInstanceOf[Any => Any],
+        vEnc.clsTag.runtimeClass,
+        vEnc.schema,
+        UnresolvedDeserializer(vEnc.deserializer),
+        kEnc.namedExpressions,
+        logicalPlan)
+
+      // The input logical plan of KeyValueGroupedDataset need to be executed and analyzed
+      val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed
+      val dataAttributes = logicalPlan.output
+      val groupingAttributes = withGroupingKey.newColumns
+      val valueDeserializer = UnresolvedDeserializer(vEnc.deserializer, dataAttributes)
+
+      // Compute sort order
+      val sortExprs =
+        sortingExprs.asScala.toSeq.map(expr => transformExpression(expr))
+      val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs)
+
+      UntypedKeyValueGroupedDataset(
+        kEnc,
+        vEnc,
+        valueDeserializer,
+        analyzed,
+        dataAttributes,
+        groupingAttributes,
+        sortOrder)
+    }
+  }
+
+  /**
+   * The UDF used in typed APIs, where the input column is absent.
+   */
+  private case class ScalaUdf(

Review Comment:
   Maybe name this `TypedScalaUdf`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell closed pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell closed pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions
URL: https://github.com/apache/spark/pull/40729


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "amaliujia (via GitHub)" <gi...@apache.org>.
amaliujia commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1169361659


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -838,6 +841,12 @@ message CoGroupMap {
 
   // (Required) Input user-defined function.
   CommonInlineUserDefinedFunction func = 5;
+
+  // (Optional) Expressions for sorting. Only used by Scala Sorted CoGroup Map API.
+  repeated Expression input_sorting_expressions = 6;

Review Comment:
   I think this is better to named `left_input_sorting_expressions` and another is `right_input_sorting_expressions`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1169434600


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, SingleInputUserDefinedFunction}
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
+   * specified type. The mapping of key columns to the type follows the same rules as `as` on
+   * [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+    mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
+  }
+
+  /**
+   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
+   * the Dataset to extract the keys and then running a distinct operation on those.
+   *
+   * @since 3.5.0
+   */
+  def keys: Dataset[K] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    flatMapSortedGroups()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = {
+    flatMapSortedGroups(SortExprs: _*)(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+    flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
+   * the function will be passed the grouping key and 2 iterators containing all elements in the
+   * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
+   * elements of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    cogroupSorted(other)()()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
+   * function will be passed the grouping key and 2 iterators containing all elements in the group
+   * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+      UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+}
+
+/**
+ * This class is the implementation of class [[KeyValueGroupedDataset]]. This class memorizes the
+ * initial types of the grouping function so that the original function will be sent to the server
+ * to perform the grouping first. Then any type modifications on the keys and the values will be
+ * applied sequentially to ensure the final type of the result remains the same as how
+ * [[KeyValueGroupedDataset]] behaves on the server.
+ */
+private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
+    ds: Dataset[IV],
+    private val sparkSession: SparkSession,
+    private val plan: proto.Plan,
+    private val kEncoder: AgnosticEncoder[K],
+    private val groupFunc: SingleInputUserDefinedFunction[IV, IK],

Review Comment:
   I am honestly not sure if you need to have `SingleInputUserDefinedFunction` here. The only thing you need is the UDF it self and its encoder.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [WIP][CONNECT] Adding groupByKey + mapGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1165781914


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala:
##########
@@ -133,6 +120,40 @@ case class ScalarUserDefinedFunction(
   override def asNondeterministic(): ScalarUserDefinedFunction = copy(deterministic = false)
 }
 
+/**
+ * A special variant of [[ScalarUserDefinedFunction]] for functions that only has one input.
+ */
+private[sql] case class SingleInputUserDefinedFunction[V, K](

Review Comment:
   Why we need this? It seems a bit overkill just to deal with a single input column.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhenlineo commented on a diff in pull request #40729: [WIP][CONNECT] Adding groupByKey + mapGroup functions

Posted by "zhenlineo (via GitHub)" <gi...@apache.org>.
zhenlineo commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1165825957


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, SingleInputUserDefinedFunction}
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {

Review Comment:
   In order to support `keyAs` and `mapValues`, we need to keep track of the original type info IK, IV, which are the ones to send via wire, as well as the casted version of K and V, which is used to modify the udf passed via e.g. mapGroups to make UDF accept the original IK and IV. Thus I introduced the new `KeyValueGroupedDatasetImpl` and `SingleInputUserDefinedFunction`. They keep track of all type info so that the modifications to UDFs can be easily chained.
   
   It is much more easier when the types are explicitly set as the compiler will ensure we did not pass any encoders etc wrongly by mistake. We can take a look again if it is better to merge the two classes after all methods are done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "xinrong-meng (via GitHub)" <gi...@apache.org>.
xinrong-meng commented on PR #40729:
URL: https://github.com/apache/spark/pull/40729#issuecomment-1515124878

   A nit on the PR description, would you please indicate which API is supported under `Does this PR introduce any user-facing change?` section?  That would be helpful for future API auditing. Thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1173260200


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
+   * specified type. The mapping of key columns to the type follows the same rules as `as` on
+   * [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+    mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
+  }
+
+  /**
+   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
+   * the Dataset to extract the keys and then running a distinct operation on those.
+   *
+   * @since 3.5.0
+   */
+  def keys: Dataset[K] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    flatMapSortedGroups()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = {
+    flatMapSortedGroups(SortExprs: _*)(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+    flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
+   * the function will be passed the grouping key and 2 iterators containing all elements in the
+   * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
+   * elements of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    cogroupSorted(other)()()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
+   * function will be passed the grouping key and 2 iterators containing all elements in the group
+   * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+      UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+}
+
+/**
+ * This class is the implementation of class [[KeyValueGroupedDataset]]. This class memorizes the
+ * initial types of the grouping function so that the original function will be sent to the server
+ * to perform the grouping first. Then any type modifications on the keys and the values will be
+ * applied sequentially to ensure the final type of the result remains the same as how
+ * [[KeyValueGroupedDataset]] behaves on the server.
+ */
+private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
+    private val ds: Dataset[IV],
+    private val sparkSession: SparkSession,
+    private val plan: proto.Plan,
+    private val ikEncoder: AgnosticEncoder[IK],
+    private val kEncoder: AgnosticEncoder[K],
+    private val groupingFunc: IV => IK,
+    private val valueMapFunc: IV => V)
+    extends KeyValueGroupedDataset[K, V] {
+
+  private val ivEncoder = ds.encoder
+
+  override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    new KeyValueGroupedDatasetImpl[L, V, IK, IV](
+      ds,
+      sparkSession,
+      plan,
+      ikEncoder,
+      encoderFor[L],
+      groupingFunc,
+      valueMapFunc)
+  }
+
+  override def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
+    new KeyValueGroupedDatasetImpl[K, W, IK, IV](
+      ds,
+      sparkSession,
+      plan,
+      ikEncoder,
+      kEncoder,
+      groupingFunc,
+      valueMapFunc.andThen(valueFunc))
+  }
+
+  override def keys: Dataset[K] = {
+    ds.map(groupingFunc)(ikEncoder)

Review Comment:
   You should be able to use `kEncoder` here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1169418265


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -838,6 +841,12 @@ message CoGroupMap {
 
   // (Required) Input user-defined function.
   CommonInlineUserDefinedFunction func = 5;
+
+  // (Optional) Expressions for sorting. Only used by Scala Sorted CoGroup Map API.
+  repeated Expression input_sorting_expressions = 6;

Review Comment:
   @xinrong-meng why did you name it `input` and `other`, instead of `left` and `right`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1177017870


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
+   * specified type. The mapping of key columns to the type follows the same rules as `as` on
+   * [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+    mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
+  }
+
+  /**
+   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
+   * the Dataset to extract the keys and then running a distinct operation on those.
+   *
+   * @since 3.5.0
+   */
+  def keys: Dataset[K] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    flatMapSortedGroups()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = {
+    flatMapSortedGroups(SortExprs: _*)(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+    flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
+   * the function will be passed the grouping key and 2 iterators containing all elements in the
+   * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
+   * elements of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    cogroupSorted(other)()()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
+   * function will be passed the grouping key and 2 iterators containing all elements in the group
+   * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+      UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+}
+
+/**
+ * This class is the implementation of class [[KeyValueGroupedDataset]]. This class memorizes the
+ * initial types of the grouping function so that the original function will be sent to the server
+ * to perform the grouping first. Then any type modifications on the keys and the values will be
+ * applied sequentially to ensure the final type of the result remains the same as how
+ * [[KeyValueGroupedDataset]] behaves on the server.
+ */
+private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
+    private val ds: Dataset[IV],
+    private val sparkSession: SparkSession,
+    private val plan: proto.Plan,
+    private val ikEncoder: AgnosticEncoder[IK],
+    private val kEncoder: AgnosticEncoder[K],
+    private val groupingFunc: IV => IK,
+    private val valueMapFunc: IV => V)
+    extends KeyValueGroupedDataset[K, V] {
+
+  private val ivEncoder = ds.encoder
+
+  override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {

Review Comment:
   We had a discussion about this one. `keyAs` just changes the key type. It does change the cardinality of the result. The cardinality is determined by the result of the grouping function.
   
   For example the following example returns 4 groups instead of 2:
   ```scala
   case class K1(a: Long)
   case class K2(a: Long, b: Long)
   spark.range(10).as[Long].groupByKey(id => K2(id % 2, id % 4)).keyAs[K1].count().collect()
   ```
   
   @cloud-fan is this expected?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1177036733


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -520,54 +515,205 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket =
-      Utils.deserialize[UdfPacket](
-        udf.getPayload.toByteArray,
-        SparkConnectArtifactManager.classLoaderWithArtifacts)
-    assert(udfPacket.inputEncoders.size == 1)
-    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
-    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
-
-    val deserializer = UnresolvedDeserializer(iEnc.deserializer)
-    val deserialized = DeserializeToObject(deserializer, generateObjAttr(iEnc), child)
+    val udf = ScalaUdf(fun)
+    val deserialized = DeserializeToObject(udf.inputDeserializer(), udf.inputObjAttr, child)
     val mapped = MapPartitions(
-      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
-      generateObjAttr(rEnc),
+      udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      udf.outputObjAttr,
       deserialized)
-    SerializeFromObject(rEnc.namedExpressions, mapped)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
   }
 
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
-    val cols =
-      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+        transformTypedGroupMap(rel, commonUdf)
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(cols: _*)
-      .flatMapGroupsInPandas(pythonUdf)
-      .logicalPlan
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val cols =
+          rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+
+        Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(cols: _*)
+          .flatMapGroupsInPandas(pythonUdf)
+          .logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported")
+    }
+  }
+
+  private def transformTypedGroupMap(
+      rel: proto.GroupMap,
+      commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+    val udf = ScalaUdf(commonUdf)
+    val ds = UntypedKeyValueGroupedDataset(
+      rel.getInput,
+      rel.getGroupingExpressionsList,
+      rel.getSortingExpressionsList)
+
+    val mapped = new MapGroups(
+      udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
+      udf.inputDeserializer(ds.groupingAttributes),
+      ds.valueDeserializer,
+      ds.groupingAttributes,
+      ds.dataAttributes,
+      ds.sortOrder,
+      udf.outputObjAttr,
+      ds.analyzed)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
   }
 
   private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
+        transformTypedCoGroupMap(rel, commonUdf)
 
-    val inputCols =
-      rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
-        Column(transformExpression(expr)))
-    val otherCols =
-      rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
-        Column(transformExpression(expr)))
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
 
-    val input = Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(inputCols: _*)
-    val other = Dataset
-      .ofRows(session, transformRelation(rel.getOther))
-      .groupBy(otherCols: _*)
+        val inputCols =
+          rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+        val otherCols =
+          rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
 
-    input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+        val input = Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(inputCols: _*)
+        val other = Dataset
+          .ofRows(session, transformRelation(rel.getOther))
+          .groupBy(otherCols: _*)
+
+        input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported")
+    }
+  }
+
+  private def transformTypedCoGroupMap(
+      rel: proto.CoGroupMap,
+      commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+    val udf = ScalaUdf(commonUdf)
+    val left = UntypedKeyValueGroupedDataset(
+      rel.getInput,
+      rel.getInputGroupingExpressionsList,
+      rel.getInputSortingExpressionsList)
+    val right = UntypedKeyValueGroupedDataset(
+      rel.getOther,
+      rel.getOtherGroupingExpressionsList,
+      rel.getOtherSortingExpressionsList)
+
+    val mapped = CoGroup(
+      udf.function.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
+      // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
+      // resolve the `keyDeserializer` based on either of them, here we pick the left one.
+      udf.inputDeserializer(left.groupingAttributes),
+      left.valueDeserializer,
+      right.valueDeserializer,
+      left.groupingAttributes,
+      right.groupingAttributes,
+      left.dataAttributes,
+      right.dataAttributes,
+      left.sortOrder,
+      right.sortOrder,
+      udf.outputObjAttr,
+      left.analyzed,
+      right.analyzed)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
+  }
+
+  /**
+   * This is the untyped version of [[KeyValueGroupedDataset]].
+   */
+  private case class UntypedKeyValueGroupedDataset(
+      kEncoder: ExpressionEncoder[_],
+      vEncoder: ExpressionEncoder[_],
+      valueDeserializer: Expression,
+      analyzed: LogicalPlan,
+      dataAttributes: Seq[Attribute],
+      groupingAttributes: Seq[Attribute],
+      sortOrder: Seq[SortOrder])
+  private object UntypedKeyValueGroupedDataset {
+    def apply(
+        input: proto.Relation,
+        groupingExprs: java.util.List[proto.Expression],
+        sortingExprs: java.util.List[proto.Expression]): UntypedKeyValueGroupedDataset = {
+      val logicalPlan = transformRelation(input)
+      assert(groupingExprs.size() == 1)
+      val groupFunc = groupingExprs.asScala.toSeq

Review Comment:
   NIT More direct: `unpackUdf(groupingExprs.get(0).getCommonInlineUserDefinedFunction)`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1177021671


##########
connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala:
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import io.grpc.StatusRuntimeException
+
+import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+
+/**
+ * All tests in this class requires client UDF artifacts synced with the server. TODO: It means
+ * these tests only works with SBT for now.
+ */
+class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession {
+
+  test("mapGroups") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val values = spark
+      .range(10)
+      .groupByKey(v => v % 2)
+      .mapGroups((_, it) => it.toSeq.size)
+      .collectAsList()
+    assert(values == Arrays.asList[Int](5, 5))
+  }
+
+  test("flatGroupMap") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val values = spark
+      .range(10)
+      .groupByKey(v => v % 2)
+      .flatMapGroups((_, it) => Seq(it.toSeq.size))
+      .collectAsList()
+    assert(values == Arrays.asList[Int](5, 5))
+  }
+
+  test("keys") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val values = spark
+      .range(10)
+      .groupByKey(v => v % 2)
+      .keys
+      .collectAsList()
+    assert(values == Arrays.asList[Long](0, 1))
+  }
+
+  test("keyAs - keys") {
+    // It is okay to cast from Long to Double, but not Long to Int.

Review Comment:
   Well... try to use a longs with a max set bit > 53...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhenlineo commented on pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "zhenlineo (via GitHub)" <gi...@apache.org>.
zhenlineo commented on PR #40729:
URL: https://github.com/apache/spark/pull/40729#issuecomment-1509161614

   CC @amaliujia @vicennial Be free to give any review feedbacks.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhenlineo commented on pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "zhenlineo (via GitHub)" <gi...@apache.org>.
zhenlineo commented on PR #40729:
URL: https://github.com/apache/spark/pull/40729#issuecomment-1507810721

   cc @xinrong-meng @grundprinzip For the proto change.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1169435477


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, SingleInputUserDefinedFunction}
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
+   * specified type. The mapping of key columns to the type follows the same rules as `as` on
+   * [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+    mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
+  }
+
+  /**
+   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
+   * the Dataset to extract the keys and then running a distinct operation on those.
+   *
+   * @since 3.5.0
+   */
+  def keys: Dataset[K] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    flatMapSortedGroups()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = {
+    flatMapSortedGroups(SortExprs: _*)(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+    flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
+   * the function will be passed the grouping key and 2 iterators containing all elements in the
+   * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
+   * elements of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    cogroupSorted(other)()()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
+   * function will be passed the grouping key and 2 iterators containing all elements in the group
+   * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+      UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+}
+
+/**
+ * This class is the implementation of class [[KeyValueGroupedDataset]]. This class memorizes the
+ * initial types of the grouping function so that the original function will be sent to the server
+ * to perform the grouping first. Then any type modifications on the keys and the values will be
+ * applied sequentially to ensure the final type of the result remains the same as how
+ * [[KeyValueGroupedDataset]] behaves on the server.
+ */
+private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
+    ds: Dataset[IV],

Review Comment:
   Just wondering why this isn't a `private val`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "xinrong-meng (via GitHub)" <gi...@apache.org>.
xinrong-meng commented on PR #40729:
URL: https://github.com/apache/spark/pull/40729#issuecomment-1515123046

   Thanks @zhenlineo , the proto changes look nice.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [WIP][CONNECT] Adding groupByKey + mapGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1165743716


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, SingleInputUserDefinedFunction}
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {

Review Comment:
   Why split the implementation into two classes?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] xinrong-meng commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "xinrong-meng (via GitHub)" <gi...@apache.org>.
xinrong-meng commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1171665334


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -838,6 +841,12 @@ message CoGroupMap {
 
   // (Required) Input user-defined function.
   CommonInlineUserDefinedFunction func = 5;
+
+  // (Optional) Expressions for sorting. Only used by Scala Sorted CoGroup Map API.
+  repeated Expression input_sorting_expressions = 6;

Review Comment:
   When there is only one relation involved, we name it `input` by convention. Thus, an additional relation is named `other`.  I don't have an objection if you prefer renaming. I would suggest minimizing changes to existing proto though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1173266690


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
+   * specified type. The mapping of key columns to the type follows the same rules as `as` on
+   * [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+    mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
+  }
+
+  /**
+   * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
+   * the Dataset to extract the keys and then running a distinct operation on those.
+   *
+   * @since 3.5.0
+   */
+  def keys: Dataset[K] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    flatMapSortedGroups()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an iterator containing elements of an arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and a sorted iterator that contains all of the elements
+   * in the group. The function can return an iterator containing elements of an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = {
+    flatMapSortedGroups(SortExprs: _*)(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+    flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For each unique group, the
+   * function will be passed the group key and an iterator that contains all of the elements in
+   * the group. The function can return an element of arbitrary type which will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory. However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
+   * the function will be passed the grouping key and 2 iterators containing all elements in the
+   * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
+   * elements of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    cogroupSorted(other)()()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
+   * function will be passed the grouping key and 2 iterators containing all elements in the group
+   * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+   * iterator containing elements of an arbitrary type which will be returned as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+   * sorted according to the given sort expressions. That sorting does not add computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+      UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+}
+
+/**
+ * This class is the implementation of class [[KeyValueGroupedDataset]]. This class memorizes the
+ * initial types of the grouping function so that the original function will be sent to the server
+ * to perform the grouping first. Then any type modifications on the keys and the values will be
+ * applied sequentially to ensure the final type of the result remains the same as how
+ * [[KeyValueGroupedDataset]] behaves on the server.
+ */
+private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
+    private val ds: Dataset[IV],
+    private val sparkSession: SparkSession,
+    private val plan: proto.Plan,
+    private val ikEncoder: AgnosticEncoder[IK],
+    private val kEncoder: AgnosticEncoder[K],
+    private val groupingFunc: IV => IK,
+    private val valueMapFunc: IV => V)
+    extends KeyValueGroupedDataset[K, V] {
+
+  private val ivEncoder = ds.encoder
+
+  override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    new KeyValueGroupedDatasetImpl[L, V, IK, IV](
+      ds,
+      sparkSession,
+      plan,
+      ikEncoder,
+      encoderFor[L],
+      groupingFunc,
+      valueMapFunc)
+  }
+
+  override def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
+    new KeyValueGroupedDatasetImpl[K, W, IK, IV](
+      ds,
+      sparkSession,
+      plan,
+      ikEncoder,
+      kEncoder,
+      groupingFunc,
+      valueMapFunc.andThen(valueFunc))
+  }
+
+  override def keys: Dataset[K] = {
+    ds.map(groupingFunc)(ikEncoder)
+      .dropDuplicates()
+      .as(kEncoder)
+  }
+
+  override def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    // Apply mapValues changes to the udf
+    val nf: (K, Iterator[IV]) => TraversableOnce[U] =
+      UdfUtils.mapValuesAdaptor(f, valueMapFunc)

Review Comment:
   A bit of OCD. We only we would need this step when there is an actual valueMapFunc right? We could skip when it is identity right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on pull request #40729: [SPARK-43136][CONNECT] Adding groupByKey + mapGroup + coGroup functions

Posted by "hvanhovell (via GitHub)" <gi...@apache.org>.
hvanhovell commented on PR #40729:
URL: https://github.com/apache/spark/pull/40729#issuecomment-1522405921

   Merging to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org