You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/17 03:03:18 UTC

[spark] branch branch-3.4 updated: [SPARK-42468][CONNECT] Implement agg by (String, String)*

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 7c1c8be960e [SPARK-42468][CONNECT] Implement agg by (String, String)*
7c1c8be960e is described below

commit 7c1c8be960ed9b30451807b460ca45ca9ddf8a72
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Thu Feb 16 23:02:51 2023 -0400

    [SPARK-42468][CONNECT] Implement agg by (String, String)*
    
    ### What changes were proposed in this pull request?
    
    Starting to support basic aggregation in Scala client. The first step is to support aggregation by strings.
    
    ### Why are the changes needed?
    
    API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #40057 from amaliujia/rw-agg.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
    (cherry picked from commit cc471a52d162d0e4d4063372253ed06a62f5cb19)
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  23 ++++
 .../spark/sql/RelationalGroupedDataset.scala       | 152 +++++++++++++++++++++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |  14 ++
 .../explain-results/groupby_agg.explain            |   2 +
 .../resources/query-tests/queries/groupby_agg.json |  88 ++++++++++++
 .../query-tests/queries/groupby_agg.proto.bin      |  19 +++
 6 files changed, 298 insertions(+)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 977c823f7c7..c39fc6100f5 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1035,6 +1035,29 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan:
     }
   }
 
+  /**
+   * Groups the Dataset using the specified columns, so we can run aggregation on them. See
+   * [[RelationalGroupedDataset]] for all the available aggregate functions.
+   *
+   * {{{
+   *   // Compute the average for all numeric columns grouped by department.
+   *   ds.groupBy($"department").avg()
+   *
+   *   // Compute the max age and average salary, grouped by department and gender.
+   *   ds.groupBy($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def groupBy(cols: Column*): RelationalGroupedDataset = {
+    new RelationalGroupedDataset(toDF(), cols.map(_.expr))
+  }
+
   /**
    * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
    * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
new file mode 100644
index 00000000000..a3dfcb01fdc
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+
+/**
+ * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
+ * [[Dataset#cube cube]] or [[Dataset#rollup rollup]] (and also `pivot`).
+ *
+ * The main method is the `agg` function, which has multiple variants. This class also contains
+ * some first-order statistics such as `mean`, `sum` for convenience.
+ *
+ * @note
+ *   This class was named `GroupedData` in Spark 1.x.
+ *
+ * @since 3.4.0
+ */
+class RelationalGroupedDataset protected[sql] (
+    private[sql] val df: DataFrame,
+    private[sql] val groupingExprs: Seq[proto.Expression]) {
+
+  private[this] def toDF(aggExprs: Seq[proto.Expression]): DataFrame = {
+    // TODO: support other GroupByType such as Rollup, Cube, Pivot.
+    df.session.newDataset { builder =>
+      builder.getAggregateBuilder
+        .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+        .setInput(df.plan.getRoot)
+        .addAllGroupingExpressions(groupingExprs.asJava)
+        .addAllAggregateExpressions(aggExprs.asJava)
+    }
+  }
+
+  /**
+   * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The
+   * resulting `DataFrame` will also contain the grouping columns.
+   *
+   * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+   * {{{
+   *   // Selects the age of the oldest employee and the aggregate expense for each department
+   *   df.groupBy("department").agg(
+   *     "age" -> "max",
+   *     "expense" -> "sum"
+   *   )
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+    toDF((aggExpr +: aggExprs).map { case (colName, expr) =>
+      strToExpr(expr, df(colName).expr)
+    })
+  }
+
+  /**
+   * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate
+   * methods. The resulting `DataFrame` will also contain the grouping columns.
+   *
+   * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+   * {{{
+   *   // Selects the age of the oldest employee and the aggregate expense for each department
+   *   df.groupBy("department").agg(Map(
+   *     "age" -> "max",
+   *     "expense" -> "sum"
+   *   ))
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def agg(exprs: Map[String, String]): DataFrame = {
+    toDF(exprs.map { case (colName, expr) =>
+      strToExpr(expr, df(colName).expr)
+    }.toSeq)
+  }
+
+  /**
+   * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods.
+   * The resulting `DataFrame` will also contain the grouping columns.
+   *
+   * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+   * {{{
+   *   // Selects the age of the oldest employee and the aggregate expense for each department
+   *   import com.google.common.collect.ImmutableMap;
+   *   df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum"));
+   * }}}
+   *
+   * @since 3.4.0
+   */
+  def agg(exprs: java.util.Map[String, String]): DataFrame = {
+    agg(exprs.asScala.toMap)
+  }
+
+  private[this] def strToExpr(expr: String, inputExpr: proto.Expression): proto.Expression = {
+    val builder = proto.Expression.newBuilder()
+
+    expr.toLowerCase(Locale.ROOT) match {
+      // We special handle a few cases that have alias that are not in function registry.
+      case "avg" | "average" | "mean" =>
+        builder.getUnresolvedFunctionBuilder
+          .setFunctionName("avg")
+          .addArguments(inputExpr)
+          .setIsDistinct(false)
+      case "stddev" | "std" =>
+        builder.getUnresolvedFunctionBuilder
+          .setFunctionName("stddev")
+          .addArguments(inputExpr)
+          .setIsDistinct(false)
+      // Also special handle count because we need to take care count(*).
+      case "count" | "size" =>
+        // Turn count(*) into count(1)
+        inputExpr match {
+          case s if s.hasUnresolvedStar =>
+            val exprBuilder = proto.Expression.newBuilder
+            exprBuilder.getLiteralBuilder.setInteger(1)
+            builder.getUnresolvedFunctionBuilder
+              .setFunctionName("count")
+              .addArguments(exprBuilder)
+              .setIsDistinct(false)
+          case _ =>
+            builder.getUnresolvedFunctionBuilder
+              .setFunctionName("count")
+              .addArguments(inputExpr)
+              .setIsDistinct(false)
+        }
+      case name =>
+        builder.getUnresolvedFunctionBuilder
+          .setFunctionName(name)
+          .addArguments(inputExpr)
+          .setIsDistinct(false)
+    }
+    builder.build()
+  }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 7b5d8bd1018..8d4550dfe4f 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -784,6 +784,20 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit
     select(fn.max(Column("id")))
   }
 
