You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/03/08 03:48:21 UTC

git commit: Spark 1165 rdd.intersection in python and java

Repository: spark
Updated Branches:
  refs/heads/master b7cd9e992 -> 6e730edcd


Spark 1165 rdd.intersection in python and java

Author: Prashant Sharma <pr...@imaginea.com>
Author: Prashant Sharma <sc...@gmail.com>

Closes #80 from ScrapCodes/SPARK-1165/RDD.intersection and squashes the following commits:

9b015e9 [Prashant Sharma] Added a note, shuffle is required for intersection.
1fea813 [Prashant Sharma] correct the lines wrapping
d0c71f3 [Prashant Sharma] SPARK-1165 RDD.intersection in java
d6effee [Prashant Sharma] SPARK-1165 Implemented RDD.intersection in python.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6e730edc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6e730edc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6e730edc

Branch: refs/heads/master
Commit: 6e730edcde7ca6cbb5727dff7a42f7284b368528
Parents: b7cd9e9
Author: Prashant Sharma <pr...@imaginea.com>
Authored: Fri Mar 7 18:48:07 2014 -0800
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Fri Mar 7 18:48:07 2014 -0800

----------------------------------------------------------------------
 .../apache/spark/api/java/JavaDoubleRDD.scala   |  8 +++++
 .../org/apache/spark/api/java/JavaPairRDD.scala | 10 +++++++
 .../org/apache/spark/api/java/JavaRDD.scala     |  9 ++++++
 .../java/org/apache/spark/JavaAPISuite.java     | 31 ++++++++++++++++++++
 python/pyspark/rdd.py                           | 17 +++++++++++
 5 files changed, 75 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6e730edc/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index d178706..f816bb4 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -140,6 +140,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
    */
   def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
 
+  /**
+   * Return the intersection of this RDD and another one. The output will not contain any duplicate
+   * elements, even if the input RDDs did.
+   *
+   * Note that this method performs a shuffle internally.
+   */
+  def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd))
+
   // Double RDD functions
 
   /** Add up the elements in this RDD. */

http://git-wip-us.apache.org/repos/asf/spark/blob/6e730edc/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 857626f..0ff428c 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -126,6 +126,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
     new JavaPairRDD[K, V](rdd.union(other.rdd))
 
+  /**
+   * Return the intersection of this RDD and another one. The output will not contain any duplicate
+   * elements, even if the input RDDs did.
+   *
+   * Note that this method performs a shuffle internally.
+   */
+  def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+    new JavaPairRDD[K, V](rdd.intersection(other.rdd))
+
+
   // first() has to be overridden here so that the generated method has the signature
   // 'public scala.Tuple2 first()'; if the trait's definition is used,
   // then the method has the signature 'public java.lang.Object first()',

http://git-wip-us.apache.org/repos/asf/spark/blob/6e730edc/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index e973c46..91bf404 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -106,6 +106,15 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
    */
   def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
 
+
+  /**
+   * Return the intersection of this RDD and another one. The output will not contain any duplicate
+   * elements, even if the input RDDs did.
+   *
+   * Note that this method performs a shuffle internally.
+   */
+  def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd))
+
   /**
    * Return an RDD with the elements from `this` that are not in `other`.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/6e730edc/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index c7d0e2d..40e853c 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -110,6 +110,37 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(4, pUnion.count());
   }
 
+  @SuppressWarnings("unchecked")
+  @Test
+  public void intersection() {
+    List<Integer> ints1 = Arrays.asList(1, 10, 2, 3, 4, 5);
+    List<Integer> ints2 = Arrays.asList(1, 6, 2, 3, 7, 8);
+    JavaRDD<Integer> s1 = sc.parallelize(ints1);
+    JavaRDD<Integer> s2 = sc.parallelize(ints2);
+
+    JavaRDD<Integer> intersections = s1.intersection(s2);
+    Assert.assertEquals(3, intersections.count());
+
+    ArrayList<Integer> list = new ArrayList<Integer>();
+    JavaRDD<Integer> empty = sc.parallelize(list);
+    JavaRDD<Integer> emptyIntersection = empty.intersection(s2);
+    Assert.assertEquals(0, emptyIntersection.count());
+
+    List<Double> doubles = Arrays.asList(1.0, 2.0);
+    JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
+    JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
+    JavaDoubleRDD dIntersection = d1.intersection(d2);
+    Assert.assertEquals(2, dIntersection.count());
+
+    List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+    pairs.add(new Tuple2<Integer, Integer>(1, 2));
+    pairs.add(new Tuple2<Integer, Integer>(3, 4));
+    JavaPairRDD<Integer, Integer> p1 = sc.parallelizePairs(pairs);
+    JavaPairRDD<Integer, Integer> p2 = sc.parallelizePairs(pairs);
+    JavaPairRDD<Integer, Integer> pIntersection = p1.intersection(p2);
+    Assert.assertEquals(2, pIntersection.count());
+  }
+
   @Test
   public void sortByKey() {
     List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();

http://git-wip-us.apache.org/repos/asf/spark/blob/6e730edc/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 097a0a2..e72f57d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -326,6 +326,23 @@ class RDD(object):
             return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
                        self.ctx.serializer)
 
+    def intersection(self, other):
+        """
+        Return the intersection of this RDD and another one. The output will not 
+        contain any duplicate elements, even if the input RDDs did.
+        
+        Note that this method performs a shuffle internally.
+
+        >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5])
+        >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8])
+        >>> rdd1.intersection(rdd2).collect()
+        [1, 2, 3]
+        """
+        return self.map(lambda v: (v, None)) \
+            .cogroup(other.map(lambda v: (v, None))) \
+            .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \
+            .keys()
+
     def _reserialize(self):
         if self._jrdd_deserializer == self.ctx.serializer:
             return self