You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/10 04:11:58 UTC

[spark] branch master updated: [SPARK-41010][CONNECT][PYTHON] Complete Support for Except and Intersect in Python client

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9dc39e199de [SPARK-41010][CONNECT][PYTHON] Complete Support for Except and Intersect in Python client
9dc39e199de is described below

commit 9dc39e199de645f60e115267fba2fae782ab53f1
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Thu Nov 10 12:11:40 2022 +0800

    [SPARK-41010][CONNECT][PYTHON] Complete Support for Except and Intersect in Python client
    
    ### What changes were proposed in this pull request?
    
    1. Add support for intersect and except.
    2. Unify union, intersect and except into `SetOperation`.
    
    ### Why are the changes needed?
    
    Improve API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38506 from amaliujia/except_python.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py            | 82 +++++++++++++++++++++-
 python/pyspark/sql/connect/plan.py                 | 38 +++++++---
 .../sql/tests/connect/test_connect_plan_only.py    | 22 ++++++
 3 files changed, 132 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index c6877707ad2..ccd826cd476 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -389,7 +389,9 @@ class DataFrame(object):
     def unionAll(self, other: "DataFrame") -> "DataFrame":
         if other._plan is None:
             raise ValueError("Argument to Union does not contain a valid plan.")
-        return DataFrame.withPlan(plan.UnionAll(self._plan, other._plan), session=self._session)
+        return DataFrame.withPlan(
+            plan.SetOperation(self._plan, other._plan, "union", is_all=True), session=self._session
+        )
 
     def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame":
         """Returns a new :class:`DataFrame` containing union of rows in this and another
@@ -415,7 +417,83 @@ class DataFrame(object):
         if other._plan is None:
             raise ValueError("Argument to UnionByName does not contain a valid plan.")
         return DataFrame.withPlan(
-            plan.UnionAll(self._plan, other._plan, allowMissingColumns), session=self._session
+            plan.SetOperation(
+                self._plan, other._plan, "union", is_all=True, by_name=allowMissingColumns
+            ),
+            session=self._session,
+        )
+
+    def exceptAll(self, other: "DataFrame") -> "DataFrame":
+        """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but
+        not in another :class:`DataFrame` while preserving duplicates.
+
+        This is equivalent to `EXCEPT ALL` in SQL.
+        As standard in SQL, this function resolves columns by position (not by name).
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        other : :class:`DataFrame`
+            The other :class:`DataFrame` to compare to.
+
+        Returns
+        -------
+        :class:`DataFrame`
+        """
+        return DataFrame.withPlan(
+            plan.SetOperation(self._plan, other._plan, "except", is_all=True), session=self._session
+        )
+
+    def intersect(self, other: "DataFrame") -> "DataFrame":
+        """Return a new :class:`DataFrame` containing rows only in
+        both this :class:`DataFrame` and another :class:`DataFrame`.
+        Note that any duplicates are removed. To preserve duplicates
+        use :func:`intersectAll`.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        other : :class:`DataFrame`
+            Another :class:`DataFrame` that needs to be combined.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            Combined DataFrame.
+
+        Notes
+        -----
+        This is equivalent to `INTERSECT` in SQL.
+        """
+        return DataFrame.withPlan(
+            plan.SetOperation(self._plan, other._plan, "intersect", is_all=False),
+            session=self._session,
+        )
+
+    def intersectAll(self, other: "DataFrame") -> "DataFrame":
+        """Return a new :class:`DataFrame` containing rows in both this :class:`DataFrame`
+        and another :class:`DataFrame` while preserving duplicates.
+
+        This is equivalent to `INTERSECT ALL` in SQL. As standard in SQL, this function
+        resolves columns by position (not by name).
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        other : :class:`DataFrame`
+            Another :class:`DataFrame` that needs to be combined.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            Combined DataFrame.
+        """
+        return DataFrame.withPlan(
+            plan.SetOperation(self._plan, other._plan, "intersect", is_all=True),
+            session=self._session,
         )
 
     def where(self, condition: Expression) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 3bb5558d04b..acc5927b519 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -607,21 +607,43 @@ class Join(LogicalPlan):
         """
 
 
