You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/07/09 01:27:14 UTC

spark git commit: [SPARK-8877] [MLLIB] Public API for association rule generation

Repository: spark
Updated Branches:
  refs/heads/master 381cb161b -> 8c32b2e87


[SPARK-8877] [MLLIB] Public API for association rule generation

Adds FPGrowth.generateAssociationRules to public API for generating association rules after mining frequent itemsets.

Author: Feynman Liang <fl...@databricks.com>

Closes #7271 from feynmanliang/SPARK-8877 and squashes the following commits:

83b8baf [Feynman Liang] Add API Doc
867abff [Feynman Liang] Add FPGrowth.generateAssociationRules and change access modifiers for AssociationRules


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8c32b2e8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8c32b2e8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8c32b2e8

Branch: refs/heads/master
Commit: 8c32b2e870c7c250a63e838718df833edf6dea07
Parents: 381cb16
Author: Feynman Liang <fl...@databricks.com>
Authored: Wed Jul 8 16:27:11 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jul 8 16:27:11 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/fpm/AssociationRules.scala      |  5 ++-
 .../org/apache/spark/mllib/fpm/FPGrowth.scala   | 11 ++++-
 .../apache/spark/mllib/fpm/FPGrowthSuite.scala  | 42 ++++++++++++++++++++
 3 files changed, 55 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8c32b2e8/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
index 4a0f842..7e2bbfe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
  * association rules which have a single item as the consequent.
  */
 @Experimental
-class AssociationRules private (
+class AssociationRules private[fpm] (
     private var minConfidence: Double) extends Logging with Serializable {
 
   /**
@@ -45,6 +45,7 @@ class AssociationRules private (
    * Sets the minimal confidence (default: `0.8`).
    */
   def setMinConfidence(minConfidence: Double): this.type = {
+    require(minConfidence >= 0.0 && minConfidence <= 1.0)
     this.minConfidence = minConfidence
     this
   }
@@ -91,7 +92,7 @@ object AssociationRules {
    * @tparam Item item type
    */
   @Experimental
-  class Rule[Item] private[mllib] (
+  class Rule[Item] private[fpm] (
       val antecedent: Array[Item],
       val consequent: Array[Item],
       freqUnion: Double,

http://git-wip-us.apache.org/repos/asf/spark/blob/8c32b2e8/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 0da59e8..9cb9a00 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -40,7 +40,16 @@ import org.apache.spark.storage.StorageLevel
  * @tparam Item item type
  */
 @Experimental
-class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
+class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
+  /**
+   * Generates association rules for the [[Item]]s in [[freqItemsets]].
+   * @param confidence minimal confidence of the rules produced
+   */
+  def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
+    val associationRules = new AssociationRules(confidence)
+    associationRules.run(freqItemsets)
+  }
+}
 
 /**
  * :: Experimental ::

http://git-wip-us.apache.org/repos/asf/spark/blob/8c32b2e8/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index ddc296a..4a9bfdb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -132,6 +132,48 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(model1.freqItemsets.count() === 625)
   }
 
+  test("FP-Growth String type association rule generation") {
+    val transactions = Seq(
+      "r z h k p",
+      "z y x w v u t s",
+      "s x o n r",
+      "x z y m t s q e",
+      "z",
+      "x z y r q t p")
+      .map(_.split(" "))
+    val rdd = sc.parallelize(transactions, 2).cache()
+
+    /* Verify results using the `R` code:
+       transactions = as(sapply(
+         list("r z h k p",
+              "z y x w v u t s",
+              "s x o n r",
+              "x z y m t s q e",
+              "z",
+              "x z y r q t p"),
+         FUN=function(x) strsplit(x," ",fixed=TRUE)),
+         "transactions")
+       ars = apriori(transactions,
+                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
+       arsDF = as(ars, "data.frame")
+       arsDF$support = arsDF$support * length(transactions)
+       names(arsDF)[names(arsDF) == "support"] = "freq"
+       > nrow(arsDF)
+       [1] 23
+       > sum(arsDF$confidence == 1)
+       [1] 23
+     */
+    val rules = (new FPGrowth())
+      .setMinSupport(0.5)
+      .setNumPartitions(2)
+      .run(rdd)
+      .generateAssociationRules(0.9)
+      .collect()
+
+    assert(rules.size === 23)
+    assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+  }
+
   test("FP-Growth using Int type") {
     val transactions = Seq(
       "1 2 3",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org