You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/30 11:06:39 UTC
[spark] branch master updated: [SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 32ff77cdb8e [SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF
32ff77cdb8e is described below
commit 32ff77cdb8ef4973494beb1a31ced05ea493dc6d
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Wed Nov 30 19:06:12 2022 +0800
[SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF
### What changes were proposed in this pull request?
Previously, the `avg` function was missing in the `GroupedData` class. This patch adds this method and the necessary plan transformation using an unresolved function. In addition, it identified a small issue where when an alias is used for a grouping column, the planner would incorrectly try to wrap the existing alias expression using an unresolved alias which would then fail.
```
df = (
self.connect.range(10)
.groupBy((col("id") % lit(2)).alias("moded"))
.avg("id")
.sort("moded")
)
```
### Why are the changes needed?
Bug / Compatibility
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38841 from grundprinzip/SPARK-41325.
Authored-by: Martin Grund <ma...@databricks.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../spark/sql/connect/planner/SparkConnectPlanner.scala | 3 ++-
python/pyspark/sql/connect/dataframe.py | 4 ++++
python/pyspark/sql/tests/connect/test_connect_basic.py | 13 +++++++++++++
3 files changed, 19 insertions(+), 1 deletion(-)
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 7b9e13cadab..d1d4c3d4fa9 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -682,7 +682,8 @@ class SparkConnectPlanner(session: SparkSession) {
rel.getGroupingExpressionsList.asScala
.map(transformExpression)
.map {
- case x @ UnresolvedAttribute(_) => x
+ case ua @ UnresolvedAttribute(_) => ua
+ case a @ Alias(_, _) => a
case x => UnresolvedAlias(x)
}
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index c9960a71fb8..ebfb52cdd74 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -84,6 +84,10 @@ class GroupedData(object):
expr = self._map_cols_to_expression("sum", col)
return self.agg(expr)
+ def avg(self, col: Union[Column, str]) -> "DataFrame":
+ expr = self._map_cols_to_expression("avg", col)
+ return self.agg(expr)
+
def count(self) -> "DataFrame":
return self.agg([scalar_function("count", lit(1))])
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f518a09ad4a..22d57994794 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -837,6 +837,19 @@ class SparkConnectTests(SparkConnectSQLTestCase):
ndf = self.connect.read.table("parquet_test")
self.assertEqual(set(df.collect()), set(ndf.collect()))
+ def test_agg_with_avg(self):
+ # SPARK-41325: groupby.avg()
+ df = (
+ self.connect.range(10)
+ .groupBy((col("id") % lit(2)).alias("moded"))
+ .avg("id")
+ .sort("moded")
+ )
+ res = df.collect()
+ self.assertEqual(2, len(res))
+ self.assertEqual(4.0, res[0][1])
+ self.assertEqual(5.0, res[1][1])
+
class ChannelBuilderTests(ReusedPySparkTestCase):
def test_invalid_connection_strings(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org