You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/01 06:30:39 UTC

git commit: Streaming KMeans [MLLIB][SPARK-3254]

Repository: spark
Updated Branches:
  refs/heads/master 860219551 -> 98c556ebb


Streaming KMeans [MLLIB][SPARK-3254]

This adds a Streaming KMeans algorithm to MLlib. It uses an update rule that generalizes the mini-batch KMeans update to incorporate a decay factor, which allows past data to be forgotten. The decay factor can be specified explicitly, or via a more intuitive "fractional decay" setting, in units of either data points or batches.

The PR includes:
- StreamingKMeans algorithm with decay factor settings
- Usage example
- Additions to documentation clustering page
- Unit tests of basic behavior and decay behaviors

tdas mengxr rezazadeh

Author: freeman <th...@gmail.com>
Author: Jeremy Freeman <th...@gmail.com>
Author: Xiangrui Meng <me...@databricks.com>

Closes #2942 from freeman-lab/streaming-kmeans and squashes the following commits:

b2e5b4a [freeman] Fixes to docs / examples
078617c [Jeremy Freeman] Merge pull request #1 from mengxr/SPARK-3254
2e682c0 [Xiangrui Meng] take discount on previous weights; use BLAS; detect dying clusters
0411bf5 [freeman] Change decay parameterization
9f7aea9 [freeman] Style fixes
374a706 [freeman] Formatting
ad9bdc2 [freeman] Use labeled points and predictOnValues in examples
77dbd3f [freeman] Make initialization check an assertion
9cfc301 [freeman] Make random seed an argument
44050a9 [freeman] Simpler constructor
c7050d5 [freeman] Fix spacing
2899623 [freeman] Use pattern matching for clarity
a4a316b [freeman] Use collect
1472ec5 [freeman] Doc formatting
ea22ec8 [freeman] Fix imports
2086bdc [freeman] Log cluster center updates
ea9877c [freeman] More documentation
9facbe3 [freeman] Bug fix
5db7074 [freeman] Example usage for StreamingKMeans
f33684b [freeman] Add explanation and example to docs
b5b5f8d [freeman] Add better documentation
a0fd790 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans
9fd9c15 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans
b93350f [freeman] Streaming KMeans with decay


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/98c556eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/98c556eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/98c556eb

Branch: refs/heads/master
Commit: 98c556ebbca6a815813daaefd292d2e46fb16cc2
Parents: 8602195
Author: freeman <th...@gmail.com>
Authored: Fri Oct 31 22:30:12 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Oct 31 22:30:12 2014 -0700

----------------------------------------------------------------------
 docs/mllib-clustering.md                        |  96 ++++++-
 .../spark/examples/mllib/StreamingKMeans.scala  |  77 ++++++
 .../mllib/clustering/StreamingKMeans.scala      | 268 +++++++++++++++++++
 .../mllib/clustering/StreamingKMeansSuite.scala | 157 +++++++++++
 4 files changed, 597 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/98c556eb/docs/mllib-clustering.md
----------------------------------------------------------------------
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 7978e93..c696ae9 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result).
 * *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
 * *epsilon* determines the distance threshold within which we consider k-means to have converged. 
 
-## Examples
+### Examples
 
 <div class="codetabs">
 <div data-lang="scala" markdown="1">
@@ -153,3 +153,97 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap
 section of the Spark
 Quick Start guide. Be sure to also include *spark-mllib* to your build file as
 a dependency.
