You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/08/01 01:39:47 UTC

[spark] branch master updated: [SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8585fafab86 [SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API
8585fafab86 is described below

commit 8585fafab8633c02e6f1b989acd2bbdb0eb1678e
Author: Enrico Minack <gi...@enrico.minack.dev>
AuthorDate: Mon Aug 1 09:39:12 2022 +0800

    [SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API
    
    ### What changes were proposed in this pull request?
    This adds `unpivot` and its alias `melt` to the PySpark API. It calls into Scala `Dataset.unpivot` (#36150). Small difference to Scala method signature is that PySpark method has default values. This is similar to `melt` in Spark Pandas API.
    
    ### Why are the changes needed?
    To support `unpivot` in Python.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, adds `DataFrame.unpivot` and `DataFrame.melt` to PySpark API.
    
    ### How was this patch tested?
    Added test to `test_dataframe.py`.
    
    Closes #37304 from EnricoMi/branch-pyspark-unpivot.
    
    Authored-by: Enrico Minack <gi...@enrico.minack.dev>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/dataframe.py                    | 134 +++++++++++++++++++
 python/pyspark/sql/tests/test_dataframe.py         | 144 +++++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  11 ++
 3 files changed, 289 insertions(+)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 481dafa310d..8c9632fe766 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2238,6 +2238,140 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         return GroupedData(jgd, self)
 
+    def unpivot(
+        self,
+        ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+        values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+        variableColumnName: str,
+        valueColumnName: str,
+    ) -> "DataFrame":
+        """
+        Unpivot a DataFrame from wide format to long format, optionally leaving
+        identifier columns set. This is the reverse to `groupBy(...).pivot(...).agg(...)`,
+        except for the aggregation, which cannot be reversed.
+
+        This function is useful to massage a DataFrame into a format where some
+        columns are identifier columns ("ids"), while all other columns ("values")
+        are "unpivoted" to the rows, leaving just two non-id columns, named as given
+        by `variableColumnName` and `valueColumnName`.
+
+        When no "id" columns are given, the unpivoted DataFrame consists of only the
+        "variable" and "value" columns.
+
+        All "value" columns must share a least common data type. Unless they are the same data type,
+        all "value" columns are cast to the nearest common data type. For instance, types
+        `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType`
+        do not have a common data type and `unpivot` fails.
+
+        :func:`groupby` is an alias for :func:`groupBy`.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        ids : str, Column, tuple, list, optional
+            Column(s) to use as identifiers. Can be a single column or column name,
+            or a list or tuple for multiple columns.
+        values : str, Column, tuple, list, optional
+            Column(s) to unpivot. Can be a single column or column name, or a list or tuple
+            for multiple columns. If not specified or empty, uses all columns that
+            are not set as `ids`.
+        variableColumnName : str
+            Name of the variable column.
+        valueColumnName : str
+            Name of the value column.
+
+        Returns
+        -------
+        DataFrame
+            Unpivoted DataFrame.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...     [(1, 11, 1.1), (2, 12, 1.2)],
+        ...     ["id", "int", "double"],
+        ... )
+        >>> df.show()
+        +---+---+------+
+        | id|int|double|
+        +---+---+------+
+        |  1| 11|   1.1|
+        |  2| 12|   1.2|
+        +---+---+------+
+
+        >>> df.unpivot("id", ["int", "double"], "var", "val").show()
+        +---+------+----+
+        | id|   var| val|
+        +---+------+----+
+        |  1|   int|11.0|
+        |  1|double| 1.1|
+        |  2|   int|12.0|
+        |  2|double| 1.2|
+        +---+------+----+
+        """
+
+        def to_jcols(
+            cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]]
+        ) -> JavaObject:
+            if cols is None:
+                lst = []
+            elif isinstance(cols, tuple):
+                lst = list(cols)
+            elif isinstance(cols, list):
+                lst = cols
+            else:
+                lst = [cols]
+            return self._jcols(*lst)
+
+        return DataFrame(
+            self._jdf.unpivotWithSeq(
+                to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName
+            ),
+            self.sparkSession,
+        )
+
+    def melt(
+        self,
+        ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+        values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+        variableColumnName: str,
+        valueColumnName: str,
+    ) -> "DataFrame":
+        """
+        Unpivot a DataFrame from wide format to long format, optionally leaving
+        identifier columns set. This is the reverse to `groupBy(...).pivot(...).agg(...)`,
+        except for the aggregation, which cannot be reversed.
+
+        :func:`melt` is an alias for :func:`unpivot`.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        ids : str, Column, tuple, list, optional
+            Column(s) to use as identifiers. Can be a single column or column name,
+            or a list or tuple for multiple columns.
+        values : str, Column, tuple, list, optional
+            Column(s) to unpivot. Can be a single column or column name, or a list or tuple
+            for multiple columns. If not specified or empty, uses all columns that
+            are not set as `ids`.
+        variableColumnName : str
+            Name of the variable column.
+        valueColumnName : str
+            Name of the value column.
+
+        Returns
+        -------
+        DataFrame
+            Unpivoted DataFrame.
+
+        See Also
+        --------
+        DataFrame.unpivot
+        """
+        return self.unpivot(ids, values, variableColumnName, valueColumnName)
+
     def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
         """Aggregate on the entire :class:`DataFrame` without groups
         (shorthand for ``df.groupBy().agg()``).
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 7c7d3d1e51c..987ff91402d 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -534,6 +534,150 @@ class DataFrameTests(ReusedSQLTestCase):
         self.assertEqual(1, logical_plan.toString().count("what"))
         self.assertEqual(3, logical_plan.toString().count("itworks"))
 
+    def test_unpivot(self):
+        # SPARK-39877: test the DataFrame.unpivot method
+        df = self.spark.createDataFrame(
+            [
+                (1, 10, 1.0, "one"),
+                (2, 20, 2.0, "two"),
+                (3, 30, 3.0, "three"),
+            ],
+            ["id", "int", "double", "str"],
+        )
+
+        with self.subTest(desc="with no identifier and no value columns"):
+            # select only columns that have common data type (double)
+            actual = df.select("id", "int", "double").unpivot(
+                ids=None, values=None, variableColumnName="var", valueColumnName="val"
+            )
+            self.assertEqual(actual.schema.simpleString(), "struct<var:string,val:double>")
+            self.assertEqual(
+                actual.collect(),
+                [
+                    Row(variable="id", value=1.0),
+                    Row(variable="int", value=10.0),
+                    Row(variable="double", value=1.0),
+                    Row(variable="id", value=2.0),
+                    Row(variable="int", value=20.0),
+                    Row(variable="double", value=2.0),
+                    Row(variable="id", value=3.0),
+                    Row(variable="int", value=30.0),
+                    Row(variable="double", value=3.0),
+                ],
+            )
+
+        with self.subTest(desc="with no identifier column and multiple value columns"):
+            for id in [None, [], ()]:
+                for values in [["int", "double"], ("int", "double")]:
+                    with self.subTest(ids=id, values=values):
+                        actual = df.unpivot(id, values, "var", "val")
+                        self.assertEqual(
+                            actual.schema.simpleString(), "struct<var:string,val:double>"
+                        )
+                        self.assertEqual(
+                            actual.collect(),
+                            [
+                                Row(variable="int", value=10.0),
+                                Row(variable="double", value=1.0),
+                                Row(variable="int", value=20.0),
+                                Row(variable="double", value=2.0),
+                                Row(variable="int", value=30.0),
+                                Row(variable="double", value=3.0),
+                            ],
+                        )
+
+        with self.subTest(desc="with single identifier column and multiple value columns"):
+            for id in ["id", ["id"], ("id",)]:
+                for values in [["int", "double"], ("int", "double")]:
+                    with self.subTest(ids=id, values=values):
+                        actual = df.unpivot(id, values, "var", "val")
+                        self.assertEqual(
+                            actual.schema.simpleString(),
+                            "struct<id:bigint,var:string,val:double>",
+                        )
+                        self.assertEqual(
+                            actual.collect(),
+                            [
+                                Row(id=1, variable="int", value=10.0),
+                                Row(id=1, variable="double", value=1.0),
+                                Row(id=2, variable="int", value=20.0),
+                                Row(id=2, variable="double", value=2.0),
+                                Row(id=3, variable="int", value=30.0),
+                                Row(id=3, variable="double", value=3.0),
+                            ],
+                        )
+
+        with self.subTest(desc="with multiple identifier columns and single given value columns"):
+            for ids in [["id", "double"], ("id", "double")]:
+                for values in ["str", ["str"], ("str",)]:
+                    with self.subTest(ids=ids, values=values):
+                        actual = df.unpivot(ids, values, "var", "val")
+                        self.assertEqual(
+                            actual.schema.simpleString(),
+                            "struct<id:bigint,double:double,var:string,val:string>",
+                        )
+                        self.assertEqual(
+                            actual.collect(),
+                            [
+                                Row(id=1, double=1.0, variable="str", value="one"),
+                                Row(id=2, double=2.0, variable="str", value="two"),
+                                Row(id=3, double=3.0, variable="str", value="three"),
+                            ],
+                        )
+
+        with self.subTest(desc="with multiple identifier columns but no given value columns"):
+            for ids in [["id", "str"], ("id", "str")]:
+                for values in [None, [], ()]:
+                    with self.subTest(ids=ids, values=values):
+                        actual = df.unpivot(ids, values, "var", "val")
+                        self.assertEqual(
+                            actual.schema.simpleString(),
+                            "struct<id:bigint,str:string,var:string,val:double>",
+                        )
+                        self.assertEqual(
+                            actual.collect(),
+                            [
+                                Row(id=1, str="one", variable="int", value=10.0),
+                                Row(id=1, str="one", variable="double", value=1.0),
+                                Row(id=2, str="two", variable="int", value=20.0),
+                                Row(id=2, str="two", variable="double", value=2.0),
+                                Row(id=3, str="three", variable="int", value=30.0),
+                                Row(id=3, str="three", variable="double", value=3.0),
+                            ],
+                        )
+
+        with self.subTest(desc="with value columns without common data type"):
+            with self.assertRaisesRegex(
+                AnalysisException,
+                r"\[UNPIVOT_VALUE_DATA_TYPE_MISMATCH\] Unpivot value columns must share "
+                r"a least common type, some types do not: .*",
+            ):
+                df.unpivot("id", ["int", "str"], "var", "val")
+
+        with self.subTest(desc="with columns"):
+            for id in [df.id, [df.id], (df.id,)]:
+                for values in [[df.int, df.double], (df.int, df.double)]:
+                    with self.subTest(ids=id, values=values):
+                        self.assertEqual(
+                            df.unpivot(id, values, "var", "val").collect(),
+                            df.unpivot("id", ["int", "double"], "var", "val").collect(),
+                        )
+
+        with self.subTest(desc="with column names and columns"):
+            for ids in [[df.id, "str"], (df.id, "str")]:
+                for values in [[df.int, "double"], (df.int, "double")]:
+                    with self.subTest(ids=ids, values=values):
+                        self.assertEqual(
+                            df.unpivot(ids, values, "var", "val").collect(),
+                            df.unpivot(["id", "str"], ["int", "double"], "var", "val").collect(),
+                        )
+
+        with self.subTest(desc="melt alias"):
+            self.assertEqual(
+                df.unpivot("id", ["int", "double"], "var", "val").collect(),
+                df.melt("id", ["int", "double"], "var", "val").collect(),
+            )
+
     def test_observe(self):
         # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
         from pyspark.sql import Observation
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4bc337e5af3..3f0cef33b5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2107,6 +2107,17 @@ class Dataset[T] private[sql](
       valueColumnName: String): DataFrame =
     unpivot(ids, Array.empty, variableColumnName, valueColumnName)
 
+  /**
+   * Called from Python as Seq[Column] are easier to create via py4j than Array[Column].
+   * We use Array[Column] for unpivot rather than Seq[Column] as those are Java-friendly.
+   */
+  private[sql] def unpivotWithSeq(
+      ids: Seq[Column],
+      values: Seq[Column],
+      variableColumnName: String,
+      valueColumnName: String): DataFrame =
+    unpivot(ids.toArray, values.toArray, variableColumnName, valueColumnName)
+
   /**
    * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
    * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org