You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2018/03/23 18:56:40 UTC

spark git commit: [SPARK-23783][SPARK-11239][ML] Add PMML export to Spark ML pipelines

Repository: spark
Updated Branches:
  refs/heads/master cb43bbe13 -> 95c03cbd2


[SPARK-23783][SPARK-11239][ML] Add PMML export to Spark ML pipelines

## What changes were proposed in this pull request?

Adds PMML export support to Spark ML pipelines in the style of Spark's DataSource API to allow library authors to add their own model export formats.

Includes a specific implementation for Spark ML linear regression PMML export.

In addition to adding PMML to reach parity with our current MLlib implementation, this approach will allow other libraries & formats (like PFA) to implement and export models with a unified API.

## How was this patch tested?

Basic unit test.

Author: Holden Karau <ho...@google.com>
Author: Holden Karau <ho...@pigscanfly.ca>

Closes #19876 from holdenk/SPARK-11171-SPARK-11237-Add-PMML-export-for-ML-KMeans-r2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/95c03cbd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/95c03cbd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/95c03cbd

Branch: refs/heads/master
Commit: 95c03cbd27cea2255d9d748f9a84a0a38e54594d
Parents: cb43bbe
Author: Holden Karau <ho...@google.com>
Authored: Fri Mar 23 11:56:17 2018 -0700
Committer: Holden Karau <ho...@pigscanfly.ca>
Committed: Fri Mar 23 11:56:17 2018 -0700

----------------------------------------------------------------------
 .../org.apache.spark.ml.util.MLFormatRegister   |   2 +
 .../spark/ml/regression/LinearRegression.scala  |  70 +++++---
 .../org/apache/spark/ml/util/ReadWrite.scala    | 173 ++++++++++++++++++-
 .../org.apache.spark.ml.util.MLFormatRegister   |   3 +
 .../ml/regression/LinearRegressionSuite.scala   |  27 ++-
 .../spark/ml/util/PMMLReadWriteTest.scala       |  55 ++++++
 .../org/apache/spark/ml/util/PMMLUtils.scala    |  43 +++++
 .../apache/spark/ml/util/ReadWriteSuite.scala   | 132 ++++++++++++++
 8 files changed, 474 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
----------------------------------------------------------------------
diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
new file mode 100644
index 0000000..5e5484f
--- /dev/null
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
@@ -0,0 +1,2 @@
+org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
+org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 9251015..f67d9d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.{PipelineStage, PredictorParams}
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.linalg.BLAS._
@@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -643,7 +644,7 @@ class LinearRegressionModel private[ml] (
     @Since("1.3.0") val intercept: Double,
     @Since("2.3.0") val scale: Double)
   extends RegressionModel[Vector, LinearRegressionModel]
-  with LinearRegressionParams with MLWritable {
+  with LinearRegressionParams with GeneralMLWritable {
 
   private[ml] def this(uid: String, coefficients: Vector, intercept: Double) =
     this(uid, coefficients, intercept, 1.0)
@@ -710,7 +711,7 @@ class LinearRegressionModel private[ml] (
   }
 
   /**
-   * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
+   * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
    *
    * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]].
    * An option to save [[summary]] may be added in the future.
@@ -718,7 +719,50 @@ class LinearRegressionModel private[ml] (
    * This also does not save the [[parent]] currently.
    */
   @Since("1.6.0")
-  override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this)
+  override def write: GeneralMLWriter = new GeneralMLWriter(this)
+}
+
+/** A writer for LinearRegression that handles the "internal" (or default) format */
+private class InternalLinearRegressionModelWriter
+  extends MLWriterFormat with MLFormatRegister {
+
+  override def format(): String = "internal"
+  override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+
+  private case class Data(intercept: Double, coefficients: Vector, scale: Double)
+
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    val instance = stage.asInstanceOf[LinearRegressionModel]
+    val sc = sparkSession.sparkContext
+    // Save metadata and Params
+    DefaultParamsWriter.saveMetadata(instance, path, sc)
+    // Save model data: intercept, coefficients, scale
+    val data = Data(instance.intercept, instance.coefficients, instance.scale)
+    val dataPath = new Path(path, "data").toString
+    sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+  }
+}
+
+/** A writer for LinearRegression that handles the "pmml" format */
+private class PMMLLinearRegressionModelWriter
+    extends MLWriterFormat with MLFormatRegister {
+
+  override def format(): String = "pmml"
+
+  override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+
+  private case class Data(intercept: Double, coefficients: Vector)
+
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    val sc = sparkSession.sparkContext
+    // Construct the MLLib model which knows how to write to PMML.
+    val instance = stage.asInstanceOf[LinearRegressionModel]
+    val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept)
+    // Save PMML
+    oldModel.toPMML(sc, path)
+  }
 }
 
 @Since("1.6.0")
