You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/07/26 11:39:01 UTC

spark git commit: [SPARK-20988][ML] Logistic regression uses aggregator hierarchy

Repository: spark
Updated Branches:
  refs/heads/master ae4ea5fe2 -> cf29828d7


[SPARK-20988][ML] Logistic regression uses aggregator hierarchy

## What changes were proposed in this pull request?

This change pulls the `LogisticAggregator` class out of LogisticRegression.scala and makes it extend `DifferentiableLossAggregator`. It also changes logistic regression to use the generic `RDDLossFunction` instead of having its own.

Other minor changes:
* L2Regularization accepts `Option[Int => Double]` for features standard deviation
* L2Regularization uses `Vector` type instead of Array
* Some tests added to LeastSquaresAggregator

## How was this patch tested?

Unit test suites are added.

Author: sethah <sh...@cloudera.com>

Closes #18305 from sethah/SPARK-20988.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cf29828d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cf29828d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cf29828d

Branch: refs/heads/master
Commit: cf29828d720611f7633a8114917bac901f76dece
Parents: ae4ea5f
Author: sethah <sh...@cloudera.com>
Authored: Wed Jul 26 13:38:53 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Wed Jul 26 13:38:53 2017 +0200

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 492 +------------------
 .../optim/aggregator/LogisticAggregator.scala   | 364 ++++++++++++++
 .../loss/DifferentiableRegularization.scala     |  55 ++-
 .../spark/ml/optim/loss/RDDLossFunction.scala   |  10 +-
 .../spark/ml/regression/LinearRegression.scala  |   3 +-
 .../LogisticRegressionSuite.scala               |   9 +-
 .../DifferentiableLossAggregatorSuite.scala     |  37 ++
 .../LeastSquaresAggregatorSuite.scala           |  47 +-
 .../aggregator/LogisticAggregatorSuite.scala    | 253 ++++++++++
 .../DifferentiableRegularizationSuite.scala     |  13 +-
 .../ml/optim/loss/RDDLossFunctionSuite.scala    |   6 +-
 11 files changed, 752 insertions(+), 537 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 65b09e5..6bba7f9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -27,11 +27,11 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg._
-import org.apache.spark.ml.linalg.BLAS._
+import org.apache.spark.ml.optim.aggregator.LogisticAggregator
+import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
@@ -598,8 +598,23 @@ class LogisticRegression @Since("1.2.0") (
         val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
 
         val bcFeaturesStd = instances.context.broadcast(featuresStd)
-        val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
-          $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial,
+        val getAggregatorFunc = new LogisticAggregator(bcFeaturesStd, numClasses, $(fitIntercept),
+          multinomial = isMultinomial)(_)
+        val getFeaturesStd = (j: Int) => if (j >= 0 && j < numCoefficientSets * numFeatures) {
+          featuresStd(j / numCoefficientSets)
+        } else {
+          0.0
+        }
+
+        val regularization = if (regParamL2 != 0.0) {
+          val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures * numCoefficientSets
+          Some(new L2Regularization(regParamL2, shouldApply,
+            if ($(standardization)) None else Some(getFeaturesStd)))
+        } else {
+          None
+        }
+
+        val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization,
           $(aggregationDepth))
 
         val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets
@@ -1236,7 +1251,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
  * Two MultilabelSummarizer can be merged together to have a statistical summary of the
  * corresponding joint dataset.
  */
