You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/07 10:28:55 UTC

spark git commit: [SPARK-7429] [ML] Params cleanups

Repository: spark
Updated Branches:
  refs/heads/master 8b6b46e4f -> 4f87e9562


[SPARK-7429] [ML] Params cleanups

Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does.

CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel

CC: mengxr

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #5960 from jkbradley/params-cleanups and squashes the following commits:

118b158 [Joseph K. Bradley] Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does. CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f87e956
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f87e956
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f87e956

Branch: refs/heads/master
Commit: 4f87e9562aa0dfe5467d7fbaba9278213106377c
Parents: 8b6b46e
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Thu May 7 01:28:44 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu May 7 01:28:44 2015 -0700

----------------------------------------------------------------------
 mllib/src/main/scala/org/apache/spark/ml/param/params.scala      | 4 +---
 .../main/scala/org/apache/spark/ml/tuning/CrossValidator.scala   | 3 ++-
 .../src/test/java/org/apache/spark/ml/param/JavaTestParams.java  | 1 +
 3 files changed, 4 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f87e956/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 51ce19d..6d09962 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable {
   /**
    * Sets default values for a list of params.
    *
-   * Note: Java developers should use the single-parameter [[setDefault()]].
-   *       Annotating this with varargs causes compilation failures.
-   *
    * @param paramPairs  a list of param pairs that specify params and their default values to set
    *                    respectively. Make sure that the params are initialized before this method
    *                    gets called.
    */
+  @varargs
   protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
     paramPairs.foreach { p =>
       setDefault(p.param.asInstanceOf[Param[Any]], p.value)

http://git-wip-us.apache.org/repos/asf/spark/blob/4f87e956/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 9208127..ac0d1fe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
 
   override def fit(dataset: DataFrame): CrossValidatorModel = {
     val schema = dataset.schema
-    transformSchema(dataset.schema, logging = true)
+    transformSchema(schema, logging = true)
     val sqlCtx = dataset.sqlContext
     val est = $(estimator)
     val eval = $(evaluator)
@@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] (
   }
 
   override def transform(dataset: DataFrame): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     bestModel.transform(dataset)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4f87e956/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 8abe575..532eca4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -59,5 +59,6 @@ public class JavaTestParams extends JavaParams {
       ParamValidators.inArray(validStrings));
     setDefault(myIntParam, 1);
     setDefault(myDoubleParam, 0.5);
+    setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org