You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/02/06 20:22:16 UTC
spark git commit: [SPARK-5652][Mllib] Use broadcasted weights in
LogisticRegressionModel
Repository: spark
Updated Branches:
refs/heads/master 0d74bd7fd -> 80f3bcb58
[SPARK-5652][Mllib] Use broadcasted weights in LogisticRegressionModel
`LogisticRegressionModel`'s `predictPoint` should directly use broadcasted weights. This pr also fixes the compilation errors of two unit test suite: `JavaLogisticRegressionSuite ` and `JavaLinearRegressionSuite`.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #4429 from viirya/use_bcvalue and squashes the following commits:
5a797e5 [Liang-Chi Hsieh] Use broadcasted weights. Fix compilation error.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/80f3bcb5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/80f3bcb5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/80f3bcb5
Branch: refs/heads/master
Commit: 80f3bcb58f836cfe1829c85bdd349c10525c8a5e
Parents: 0d74bd7
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Fri Feb 6 11:22:11 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Feb 6 11:22:11 2015 -0800
----------------------------------------------------------------------
.../spark/mllib/classification/LogisticRegression.scala | 8 ++++----
.../spark/ml/classification/JavaLogisticRegressionSuite.java | 4 ++--
.../spark/ml/regression/JavaLinearRegressionSuite.java | 4 ++--
3 files changed, 8 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index a668e7a..9a391bf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
*
* @param weights Weights computed for every feature.
* @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression.
- * In Multinomial Logistic Regression, the intercepts will not be a single values,
+ * In Multinomial Logistic Regression, the intercepts will not be a single value,
* so the intercepts will be part of the weights.)
* @param numFeatures the dimension of the features.
* @param numClasses the number of possible outcomes for k classes classification problem in
@@ -107,7 +107,7 @@ class LogisticRegressionModel (
// If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression.
if (numClasses == 2) {
require(numFeatures == weightMatrix.size)
- val margin = dot(weights, dataMatrix) + intercept
+ val margin = dot(weightMatrix, dataMatrix) + intercept
val score = 1.0 / (1.0 + math.exp(-margin))
threshold match {
case Some(t) => if (score > t) 1.0 else 0.0
@@ -116,11 +116,11 @@ class LogisticRegressionModel (
} else {
val dataWithBiasSize = weightMatrix.size / (numClasses - 1)
- val weightsArray = weights match {
+ val weightsArray = weightMatrix match {
case dv: DenseVector => dv.values
case _ =>
throw new IllegalArgumentException(
- s"weights only supports dense vector but got type ${weights.getClass}.")
+ s"weights only supports dense vector but got type ${weightMatrix.getClass}.")
}
val margins = (0 until numClasses - 1).map { i =>
http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 2628402..d4b6644 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -84,7 +84,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
.setThreshold(0.6)
.setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
- assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
assert(model.getThreshold() == 0.6);
@@ -109,7 +109,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
- assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
assert(model2.getThreshold() == 0.4);
http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 5bd616e..40d5a92 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -76,13 +76,13 @@ public class JavaLinearRegressionSuite implements Serializable {
.setMaxIter(10)
.setRegParam(1.0);
LinearRegressionModel model = lr.fit(dataset);
- assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
- assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
assert(model2.getPredictionCol().equals("thePred"));
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org