You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/05/03 23:17:24 UTC
[spark] branch master updated: [SPARK-27621][ML] Linear Regression
- validate training related params such as loss only during fitting phase
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4241a72 [SPARK-27621][ML] Linear Regression - validate training related params such as loss only during fitting phase
4241a72 is described below
commit 4241a72c654f13b6b4ceafb27daceb7bb553add6
Author: asarb <as...@expedia.com>
AuthorDate: Fri May 3 18:17:04 2019 -0500
[SPARK-27621][ML] Linear Regression - validate training related params such as loss only during fitting phase
## What changes were proposed in this pull request?
When transform(...) method is called on a LinearRegressionModel created directly with the coefficients and intercepts, the following exception is encountered.
```
java.util.NoSuchElementException: Failed to find a default value for loss
at org.apache.spark.ml.param.Params$$anonfun$getOrDefault$2.apply(params.scala:780)
at org.apache.spark.ml.param.Params$$anonfun$getOrDefault$2.apply(params.scala:780)
at scala.Option.getOrElse(Option.scala:121)
at org.apache.spark.ml.param.Params$class.getOrDefault(params.scala:779)
at org.apache.spark.ml.PipelineStage.getOrDefault(Pipeline.scala:42)
at org.apache.spark.ml.param.Params$class.$(params.scala:786)
at org.apache.spark.ml.PipelineStage.$(Pipeline.scala:42)
at org.apache.spark.ml.regression.LinearRegressionParams$class.validateAndTransformSchema(LinearRegression.scala:111)
at org.apache.spark.ml.regression.LinearRegressionModel.validateAndTransformSchema(LinearRegression.scala:637)
at org.apache.spark.ml.PredictionModel.transformSchema(Predictor.scala:192)
at org.apache.spark.ml.PipelineModel$$anonfun$transformSchema$5.apply(Pipeline.scala:311)
at org.apache.spark.ml.PipelineModel$$anonfun$transformSchema$5.apply(Pipeline.scala:311)
at scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57)
at scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66)
at scala.collection.mutable.ArrayOps$ofRef.foldLeft(ArrayOps.scala:186)
at org.apache.spark.ml.PipelineModel.transformSchema(Pipeline.scala:311)
at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
at org.apache.spark.ml.PipelineModel.transform(Pipeline.scala:305)
```
This is because validateAndTransformSchema() is called both during training and scoring phases, but the checks against the training related params like loss should really be performed during training phase only, I think, please correct me if I'm missing anything :)
This issue was first reported for mleap (https://github.com/combust/mleap/issues/455) because basically when we serialize the Spark transformers for mleap, we only serialize the params that are relevant for scoring. We do have the option to de-serialize the serialized transformers back into Spark for scoring again, but in that case, we no longer have all the training params.
## How was this patch tested?
Added a unit test to check this scenario.
Please let me know if there's anything additional required, this is the first PR that I've raised in this project.
Closes #24509 from ancasarb/linear_regression_params_fix.
Authored-by: asarb <as...@expedia.com>
Signed-off-by: Sean Owen <se...@databricks.com>
---
.../org/apache/spark/ml/regression/LinearRegression.scala | 13 +++++++------
.../apache/spark/ml/regression/LinearRegressionSuite.scala | 12 ++++++++++++
2 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index cd0b7d3..09f3f94 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -109,12 +109,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- if ($(loss) == Huber) {
- require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " +
- "normal solver, please change solver to auto or l-bfgs.")
- require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " +
- s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")
-
+ if (fitting) {
+ if ($(loss) == Huber) {
+ require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " +
+ "normal solver, please change solver to auto or l-bfgs.")
+ require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " +
+ s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")
+ }
}
super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index d3df0e5..82d9849 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -187,6 +187,18 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
assert(model.numFeatures === numFeatures)
}
+ test("linear regression: can transform data with LinearRegressionModel") {
+ withClue("training related params like loss are only validated during fitting phase") {
+ val original = new LinearRegression().fit(datasetWithDenseFeature)
+
+ val deserialized = new LinearRegressionModel(uid = original.uid,
+ coefficients = original.coefficients,
+ intercept = original.intercept)
+ val output = deserialized.transform(datasetWithDenseFeature)
+ assert(output.collect().size > 0) // simple assertion to ensure no exception thrown
+ }
+ }
+
test("linear regression: illegal params") {
withClue("LinearRegression with huber loss only supports L2 regularization") {
intercept[IllegalArgumentException] {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org