You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/09/05 12:04:00 UTC

spark git commit: [SPARK-10013] [ML] [JAVA] [TEST] remove java assert from java unit tests

Repository: spark
Updated Branches:
  refs/heads/master bca8c072b -> 871764c6c


[SPARK-10013] [ML] [JAVA] [TEST] remove java assert from java unit tests

>From Jira: We should use assertTrue, etc. instead to make sure the asserts are not ignored in tests.

Author: Holden Karau <ho...@pigscanfly.ca>

Closes #8607 from holdenk/SPARK-10013-remove-java-assert-from-java-unit-tests.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/871764c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/871764c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/871764c6

Branch: refs/heads/master
Commit: 871764c6ce531af5b1ac7ccccb32e7a903b59a2a
Parents: bca8c07
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Sat Sep 5 00:04:00 2015 -1000
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Sep 5 00:04:00 2015 -1000

----------------------------------------------------------------------
 .../JavaLogisticRegressionSuite.java            | 51 ++++++++++----------
 .../ml/classification/JavaNaiveBayesSuite.java  | 13 ++---
 .../regression/JavaLinearRegressionSuite.java   |  2 +-
 .../spark/mllib/linalg/JavaMatricesSuite.java   | 40 +++++++--------
 4 files changed, 54 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/871764c6/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 618b95b..fd22eb6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -22,6 +22,7 @@ import java.lang.Math;
 import java.util.List;
 
 import org.junit.After;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -63,16 +64,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
   @Test
   public void logisticRegressionDefaultParams() {
     LogisticRegression lr = new LogisticRegression();
-    assert(lr.getLabelCol().equals("label"));
+    Assert.assertEquals(lr.getLabelCol(), "label");
     LogisticRegressionModel model = lr.fit(dataset);
     model.transform(dataset).registerTempTable("prediction");
     DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
     predictions.collectAsList();
     // Check defaults
-    assert(model.getThreshold() == 0.5);
-    assert(model.getFeaturesCol().equals("features"));
-    assert(model.getPredictionCol().equals("prediction"));
-    assert(model.getProbabilityCol().equals("probability"));
+    Assert.assertEquals(0.5, model.getThreshold(), eps);
+    Assert.assertEquals("features", model.getFeaturesCol());
+    Assert.assertEquals("prediction", model.getPredictionCol());
+    Assert.assertEquals("probability", model.getProbabilityCol());
   }
 
   @Test
@@ -85,19 +86,19 @@ public class JavaLogisticRegressionSuite implements Serializable {
       .setProbabilityCol("myProbability");
     LogisticRegressionModel model = lr.fit(dataset);
     LogisticRegression parent = (LogisticRegression) model.parent();
-    assert(parent.getMaxIter() == 10);
-    assert(parent.getRegParam() == 1.0);
-    assert(parent.getThresholds()[0] == 0.4);
-    assert(parent.getThresholds()[1] == 0.6);
-    assert(parent.getThreshold() == 0.6);
-    assert(model.getThreshold() == 0.6);
+    Assert.assertEquals(10, parent.getMaxIter());
+    Assert.assertEquals(1.0, parent.getRegParam(), eps);
+    Assert.assertEquals(0.4, parent.getThresholds()[0], eps);
+    Assert.assertEquals(0.6, parent.getThresholds()[1], eps);
+    Assert.assertEquals(0.6, parent.getThreshold(), eps);
+    Assert.assertEquals(0.6, model.getThreshold(), eps);
 
     // Modify model params, and check that the params worked.
     model.setThreshold(1.0);
     model.transform(dataset).registerTempTable("predAllZero");
     DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
     for (Row r: predAllZero.collectAsList()) {
-      assert(r.getDouble(0) == 0.0);
+      Assert.assertEquals(0.0, r.getDouble(0), eps);
     }
     // Call transform with params, and check that the params worked.
     model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
@@ -107,17 +108,17 @@ public class JavaLogisticRegressionSuite implements Serializable {
     for (Row r: predNotAllZero.collectAsList()) {
       if (r.getDouble(0) != 0.0) foundNonZero = true;
     }
-    assert(foundNonZero);
+    Assert.assertTrue(foundNonZero);
 
     // Call fit() with new params, and check as many params as we can.
     LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
         lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
     LogisticRegression parent2 = (LogisticRegression) model2.parent();
-    assert(parent2.getMaxIter() == 5);
-    assert(parent2.getRegParam() == 0.1);
-    assert(parent2.getThreshold() == 0.4);
-    assert(model2.getThreshold() == 0.4);
-    assert(model2.getProbabilityCol().equals("theProb"));
+    Assert.assertEquals(5, parent2.getMaxIter());
+    Assert.assertEquals(0.1, parent2.getRegParam(), eps);
+    Assert.assertEquals(0.4, parent2.getThreshold(), eps);
+    Assert.assertEquals(0.4, model2.getThreshold(), eps);
+    Assert.assertEquals("theProb", model2.getProbabilityCol());
   }
 
   @SuppressWarnings("unchecked")
@@ -125,18 +126,18 @@ public class JavaLogisticRegressionSuite implements Serializable {
   public void logisticRegressionPredictorClassifierMethods() {
     LogisticRegression lr = new LogisticRegression();
     LogisticRegressionModel model = lr.fit(dataset);
-    assert(model.numClasses() == 2);
+    Assert.assertEquals(2, model.numClasses());
 
     model.transform(dataset).registerTempTable("transformed");
     DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
     for (Row row: trans1.collect()) {
       Vector raw = (Vector)row.get(0);
       Vector prob = (Vector)row.get(1);
-      assert(raw.size() == 2);
-      assert(prob.size() == 2);
+      Assert.assertEquals(raw.size(), 2);
+      Assert.assertEquals(prob.size(), 2);
       double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
-      assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
-      assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
+      Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps);
+      Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
     }
 
     DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
@@ -145,7 +146,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
       Vector prob = (Vector)row.get(1);
       double probOfPred = prob.apply((int)pred);
       for (int i = 0; i < prob.size(); ++i) {
-        assert(probOfPred >= prob.apply(i));
+        Assert.assertTrue(probOfPred >= prob.apply(i));
       }
     }
   }
