You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/19 19:55:28 UTC

spark git commit: [SPARK-7047] [ML] ml.Model optional parent support

Repository: spark
Updated Branches:
  refs/heads/master 32fa611b1 -> fb9027321


[SPARK-7047] [ML] ml.Model optional parent support

Made Model.parent transient.  Added Model.hasParent to test for null parent

CC: mengxr

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #5914 from jkbradley/parent-optional and squashes the following commits:

d501774 [Joseph K. Bradley] Made Model.parent transient.  Added Model.hasParent to test for null parent


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fb902732
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fb902732
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fb902732

Branch: refs/heads/master
Commit: fb90273212dc7241c9a0c3446e25e0e0b9377750
Parents: 32fa611
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Tue May 19 10:55:21 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 19 10:55:21 2015 -0700

----------------------------------------------------------------------
 mllib/src/main/scala/org/apache/spark/ml/Model.scala            | 5 ++++-
 .../spark/ml/classification/LogisticRegressionSuite.scala       | 1 +
 .../spark/ml/classification/RandomForestClassifierSuite.scala   | 2 ++
 3 files changed, 7 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fb902732/mllib/src/main/scala/org/apache/spark/ml/Model.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 7fd5153..70e7495 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -32,7 +32,7 @@ abstract class Model[M <: Model[M]] extends Transformer {
    * The parent estimator that produced this model.
    * Note: For ensembles' component Models, this value can be null.
    */
-  var parent: Estimator[M] = _
+  @transient var parent: Estimator[M] = _
 
   /**
    * Sets the parent of this model (Java API).
@@ -42,6 +42,9 @@ abstract class Model[M <: Model[M]] extends Transformer {
     this.asInstanceOf[M]
   }
 
+  /** Indicates whether this [[Model]] has a corresponding parent. */
+  def hasParent: Boolean = parent != null
+
   override def copy(extra: ParamMap): M = {
     // The default implementation of Params.copy doesn't work for models.
     throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")

http://git-wip-us.apache.org/repos/asf/spark/blob/fb902732/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 4376524..97f9749 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
     assert(model.getRawPredictionCol === "rawPrediction")
     assert(model.getProbabilityCol === "probability")
     assert(model.intercept !== 0.0)
+    assert(model.hasParent)
   }
 
   test("logistic regression doesn't fit intercept when fitIntercept is off") {

http://git-wip-us.apache.org/repos/asf/spark/blob/fb902732/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 08f86fa..cdbbaca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -162,5 +162,7 @@ private object RandomForestClassifierSuite {
     val oldModelAsNew = RandomForestClassificationModel.fromOld(
       oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
+    assert(newModel.hasParent)
+    assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org