-class UnionAll(LogicalPlan):
+class SetOperation(LogicalPlan):
     def __init__(
-        self, child: Optional["LogicalPlan"], other: "LogicalPlan", by_name: bool = False
+        self,
+        child: Optional["LogicalPlan"],
+        other: Optional["LogicalPlan"],
+        set_op: str,
+        is_all: bool = True,
+        by_name: bool = False,
     ) -> None:
         super().__init__(child)
         self.other = other
         self.by_name = by_name
+        self.is_all = is_all
+        self.set_op = set_op
 
     def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
         assert self._child is not None
         rel = proto.Relation()
-        rel.set_op.left_input.CopyFrom(self._child.plan(session))
-        rel.set_op.right_input.CopyFrom(self.other.plan(session))
-        rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
-        rel.set_op.is_all = True
+        if self._child is not None:
+            rel.set_op.left_input.CopyFrom(self._child.plan(session))
+        if self.other is not None:
+            rel.set_op.right_input.CopyFrom(self.other.plan(session))
+        if self.set_op == "union":
+            rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
+        elif self.set_op == "intersect":
+            rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT
+        elif self.set_op == "except":
+            rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT
+        else:
+            raise NotImplementedError(
+                """
+                Unsupported set operation type: %s.
+                """
+                % rel.set_op.set_op_type
+            )
+
+        rel.set_op.is_all = self.is_all
         rel.set_op.by_name = self.by_name
         return rel
 
@@ -633,7 +655,7 @@ class UnionAll(LogicalPlan):
         o = " " * (indent + LogicalPlan.INDENT)
         n = indent + LogicalPlan.INDENT * 2
         return (
-            f"{i}UnionAll\n{o}child1=\n{self._child.print(n)}"
+            f"{i}SetOperation\n{o}child1=\n{self._child.print(n)}"
             f"\n{o}child2=\n{self.other.print(n)}"
         )
 
@@ -644,7 +666,7 @@ class UnionAll(LogicalPlan):
         return f"""
         <ul>
             <li>
-                <b>Union</b><br />
+                <b>SetOperation</b><br />
                 Left: {self._child._repr_html_()}
                 Right: {self.other._repr_html_()}
             </li>
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 6807a13a8c9..adfaa651c08 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -210,10 +210,32 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         df2 = self.connect.readTable(table_name=self.tbl_name)
         plan1 = df1.union(df2)._plan.to_proto(self.connect)
         self.assertTrue(plan1.root.set_op.is_all)
+        self.assertEqual(proto.SetOperation.SET_OP_TYPE_UNION, plan1.root.set_op.set_op_type)
         plan2 = df1.union(df2)._plan.to_proto(self.connect)
         self.assertTrue(plan2.root.set_op.is_all)
+        self.assertEqual(proto.SetOperation.SET_OP_TYPE_UNION, plan2.root.set_op.set_op_type)
         plan3 = df1.unionByName(df2, True)._plan.to_proto(self.connect)
         self.assertTrue(plan3.root.set_op.by_name)
+        self.assertEqual(proto.SetOperation.SET_OP_TYPE_UNION, plan3.root.set_op.set_op_type)
+
+    def test_except(self):
+        # SPARK-41010: test `except` API for Python client.
+        df1 = self.connect.readTable(table_name=self.tbl_name)
+        df2 = self.connect.readTable(table_name=self.tbl_name)
+        plan1 = df1.exceptAll(df2)._plan.to_proto(self.connect)
+        self.assertTrue(plan1.root.set_op.is_all)
+        self.assertEqual(proto.SetOperation.SET_OP_TYPE_EXCEPT, plan1.root.set_op.set_op_type)
+
+    def test_intersect(self):
+        # SPARK-41010: test `intersect` API for Python client.
+        df1 = self.connect.readTable(table_name=self.tbl_name)
+        df2 = self.connect.readTable(table_name=self.tbl_name)
+        plan1 = df1.intersect(df2)._plan.to_proto(self.connect)
+        self.assertFalse(plan1.root.set_op.is_all)
+        self.assertEqual(proto.SetOperation.SET_OP_TYPE_INTERSECT, plan1.root.set_op.set_op_type)
+        plan2 = df1.intersectAll(df2)._plan.to_proto(self.connect)
+        self.assertTrue(plan2.root.set_op.is_all)
+        self.assertEqual(proto.SetOperation.SET_OP_TYPE_INTERSECT, plan2.root.set_op.set_op_type)
 
     def test_coalesce_and_repartition(self):
         # SPARK-41026: test Coalesce and Repartition API in Python client.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org