@@ -156,6 +157,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
     LogisticRegressionModel model = lr.fit(dataset);
 
     LogisticRegressionTrainingSummary summary = model.summary();
-    assert(summary.totalIterations() == summary.objectiveHistory().length);
+    Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/871764c6/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 8fd7bf5..075a62c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -23,6 +23,7 @@ import java.util.Arrays;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
+import static org.junit.Assert.assertEquals;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -58,18 +59,18 @@ public class JavaNaiveBayesSuite implements Serializable {
     for (Row r : predictionAndLabels.collect()) {
       double prediction = r.getAs(0);
       double label = r.getAs(1);
-      assert(prediction == label);
+      assertEquals(label, prediction, 1E-5);
     }
   }
 
   @Test
   public void naiveBayesDefaultParams() {
     NaiveBayes nb = new NaiveBayes();
-    assert(nb.getLabelCol() == "label");
-    assert(nb.getFeaturesCol() == "features");
-    assert(nb.getPredictionCol() == "prediction");
-    assert(nb.getSmoothing() == 1.0);
-    assert(nb.getModelType() == "multinomial");
+    assertEquals("label", nb.getLabelCol());
+    assertEquals("features", nb.getFeaturesCol());
+    assertEquals("prediction", nb.getPredictionCol());
+    assertEquals(1.0, nb.getSmoothing(), 1E-5);
+    assertEquals("multinomial", nb.getModelType());
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/871764c6/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index d591a45..91c589d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -60,7 +60,7 @@ public class JavaLinearRegressionSuite implements Serializable {
   @Test
   public void linearRegressionDefaultParams() {
     LinearRegression lr = new LinearRegression();
-    assert(lr.getLabelCol().equals("label"));
+    assertEquals("label", lr.getLabelCol());
     LinearRegressionModel model = lr.fit(dataset);
     model.transform(dataset).registerTempTable("prediction");
     DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");

http://git-wip-us.apache.org/repos/asf/spark/blob/871764c6/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
index 3349c50..8beea10 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
@@ -80,10 +80,10 @@ public class JavaMatricesSuite implements Serializable {
         assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
         assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
         assertArrayEquals(s.values(), ss.values(), 0.0);
-        assert(s.values().length == 2);
-        assert(ss.values().length == 2);
-        assert(s.colPtrs().length == 4);
-        assert(ss.colPtrs().length == 4);
+        assertEquals(2, s.values().length);
+        assertEquals(2, ss.values().length);
+        assertEquals(4, s.colPtrs().length);
+        assertEquals(4, ss.colPtrs().length);
     }
 
     @Test
@@ -137,27 +137,27 @@ public class JavaMatricesSuite implements Serializable {
         Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
         Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
 
-        assert(deHorz1.numRows() == 3);
-        assert(deHorz2.numRows() == 3);
-        assert(deHorz3.numRows() == 3);
-        assert(spHorz.numRows() == 3);
-        assert(deHorz1.numCols() == 5);
-        assert(deHorz2.numCols() == 5);
-        assert(deHorz3.numCols() == 5);
-        assert(spHorz.numCols() == 5);
+        assertEquals(3, deHorz1.numRows());
+        assertEquals(3, deHorz2.numRows());
+        assertEquals(3, deHorz3.numRows());
+        assertEquals(3, spHorz.numRows());
+        assertEquals(5, deHorz1.numCols());
+        assertEquals(5, deHorz2.numCols());
+        assertEquals(5, deHorz3.numCols());
+        assertEquals(5, spHorz.numCols());
 
         Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
         Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
         Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
         Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
 
-        assert(deVert1.numRows() == 5);
-        assert(deVert2.numRows() == 5);
-        assert(deVert3.numRows() == 5);
-        assert(spVert.numRows() == 5);
-        assert(deVert1.numCols() == 2);
-        assert(deVert2.numCols() == 2);
-        assert(deVert3.numCols() == 2);
-        assert(spVert.numCols() == 2);
+        assertEquals(5, deVert1.numRows());
+        assertEquals(5, deVert2.numRows());
+        assertEquals(5, deVert3.numRows());
+        assertEquals(5, spVert.numRows());
+        assertEquals(2, deVert1.numCols());
+        assertEquals(2, deVert2.numCols());
+        assertEquals(2, deVert3.numCols());
+        assertEquals(2, spVert.numCols());
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org