You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2017/02/20 13:39:01 UTC

mahout git commit: MAHOUT-1930 Add Test for Standard Scaler closes apache/mahout#280

Repository: mahout
Updated Branches:
  refs/heads/master 60bb75192 -> a70a8733c


MAHOUT-1930 Add Test for Standard Scaler closes apache/mahout#280


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/a70a8733
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/a70a8733
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/a70a8733

Branch: refs/heads/master
Commit: a70a8733c6db0e5dcf02384f8dd474469c42e7c5
Parents: 60bb751
Author: rawkintrevo <tr...@gmail.com>
Authored: Mon Feb 20 07:38:33 2017 -0600
Committer: rawkintrevo <tr...@gmail.com>
Committed: Mon Feb 20 07:38:33 2017 -0600

----------------------------------------------------------------------
 .../preprocessing/StandardScaler.scala          | 17 ++++++++--
 .../math/algorithms/PreprocessorSuiteBase.scala | 33 +++++++++++++++++++-
 2 files changed, 46 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/a70a8733/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala b/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
index 98d0be1..5863330 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
@@ -29,6 +29,18 @@ import org.apache.mahout.math.{Vector => MahoutVector, Matrix}
 
 /**
   * Scales columns to mean 0 and unit variance
+  *
+  * An important note- The equivelent call in R would be something like
+  * ```r
+  * N <- nrow(x)
+  * scale(x, scale= apply(x, 2, sd) * sqrt(N-1/N))
+  * ```
+  *
+  * This is because R uses degrees of freedom = 1 to calculate standard deviation.
+  * Multiplying the standard deviation by sqrt(N-1/N) 'undoes' this correction.
+  *
+  * The StandardScaler of sklearn uses degrees of freedom = 0 for its calculation, so results
+  * should be similar.
   */
 class StandardScaler extends PreprocessorFitter {
 
@@ -40,15 +52,14 @@ class StandardScaler extends PreprocessorFitter {
 
 }
 
-class StandardScalerModel(meanVec: MahoutVector,
-                          stdev: MahoutVector
+class StandardScalerModel(val meanVec: MahoutVector,
+                          val stdev: MahoutVector
                          ) extends PreprocessorModel {
 
 
   def transform[K](input: DrmLike[K]): DrmLike[K] = {
     implicit val ctx = input.context
 
-
     // Some mapBlock() calls need it
     // implicit val ktag =  input.keyClassTag
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/a70a8733/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
index 9e8f029..ec76c11 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
@@ -19,7 +19,7 @@
 
 package org.apache.mahout.math.algorithms
 
-import org.apache.mahout.math.algorithms.preprocessing.{AsFactor, AsFactorModel}
+import org.apache.mahout.math.algorithms.preprocessing._
 import org.apache.mahout.math.drm.drmParallelize
 import org.apache.mahout.math.scalabindings.{dense, sparse, svec}
 import org.apache.mahout.math.scalabindings.RLikeOps._
@@ -56,4 +56,35 @@ trait PreprocessorSuiteBase extends DistributedMahoutSuite with Matchers {
     (myAnswer.norm - correctAnswer.norm) should be <= epsilon
 
   }
+
+  test("standard scaler test") {
+    /**
+      * R Prototype
+      * x <- matrix( c(1,2,3,1,5,9,5,-15,-2), nrow=3)
+      * scale(x, scale= apply(x, 2, sd) * sqrt(2/3))
+      * # ^^ note: R uses degress of freedom = 1 for standard deviation calculations.
+      * # we don't (and neither does sklearn)
+      * # the *sqrt(N-1/N) 'undoes' the degrees of freedom = 1
+      */
+
+    val A = drmParallelize(dense(
+      (1, 1, 5),
+      (2, 5, -15),
+      (3, 9, -2)), numPartitions = 2)
+
+    val scaler: StandardScalerModel = new StandardScaler().fit(A)
+
+    val correctAnswer = dense(
+      (-1.224745, -1.224745, -1.224745),
+      (0.000000,  0.000000,  1.224745),
+      (1.224745,  1.224745,  0.000000))
+
+    val myAnswer = scaler.transform(A).collect
+    println(scaler.meanVec)
+    println(scaler.stdev)
+
+    val epsilon = 1E-6
+    (myAnswer.norm - correctAnswer.norm) should be <= epsilon
+
+  }
 }