You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/02/04 02:02:45 UTC

spark git commit: [SPARK-5520][MLlib] Make FP-Growth implementation take generic item types (WIP)

Repository: spark
Updated Branches:
  refs/heads/master 068c0e2ee -> e380d2d46


[SPARK-5520][MLlib] Make FP-Growth implementation take generic item types (WIP)

Make FPGrowth.run API take generic item types:
`def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item]`
so that user can invoke it by run[String, Seq[String]], run[Int, Seq[Int]], run[Int, List[Int]], etc.

Scala part is done, while java part is still in progress

Author: Jacky Li <ja...@huawei.com>
Author: Jacky Li <ja...@users.noreply.github.com>
Author: Xiangrui Meng <me...@databricks.com>

Closes #4340 from jackylk/SPARK-5520-WIP and squashes the following commits:

f5acf84 [Jacky Li] Merge pull request #2 from mengxr/SPARK-5520
63073d0 [Xiangrui Meng] update to make generic FPGrowth Java-friendly
737d8bb [Jacky Li] fix scalastyle
793f85c [Jacky Li] add Java test case
7783351 [Jacky Li] add generic support in FPGrowth


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e380d2d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e380d2d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e380d2d4

Branch: refs/heads/master
Commit: e380d2d46c92b319eafe30974ac7c1509081fca4
Parents: 068c0e2
Author: Jacky Li <ja...@huawei.com>
Authored: Tue Feb 3 17:02:42 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Feb 3 17:02:42 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/mllib/fpm/FPGrowth.scala   | 50 ++++++++----
 .../spark/mllib/fpm/JavaFPGrowthSuite.java      | 84 ++++++++++++++++++++
 .../apache/spark/mllib/fpm/FPGrowthSuite.scala  | 51 +++++++++++-
 3 files changed, 170 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e380d2d4/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 9591c79..1433ee9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -18,14 +18,31 @@
 package org.apache.spark.mllib.fpm
 
 import java.{util => ju}
+import java.lang.{Iterable => JavaIterable}
 
 import scala.collection.mutable
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
 
-import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
+import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 
-class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
+/**
+ * Model trained by [[FPGrowth]], which holds frequent itemsets.
+ * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
+ * @tparam Item item type
+ */
+class FPGrowthModel[Item: ClassTag](
+    val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
+
+  /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
+  def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
+    JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
+  }
+}
 
 /**
  * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -69,7 +86,7 @@ class FPGrowth private (
    * @param data input data set, each element contains a transaction
    * @return an [[FPGrowthModel]]
    */
-  def run(data: RDD[Array[String]]): FPGrowthModel = {
+  def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
     if (data.getStorageLevel == StorageLevel.NONE) {
       logWarning("Input data is not cached.")
     }
@@ -82,19 +99,24 @@ class FPGrowth private (
     new FPGrowthModel(freqItemsets)
   }
 
+  def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
+    implicit val tag = fakeClassTag[Item]
+    run(data.rdd.map(_.asScala.toArray))
+  }
+
   /**
    * Generates frequent items by filtering the input data using minimal support level.
    * @param minCount minimum count for frequent itemsets
    * @param partitioner partitioner used to distribute items
    * @return array of frequent pattern ordered by their frequencies
    */