+
+## Streaming clustering
+
+When data arrive in a stream, we may want to estimate clusters dynamically, 
+updating them as new data arrive. MLlib provides support for streaming k-means clustering, 
+with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm 
+uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign 
+all points to their nearest cluster, compute new cluster centers, then update each cluster using:
+
+`\begin{equation}
+    c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
+\end{equation}`
+`\begin{equation}
+    n_{t+1} = n_t + m_t  
+\end{equation}`
+
+Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned 
+to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` 
+is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` 
+can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; 
+with `$\alpha$=0` only the most recent data will be used. This is analogous to an 
+exponentially-weighted moving average. 
+
+The decay can be specified using a `halfLife` parameter, which determines the 
+correct decay factor `a` such that, for data acquired
+at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
+The unit of time can be specified either as `batches` or `points` and the update rule
+will be adjusted accordingly.
+
+### Examples
+
+This example shows how to estimate clusters on streaming data.
+
+<div class="codetabs">
+
+<div data-lang="scala" markdown="1">
+
+First we import the neccessary classes.
+
+{% highlight scala %}
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+
+{% endhighlight %}
+
+Then we make an input stream of vectors for training, as well as a stream of labeled data 
+points for testing. We assume a StreamingContext `ssc` has been created, see 
+[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.  
+
+{% highlight scala %}
+
+val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse)
+val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse)
+
+{% endhighlight %}
+
+We create a model with random clusters and specify the number of clusters to find
+
+{% highlight scala %}
+
+val numDimensions = 3
+val numClusters = 2
+val model = new StreamingKMeans()
+  .setK(numClusters)
+  .setDecayFactor(1.0)
+  .setRandomCenters(numDimensions, 0.0)
+
+{% endhighlight %}
+
+Now register the streams for training and testing and start the job, printing 
+the predicted cluster assignments on new data points as they arrive.
+
+{% highlight scala %}
+
+model.trainOn(trainingData)
+model.predictOnValues(testData).print()
+
+ssc.start()
+ssc.awaitTermination()
+ 
+{% endhighlight %}
+
+As you add new text files with data the cluster centers will update. Each training 
+point should be formatted as `[x1, x2, x3]`, and each test data point
+should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier 
+(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` 
+the model will update. Anytime a text file is placed in `/testing/data/dir` 
+you will see predictions. With new data, the cluster centers will change!
+
+</div>
+
+</div>

http://git-wip-us.apache.org/repos/asf/spark/blob/98c556eb/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
new file mode 100644
index 0000000..33e5760
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+
+/**
+ * Estimate clusters on one stream of data and make predictions
+ * on another stream, where the data streams arrive as text files
+ * into two different directories.
+ *
+ * The rows of the training text files must be vector data in the form
+ * `[x1,x2,x3,...,xn]`
+ * Where n is the number of dimensions.
+ *
+ * The rows of the test text files must be labeled data in the form
+ * `(y,[x1,x2,x3,...,xn])`
+ * Where y is some identifier. n must be the same for train and test.
+ *
+ * Usage: StreamingKmeans <trainingDir> <testDir> <batchDuration> <numClusters> <numDimensions>
+ *
+ * To run on your local machine using the two directories `trainingDir` and `testDir`,
+ * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call:
+ *    $ bin/run-example \
+ *        org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2
+ *
+ * As you add text files to `trainingDir` the clusters will continuously update.
+ * Anytime you add text files to `testDir`, you'll see predicted labels using the current model.
+ *
+ */
+object StreamingKMeans {
+
+  def main(args: Array[String]) {
+    if (args.length != 5) {
+      System.err.println(
+        "Usage: StreamingKMeans " +
+          "<trainingDir> <testDir> <batchDuration> <numClusters> <numDimensions>")
+      System.exit(1)
+    }
+
+    val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
+    val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
+
+    val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
+    val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
+
+    val model = new StreamingKMeans()
+      .setK(args(3).toInt)
+      .setDecayFactor(1.0)
+      .setRandomCenters(args(4).toInt, 0.0)
+
+    model.trainOn(trainingData)
+    model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
+
+    ssc.start()
+    ssc.awaitTermination()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/98c556eb/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
new file mode 100644
index 0000000..6189dce
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeansModel extends MLlib's KMeansModel for streaming
+ * algorithms, so it can keep track of a continuously updated weight
+ * associated with each cluster, and also update the model by
+ * doing a single iteration of the standard k-means algorithm.
+ *
+ * The update algorithm uses the "mini-batch" KMeans rule,
+ * generalized to incorporate forgetfullness (i.e. decay).
+ * The update rule (for each cluster) is:
+ *
+ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
+ * n_t+t = n_t * a + m_t
+ *
+ * Where c_t is the previously estimated centroid for that cluster,
+ * n_t is the number of points assigned to it thus far, x_t is the centroid
+ * estimated on the current batch, and m_t is the number of points assigned
+ * to that centroid in the current batch.
+ *
+ * The decay factor 'a' scales the contribution of the clusters as estimated thus far,
+ * by applying a as a discount weighting on the current point when evaluating
+ * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
+ * are determined entirely by recent data. Lower values correspond to
+ * more forgetting.
+ *
+ * Decay can optionally be specified by a half life and associated
+ * time unit. The time unit can either be a batch of data or a single
+ * data point. Considering data arrived at time t, the half life h is defined
+ * such that at time t + h the discount applied to the data from t is 0.5.
+ * The definition remains the same whether the time unit is given
+ * as batches or points.
+ *
+ */
+@DeveloperApi
+class StreamingKMeansModel(
+    override val clusterCenters: Array[Vector],
+    val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
+
+  /** Perform a k-means update on a batch of data. */
+  def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
+
+    // find nearest cluster to each point
+    val closest = data.map(point => (this.predict(point), (point, 1L)))
+
+    // get sums and counts for updating each cluster
+    val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
+      BLAS.axpy(1.0, p2._1, p1._1)
+      (p1._1, p1._2 + p2._2)
+    }
+    val dim = clusterCenters(0).size
+    val pointStats: Array[(Int, (Vector, Long))] = closest
+      .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
+      .collect()
+
+    val discount = timeUnit match {
+      case StreamingKMeans.BATCHES => decayFactor
+      case StreamingKMeans.POINTS =>
+        val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
+          n
+        }.sum
+        math.pow(decayFactor, numNewPoints)
+    }
+
+    // apply discount to weights
+    BLAS.scal(discount, Vectors.dense(clusterWeights))
+
+    // implement update rule
+    pointStats.foreach { case (label, (sum, count)) =>
+      val centroid = clusterCenters(label)
+
+      val updatedWeight = clusterWeights(label) + count
+      val lambda = count / math.max(updatedWeight, 1e-16)
+
+      clusterWeights(label) = updatedWeight
+      BLAS.scal(1.0 - lambda, centroid)
+      BLAS.axpy(lambda / count, sum, centroid)
+
+      // display the updated cluster centers
+      val display = clusterCenters(label).size match {
+        case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
+        case _ => centroid.toArray.mkString("[", ",", "]")
+      }
+
+      logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
+    }
+
+    // Check whether the smallest cluster is dying. If so, split the largest cluster.
+    val weightsWithIndex = clusterWeights.view.zipWithIndex
+    val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
+    val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
+    if (minWeight < 1e-8 * maxWeight) {
+      logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
+      val weight = (maxWeight + minWeight) / 2.0
+      clusterWeights(largest) = weight
+      clusterWeights(smallest) = weight
+      val largestClusterCenter = clusterCenters(largest)
+      val smallestClusterCenter = clusterCenters(smallest)
+      var j = 0
+      while (j < dim) {
+        val x = largestClusterCenter(j)
+        val p = 1e-14 * math.max(math.abs(x), 1.0)
+        largestClusterCenter.toBreeze(j) = x + p
+        smallestClusterCenter.toBreeze(j) = x - p
+        j += 1
+      }
+    }
+
+    this
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeans provides methods for configuring a
+ * streaming k-means analysis, training the model on streaming,
+ * and using the model to make predictions on streaming data.
+ * See KMeansModel for details on algorithm and update rules.
+ *
+ * Use a builder pattern to construct a streaming k-means analysis
+ * in an application, like:
+ *
+ *  val model = new StreamingKMeans()
+ *    .setDecayFactor(0.5)
+ *    .setK(3)
+ *    .setRandomCenters(5, 100.0)
+ *    .trainOn(DStream)
+ */
+@DeveloperApi
+class StreamingKMeans(
+    var k: Int,
+    var decayFactor: Double,
+    var timeUnit: String) extends Logging {
+
+  def this() = this(2, 1.0, StreamingKMeans.BATCHES)
+
+  protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
+
+  /** Set the number of clusters. */
+  def setK(k: Int): this.type = {
+    this.k = k
+    this
+  }
+
+  /** Set the decay factor directly (for forgetful algorithms). */
+  def setDecayFactor(a: Double): this.type = {
+    this.decayFactor = decayFactor
+    this
+  }
+
+  /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
+  def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
+    if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
+      throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
+    }
+    this.decayFactor = math.exp(math.log(0.5) / halfLife)
+    logInfo("Setting decay factor to: %g ".format (this.decayFactor))
+    this.timeUnit = timeUnit
+    this
+  }
+
+  /** Specify initial centers directly. */
+  def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+    model = new StreamingKMeansModel(centers, weights)
+    this
+  }
+
+  /**
+   * Initialize random centers, requiring only the number of dimensions.
+   *
+   * @param dim Number of dimensions
+   * @param weight Weight for each center
+   * @param seed Random seed
+   */
+  def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+    val random = new XORShiftRandom(seed)
+    val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
+    val weights = Array.fill(k)(weight)
+    model = new StreamingKMeansModel(centers, weights)
+    this
+  }
+
+  /** Return the latest model. */
+  def latestModel(): StreamingKMeansModel = {
+    model
+  }
+
+  /**
+   * Update the clustering model by training on batches of data from a DStream.
+   * This operation registers a DStream for training the model,
+   * checks whether the cluster centers have been initialized,
+   * and updates the model using each batch of data from the stream.
+   *
+   * @param data DStream containing vector data
+   */
+  def trainOn(data: DStream[Vector]) {
+    assertInitialized()
+    data.foreachRDD { (rdd, time) =>
+      model = model.update(rdd, decayFactor, timeUnit)
+    }
+  }
+
+  /**
+   * Use the clustering model to make predictions on batches of data from a DStream.
+   *
+   * @param data DStream containing vector data
+   * @return DStream containing predictions
+   */
+  def predictOn(data: DStream[Vector]): DStream[Int] = {
+    assertInitialized()
+    data.map(model.predict)
+  }
+
+  /**
+   * Use the model to make predictions on the values of a DStream and carry over its keys.
+   *
+   * @param data DStream containing (key, feature vector) pairs
+   * @tparam K key type
+   * @return DStream containing the input keys and the predictions as values
+   */
+  def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
+    assertInitialized()
+    data.mapValues(model.predict)
+  }
+
+  /** Check whether cluster centers have been initialized. */
+  private[this] def assertInitialized(): Unit = {
+    if (model.clusterCenters == null) {
+      throw new IllegalStateException(
+        "Initial cluster centers must be set before starting predictions")
+    }
+  }
+}
+
+private[clustering] object StreamingKMeans {
+  final val BATCHES = "batches"
+  final val POINTS = "points"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/98c556eb/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
new file mode 100644
index 0000000..850c9fc
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.random.XORShiftRandom
+
+class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
+
+  override def maxWaitTimeMillis = 30000
+
+  test("accuracy for single center and equivalence to grand average") {
+    // set parameters
+    val numBatches = 10
+    val numPoints = 50
+    val k = 1
+    val d = 5
+    val r = 0.1
+
+    // create model with one cluster
+    val model = new StreamingKMeans()
+      .setK(1)
+      .setDecayFactor(1.0)
+      .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0))
+
+    // generate random data for k-means
+    val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+    // setup and run the model training
+    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+      model.trainOn(inputDStream)
+      inputDStream.count()
+    })
+    runStreams(ssc, numBatches, numBatches)
+
+    // estimated center should be close to true center
+    assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
+
+    // estimated center from streaming should exactly match the arithmetic mean of all data points
+    // because the decay factor is set to 1.0
+    val grandMean =
+      input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
+    assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
+  }
+
+  test("accuracy for two centers") {
+    val numBatches = 10
+    val numPoints = 5
+    val k = 2
+    val d = 5
+    val r = 0.1
+
+    // create model with two clusters
+    val kMeans = new StreamingKMeans()
+      .setK(2)
+      .setHalfLife(2, "batches")
+      .setInitialCenters(
+        Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1),
+          Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)),
+        Array(5.0, 5.0))
+
+    // generate random data for k-means
+    val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+    // setup and run the model training
+    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+      kMeans.trainOn(inputDStream)
+      inputDStream.count()
+    })
+    runStreams(ssc, numBatches, numBatches)
+
+    // check that estimated centers are close to true centers
+    // NOTE exact assignment depends on the initialization!
+    assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1)
+    assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1)
+  }
+
+  test("detecting dying clusters") {
+    val numBatches = 10
+    val numPoints = 5
+    val k = 1
+    val d = 1
+    val r = 1.0
+
+    // create model with two clusters
+    val kMeans = new StreamingKMeans()
+      .setK(2)
+      .setHalfLife(0.5, "points")
+      .setInitialCenters(
+        Array(Vectors.dense(0.0), Vectors.dense(1000.0)),
+        Array(1.0, 1.0))
+
+    // new data are all around the first cluster 0.0
+    val (input, _) =
+      StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
+
+    // setup and run the model training
+    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+      kMeans.trainOn(inputDStream)
+      inputDStream.count()
+    })
+    runStreams(ssc, numBatches, numBatches)
+
+    // check that estimated centers are close to true centers
+    // NOTE exact assignment depends on the initialization!
+    val model = kMeans.latestModel()
+    val c0 = model.clusterCenters(0)(0)
+    val c1 = model.clusterCenters(1)(0)
+
+    assert(c0 * c1 < 0.0, "should have one positive center and one negative center")
+    // 0.8 is the mean of half-normal distribution
+    assert(math.abs(c0) ~== 0.8 absTol 0.6)
+    assert(math.abs(c1) ~== 0.8 absTol 0.6)
+  }
+
+  def StreamingKMeansDataGenerator(
+      numPoints: Int,
+      numBatches: Int,
+      k: Int,
+      d: Int,
+      r: Double,
+      seed: Int,
+      initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = {
+    val rand = new XORShiftRandom(seed)
+    val centers = initCenters match {
+      case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
+      case _ => initCenters
+    }
+    val data = (0 until numBatches).map { i =>
+      (0 until numPoints).map { idx =>
+        val center = centers(idx % k)
+        Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
+      }
+    }
+    (data, centers)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org