You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "santosh-d3vpl3x (via GitHub)" <gi...@apache.org> on 2023/02/22 21:10:47 UTC

[GitHub] [spark] santosh-d3vpl3x commented on a diff in pull request #40122: [SPARK-42349][PYTHON] Support pandas cogroup with multiple df

santosh-d3vpl3x commented on code in PR #40122:
URL: https://github.com/apache/spark/pull/40122#discussion_r1114960540


##########
sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -582,48 +582,59 @@ class RelationalGroupedDataset protected[sql](
   }
 
   /**
-   * Applies a vectorized python user-defined function to each cogrouped data.
-   * The user-defined function defines a transformation:
-   * `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`.
-   *  For each group in the cogrouped data, all elements in the group are passed as a
-   * `pandas.DataFrame` and the results for all cogroups are combined into a new [[DataFrame]].
+   * Applies a vectorized python user-defined function to each cogrouped data. The user-defined
+   * function defines a transformation: `pandas.DataFrame`, `pandas.DataFrame` ->
+   * `pandas.DataFrame`. For each group in the cogrouped data, all elements in the group are
+   * passed as a `pandas.DataFrame` and the results for all cogroups are combined into a new
+   * [[DataFrame]].
    *
    * This function uses Apache Arrow as serialization format between Java executors and Python
    * workers.
    */
   private[sql] def flatMapCoGroupsInPandas(
-      r: RelationalGroupedDataset,
+      rs: Seq[RelationalGroupedDataset],
       expr: PythonUDF): DataFrame = {
     require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
       "Must pass a cogrouped map udf")
-    require(this.groupingExprs.length == r.groupingExprs.length,
-      "Cogroup keys must have same size: " +
-        s"${this.groupingExprs.length} != ${r.groupingExprs.length}")
+    val groupingExprLengthEquals = rs.map(_.groupingExprs.length).
+      forall(_ == this.groupingExprs.length)
+
+    require(groupingExprLengthEquals, s"Cogroup keys must have same size.")

Review Comment:
   Good suggestion, I will add them.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org