You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/06/09 08:27:09 UTC

spark git commit: [SPARK-7990][SQL] Add methods to facilitate equi-join on multiple joining keys

Repository: spark
Updated Branches:
  refs/heads/master a5c52c1a3 -> 7658eb28a


[SPARK-7990][SQL] Add methods to facilitate equi-join on multiple joining keys

JIRA: https://issues.apache.org/jira/browse/SPARK-7990

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #6616 from viirya/multi_keys_equi_join and squashes the following commits:

cd5c888 [Liang-Chi Hsieh] Import reduce in python3.
c43722c [Liang-Chi Hsieh] For comments.
0400e89 [Liang-Chi Hsieh] Fix scala style.
cc90015 [Liang-Chi Hsieh] Add methods to facilitate equi-join on multiple joining keys.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7658eb28
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7658eb28
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7658eb28

Branch: refs/heads/master
Commit: 7658eb28a2ea28c06e3b5a26f7734a7dc36edc19
Parents: a5c52c1
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Mon Jun 8 23:27:05 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Jun 8 23:27:05 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 45 ++++++++++++++------
 .../scala/org/apache/spark/sql/DataFrame.scala  | 40 ++++++++++++++---
 .../apache/spark/sql/DataFrameJoinSuite.scala   |  9 ++++
 3 files changed, 75 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7658eb28/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2d8c595..e9dd05e 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -22,6 +22,7 @@ import random
 if sys.version >= '3':
     basestring = unicode = str
     long = int
+    from functools import reduce
 else:
     from itertools import imap as map
 
@@ -503,36 +504,52 @@ class DataFrame(object):
 
     @ignore_unicode_prefix
     @since(1.3)
-    def join(self, other, joinExprs=None, joinType=None):
+    def join(self, other, on=None, how=None):
         """Joins with another :class:`DataFrame`, using the given join expression.
 
         The following performs a full outer join between ``df1`` and ``df2``.
 
         :param other: Right side of the join
-        :param joinExprs: a string for join column name, or a join expression (Column).
-            If joinExprs is a string indicating the name of the join column,
-            the column must exist on both sides, and this performs an inner equi-join.
-        :param joinType: str, default 'inner'.
+        :param on: a string for join column name, a list of column names,
+            , a join expression (Column) or a list of Columns.
+            If `on` is a string or a list of string indicating the name of the join column(s),
+            the column(s) must exist on both sides, and this performs an inner equi-join.
+        :param how: str, default 'inner'.
             One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
 
         >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
         [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
 
+        >>> cond = [df.name == df3.name, df.age == df3.age]
+        >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
+        [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]
+
         >>> df.join(df2, 'name').select(df.name, df2.height).collect()
         [Row(name=u'Bob', height=85)]
+
+        >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
+        [Row(name=u'Bob', age=5)]
         """
 
-        if joinExprs is None:
+        if on is not None and not isinstance(on, list):
+            on = [on]
+
+        if on is None or len(on) == 0:
             jdf = self._jdf.join(other._jdf)
-        elif isinstance(joinExprs, basestring):
-            jdf = self._jdf.join(other._jdf, joinExprs)
+
+        if isinstance(on[0], basestring):
+            jdf = self._jdf.join(other._jdf, self._jseq(on))
         else:
-            assert isinstance(joinExprs, Column), "joinExprs should be Column"
-            if joinType is None:
-                jdf = self._jdf.join(other._jdf, joinExprs._jc)
+            assert isinstance(on[0], Column), "on should be Column or list of Column"
+            if len(on) > 1:
+                on = reduce(lambda x, y: x.__and__(y), on)
+            else:
+                on = on[0]
+            if how is None:
+                jdf = self._jdf.join(other._jdf, on._jc, "inner")
             else:
-                assert isinstance(joinType, basestring), "joinType should be basestring"
-                jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+                assert isinstance(how, basestring), "how should be basestring"
+                jdf = self._jdf.join(other._jdf, on._jc, how)
         return DataFrame(jdf, self.sql_ctx)
 
     @ignore_unicode_prefix
@@ -1315,6 +1332,8 @@ def _test():
         .toDF(StructType([StructField('age', IntegerType()),
                           StructField('name', StringType())]))
     globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+    globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
+                                   Row(name='Bob', age=5)]).toDF()
     globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
                                   Row(name='Bob', age=5, height=None),
                                   Row(name='Tom', age=None, height=None),

http://git-wip-us.apache.org/repos/asf/spark/blob/7658eb28/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 4a22415..59f64dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -395,22 +395,50 @@ class DataFrame private[sql](
    * @since 1.4.0
    */
   def join(right: DataFrame, usingColumn: String): DataFrame = {
+    join(right, Seq(usingColumn))
+  }
+
+  /**
+   * Inner equi-join with another [[DataFrame]] using the given columns.
+   *
+   * Different from other join functions, the join columns will only appear once in the output,
+   * i.e. similar to SQL's `JOIN USING` syntax.
+   *
+   * {{{
+   *   // Joining df1 and df2 using the columns "user_id" and "user_name"
+   *   df1.join(df2, Seq("user_id", "user_name"))
+   * }}}
+   *
+   * Note that if you perform a self-join using this function without aliasing the input
+   * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since
+   * there is no way to disambiguate which side of the join you would like to reference.
+   *
+   * @param right Right side of the join operation.
+   * @param usingColumns Names of the columns to join on. This columns must exist on both sides.
+   * @group dfops
+   * @since 1.4.0
+   */
+  def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = {
     // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
     // by creating a new instance for one of the branch.
     val joined = sqlContext.executePlan(
       Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
 
-    // Project only one of the join column.
-    val joinedCol = joined.right.resolve(usingColumn)
+    // Project only one of the join columns.
+    val joinedCols = usingColumns.map(col => joined.right.resolve(col))
+    val condition = usingColumns.map { col =>
+      catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col))
+    }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) =>
+      catalyst.expressions.And(cond, eqTo)
+    }
+
     Project(
-      joined.output.filterNot(_ == joinedCol),
+      joined.output.filterNot(joinedCols.contains(_)),
       Join(
         joined.left,
         joined.right,
         joinType = Inner,
-        Some(catalyst.expressions.EqualTo(
-          joined.left.resolve(usingColumn),
-          joined.right.resolve(usingColumn))))
+        condition)
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7658eb28/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 051d13e..6165764 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -34,6 +34,15 @@ class DataFrameJoinSuite extends QueryTest {
       Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil)
   }
 
+  test("join - join using multiple columns") {
+    val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str")
+    val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str")
+
+    checkAnswer(
+      df.join(df2, Seq("int", "int2")),
+      Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
+  }
+
   test("join - join using self join") {
     val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org