-private[classification] class MultiClassSummarizer extends Serializable {
+private[ml] class MultiClassSummarizer extends Serializable {
   // The first element of value in distinctMap is the actually number of instances,
   // and the second element of value is sum of the weights.
   private val distinctMap = new mutable.HashMap[Int, (Long, Double)]
@@ -1468,470 +1483,3 @@ class BinaryLogisticRegressionSummary private[classification] (
     binaryMetrics.recallByThreshold().toDF("threshold", "recall")
   }
 }
-
-/**
- * LogisticAggregator computes the gradient and loss for binary or multinomial logistic (softmax)
- * loss function, as used in classification for instances in sparse or dense vector in an online
- * fashion.
- *
- * Two LogisticAggregators can be merged together to have a summary of loss and gradient of
- * the corresponding joint dataset.
- *
- * For improving the convergence rate during the optimization process and also to prevent against
- * features with very large variances exerting an overly large influence during model training,
- * packages like R's GLMNET perform the scaling to unit variance and remove the mean in order to
- * reduce the condition number. The model is then trained in this scaled space, but returns the
- * coefficients in the original scale. See page 9 in
- * http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
- *
- * However, we don't want to apply the [[org.apache.spark.ml.feature.StandardScaler]] on the
- * training dataset, and then cache the standardized dataset since it will create a lot of overhead.
- * As a result, we perform the scaling implicitly when we compute the objective function (though
- * we do not subtract the mean).
- *
- * Note that there is a difference between multinomial (softmax) and binary loss. The binary case
- * uses one outcome class as a "pivot" and regresses the other class against the pivot. In the
- * multinomial case, the softmax loss function is used to model each class probability
- * independently. Using softmax loss produces `K` sets of coefficients, while using a pivot class
- * produces `K - 1` sets of coefficients (a single coefficient vector in the binary case). In the
- * binary case, we can say that the coefficients are shared between the positive and negative
- * classes. When regularization is applied, multinomial (softmax) loss will produce a result
- * different from binary loss since the positive and negative don't share the coefficients while the
- * binary regression shares the coefficients between positive and negative.
- *
- * The following is a mathematical derivation for the multinomial (softmax) loss.
- *
- * The probability of the multinomial outcome $y$ taking on any of the K possible outcomes is:
- *
- * <blockquote>
- *    $$
- *    P(y_i=0|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1}
- *       e^{\vec{x}_i^T \vec{\beta}_k}} \\
- *    P(y_i=1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_1}}{\sum_{k=0}^{K-1}
- *       e^{\vec{x}_i^T \vec{\beta}_k}}\\
- *    P(y_i=K-1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_{K-1}}\,}{\sum_{k=0}^{K-1}
- *       e^{\vec{x}_i^T \vec{\beta}_k}}
- *    $$
- * </blockquote>
- *
- * The model coefficients $\beta = (\beta_0, \beta_1, \beta_2, ..., \beta_{K-1})$ become a matrix
- * which has dimension of $K \times (N+1)$ if the intercepts are added. If the intercepts are not
- * added, the dimension will be $K \times N$.
- *
- * Note that the coefficients in the model above lack identifiability. That is, any constant scalar
- * can be added to all of the coefficients and the probabilities remain the same.
- *
- * <blockquote>
- *    $$
- *    \begin{align}
- *    \frac{e^{\vec{x}_i^T \left(\vec{\beta}_0 + \vec{c}\right)}}{\sum_{k=0}^{K-1}
- *       e^{\vec{x}_i^T \left(\vec{\beta}_k + \vec{c}\right)}}
- *    = \frac{e^{\vec{x}_i^T \vec{\beta}_0}e^{\vec{x}_i^T \vec{c}}\,}{e^{\vec{x}_i^T \vec{c}}
- *       \sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
- *    = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
- *    \end{align}
- *    $$
- * </blockquote>
- *
- * However, when regularization is added to the loss function, the coefficients are indeed
- * identifiable because there is only one set of coefficients which minimizes the regularization
- * term. When no regularization is applied, we choose the coefficients with the minimum L2
- * penalty for consistency and reproducibility. For further discussion see:
- *
- * Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent"
- *
- * The loss of objective function for a single instance of data (we do not include the
- * regularization term here for simplicity) can be written as
- *
- * <blockquote>
- *    $$
- *    \begin{align}
- *    \ell\left(\beta, x_i\right) &= -log{P\left(y_i \middle| \vec{x}_i, \beta\right)} \\
- *    &= log\left(\sum_{k=0}^{K-1}e^{\vec{x}_i^T \vec{\beta}_k}\right) - \vec{x}_i^T \vec{\beta}_y\\
- *    &= log\left(\sum_{k=0}^{K-1} e^{margins_k}\right) - margins_y
- *    \end{align}
- *    $$
- * </blockquote>
- *
- * where ${margins}_k = \vec{x}_i^T \vec{\beta}_k$.
- *
- * For optimization, we have to calculate the first derivative of the loss function, and a simple
- * calculation shows that
- *
- * <blockquote>
- *    $$
- *    \begin{align}
- *    \frac{\partial \ell(\beta, \vec{x}_i, w_i)}{\partial \beta_{j, k}}
- *    &= x_{i,j} \cdot w_i \cdot \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k'=0}^{K-1}
- *      e^{\vec{x}_i \cdot \vec{\beta}_{k'}}\,} - I_{y=k}\right) \\
- *    &= x_{i, j} \cdot w_i \cdot multiplier_k
- *    \end{align}
- *    $$
- * </blockquote>
- *
- * where $w_i$ is the sample weight, $I_{y=k}$ is an indicator function
- *
- *  <blockquote>
- *    $$
- *    I_{y=k} = \begin{cases}
- *          1 & y = k \\
- *          0 & else
- *       \end{cases}
- *    $$
- * </blockquote>
- *
- * and
- *
- * <blockquote>
- *    $$
- *    multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k=0}^{K-1}
- *       e^{\vec{x}_i \cdot \vec{\beta}_k}} - I_{y=k}\right)
- *    $$
- * </blockquote>
- *
- * If any of margins is larger than 709.78, the numerical computation of multiplier and loss
- * function will suffer from arithmetic overflow. This issue occurs when there are outliers in
- * data which are far away from the hyperplane, and this will cause the failing of training once
- * infinity is introduced. Note that this is only a concern when max(margins) &gt; 0.
- *
- * Fortunately, when max(margins) = maxMargin &gt; 0, the loss function and the multiplier can
- * easily be rewritten into the following equivalent numerically stable formula.
- *
- * <blockquote>
- *    $$
- *    \ell\left(\beta, x\right) = log\left(\sum_{k=0}^{K-1} e^{margins_k - maxMargin}\right) -
- *       margins_{y} + maxMargin
- *    $$
- * </blockquote>
- *
- * Note that each term, $(margins_k - maxMargin)$ in the exponential is no greater than zero; as a
- * result, overflow will not happen with this formula.
- *
- * For $multiplier$, a similar trick can be applied as the following,
- *
- * <blockquote>
- *    $$
- *    multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k - maxMargin}}{\sum_{k'=0}^{K-1}
- *       e^{\vec{x}_i \cdot \vec{\beta}_{k'} - maxMargin}} - I_{y=k}\right)
- *    $$
- * </blockquote>
- *
- * @param bcCoefficients The broadcast coefficients corresponding to the features.
- * @param bcFeaturesStd The broadcast standard deviation values of the features.
- * @param numClasses the number of possible outcomes for k classes classification problem in
- *                   Multinomial Logistic Regression.
- * @param fitIntercept Whether to fit an intercept term.
- * @param multinomial Whether to use multinomial (softmax) or binary loss
- *
- * @note In order to avoid unnecessary computation during calculation of the gradient updates
- * we lay out the coefficients in column major order during training. This allows us to
- * perform feature standardization once, while still retaining sequential memory access
- * for speed. We convert back to row major order when we create the model,
- * since this form is optimal for the matrix operations used for prediction.
- */
-private class LogisticAggregator(
-    bcCoefficients: Broadcast[Vector],
-    bcFeaturesStd: Broadcast[Array[Double]],
-    numClasses: Int,
-    fitIntercept: Boolean,
-    multinomial: Boolean) extends Serializable with Logging {
-
-  private val numFeatures = bcFeaturesStd.value.length
-  private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
-  private val coefficientSize = bcCoefficients.value.size
-  private val numCoefficientSets = if (multinomial) numClasses else 1
-  if (multinomial) {
-    require(numClasses ==  coefficientSize / numFeaturesPlusIntercept, s"The number of " +
-      s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize")
-  } else {
-    require(coefficientSize == numFeaturesPlusIntercept, s"Expected $numFeaturesPlusIntercept " +
-      s"coefficients but got $coefficientSize")
-    require(numClasses == 1 || numClasses == 2, s"Binary logistic aggregator requires numClasses " +
-      s"in {1, 2} but found $numClasses.")
-  }
-
-  private var weightSum = 0.0
-  private var lossSum = 0.0
-
-  @transient private lazy val coefficientsArray: Array[Double] = bcCoefficients.value match {
-    case DenseVector(values) => values
-    case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " +
-      s"got type ${bcCoefficients.value.getClass}.)")
-  }
-  private lazy val gradientSumArray = new Array[Double](coefficientSize)
-
-  if (multinomial && numClasses <= 2) {
-    logInfo(s"Multinomial logistic regression for binary classification yields separate " +
-      s"coefficients for positive and negative classes. When no regularization is applied, the" +
-      s"result will be effectively the same as binary logistic regression. When regularization" +
-      s"is applied, multinomial loss will produce a result different from binary loss.")
-  }
-
-  /** Update gradient and loss using binary loss function. */
-  private def binaryUpdateInPlace(
-      features: Vector,
-      weight: Double,
-      label: Double): Unit = {
-
-    val localFeaturesStd = bcFeaturesStd.value
-    val localCoefficients = coefficientsArray
-    val localGradientArray = gradientSumArray
-    val margin = - {
-      var sum = 0.0
-      features.foreachActive { (index, value) =>
-        if (localFeaturesStd(index) != 0.0 && value != 0.0) {
-          sum += localCoefficients(index) * value / localFeaturesStd(index)
-        }
-      }
-      if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
-      sum
-    }
-
-    val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
-
-    features.foreachActive { (index, value) =>
-      if (localFeaturesStd(index) != 0.0 && value != 0.0) {
-        localGradientArray(index) += multiplier * value / localFeaturesStd(index)
-      }
-    }
-
-    if (fitIntercept) {
-      localGradientArray(numFeaturesPlusIntercept - 1) += multiplier
-    }
-
-    if (label > 0) {
-      // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
-      lossSum += weight * MLUtils.log1pExp(margin)
-    } else {
-      lossSum += weight * (MLUtils.log1pExp(margin) - margin)
-    }
-  }
-
-  /** Update gradient and loss using multinomial (softmax) loss function. */
-  private def multinomialUpdateInPlace(
-      features: Vector,
-      weight: Double,
-      label: Double): Unit = {
-    // TODO: use level 2 BLAS operations
-    /*
-      Note: this can still be used when numClasses = 2 for binary
-      logistic regression without pivoting.
-     */
-    val localFeaturesStd = bcFeaturesStd.value
-    val localCoefficients = coefficientsArray
-    val localGradientArray = gradientSumArray
-
-    // marginOfLabel is margins(label) in the formula
-    var marginOfLabel = 0.0
-    var maxMargin = Double.NegativeInfinity
-
-    val margins = new Array[Double](numClasses)
-    features.foreachActive { (index, value) =>
-      val stdValue = value / localFeaturesStd(index)
-      var j = 0
-      while (j < numClasses) {
-        margins(j) += localCoefficients(index * numClasses + j) * stdValue
-        j += 1
-      }
-    }
-    var i = 0
-    while (i < numClasses) {
-      if (fitIntercept) {
-        margins(i) += localCoefficients(numClasses * numFeatures + i)
-      }
-      if (i == label.toInt) marginOfLabel = margins(i)
-      if (margins(i) > maxMargin) {
-        maxMargin = margins(i)
-      }
-      i += 1
-    }
-
-    /**
-     * When maxMargin is greater than 0, the original formula could cause overflow.
-     * We address this by subtracting maxMargin from all the margins, so it's guaranteed
-     * that all of the new margins will be smaller than zero to prevent arithmetic overflow.
-     */
-    val multipliers = new Array[Double](numClasses)
-    val sum = {
-      var temp = 0.0
-      var i = 0
-      while (i < numClasses) {
-        if (maxMargin > 0) margins(i) -= maxMargin
-        val exp = math.exp(margins(i))
-        temp += exp
-        multipliers(i) = exp
-        i += 1
-      }
-      temp
-    }
-
-    margins.indices.foreach { i =>
-      multipliers(i) = multipliers(i) / sum - (if (label == i) 1.0 else 0.0)
-    }
-    features.foreachActive { (index, value) =>
-      if (localFeaturesStd(index) != 0.0 && value != 0.0) {
-        val stdValue = value / localFeaturesStd(index)
-        var j = 0
-        while (j < numClasses) {
-          localGradientArray(index * numClasses + j) +=
-            weight * multipliers(j) * stdValue
-          j += 1
-        }
-      }
-    }
-    if (fitIntercept) {
-      var i = 0
-      while (i < numClasses) {
-        localGradientArray(numFeatures * numClasses + i) += weight * multipliers(i)
-        i += 1
-      }
-    }
-
-    val loss = if (maxMargin > 0) {
-      math.log(sum) - marginOfLabel + maxMargin
-    } else {
-      math.log(sum) - marginOfLabel
-    }
-    lossSum += weight * loss
-  }
-
-  /**
-   * Add a new training instance to this LogisticAggregator, and update the loss and gradient
-   * of the objective function.
-   *
-   * @param instance The instance of data point to be added.
-   * @return This LogisticAggregator object.
-   */
-  def add(instance: Instance): this.type = {
-    instance match { case Instance(label, weight, features) =>
-
-      if (weight == 0.0) return this
-
-      if (multinomial) {
-        multinomialUpdateInPlace(features, weight, label)
-      } else {
-        binaryUpdateInPlace(features, weight, label)
-      }
-      weightSum += weight
-      this
-    }
-  }
-
-  /**
-   * Merge another LogisticAggregator, and update the loss and gradient
-   * of the objective function.
-   * (Note that it's in place merging; as a result, `this` object will be modified.)
-   *
-   * @param other The other LogisticAggregator to be merged.
-   * @return This LogisticAggregator object.
-   */
-  def merge(other: LogisticAggregator): this.type = {
-
-    if (other.weightSum != 0.0) {
-      weightSum += other.weightSum
-      lossSum += other.lossSum
-
-      var i = 0
-      val localThisGradientSumArray = this.gradientSumArray
-      val localOtherGradientSumArray = other.gradientSumArray
-      val len = localThisGradientSumArray.length
-      while (i < len) {
-        localThisGradientSumArray(i) += localOtherGradientSumArray(i)
-        i += 1
-      }
-    }
-    this
-  }
-
-  def loss: Double = {
-    require(weightSum > 0.0, s"The effective number of instances should be " +
-      s"greater than 0.0, but $weightSum.")
-    lossSum / weightSum
-  }
-
-  def gradient: Matrix = {
-    require(weightSum > 0.0, s"The effective number of instances should be " +
-      s"greater than 0.0, but $weightSum.")
-    val result = Vectors.dense(gradientSumArray.clone())
-    scal(1.0 / weightSum, result)
-    new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, result.toArray)
-  }
-}
-
-/**
- * LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial (softmax) logistic loss
- * function, as used in multi-class classification (it is also used in binary logistic regression).
- * It returns the loss and gradient with L2 regularization at a particular point (coefficients).
- * It's used in Breeze's convex optimization routines.
- */
-private class LogisticCostFun(
-    instances: RDD[Instance],
-    numClasses: Int,
-    fitIntercept: Boolean,
-    standardization: Boolean,
-    bcFeaturesStd: Broadcast[Array[Double]],
-    regParamL2: Double,
-    multinomial: Boolean,
-    aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
-
-  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
-    val coeffs = Vectors.fromBreeze(coefficients)
-    val bcCoeffs = instances.context.broadcast(coeffs)
-    val featuresStd = bcFeaturesStd.value
-    val numFeatures = featuresStd.length
-    val numCoefficientSets = if (multinomial) numClasses else 1
-    val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
-
-    val logisticAggregator = {
-      val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
-      val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
-
-      instances.treeAggregate(
-        new LogisticAggregator(bcCoeffs, bcFeaturesStd, numClasses, fitIntercept,
-          multinomial)
-      )(seqOp, combOp, aggregationDepth)
-    }
-
-    val totalGradientMatrix = logisticAggregator.gradient
-    val coefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, coeffs.toArray)
-    // regVal is the sum of coefficients squares excluding intercept for L2 regularization.
-    val regVal = if (regParamL2 == 0.0) {
-      0.0
-    } else {
-      var sum = 0.0
-      coefMatrix.foreachActive { case (classIndex, featureIndex, value) =>
-        // We do not apply regularization to the intercepts
-        val isIntercept = fitIntercept && (featureIndex == numFeatures)
-        if (!isIntercept) {
-          // The following code will compute the loss of the regularization; also
-          // the gradient of the regularization, and add back to totalGradientArray.
-          sum += {
-            if (standardization) {
-              val gradValue = totalGradientMatrix(classIndex, featureIndex)
-              totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * value)
-              value * value
-            } else {
-              if (featuresStd(featureIndex) != 0.0) {
-                // If `standardization` is false, we still standardize the data
-                // to improve the rate of convergence; as a result, we have to
-                // perform this reverse standardization by penalizing each component
-                // differently to get effectively the same objective function when
-                // the training dataset is not standardized.
-                val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex))
-                val gradValue = totalGradientMatrix(classIndex, featureIndex)
-                totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * temp)
-                value * temp
-              } else {
-                0.0
-              }
-            }
-          }
-        }
-      }
-      0.5 * regParamL2 * sum
-    }
-    bcCoeffs.destroy(blocking = false)
-
-    (logisticAggregator.loss + regVal, new BDV(totalGradientMatrix.toArray))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala
new file mode 100644
index 0000000..66a5294
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.{DenseVector, Vector}
+import org.apache.spark.mllib.util.MLUtils
+
+/**
+ * LogisticAggregator computes the gradient and loss for binary or multinomial logistic (softmax)
+ * loss function, as used in classification for instances in sparse or dense vector in an online
+ * fashion.
+ *
+ * Two LogisticAggregators can be merged together to have a summary of loss and gradient of
+ * the corresponding joint dataset.
+ *
+ * For improving the convergence rate during the optimization process and also to prevent against
+ * features with very large variances exerting an overly large influence during model training,
+ * packages like R's GLMNET perform the scaling to unit variance and remove the mean in order to
+ * reduce the condition number. The model is then trained in this scaled space, but returns the
+ * coefficients in the original scale. See page 9 in
+ * http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
+ *
+ * However, we don't want to apply the [[org.apache.spark.ml.feature.StandardScaler]] on the
+ * training dataset, and then cache the standardized dataset since it will create a lot of overhead.
+ * As a result, we perform the scaling implicitly when we compute the objective function (though
+ * we do not subtract the mean).
+ *
+ * Note that there is a difference between multinomial (softmax) and binary loss. The binary case
+ * uses one outcome class as a "pivot" and regresses the other class against the pivot. In the
+ * multinomial case, the softmax loss function is used to model each class probability
+ * independently. Using softmax loss produces `K` sets of coefficients, while using a pivot class
+ * produces `K - 1` sets of coefficients (a single coefficient vector in the binary case). In the
+ * binary case, we can say that the coefficients are shared between the positive and negative
+ * classes. When regularization is applied, multinomial (softmax) loss will produce a result
+ * different from binary loss since the positive and negative don't share the coefficients while the
+ * binary regression shares the coefficients between positive and negative.
+ *
+ * The following is a mathematical derivation for the multinomial (softmax) loss.
+ *
+ * The probability of the multinomial outcome $y$ taking on any of the K possible outcomes is:
+ *
+ * <blockquote>
+ *    $$
+ *    P(y_i=0|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1}
+ *       e^{\vec{x}_i^T \vec{\beta}_k}} \\
+ *    P(y_i=1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_1}}{\sum_{k=0}^{K-1}
+ *       e^{\vec{x}_i^T \vec{\beta}_k}}\\
+ *    P(y_i=K-1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_{K-1}}\,}{\sum_{k=0}^{K-1}
+ *       e^{\vec{x}_i^T \vec{\beta}_k}}
+ *    $$
+ * </blockquote>
+ *
+ * The model coefficients $\beta = (\beta_0, \beta_1, \beta_2, ..., \beta_{K-1})$ become a matrix
+ * which has dimension of $K \times (N+1)$ if the intercepts are added. If the intercepts are not
+ * added, the dimension will be $K \times N$.
+ *
+ * Note that the coefficients in the model above lack identifiability. That is, any constant scalar
+ * can be added to all of the coefficients and the probabilities remain the same.
+ *
+ * <blockquote>
+ *    $$
+ *    \begin{align}
+ *    \frac{e^{\vec{x}_i^T \left(\vec{\beta}_0 + \vec{c}\right)}}{\sum_{k=0}^{K-1}
+ *       e^{\vec{x}_i^T \left(\vec{\beta}_k + \vec{c}\right)}}
+ *    = \frac{e^{\vec{x}_i^T \vec{\beta}_0}e^{\vec{x}_i^T \vec{c}}\,}{e^{\vec{x}_i^T \vec{c}}
+ *       \sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
+ *    = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
+ *    \end{align}
+ *    $$
+ * </blockquote>
+ *
+ * However, when regularization is added to the loss function, the coefficients are indeed
+ * identifiable because there is only one set of coefficients which minimizes the regularization
+ * term. When no regularization is applied, we choose the coefficients with the minimum L2
+ * penalty for consistency and reproducibility. For further discussion see:
+ *
+ * Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent"
+ *
+ * The loss of objective function for a single instance of data (we do not include the
+ * regularization term here for simplicity) can be written as
+ *
+ * <blockquote>
+ *    $$
+ *    \begin{align}
+ *    \ell\left(\beta, x_i\right) &= -log{P\left(y_i \middle| \vec{x}_i, \beta\right)} \\
+ *    &= log\left(\sum_{k=0}^{K-1}e^{\vec{x}_i^T \vec{\beta}_k}\right) - \vec{x}_i^T \vec{\beta}_y\\
+ *    &= log\left(\sum_{k=0}^{K-1} e^{margins_k}\right) - margins_y
+ *    \end{align}
+ *    $$
+ * </blockquote>
+ *
+ * where ${margins}_k = \vec{x}_i^T \vec{\beta}_k$.
+ *
+ * For optimization, we have to calculate the first derivative of the loss function, and a simple
+ * calculation shows that
+ *
+ * <blockquote>
+ *    $$
+ *    \begin{align}
+ *    \frac{\partial \ell(\beta, \vec{x}_i, w_i)}{\partial \beta_{j, k}}
+ *    &= x_{i,j} \cdot w_i \cdot \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k'=0}^{K-1}
+ *      e^{\vec{x}_i \cdot \vec{\beta}_{k'}}\,} - I_{y=k}\right) \\
+ *    &= x_{i, j} \cdot w_i \cdot multiplier_k
+ *    \end{align}
+ *    $$
+ * </blockquote>
+ *
+ * where $w_i$ is the sample weight, $I_{y=k}$ is an indicator function
+ *
+ *  <blockquote>
+ *    $$
+ *    I_{y=k} = \begin{cases}
+ *          1 & y = k \\
+ *          0 & else
+ *       \end{cases}
+ *    $$
+ * </blockquote>
+ *
+ * and
+ *
+ * <blockquote>
+ *    $$
+ *    multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k=0}^{K-1}
+ *       e^{\vec{x}_i \cdot \vec{\beta}_k}} - I_{y=k}\right)
+ *    $$
+ * </blockquote>
+ *
+ * If any of margins is larger than 709.78, the numerical computation of multiplier and loss
+ * function will suffer from arithmetic overflow. This issue occurs when there are outliers in
+ * data which are far away from the hyperplane, and this will cause the failing of training once
+ * infinity is introduced. Note that this is only a concern when max(margins) &gt; 0.
+ *
+ * Fortunately, when max(margins) = maxMargin &gt; 0, the loss function and the multiplier can
+ * easily be rewritten into the following equivalent numerically stable formula.
+ *
+ * <blockquote>
+ *    $$
+ *    \ell\left(\beta, x\right) = log\left(\sum_{k=0}^{K-1} e^{margins_k - maxMargin}\right) -
+ *       margins_{y} + maxMargin
+ *    $$
+ * </blockquote>
+ *
+ * Note that each term, $(margins_k - maxMargin)$ in the exponential is no greater than zero; as a
+ * result, overflow will not happen with this formula.
+ *
+ * For $multiplier$, a similar trick can be applied as the following,
+ *
+ * <blockquote>
+ *    $$
+ *    multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k - maxMargin}}{\sum_{k'=0}^{K-1}
+ *       e^{\vec{x}_i \cdot \vec{\beta}_{k'} - maxMargin}} - I_{y=k}\right)
+ *    $$
+ * </blockquote>
+ *
+ *
+ * @param bcCoefficients The broadcast coefficients corresponding to the features.
+ * @param bcFeaturesStd The broadcast standard deviation values of the features.
+ * @param numClasses the number of possible outcomes for k classes classification problem in
+ *                   Multinomial Logistic Regression.
+ * @param fitIntercept Whether to fit an intercept term.
+ * @param multinomial Whether to use multinomial (softmax) or binary loss
+ * @note In order to avoid unnecessary computation during calculation of the gradient updates
+ * we lay out the coefficients in column major order during training. This allows us to
+ * perform feature standardization once, while still retaining sequential memory access
+ * for speed. We convert back to row major order when we create the model,
+ * since this form is optimal for the matrix operations used for prediction.
+ */
+private[ml] class LogisticAggregator(
+    bcFeaturesStd: Broadcast[Array[Double]],
+    numClasses: Int,
+    fitIntercept: Boolean,
+    multinomial: Boolean)(bcCoefficients: Broadcast[Vector])
+  extends DifferentiableLossAggregator[Instance, LogisticAggregator] with Logging {
+
+  private val numFeatures = bcFeaturesStd.value.length
+  private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
+  private val coefficientSize = bcCoefficients.value.size
+  protected override val dim: Int = coefficientSize
+  if (multinomial) {
+    require(numClasses ==  coefficientSize / numFeaturesPlusIntercept, s"The number of " +
+      s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize")
+  } else {
+    require(coefficientSize == numFeaturesPlusIntercept, s"Expected $numFeaturesPlusIntercept " +
+      s"coefficients but got $coefficientSize")
+    require(numClasses == 1 || numClasses == 2, s"Binary logistic aggregator requires numClasses " +
+      s"in {1, 2} but found $numClasses.")
+  }
+
+  @transient private lazy val coefficientsArray: Array[Double] = bcCoefficients.value match {
+    case DenseVector(values) => values
+    case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " +
+      s"got type ${bcCoefficients.value.getClass}.)")
+  }
+
+  if (multinomial && numClasses <= 2) {
+    logInfo(s"Multinomial logistic regression for binary classification yields separate " +
+      s"coefficients for positive and negative classes. When no regularization is applied, the" +
+      s"result will be effectively the same as binary logistic regression. When regularization" +
+      s"is applied, multinomial loss will produce a result different from binary loss.")
+  }
+
+  /** Update gradient and loss using binary loss function. */
+  private def binaryUpdateInPlace(features: Vector, weight: Double, label: Double): Unit = {
+
+    val localFeaturesStd = bcFeaturesStd.value
+    val localCoefficients = coefficientsArray
+    val localGradientArray = gradientSumArray
+    val margin = - {
+      var sum = 0.0
+      features.foreachActive { (index, value) =>
+        if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+          sum += localCoefficients(index) * value / localFeaturesStd(index)
+        }
+      }
+      if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
+      sum
+    }
+
+    val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
+
+    features.foreachActive { (index, value) =>
+      if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+        localGradientArray(index) += multiplier * value / localFeaturesStd(index)
+      }
+    }
+
+    if (fitIntercept) {
+      localGradientArray(numFeaturesPlusIntercept - 1) += multiplier
+    }
+
+    if (label > 0) {
+      // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+      lossSum += weight * MLUtils.log1pExp(margin)
+    } else {
+      lossSum += weight * (MLUtils.log1pExp(margin) - margin)
+    }
+  }
+
+  /** Update gradient and loss using multinomial (softmax) loss function. */
+  private def multinomialUpdateInPlace(features: Vector, weight: Double, label: Double): Unit = {
+    // TODO: use level 2 BLAS operations
+    /*
+      Note: this can still be used when numClasses = 2 for binary
+      logistic regression without pivoting.
+     */
+    val localFeaturesStd = bcFeaturesStd.value
+    val localCoefficients = coefficientsArray
+    val localGradientArray = gradientSumArray
+
+    // marginOfLabel is margins(label) in the formula
+    var marginOfLabel = 0.0
+    var maxMargin = Double.NegativeInfinity
+
+    val margins = new Array[Double](numClasses)
+    features.foreachActive { (index, value) =>
+      val stdValue = value / localFeaturesStd(index)
+      var j = 0
+      while (j < numClasses) {
+        margins(j) += localCoefficients(index * numClasses + j) * stdValue
+        j += 1
+      }
+    }
+    var i = 0
+    while (i < numClasses) {
+      if (fitIntercept) {
+        margins(i) += localCoefficients(numClasses * numFeatures + i)
+      }
+      if (i == label.toInt) marginOfLabel = margins(i)
+      if (margins(i) > maxMargin) {
+        maxMargin = margins(i)
+      }
+      i += 1
+    }
+
+    /**
+     * When maxMargin is greater than 0, the original formula could cause overflow.
+     * We address this by subtracting maxMargin from all the margins, so it's guaranteed
+     * that all of the new margins will be smaller than zero to prevent arithmetic overflow.
+     */
+    val multipliers = new Array[Double](numClasses)
+    val sum = {
+      var temp = 0.0
+      var i = 0
+      while (i < numClasses) {
+        if (maxMargin > 0) margins(i) -= maxMargin
+        val exp = math.exp(margins(i))
+        temp += exp
+        multipliers(i) = exp
+        i += 1
+      }
+      temp
+    }
+
+    margins.indices.foreach { i =>
+      multipliers(i) = multipliers(i) / sum - (if (label == i) 1.0 else 0.0)
+    }
+    features.foreachActive { (index, value) =>
+      if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+        val stdValue = value / localFeaturesStd(index)
+        var j = 0
+        while (j < numClasses) {
+          localGradientArray(index * numClasses + j) += weight * multipliers(j) * stdValue
+          j += 1
+        }
+      }
+    }
+    if (fitIntercept) {
+      var i = 0
+      while (i < numClasses) {
+        localGradientArray(numFeatures * numClasses + i) += weight * multipliers(i)
+        i += 1
+      }
+    }
+
+    val loss = if (maxMargin > 0) {
+      math.log(sum) - marginOfLabel + maxMargin
+    } else {
+      math.log(sum) - marginOfLabel
+    }
+    lossSum += weight * loss
+  }
+
+  /**
+   * Add a new training instance to this LogisticAggregator, and update the loss and gradient
+   * of the objective function.
+   *
+   * @param instance The instance of data point to be added.
+   * @return This LogisticAggregator object.
+   */
+  def add(instance: Instance): this.type = {
+    instance match { case Instance(label, weight, features) =>
+      require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
+        s" Expecting $numFeatures but got ${features.size}.")
+      require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
+
+      if (weight == 0.0) return this
+
+      if (multinomial) {
+        multinomialUpdateInPlace(features, weight, label)
+      } else {
+        binaryUpdateInPlace(features, weight, label)
+      }
+      weightSum += weight
+      this
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala
index 118c0eb..7ac7c22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala
@@ -18,6 +18,8 @@ package org.apache.spark.ml.optim.loss
 
 import breeze.optimize.DiffFunction
 
+import org.apache.spark.ml.linalg._
+
 /**
  * A Breeze diff function which represents a cost function for differentiable regularization
  * of parameters. e.g. L2 regularization: 1 / 2 regParam * beta dot beta
@@ -32,40 +34,45 @@ private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] {
 }
 
 /**
- * A Breeze diff function for computing the L2 regularized loss and gradient of an array of
+ * A Breeze diff function for computing the L2 regularized loss and gradient of a vector of
  * coefficients.
  *
  * @param regParam The magnitude of the regularization.
  * @param shouldApply A function (Int => Boolean) indicating whether a given index should have
  *                    regularization applied to it.
- * @param featuresStd Option indicating whether the regularization should be scaled by the standard
- *                    deviation of the features.
+ * @param applyFeaturesStd Option for a function which maps coefficient index (column major) to the
+ *                         feature standard deviation. If `None`, no standardization is applied.
  */
 private[ml] class L2Regularization(
-    val regParam: Double,
+    override val regParam: Double,
     shouldApply: Int => Boolean,
-    featuresStd: Option[Array[Double]]) extends DifferentiableRegularization[Array[Double]] {
+    applyFeaturesStd: Option[Int => Double]) extends DifferentiableRegularization[Vector] {
 
-  override def calculate(coefficients: Array[Double]): (Double, Array[Double]) = {
-    var sum = 0.0
-    val gradient = new Array[Double](coefficients.length)
-    coefficients.indices.filter(shouldApply).foreach { j =>
-      val coef = coefficients(j)
-      featuresStd match {
-        case Some(stds) =>
-          val std = stds(j)
-          if (std != 0.0) {
-            val temp = coef / (std * std)
-            sum += coef * temp
-            gradient(j) = regParam * temp
-          } else {
-            0.0
+  override def calculate(coefficients: Vector): (Double, Vector) = {
+    coefficients match {
+      case dv: DenseVector =>
+        var sum = 0.0
+        val gradient = new Array[Double](dv.size)
+        dv.values.indices.filter(shouldApply).foreach { j =>
+          val coef = coefficients(j)
+          applyFeaturesStd match {
+            case Some(getStd) =>
+              val std = getStd(j)
+              if (std != 0.0) {
+                val temp = coef / (std * std)
+                sum += coef * temp
+                gradient(j) = regParam * temp
+              } else {
+                0.0
+              }
+            case None =>
+              sum += coef * coef
+              gradient(j) = coef * regParam
           }
-        case None =>
-          sum += coef * coef
-          gradient(j) = coef * regParam
-      }
+        }
+        (0.5 * sum * regParam, Vectors.dense(gradient))
+      case _: SparseVector =>
+        throw new IllegalArgumentException("Sparse coefficients are not currently supported.")
     }
-    (0.5 * sum * regParam, gradient)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala
index 3b1618e..1730416 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala
@@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD
 
 /**
  * This class computes the gradient and loss of a differentiable loss function by mapping a
- * [[DifferentiableLossAggregator]] over an [[RDD]] of [[Instance]]s. The loss function is the
+ * [[DifferentiableLossAggregator]] over an [[RDD]]. The loss function is the
  * sum of the loss computed on a single instance across all points in the RDD. Therefore, the actual
  * analytical form of the loss function is specified by the aggregator, which computes each points
  * contribution to the overall loss.
@@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD
  * A differentiable regularization component can also be added by providing a
  * [[DifferentiableRegularization]] loss function.
  *
- * @param instances
+ * @param instances RDD containing the data to compute the loss function over.
  * @param getAggregator A function which gets a new loss aggregator in every tree aggregate step.
  * @param regularization An option representing the regularization loss function to apply to the
  *                       coefficients.
@@ -50,7 +50,7 @@ private[ml] class RDDLossFunction[
     Agg <: DifferentiableLossAggregator[T, Agg]: ClassTag](
     instances: RDD[T],
     getAggregator: (Broadcast[Vector] => Agg),
-    regularization: Option[DifferentiableRegularization[Array[Double]]],
+    regularization: Option[DifferentiableRegularization[Vector]],
     aggregationDepth: Int = 2)
   extends DiffFunction[BDV[Double]] {
 
@@ -62,8 +62,8 @@ private[ml] class RDDLossFunction[
     val newAgg = instances.treeAggregate(thisAgg)(seqOp, combOp, aggregationDepth)
     val gradient = newAgg.gradient
     val regLoss = regularization.map { regFun =>
-      val (regLoss, regGradient) = regFun.calculate(coefficients.data)
-      BLAS.axpy(1.0, Vectors.dense(regGradient), gradient)
+      val (regLoss, regGradient) = regFun.calculate(Vectors.fromBreeze(coefficients))
+      BLAS.axpy(1.0, regGradient, gradient)
       regLoss
     }.getOrElse(0.0)
     bcCoefficients.destroy(blocking = false)

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index ccc61fe..50931fe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -336,10 +336,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
 
     val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept),
       bcFeaturesStd, bcFeaturesMean)(_)
+    val getFeaturesStd = (j: Int) => if (j >= 0 && j < numFeatures) featuresStd(j) else 0.0
     val regularization = if (effectiveL2RegParam != 0.0) {
       val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
       Some(new L2Regularization(effectiveL2RegParam, shouldApply,
-        if ($(standardization)) None else Some(featuresStd)))
+        if ($(standardization)) None else Some(getFeaturesStd)))
     } else {
       None
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 1ffd8dc..0570499 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.classification.LogisticRegressionSuite._
 import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors}
+import org.apache.spark.ml.optim.aggregator.LogisticAggregator
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
@@ -506,8 +507,8 @@ class LogisticRegressionSuite
   test("sparse coefficients in LogisticAggregator") {
     val bcCoefficientsBinary = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0)))
     val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0))
-    val binaryAgg = new LogisticAggregator(bcCoefficientsBinary, bcFeaturesStd, 2,
-      fitIntercept = true, multinomial = false)
+    val binaryAgg = new LogisticAggregator(bcFeaturesStd, 2,
+      fitIntercept = true, multinomial = false)(bcCoefficientsBinary)
     val thrownBinary = withClue("binary logistic aggregator cannot handle sparse coefficients") {
       intercept[IllegalArgumentException] {
         binaryAgg.add(Instance(1.0, 1.0, Vectors.dense(1.0)))
@@ -516,8 +517,8 @@ class LogisticRegressionSuite
     assert(thrownBinary.getMessage.contains("coefficients only supports dense"))
 
     val bcCoefficientsMulti = spark.sparkContext.broadcast(Vectors.sparse(6, Array(0), Array(1.0)))
-    val multinomialAgg = new LogisticAggregator(bcCoefficientsMulti, bcFeaturesStd, 3,
-      fitIntercept = true, multinomial = true)
+    val multinomialAgg = new LogisticAggregator(bcFeaturesStd, 3,
+      fitIntercept = true, multinomial = true)(bcCoefficientsMulti)
     val thrown = withClue("multinomial logistic aggregator cannot handle sparse coefficients") {
       intercept[IllegalArgumentException] {
         multinomialAgg.add(Instance(1.0, 1.0, Vectors.dense(1.0)))

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala
index 7a4faeb..d7cdeae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala
@@ -17,9 +17,12 @@
 package org.apache.spark.ml.optim.aggregator
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.classification.MultiClassSummarizer
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
 import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 
 class DifferentiableLossAggregatorSuite extends SparkFunSuite {
 
@@ -157,4 +160,38 @@ object DifferentiableLossAggregatorSuite {
       this
     }
   }
+
+  /** Get feature and label summarizers for provided data. */
+  private[ml] def getRegressionSummarizers(
+      instances: Array[Instance]): (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer) = {
+    val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
+                 instance: Instance) =>
+      (c._1.add(instance.features, instance.weight),
+        c._2.add(Vectors.dense(instance.label), instance.weight))
+
+    val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
+                  c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) =>
+      (c1._1.merge(c2._1), c1._2.merge(c2._2))
+
+    instances.aggregate(
+      new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer
+    )(seqOp, combOp)
+  }
+
+  /** Get feature and label summarizers for provided data. */
+  private[ml] def getClassificationSummarizers(
+      instances: Array[Instance]): (MultivariateOnlineSummarizer, MultiClassSummarizer) = {
+    val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+                 instance: Instance) =>
+      (c._1.add(instance.features, instance.weight),
+        c._2.add(instance.label, instance.weight))
+
+    val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+                  c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
+      (c1._1.merge(c2._1), c1._2.merge(c2._2))
+
+    instances.aggregate(
+      new MultivariateOnlineSummarizer, new MultiClassSummarizer
+    )(seqOp, combOp)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala
index d1cb0d3..35b6944 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala
@@ -20,12 +20,12 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
 class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
 
+  import DifferentiableLossAggregatorSuite.getRegressionSummarizers
+
   @transient var instances: Array[Instance] = _
   @transient var instancesConstantFeature: Array[Instance] = _
   @transient var instancesConstantLabel: Array[Instance] = _
@@ -49,29 +49,12 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte
     )
   }
 
-  /** Get feature and label summarizers for provided data. */
-  def getSummarizers(
-    instances: Array[Instance]): (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer) = {
-    val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
-                 instance: Instance) =>
-      (c._1.add(instance.features, instance.weight),
-        c._2.add(Vectors.dense(instance.label), instance.weight))
-
-    val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
-                  c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) =>
-      (c1._1.merge(c2._1), c1._2.merge(c2._2))
-
-    instances.aggregate(
-      new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer
-    )(seqOp, combOp)
-  }
-
   /** Get summary statistics for some data and create a new LeastSquaresAggregator. */