@@ -730,22 +774,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
   @Since("1.6.0")
   override def load(path: String): LinearRegressionModel = super.load(path)
 
-  /** [[MLWriter]] instance for [[LinearRegressionModel]] */
-  private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel)
-    extends MLWriter with Logging {
-
-    private case class Data(intercept: Double, coefficients: Vector, scale: Double)
-
-    override protected def saveImpl(path: String): Unit = {
-      // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
-      // Save model data: intercept, coefficients, scale
-      val data = Data(instance.intercept, instance.coefficients, instance.scale)
-      val dataPath = new Path(path, "data").toString
-      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
-    }
-  }
-
   private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {
 
     /** Checked against metadata when loading model */

http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index a616907..7edcd49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -18,9 +18,11 @@
 package org.apache.spark.ml.util
 
 import java.io.IOException
-import java.util.Locale
+import java.util.{Locale, ServiceLoader}
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.util.{Failure, Success, Try}
 
 import org.apache.hadoop.fs.Path
 import org.json4s._
@@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
-import org.apache.spark.SparkContext
-import org.apache.spark.annotation.{DeveloperApi, Since}
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml._
 import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
@@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite {
 }
 
 /**
- * Abstract class for utility classes that can save ML instances.
+ * Abstract class to be implemented by objects that provide ML exportability.
+ *
+ * A new instance of this class will be instantiated each time a save call is made.
+ *
+ * Must have a valid zero argument constructor which will be called to instantiate.
+ *
+ * @since 2.4.0
+ */
+@InterfaceStability.Unstable
+@Since("2.4.0")
+trait MLWriterFormat {
+  /**
+   * Function to write the provided pipeline stage out.
+   *
+   * @param path  The path to write the result out to.
+   * @param session  SparkSession associated with the write request.
+   * @param optionMap  User provided options stored as strings.
+   * @param stage  The pipeline stage to be saved.
+   */
+  @Since("2.4.0")
+  def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
+    stage: PipelineStage): Unit
+}
+
+/**
+ * ML export formats for should implement this trait so that users can specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a save call is made.
+ *
+ * @since 2.4.0
+ */
+@InterfaceStability.Unstable
+@Since("2.4.0")
+trait MLFormatRegister extends MLWriterFormat {
+  /**
+   * The string that represents the format that this format provider uses. This is, along with
+   * stageName, is overridden by children to provide a nice alias for the writer. For example:
+   *
+   * {{{
+   *   override def format(): String =
+   *       "pmml"
+   * }}}
+   * Indicates that this format is capable of saving a pmml model.
+   *
+   * Must have a valid zero argument constructor which will be called to instantiate.
+   *
+   * Format discovery is done using a ServiceLoader so make sure to list your format in
+   * META-INF/services.
+   * @since 2.4.0
+   */
+  @Since("2.4.0")
+  def format(): String
+
+  /**
+   * The string that represents the stage type that this writer supports. This is, along with
+   * format, is overridden by children to provide a nice alias for the writer. For example:
+   *
+   * {{{
+   *   override def stageName(): String =
+   *       "org.apache.spark.ml.regression.LinearRegressionModel"
+   * }}}
+   * Indicates that this format is capable of saving Spark's own PMML model.
+   *
+   * Format discovery is done using a ServiceLoader so make sure to list your format in
+   * META-INF/services.
+   * @since 2.4.0
+   */
+  @Since("2.4.0")
+  def stageName(): String
+
+  private[ml] def shortName(): String = s"${format()}+${stageName()}"
+}
+
+/**
+ * Abstract class for utility classes that can save ML instances in Spark's internal format.
  */
 @Since("1.6.0")
 abstract class MLWriter extends BaseReadWrite with Logging {
@@ -111,6 +188,15 @@ abstract class MLWriter extends BaseReadWrite with Logging {
   protected def saveImpl(path: String): Unit
 
   /**
+   * Overwrites if the output path already exists.
+   */
+  @Since("1.6.0")
+  def overwrite(): this.type = {
+    shouldOverwrite = true
+    this
+  }
+
+  /**
    * Map to store extra options for this writer.
    */
   protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()
@@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging {
     this
   }
 
+  // override for Java compatibility
+  @Since("1.6.0")
+  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
+
+  // override for Java compatibility
+  @Since("1.6.0")
+  override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+@InterfaceStability.Unstable
+@Since("2.4.0")
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. "pmml", "internal", or
+   * the fully qualified class name for export).
    */
-  @Since("1.6.0")
-  def overwrite(): this.type = {
-    shouldOverwrite = true
+  @Since("2.4.0")
+  def format(source: String): this.type = {
+    this.source = source
     this
   }
 
+  /**
+   * Dispatches the save to the correct MLFormat.
+   */
+  @Since("2.4.0")
+  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+  @throws[SparkException]("If multiple sources for a given short name format are found.")
+  override protected def saveImpl(path: String): Unit = {
+    val loader = Utils.getContextOrSparkClassLoader
+    val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader)
+    val stageName = stage.getClass.getName
+    val targetName = s"$source+$stageName"
+    val formats = serviceLoader.asScala.toList
+    val shortNames = formats.map(_.shortName())
+    val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
+      // requested name did not match any given registered alias
+      case Nil =>
+        Try(loader.loadClass(source)) match {
+          case Success(writer) =>
+            // Found the ML writer using the fully qualified path
+            writer
+          case Failure(error) =>
+            throw new SparkException(
+              s"Could not load requested format $source for $stageName ($targetName) had $formats" +
+              s"supporting $shortNames", error)
+        }
+      case head :: Nil =>
+        head.getClass
+      case _ =>
+        // Multiple sources
+        throw new SparkException(
+          s"Multiple writers found for $source+$stageName, try using the class name of the writer")
+    }
+    if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
+      val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
+      writer.write(path, sparkSession, optionMap, stage)
+    } else {
+      throw new SparkException(s"ML source $source is not a valid MLWriterFormat")
+    }
+  }
+
   // override for Java compatibility
   override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
 
