You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/03/17 18:19:14 UTC

spark git commit: [SPARK-11891] Model export/import for RFormula and RFormulaModel

Repository: spark
Updated Branches:
  refs/heads/master 828213d4c -> edf8b8775


[SPARK-11891] Model export/import for RFormula and RFormulaModel

https://issues.apache.org/jira/browse/SPARK-11891

Author: Xusen Yin <yi...@gmail.com>

Closes #9884 from yinxusen/SPARK-11891.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/edf8b877
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/edf8b877
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/edf8b877

Branch: refs/heads/master
Commit: edf8b8775b81f5522680094bf24f372aa0c61447
Parents: 828213d
Author: Xusen Yin <yi...@gmail.com>
Authored: Thu Mar 17 10:19:10 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Mar 17 10:19:10 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/RFormula.scala  | 179 +++++++++++++++++--
 .../apache/spark/ml/tuning/CrossValidator.scala |   5 +-
 .../apache/spark/ml/feature/RFormulaSuite.scala |  40 ++++-
 3 files changed, 207 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/edf8b877/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index ab5f4a1..e7ca7ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -20,12 +20,14 @@ package org.apache.spark.ml.feature
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.VectorUDT
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types._
@@ -68,7 +70,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
  * will be created from the specified response variable in the formula.
  */
 @Experimental
-class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
+class RFormula(override val uid: String)
+  extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
 
   def this() = this(Identifiable.randomUID("rFormula"))
 
@@ -180,6 +183,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
   override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
 }
 
+@Since("2.0.0")
+object RFormula extends DefaultParamsReadable[RFormula] {
+
+  @Since("2.0.0")
+  override def load(path: String): RFormula = super.load(path)
+}
+
 /**
  * :: Experimental ::
  * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
@@ -189,9 +199,9 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
 @Experimental
 class RFormulaModel private[feature](
     override val uid: String,
-    resolvedFormula: ResolvedRFormula,
-    pipelineModel: PipelineModel)
-  extends Model[RFormulaModel] with RFormulaBase {
+    private[ml] val resolvedFormula: ResolvedRFormula,
+    private[ml] val pipelineModel: PipelineModel)
+  extends Model[RFormulaModel] with RFormulaBase with MLWritable {
 
   override def transform(dataset: DataFrame): DataFrame = {
     checkCanTransform(dataset.schema)
@@ -246,14 +256,71 @@ class RFormulaModel private[feature](
       !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
       "Label column already exists and is not of type DoubleType.")
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this)
+}
+
+@Since("2.0.0")
+object RFormulaModel extends MLReadable[RFormulaModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[RFormulaModel] = new RFormulaModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): RFormulaModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[RFormulaModel]] */
+  private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: resolvedFormula
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(instance.resolvedFormula))
+        .repartition(1).write.parquet(dataPath)
+      // Save pipeline model
+      val pmPath = new Path(path, "pipelineModel").toString
+      instance.pipelineModel.save(pmPath)
+    }
+  }
+
+  private class RFormulaModelReader extends MLReader[RFormulaModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[RFormulaModel].getName
+
+    override def load(path: String): RFormulaModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
+      val label = data.getString(0)
+      val terms = data.getAs[Seq[Seq[String]]](1)
+      val hasIntercept = data.getBoolean(2)
+      val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept)
+
+      val pmPath = new Path(path, "pipelineModel").toString
+      val pipelineModel = PipelineModel.load(pmPath)
+
+      val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }
 
 /**
  * Utility transformer for removing temporary columns from a DataFrame.
  * TODO(ekl) make this a public transformer
  */
