You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2017/02/20 13:39:01 UTC
mahout git commit: MAHOUT-1930 Add Test for Standard Scaler closes
apache/mahout#280
Repository: mahout
Updated Branches:
refs/heads/master 60bb75192 -> a70a8733c
MAHOUT-1930 Add Test for Standard Scaler closes apache/mahout#280
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/a70a8733
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/a70a8733
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/a70a8733
Branch: refs/heads/master
Commit: a70a8733c6db0e5dcf02384f8dd474469c42e7c5
Parents: 60bb751
Author: rawkintrevo <tr...@gmail.com>
Authored: Mon Feb 20 07:38:33 2017 -0600
Committer: rawkintrevo <tr...@gmail.com>
Committed: Mon Feb 20 07:38:33 2017 -0600
----------------------------------------------------------------------
.../preprocessing/StandardScaler.scala | 17 ++++++++--
.../math/algorithms/PreprocessorSuiteBase.scala | 33 +++++++++++++++++++-
2 files changed, 46 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/a70a8733/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala b/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
index 98d0be1..5863330 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala
@@ -29,6 +29,18 @@ import org.apache.mahout.math.{Vector => MahoutVector, Matrix}
/**
* Scales columns to mean 0 and unit variance
+ *
+ * An important note- The equivelent call in R would be something like
+ * ```r
+ * N <- nrow(x)
+ * scale(x, scale= apply(x, 2, sd) * sqrt(N-1/N))
+ * ```
+ *
+ * This is because R uses degrees of freedom = 1 to calculate standard deviation.
+ * Multiplying the standard deviation by sqrt(N-1/N) 'undoes' this correction.
+ *
+ * The StandardScaler of sklearn uses degrees of freedom = 0 for its calculation, so results
+ * should be similar.
*/
class StandardScaler extends PreprocessorFitter {
@@ -40,15 +52,14 @@ class StandardScaler extends PreprocessorFitter {
}
-class StandardScalerModel(meanVec: MahoutVector,
- stdev: MahoutVector
+class StandardScalerModel(val meanVec: MahoutVector,
+ val stdev: MahoutVector
) extends PreprocessorModel {
def transform[K](input: DrmLike[K]): DrmLike[K] = {
implicit val ctx = input.context
-
// Some mapBlock() calls need it
// implicit val ktag = input.keyClassTag
http://git-wip-us.apache.org/repos/asf/mahout/blob/a70a8733/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
index 9e8f029..ec76c11 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala
@@ -19,7 +19,7 @@
package org.apache.mahout.math.algorithms
-import org.apache.mahout.math.algorithms.preprocessing.{AsFactor, AsFactorModel}
+import org.apache.mahout.math.algorithms.preprocessing._
import org.apache.mahout.math.drm.drmParallelize
import org.apache.mahout.math.scalabindings.{dense, sparse, svec}
import org.apache.mahout.math.scalabindings.RLikeOps._
@@ -56,4 +56,35 @@ trait PreprocessorSuiteBase extends DistributedMahoutSuite with Matchers {
(myAnswer.norm - correctAnswer.norm) should be <= epsilon
}
+
+ test("standard scaler test") {
+ /**
+ * R Prototype
+ * x <- matrix( c(1,2,3,1,5,9,5,-15,-2), nrow=3)
+ * scale(x, scale= apply(x, 2, sd) * sqrt(2/3))
+ * # ^^ note: R uses degress of freedom = 1 for standard deviation calculations.
+ * # we don't (and neither does sklearn)
+ * # the *sqrt(N-1/N) 'undoes' the degrees of freedom = 1
+ */
+
+ val A = drmParallelize(dense(
+ (1, 1, 5),
+ (2, 5, -15),
+ (3, 9, -2)), numPartitions = 2)
+
+ val scaler: StandardScalerModel = new StandardScaler().fit(A)
+
+ val correctAnswer = dense(
+ (-1.224745, -1.224745, -1.224745),
+ (0.000000, 0.000000, 1.224745),
+ (1.224745, 1.224745, 0.000000))
+
+ val myAnswer = scaler.transform(A).collect
+ println(scaler.meanVec)
+ println(scaler.stdev)
+
+ val epsilon = 1E-6
+ (myAnswer.norm - correctAnswer.norm) should be <= epsilon
+
+ }
}