You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/25 02:30:33 UTC

[spark] branch master updated: [SPARK-42541][CONNECT] Support Pivot with provided pivot column values

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 34a2d95dadf [SPARK-42541][CONNECT] Support Pivot with provided pivot column values
34a2d95dadf is described below

commit 34a2d95dadfca2ee643eb937d50f12e3b8b148eb
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Fri Feb 24 22:30:20 2023 -0400

    [SPARK-42541][CONNECT] Support Pivot with provided pivot column values
    
    ### What changes were proposed in this pull request?
    
    Support Pivot with provided pivot column values. Not supporting Pivot without providing column values because that requires to do max value check which depends on the implementation of Spark configuration in Spark Connect.
    
    ### Why are the changes needed?
    
    API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #40145 from amaliujia/rw-pivot.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../spark/sql/RelationalGroupedDataset.scala       | 138 ++++++++++++++++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |   7 ++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 +
 .../query-tests/explain-results/pivot.explain      |   4 +
 .../test/resources/query-tests/queries/pivot.json  |  45 +++++++
 .../resources/query-tests/queries/pivot.proto.bin  | Bin 0 -> 97 bytes
 6 files changed, 196 insertions(+), 2 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 76db231db9e..c918061ac46 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -38,7 +38,8 @@ import org.apache.spark.connect.proto
 class RelationalGroupedDataset protected[sql] (
     private[sql] val df: DataFrame,
     private[sql] val groupingExprs: Seq[proto.Expression],
-    groupType: proto.Aggregate.GroupType) {
+    groupType: proto.Aggregate.GroupType,
+    pivot: Option[proto.Aggregate.Pivot] = None) {
 
   private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
     df.session.newDataset { builder =>
@@ -47,7 +48,6 @@ class RelationalGroupedDataset protected[sql] (
         .addAllGroupingExpressions(groupingExprs.asJava)
         .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava)
 
-      // TODO: support Pivot.
       groupType match {
         case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
           builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
@@ -55,6 +55,11 @@ class RelationalGroupedDataset protected[sql] (
           builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
         case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
           builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+        case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT =>
+          assert(pivot.isDefined)
+          builder.getAggregateBuilder
+            .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
+            .setPivot(pivot.get)
         case g => throw new UnsupportedOperationException(g.toString)
       }
     }
@@ -234,4 +239,133 @@ class RelationalGroupedDataset protected[sql] (
   def sum(colNames: String*): DataFrame = {
     toDF(colNames.map(colName => functions.sum(colName)))
   }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are
+   * two versions of pivot function: one that requires the caller to specify the list of distinct
+   * values to pivot on, and one that does not. The latter is more concise but less efficient,
+   * because Spark needs to first compute the list of distinct values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by
+   * multiple columns, use the `struct` function to combine the columns and values:
+   *
+   * {{{
+   *   df.groupBy("year")
+   *     .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
+   *     .agg(sum($"earnings"))
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
+   * aggregation.
+   *
+   * There are two versions of pivot function: one that requires the caller to specify the list of
+   * distinct values to pivot on, and one that does not. The latter is more concise but less
+   * efficient, because Spark needs to first compute the list of distinct values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course as a separate column
+   *   df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings");
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an
+   * overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course as a separate column
+   *   df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
+    groupType match {
+      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
+        val valueExprs = values.map(_ match {
+          case c: Column if c.expr.hasLiteral => c.expr.getLiteral
+          case c: Column if !c.expr.hasLiteral =>
+            throw new IllegalArgumentException("values only accept literal Column")
+          case v => functions.lit(v).expr.getLiteral
+        })
+        new RelationalGroupedDataset(
+          df,
+          groupingExprs,
+          proto.Aggregate.GroupType.GROUP_TYPE_PIVOT,
+          Some(
+            proto.Aggregate.Pivot
+              .newBuilder()
+              .setCol(pivotColumn.expr)
+              .addAllValues(valueExprs.asJava)
+              .build()))
+      case _ =>
+        throw new UnsupportedOperationException()
+    }
+  }
+
+  /**
+   * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
+   * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of the
+   * `String` type.
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = {
+    pivot(pivotColumn, values.asScala.toSeq)
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c3c80a08379..be69959beac 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -129,4 +129,11 @@ class DatasetSuite
     val actualPlan = service.getAndClearLatestInputPlan()
     assert(actualPlan.equals(expectedPlan))
   }
+
+  test("Pivot") {
+    val df = ss.newDataset(_ => ())
+    intercept[IllegalArgumentException] {
+      df.groupBy().pivot(Column("c"), Seq(Column("col")))
+    }
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index f7589d957ca..465d2091ca2 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -1949,6 +1949,10 @@ class PlanGenerationTestSuite
     simple.cube("a", "b").count()
   }
 
+  test("pivot") {
+    simple.groupBy(Column("id")).pivot("a", Seq(1, 2, 3)).agg(functions.count(Column("b")))
+  }
+
   test("function lit") {
     simple.select(
       fn.lit(fn.col("id")),
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/pivot.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/pivot.explain
new file mode 100644
index 00000000000..b8cd8441237
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/pivot.explain
@@ -0,0 +1,4 @@
+Project [id#0L, __pivot_count(b) AS `count(b)`#0[0] AS 1#0L, __pivot_count(b) AS `count(b)`#0[1] AS 2#0L, __pivot_count(b) AS `count(b)`#0[2] AS 3#0L]
++- Aggregate [id#0L], [id#0L, pivotfirst(a#0, count(b)#0L, 1, 2, 3, 0, 0) AS __pivot_count(b) AS `count(b)`#0]
+   +- Aggregate [id#0L, a#0], [id#0L, a#0, count(b#0) AS count(b)#0L]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/pivot.json b/connector/connect/common/src/test/resources/query-tests/queries/pivot.json
new file mode 100644
index 00000000000..30bff04c531
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/pivot.json
@@ -0,0 +1,45 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "aggregate": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_PIVOT",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "id"
+      }
+    }],
+    "aggregateExpressions": [{
+      "unresolvedFunction": {
+        "functionName": "count",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }],
+    "pivot": {
+      "col": {
+        "unresolvedAttribute": {
+          "unparsedIdentifier": "a"
+        }
+      },
+      "values": [{
+        "integer": 1
+      }, {
+        "integer": 2
+      }, {
+        "integer": 3
+      }]
+    }
+  }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin
new file mode 100644
index 00000000000..67063209a18
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin differ


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org