You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/05/07 21:55:44 UTC
spark git commit: [SPARK-22885][ML][TEST] ML test for
StructuredStreaming: spark.ml.tuning
Repository: spark
Updated Branches:
refs/heads/master 1c9c5de95 -> f48bd6bdc
[SPARK-22885][ML][TEST] ML test for StructuredStreaming: spark.ml.tuning
## What changes were proposed in this pull request?
ML test for StructuredStreaming: spark.ml.tuning
## How was this patch tested?
N/A
Author: WeichenXu <we...@databricks.com>
Closes #20261 from WeichenXu123/ml_stream_tuning_test.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f48bd6bd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f48bd6bd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f48bd6bd
Branch: refs/heads/master
Commit: f48bd6bdc5aefd9ec43e2d0ee648d17add7ef554
Parents: 1c9c5de
Author: WeichenXu <we...@databricks.com>
Authored: Mon May 7 14:55:41 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon May 7 14:55:41 2018 -0700
----------------------------------------------------------------------
.../apache/spark/ml/tuning/CrossValidatorSuite.scala | 15 +++++++++++----
.../spark/ml/tuning/TrainValidationSplitSuite.scala | 15 +++++++++++----
2 files changed, 22 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f48bd6bd/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 15dade2..e6ee722 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
class CrossValidatorSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+ extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -66,6 +66,13 @@ class CrossValidatorSuite
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.avgMetrics.length === lrParamMaps.length)
+
+ val result = cvModel.transform(dataset).select("prediction").as[Double].collect()
+ testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") {
+ rows =>
+ val result2 = rows.map(_.getDouble(0))
+ assert(result === result2)
+ }
}
test("cross validation with linear regression") {
http://git-wip-us.apache.org/repos/asf/spark/blob/f48bd6bd/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 9024342..cd76acf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest}
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
class TrainValidationSplitSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+ extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -64,6 +64,13 @@ class TrainValidationSplitSuite
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(tvsModel.validationMetrics.length === lrParamMaps.length)
+
+ val result = tvsModel.transform(dataset).select("prediction").as[Double].collect()
+ testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") {
+ rows =>
+ val result2 = rows.map(_.getDouble(0))
+ assert(result === result2)
+ }
}
test("train validation with linear regression") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org