You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2014/10/21 22:15:36 UTC

git commit: [SPARK-3994] Use standard Aggregator code path for countByKey and countByValue

Repository: spark
Updated Branches:
  refs/heads/master 1a623b2e1 -> 5fdaf52a9


[SPARK-3994] Use standard Aggregator code path for countByKey and countByValue

See [JIRA](https://issues.apache.org/jira/browse/SPARK-3994) for more information. Also adds
a note which warns against using these methods.

Author: Aaron Davidson <aa...@databricks.com>

Closes #2839 from aarondav/countByKey and squashes the following commits:

d6fdb2a [Aaron Davidson] Respond to comments
e1f06d3 [Aaron Davidson] [SPARK-3994] Use standard Aggregator code path for countByKey and countByValue


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5fdaf52a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5fdaf52a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5fdaf52a

Branch: refs/heads/master
Commit: 5fdaf52a9df21cac69e2a4612aeb4e760e4424e7
Parents: 1a623b2
Author: Aaron Davidson <aa...@databricks.com>
Authored: Tue Oct 21 13:15:29 2014 -0700
Committer: Andrew Or <an...@gmail.com>
Committed: Tue Oct 21 13:15:29 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/rdd/PairRDDFunctions.scala | 11 +++++--
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 31 +++++---------------
 2 files changed, 16 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5fdaf52a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index ac96de8..da89f63 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -315,8 +315,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
   @deprecated("Use reduceByKeyLocally", "1.0.0")
   def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
 
-  /** Count the number of elements for each key, and return the result to the master as a Map. */
-  def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
+  /** 
+   * Count the number of elements for each key, collecting the results to a local Map.
+   *
+   * Note that this method should only be used if the resulting map is expected to be small, as
+   * the whole thing is loaded into the driver's memory.
+   * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which
+   * returns an RDD[T, Long] instead of a map.
+   */
+  def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap
 
   /**
    * :: Experimental ::

http://git-wip-us.apache.org/repos/asf/spark/blob/5fdaf52a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 71cabf6..b7f125d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -927,32 +927,15 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
-   * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
-   * combine step happens locally on the master, equivalent to running a single reduce task.
+   * Return the count of each unique value in this RDD as a local map of (value, count) pairs.
+   *
+   * Note that this method should only be used if the resulting map is expected to be small, as
+   * the whole thing is loaded into the driver's memory.
+   * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
+   * returns an RDD[T, Long] instead of a map.
    */
   def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = {
-    if (elementClassTag.runtimeClass.isArray) {
-      throw new SparkException("countByValue() does not support arrays")
-    }
-    // TODO: This should perhaps be distributed by default.
-    val countPartition = (iter: Iterator[T]) => {
-      val map = new OpenHashMap[T,Long]
-      iter.foreach {
-        t => map.changeValue(t, 1L, _ + 1L)
-      }
-      Iterator(map)
-    }: Iterator[OpenHashMap[T,Long]]
-    val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => {
-      m2.foreach { case (key, value) =>
-        m1.changeValue(key, value, _ + value)
-      }
-      m1
-    }: OpenHashMap[T,Long]
-    val myResult = mapPartitions(countPartition).reduce(mergeMaps)
-    // Convert to a Scala mutable map
-    val mutableResult = scala.collection.mutable.Map[T,Long]()
-    myResult.foreach { case (k, v) => mutableResult.put(k, v) }
-    mutableResult
+    map(value => (value, null)).countByKey()
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org