You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/01/23 20:18:08 UTC

spark git commit: [SPARK-14709][ML] spark.ml API for linear SVM

Repository: spark
Updated Branches:
  refs/heads/master 0ef1421a6 -> 4a11d029d


[SPARK-14709][ML] spark.ml API for linear SVM

## What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-14709

Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function.
The API should mimic existing spark.ml.classification APIs.
Currently only Binary Classification is supported. Multinomial support can be added in this or following release.
## How was this patch tested?

new unit tests and simple manual test

Author: Yuhao <yu...@intel.com>
Author: Yuhao Yang <hh...@gmail.com>

Closes #15211 from hhbyyh/mlsvm.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4a11d029
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4a11d029
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4a11d029

Branch: refs/heads/master
Commit: 4a11d029dc6abeb98fef5725d3d446a3eb5deddf
Parents: 0ef1421
Author: Yuhao <yu...@intel.com>
Authored: Mon Jan 23 12:18:06 2017 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Jan 23 12:18:06 2017 -0800

----------------------------------------------------------------------
 .../spark/ml/classification/LinearSVC.scala     | 546 +++++++++++++++++++
 .../ml/classification/LogisticRegression.scala  |   4 +-
 .../ml/classification/LinearSVCSuite.scala      | 241 ++++++++
 3 files changed, 789 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4a11d029/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
