You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "hvanhovell (via GitHub)" <gi...@apache.org> on 2023/02/24 03:27:03 UTC

[GitHub] [spark] hvanhovell commented on a diff in pull request #40145: [SPARK-42541][CONNECT] Support Pivot with provided pivot column values

hvanhovell commented on code in PR #40145:
URL: https://github.com/apache/spark/pull/40145#discussion_r1116470992


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -234,4 +239,132 @@ class RelationalGroupedDataset protected[sql] (
   def sum(colNames: String*): DataFrame = {
     toDF(colNames.map(colName => functions.sum(colName)))
   }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are
+   * two versions of pivot function: one that requires the caller to specify the list of distinct
+   * values to pivot on, and one that does not. The latter is more concise but less efficient,
+   * because Spark needs to first compute the list of distinct values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by
+   * multiple columns, use the `struct` function to combine the columns and values:
+   *
+   * {{{
+   *   df.groupBy("year")
+   *     .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
+   *     .agg(sum($"earnings"))
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
+   * aggregation.
+   *
+   * There are two versions of pivot function: one that requires the caller to specify the list of
+   * distinct values to pivot on, and one that does not. The latter is more concise but less
+   * efficient, because Spark needs to first compute the list of distinct values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course as a separate column
+   *   df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings");
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an
+   * overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course as a separate column
+   *   df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
+    groupType match {
+      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
+        val valueExprs = values.map(_ match {
+          case c: Column if c.expr.hasLiteral => c.expr.getLiteral
+          case v => functions.lit(v).expr.getLiteral

Review Comment:
   If `v` is a `Column` (which can happen because of the guard in the previous case statement) then this will produce an empty literal. The same applies to a `scala.Symbol`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org