You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/31 04:13:43 UTC

[spark] branch master updated: [SPARK-41391][SQL] The output column name of groupBy.agg(count_distinct) is incorrect

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cb7d0828062 [SPARK-41391][SQL] The output column name of groupBy.agg(count_distinct) is incorrect
cb7d0828062 is described below

commit cb7d08280623794b238297d2f3de8abdc8b72bdb
Author: Ritika Maheshwari <ri...@gmail.com>
AuthorDate: Fri Mar 31 12:13:23 2023 +0800

    [SPARK-41391][SQL] The output column name of groupBy.agg(count_distinct) is incorrect
    
    ### What changes were proposed in this pull request?
    
    correct the output column name of groupBy.agg(count_distinct),  so the "*" is expanded correctly into column names and the output column has the distinct keyword.
    
    ### Why are the changes needed?
    
    Output column name for groupBy.agg(count_distinct)  is incorrect . However similar queries in spark sql return correct output column. For groupBy.agg queries on dataframe "*" is not expanded correctly in the output column  and the distinct keyword is missing from output column.
    
    ```
    // initializing data
    scala> val df = spark.range(1, 10).withColumn("value", lit(1))
    df: org.apache.spark.sql.DataFrame = [id: bigint, value: int]
    scala> df.createOrReplaceTempView("table")
    
    // Dataframe  aggregate queries with incorrect output column
    scala> df.groupBy("id").agg(count_distinct($"*"))
    res3: org.apache.spark.sql.DataFrame = [id: bigint, count(unresolvedstar()): bigint]
    scala> df.groupBy("id").agg(count_distinct($"value"))
    res1: org.apache.spark.sql.DataFrame = [id: bigint, count(value): bigint]
    
    // Spark Sql aggregate queries with correct output column
    scala> spark.sql(" SELECT id, COUNT(DISTINCT *) FROM table GROUP BY id ")
    res4: org.apache.spark.sql.DataFrame = [id: bigint, count(DISTINCT id, value): bigint]
    scala> spark.sql(" SELECT id, COUNT(DISTINCT value) FROM table GROUP BY id ")
    res2: org.apache.spark.sql.DataFrame = [id: bigint, count(DISTINCT value): bigint]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added UT
    
    Closes #40116 from ritikam2/master.
    
    Authored-by: Ritika Maheshwari <ri...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../org/apache/spark/sql/RelationalGroupedDataset.scala      |  1 +
 .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++++++++
 2 files changed, 13 insertions(+)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 1c2e309bdaf..31c303921f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -89,6 +89,7 @@ class RelationalGroupedDataset protected[sql](
     case expr: NamedExpression => expr
     case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
       UnresolvedAlias(a, Some(Column.generateAlias))
+    case u: UnresolvedFunction => UnresolvedAlias(expr, None)
     case expr: Expression => Alias(expr, toPrettySQL(expr))()
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a15c049715b..d4c4c7c9b16 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1137,6 +1137,18 @@ class DataFrameSuite extends QueryTest
     checkAnswer(approxSummaryDF, approxSummaryResult)
   }
 
+  test("SPARK-41391: Correct the output column name of groupBy.agg(count_distinct)") {
+    withTempView("person") {
+      person.createOrReplaceTempView("person")
+      val df1 = person.groupBy("id").agg(count_distinct(col("name")))
+      val df2 = spark.sql("SELECT id, COUNT(DISTINCT name) FROM person GROUP BY id")
+      assert(df1.columns === df2.columns)
+      val df3 = person.groupBy("id").agg(count_distinct(col("*")))
+      val df4 = spark.sql("SELECT id, COUNT(DISTINCT *) FROM person GROUP BY id")
+      assert(df3.columns === df4.columns)
+    }
+  }
+
   test("summary advanced") {
     val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
     val orderMatters = person2.summary(stats: _*)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org