You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/09/29 13:50:41 UTC

spark git commit: [SPARK-25048][SQL] Pivoting by multiple columns in Scala/Java

Repository: spark
Updated Branches:
  refs/heads/master dcb9a97f3 -> 623c2ec4e


[SPARK-25048][SQL] Pivoting by multiple columns in Scala/Java

## What changes were proposed in this pull request?

In the PR, I propose to extend implementation of existing method:
```
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset
```
to support values of the struct type. This allows pivoting by multiple columns combined by `struct`:
```
trainingSales
      .groupBy($"sales.year")
      .pivot(
        pivotColumn = struct(lower($"sales.course"), $"training"),
        values = Seq(
          struct(lit("dotnet"), lit("Experts")),
          struct(lit("java"), lit("Dummies")))
      ).agg(sum($"sales.earnings"))
```

## How was this patch tested?

Added a test for values specified via `struct` in Java and Scala.

Closes #22316 from MaxGekk/pivoting-by-multiple-columns2.

Lead-authored-by: Maxim Gekk <ma...@databricks.com>
Co-authored-by: Maxim Gekk <ma...@gmail.com>
Signed-off-by: hyukjinkwon <gu...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/623c2ec4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/623c2ec4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/623c2ec4

Branch: refs/heads/master
Commit: 623c2ec4ef3776bc5e2cac2c66300ddc6264db54
Parents: dcb9a97
Author: Maxim Gekk <ma...@databricks.com>
Authored: Sat Sep 29 21:50:35 2018 +0800
Committer: hyukjinkwon <gu...@apache.org>
Committed: Sat Sep 29 21:50:35 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/RelationalGroupedDataset.scala    | 17 +++++++++++++--
 .../apache/spark/sql/JavaDataFrameSuite.java    | 16 ++++++++++++++
 .../apache/spark/sql/DataFramePivotSuite.scala  | 23 ++++++++++++++++++++
 3 files changed, 54 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/623c2ec4/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index d700fb8..dbacdbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -330,6 +330,15 @@ class RelationalGroupedDataset protected[sql](
    *   df.groupBy("year").pivot("course").sum("earnings")
    * }}}
    *
+   * From Spark 2.5.0, values can be literal columns, for instance, struct. For pivoting by
+   * multiple columns, use the `struct` function to combine the columns and values:
+   *
+   * {{{
+   *   df.groupBy("year")
+   *     .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
+   *     .agg(sum($"earnings"))
+   * }}}
+   *
    * @param pivotColumn Name of the column to pivot.
    * @param values List of values that will be translated to columns in the output DataFrame.
    * @since 1.6.0
@@ -413,10 +422,14 @@ class RelationalGroupedDataset protected[sql](
   def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
     groupType match {
       case RelationalGroupedDataset.GroupByType =>
+        val valueExprs = values.map(_ match {
+          case c: Column => c.expr
+          case v => Literal.apply(v)
+        })
         new RelationalGroupedDataset(
           df,
           groupingExprs,
-          RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
+          RelationalGroupedDataset.PivotType(pivotColumn.expr, valueExprs))
       case _: RelationalGroupedDataset.PivotType =>
         throw new UnsupportedOperationException("repeated pivots are not supported")
       case _ =>
@@ -561,5 +574,5 @@ private[sql] object RelationalGroupedDataset {
   /**
    * To indicate it's the PIVOT
    */
-  private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
+  private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/623c2ec4/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 3f37e58..00f41d6 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -317,6 +317,22 @@ public class JavaDataFrameSuite {
     Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
   }
 
+  @Test
+  public void pivotColumnValues() {
+    Dataset<Row> df = spark.table("courseSales");
+    List<Row> actual = df.groupBy("year")
+      .pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java")))
+      .agg(sum("earnings")).orderBy("year").collectAsList();
+
+    Assert.assertEquals(2012, actual.get(0).getInt(0));
+    Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
+    Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);
+
+    Assert.assertEquals(2013, actual.get(1).getInt(0));
+    Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
+    Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
+  }
+
   private String getResource(String resource) {
     try {
       // The following "getResource" has different behaviors in SBT and Maven.

http://git-wip-us.apache.org/repos/asf/spark/blob/623c2ec4/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index b972b9e..02ab197 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -308,4 +308,27 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
 
     assert(exception.getMessage.contains("aggregate functions are not allowed"))
   }
+
+  test("pivoting column list with values") {
+    val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil
+    val df = trainingSales
+      .groupBy($"sales.year")
+      .pivot(struct(lower($"sales.course"), $"training"), Seq(
+        struct(lit("dotnet"), lit("Experts")),
+        struct(lit("java"), lit("Dummies")))
+      ).agg(sum($"sales.earnings"))
+
+    checkAnswer(df, expected)
+  }
+
+  test("pivoting column list") {
+    val exception = intercept[RuntimeException] {
+      trainingSales
+        .groupBy($"sales.year")
+        .pivot(struct(lower($"sales.course"), $"training"))
+        .agg(sum($"sales.earnings"))
+        .collect()
+    }
+    assert(exception.getMessage.contains("Unsupported literal type"))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org