You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/10/07 04:10:25 UTC
spark git commit: [SPARK-17792][ML] L-BFGS solver for linear
regression does not accept general numeric label column types
Repository: spark
Updated Branches:
refs/heads/master 49d11d499 -> 3713bb199
[SPARK-17792][ML] L-BFGS solver for linear regression does not accept general numeric label column types
## What changes were proposed in this pull request?
Before, we computed `instances` in LinearRegression in two spots, even though they did the same thing. One of them did not cast the label column to `DoubleType`. This patch consolidates the computation and always casts the label column to `DoubleType`.
## How was this patch tested?
Added a unit test to check all solvers. This test failed before this patch.
Author: sethah <se...@gmail.com>
Closes #15364 from sethah/linreg_numeric_type.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3713bb19
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3713bb19
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3713bb19
Branch: refs/heads/master
Commit: 3713bb199142c5e06e2e527c99650f02f41f47b1
Parents: 49d11d4
Author: sethah <se...@gmail.com>
Authored: Thu Oct 6 21:10:17 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Thu Oct 6 21:10:17 2016 -0700
----------------------------------------------------------------------
.../spark/ml/regression/LinearRegression.scala | 17 ++++++-----------
.../ml/regression/LinearRegressionSuite.scala | 8 +++++---
2 files changed, 11 insertions(+), 14 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3713bb19/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 536c58f..025ed20 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -188,17 +188,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances: RDD[Instance] = dataset.select(
+ col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
+
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " +
"solver is used.'")
// For low dimensional data, WeightedLeastSquares is more efficiently since the
// training algorithm only requires one pass through the data. (SPARK-10668)
- val instances: RDD[Instance] = dataset.select(
- col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
$(standardization), true)
@@ -221,12 +222,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
return lrModel.setSummary(trainingSummary)
}
- val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
http://git-wip-us.apache.org/repos/asf/spark/blob/3713bb19/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 5ae371b..1c94ec6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -1015,12 +1015,14 @@ class LinearRegressionSuite
}
test("should support all NumericType labels and not support other types") {
- val lr = new LinearRegression().setMaxIter(1)
- MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
- lr, spark, isClassification = false) { (expected, actual) =>
+ for (solver <- Seq("auto", "l-bfgs", "normal")) {
+ val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
+ MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
+ lr, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org