You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/29 02:26:06 UTC

spark git commit: [SPARK-5430] move treeReduce and treeAggregate from mllib to core

Repository: spark
Updated Branches:
  refs/heads/master e80dc1c5a -> 4ee79c71a


[SPARK-5430] move treeReduce and treeAggregate from mllib to core

We have seen many use cases of `treeAggregate`/`treeReduce` outside the ML domain. Maybe it is time to move them to Core. pwendell

Author: Xiangrui Meng <me...@databricks.com>

Closes #4228 from mengxr/SPARK-5430 and squashes the following commits:

20ad40d [Xiangrui Meng] exclude tree* from mima
e89a43e [Xiangrui Meng] fix compile and update java doc
3ae1a4b [Xiangrui Meng] add treeReduce/treeAggregate to Python
6f948c5 [Xiangrui Meng] add treeReduce/treeAggregate to JavaRDDLike
d600b6c [Xiangrui Meng] move treeReduce and treeAggregate to core


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4ee79c71
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4ee79c71
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4ee79c71

Branch: refs/heads/master
Commit: 4ee79c71afc5175ba42b5e3d4088fe23db3e45d1
Parents: e80dc1c
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Jan 28 17:26:03 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jan 28 17:26:03 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/java/JavaRDDLike.scala | 37 ++++++++
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 63 ++++++++++++++
 .../java/org/apache/spark/JavaAPISuite.java     | 30 +++++++
 .../scala/org/apache/spark/rdd/RDDSuite.scala   | 19 ++++
 .../org/apache/spark/mllib/feature/IDF.scala    |  1 -
 .../spark/mllib/feature/StandardScaler.scala    |  1 -
 .../mllib/linalg/distributed/RowMatrix.scala    |  1 -
 .../mllib/optimization/GradientDescent.scala    |  1 -
 .../apache/spark/mllib/optimization/LBFGS.scala |  1 -
 .../apache/spark/mllib/rdd/RDDFunctions.scala   | 59 ++-----------
 .../mllib/feature/StandardScalerSuite.scala     |  1 -
 .../spark/mllib/rdd/RDDFunctionsSuite.scala     | 18 ----
 project/MimaExcludes.scala                      |  6 ++
 python/pyspark/rdd.py                           | 91 +++++++++++++++++++-
 14 files changed, 254 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 62bf18d..0f91c94 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -349,6 +349,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
 
   /**
+   * Reduces the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree
+   * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]]
+   */
+  def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth)
+
+  /**
+   * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2.
+   */
+  def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2)
+
+  /**
    * Aggregate the elements of each partition, and then the results for all the partitions, using a
    * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
    * modify t1 and return it as its result value to avoid object allocation; however, it should not
@@ -370,6 +383,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
 
   /**
+   * Aggregates the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree
+   * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]]
+   */
+  def treeAggregate[U](
+      zeroValue: U,
+      seqOp: JFunction2[U, T, U],
+      combOp: JFunction2[U, U, U],
+      depth: Int): U = {
+    rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U])
+  }
+
+  /**
+   * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2.
+   */
+  def treeAggregate[U](
+      zeroValue: U,
+      seqOp: JFunction2[U, T, U],
+      combOp: JFunction2[U, U, U]): U = {
+    treeAggregate(zeroValue, seqOp, combOp, 2)
+  }
+
+  /**
    * Return the number of elements in the RDD.
    */
   def count(): Long = rdd.count()

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index ab7410a..5f39384 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -901,6 +901,38 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
+   * Reduces the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree (default: 2)
+   * @see [[org.apache.spark.rdd.RDD#reduce]]
+   */
+  def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
+    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+    val cleanF = context.clean(f)
+    val reducePartition: Iterator[T] => Option[T] = iter => {
+      if (iter.hasNext) {
+        Some(iter.reduceLeft(cleanF))
+      } else {
+        None
+      }
+    }
+    val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it)))
+    val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
+      if (c.isDefined && x.isDefined) {
+        Some(cleanF(c.get, x.get))
+      } else if (c.isDefined) {
+        c
+      } else if (x.isDefined) {
+        x
+      } else {
+        None
+      }
+    }
+    partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth)
+      .getOrElse(throw new UnsupportedOperationException("empty collection"))
+  }
+
+  /**
    * Aggregate the elements of each partition, and then the results for all the partitions, using a
    * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
    * modify t1 and return it as its result value to avoid object allocation; however, it should not
@@ -936,6 +968,37 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
+   * Aggregates the elements of this RDD in a multi-level tree pattern.
+   *
+   * @param depth suggested depth of the tree (default: 2)
+   * @see [[org.apache.spark.rdd.RDD#aggregate]]
+   */
+  def treeAggregate[U: ClassTag](zeroValue: U)(
+      seqOp: (U, T) => U,
+      combOp: (U, U) => U,
+      depth: Int = 2): U = {
+    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+    if (partitions.size == 0) {
+      return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
+    }
+    val cleanSeqOp = context.clean(seqOp)
+    val cleanCombOp = context.clean(combOp)
+    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+    var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
+    var numPartitions = partiallyAggregated.partitions.size
+    val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
+    // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
+    while (numPartitions > scale + numPartitions / scale) {
+      numPartitions /= scale
+      val curNumPartitions = numPartitions
+      partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
+        iter.map((i % curNumPartitions, _))
+      }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
+    }
+    partiallyAggregated.reduce(cleanCombOp)
+  }
+
+  /**
    * Return the number of elements in the RDD.
    */
   def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 004de05..b16a1e9 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -492,6 +492,36 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(33, sum);
   }
 
