You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/03/31 20:17:39 UTC

spark git commit: [SPARK-11892][ML] Model export/import for spark.ml: OneVsRest

Repository: spark
Updated Branches:
  refs/heads/master a0a199158 -> 8b207f3b6


[SPARK-11892][ML] Model export/import for spark.ml: OneVsRest

# What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-11892

Add save/load for spark ml.OneVsRest and its model. Also add OneVsRest and OneVsRestModel in MetaAlgorithmReadWrite.

# How was this patch tested?

Test with Scala unit test.

Author: Xusen Yin <yi...@gmail.com>

Closes #9934 from yinxusen/SPARK-11892.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8b207f3b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8b207f3b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8b207f3b

Branch: refs/heads/master
Commit: 8b207f3b6a0eb617d38091f3b9001830ac3651fe
Parents: a0a1991
Author: Xusen Yin <yi...@gmail.com>
Authored: Thu Mar 31 11:17:32 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Mar 31 11:17:32 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/OneVsRest.scala     | 165 +++++++++++++++++--
 .../org/apache/spark/ml/util/ReadWrite.scala    |   8 +-
 .../ml/classification/OneVsRestSuite.scala      |  68 +++++++-
 3 files changed, 223 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8b207f3b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index c41a611..98b99a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -21,22 +21,24 @@ import java.util.UUID
 
 import scala.language.existentials
 
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject, _}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml._
 import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 
-/**
- * Params for [[OneVsRest]].
- */
-private[ml] trait OneVsRestParams extends PredictorParams {
-
+private[ml] trait ClassifierTypeTrait {
   // scalastyle:off structural.type
   type ClassifierType = Classifier[F, E, M] forSome {
     type F
@@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams {
     type E <: Classifier[F, E, M]
   }
   // scalastyle:on structural.type
+}
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
 
   /**
    * param for the base binary classifier that we reduce multiclass classification into.
@@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams {
   def getClassifier: ClassifierType = $(classifier)
 }
 
+private[ml] object OneVsRestParams extends ClassifierTypeTrait {
+
+  def validateParams(instance: OneVsRestParams): Unit = {
+    def checkElement(elem: Params, name: String): Unit = elem match {
+      case stage: MLWritable => // good
+      case other =>
+        throw new UnsupportedOperationException("OneVsRest write will fail " +
+          s" because it contains $name which does not implement MLWritable." +
+          s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+    }
+
+    instance match {
+      case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model"))
+      case _ => // no need to check OneVsRest here
+    }
+
+    checkElement(instance.getClassifier, "classifier")
+  }
+
+  def saveImpl(
+      path: String,
+      instance: OneVsRestParams,
+      sc: SparkContext,
+      extraMetadata: Option[JObject] = None): Unit = {
+
+    val params = instance.extractParamMap().toSeq
+    val jsonParams = render(params
+      .filter { case ParamPair(p, v) => p.name != "classifier" }
+      .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
+      .toList)
+
+    DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+    val classifierPath = new Path(path, "classifier").toString
+    instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
+  }
+
+  def loadImpl(
+      path: String,
+      sc: SparkContext,
+      expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {
+
+    val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+    val classifierPath = new Path(path, "classifier").toString
+    val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
+    (metadata, estimator)
+  }
+}
+
 /**
  * :: Experimental ::
  * Model produced by [[OneVsRest]].
@@ -73,10 +130,10 @@ private[ml] trait OneVsRestParams extends PredictorParams {
 @Since("1.4.0")
 @Experimental
 final class OneVsRestModel private[ml] (
-    @Since("1.4.0")  override val uid: String,
-    @Since("1.4.0") labelMetadata: Metadata,
+    @Since("1.4.0") override val uid: String,
+    private[ml] val labelMetadata: Metadata,
     @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
-  extends Model[OneVsRestModel] with OneVsRestParams {
+  extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
@@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] (
       uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRestModel extends MLReadable[OneVsRestModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): OneVsRestModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[OneVsRestModel]] */
+  private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter {
+
+    OneVsRestParams.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit = {
+      val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
+        ("numClasses" -> instance.models.length)
+      OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
+      instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) =>
+        val modelPath = new Path(path, s"model_$idx").toString
+        model.save(modelPath)
+      }
+    }
+  }
+
+  private class OneVsRestModelReader extends MLReader[OneVsRestModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[OneVsRestModel].getName
+
+    override def load(path: String): OneVsRestModel = {
+      implicit val format = DefaultFormats
+      val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+      val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
+      val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+      val models = Range(0, numClasses).toArray.map { idx =>
+        val modelPath = new Path(path, s"model_$idx").toString
+        DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
+      }
+      val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
+      DefaultParamsReader.getAndSetParams(ovrModel, metadata)
+      ovrModel.set("classifier", classifier)
+      ovrModel
+    }
+  }
 }
 
 /**
@@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] (
 @Experimental
 final class OneVsRest @Since("1.4.0") (
     @Since("1.4.0") override val uid: String)
-  extends Estimator[OneVsRestModel] with OneVsRestParams {
+  extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("oneVsRest"))
@@ -243,4 +350,40 @@ final class OneVsRest @Since("1.4.0") (
     }
     copied
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new OneVsRest.OneVsRestWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRest extends MLReadable[OneVsRest] {
+
+  @Since("2.0.0")
+  override def read: MLReader[OneVsRest] = new OneVsRestReader
+
+  @Since("2.0.0")
+  override def load(path: String): OneVsRest = super.load(path)
+
+  /** [[MLWriter]] instance for [[OneVsRest]] */
+  private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter {
+
+    OneVsRestParams.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit = {
+      OneVsRestParams.saveImpl(path, instance, sc)
+    }
+  }
+
+  private class OneVsRestReader extends MLReader[OneVsRest] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[OneVsRest].getName
+
+    override def load(path: String): OneVsRest = {
+      val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+      val ovr = new OneVsRest(metadata.uid)
+      DefaultParamsReader.getAndSetParams(ovr, metadata)
+      ovr.setClassifier(classifier)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8b207f3b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 5a596ca..39999ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
 import org.apache.spark.ml.feature.RFormulaModel
 import org.apache.spark.ml.param.{ParamPair, Params}
 import org.apache.spark.ml.tuning.ValidatorParams
@@ -381,10 +381,8 @@ private[ml] object MetaAlgorithmReadWrite {
       case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
       case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
       case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
-      case ovr: OneVsRestParams =>
-        // TODO: SPARK-11892: This case may require special handling.
-        throw new UnsupportedOperationException(s"${instance.getClass.getName} write will fail" +
-          s" because it cannot yet handle an estimator containing type: ${ovr.getClass.getName}.")
+      case ovr: OneVsRest => Array(ovr.getClassifier)
+      case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
       case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
       case _: Params => Array()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/8b207f3b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 2ae74a2..51c1baf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.feature.StringIndexer
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
 import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.Metadata
 
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   @transient var dataset: DataFrame = _
   @transient var rdd: RDD[LabeledPoint] = _
@@ -160,6 +160,70 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
       require(m.getThreshold === 0.1, "copy should handle extra model params")
     }
   }
