You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2017/01/13 06:28:11 UTC

spark git commit: [SPARK-19142][SPARKR] spark.kmeans should take seed, initSteps, and tol as parameters

Repository: spark
Updated Branches:
  refs/heads/master 3356b8b6a -> 7f24a0b6c


[SPARK-19142][SPARKR] spark.kmeans should take seed, initSteps, and tol as parameters

## What changes were proposed in this pull request?
spark.kmeans doesn't have interface to set initSteps, seed and tol. As Spark Kmeans algorithm doesn't take the same set of parameters as R kmeans, we should maintain a different interface in spark.kmeans.

Add missing parameters and corresponding document.

Modified existing unit tests to take additional parameters.

Author: wm624@hotmail.com <wm...@hotmail.com>

Closes #16523 from wangmiao1981/kmeans.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f24a0b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f24a0b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f24a0b6

Branch: refs/heads/master
Commit: 7f24a0b6c32c56a38cf879d953bbd523922ab9c9
Parents: 3356b8b
Author: wm624@hotmail.com <wm...@hotmail.com>
Authored: Thu Jan 12 22:27:57 2017 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Thu Jan 12 22:27:57 2017 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib_clustering.R                      | 13 +++++++++++--
 .../inst/tests/testthat/test_mllib_clustering.R | 20 ++++++++++++++++++++
 .../org/apache/spark/ml/r/KMeansWrapper.scala   |  9 ++++++++-
 3 files changed, 39 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f24a0b6/R/pkg/R/mllib_clustering.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R
index c443588..ca5182d 100644
--- a/R/pkg/R/mllib_clustering.R
+++ b/R/pkg/R/mllib_clustering.R
@@ -175,6 +175,10 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
 #' @param k number of centers.
 #' @param maxIter maximum iteration number.
 #' @param initMode the initialization algorithm choosen to fit the model.
+#' @param seed the random seed for cluster initialization.
+#' @param initSteps the number of steps for the k-means|| initialization mode.
+#'                  This is an advanced setting, the default of 2 is almost always enough. Must be > 0.
+#' @param tol convergence tolerance of iterations.
 #' @param ... additional argument(s) passed to the method.
 #' @return \code{spark.kmeans} returns a fitted k-means model.
 #' @rdname spark.kmeans
@@ -204,11 +208,16 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
 #' @note spark.kmeans since 2.0.0
 #' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
 setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"),
-          function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) {
+          function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random"),
+                   seed = NULL, initSteps = 2, tol = 1E-4) {
             formula <- paste(deparse(formula), collapse = "")
             initMode <- match.arg(initMode)
+            if (!is.null(seed)) {
+              seed <- as.character(as.integer(seed))
+            }
             jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula,
-                                as.integer(k), as.integer(maxIter), initMode)
+                                as.integer(k), as.integer(maxIter), initMode, seed,
+                                as.integer(initSteps), as.numeric(tol))
             new("KMeansModel", jobj = jobj)
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7f24a0b6/R/pkg/inst/tests/testthat/test_mllib_clustering.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index 1980fff..f013991 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -132,6 +132,26 @@ test_that("spark.kmeans", {
   expect_true(summary2$is.loaded)
 
   unlink(modelPath)
+
+  # Test Kmeans on dataset that is sensitive to seed value
+  col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
+  col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
+  col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
+  cols <- as.data.frame(cbind(col1, col2, col3))
+  df <- createDataFrame(cols)
+
+  model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
+                         initMode = "random", seed = 1, tol = 1E-5)
+  model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
+                         initMode = "random", seed = 22222, tol = 1E-5)
+
+  fitted.model1 <- fitted(model1)
+  fitted.model2 <- fitted(model2)
+  # The predicted clusters are different
+  expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction),
+             c(0, 1, 2, 3))
+  expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction),
+             c(0, 1, 2))
 })
 
 test_that("spark.lda with libsvm", {

http://git-wip-us.apache.org/repos/asf/spark/blob/7f24a0b6/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index ea94585..a1fefd3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -68,7 +68,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
       formula: String,
       k: Int,
       maxIter: Int,
-      initMode: String): KMeansWrapper = {
+      initMode: String,
+      seed: String,
+      initSteps: Int,
+      tol: Double): KMeansWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
@@ -87,6 +90,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
       .setMaxIter(maxIter)
       .setInitMode(initMode)
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setInitSteps(initSteps)
+      .setTol(tol)
+
+    if (seed != null && seed.length > 0) kMeans.setSeed(seed.toInt)
 
     val pipeline = new Pipeline()
       .setStages(Array(rFormulaModel, kMeans))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org