You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/17 03:03:18 UTC
[spark] branch branch-3.4 updated: [SPARK-42468][CONNECT] Implement agg by (String, String)*
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 7c1c8be960e [SPARK-42468][CONNECT] Implement agg by (String, String)*
7c1c8be960e is described below
commit 7c1c8be960ed9b30451807b460ca45ca9ddf8a72
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Thu Feb 16 23:02:51 2023 -0400
[SPARK-42468][CONNECT] Implement agg by (String, String)*
### What changes were proposed in this pull request?
Starting to support basic aggregation in Scala client. The first step is to support aggregation by strings.
### Why are the changes needed?
API coverage
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #40057 from amaliujia/rw-agg.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
(cherry picked from commit cc471a52d162d0e4d4063372253ed06a62f5cb19)
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 23 ++++
.../spark/sql/RelationalGroupedDataset.scala | 152 +++++++++++++++++++++
.../apache/spark/sql/PlanGenerationTestSuite.scala | 14 ++
.../explain-results/groupby_agg.explain | 2 +
.../resources/query-tests/queries/groupby_agg.json | 88 ++++++++++++
.../query-tests/queries/groupby_agg.proto.bin | 19 +++
6 files changed, 298 insertions(+)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 977c823f7c7..c39fc6100f5 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1035,6 +1035,29 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan:
}
}
+ /**
+ * Groups the Dataset using the specified columns, so we can run aggregation on them. See
+ * [[RelationalGroupedDataset]] for all the available aggregate functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns grouped by department.
+ * ds.groupBy($"department").avg()
+ *
+ * // Compute the max age and average salary, grouped by department and gender.
+ * ds.groupBy($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def groupBy(cols: Column*): RelationalGroupedDataset = {
+ new RelationalGroupedDataset(toDF(), cols.map(_.expr))
+ }
+
/**
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
* set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
new file mode 100644
index 00000000000..a3dfcb01fdc
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+
+/**
+ * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
+ * [[Dataset#cube cube]] or [[Dataset#rollup rollup]] (and also `pivot`).
+ *
+ * The main method is the `agg` function, which has multiple variants. This class also contains
+ * some first-order statistics such as `mean`, `sum` for convenience.
+ *
+ * @note
+ * This class was named `GroupedData` in Spark 1.x.
+ *
+ * @since 3.4.0
+ */
+class RelationalGroupedDataset protected[sql] (
+ private[sql] val df: DataFrame,
+ private[sql] val groupingExprs: Seq[proto.Expression]) {
+
+ private[this] def toDF(aggExprs: Seq[proto.Expression]): DataFrame = {
+ // TODO: support other GroupByType such as Rollup, Cube, Pivot.
+ df.session.newDataset { builder =>
+ builder.getAggregateBuilder
+ .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+ .setInput(df.plan.getRoot)
+ .addAllGroupingExpressions(groupingExprs.asJava)
+ .addAllAggregateExpressions(aggExprs.asJava)
+ }
+ }
+
+ /**
+ * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The
+ * resulting `DataFrame` will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * )
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+ toDF((aggExpr +: aggExprs).map { case (colName, expr) =>
+ strToExpr(expr, df(colName).expr)
+ })
+ }
+
+ /**
+ * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate
+ * methods. The resulting `DataFrame` will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(Map(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * ))
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def agg(exprs: Map[String, String]): DataFrame = {
+ toDF(exprs.map { case (colName, expr) =>
+ strToExpr(expr, df(colName).expr)
+ }.toSeq)
+ }
+
+ /**
+ * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods.
+ * The resulting `DataFrame` will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * import com.google.common.collect.ImmutableMap;
+ * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum"));
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def agg(exprs: java.util.Map[String, String]): DataFrame = {
+ agg(exprs.asScala.toMap)
+ }
+
+ private[this] def strToExpr(expr: String, inputExpr: proto.Expression): proto.Expression = {
+ val builder = proto.Expression.newBuilder()
+
+ expr.toLowerCase(Locale.ROOT) match {
+ // We special handle a few cases that have alias that are not in function registry.
+ case "avg" | "average" | "mean" =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName("avg")
+ .addArguments(inputExpr)
+ .setIsDistinct(false)
+ case "stddev" | "std" =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName("stddev")
+ .addArguments(inputExpr)
+ .setIsDistinct(false)
+ // Also special handle count because we need to take care count(*).
+ case "count" | "size" =>
+ // Turn count(*) into count(1)
+ inputExpr match {
+ case s if s.hasUnresolvedStar =>
+ val exprBuilder = proto.Expression.newBuilder
+ exprBuilder.getLiteralBuilder.setInteger(1)
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName("count")
+ .addArguments(exprBuilder)
+ .setIsDistinct(false)
+ case _ =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName("count")
+ .addArguments(inputExpr)
+ .setIsDistinct(false)
+ }
+ case name =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName(name)
+ .addArguments(inputExpr)
+ .setIsDistinct(false)
+ }
+ builder.build()
+ }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 7b5d8bd1018..8d4550dfe4f 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -784,6 +784,20 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit
select(fn.max(Column("id")))
}
+ test("groupby agg") {
+ simple
+ .groupBy(Column("id"))
+ .agg(
+ "a" -> "max",
+ "b" -> "stddev",
+ "b" -> "std",
+ "b" -> "mean",
+ "b" -> "average",
+ "b" -> "avg",
+ "*" -> "size",
+ "a" -> "count")
+ }
+
test("function lit") {
select(
fn.lit(fn.col("id")),
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain
new file mode 100644
index 00000000000..acb42c1408c
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain
@@ -0,0 +1,2 @@
+Aggregate [id#0L], [id#0L, max(a#0) AS max(a)#0, stddev(b#0) AS stddev(b)#0, stddev(b#0) AS stddev(b)#0, avg(b#0) AS avg(b)#0, avg(b#0) AS avg(b)#0, avg(b#0) AS avg(b)#0, count(1) AS count(1)#0L, count(a#0) AS count(a)#0L]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
new file mode 100644
index 00000000000..7838a89974d
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
@@ -0,0 +1,88 @@
+{
+ "aggregate": {
+ "input": {
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "groupType": "GROUP_TYPE_GROUPBY",
+ "groupingExpressions": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "id"
+ }
+ }],
+ "aggregateExpressions": [{
+ "unresolvedFunction": {
+ "functionName": "max",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a"
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "stddev",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "b"
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "stddev",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "b"
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "avg",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "b"
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "avg",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "b"
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "avg",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "b"
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "count",
+ "arguments": [{
+ "literal": {
+ "integer": 1
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "count",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a"
+ }
+ }]
+ }
+ }]
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
new file mode 100644
index 00000000000..9c6d1cca8a4
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
@@ -0,0 +1,19 @@
+J�
+$Z" struct<id:bigint,a:int,b:double>
+id"
+max
+a"
+stddev
+b"
+stddev
+b"
+avg
+b"
+avg
+b"
+avg
+b"
+count
+0"
+count
+a
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org