You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/08/01 01:39:47 UTC
[spark] branch master updated: [SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8585fafab86 [SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API
8585fafab86 is described below
commit 8585fafab8633c02e6f1b989acd2bbdb0eb1678e
Author: Enrico Minack <gi...@enrico.minack.dev>
AuthorDate: Mon Aug 1 09:39:12 2022 +0800
[SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API
### What changes were proposed in this pull request?
This adds `unpivot` and its alias `melt` to the PySpark API. It calls into Scala `Dataset.unpivot` (#36150). Small difference to Scala method signature is that PySpark method has default values. This is similar to `melt` in Spark Pandas API.
### Why are the changes needed?
To support `unpivot` in Python.
### Does this PR introduce _any_ user-facing change?
Yes, adds `DataFrame.unpivot` and `DataFrame.melt` to PySpark API.
### How was this patch tested?
Added test to `test_dataframe.py`.
Closes #37304 from EnricoMi/branch-pyspark-unpivot.
Authored-by: Enrico Minack <gi...@enrico.minack.dev>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/dataframe.py | 134 +++++++++++++++++++
python/pyspark/sql/tests/test_dataframe.py | 144 +++++++++++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 11 ++
3 files changed, 289 insertions(+)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 481dafa310d..8c9632fe766 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2238,6 +2238,140 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return GroupedData(jgd, self)
+ def unpivot(
+ self,
+ ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+ values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+ variableColumnName: str,
+ valueColumnName: str,
+ ) -> "DataFrame":
+ """
+ Unpivot a DataFrame from wide format to long format, optionally leaving
+ identifier columns set. This is the reverse to `groupBy(...).pivot(...).agg(...)`,
+ except for the aggregation, which cannot be reversed.
+
+ This function is useful to massage a DataFrame into a format where some
+ columns are identifier columns ("ids"), while all other columns ("values")
+ are "unpivoted" to the rows, leaving just two non-id columns, named as given
+ by `variableColumnName` and `valueColumnName`.
+
+ When no "id" columns are given, the unpivoted DataFrame consists of only the
+ "variable" and "value" columns.
+
+ All "value" columns must share a least common data type. Unless they are the same data type,
+ all "value" columns are cast to the nearest common data type. For instance, types
+ `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType`
+ do not have a common data type and `unpivot` fails.
+
+ :func:`groupby` is an alias for :func:`groupBy`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ ids : str, Column, tuple, list, optional
+ Column(s) to use as identifiers. Can be a single column or column name,
+ or a list or tuple for multiple columns.
+ values : str, Column, tuple, list, optional
+ Column(s) to unpivot. Can be a single column or column name, or a list or tuple
+ for multiple columns. If not specified or empty, uses all columns that
+ are not set as `ids`.
+ variableColumnName : str
+ Name of the variable column.
+ valueColumnName : str
+ Name of the value column.
+
+ Returns
+ -------
+ DataFrame
+ Unpivoted DataFrame.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(
+ ... [(1, 11, 1.1), (2, 12, 1.2)],
+ ... ["id", "int", "double"],
+ ... )
+ >>> df.show()
+ +---+---+------+
+ | id|int|double|
+ +---+---+------+
+ | 1| 11| 1.1|
+ | 2| 12| 1.2|
+ +---+---+------+
+
+ >>> df.unpivot("id", ["int", "double"], "var", "val").show()
+ +---+------+----+
+ | id| var| val|
+ +---+------+----+
+ | 1| int|11.0|
+ | 1|double| 1.1|
+ | 2| int|12.0|
+ | 2|double| 1.2|
+ +---+------+----+
+ """
+
+ def to_jcols(
+ cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]]
+ ) -> JavaObject:
+ if cols is None:
+ lst = []
+ elif isinstance(cols, tuple):
+ lst = list(cols)
+ elif isinstance(cols, list):
+ lst = cols
+ else:
+ lst = [cols]
+ return self._jcols(*lst)
+
+ return DataFrame(
+ self._jdf.unpivotWithSeq(
+ to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName
+ ),
+ self.sparkSession,
+ )
+
+ def melt(
+ self,
+ ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+ values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+ variableColumnName: str,
+ valueColumnName: str,
+ ) -> "DataFrame":
+ """
+ Unpivot a DataFrame from wide format to long format, optionally leaving
+ identifier columns set. This is the reverse to `groupBy(...).pivot(...).agg(...)`,
+ except for the aggregation, which cannot be reversed.
+
+ :func:`melt` is an alias for :func:`unpivot`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ ids : str, Column, tuple, list, optional
+ Column(s) to use as identifiers. Can be a single column or column name,
+ or a list or tuple for multiple columns.
+ values : str, Column, tuple, list, optional
+ Column(s) to unpivot. Can be a single column or column name, or a list or tuple
+ for multiple columns. If not specified or empty, uses all columns that
+ are not set as `ids`.
+ variableColumnName : str
+ Name of the variable column.
+ valueColumnName : str
+ Name of the value column.
+
+ Returns
+ -------
+ DataFrame
+ Unpivoted DataFrame.
+
+ See Also
+ --------
+ DataFrame.unpivot
+ """
+ return self.unpivot(ids, values, variableColumnName, valueColumnName)
+
def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
"""Aggregate on the entire :class:`DataFrame` without groups
(shorthand for ``df.groupBy().agg()``).
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 7c7d3d1e51c..987ff91402d 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -534,6 +534,150 @@ class DataFrameTests(ReusedSQLTestCase):
self.assertEqual(1, logical_plan.toString().count("what"))
self.assertEqual(3, logical_plan.toString().count("itworks"))
+ def test_unpivot(self):
+ # SPARK-39877: test the DataFrame.unpivot method
+ df = self.spark.createDataFrame(
+ [
+ (1, 10, 1.0, "one"),
+ (2, 20, 2.0, "two"),
+ (3, 30, 3.0, "three"),
+ ],
+ ["id", "int", "double", "str"],
+ )
+
+ with self.subTest(desc="with no identifier and no value columns"):
+ # select only columns that have common data type (double)
+ actual = df.select("id", "int", "double").unpivot(
+ ids=None, values=None, variableColumnName="var", valueColumnName="val"
+ )
+ self.assertEqual(actual.schema.simpleString(), "struct<var:string,val:double>")
+ self.assertEqual(
+ actual.collect(),
+ [
+ Row(variable="id", value=1.0),
+ Row(variable="int", value=10.0),
+ Row(variable="double", value=1.0),
+ Row(variable="id", value=2.0),
+ Row(variable="int", value=20.0),
+ Row(variable="double", value=2.0),
+ Row(variable="id", value=3.0),
+ Row(variable="int", value=30.0),
+ Row(variable="double", value=3.0),
+ ],
+ )
+
+ with self.subTest(desc="with no identifier column and multiple value columns"):
+ for id in [None, [], ()]:
+ for values in [["int", "double"], ("int", "double")]:
+ with self.subTest(ids=id, values=values):
+ actual = df.unpivot(id, values, "var", "val")
+ self.assertEqual(
+ actual.schema.simpleString(), "struct<var:string,val:double>"
+ )
+ self.assertEqual(
+ actual.collect(),
+ [
+ Row(variable="int", value=10.0),
+ Row(variable="double", value=1.0),
+ Row(variable="int", value=20.0),
+ Row(variable="double", value=2.0),
+ Row(variable="int", value=30.0),
+ Row(variable="double", value=3.0),
+ ],
+ )
+
+ with self.subTest(desc="with single identifier column and multiple value columns"):
+ for id in ["id", ["id"], ("id",)]:
+ for values in [["int", "double"], ("int", "double")]:
+ with self.subTest(ids=id, values=values):
+ actual = df.unpivot(id, values, "var", "val")
+ self.assertEqual(
+ actual.schema.simpleString(),
+ "struct<id:bigint,var:string,val:double>",
+ )
+ self.assertEqual(
+ actual.collect(),
+ [
+ Row(id=1, variable="int", value=10.0),
+ Row(id=1, variable="double", value=1.0),
+ Row(id=2, variable="int", value=20.0),
+ Row(id=2, variable="double", value=2.0),
+ Row(id=3, variable="int", value=30.0),
+ Row(id=3, variable="double", value=3.0),
+ ],
+ )
+
+ with self.subTest(desc="with multiple identifier columns and single given value columns"):
+ for ids in [["id", "double"], ("id", "double")]:
+ for values in ["str", ["str"], ("str",)]:
+ with self.subTest(ids=ids, values=values):
+ actual = df.unpivot(ids, values, "var", "val")
+ self.assertEqual(
+ actual.schema.simpleString(),
+ "struct<id:bigint,double:double,var:string,val:string>",
+ )
+ self.assertEqual(
+ actual.collect(),
+ [
+ Row(id=1, double=1.0, variable="str", value="one"),
+ Row(id=2, double=2.0, variable="str", value="two"),
+ Row(id=3, double=3.0, variable="str", value="three"),
+ ],
+ )
+
+ with self.subTest(desc="with multiple identifier columns but no given value columns"):
+ for ids in [["id", "str"], ("id", "str")]:
+ for values in [None, [], ()]:
+ with self.subTest(ids=ids, values=values):
+ actual = df.unpivot(ids, values, "var", "val")
+ self.assertEqual(
+ actual.schema.simpleString(),
+ "struct<id:bigint,str:string,var:string,val:double>",
+ )
+ self.assertEqual(
+ actual.collect(),
+ [
+ Row(id=1, str="one", variable="int", value=10.0),
+ Row(id=1, str="one", variable="double", value=1.0),
+ Row(id=2, str="two", variable="int", value=20.0),
+ Row(id=2, str="two", variable="double", value=2.0),
+ Row(id=3, str="three", variable="int", value=30.0),
+ Row(id=3, str="three", variable="double", value=3.0),
+ ],
+ )
+
+ with self.subTest(desc="with value columns without common data type"):
+ with self.assertRaisesRegex(
+ AnalysisException,
+ r"\[UNPIVOT_VALUE_DATA_TYPE_MISMATCH\] Unpivot value columns must share "
+ r"a least common type, some types do not: .*",
+ ):
+ df.unpivot("id", ["int", "str"], "var", "val")
+
+ with self.subTest(desc="with columns"):
+ for id in [df.id, [df.id], (df.id,)]:
+ for values in [[df.int, df.double], (df.int, df.double)]:
+ with self.subTest(ids=id, values=values):
+ self.assertEqual(
+ df.unpivot(id, values, "var", "val").collect(),
+ df.unpivot("id", ["int", "double"], "var", "val").collect(),
+ )
+
+ with self.subTest(desc="with column names and columns"):
+ for ids in [[df.id, "str"], (df.id, "str")]:
+ for values in [[df.int, "double"], (df.int, "double")]:
+ with self.subTest(ids=ids, values=values):
+ self.assertEqual(
+ df.unpivot(ids, values, "var", "val").collect(),
+ df.unpivot(["id", "str"], ["int", "double"], "var", "val").collect(),
+ )
+
+ with self.subTest(desc="melt alias"):
+ self.assertEqual(
+ df.unpivot("id", ["int", "double"], "var", "val").collect(),
+ df.melt("id", ["int", "double"], "var", "val").collect(),
+ )
+
def test_observe(self):
# SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
from pyspark.sql import Observation
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4bc337e5af3..3f0cef33b5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2107,6 +2107,17 @@ class Dataset[T] private[sql](
valueColumnName: String): DataFrame =
unpivot(ids, Array.empty, variableColumnName, valueColumnName)
+ /**
+ * Called from Python as Seq[Column] are easier to create via py4j than Array[Column].
+ * We use Array[Column] for unpivot rather than Seq[Column] as those are Java-friendly.
+ */
+ private[sql] def unpivotWithSeq(
+ ids: Seq[Column],
+ values: Seq[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame =
+ unpivot(ids.toArray, values.toArray, variableColumnName, valueColumnName)
+
/**
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org