-  def getNewAggregator(
+  private def getNewAggregator(
       instances: Array[Instance],
       coefficients: Vector,
       fitIntercept: Boolean): LeastSquaresAggregator = {
-    val (featuresSummarizer, ySummarizer) = getSummarizers(instances)
+    val (featuresSummarizer, ySummarizer) = getRegressionSummarizers(instances)
     val yStd = math.sqrt(ySummarizer.variance(0))
     val yMean = ySummarizer.mean(0)
     val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
@@ -83,6 +66,26 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte
       bcFeaturesMean)(bcCoefficients)
   }
 
+  test("aggregator add method input size") {
+    val coefficients = Vectors.dense(1.0, 2.0)
+    val agg = getNewAggregator(instances, coefficients, fitIntercept = true)
+    withClue("LeastSquaresAggregator features dimension must match coefficients dimension") {
+      intercept[IllegalArgumentException] {
+        agg.add(Instance(1.0, 1.0, Vectors.dense(2.0)))
+      }
+    }
+  }
+
+  test("negative weight") {
+    val coefficients = Vectors.dense(1.0, 2.0)
+    val agg = getNewAggregator(instances, coefficients, fitIntercept = true)
+    withClue("LeastSquaresAggregator does not support negative instance weights.") {
+      intercept[IllegalArgumentException] {
+        agg.add(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0)))
+      }
+    }
+  }
+
   test("check sizes") {
     val coefficients = Vectors.dense(1.0, 2.0)
     val aggIntercept = getNewAggregator(instances, coefficients, fitIntercept = true)
@@ -102,7 +105,7 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte
      */
     val coefficients = Vectors.dense(1.0, 2.0)
     val numFeatures = coefficients.size
-    val (featuresSummarizer, ySummarizer) = getSummarizers(instances)
+    val (featuresSummarizer, ySummarizer) = getRegressionSummarizers(instances)
     val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
     val featuresMean = featuresSummarizer.mean.toArray
     val yStd = math.sqrt(ySummarizer.variance(0))

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala
new file mode 100644
index 0000000..2b29c67
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.{BLAS, Matrices, Vector, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  import DifferentiableLossAggregatorSuite.getClassificationSummarizers
+
+  @transient var instances: Array[Instance] = _
+  @transient var instancesConstantFeature: Array[Instance] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    instances = Array(
+      Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+      Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)),
+      Instance(2.0, 0.3, Vectors.dense(4.0, 0.5))
+    )
+    instancesConstantFeature = Array(
+      Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+      Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
+      Instance(2.0, 0.3, Vectors.dense(1.0, 0.5))
+    )
+  }
+
+  /** Get summary statistics for some data and create a new LogisticAggregator. */
+  private def getNewAggregator(
+      instances: Array[Instance],
+      coefficients: Vector,
+      fitIntercept: Boolean,
+      isMultinomial: Boolean): LogisticAggregator = {
+    val (featuresSummarizer, ySummarizer) =
+      DifferentiableLossAggregatorSuite.getClassificationSummarizers(instances)
+    val numClasses = ySummarizer.histogram.length
+    val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+    val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd)
+    val bcCoefficients = spark.sparkContext.broadcast(coefficients)
+    new LogisticAggregator(bcFeaturesStd, numClasses, fitIntercept, isMultinomial)(bcCoefficients)
+  }
+
+  test("aggregator add method input size") {
+    val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
+    val interceptArray = Array(4.0, 2.0, -3.0)
+    val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray),
+      fitIntercept = true, isMultinomial = true)
+    withClue("LogisticAggregator features dimension must match coefficients dimension") {
+      intercept[IllegalArgumentException] {
+        agg.add(Instance(1.0, 1.0, Vectors.dense(2.0)))
+      }
+    }
+  }
+
+  test("negative weight") {
+    val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
+    val interceptArray = Array(4.0, 2.0, -3.0)
+    val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray),
+      fitIntercept = true, isMultinomial = true)
+    withClue("LogisticAggregator does not support negative instance weights") {
+      intercept[IllegalArgumentException] {
+        agg.add(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0)))
+      }
+    }
+  }
+
+  test("check sizes multinomial") {
+    val rng = new scala.util.Random
+    val numFeatures = instances.head.features.size
+    val numClasses = instances.map(_.label).toSet.size
+    val coefWithIntercept = Vectors.dense(
+      Array.fill(numClasses * (numFeatures + 1))(rng.nextDouble))
+    val coefWithoutIntercept = Vectors.dense(
+      Array.fill(numClasses * numFeatures)(rng.nextDouble))
+    val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true,
+      isMultinomial = true)
+    val aggNoIntercept = getNewAggregator(instances, coefWithoutIntercept, fitIntercept = false,
+      isMultinomial = true)
+    instances.foreach(aggIntercept.add)
+    instances.foreach(aggNoIntercept.add)
+
+    assert(aggIntercept.gradient.size === (numFeatures + 1) * numClasses)
+    assert(aggNoIntercept.gradient.size === numFeatures * numClasses)
+  }
+
+  test("check sizes binomial") {
+    val rng = new scala.util.Random
+    val binaryInstances = instances.filter(_.label < 2.0)
+    val numFeatures = binaryInstances.head.features.size
+    val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble))
+    val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble))
+    val aggIntercept = getNewAggregator(binaryInstances, coefWithIntercept, fitIntercept = true,
+      isMultinomial = false)
+    val aggNoIntercept = getNewAggregator(binaryInstances, coefWithoutIntercept,
+      fitIntercept = false, isMultinomial = false)
+    binaryInstances.foreach(aggIntercept.add)
+    binaryInstances.foreach(aggNoIntercept.add)
+
+    assert(aggIntercept.gradient.size === numFeatures + 1)
+    assert(aggNoIntercept.gradient.size === numFeatures)
+  }
+
+  test("check correctness multinomial") {
+    /*
+    Check that the aggregator computes loss/gradient for:
+      -sum_i w_i * (beta_y dot x_i - log(sum_k e^(beta_k dot x_i)))
+     */
+    val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
+    val interceptArray = Array(4.0, 2.0, -3.0)
+    val numFeatures = instances.head.features.size
+    val numClasses = instances.map(_.label).toSet.size
+    val intercepts = Vectors.dense(interceptArray)
+    val (featuresSummarizer, ySummarizer) = getClassificationSummarizers(instances)
+    val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+    val weightSum = instances.map(_.weight).sum
+
+    val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray),
+      fitIntercept = true, isMultinomial = true)
+    instances.foreach(agg.add)
+
+    // compute the loss
+    val stdCoef = coefArray.indices.map(i => coefArray(i) / featuresStd(i / numClasses)).toArray
+    val linearPredictors = instances.map { case Instance(l, w, f) =>
+      val result = intercepts.copy.toDense
+      BLAS.gemv(1.0, Matrices.dense(numClasses, numFeatures, stdCoef), f, 1.0, result)
+      (l, w, result)
+    }
+
+    // sum_i w * beta_k dot x_i
+    val sumLinear = linearPredictors.map { case (l, w, p) =>
+      w * p(l.toInt)
+    }.sum
+
+    // sum_i w * log(sum_k e^(beta_K dot x_i))
+    val sumLogs = linearPredictors.map { case (l, w, p) =>
+      w * math.log(p.values.map(math.exp).sum)
+    }.sum
+    val loss = (sumLogs - sumLinear) / weightSum
+
+
+    // compute the gradients
+    val gradientCoef = new Array[Double](numFeatures * numClasses)
+    val gradientIntercept = new Array[Double](numClasses)
+    instances.foreach { case Instance(l, w, f) =>
+      val margin = intercepts.copy.toDense
+      BLAS.gemv(1.0, Matrices.dense(numClasses, numFeatures, stdCoef), f, 1.0, margin)
+      val sum = margin.values.map(math.exp).sum
+
+      gradientCoef.indices.foreach { i =>
+        val fStd = f(i / numClasses) / featuresStd(i / numClasses)
+        val cidx = i % numClasses
+        if (cidx == l.toInt) gradientCoef(i) -= w * fStd
+        gradientCoef(i) += w * math.exp(margin(cidx)) / sum * fStd
+      }
+
+      gradientIntercept.indices.foreach { i =>
+        val cidx = i % numClasses
+        if (cidx == l.toInt) gradientIntercept(i) -= w
+        gradientIntercept(i) += w * math.exp(margin(cidx)) / sum
+      }
+    }
+    val gradient = Vectors.dense((gradientCoef ++ gradientIntercept).map(_ / weightSum))
+
+    assert(loss ~== agg.loss relTol 0.01)
+    assert(gradient ~== agg.gradient relTol 0.01)
+  }
+
+  test("check correctness binomial") {
+    /*
+    Check that the aggregator computes loss/gradient for:
+      -sum_i y_i * log(1 / (1 + e^(-beta dot x_i)) + (1 - y_i) * log(1 - 1 / (1 + e^(-beta dot x_i))
+     */
+    val binaryInstances = instances.map { instance =>
+      if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
+    }
+    val coefArray = Array(1.0, 2.0)
+    val intercept = 1.0
+    val numFeatures = binaryInstances.head.features.size
+    val (featuresSummarizer, _) = getClassificationSummarizers(binaryInstances)
+    val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+    val weightSum = binaryInstances.map(_.weight).sum
+
+    val agg = getNewAggregator(binaryInstances, Vectors.dense(coefArray ++ Array(intercept)),
+      fitIntercept = true, isMultinomial = false)
+    binaryInstances.foreach(agg.add)
+
+    // compute the loss
+    val stdCoef = coefArray.indices.map(i => coefArray(i) / featuresStd(i)).toArray
+    val lossSum = binaryInstances.map { case Instance(l, w, f) =>
+      val margin = BLAS.dot(Vectors.dense(stdCoef), f) + intercept
+      val prob = 1.0 / (1.0 + math.exp(-margin))
+      -w * l * math.log(prob) - w * (1.0 - l) * math.log(1.0 - prob)
+    }.sum
+    val loss = lossSum / weightSum
+
+
+
+    // compute the gradients
+    val gradientCoef = new Array[Double](numFeatures)
+    var gradientIntercept = 0.0
+    binaryInstances.foreach { case Instance(l, w, f) =>
+      val margin = BLAS.dot(f, Vectors.dense(coefArray)) + intercept
+      gradientCoef.indices.foreach { i =>
+        gradientCoef(i) += w * (1.0 / (1.0 + math.exp(-margin)) - l) * f(i) / featuresStd(i)
+      }
+      gradientIntercept += w * (1.0 / (1.0 + math.exp(-margin)) - l)
+    }
+    val gradient = Vectors.dense((gradientCoef ++ Array(gradientIntercept)).map(_ / weightSum))
+
+    assert(loss ~== agg.loss relTol 0.01)
+    assert(gradient ~== agg.gradient relTol 0.01)
+  }
+
+  test("check with zero standard deviation") {
+    val binaryInstances = instancesConstantFeature.map { instance =>
+      if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
+    }
+    val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
+    val interceptArray = Array(4.0, 2.0, -3.0)
+    val aggConstantFeature = getNewAggregator(instancesConstantFeature,
+      Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true)
+    instances.foreach(aggConstantFeature.add)
+    // constant features should not affect gradient
+    assert(aggConstantFeature.gradient(0) === 0.0)
+
+    val binaryCoefArray = Array(1.0, 2.0)
+    val intercept = 1.0
+    val aggConstantFeatureBinary = getNewAggregator(binaryInstances,
+      Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true,
+      isMultinomial = false)
+    instances.foreach(aggConstantFeatureBinary.add)
+    // constant features should not affect gradient
+    assert(aggConstantFeatureBinary.gradient(0) === 0.0)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala
index 0794417..4377a6b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala
@@ -17,20 +17,21 @@
 package org.apache.spark.ml.optim.loss
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{BLAS, Vectors}
 
 class DifferentiableRegularizationSuite extends SparkFunSuite {
 
   test("L2 regularization") {
     val shouldApply = (_: Int) => true
     val regParam = 0.3
-    val coefficients = Array(1.0, 3.0, -2.0)
+    val coefficients = Vectors.dense(Array(1.0, 3.0, -2.0))
     val numFeatures = coefficients.size
 
     // check without features standard
     val regFun = new L2Regularization(regParam, shouldApply, None)
     val (loss, grad) = regFun.calculate(coefficients)
-    assert(loss === 0.5 * regParam * coefficients.map(x => x * x).sum)
-    assert(grad === coefficients.map(_ * regParam))
+    assert(loss === 0.5 * regParam * BLAS.dot(coefficients, coefficients))
+    assert(grad === Vectors.dense(coefficients.toArray.map(_ * regParam)))
 
     // check with features standard
     val featuresStd = Array(0.1, 1.1, 0.5)
@@ -39,9 +40,9 @@ class DifferentiableRegularizationSuite extends SparkFunSuite {
     val expectedLossStd = 0.5 * regParam * (0 until numFeatures).map { j =>
       coefficients(j) * coefficients(j) / (featuresStd(j) * featuresStd(j))
     }.sum
-    val expectedGradientStd = (0 until numFeatures).map { j =>
+    val expectedGradientStd = Vectors.dense((0 until numFeatures).map { j =>
       regParam * coefficients(j) / (featuresStd(j) * featuresStd(j))
-    }.toArray
+    }.toArray)
     assert(lossStd === expectedLossStd)
     assert(gradStd === expectedGradientStd)
 
@@ -50,7 +51,7 @@ class DifferentiableRegularizationSuite extends SparkFunSuite {
     val regFunApply = new L2Regularization(regParam, shouldApply2, None)
     val (lossApply, gradApply) = regFunApply.calculate(coefficients)
     assert(lossApply === 0.5 * regParam * coefficients(1) * coefficients(1))
-    assert(gradApply ===  Array(0.0, coefficients(1) * regParam, 0.0))
+    assert(gradApply ===  Vectors.dense(0.0, coefficients(1) * regParam, 0.0))
 
     // check with zero features standard
     val featuresStdZero = Array(0.1, 0.0, 0.5)

http://git-wip-us.apache.org/repos/asf/spark/blob/cf29828d/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala
index cd5cebe..f70da57 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala
@@ -46,11 +46,11 @@ class RDDLossFunctionSuite extends SparkFunSuite with MLlibTestSparkContext {
     val lossWithReg = new RDDLossFunction(instances, getAgg, Some(regLossFun))
 
     val (loss1, grad1) = lossNoReg.calculate(coefficients.asBreeze.toDenseVector)
-    val (regLoss, regGrad) = regLossFun.calculate(coefficients.toArray)
+    val (regLoss, regGrad) = regLossFun.calculate(coefficients)
     val (loss2, grad2) = lossWithReg.calculate(coefficients.asBreeze.toDenseVector)
 
-    BLAS.axpy(1.0, Vectors.fromBreeze(grad1), Vectors.dense(regGrad))
-    assert(Vectors.dense(regGrad) ~== Vectors.fromBreeze(grad2) relTol 1e-5)
+    BLAS.axpy(1.0, Vectors.fromBreeze(grad1), regGrad)
+    assert(regGrad ~== Vectors.fromBreeze(grad2) relTol 1e-5)
     assert(loss1 + regLoss === loss2)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org