+  test("groupby agg") {
+    simple
+      .groupBy(Column("id"))
+      .agg(
+        "a" -> "max",
+        "b" -> "stddev",
+        "b" -> "std",
+        "b" -> "mean",
+        "b" -> "average",
+        "b" -> "avg",
+        "*" -> "size",
+        "a" -> "count")
+  }
+
   test("function lit") {
     select(
       fn.lit(fn.col("id")),
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain
new file mode 100644
index 00000000000..acb42c1408c
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain
@@ -0,0 +1,2 @@
+Aggregate [id#0L], [id#0L, max(a#0) AS max(a)#0, stddev(b#0) AS stddev(b)#0, stddev(b#0) AS stddev(b)#0, avg(b#0) AS avg(b)#0, avg(b#0) AS avg(b)#0, avg(b#0) AS avg(b)#0, count(1) AS count(1)#0L, count(a#0) AS count(a)#0L]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
new file mode 100644
index 00000000000..7838a89974d
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
@@ -0,0 +1,88 @@
+{
+  "aggregate": {
+    "input": {
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_GROUPBY",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "id"
+      }
+    }],
+    "aggregateExpressions": [{
+      "unresolvedFunction": {
+        "functionName": "max",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "a"
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "stddev",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "stddev",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "avg",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "avg",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "avg",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "count",
+        "arguments": [{
+          "literal": {
+            "integer": 1
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "count",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "a"
+          }
+        }]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
new file mode 100644
index 00000000000..9c6d1cca8a4
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
@@ -0,0 +1,19 @@
+J�
+$Z" struct<id:bigint,a:int,b:double>
+id"
+max
+a"
+stddev
+b"
+stddev
+b"
+avg
+b"
+avg
+b"
+avg
+b"
+count
+0"
+count
+a
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org