@@ -163,6 +307,19 @@ trait MLWritable {
 }
 
 /**
+ * Trait for classes that provide `GeneralMLWriter`.
+ */
+@Since("2.4.0")
+@InterfaceStability.Unstable
+trait GeneralMLWritable extends MLWritable {
+  /**
+   * Returns an `MLWriter` instance for this ML instance.
+   */
+  @Since("2.4.0")
+  override def write: GeneralMLWriter
+}
+
+/**
  * :: DeveloperApi ::
  *
  * Helper trait for making simple `Params` types writable.  If a `Params` class stores

http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
----------------------------------------------------------------------
diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
new file mode 100644
index 0000000..100ef25
--- /dev/null
+++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
@@ -0,0 +1,3 @@
+org.apache.spark.ml.util.DuplicateLinearRegressionWriter1
+org.apache.spark.ml.util.DuplicateLinearRegressionWriter2
+org.apache.spark.ml.util.FakeLinearRegressionWriterWithName

http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 9b19f63..90ceb7d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,18 +17,23 @@
 
 package org.apache.spark.ml.regression
 
+import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.util.Random
 
+import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel}
+
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.LinearDataGenerator
 import org.apache.spark.sql.{DataFrame, Row}
 
-class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
+
+class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {
 
   import testImplicits._
 
@@ -1052,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
       LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
+  test("pmml export") {
+    val lr = new LinearRegression()
+    val model = lr.fit(datasetWithWeight)
+    def checkModel(pmml: PMML): Unit = {
+      val dd = pmml.getDataDictionary
+      assert(dd.getNumberOfFields === 3)
+      val fields = dd.getDataFields.asScala
+      assert(fields(0).getName().toString === "field_0")
+      assert(fields(0).getOpType() == OpType.CONTINUOUS)
+      val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+      val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+      val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList
+      assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+      assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+    }
+    testPMMLWrite(sc, model, checkModel)
+  }
+
   test("should support all NumericType labels and weights, and not support other types") {
     for (solver <- Seq("auto", "l-bfgs", "normal")) {
       val lr = new LinearRegression().setMaxIter(1).setSolver(solver)

http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala
new file mode 100644
index 0000000..d2c4832
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.{File, IOException}
+
+import org.dmg.pmml.PMML
+import org.scalatest.Suite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+trait PMMLReadWriteTest extends TempDirectory { self: Suite =>
+  /**
+   * Test PMML export. Requires exported model is small enough to be loaded locally.
+   * Checks that the model can be exported and the result is valid PMML, but does not check
+   * the specific contents of the model.
+   */
+  def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T,
+    checkModelData: PMML => Unit): Unit = {
+    val uid = instance.uid
+    val subdirName = Identifiable.randomUID("pmml-")
+
+    val subdir = new File(tempDir, subdirName)
+    val path = new File(subdir, uid).getPath
+
+    instance.write.format("pmml").save(path)
+    intercept[IOException] {
+      instance.write.format("pmml").save(path)
+    }
+    instance.write.format("pmml").overwrite().save(path)
+    val pmmlStr = sc.textFile(path).collect.mkString("\n")
+    val pmmlModel = PMMLUtils.loadFromString(pmmlStr)
+    assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark"))
+    checkModelData(pmmlModel)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
new file mode 100644
index 0000000..dbdc69f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.util
+
+import java.io.StringReader
+import javax.xml.bind.Unmarshaller
+import javax.xml.transform.Source
+
+import org.dmg.pmml._
+import org.jpmml.model.{ImportFilter, JAXBUtil}
+import org.xml.sax.InputSource
+
+/**
+ * Testing utils for working with PMML.
+ * Predictive Model Markup Language (PMML) is an XML-based file format
+ * developed by the Data Mining Group (www.dmg.org).
+ */
+private[spark] object PMMLUtils {
+  /**
+   * :: Experimental ::
+   * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported
+   * through external spark-packages.
+   */
+  def loadFromString(input: String): PMML = {
+    val is = new StringReader(input)
+    val transformed = ImportFilter.apply(new InputSource(is))
+    JAXBUtil.unmarshalPMML(transformed)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/95c03cbd/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala
new file mode 100644
index 0000000..f4c1f0b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.PipelineStage
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.sql.{DataFrame, SparkSession}
+
+class FakeLinearRegressionWriter extends MLWriterFormat {
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    throw new Exception(s"Fake writer doesn't writestart")
+  }
+}
+
+class FakeLinearRegressionWriterWithName extends MLFormatRegister {
+  override def format(): String = "fakeWithName"
+  override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    throw new Exception(s"Fake writer doesn't writestart")
+  }
+}
+
+
+class DuplicateLinearRegressionWriter1 extends MLFormatRegister {
+  override def format(): String = "dupe"
+  override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    throw new Exception(s"Duplicate writer shouldn't have been called")
+  }
+}
+
+class DuplicateLinearRegressionWriter2 extends MLFormatRegister {
+  override def format(): String = "dupe"
+  override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    throw new Exception(s"Duplicate writer shouldn't have been called")
+  }
+}
+
+class ReadWriteSuite extends MLTest {
+
+  import testImplicits._
+
+  private val seed: Int = 42
+  @transient var dataset: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    dataset = sc.parallelize(LinearDataGenerator.generateLinearInput(
+      intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0),
+      xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF()
+  }
+
+  test("unsupported/non existent export formats") {
+    val lr = new LinearRegression()
+    val model = lr.fit(dataset)
+    // Does not exist with a long class name
+    val thrownDNE = intercept[SparkException] {
+      model.write.format("com.holdenkarau.boop").save("boop")
+    }
+    assert(thrownDNE.getMessage().
+      contains("Could not load requested format"))
+
+    // Does not exist with a short name
+    val thrownDNEShort = intercept[SparkException] {
+      model.write.format("boop").save("boop")
+    }
+    assert(thrownDNEShort.getMessage().
+      contains("Could not load requested format"))
+
+    // Check with a valid class that is not a writer format.
+    val thrownInvalid = intercept[SparkException] {
+      model.write.format("org.apache.spark.SparkContext").save("boop2")
+    }
+    assert(thrownInvalid.getMessage()
+      .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat"))
+  }
+
+  test("invalid paths fail") {
+    val lr = new LinearRegression()
+    val model = lr.fit(dataset)
+    val thrown = intercept[Exception] {
+      model.write.format("pmml").save("")
+    }
+    assert(thrown.getMessage().contains("Can not create a Path from an empty string"))
+  }
+
+  test("dummy export format is called") {
+    val lr = new LinearRegression()
+    val model = lr.fit(dataset)
+    val thrown = intercept[Exception] {
+      model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name")
+    }
+    assert(thrown.getMessage().contains("Fake writer doesn't write"))
+    val thrownWithName = intercept[Exception] {
+      model.write.format("fakeWithName").save("name")
+    }
+    assert(thrownWithName.getMessage().contains("Fake writer doesn't write"))
+  }
+
+  test("duplicate format raises error") {
+    val lr = new LinearRegression()
+    val model = lr.fit(dataset)
+    val thrown = intercept[Exception] {
+      model.write.format("dupe").save("dupepanda")
+    }
+    assert(thrown.getMessage().contains("Multiple writers found for"))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org