new file mode 100644
index 0000000..c4e93bf
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.collection.mutable
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, OWLQN => BreezeOWLQN}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.linalg.BLAS._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.{col, lit}
+
+/** Params for linear SVM Classifier. */
+private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
+  with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
+  with HasThreshold with HasAggregationDepth
+
+/**
+ * :: Experimental ::
+ *
+ * Linear SVM Classifier (https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM)
+ *
+ * This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
+ *
+ */
+@Since("2.2.0")
+@Experimental
+class LinearSVC @Since("2.2.0") (
+    @Since("2.2.0") override val uid: String)
+  extends Classifier[Vector, LinearSVC, LinearSVCModel]
+  with LinearSVCParams with DefaultParamsWritable {
+
+  @Since("2.2.0")
+  def this() = this(Identifiable.randomUID("linearsvc"))
+
+  /**
+   * Set the regularization parameter.
+   * Default is 0.0.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setRegParam(value: Double): this.type = set(regParam, value)
+  setDefault(regParam -> 0.0)
+
+  /**
+   * Set the maximum number of iterations.
+   * Default is 100.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setMaxIter(value: Int): this.type = set(maxIter, value)
+  setDefault(maxIter -> 100)
+
+  /**
+   * Whether to fit an intercept term.
+   * Default is true.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+  setDefault(fitIntercept -> true)
+
+  /**
+   * Set the convergence tolerance of iterations.
+   * Smaller values will lead to higher accuracy at the cost of more iterations.
+   * Default is 1E-6.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setTol(value: Double): this.type = set(tol, value)
+  setDefault(tol -> 1E-6)
+
+  /**
+   * Whether to standardize the training features before fitting the model.
+   * Default is true.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setStandardization(value: Boolean): this.type = set(standardization, value)
+  setDefault(standardization -> true)
+
+  /**
+   * Sets the value of param [[weightCol]].
+   * If this is not set or empty, we treat all instance weights as 1.0.
+   * Default is not set, so all instances have weight one.
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
+  /**
+   * Set threshold in binary classification, in range [0, 1].
+   *
+   * @group setParam
+   */
+  @Since("2.2.0")
+  def setThreshold(value: Double): this.type = set(threshold, value)
+  setDefault(threshold -> 0.0)
+
+  /**
+   * Suggested depth for treeAggregate (greater than or equal to 2).
+   * If the dimensions of features or the number of partitions are large,
+   * this param could be adjusted to a larger size.
+   * Default is 2.
+   *
+   * @group expertSetParam
+   */
+  @Since("2.2.0")
+  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+  setDefault(aggregationDepth -> 2)
+
+  @Since("2.2.0")
+  override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
+
+  override protected[classification] def train(dataset: Dataset[_]): LinearSVCModel = {
+    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+    val instances: RDD[Instance] =
+      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
+        case Row(label: Double, weight: Double, features: Vector) =>
+          Instance(label, weight, features)
+      }
+
+    val instr = Instrumentation.create(this, instances)
+    instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold,
+      aggregationDepth)
+
+    val (summarizer, labelSummarizer) = {
+      val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+        instance: Instance) =>
+          (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
+
+      val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+        c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
+          (c1._1.merge(c2._1), c1._2.merge(c2._2))
+
+      instances.treeAggregate(
+        new MultivariateOnlineSummarizer, new MultiClassSummarizer
+      )(seqOp, combOp, $(aggregationDepth))
+    }
+
+    val histogram = labelSummarizer.histogram
+    val numInvalid = labelSummarizer.countInvalid
+    val numFeatures = summarizer.mean.size
+    val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
+
+    val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+      case Some(n: Int) =>
+        require(n >= histogram.length, s"Specified number of classes $n was " +
+          s"less than the number of unique labels ${histogram.length}.")
+        n
+      case None => histogram.length
+    }
+    require(numClasses == 2, s"LinearSVC only supports binary classification." +
+      s" $numClasses classes detected in $labelCol")
+    instr.logNumClasses(numClasses)
+    instr.logNumFeatures(numFeatures)
+
+    val (coefficientVector, interceptVector, objectiveHistory) = {
+      if (numInvalid != 0) {
+        val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
+          s"Found $numInvalid invalid labels."
+        logError(msg)
+        throw new SparkException(msg)
+      }
+
+      val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+      val regParamL2 = $(regParam)
+      val bcFeaturesStd = instances.context.broadcast(featuresStd)
+      val costFun = new LinearSVCCostFun(instances, $(fitIntercept),
+        $(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth))
+
+      def regParamL1Fun = (index: Int) => 0D
+      val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
+      val initialCoefWithIntercept = Vectors.zeros(numFeaturesPlusIntercept)
+
+      val states = optimizer.iterations(new CachedDiffFunction(costFun),
+        initialCoefWithIntercept.asBreeze.toDenseVector)
+
+      val scaledObjectiveHistory = mutable.ArrayBuilder.make[Double]
+      var state: optimizer.State = null
+      while (states.hasNext) {
+        state = states.next()
+        scaledObjectiveHistory += state.adjustedValue
+      }
+
+      bcFeaturesStd.destroy(blocking = false)
+      if (state == null) {
+        val msg = s"${optimizer.getClass.getName} failed."
+        logError(msg)
+        throw new SparkException(msg)
+      }
+
+      /*
+         The coefficients are trained in the scaled space; we're converting them back to
+         the original space.
+         Note that the intercept in scaled space and original space is the same;
+         as a result, no scaling is needed.
+       */
+      val rawCoefficients = state.x.toArray
+      val coefficientArray = Array.tabulate(numFeatures) { i =>
+        if (featuresStd(i) != 0.0) {
+          rawCoefficients(i) / featuresStd(i)
+        } else {
+          0.0
+        }
+      }
+
+      val intercept = if ($(fitIntercept)) {
+        rawCoefficients(numFeaturesPlusIntercept - 1)
+      } else {
+        0.0
+      }
+      (Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result())
+    }
+
+    val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector))
+    instr.logSuccess(model)
+    model
+  }
+}
+
+@Since("2.2.0")
+object LinearSVC extends DefaultParamsReadable[LinearSVC] {
+
+  @Since("2.2.0")
+  override def load(path: String): LinearSVC = super.load(path)
+}
+
+/**
+ * :: Experimental ::
+ * SVM Model trained by [[LinearSVC]]
+ */
+@Since("2.2.0")
+@Experimental
+class LinearSVCModel private[classification] (
+    @Since("2.2.0") override val uid: String,
+    @Since("2.2.0") val coefficients: Vector,
+    @Since("2.2.0") val intercept: Double)
+  extends ClassificationModel[Vector, LinearSVCModel]
+  with LinearSVCParams with MLWritable {
+
+  @Since("2.2.0")
+  override val numClasses: Int = 2
+
+  @Since("2.2.0")
+  override val numFeatures: Int = coefficients.size
+
+  @Since("2.2.0")
+  def setThreshold(value: Double): this.type = set(threshold, value)
+
+  @Since("2.2.0")
+  def setWeightCol(value: Double): this.type = set(threshold, value)
+
+  private val margin: Vector => Double = (features) => {
+    BLAS.dot(features, coefficients) + intercept
+  }
+
+  override protected def predict(features: Vector): Double = {
+    if (margin(features) > $(threshold)) 1.0 else 0.0
+  }
+
+  override protected def predictRaw(features: Vector): Vector = {
+    val m = margin(features)
+    Vectors.dense(-m, m)
+  }
+
+  @Since("2.2.0")
+  override def copy(extra: ParamMap): LinearSVCModel = {
+    copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent)
+  }
+
+  @Since("2.2.0")
+  override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)
+
+}
+
+
+@Since("2.2.0")
+object LinearSVCModel extends MLReadable[LinearSVCModel] {
+
+  @Since("2.2.0")
+  override def read: MLReader[LinearSVCModel] = new LinearSVCReader
+
+  @Since("2.2.0")
+  override def load(path: String): LinearSVCModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[LinearSVCModel]] */
+  private[LinearSVCModel]
+  class LinearSVCWriter(instance: LinearSVCModel)
+    extends MLWriter with Logging {
+
+    private case class Data(coefficients: Vector, intercept: Double)
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.coefficients, instance.intercept)
+      val dataPath = new Path(path, "data").toString
+      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class LinearSVCReader extends MLReader[LinearSVCModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[LinearSVCModel].getName
+
+    override def load(path: String): LinearSVCModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sparkSession.read.format("parquet").load(dataPath)
+      val Row(coefficients: Vector, intercept: Double) =
+        data.select("coefficients", "intercept").head()
+      val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+}
+
+/**
+ * LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function
+ */
+private class LinearSVCCostFun(
+    instances: RDD[Instance],
+    fitIntercept: Boolean,
+    standardization: Boolean,
+    bcFeaturesStd: Broadcast[Array[Double]],
+    regParamL2: Double,
+    aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
+
+  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
+    val coeffs = Vectors.fromBreeze(coefficients)
+    val bcCoeffs = instances.context.broadcast(coeffs)
+    val featuresStd = bcFeaturesStd.value
+    val numFeatures = featuresStd.length
+
+    val svmAggregator = {
+      val seqOp = (c: LinearSVCAggregator, instance: Instance) => c.add(instance)
+      val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2)
+
+      instances.treeAggregate(
+        new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept)
+      )(seqOp, combOp, aggregationDepth)
+    }
+
+    val totalGradientArray = svmAggregator.gradient.toArray
+    // regVal is the sum of coefficients squares excluding intercept for L2 regularization.
+    val regVal = if (regParamL2 == 0.0) {
+      0.0
+    } else {
+      var sum = 0.0
+      coeffs.foreachActive { case (index, value) =>
+        // We do not apply regularization to the intercepts
+        if (index != numFeatures) {
+          // The following code will compute the loss of the regularization; also
+          // the gradient of the regularization, and add back to totalGradientArray.
+          sum += {
+            if (standardization) {
+              totalGradientArray(index) += regParamL2 * value
+              value * value
+            } else {
+              if (featuresStd(index) != 0.0) {
+                // If `standardization` is false, we still standardize the data
+                // to improve the rate of convergence; as a result, we have to
+                // perform this reverse standardization by penalizing each component
+                // differently to get effectively the same objective function when
+                // the training dataset is not standardized.
+                val temp = value / (featuresStd(index) * featuresStd(index))
+                totalGradientArray(index) += regParamL2 * temp
+                value * temp
+              } else {
+                0.0
+              }
+            }
+          }
+        }
+      }
+      0.5 * regParamL2 * sum
+    }
+    bcCoeffs.destroy(blocking = false)
+
+    (svmAggregator.loss + regVal, new BDV(totalGradientArray))
+  }
+}
+
+/**
+ * LinearSVCAggregator computes the gradient and loss for hinge loss function, as used
+ * in binary classification for instances in sparse or dense vector in a online fashion.
+ *
+ * Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of
+ * the corresponding joint dataset.
+ *
+ * This class standardizes feature values during computation using bcFeaturesStd.
+ *
+ * @param bcCoefficients The coefficients corresponding to the features.
+ * @param fitIntercept Whether to fit an intercept term.
+ * @param bcFeaturesStd The standard deviation values of the features.
+ */
+private class LinearSVCAggregator(
+    bcCoefficients: Broadcast[Vector],
+    bcFeaturesStd: Broadcast[Array[Double]],
+    fitIntercept: Boolean) extends Serializable {
+
+  private val numFeatures: Int = bcFeaturesStd.value.length
+  private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
+  private val coefficients: Vector = bcCoefficients.value
+  private var weightSum: Double = 0.0
+  private var lossSum: Double = 0.0
+  require(numFeaturesPlusIntercept == coefficients.size, s"Dimension mismatch. Coefficients " +
+    s"length ${coefficients.size}, FeaturesStd length ${numFeatures}, fitIntercept: $fitIntercept")
+
+  private val coefficientsArray = coefficients match {
+    case dv: DenseVector => dv.values
+    case _ =>
+      throw new IllegalArgumentException(
+        s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
+  }
+  private val gradientSumArray = Array.fill[Double](coefficientsArray.length)(0)
+
+  /**
+   * Add a new training instance to this LinearSVCAggregator, and update the loss and gradient
+   * of the objective function.
+   *
+   * @param instance The instance of data point to be added.
+   * @return This LinearSVCAggregator object.
+   */
+  def add(instance: Instance): this.type = {
+    instance match { case Instance(label, weight, features) =>
+      if (weight == 0.0) return this
+      val localFeaturesStd = bcFeaturesStd.value
+      val localCoefficients = coefficientsArray
+      val localGradientSumArray = gradientSumArray
+
+      val dotProduct = {
+        var sum = 0.0
+        features.foreachActive { (index, value) =>
+          if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+            sum += localCoefficients(index) * value / localFeaturesStd(index)
+          }
+        }
+        if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
+        sum
+      }
+      // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
+      // Therefore the gradient is -(2y - 1)*x
+      val labelScaled = 2 * label - 1.0
+      val loss = if (1.0 > labelScaled * dotProduct) {
+        weight * (1.0 - labelScaled * dotProduct)
+      } else {
+        0.0
+      }
+
+      if (1.0 > labelScaled * dotProduct) {
+        val gradientScale = -labelScaled * weight
+        features.foreachActive { (index, value) =>
+          if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+            localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
+          }
+        }
+        if (fitIntercept) {
+          localGradientSumArray(localGradientSumArray.length - 1) += gradientScale
+        }
+      }
+
+      lossSum += loss
+      weightSum += weight
+      this
+    }
+  }
+
+  /**
+   * Merge another LinearSVCAggregator, and update the loss and gradient
+   * of the objective function.
+   * (Note that it's in place merging; as a result, `this` object will be modified.)
+   *
+   * @param other The other LinearSVCAggregator to be merged.
+   * @return This LinearSVCAggregator object.
+   */
+  def merge(other: LinearSVCAggregator): this.type = {
+    if (other.weightSum != 0.0) {
+      weightSum += other.weightSum
+      lossSum += other.lossSum
+
+      var i = 0
+      val localThisGradientSumArray = this.gradientSumArray
+      val localOtherGradientSumArray = other.gradientSumArray
+      val len = localThisGradientSumArray.length
+      while (i < len) {
+        localThisGradientSumArray(i) += localOtherGradientSumArray(i)
+        i += 1
+      }
+    }
+    this
+  }
+
+  def loss: Double = {
+    if (weightSum != 0) {
+      lossSum / weightSum
+    } else 0.0
+  }
+
+  def gradient: Vector = {
+    if (weightSum != 0) {
+      val result = Vectors.dense(gradientSumArray.clone())
+      scal(1.0 / weightSum, result)
+      result
+    } else Vectors.dense(Array.fill[Double](coefficientsArray.length)(0))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4a11d029/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 5e1d6ee..d2b0f2a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -233,7 +233,7 @@ class LogisticRegression @Since("1.2.0") (
 
   /**
    * Set the convergence tolerance of iterations.
-   * Smaller value will lead to higher accuracy with the cost of more iterations.
+   * Smaller value will lead to higher accuracy at the cost of more iterations.
    * Default is 1E-6.
    *
    * @group setParam
@@ -1431,7 +1431,7 @@ private class LogisticAggregator(
   private var weightSum = 0.0
   private var lossSum = 0.0
 
-  private val gradientSumArray = Array.ofDim[Double](coefficientSize)
+  private val gradientSumArray = Array.fill[Double](coefficientSize)(0.0D)
 
   if (multinomial && numClasses <= 2) {
     logInfo(s"Multinomial logistic regression for binary classification yields separate " +

http://git-wip-us.apache.org/repos/asf/spark/blob/4a11d029/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
new file mode 100644
index 0000000..ee2aefe
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.util.Random
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.classification.LinearSVCSuite._
+import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{Dataset, Row}
+
+
+class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+  import testImplicits._
+
+  private val nPoints = 50
+  @transient var smallBinaryDataset: Dataset[_] = _
+  @transient var smallValidationDataset: Dataset[_] = _
+  @transient var binaryDataset: Dataset[_] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    // NOTE: Intercept should be small for generating equal 0s and 1s
+    val A = 0.01
+    val B = -1.5
+    val C = 1.0
+    smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF()
+    smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF()
+    binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF()
+  }
+
+  /**
+   * Enable the ignored test to export the dataset into CSV format,
+   * so we can validate the training accuracy compared with R's e1071 package.
+   */
+  ignore("export test data into CSV format") {
+    binaryDataset.rdd.map { case Row(label: Double, features: Vector) =>
+      label + "," + features.toArray.mkString(",")
+    }.repartition(1).saveAsTextFile("target/tmp/LinearSVC/binaryDataset")
+  }
+
+  test("Linear SVC binary classification") {
+    val svm = new LinearSVC()
+    val model = svm.fit(smallBinaryDataset)
+    assert(model.transform(smallValidationDataset)
+      .where("prediction=label").count() > nPoints * 0.8)
+  }
+
+  test("Linear SVC binary classification with regularization") {
+    val svm = new LinearSVC()
+    val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
+    assert(model.transform(smallValidationDataset)
+      .where("prediction=label").count() > nPoints * 0.8)
+  }
+
+  test("params") {
+    ParamsSuite.checkParams(new LinearSVC)
+    val model = new LinearSVCModel("linearSVC", Vectors.dense(0.0), 0.0)
+    ParamsSuite.checkParams(model)
+  }
+
+  test("linear svc: default params") {
+    val lsvc = new LinearSVC()
+    assert(lsvc.getRegParam === 0.0)
+    assert(lsvc.getMaxIter === 100)
+    assert(lsvc.getFitIntercept)
+    assert(lsvc.getTol === 1E-6)
+    assert(lsvc.getStandardization)
+    assert(!lsvc.isDefined(lsvc.weightCol))
+    assert(lsvc.getThreshold === 0.0)
+    assert(lsvc.getAggregationDepth === 2)
+    assert(lsvc.getLabelCol === "label")
+    assert(lsvc.getFeaturesCol === "features")
+    assert(lsvc.getPredictionCol === "prediction")
+    assert(lsvc.getRawPredictionCol === "rawPrediction")
+    val model = lsvc.setMaxIter(5).fit(smallBinaryDataset)
+    model.transform(smallBinaryDataset)
+      .select("label", "prediction", "rawPrediction")
+      .collect()
+    assert(model.getThreshold === 0.0)
+    assert(model.getFeaturesCol === "features")
+    assert(model.getPredictionCol === "prediction")
+    assert(model.getRawPredictionCol === "rawPrediction")
+    assert(model.intercept !== 0.0)
+    assert(model.hasParent)
+    assert(model.numFeatures === 2)
+
+    // copied model must have the same parent.
+    MLTestingUtils.checkCopy(model)
+  }
+
+  test("linear svc doesn't fit intercept when fitIntercept is off") {
+    val lsvc = new LinearSVC().setFitIntercept(false).setMaxIter(5)
+    val model = lsvc.fit(smallBinaryDataset)
+    assert(model.intercept === 0.0)
+
+    val lsvc2 = new LinearSVC().setFitIntercept(true).setMaxIter(5)
+    val model2 = lsvc2.fit(smallBinaryDataset)
+    assert(model2.intercept !== 0.0)
+  }
+
+  test("linearSVC with sample weights") {
+    def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = {
+      assert(m1.coefficients ~== m2.coefficients absTol 0.05)
+      assert(m1.intercept ~== m2.intercept absTol 0.05)
+    }
+
+    val estimator = new LinearSVC().setRegParam(0.01).setTol(0.01)
+    val dataset = smallBinaryDataset
+    MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
+      dataset.as[LabeledPoint], estimator, modelEquals)
+    MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC](
+      dataset.as[LabeledPoint], estimator, 2, modelEquals)
+    MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC](
+      dataset.as[LabeledPoint], estimator, modelEquals, 42L)
+  }
+
+  test("linearSVC comparison with R e1071 and scikit-learn") {
+    val trainer1 = new LinearSVC()
+      .setRegParam(0.00002) // set regParam = 2.0 / datasize / c
+      .setMaxIter(200)
+      .setTol(1e-4)
+    val model1 = trainer1.fit(binaryDataset)
+
+    /*
+      Use the following R code to load the data and train the model using glmnet package.
+
+      library(e1071)
+      data <- read.csv("path/target/tmp/LinearSVC/binaryDataset/part-00000", header=FALSE)
+      label <- factor(data$V1)
+      features <- as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+      svm_model <- svm(features, label, type='C', kernel='linear', cost=10, scale=F, tolerance=1e-4)
+      w <- -t(svm_model$coefs) %*% svm_model$SV
+      w
+      svm_model$rho
+
+      > w
+            data.V2  data.V3  data.V4  data.V5
+      [1,] 7.310338 14.89741 22.21005 29.83508
+      > svm_model$rho
+      [1] 7.440177
+
+     */
+    val coefficientsR = Vectors.dense(7.310338, 14.89741, 22.21005, 29.83508)
+    val interceptR = 7.440177
+    assert(model1.intercept ~== interceptR relTol 1E-2)
+    assert(model1.coefficients ~== coefficientsR relTol 1E-2)
+
+    /*
+      Use the following python code to load the data and train the model using scikit-learn package.
+
+      import numpy as np
+      from sklearn import svm
+      f = open("path/target/tmp/LinearSVC/binaryDataset/part-00000")
+      data = np.loadtxt(f,  delimiter=",")
+      X = data[:, 1:]  # select columns 1 through end
+      y = data[:, 0]   # select column 0 as label
+      clf = svm.LinearSVC(fit_intercept=True, C=10, loss='hinge', tol=1e-4, random_state=42)
+      m = clf.fit(X, y)
+      print m.coef_
+      print m.intercept_
+
+      [[  7.24690165  14.77029087  21.99924004  29.5575729 ]]
+      [ 7.36947518]
+     */
+
+    val coefficientsSK = Vectors.dense(7.24690165, 14.77029087, 21.99924004, 29.5575729)
+    val interceptSK = 7.36947518
+    assert(model1.intercept ~== interceptSK relTol 1E-3)
+    assert(model1.coefficients ~== coefficientsSK relTol 4E-3)
+  }
+
+  test("read/write: SVM") {
+    def checkModelData(model: LinearSVCModel, model2: LinearSVCModel): Unit = {
+      assert(model.intercept === model2.intercept)
+      assert(model.coefficients === model2.coefficients)
+      assert(model.numFeatures === model2.numFeatures)
+    }
+    val svm = new LinearSVC()
+    testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
+      checkModelData)
+  }
+}
+
+object LinearSVCSuite {
+
+  val allParamSettings: Map[String, Any] = Map(
+    "regParam" -> 0.01,
+    "maxIter" -> 2,  // intentionally small
+    "fitIntercept" -> true,
+    "tol" -> 0.8,
+    "standardization" -> false,
+    "threshold" -> 0.6,
+    "predictionCol" -> "myPredict",
+    "rawPredictionCol" -> "myRawPredict",
+    "aggregationDepth" -> 3
+  )
+
+    // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
+  def generateSVMInput(
+      intercept: Double,
+      weights: Array[Double],
+      nPoints: Int,
+      seed: Int): Seq[LabeledPoint] = {
+    val rnd = new Random(seed)
+    val weightsMat = new BDV(weights)
+    val x = Array.fill[Array[Double]](nPoints)(
+        Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0))
+    val y = x.map { xi =>
+      val yD = new BDV(xi).dot(weightsMat) + intercept + 0.01 * rnd.nextGaussian()
+      if (yD > 0) 1.0 else 0.0
+    }
+    y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
+  }
+
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org