-private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
-  override val uid = Identifiable.randomUID("columnPruner")
+private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String])
+  extends Transformer with MLWritable {
+
+  def this(columnsToPrune: Set[String]) =
+    this(Identifiable.randomUID("columnPruner"), columnsToPrune)
 
   override def transform(dataset: DataFrame): DataFrame = {
     val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
@@ -265,6 +332,48 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
   }
 
   override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
+
+  override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this)
+}
+
+private object ColumnPruner extends MLReadable[ColumnPruner] {
+
+  override def read: MLReader[ColumnPruner] = new ColumnPrunerReader
+
+  override def load(path: String): ColumnPruner = super.load(path)
+
+  /** [[MLWriter]] instance for [[ColumnPruner]] */
+  private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter {
+
+    private case class Data(columnsToPrune: Seq[String])
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: columnsToPrune
+      val data = Data(instance.columnsToPrune.toSeq)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class ColumnPrunerReader extends MLReader[ColumnPruner] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[ColumnPruner].getName
+
+    override def load(path: String): ColumnPruner = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head()
+      val columnsToPrune = data.getAs[Seq[String]](0).toSet
+      val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
+
+      DefaultParamsReader.getAndSetParams(pruner, metadata)
+      pruner
+    }
+  }
 }
 
 /**
@@ -278,11 +387,13 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
  *                          by the value in the map.
  */
 private class VectorAttributeRewriter(
-    vectorCol: String,
-    prefixesToRewrite: Map[String, String])
-  extends Transformer {
+    override val uid: String,
+    val vectorCol: String,
+    val prefixesToRewrite: Map[String, String])
+  extends Transformer with MLWritable {
 
-  override val uid = Identifiable.randomUID("vectorAttrRewriter")
+  def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
+    this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
 
   override def transform(dataset: DataFrame): DataFrame = {
     val metadata = {
@@ -315,4 +426,48 @@ private class VectorAttributeRewriter(
   }
 
   override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra)
+
+  override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this)
+}
+
+private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] {
+
+  override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader
+
+  override def load(path: String): VectorAttributeRewriter = super.load(path)
+
+  /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */
+  private[VectorAttributeRewriter]
+  class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter {
+
+    private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String])
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: vectorCol, prefixesToRewrite
+      val data = Data(instance.vectorCol, instance.prefixesToRewrite)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[VectorAttributeRewriter].getName
+
+    override def load(path: String): VectorAttributeRewriter = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
+      val vectorCol = data.getString(0)
+      val prefixesToRewrite = data.getAs[Map[String, String]](1)
+      val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
+
+      DefaultParamsReader.getAndSetParams(rewriter, metadata)
+      rewriter
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/edf8b877/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 010e7d2..3d7a91d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -221,10 +221,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
           // TODO: SPARK-11892: This case may require special handling.
           throw new UnsupportedOperationException("CrossValidator write will fail because it" +
             " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
-        case rform: RFormulaModel =>
-          // TODO: SPARK-11891: This case may require special handling.
-          throw new UnsupportedOperationException("CrossValidator write will fail because it" +
-            " cannot yet handle an estimator containing an RFormulaModel")
+        case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
         case _: Params => Array()
       }
       val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)

http://git-wip-us.apache.org/repos/asf/spark/blob/edf8b877/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 16e565d..e1b269b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.ml.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
   test("params") {
     ParamsSuite.checkParams(new RFormula())
   }
@@ -252,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
         new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
     assert(attrs === expectedAttrs)
   }
+
+  test("read/write: RFormula") {
+    val rFormula = new RFormula()
+      .setFormula("id ~ a:b")
+      .setFeaturesCol("myFeatures")
+      .setLabelCol("myLabels")
+
+    testDefaultReadWrite(rFormula)
+  }
+
+  test("read/write: RFormulaModel") {
+    def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = {
+      assert(model.uid === model2.uid)
+
+      assert(model.resolvedFormula.label === model2.resolvedFormula.label)
+      assert(model.resolvedFormula.terms === model2.resolvedFormula.terms)
+      assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept)
+
+      assert(model.pipelineModel.uid === model2.pipelineModel.uid)
+
+      model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach {
+        case (transformer1, transformer2) =>
+          assert(transformer1.uid === transformer2.uid)
+          assert(transformer1.params === transformer2.params)
+      }
+    }
+
+    val dataset = sqlContext.createDataFrame(
+      Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
+    ).toDF("id", "a", "b")
+
+    val rFormula = new RFormula().setFormula("id ~ a:b")
+
+    val model = rFormula.fit(dataset)
+    val newModel = testDefaultReadWrite(model)
+    checkModelData(model, newModel)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org