You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/02/01 18:21:15 UTC

spark git commit: [SPARK-5207] [MLLIB] StandardScalerModel mean and variance re-use

Repository: spark
Updated Branches:
  refs/heads/master 80bd715a3 -> bdb0680d3


[SPARK-5207] [MLLIB] StandardScalerModel mean and variance re-use

This seems complete, the duplication of tests for provided means/variances might be overkill, would appreciate some feedback.

Author: Octavian Geagla <og...@gmail.com>

Closes #4140 from ogeagla/SPARK-5207 and squashes the following commits:

fa64dfa [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] change StandardScalerModel to take stddev instead of variance
9078fe0 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] Incorporate code review feedback: change arg ordering, add dev api annotations, do better null checking, add another test and some doc for this.
997d2e0 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] make withMean and withStd public, add constructor which uses defaults, un-refactor test class
64408a4 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] change StandardScalerModel contructor to not be private to mllib, added tests for newly-exposed functionality


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bdb0680d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bdb0680d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bdb0680d

Branch: refs/heads/master
Commit: bdb0680d37614ccdec8933d2dec53793825e43d7
Parents: 80bd715
Author: Octavian Geagla <og...@gmail.com>
Authored: Sun Feb 1 09:21:14 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun Feb 1 09:21:14 2015 -0800

----------------------------------------------------------------------
 docs/mllib-feature-extraction.md                |  11 +-
 .../spark/mllib/feature/StandardScaler.scala    |  71 +++--
 .../mllib/feature/StandardScalerSuite.scala     | 258 ++++++++++++++++---
 3 files changed, 267 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bdb0680d/docs/mllib-feature-extraction.md
----------------------------------------------------------------------
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 197bc77..d4a61a7 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -240,11 +240,11 @@ following parameters in the constructor:
 
 * `withMean` False by default. Centers the data with mean before scaling. It will build a dense
 output, so this does not work on sparse input and will raise an exception.