+  @Test
+  public void treeReduce() {
+    JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10);
+    Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+      @Override
+      public Integer call(Integer a, Integer b) {
+        return a + b;
+      }
+    };
+    for (int depth = 1; depth <= 10; depth++) {
+      int sum = rdd.treeReduce(add, depth);
+      Assert.assertEquals(-5, sum);
+    }
+  }
+
+  @Test
+  public void treeAggregate() {
+    JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10);
+    Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+      @Override
+      public Integer call(Integer a, Integer b) {
+        return a + b;
+      }
+    };
+    for (int depth = 1; depth <= 10; depth++) {
+      int sum = rdd.treeAggregate(0, add, add, depth);
+      Assert.assertEquals(-5, sum);
+    }
+  }
+
   @SuppressWarnings("unchecked")
   @Test
   public void aggregateByKey() {

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index e33b4bb..bede1ff 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
   }
 
+  test("treeAggregate") {
+    val rdd = sc.makeRDD(-1000 until 1000, 10)
+    def seqOp = (c: Long, x: Int) => c + x
+    def combOp = (c1: Long, c2: Long) => c1 + c2
+    for (depth <- 1 until 10) {
+      val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
+      assert(sum === -1000L)
+    }
+  }
+
+  test("treeReduce") {
+    val rdd = sc.makeRDD(-1000 until 1000, 10)
+    for (depth <- 1 until 10) {
+      val sum = rdd.treeReduce(_ + _, depth)
+      assert(sum === -1000)
+    }
+  }
+
   test("basic caching") {
     val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
     assert(rdd.collect().toList === List(1, 2, 3, 4))
@@ -967,4 +985,5 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assertFails { sc.parallelize(1 to 100) }
     assertFails { sc.textFile("/nonexistent-path") }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index 3260f27..a89eea0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd.RDD
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 3c20917..2f2c6f9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.feature
 import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 02075ed..ddca30c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -30,7 +30,6 @@ import org.apache.spark.Logging
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg._
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.random.XORShiftRandom

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 0857877..4b7d058 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi}
 import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.{Vectors, Vector}
-import org.apache.spark.mllib.rdd.RDDFunctions._
 
 /**
  * Class used to solve an optimization problem using Gradient Descent.

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index d16d0da..d5e4f4c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -26,7 +26,6 @@ import org.apache.spark.Logging
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS.axpy
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd.RDD
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
index 57c0768..7817284 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -21,10 +21,7 @@ import scala.language.implicitConversions
 import scala.reflect.ClassTag
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.HashPartitioner
-import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
 
 /**
  * Machine learning specific RDD functions.
@@ -53,63 +50,25 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
    * Reduces the elements of this RDD in a multi-level tree pattern.
    *
    * @param depth suggested depth of the tree (default: 2)
-   * @see [[org.apache.spark.rdd.RDD#reduce]]
+   * @see [[org.apache.spark.rdd.RDD#treeReduce]]
+   * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead.
    */
-  def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
-    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
-    val cleanF = self.context.clean(f)
-    val reducePartition: Iterator[T] => Option[T] = iter => {
-      if (iter.hasNext) {
-        Some(iter.reduceLeft(cleanF))
-      } else {
-        None
-      }
-    }
-    val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it)))
-    val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
-      if (c.isDefined && x.isDefined) {
-        Some(cleanF(c.get, x.get))
-      } else if (c.isDefined) {
-        c
-      } else if (x.isDefined) {
-        x
-      } else {
-        None
-      }
-    }
-    RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth)
-      .getOrElse(throw new UnsupportedOperationException("empty collection"))
-  }
+  @deprecated("Use RDD.treeReduce instead.", "1.3.0")
+  def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth)
 
   /**
    * Aggregates the elements of this RDD in a multi-level tree pattern.
    *
    * @param depth suggested depth of the tree (default: 2)
-   * @see [[org.apache.spark.rdd.RDD#aggregate]]
+   * @see [[org.apache.spark.rdd.RDD#treeAggregate]]
+   * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead.
    */