+
+  test("read/write: OneVsRest") {
+    val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+
+    val ova = new OneVsRest()
+      .setClassifier(lr)
+      .setLabelCol("myLabel")
+      .setFeaturesCol("myFeature")
+      .setPredictionCol("myPrediction")
+
+    val ova2 = testDefaultReadWrite(ova, testParams = false)
+    assert(ova.uid === ova2.uid)
+    assert(ova.getFeaturesCol === ova2.getFeaturesCol)
+    assert(ova.getLabelCol === ova2.getLabelCol)
+    assert(ova.getPredictionCol === ova2.getPredictionCol)
+
+    ova2.getClassifier match {
+      case lr2: LogisticRegression =>
+        assert(lr.uid === lr2.uid)
+        assert(lr.getMaxIter === lr2.getMaxIter)
+        assert(lr.getRegParam === lr2.getRegParam)
+      case other =>
+        throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
+          s" LogisticRegression but found ${other.getClass.getName}")
+    }
+  }
+
+  test("read/write: OneVsRestModel") {
+    def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = {
+      assert(model.uid === model2.uid)
+      assert(model.getFeaturesCol === model2.getFeaturesCol)
+      assert(model.getLabelCol === model2.getLabelCol)
+      assert(model.getPredictionCol === model2.getPredictionCol)
+
+      val classifier = model.getClassifier.asInstanceOf[LogisticRegression]
+
+      model2.getClassifier match {
+        case lr2: LogisticRegression =>
+          assert(classifier.uid === lr2.uid)
+          assert(classifier.getMaxIter === lr2.getMaxIter)
+          assert(classifier.getRegParam === lr2.getRegParam)
+        case other =>
+          throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
+            s" LogisticRegression but found ${other.getClass.getName}")
+      }
+
+      assert(model.labelMetadata === model2.labelMetadata)
+      model.models.zip(model2.models).foreach {
+        case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
+          assert(lrModel1.uid === lrModel2.uid)
+          assert(lrModel1.coefficients === lrModel2.coefficients)
+          assert(lrModel1.intercept === lrModel2.intercept)
+        case other =>
+          throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
+            s" LogisticRegressionModel but found ${other.getClass.getName}")
+      }
+    }
+
+    val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+    val ova = new OneVsRest().setClassifier(lr)
+    val ovaModel = ova.fit(dataset)
+    val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
+    checkModelData(ovaModel, newOvaModel)
+  }
 }
 
 private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org