-* `withStd` True by default. Scales the data to unit variance.
+* `withStd` True by default. Scales the data to unit standard deviation.
 
 We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in
 `StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then
-return a model which can transform the input dataset into unit variance and/or zero mean features
+return a model which can transform the input dataset into unit standard deviation and/or zero mean features
 depending how we configure the `StandardScaler`.
 
 This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer)
@@ -257,7 +257,7 @@ for that feature.
 ### Example
 
 The example below demonstrates how to load a dataset in libsvm format, and standardize the features
-so that the new features have unit variance and/or zero mean.
+so that the new features have unit standard deviation and/or zero mean.
 
 <div class="codetabs">
 <div data-lang="scala">
@@ -271,6 +271,8 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
 
 val scaler1 = new StandardScaler().fit(data.map(x => x.features))
 val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features))
+// scaler3 is an identical model to scaler2, and will produce identical transformations
+val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean)
 
 // data1 will be unit variance.
 val data1 = data.map(x => (x.label, scaler1.transform(x.features)))
@@ -294,6 +296,9 @@ features = data.map(lambda x: x.features)
 
 scaler1 = StandardScaler().fit(features)
 scaler2 = StandardScaler(withMean=True, withStd=True).fit(features)
+# scaler3 is an identical model to scaler2, and will produce identical transformations
+scaler3 = StandardScalerModel(scaler2.std, scaler2.mean)
+
 
 # data1 will be unit variance.
 data1 = label.zip(scaler1.transform(features))

http://git-wip-us.apache.org/repos/asf/spark/blob/bdb0680d/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 2f2c6f9..6ae6917 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -18,14 +18,14 @@
 package org.apache.spark.mllib.feature
 
 import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 
 /**
  * :: Experimental ::
- * Standardizes features by removing the mean and scaling to unit variance using column summary
+ * Standardizes features by removing the mean and scaling to unit std using column summary
  * statistics on the samples in the training set.
  *
  * @param withMean False by default. Centers the data with mean before scaling. It will build a
@@ -52,7 +52,11 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
     val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
       (aggregator, data) => aggregator.add(data),
       (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
-    new StandardScalerModel(withMean, withStd, summary.mean, summary.variance)
+    new StandardScalerModel(
+      Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))),
+      summary.mean,
+      withStd,
+      withMean)
   }
 }
 
@@ -60,28 +64,43 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
  * :: Experimental ::
  * Represents a StandardScaler model that can transform vectors.
  *
- * @param withMean whether to center the data before scaling
- * @param withStd whether to scale the data to have unit standard deviation
+ * @param std column standard deviation values
  * @param mean column mean values
- * @param variance column variance values
+ * @param withStd whether to scale the data to have unit standard deviation
+ * @param withMean whether to center the data before scaling
  */
 @Experimental
-class StandardScalerModel private[mllib] (
-    val withMean: Boolean,
-    val withStd: Boolean,
+class StandardScalerModel (
+    val std: Vector,
     val mean: Vector,
-    val variance: Vector) extends VectorTransformer {
-
-  require(mean.size == variance.size)
+    var withStd: Boolean,
+    var withMean: Boolean) extends VectorTransformer {
 
-  private lazy val factor: Array[Double] = {
-    val f = Array.ofDim[Double](variance.size)
-    var i = 0
-    while (i < f.size) {
-      f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
-      i += 1
+  def this(std: Vector, mean: Vector) {
+    this(std, mean, withStd = std != null, withMean = mean != null)
+    require(this.withStd || this.withMean,
+      "at least one of std or mean vectors must be provided")
+    if (this.withStd && this.withMean) {
+      require(mean.size == std.size,
+        "mean and std vectors must have equal size if both are provided")
     }
-    f
+  }
+
+  def this(std: Vector) = this(std, null)
+
+  @DeveloperApi
+  def setWithMean(withMean: Boolean): this.type = {
+    require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null")
+    this.withMean = withMean
+    this
+  }
+
+  @DeveloperApi
+  def setWithStd(withStd: Boolean): this.type = {
+    require(!(withStd && this.std == null),
+      "cannot set withStd to true while std is null")
+    this.withStd = withStd
+    this
   }
 
   // Since `shift` will be only used in `withMean` branch, we have it as
@@ -93,8 +112,8 @@ class StandardScalerModel private[mllib] (
    * Applies standardization transformation on a vector.
    *
    * @param vector Vector to be standardized.
-   * @return Standardized vector. If the variance of a column is zero, it will return default `0.0`
-   *         for the column with zero variance.
+   * @return Standardized vector. If the std of a column is zero, it will return default `0.0`
+   *         for the column with zero std.
    */
   override def transform(vector: Vector): Vector = {
     require(mean.size == vector.size)
@@ -108,11 +127,9 @@ class StandardScalerModel private[mllib] (
           val values = vs.clone()
           val size = values.size
           if (withStd) {
-            // Having a local reference of `factor` to avoid overhead as the comment before.
-            val localFactor = factor
             var i = 0
             while (i < size) {
-              values(i) = (values(i) - localShift(i)) * localFactor(i)
+              values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
               i += 1
             }
           } else {
@@ -126,15 +143,13 @@ class StandardScalerModel private[mllib] (
         case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
       }
     } else if (withStd) {
-      // Having a local reference of `factor` to avoid overhead as the comment before.
-      val localFactor = factor
       vector match {
         case DenseVector(vs) =>
           val values = vs.clone()
           val size = values.size
           var i = 0
           while(i < size) {
-            values(i) *= localFactor(i)
+            values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0)
             i += 1
           }
           Vectors.dense(values)
@@ -145,7 +160,7 @@ class StandardScalerModel private[mllib] (
           val nnz = values.size
           var i = 0
           while (i < nnz) {
-            values(i) *= localFactor(indices(i))
+            values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0)
             i += 1
           }
           Vectors.sparse(size, indices, values)

http://git-wip-us.apache.org/repos/asf/spark/blob/bdb0680d/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index e9e510b..7f94564 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -27,23 +27,109 @@ import org.apache.spark.rdd.RDD
 
 class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
 
+  // When the input data is all constant, the variance is zero. The standardization against
+  // zero variance is not well-defined, but we decide to just set it into zero here.
+  val constantData = Array(
+    Vectors.dense(2.0),
+    Vectors.dense(2.0),
+    Vectors.dense(2.0)
+  )
+
+  val sparseData = Array(
+    Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
+    Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))),
+    Vectors.sparse(3, Seq((1, -5.1))),
+    Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))),
+    Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))),
+    Vectors.sparse(3, Seq((1, 1.9)))
+  )
+
+  val denseData = Array(
+    Vectors.dense(-2.0, 2.3, 0),
+    Vectors.dense(0.0, -1.0, -3.0),
+    Vectors.dense(0.0, -5.1, 0.0),
+    Vectors.dense(3.8, 0.0, 1.9),
+    Vectors.dense(1.7, -0.6, 0.0),
+    Vectors.dense(0.0, 1.9, 0.0)
+  )
+
   private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
     data.treeAggregate(new MultivariateOnlineSummarizer)(
       (aggregator, data) => aggregator.add(data),
       (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
   }
 
+  test("Standardization with dense input when means and stds are provided") {
+
+    val dataRDD = sc.parallelize(denseData, 3)
+
+    val standardizer1 = new StandardScaler(withMean = true, withStd = true)
+    val standardizer2 = new StandardScaler()
+    val standardizer3 = new StandardScaler(withMean = true, withStd = false)
+
+    val model1 = standardizer1.fit(dataRDD)
+    val model2 = standardizer2.fit(dataRDD)
+    val model3 = standardizer3.fit(dataRDD)
+
+    val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
+    val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
+    val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
+
+    val data1 = denseData.map(equivalentModel1.transform)
+    val data2 = denseData.map(equivalentModel2.transform)
+    val data3 = denseData.map(equivalentModel3.transform)
+
+    val data1RDD = equivalentModel1.transform(dataRDD)
+    val data2RDD = equivalentModel2.transform(dataRDD)
+    val data3RDD = equivalentModel3.transform(dataRDD)
+
+    val summary = computeSummary(dataRDD)
+    val summary1 = computeSummary(data1RDD)
+    val summary2 = computeSummary(data2RDD)
+    val summary3 = computeSummary(data3RDD)
+
+    assert((denseData, data1, data1RDD.collect()).zipped.forall {
+      case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+      case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+      case _ => false
+    }, "The vector type should be preserved after standardization.")
+
+    assert((denseData, data2, data2RDD.collect()).zipped.forall {
+      case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+      case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+      case _ => false
+    }, "The vector type should be preserved after standardization.")
+
+    assert((denseData, data3, data3RDD.collect()).zipped.forall {
+      case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+      case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+      case _ => false
+    }, "The vector type should be preserved after standardization.")
+
+    assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+    assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+    assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+
+    assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+    assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+
+    assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+    assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+
+    assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+    assert(summary3.variance ~== summary.variance absTol 1E-5)
+
+    assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5)
+    assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5)
+    assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5)
+    assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5)
+    assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5)
+    assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5)
+  }
+
   test("Standardization with dense input") {
-    val data = Array(
-      Vectors.dense(-2.0, 2.3, 0),
-      Vectors.dense(0.0, -1.0, -3.0),
-      Vectors.dense(0.0, -5.1, 0.0),
-      Vectors.dense(3.8, 0.0, 1.9),
-      Vectors.dense(1.7, -0.6, 0.0),
-      Vectors.dense(0.0, 1.9, 0.0)
-    )
 
-    val dataRDD = sc.parallelize(data, 3)
+    val dataRDD = sc.parallelize(denseData, 3)
 
     val standardizer1 = new StandardScaler(withMean = true, withStd = true)
     val standardizer2 = new StandardScaler()
@@ -53,9 +139,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
     val model2 = standardizer2.fit(dataRDD)
     val model3 = standardizer3.fit(dataRDD)
 
-    val data1 = data.map(model1.transform)
-    val data2 = data.map(model2.transform)
-    val data3 = data.map(model3.transform)
+    val data1 = denseData.map(model1.transform)
+    val data2 = denseData.map(model2.transform)
+    val data3 = denseData.map(model3.transform)
 
     val data1RDD = model1.transform(dataRDD)
     val data2RDD = model2.transform(dataRDD)
@@ -66,19 +152,19 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
     val summary2 = computeSummary(data2RDD)
     val summary3 = computeSummary(data3RDD)
 
-    assert((data, data1, data1RDD.collect()).zipped.forall {
+    assert((denseData, data1, data1RDD.collect()).zipped.forall {
       case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
       case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
       case _ => false
     }, "The vector type should be preserved after standardization.")
 
-    assert((data, data2, data2RDD.collect()).zipped.forall {
+    assert((denseData, data2, data2RDD.collect()).zipped.forall {
       case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
       case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
       case _ => false
     }, "The vector type should be preserved after standardization.")
 
-    assert((data, data3, data3RDD.collect()).zipped.forall {
+    assert((denseData, data3, data3RDD.collect()).zipped.forall {
       case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
       case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
       case _ => false
@@ -106,17 +192,58 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
   }
 
 
+  test("Standardization with sparse input when means and stds are provided") {
+
+    val dataRDD = sc.parallelize(sparseData, 3)
+
+    val standardizer1 = new StandardScaler(withMean = true, withStd = true)
+    val standardizer2 = new StandardScaler()
+    val standardizer3 = new StandardScaler(withMean = true, withStd = false)
+
+    val model1 = standardizer1.fit(dataRDD)
+    val model2 = standardizer2.fit(dataRDD)
+    val model3 = standardizer3.fit(dataRDD)
+
+    val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
+    val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
+    val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
+
+    val data2 = sparseData.map(equivalentModel2.transform)
+
+    withClue("Standardization with mean can not be applied on sparse input.") {
+      intercept[IllegalArgumentException] {
+        sparseData.map(equivalentModel1.transform)
+      }
+    }
+
+    withClue("Standardization with mean can not be applied on sparse input.") {
+      intercept[IllegalArgumentException] {
+        sparseData.map(equivalentModel3.transform)
+      }
+    }
+
+    val data2RDD = equivalentModel2.transform(dataRDD)
+
+    val summary = computeSummary(data2RDD)
+
+    assert((sparseData, data2, data2RDD.collect()).zipped.forall {
+      case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
+      case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
+      case _ => false
+    }, "The vector type should be preserved after standardization.")
+
+    assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
+
+    assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+    assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+
+    assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
+    assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
+  }
+
   test("Standardization with sparse input") {
-    val data = Array(
-      Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
-      Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))),
-      Vectors.sparse(3, Seq((1, -5.1))),
-      Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))),
-      Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))),
-      Vectors.sparse(3, Seq((1, 1.9)))
-    )
 
-    val dataRDD = sc.parallelize(data, 3)
+    val dataRDD = sc.parallelize(sparseData, 3)
 
     val standardizer1 = new StandardScaler(withMean = true, withStd = true)
     val standardizer2 = new StandardScaler()
@@ -126,25 +253,26 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
     val model2 = standardizer2.fit(dataRDD)
     val model3 = standardizer3.fit(dataRDD)
 
-    val data2 = data.map(model2.transform)
+    val data2 = sparseData.map(model2.transform)
 
     withClue("Standardization with mean can not be applied on sparse input.") {
       intercept[IllegalArgumentException] {
-        data.map(model1.transform)
+        sparseData.map(model1.transform)
       }
     }
 
     withClue("Standardization with mean can not be applied on sparse input.") {
       intercept[IllegalArgumentException] {
-        data.map(model3.transform)
+        sparseData.map(model3.transform)
       }
     }
 
     val data2RDD = model2.transform(dataRDD)
 
-    val summary2 = computeSummary(data2RDD)
 
-    assert((data, data2, data2RDD.collect()).zipped.forall {
+    val summary = computeSummary(data2RDD)
+
+    assert((sparseData, data2, data2RDD.collect()).zipped.forall {
       case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
       case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
       case _ => false
@@ -152,23 +280,44 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
 
     assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
 
-    assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
-    assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
+    assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
+    assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
 
     assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
     assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
   }
 
+  test("Standardization with constant input when means and stds are provided") {
+
+    val dataRDD = sc.parallelize(constantData, 2)
+
+    val standardizer1 = new StandardScaler(withMean = true, withStd = true)
+    val standardizer2 = new StandardScaler(withMean = true, withStd = false)
+    val standardizer3 = new StandardScaler(withMean = false, withStd = true)
+
+    val model1 = standardizer1.fit(dataRDD)
+    val model2 = standardizer2.fit(dataRDD)
+    val model3 = standardizer3.fit(dataRDD)
+
+    val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
+    val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
+    val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
+
+    val data1 = constantData.map(equivalentModel1.transform)
+    val data2 = constantData.map(equivalentModel2.transform)
+    val data3 = constantData.map(equivalentModel3.transform)
+
+    assert(data1.forall(_.toArray.forall(_ == 0.0)),
+      "The variance is zero, so the transformed result should be 0.0")
+    assert(data2.forall(_.toArray.forall(_ == 0.0)),
+      "The variance is zero, so the transformed result should be 0.0")
+    assert(data3.forall(_.toArray.forall(_ == 0.0)),
+      "The variance is zero, so the transformed result should be 0.0")
+  }
+
   test("Standardization with constant input") {
-    // When the input data is all constant, the variance is zero. The standardization against
-    // zero variance is not well-defined, but we decide to just set it into zero here.
-    val data = Array(
-      Vectors.dense(2.0),
-      Vectors.dense(2.0),
-      Vectors.dense(2.0)
-    )
 
-    val dataRDD = sc.parallelize(data, 2)
+    val dataRDD = sc.parallelize(constantData, 2)
 
     val standardizer1 = new StandardScaler(withMean = true, withStd = true)
     val standardizer2 = new StandardScaler(withMean = true, withStd = false)
@@ -178,9 +327,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
     val model2 = standardizer2.fit(dataRDD)
     val model3 = standardizer3.fit(dataRDD)
 
-    val data1 = data.map(model1.transform)
-    val data2 = data.map(model2.transform)
-    val data3 = data.map(model3.transform)
+    val data1 = constantData.map(model1.transform)
+    val data2 = constantData.map(model2.transform)
+    val data3 = constantData.map(model3.transform)
 
     assert(data1.forall(_.toArray.forall(_ == 0.0)),
       "The variance is zero, so the transformed result should be 0.0")
@@ -190,4 +339,29 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
       "The variance is zero, so the transformed result should be 0.0")
   }
 
+  test("StandardScalerModel argument nulls are properly handled") {
+
+    withClue("model needs at least one of std or mean vectors") {
+      intercept[IllegalArgumentException] {
+        val model = new StandardScalerModel(null, null)
+      }
+    }
+    withClue("model needs std to set withStd to true") {
+      intercept[IllegalArgumentException] {
+        val model = new StandardScalerModel(null, Vectors.dense(0.0))
+        model.setWithStd(true)
+      }
+    }
+    withClue("model needs mean to set withMean to true") {
+      intercept[IllegalArgumentException] {
+        val model = new StandardScalerModel(Vectors.dense(0.0), null)
+        model.setWithMean(true)
+      }
+    }
+    withClue("model needs std and mean vectors to be equal size when both are provided") {
+      intercept[IllegalArgumentException] {
+        val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0))
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org