You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/01/05 22:32:05 UTC

spark git commit: [SPARK-6724][MLLIB] Support model save/load for FPGrowthModel

Repository: spark
Updated Branches:
  refs/heads/master 047a31bb1 -> 13a3b636d


[SPARK-6724][MLLIB] Support model save/load for FPGrowthModel

Support model save/load for FPGrowthModel

Author: Yanbo Liang <yb...@gmail.com>

Closes #9267 from yanboliang/spark-6724.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/13a3b636
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/13a3b636
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/13a3b636

Branch: refs/heads/master
Commit: 13a3b636d9425c5713cd1381203ee1b60f71b8c8
Parents: 047a31b
Author: Yanbo Liang <yb...@gmail.com>
Authored: Tue Jan 5 13:31:59 2016 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Jan 5 13:31:59 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/mllib/fpm/FPGrowth.scala   | 100 ++++++++++++++++++-
 .../spark/mllib/fpm/JavaFPGrowthSuite.java      |  40 ++++++++
 .../apache/spark/mllib/fpm/FPGrowthSuite.scala  |  68 +++++++++++++
 3 files changed, 205 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/13a3b636/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 70ef1ed..5273ed4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -17,19 +17,29 @@
 
 package org.apache.spark.mllib.fpm
 
-import java.{util => ju}
 import java.lang.{Iterable => JavaIterable}
+import java.{util => ju}
 
-import scala.collection.mutable
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
 
 import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.mllib.fpm.FPGrowth._
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -39,7 +49,8 @@ import org.apache.spark.storage.StorageLevel
  */
 @Since("1.3.0")
 class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
-    @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
+    @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]])
+  extends Saveable with Serializable {
   /**
    * Generates association rules for the [[Item]]s in [[freqItemsets]].
    * @param confidence minimal confidence of the rules produced
@@ -49,6 +60,89 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
     val associationRules = new AssociationRules(confidence)
     associationRules.run(freqItemsets)
   }
+
+  /**
+   * Save this model to the given path.
+   * It only works for Item datatypes supported by DataFrames.
+   *
+   * This saves:
+   *  - human-readable (JSON) model metadata to path/metadata/
+   *  - Parquet formatted data to path/data/
+   *
+   * The model may be loaded using [[FPGrowthModel.load]].
+   *
+   * @param sc  Spark context used to save model data.
+   * @param path  Path specifying the directory in which to save this model.
+   *              If the directory already exists, this method throws an exception.
+   */
+  @Since("2.0.0")
+  override def save(sc: SparkContext, path: String): Unit = {
+    FPGrowthModel.SaveLoadV1_0.save(this, path)
+  }
+
+  override protected val formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object FPGrowthModel extends Loader[FPGrowthModel[_]] {
+
+  @Since("2.0.0")
+  override def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
+    FPGrowthModel.SaveLoadV1_0.load(sc, path)
+  }
+
+  private[fpm] object SaveLoadV1_0 {
+
+    private val thisFormatVersion = "1.0"
+
+    private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel"
+
+    def save(model: FPGrowthModel[_], path: String): Unit = {
+      val sc = model.freqItemsets.sparkContext
+      val sqlContext = SQLContext.getOrCreate(sc)
+
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+      // Get the type of item class
+      val sample = model.freqItemsets.first().items(0)
+      val className = sample.getClass.getCanonicalName
+      val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
+      val tpe = classSymbol.selfType
+
+      val itemType = ScalaReflection.schemaFor(tpe).dataType
+      val fields = Array(StructField("items", ArrayType(itemType)),
+        StructField("freq", LongType))
+      val schema = StructType(fields)
+      val rowDataRDD = model.freqItemsets.map { x =>
+        Row(x.items, x.freq)
+      }
+      sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
+      implicit val formats = DefaultFormats
+      val sqlContext = SQLContext.getOrCreate(sc)
+
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+
+      val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
+      val sample = freqItemsets.select("items").head().get(0)
+      loadImpl(freqItemsets, sample)
+    }
+
+    def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = {
+      val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x =>
+        val items = x.getAs[Seq[Item]](0).toArray
+        val freq = x.getLong(1)
+        new FreqItemset(items, freq)
+      }
+      new FPGrowthModel(freqItemsetsRDD)
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/13a3b636/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 154f75d..eeeabfe 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.fpm;
 
+import java.io.File;
 import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
@@ -28,6 +29,7 @@ import static org.junit.Assert.*;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.util.Utils;
 
 public class JavaFPGrowthSuite implements Serializable {
   private transient JavaSparkContext sc;
@@ -69,4 +71,42 @@ public class JavaFPGrowthSuite implements Serializable {
       long freq = itemset.freq();
     }
   }
+
+  @Test
+  public void runFPGrowthSaveLoad() {
+
+    @SuppressWarnings("unchecked")
+    JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+      Arrays.asList("r z h k p".split(" ")),
+      Arrays.asList("z y x w v u t s".split(" ")),
+      Arrays.asList("s x o n r".split(" ")),
+      Arrays.asList("x z y m t s q e".split(" ")),
+      Arrays.asList("z".split(" ")),
+      Arrays.asList("x z y r q t p".split(" "))), 2);
+
+    FPGrowthModel<String> model = new FPGrowth()
+      .setMinSupport(0.5)
+      .setNumPartitions(2)
+      .run(rdd);
+
+    File tempDir = Utils.createTempDir(
+      System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite");
+    String outputPath = tempDir.getPath();
+
+    try {
+      model.save(sc.sc(), outputPath);
+      FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath);
+      List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
+        .collect();
+      assertEquals(18, freqItemsets.size());
+
+      for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+        // Test return types.
+        List<String> items = itemset.javaItems();
+        long freq = itemset.freq();
+      }
+    } finally {
+      Utils.deleteRecursively(tempDir);
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/13a3b636/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 4a9bfdb..b9e997c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
 
 class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
      */
     assert(model1.freqItemsets.count() === 65)
   }
+
+  test("model save/load with String type") {
+    val transactions = Seq(
+      "r z h k p",
+      "z y x w v u t s",
+      "s x o n r",
+      "x z y m t s q e",
+      "z",
+      "x z y r q t p")
+      .map(_.split(" "))
+    val rdd = sc.parallelize(transactions, 2).cache()
+
+    val model3 = new FPGrowth()
+      .setMinSupport(0.5)
+      .setNumPartitions(2)
+      .run(rdd)
+    val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+      (itemset.items.toSet, itemset.freq)
+    }
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    try {
+      model3.save(sc, path)
+      val newModel = FPGrowthModel.load(sc, path)
+      val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+        (itemset.items.toSet, itemset.freq)
+      }
+      assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
+  test("model save/load with Int type") {
+    val transactions = Seq(
+      "1 2 3",
+      "1 2 3 4",
+      "5 4 3 2 1",
+      "6 5 4 3 2 1",
+      "2 4",
+      "1 3",
+      "1 7")
+      .map(_.split(" ").map(_.toInt).toArray)
+    val rdd = sc.parallelize(transactions, 2).cache()
+
+    val model3 = new FPGrowth()
+      .setMinSupport(0.5)
+      .setNumPartitions(2)
+      .run(rdd)
+    val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+      (itemset.items.toSet, itemset.freq)
+    }
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    try {
+      model3.save(sc, path)
+      val newModel = FPGrowthModel.load(sc, path)
+      val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+        (itemset.items.toSet, itemset.freq)
+      }
+      assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org