You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/05/07 21:55:44 UTC

spark git commit: [SPARK-22885][ML][TEST] ML test for StructuredStreaming: spark.ml.tuning

Repository: spark
Updated Branches:
  refs/heads/master 1c9c5de95 -> f48bd6bdc


[SPARK-22885][ML][TEST] ML test for StructuredStreaming: spark.ml.tuning

## What changes were proposed in this pull request?

ML test for StructuredStreaming: spark.ml.tuning

## How was this patch tested?

N/A

Author: WeichenXu <we...@databricks.com>

Closes #20261 from WeichenXu123/ml_stream_tuning_test.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f48bd6bd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f48bd6bd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f48bd6bd

Branch: refs/heads/master
Commit: f48bd6bdc5aefd9ec43e2d0ee648d17add7ef554
Parents: 1c9c5de
Author: WeichenXu <we...@databricks.com>
Authored: Mon May 7 14:55:41 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon May 7 14:55:41 2018 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/tuning/CrossValidatorSuite.scala | 15 +++++++++++----
 .../spark/ml/tuning/TrainValidationSplitSuite.scala  | 15 +++++++++++----
 2 files changed, 22 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f48bd6bd/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 15dade2..e6ee722 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
 import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator}
 import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.mllib.util.LinearDataGenerator
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.types.StructType
 
 class CrossValidatorSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+  extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -66,6 +66,13 @@ class CrossValidatorSuite
     assert(parent.getRegParam === 0.001)
     assert(parent.getMaxIter === 10)
     assert(cvModel.avgMetrics.length === lrParamMaps.length)
+
+    val result = cvModel.transform(dataset).select("prediction").as[Double].collect()
+    testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") {
+      rows =>
+        val result2 = rows.map(_.getDouble(0))
+        assert(result === result2)
+    }
   }
 
   test("cross validation with linear regression") {

http://git-wip-us.apache.org/repos/asf/spark/blob/f48bd6bd/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 9024342..cd76acf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest}
 import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.mllib.util.LinearDataGenerator
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.types.StructType
 
 class TrainValidationSplitSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+  extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -64,6 +64,13 @@ class TrainValidationSplitSuite
     assert(parent.getRegParam === 0.001)
     assert(parent.getMaxIter === 10)
     assert(tvsModel.validationMetrics.length === lrParamMaps.length)
+
+    val result = tvsModel.transform(dataset).select("prediction").as[Double].collect()
+    testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") {
+      rows =>
+        val result2 = rows.map(_.getDouble(0))
+        assert(result === result2)
+    }
   }
 
   test("train validation with linear regression") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org