-  private def genFreqItems(
-      data: RDD[Array[String]],
+  private def genFreqItems[Item: ClassTag](
+      data: RDD[Array[Item]],
       minCount: Long,
-      partitioner: Partitioner): Array[String] = {
+      partitioner: Partitioner): Array[Item] = {
     data.flatMap { t =>
       val uniq = t.toSet
-      if (t.length != uniq.size) {
+      if (t.size != uniq.size) {
         throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
       }
       t
@@ -114,11 +136,11 @@ class FPGrowth private (
    * @param partitioner partitioner used to distribute transactions
    * @return an RDD of (frequent itemset, count)
    */
-  private def genFreqItemsets(
-      data: RDD[Array[String]],
+  private def genFreqItemsets[Item: ClassTag](
+      data: RDD[Array[Item]],
       minCount: Long,
-      freqItems: Array[String],
-      partitioner: Partitioner): RDD[(Array[String], Long)] = {
+      freqItems: Array[Item],
+      partitioner: Partitioner): RDD[(Array[Item], Long)] = {
     val itemToRank = freqItems.zipWithIndex.toMap
     data.flatMap { transaction =>
       genCondTransactions(transaction, itemToRank, partitioner)
@@ -139,9 +161,9 @@ class FPGrowth private (
    * @param partitioner partitioner used to distribute transactions
    * @return a map of (target partition, conditional transaction)
    */
-  private def genCondTransactions(
-      transaction: Array[String],
-      itemToRank: Map[String, Int],
+  private def genCondTransactions[Item: ClassTag](
+      transaction: Array[Item],
+      itemToRank: Map[Item, Int],
       partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
     val output = mutable.Map.empty[Int, Array[Int]]
     // Filter the basket by frequent items pattern and sort their ranks.

http://git-wip-us.apache.org/repos/asf/spark/blob/e380d2d4/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
new file mode 100644
index 0000000..851707c
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import com.google.common.collect.Lists;
+import static org.junit.Assert.*;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaFPGrowthSuite implements Serializable {
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaFPGrowth");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void runFPGrowth() {
+
+    @SuppressWarnings("unchecked")
+    JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
+      Lists.newArrayList("r z h k p".split(" ")),
+      Lists.newArrayList("z y x w v u t s".split(" ")),
+      Lists.newArrayList("s x o n r".split(" ")),
+      Lists.newArrayList("x z y m t s q e".split(" ")),
+      Lists.newArrayList("z".split(" ")),
+      Lists.newArrayList("x z y r q t p".split(" "))), 2);
+
+    FPGrowth fpg = new FPGrowth();
+
+    FPGrowthModel<String> model6 = fpg
+      .setMinSupport(0.9)
+      .setNumPartitions(1)
+      .run(rdd);
+    assertEquals(0, model6.javaFreqItemsets().count());
+
+    FPGrowthModel<String> model3 = fpg
+      .setMinSupport(0.5)
+      .setNumPartitions(2)
+      .run(rdd);
+    assertEquals(18, model3.javaFreqItemsets().count());
+
+    FPGrowthModel<String> model2 = fpg
+      .setMinSupport(0.3)
+      .setNumPartitions(4)
+      .run(rdd);
+    assertEquals(54, model2.javaFreqItemsets().count());
+
+    FPGrowthModel<String> model1 = fpg
+      .setMinSupport(0.1)
+      .setNumPartitions(8)
+      .run(rdd);
+    assertEquals(625, model1.javaFreqItemsets().count());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e380d2d4/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 71ef60d..6812828 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 
 class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
 
-  test("FP-Growth") {
+
+  test("FP-Growth using String type") {
     val transactions = Seq(
       "r z h k p",
       "z y x w v u t s",
@@ -70,4 +71,52 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
       .run(rdd)
     assert(model1.freqItemsets.count() === 625)
   }
+
+  test("FP-Growth using Int type") {
+    val transactions = Seq(
+      "1 2 3",
+      "1 2 3 4",
+      "5 4 3 2 1",
+      "6 5 4 3 2 1",
+      "2 4",
+      "1 3",
+      "1 7")
+      .map(_.split(" ").map(_.toInt).toArray)
+    val rdd = sc.parallelize(transactions, 2).cache()
+
+    val fpg = new FPGrowth()
+
+    val model6 = fpg
+      .setMinSupport(0.9)
+      .setNumPartitions(1)
+      .run(rdd)
+    assert(model6.freqItemsets.count() === 0)
+
+    val model3 = fpg
+      .setMinSupport(0.5)
+      .setNumPartitions(2)
+      .run(rdd)
+    assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
+      "frequent itemsets should use primitive arrays")
+    val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
+      (items.toSet, count)
+    }
+    val expected = Set(
+      (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
+      (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
+      (Set(2, 4), 4L), (Set(1, 2, 3), 4L))
+    assert(freqItemsets3.toSet === expected)
+
+    val model2 = fpg
+      .setMinSupport(0.3)
+      .setNumPartitions(4)
+      .run(rdd)
+    assert(model2.freqItemsets.count() === 15)
+
+    val model1 = fpg
+      .setMinSupport(0.1)
+      .setNumPartitions(8)
+      .run(rdd)
+    assert(model1.freqItemsets.count() === 65)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org