You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/30 11:06:39 UTC

[spark] branch master updated: [SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 32ff77cdb8e [SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF
32ff77cdb8e is described below

commit 32ff77cdb8ef4973494beb1a31ced05ea493dc6d
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Wed Nov 30 19:06:12 2022 +0800

    [SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF
    
    ### What changes were proposed in this pull request?
    Previously, the `avg` function was missing in the `GroupedData` class. This patch adds this method and the necessary plan transformation using an unresolved function. In addition, it identified a small issue where when an alias is used for a grouping column, the planner would incorrectly try to wrap the existing alias expression using an unresolved alias which would then fail.
    
    ```
    df = (
        self.connect.range(10)
        .groupBy((col("id") % lit(2)).alias("moded"))
        .avg("id")
        .sort("moded")
    )
    ```
    
    ### Why are the changes needed?
    Bug / Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #38841 from grundprinzip/SPARK-41325.
    
    Authored-by: Martin Grund <ma...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../spark/sql/connect/planner/SparkConnectPlanner.scala     |  3 ++-
 python/pyspark/sql/connect/dataframe.py                     |  4 ++++
 python/pyspark/sql/tests/connect/test_connect_basic.py      | 13 +++++++++++++
 3 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 7b9e13cadab..d1d4c3d4fa9 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -682,7 +682,8 @@ class SparkConnectPlanner(session: SparkSession) {
       rel.getGroupingExpressionsList.asScala
         .map(transformExpression)
         .map {
-          case x @ UnresolvedAttribute(_) => x
+          case ua @ UnresolvedAttribute(_) => ua
+          case a @ Alias(_, _) => a
           case x => UnresolvedAlias(x)
         }
 
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index c9960a71fb8..ebfb52cdd74 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -84,6 +84,10 @@ class GroupedData(object):
         expr = self._map_cols_to_expression("sum", col)
         return self.agg(expr)
 
+    def avg(self, col: Union[Column, str]) -> "DataFrame":
+        expr = self._map_cols_to_expression("avg", col)
+        return self.agg(expr)
+
     def count(self) -> "DataFrame":
         return self.agg([scalar_function("count", lit(1))])
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f518a09ad4a..22d57994794 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -837,6 +837,19 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         ndf = self.connect.read.table("parquet_test")
         self.assertEqual(set(df.collect()), set(ndf.collect()))
 
+    def test_agg_with_avg(self):
+        # SPARK-41325: groupby.avg()
+        df = (
+            self.connect.range(10)
+            .groupBy((col("id") % lit(2)).alias("moded"))
+            .avg("id")
+            .sort("moded")
+        )
+        res = df.collect()
+        self.assertEqual(2, len(res))
+        self.assertEqual(4.0, res[0][1])
+        self.assertEqual(5.0, res[1][1])
+
 
 class ChannelBuilderTests(ReusedPySparkTestCase):
     def test_invalid_connection_strings(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org