You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/10/15 01:24:44 UTC

spark git commit: [SPARK-17946][PYSPARK] Python crossJoin API similar to Scala

Repository: spark
Updated Branches:
  refs/heads/master 72adfbf94 -> 2d96d35dc


[SPARK-17946][PYSPARK] Python crossJoin API similar to Scala

## What changes were proposed in this pull request?

Add a crossJoin function to the DataFrame API similar to that in Scala. Joins with no condition (cartesian products) must be specified with the crossJoin API

## How was this patch tested?
Added python tests to ensure that an AnalysisException if a cartesian product is specified without crossJoin(), and that cartesian products can execute if specified via crossJoin()

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)

Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request.

Author: Srinath Shankar <sr...@databricks.com>

Closes #15493 from srinathshankar/crosspython.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2d96d35d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2d96d35d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2d96d35d

Branch: refs/heads/master
Commit: 2d96d35dc0fed6df249606d9ce9272c0f0109fa2
Parents: 72adfbf
Author: Srinath Shankar <sr...@databricks.com>
Authored: Fri Oct 14 18:24:47 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Oct 14 18:24:47 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 26 ++++++++++++++++----
 python/pyspark/sql/tests.py                     | 15 ++++++++++-
 .../scala/org/apache/spark/sql/Dataset.scala    |  2 +-
 3 files changed, 36 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2d96d35d/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7606ac0..29710ac 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -651,6 +651,25 @@ class DataFrame(object):
         return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
 
     @ignore_unicode_prefix
+    @since(2.1)
+    def crossJoin(self, other):
+        """Returns the cartesian product with another :class:`DataFrame`.
+
+        :param other: Right side of the cartesian product.
+
+        >>> df.select("age", "name").collect()
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+        >>> df2.select("name", "height").collect()
+        [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85)]
+        >>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect()
+        [Row(age=2, name=u'Alice', height=80), Row(age=2, name=u'Alice', height=85),
+         Row(age=5, name=u'Bob', height=80), Row(age=5, name=u'Bob', height=85)]
+        """
+
+        jdf = self._jdf.crossJoin(other._jdf)
+        return DataFrame(jdf, self.sql_ctx)
+
+    @ignore_unicode_prefix
     @since(1.3)
     def join(self, other, on=None, how=None):
         """Joins with another :class:`DataFrame`, using the given join expression.
@@ -690,14 +709,11 @@ class DataFrame(object):
                 on = self._jseq(on)
             else:
                 assert isinstance(on[0], Column), "on should be Column or list of Column"
-                if len(on) > 1:
-                    on = reduce(lambda x, y: x.__and__(y), on)
-                else:
-                    on = on[0]
+                on = reduce(lambda x, y: x.__and__(y), on)
                 on = on._jc
 
         if on is None and how is None:
-            jdf = self._jdf.crossJoin(other._jdf)
+            jdf = self._jdf.join(other._jdf)
         else:
             if how is None:
                 how = "inner"

http://git-wip-us.apache.org/repos/asf/spark/blob/2d96d35d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 51d5e7a..3d46b85 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1466,7 +1466,7 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
 
         # no join key -- should not be a broadcast join
-        plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan()
+        plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
         self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
 
         # planner should not crash without a join
@@ -1514,6 +1514,19 @@ class SQLTests(ReusedPySparkTestCase):
         df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
         self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
 
+    # Cartesian products require cross join syntax
+    def test_require_cross(self):
+        from pyspark.sql.functions import broadcast
+
+        df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
+        df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
+
+        # joins without conditions require cross join syntax
+        self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
+
+        # works with crossJoin
+        self.assertEqual(1, df1.crossJoin(df2).count())
+
     def test_conf(self):
         spark = self.spark
         spark.conf.set("bogo", "sipeo")

http://git-wip-us.apache.org/repos/asf/spark/blob/2d96d35d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 70c9cf5..7ae3275 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -774,7 +774,7 @@ class Dataset[T] private[sql](
    * @param right Right side of the join operation.
    *
    * @group untypedrel
-   * @since 2.0.0
+   * @since 2.1.0
    */
   def crossJoin(right: Dataset[_]): DataFrame = withPlan {
     Join(logicalPlan, right.logicalPlan, joinType = Cross, None)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org