+  @deprecated("Use RDD.treeAggregate instead.", "1.3.0")
   def treeAggregate[U: ClassTag](zeroValue: U)(
       seqOp: (U, T) => U,
       combOp: (U, U) => U,
       depth: Int = 2): U = {
-    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
-    if (self.partitions.size == 0) {
-      return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
-    }
-    val cleanSeqOp = self.context.clean(seqOp)
-    val cleanCombOp = self.context.clean(combOp)
-    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
-    var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it)))
-    var numPartitions = partiallyAggregated.partitions.size
-    val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
-    // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
-    while (numPartitions > scale + numPartitions / scale) {
-      numPartitions /= scale
-      val curNumPartitions = numPartitions
-      partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
-        iter.map((i % curNumPartitions, _))
-      }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
-    }
-    partiallyAggregated.reduce(cleanCombOp)
+    self.treeAggregate(zeroValue)(seqOp, combOp, depth)
   }
 }
 
@@ -117,5 +76,5 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
 object RDDFunctions {
 
   /** Implicit conversion from an RDD to RDDFunctions. */
-  implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd)
+  implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index 4c93c0c..e9e510b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -22,7 +22,6 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
 import org.apache.spark.rdd.RDD
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 681ce92..6d6c0aa 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
     val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
     assert(sliding === expected)
   }
-
-  test("treeAggregate") {
-    val rdd = sc.makeRDD(-1000 until 1000, 10)
-    def seqOp = (c: Long, x: Int) => c + x
-    def combOp = (c1: Long, c2: Long) => c1 + c2
-    for (depth <- 1 until 10) {
-      val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
-      assert(sum === -1000L)
-    }
-  }
-
-  test("treeReduce") {
-    val rdd = sc.makeRDD(-1000 until 1000, 10)
-    for (depth <- 1 until 10) {
-      val sum = rdd.treeReduce(_ + _, depth)
-      assert(sum === -1000)
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index e750fed..14ba03e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -114,6 +114,12 @@ object MimaExcludes {
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.api.java.JavaRDDLike.isEmpty")
           ) ++ Seq(
+            // SPARK-5430
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.java.JavaRDDLike.treeReduce"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.java.JavaRDDLike.treeAggregate")
+          ) ++ Seq(
             // SPARK-5297 Java FileStream do not work with custom key/values
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")

http://git-wip-us.apache.org/repos/asf/spark/blob/4ee79c71/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b6dd5a3..2f8a0ed 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ import warnings
 import heapq
 import bisect
 import random
-from math import sqrt, log, isinf, isnan
+from math import sqrt, log, isinf, isnan, pow, ceil
 
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
@@ -726,6 +726,43 @@ class RDD(object):
             return reduce(f, vals)
         raise ValueError("Can not reduce() empty RDD")
 
+    def treeReduce(self, f, depth=2):
+        """
+        Reduces the elements of this RDD in a multi-level tree pattern.
+
+        :param depth: suggested depth of the tree (default: 2)
+
+        >>> add = lambda x, y: x + y
+        >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+        >>> rdd.treeReduce(add)
+        -5
+        >>> rdd.treeReduce(add, 1)
+        -5
+        >>> rdd.treeReduce(add, 2)
+        -5
+        >>> rdd.treeReduce(add, 5)
+        -5
+        >>> rdd.treeReduce(add, 10)
+        -5
+        """
+        if depth < 1:
+            raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+        zeroValue = None, True  # Use the second entry to indicate whether this is a dummy value.
+
+        def op(x, y):
+            if x[1]:
+                return y
+            elif y[1]:
+                return x
+            else:
+                return f(x[0], y[0]), False
+
+        reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth)
+        if reduced[1]:
+            raise ValueError("Cannot reduce empty RDD.")
+        return reduced[0]
+
     def fold(self, zeroValue, op):
         """
         Aggregate the elements of each partition, and then the results for all
@@ -777,6 +814,58 @@ class RDD(object):
 
         return self.mapPartitions(func).fold(zeroValue, combOp)
 
+    def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
+        """
+        Aggregates the elements of this RDD in a multi-level tree
+        pattern.
+
+        :param depth: suggested depth of the tree (default: 2)
+
+        >>> add = lambda x, y: x + y
+        >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+        >>> rdd.treeAggregate(0, add, add)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 1)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 2)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 5)
+        -5
+        >>> rdd.treeAggregate(0, add, add, 10)
+        -5
+        """
+        if depth < 1:
+            raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+        if self.getNumPartitions() == 0:
+            return zeroValue
+
+        def aggregatePartition(iterator):
+            acc = zeroValue
+            for obj in iterator:
+                acc = seqOp(acc, obj)
+            yield acc
+
+        partiallyAggregated = self.mapPartitions(aggregatePartition)
+        numPartitions = partiallyAggregated.getNumPartitions()
+        scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2)
+        # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree
+        # aggregation.
+        while numPartitions > scale + numPartitions / scale:
+            numPartitions /= scale
+            curNumPartitions = numPartitions
+
+            def mapPartition(i, iterator):
+                for obj in iterator:
+                    yield (i % curNumPartitions, obj)
+
+            partiallyAggregated = partiallyAggregated \
+                .mapPartitionsWithIndex(mapPartition) \
+                .reduceByKey(combOp, curNumPartitions) \
+                .values()
+
+        return partiallyAggregated.reduce(combOp)
+
     def max(self, key=None):
         """
         Find the maximum item in this RDD.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org