You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/07/17 17:07:36 UTC
spark git commit: [SPARK-21221][ML] CrossValidator and
TrainValidationSplit Persist Nested Estimators such as OneVsRest
Repository: spark
Updated Branches:
refs/heads/master 4ce735eed -> 7047f49f4
[SPARK-21221][ML] CrossValidator and TrainValidationSplit Persist Nested Estimators such as OneVsRest
## What changes were proposed in this pull request?
Added functionality for CrossValidator and TrainValidationSplit to persist nested estimators such as OneVsRest. Also added CrossValidator and TrainValidation split persistence to pyspark.
## How was this patch tested?
Performed both cross validation and train validation split with a one vs. rest estimator and tested read/write functionality of the estimator parameter maps required by these meta-algorithms.
Author: Ajay Saini <aj...@gmail.com>
Closes #18428 from ajaysaini725/MetaAlgorithmPersistNestedEstimators.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7047f49f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7047f49f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7047f49f
Branch: refs/heads/master
Commit: 7047f49f45406be3b4a9b0aa209b3021621392ca
Parents: 4ce735e
Author: Ajay Saini <aj...@gmail.com>
Authored: Mon Jul 17 10:07:32 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Jul 17 10:07:32 2017 -0700
----------------------------------------------------------------------
.../spark/ml/tuning/ValidatorParams.scala | 31 ++-
.../spark/ml/tuning/CrossValidatorSuite.scala | 103 +++++++--
.../ml/tuning/TrainValidationSplitSuite.scala | 84 ++++++-
.../ml/tuning/ValidatorParamsSuiteHelpers.scala | 86 +++++++
.../spark/ml/util/DefaultReadWriteTest.scala | 1 -
python/pyspark/ml/classification.py | 92 +++++---
python/pyspark/ml/tests.py | 145 +++++++++++-
python/pyspark/ml/tuning.py | 226 ++++++++++++++++++-
python/pyspark/ml/wrapper.py | 2 +-
9 files changed, 696 insertions(+), 74 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index d55eb14..0ab6eed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -126,10 +126,26 @@ private[ml] object ValidatorParams {
extraMetadata: Option[JObject] = None): Unit = {
import org.json4s.JsonDSL._
+ var numParamsNotJson = 0
val estimatorParamMapsJson = compact(render(
instance.getEstimatorParamMaps.map { case paramMap =>
paramMap.toSeq.map { case ParamPair(p, v) =>
- Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+ v match {
+ case writeableObj: DefaultParamsWritable =>
+ val relativePath = "epm_" + p.name + numParamsNotJson
+ val paramPath = new Path(path, relativePath).toString
+ numParamsNotJson += 1
+ writeableObj.save(paramPath)
+ Map("parent" -> p.parent, "name" -> p.name,
+ "value" -> compact(render(JString(relativePath))),
+ "isJson" -> compact(render(JBool(false))))
+ case _: MLWritable =>
+ throw new NotImplementedError("ValidatorParams.saveImpl does not handle parameters " +
+ "of type: MLWritable that are not DefaultParamsWritable")
+ case _ =>
+ Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v),
+ "isJson" -> compact(render(JBool(true))))
+ }
}
}.toSeq
))
@@ -183,8 +199,17 @@ private[ml] object ValidatorParams {
val paramPairs = pMap.map { case pInfo: Map[String, String] =>
val est = uidToParams(pInfo("parent"))
val param = est.getParam(pInfo("name"))
- val value = param.jsonDecode(pInfo("value"))
- param -> value
+ // [Spark-21221] introduced the isJson field
+ if (!pInfo.contains("isJson") ||
+ (pInfo.contains("isJson") && pInfo("isJson").toBoolean.booleanValue())) {
+ val value = param.jsonDecode(pInfo("value"))
+ param -> value
+ } else {
+ val relativePath = param.jsonDecode(pInfo("value")).toString
+ val value = DefaultParamsReader
+ .loadParamsInstance[MLWritable](new Path(path, relativePath).toString, sc)
+ param -> value
+ }
}
ParamMap(paramPairs: _*)
}.toArray
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 2b4e6b5..2791ea7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model, Pipeline}
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest}
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
-import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
-import org.apache.spark.ml.param.{ParamMap, ParamPair}
+import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -153,7 +153,76 @@ class CrossValidatorSuite
s" LogisticRegression but found ${other.getClass.getName}")
}
- CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ ValidatorParamsSuiteHelpers
+ .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ }
+
+ test("read/write: CrossValidator with nested estimator") {
+ val ova = new OneVsRest().setClassifier(new LogisticRegression)
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("accuracy")
+ val classifier1 = new LogisticRegression().setRegParam(2.0)
+ val classifier2 = new LogisticRegression().setRegParam(3.0)
+ // params that are not JSON serializable must inherit from Params
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(ova.classifier, Array(classifier1, classifier2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(ova)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+ assert(cv.getSeed === cv2.getSeed)
+
+ assert(cv2.getEvaluator.isInstanceOf[MulticlassClassificationEvaluator])
+ val evaluator2 = cv2.getEvaluator.asInstanceOf[MulticlassClassificationEvaluator]
+ assert(evaluator.uid === evaluator2.uid)
+ assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+ cv2.getEstimator match {
+ case ova2: OneVsRest =>
+ assert(ova.uid === ova2.uid)
+ val classifier = ova2.getClassifier
+ classifier match {
+ case lr: LogisticRegression =>
+ assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
+ === lr.getMaxIter)
+ case _ =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" LogisticREgression but found ${classifier.getClass.getName}")
+ }
+
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" OneVsRest but found ${other.getClass.getName}")
+ }
+
+ ValidatorParamsSuiteHelpers
+ .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ }
+
+ test("read/write: Persistence of nested estimator works if parent directory changes") {
+ val ova = new OneVsRest().setClassifier(new LogisticRegression)
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("accuracy")
+ val classifier1 = new LogisticRegression().setRegParam(2.0)
+ val classifier2 = new LogisticRegression().setRegParam(3.0)
+ // params that are not JSON serializable must inherit from Params
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(ova.classifier, Array(classifier1, classifier2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(ova)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ ValidatorParamsSuiteHelpers.testFileMove(cv)
}
test("read/write: CrossValidator with complex estimator") {
@@ -193,7 +262,8 @@ class CrossValidatorSuite
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
- CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ ValidatorParamsSuiteHelpers
+ .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
cv2.getEstimator match {
case pipeline2: Pipeline =>
@@ -212,7 +282,8 @@ class CrossValidatorSuite
assert(lrcv.uid === lrcv2.uid)
assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
- CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
+ ValidatorParamsSuiteHelpers
+ .compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
case other =>
throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" +
" but found: " + other.map(_.getClass.getName).mkString(", "))
@@ -278,7 +349,8 @@ class CrossValidatorSuite
s" LogisticRegression but found ${other.getClass.getName}")
}
- CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ ValidatorParamsSuiteHelpers
+ .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
cv2.bestModel match {
case lrModel2: LogisticRegressionModel =>
@@ -296,21 +368,6 @@ class CrossValidatorSuite
object CrossValidatorSuite extends SparkFunSuite {
- /**
- * Assert sequences of estimatorParamMaps are identical.
- * Params must be simple types comparable with `===`.
- */
- def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
- assert(pMaps.length === pMaps2.length)
- pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
- assert(pMap.size === pMap2.size)
- pMap.toSeq.foreach { case ParamPair(p, v) =>
- assert(pMap2.contains(p))
- assert(pMap2(p) === v)
- }
- }
- }
-
abstract class MyModel extends Model[MyModel]
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index a34f930..71a1776 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest}
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{ParamMap}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -95,7 +95,7 @@ class TrainValidationSplitSuite
}
test("transformSchema should check estimatorParamMaps") {
- import TrainValidationSplitSuite._
+ import TrainValidationSplitSuite.{MyEstimator, MyEvaluator}
val est = new MyEstimator("est")
val eval = new MyEvaluator
@@ -134,6 +134,82 @@ class TrainValidationSplitSuite
assert(tvs.getTrainRatio === tvs2.getTrainRatio)
assert(tvs.getSeed === tvs2.getSeed)
+
+ ValidatorParamsSuiteHelpers
+ .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
+
+ tvs2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ case other =>
+ throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: TrainValidationSplit with nested estimator") {
+ val ova = new OneVsRest()
+ .setClassifier(new LogisticRegression)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val classifier1 = new LogisticRegression().setRegParam(2.0)
+ val classifier2 = new LogisticRegression().setRegParam(3.0)
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(ova.classifier, Array(classifier1, classifier2))
+ .build()
+ val tvs = new TrainValidationSplit()
+ .setEstimator(ova)
+ .setEvaluator(evaluator)
+ .setTrainRatio(0.5)
+ .setEstimatorParamMaps(paramMaps)
+ .setSeed(42L)
+
+ val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+ assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+ assert(tvs.getSeed === tvs2.getSeed)
+
+ tvs2.getEstimator match {
+ case ova2: OneVsRest =>
+ assert(ova.uid === ova2.uid)
+ val classifier = ova2.getClassifier
+ classifier match {
+ case lr: LogisticRegression =>
+ assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
+ === lr.getMaxIter)
+ case _ =>
+ throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" +
+ s" LogisticREgression but found ${classifier.getClass.getName}")
+ }
+
+ case other =>
+ throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" +
+ s" OneVsRest but found ${other.getClass.getName}")
+ }
+
+ ValidatorParamsSuiteHelpers
+ .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
+ }
+
+ test("read/write: Persistence of nested estimator works if parent directory changes") {
+ val ova = new OneVsRest()
+ .setClassifier(new LogisticRegression)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val classifier1 = new LogisticRegression().setRegParam(2.0)
+ val classifier2 = new LogisticRegression().setRegParam(3.0)
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(ova.classifier, Array(classifier1, classifier2))
+ .build()
+ val tvs = new TrainValidationSplit()
+ .setEstimator(ova)
+ .setEvaluator(evaluator)
+ .setTrainRatio(0.5)
+ .setEstimatorParamMaps(paramMaps)
+ .setSeed(42L)
+
+ ValidatorParamsSuiteHelpers.testFileMove(tvs)
}
test("read/write: TrainValidationSplitModel") {
@@ -160,7 +236,7 @@ class TrainValidationSplitSuite
}
}
-object TrainValidationSplitSuite {
+object TrainValidationSplitSuite extends SparkFunSuite{
abstract class MyModel extends Model[MyModel]
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
new file mode 100644
index 0000000..1df673c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import java.io.File
+import java.nio.file.{Files, StandardCopyOption}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLReader, MLWritable}
+
+object ValidatorParamsSuiteHelpers extends SparkFunSuite with DefaultReadWriteTest {
+ /**
+ * Assert sequences of estimatorParamMaps are identical.
+ * If the values for a parameter are not directly comparable with ===
+ * and are instead Params types themselves then their corresponding paramMaps
+ * are compared against each other.
+ */
+ def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
+ assert(pMaps.length === pMaps2.length)
+ pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
+ assert(pMap.size === pMap2.size)
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ assert(pMap2.contains(p))
+ val otherParam = pMap2(p)
+ v match {
+ case estimator: Params =>
+ otherParam match {
+ case estimator2: Params =>
+ val estimatorParamMap = Array(estimator.extractParamMap())
+ val estimatorParamMap2 = Array(estimator2.extractParamMap())
+ compareParamMaps(estimatorParamMap, estimatorParamMap2)
+ case other =>
+ throw new AssertionError(s"Expected parameter of type Params but" +
+ s" found ${otherParam.getClass.getName}")
+ }
+ case _ =>
+ assert(otherParam === v)
+ }
+ }
+ }
+ }
+
+ /**
+ * When nested estimators (ex. OneVsRest) are saved within meta-algorithms such as
+ * CrossValidator and TrainValidationSplit, relative paths should be used to store
+ * the path of the estimator so that if the parent directory changes, loading the
+ * model still works.
+ */
+ def testFileMove[T <: Params with MLWritable](instance: T): Unit = {
+ val uid = instance.uid
+ val subdirName = Identifiable.randomUID("test")
+
+ val subdir = new File(tempDir, subdirName)
+ val subDirWithUid = new File(subdir, uid)
+
+ instance.save(subDirWithUid.getPath)
+
+ val newSubdirName = Identifiable.randomUID("test_moved")
+ val newSubdir = new File(tempDir, newSubdirName)
+ val newSubdirWithUid = new File(newSubdir, uid)
+
+ Files.createDirectory(newSubdir.toPath)
+ Files.createDirectory(newSubdirWithUid.toPath)
+ Files.move(subDirWithUid.toPath, newSubdirWithUid.toPath, StandardCopyOption.ATOMIC_MOVE)
+
+ val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
+ val newInstance = loader.load(newSubdirWithUid.getPath)
+ assert(uid == newInstance.uid)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 27d606c..4da95e7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -55,7 +55,6 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
instance.write.overwrite().save(path)
val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
val newInstance = loader.load(path)
-
assert(newInstance.uid === instance.uid)
if (testParams) {
instance.params.foreach { p =>
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 948806a..82207f6 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -25,7 +25,7 @@ from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
-from pyspark.ml.common import inherit_doc
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
@@ -1472,7 +1472,7 @@ class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
@inherit_doc
-class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
+class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1589,22 +1589,6 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
newOvr.setClassifier(self.getClassifier().copy(extra))
return newOvr
- @since("2.0.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @since("2.0.0")
- def save(self, path):
- """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
-
- @classmethod
- @since("2.0.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
@classmethod
def _from_java(cls, java_stage):
"""
@@ -1634,8 +1618,52 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
_java_obj.setPredictionCol(self.getPredictionCol())
return _java_obj
+ def _make_java_param_pair(self, param, value):
+ """
+ Makes a Java param pair.
+ """
+ sc = SparkContext._active_spark_context
+ param = self._resolveParam(param)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
+ self.uid)
+ java_param = _java_obj.getParam(param.name)
+ if isinstance(value, JavaParams):
+ # used in the case of an estimator having another estimator as a parameter
+ # the reason why this is not in _py2java in common.py is that importing
+ # Estimator and Model in common.py results in a circular import with inherit_doc
+ java_value = value._to_java()
+ else:
+ java_value = _py2java(sc, value)
+ return java_param.w(java_value)
-class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
+ def _transfer_param_map_to_java(self, pyParamMap):
+ """
+ Transforms a Python ParamMap into a Java ParamMap.
+ """
+ paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+ for param in self.params:
+ if param in pyParamMap:
+ pair = self._make_java_param_pair(param, pyParamMap[param])
+ paramMap.put([pair])
+ return paramMap
+
+ def _transfer_param_map_from_java(self, javaParamMap):
+ """
+ Transforms a Java ParamMap into a Python ParamMap.
+ """
+ sc = SparkContext._active_spark_context
+ paramMap = dict()
+ for pair in javaParamMap.toList():
+ param = pair.param()
+ if self.hasParam(str(param.name())):
+ if param.name() == "classifier":
+ paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
+ else:
+ paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+ return paramMap
+
+
+class OneVsRestModel(Model, OneVsRestParams, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1650,6 +1678,16 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
def __init__(self, models):
super(OneVsRestModel, self).__init__()
self.models = models
+ java_models = [model._to_java() for model in self.models]
+ sc = SparkContext._active_spark_context
+ java_models_array = JavaWrapper._new_java_array(java_models,
+ sc._gateway.jvm.org.apache.spark.ml
+ .classification.ClassificationModel)
+ # TODO: need to set metadata
+ metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
+ self._java_obj = \
+ JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
+ self.uid, metadata.empty(), java_models_array)
def _transform(self, dataset):
# determine the input columns: these need to be passed through
@@ -1715,22 +1753,6 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
newModel.models = [model.copy(extra) for model in self.models]
return newModel
- @since("2.0.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @since("2.0.0")
- def save(self, path):
- """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
-
- @classmethod
- @since("2.0.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
@classmethod
def _from_java(cls, java_stage):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 7870047..6c71e69 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -49,7 +49,8 @@ from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import *
from pyspark.ml.clustering import *
from pyspark.ml.common import _java2py, _py2java
-from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, \
+ MulticlassClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.fpm import FPGrowth, FPGrowthModel
from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \
@@ -678,7 +679,7 @@ class CrossValidatorTests(SparkSessionTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
- def test_save_load(self):
+ def test_save_load_trained_model(self):
# This tests saving and loading the trained model only.
# Save/load for CrossValidator will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
@@ -702,6 +703,76 @@ class CrossValidatorTests(SparkSessionTestCase):
self.assertEqual(loadedLrModel.uid, lrModel.uid)
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
+ def test_save_load_simple_estimator(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+
+ # test save/load of CrossValidator
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+ self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+
+ # test save/load of CrossValidatorModel
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
+ def test_save_load_nested_estimator(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ ova = OneVsRest(classifier=LogisticRegression())
+ lr1 = LogisticRegression().setMaxIter(100)
+ lr2 = LogisticRegression().setMaxIter(150)
+ grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+ evaluator = MulticlassClassificationEvaluator()
+
+ # test save/load of CrossValidator
+ cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+
+ originalParamMap = cv.getEstimatorParamMaps()
+ loadedParamMap = loadedCV.getEstimatorParamMaps()
+ for i, param in enumerate(loadedParamMap):
+ for p in param:
+ if p.name == "classifier":
+ self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+ else:
+ self.assertEqual(param[p], originalParamMap[i][p])
+
+ # test save/load of CrossValidatorModel
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
class TrainValidationSplitTests(SparkSessionTestCase):
@@ -759,7 +830,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
"validationMetrics has the same size of grid parameter")
self.assertEqual(1.0, max(validationMetrics))
- def test_save_load(self):
+ def test_save_load_trained_model(self):
# This tests saving and loading the trained model only.
# Save/load for TrainValidationSplit will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
@@ -783,6 +854,74 @@ class TrainValidationSplitTests(SparkSessionTestCase):
self.assertEqual(loadedLrModel.uid, lrModel.uid)
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
+ def test_save_load_simple_estimator(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for TrainValidationSplit will be added later: SPARK-13786
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+ self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
+ def test_save_load_nested_estimator(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for TrainValidationSplit will be added later: SPARK-13786
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ ova = OneVsRest(classifier=LogisticRegression())
+ lr1 = LogisticRegression().setMaxIter(100)
+ lr2 = LogisticRegression().setMaxIter(150)
+ grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+ evaluator = MulticlassClassificationEvaluator()
+
+ tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+
+ originalParamMap = tvs.getEstimatorParamMaps()
+ loadedParamMap = loadedTvs.getEstimatorParamMaps()
+ for i, param in enumerate(loadedParamMap):
+ for p in param:
+ if p.name == "classifier":
+ self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+ else:
+ self.assertEqual(param[p], originalParamMap[i][p])
+
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
def test_copy(self):
dataset = self.spark.createDataFrame([
(10, 10.0),
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index b648582..00c348a 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -20,8 +20,11 @@ import numpy as np
from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
+from pyspark.ml.common import _py2java
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
+from pyspark.ml.util import *
+from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
@@ -137,8 +140,37 @@ class ValidatorParams(HasSeed):
"""
return self.getOrDefault(self.evaluator)
+ @classmethod
+ def _from_java_impl(cls, java_stage):
+ """
+ Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
+ """
+
+ # Load information from java_stage to the instance.
+ estimator = JavaParams._from_java(java_stage.getEstimator())
+ evaluator = JavaParams._from_java(java_stage.getEvaluator())
+ epms = [estimator._transfer_param_map_from_java(epm)
+ for epm in java_stage.getEstimatorParamMaps()]
+ return estimator, epms, evaluator
+
+ def _to_java_impl(self):
+ """
+ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
+
+ java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
+ for idx, epm in enumerate(self.getEstimatorParamMaps()):
+ java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
-class CrossValidator(Estimator, ValidatorParams):
+ java_estimator = self.getEstimator()._to_java()
+ java_evaluator = self.getEvaluator()._to_java()
+ return java_estimator, java_epms, java_evaluator
+
+
+class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
K-fold cross validation performs model selection by splitting the dataset into a set of
@@ -263,8 +295,53 @@ class CrossValidator(Estimator, ValidatorParams):
newCV.setEvaluator(self.getEvaluator().copy(extra))
return newCV
+ @since("2.3.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @classmethod
+ @since("2.3.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidator, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
-class CrossValidatorModel(Model, ValidatorParams):
+ estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
+ numFolds = java_stage.getNumFolds()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ numFolds=numFolds, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidator. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
+
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setSeed(self.getSeed())
+ _java_obj.setNumFolds(self.getNumFolds())
+
+ return _java_obj
+
+
+class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
CrossValidatorModel contains the model with the highest average cross-validation
@@ -302,8 +379,55 @@ class CrossValidatorModel(Model, ValidatorParams):
avgMetrics = self.avgMetrics
return CrossValidatorModel(bestModel, avgMetrics)
+ @since("2.3.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @classmethod
+ @since("2.3.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
-class TrainValidationSplit(Estimator, ValidatorParams):
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidatorModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ bestModel = JavaParams._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
+
+ py_stage = cls(bestModel=bestModel).setEstimator(estimator)
+ py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+ # TODO: persist average metrics as well
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
+
+class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -418,8 +542,53 @@ class TrainValidationSplit(Estimator, ValidatorParams):
newTVS.setEvaluator(self.getEvaluator().copy(extra))
return newTVS
+ @since("2.3.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @classmethod
+ @since("2.3.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplit, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
+ trainRatio = java_stage.getTrainRatio()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ trainRatio=trainRatio, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
-class TrainValidationSplitModel(Model, ValidatorParams):
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+ self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setTrainRatio(self.getTrainRatio())
+ _java_obj.setSeed(self.getSeed())
+
+ return _java_obj
+
+
+class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -456,6 +625,55 @@ class TrainValidationSplitModel(Model, ValidatorParams):
validationMetrics = list(self.validationMetrics)
return TrainValidationSplitModel(bestModel, validationMetrics)
+ @since("2.3.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @classmethod
+ @since("2.3.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ # Load information from java_stage to the instance.
+ bestModel = JavaParams._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = super(TrainValidationSplitModel,
+ cls)._from_java_impl(java_stage)
+ # Create a new instance of this stage.
+ py_stage = cls(bestModel=bestModel).setEstimator(estimator)
+ py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+ # TODO: persst validation metrics as well
+ _java_obj = JavaParams._new_java_obj(
+ "org.apache.spark.ml.tuning.TrainValidationSplitModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
if __name__ == "__main__":
import doctest
http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 80a0b31..ee6301e 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -106,7 +106,7 @@ class JavaParams(JavaWrapper, Params):
def _make_java_param_pair(self, param, value):
"""
- Makes a Java parm pair.
+ Makes a Java param pair.
"""
sc = SparkContext._active_spark_context
param = self._resolveParam(param)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org