You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/03/31 18:39:18 UTC

spark git commit: [SPARK-14164][MLLIB] Improve input layer validation of MultilayerPerceptronClassifier

Repository: spark
Updated Branches:
  refs/heads/master a9b93e073 -> 208fff3ac


[SPARK-14164][MLLIB] Improve input layer validation of MultilayerPerceptronClassifier

## What changes were proposed in this pull request?

This issue improves an input layer validation and adds related testcases to MultilayerPerceptronClassifier.

```scala
-    // TODO: how to check ALSO that all elements are greater than 0?
-    ParamValidators.arrayLengthGt(1)
+    (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1
```

## How was this patch tested?

Pass the Jenkins tests including the new testcases.

Author: Dongjoon Hyun <do...@apache.org>

Closes #11964 from dongjoon-hyun/SPARK-14164.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/208fff3a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/208fff3a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/208fff3a

Branch: refs/heads/master
Commit: 208fff3ac87f200fd4e6f0407d70bf81cf8c556f
Parents: a9b93e0
Author: Dongjoon Hyun <do...@apache.org>
Authored: Thu Mar 31 09:39:15 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Mar 31 09:39:15 2016 -0700

----------------------------------------------------------------------
 .../MultilayerPerceptronClassifier.scala           |  3 +--
 .../MultilayerPerceptronClassifierSuite.scala      | 17 +++++++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/208fff3a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index f6de5f2..7ce3ec6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -43,8 +43,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
     "Sizes of layers from input layer to output layer" +
       " E.g., Array(780, 100, 10) means 780 inputs, " +
       "one hidden layer with 100 neurons and output layer of 10 neurons.",
-    // TODO: how to check ALSO that all elements are greater than 0?
-    ParamValidators.arrayLengthGt(1)
+    (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1
   )
 
   /** @group getParam */

http://git-wip-us.apache.org/repos/asf/spark/blob/208fff3a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 5df8e6a..53c7a55 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -43,6 +43,23 @@ class MultilayerPerceptronClassifierSuite
     ).toDF("features", "label")
   }
 
+  test("Input Validation") {
+    val mlpc = new MultilayerPerceptronClassifier()
+    intercept[IllegalArgumentException] {
+      mlpc.setLayers(Array[Int]())
+    }
+    intercept[IllegalArgumentException] {
+      mlpc.setLayers(Array[Int](1))
+    }
+    intercept[IllegalArgumentException] {
+      mlpc.setLayers(Array[Int](0, 1))
+    }
+    intercept[IllegalArgumentException] {
+      mlpc.setLayers(Array[Int](1, 0))
+    }
+    mlpc.setLayers(Array[Int](1, 1))
+  }
+
   test("XOR function learning as binary classification problem with two outputs.") {
     val layers = Array[Int](2, 5, 2)
     val trainer = new MultilayerPerceptronClassifier()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org