You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/02/17 19:17:47 UTC
spark git commit: [SPARK-5858][MLLIB] Remove unnecessary first() call
in GLM
Repository: spark
Updated Branches:
refs/heads/master 3ce46e94f -> c76da36c2
[SPARK-5858][MLLIB] Remove unnecessary first() call in GLM
`numFeatures` is only used by multinomial logistic regression. Calling `.first()` for every GLM causes performance regression, especially in Python.
Author: Xiangrui Meng <me...@databricks.com>
Closes #4647 from mengxr/SPARK-5858 and squashes the following commits:
036dc7f [Xiangrui Meng] remove unnecessary first() call
12c5548 [Xiangrui Meng] check numFeatures only once
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c76da36c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c76da36c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c76da36c
Branch: refs/heads/master
Commit: c76da36c2163276b5c34e59fbb139eeb34ed0faa
Parents: 3ce46e9
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Feb 17 10:17:45 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Feb 17 10:17:45 2015 -0800
----------------------------------------------------------------------
.../spark/mllib/classification/LogisticRegression.scala | 6 +++++-
.../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 7 ++++---
2 files changed, 9 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c76da36c/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 420d6e2..b787667 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -355,6 +355,10 @@ class LogisticRegressionWithLBFGS
}
override protected def createModel(weights: Vector, intercept: Double) = {
- new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
+ if (numOfLinearPredictor == 1) {
+ new LogisticRegressionModel(weights, intercept)
+ } else {
+ new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c76da36c/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 2b71453..7c66e8c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -126,7 +126,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
/**
* The dimension of training features.
*/
- protected var numFeatures: Int = 0
+ protected var numFeatures: Int = -1
/**
* Set if the algorithm should use feature scaling to improve the convergence during optimization.
@@ -163,7 +163,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* RDD of LabeledPoint entries.
*/
def run(input: RDD[LabeledPoint]): M = {
- numFeatures = input.first().features.size
+ if (numFeatures < 0) {
+ numFeatures = input.map(_.features.size).first()
+ }
/**
* When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights,
@@ -193,7 +195,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* of LabeledPoint entries starting from the initial weights provided.
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
- numFeatures = input.first().features.size
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org