You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/12/02 13:48:24 UTC

[GitHub] [spark] zhengruifeng commented on a diff in pull request #38883: [Spark 41366][CONNECT] DF.groupby.agg() should be compatible

zhengruifeng commented on code in PR #38883:
URL: https://github.com/apache/spark/pull/38883#discussion_r1038160665


##########
python/pyspark/sql/connect/dataframe.py:
##########
@@ -55,8 +60,109 @@ def __init__(self, df: "DataFrame", *grouping_cols: Union[Column, str]) -> None:
         self._df = df
         self._grouping_cols = [x if isinstance(x, Column) else df[x] for x in grouping_cols]
 
-    def agg(self, measures: Sequence[Column]) -> "DataFrame":
-        assert len(measures) > 0, "exprs should not be empty"
+    @overload
+    def agg(self, *exprs: Column) -> "DataFrame":
+        ...
+
+    @overload
+    def agg(self, __exprs: Dict[str, str]) -> "DataFrame":
+        ...
+
+    def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
+        """Compute aggregates and returns the result as a :class:`DataFrame`.
+
+        The available aggregate functions can be:
+
+        1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
+
+        2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
+
+           .. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
+               a full shuffle is required. Also, all the data of a group will be loaded into
+               memory, so the user should be aware of the potential OOM risk if data is skewed
+               and certain groups are too large to fit in memory.
+
+           .. seealso:: :func:`pyspark.sql.functions.pandas_udf`
+
+        If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
+        is the column to perform aggregation on, and the value is the aggregate function.
+
+        Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
+
+        .. versionadded:: 1.3.0
+
+        Parameters
+        ----------
+        exprs : dict
+            a dict mapping from column name (string) to aggregate functions (string),
+            or a list of :class:`Column`.
+
+        Notes
+        -----
+        Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
+        in a single call to this function.
+
+        Examples
+        --------
+        >>> from pyspark.sql import functions as F
+        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"])
+        >>> df.show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  2|Alice|
+        |  3|Alice|
+        |  5|  Bob|
+        | 10|  Bob|
+        +---+-----+
+
+        Group-by name, and count each group.
+
+        >>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show()
+        +-----+--------+
+        | name|count(1)|
+        +-----+--------+
+        |Alice|       2|
+        |  Bob|       2|
+        +-----+--------+
+
+        Group-by name, and calculate the minimum age.
+
+        >>> df.groupBy(df.name).agg(F.min(df.age)).sort("name").show()
+        +-----+--------+
+        | name|min(age)|
+        +-----+--------+
+        |Alice|       2|
+        |  Bob|       5|
+        +-----+--------+
+
+        Same as above but uses pandas UDF.
+
+        >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG)  # doctest: +SKIP
+        ... def min_udf(v):
+        ...     return v.min()
+        ...
+        >>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show()  # doctest: +SKIP
+        +-----+------------+
+        | name|min_udf(age)|
+        +-----+------------+
+        |Alice|           2|
+        |  Bob|           5|
+        +-----+------------+
+        """
+        assert exprs, "exprs should not be empty"
+        if len(exprs) == 1 and isinstance(exprs[0], dict):
+            from pyspark.sql.connect.function_builder import functions as FB
+
+            # Convert the dict into key value pairs
+            measures = [Column(ScalarFunctionExpression(exprs[0][k], col(k))) for k in exprs[0]]

Review Comment:
   I think this PR needs rebasing now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org