You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/09/21 22:11:37 UTC

spark git commit: [SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing

Repository: spark
Updated Branches:
  refs/heads/master ba882db6f -> aeef44a3e


[SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing

Implementation of significance testing using Streaming API.

Author: Feynman Liang <fl...@databricks.com>
Author: Feynman Liang <fe...@gmail.com>

Closes #4716 from feynmanliang/ab_testing.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/aeef44a3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/aeef44a3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/aeef44a3

Branch: refs/heads/master
Commit: aeef44a3e32b53f7adecc8e9cfd684fb4598e87d
Parents: ba882db
Author: Feynman Liang <fl...@databricks.com>
Authored: Mon Sep 21 13:11:28 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Sep 21 13:11:28 2015 -0700

----------------------------------------------------------------------
 .../examples/mllib/StreamingTestExample.scala   |  90 +++++++
 .../spark/mllib/stat/test/StreamingTest.scala   | 145 +++++++++++
 .../mllib/stat/test/StreamingTestMethod.scala   | 167 +++++++++++++
 .../spark/mllib/stat/test/TestResult.scala      |  22 ++
 .../spark/mllib/stat/StreamingTestSuite.scala   | 243 +++++++++++++++++++
 5 files changed, 667 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/aeef44a3/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
new file mode 100644
index 0000000..ab29f90
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.SparkConf
+import org.apache.spark.mllib.stat.test.StreamingTest
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+import org.apache.spark.util.Utils
+
+/**
+ * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data
+ * stream arrives as text files in a directory. Stops when the two groups are statistically
+ * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded.
+ *
+ * The rows of the text files must be in the form `Boolean, Double`. For example:
+ *   false, -3.92
+ *   true, 99.32
+ *
+ * Usage:
+ *   StreamingTestExample <dataDir> <batchDuration> <numBatchesTimeout>
+ *
+ * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and
+ * a timeout after 100 insignificant batches, call:
+ *    $ bin/run-example mllib.StreamingTestExample dataDir 5 100
+ *
+ * As you add text files to `dataDir` the significance test wil continually update every
+ * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of
+ * batches processed exceeds `numBatchesTimeout`.
+ */
+object StreamingTestExample {
+
+  def main(args: Array[String]) {
+    if (args.length != 3) {
+      // scalastyle:off println
+      System.err.println(
+        "Usage: StreamingTestExample " +
+          "<dataDir> <batchDuration> <numBatchesTimeout>")
+      // scalastyle:on println
+      System.exit(1)
+    }
+    val dataDir = args(0)
+    val batchDuration = Seconds(args(1).toLong)
+    val numBatchesTimeout = args(2).toInt
+
+    val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample")
+    val ssc = new StreamingContext(conf, batchDuration)
+    ssc.checkpoint({
+      val dir = Utils.createTempDir()
+      dir.toString
+    })
+
+    val data = ssc.textFileStream(dataDir).map(line => line.split(",") match {
+      case Array(label, value) => (label.toBoolean, value.toDouble)
+    })
+
+    val streamingTest = new StreamingTest()
+      .setPeacePeriod(0)
+      .setWindowSize(0)
+      .setTestMethod("welch")
+
+    val out = streamingTest.registerStream(data)
+    out.print()
+
+    // Stop processing if test becomes significant or we time out
+    var timeoutCounter = numBatchesTimeout
+    out.foreachRDD { rdd =>
+      timeoutCounter -= 1
+      val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _)
+      if (timeoutCounter == 0 || anySignificant) rdd.context.stop()
+    }
+
+    ssc.start()
+    ssc.awaitTermination()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aeef44a3/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
new file mode 100644
index 0000000..75c6a51
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat.test
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.StatCounter
+
+/**
+ * :: Experimental ::
+ * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The
+ * Boolean identifies which sample each observation comes from, and the Double is the numeric value
+ * of the observation.
+ *
+ * To address novelty affects, the `peacePeriod` specifies a set number of initial
+ * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing.
+ *
+ * The `windowSize` sets the number of batches each significance test is to be performed over. The
+ * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform
+ * cumulative processing, using all batches seen so far.
+ *
+ * Different tests may be used for assessing statistical significance depending on assumptions
+ * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies
+ * which test will be used.
+ *
+ * Use a builder pattern to construct a streaming test in an application, for example:
+ * {{{
+ *   val model = new StreamingTest()
+ *     .setPeacePeriod(10)
+ *     .setWindowSize(0)
+ *     .setTestMethod("welch")
+ *     .registerStream(DStream)
+ * }}}
+ */
+@Experimental
+@Since("1.6.0")
+class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
+  private var peacePeriod: Int = 0
+  private var windowSize: Int = 0
+  private var testMethod: StreamingTestMethod = WelchTTest
+
+  /** Set the number of initial batches to ignore. Default: 0. */
+  @Since("1.6.0")
+  def setPeacePeriod(peacePeriod: Int): this.type = {
+    this.peacePeriod = peacePeriod
+    this
+  }
+
+  /**
+   * Set the number of batches to compute significance tests over. Default: 0.
+   * A value of 0 will use all batches seen so far.
+   */
+  @Since("1.6.0")
+  def setWindowSize(windowSize: Int): this.type = {
+    this.windowSize = windowSize
+    this
+  }
+
+  /** Set the statistical method used for significance testing. Default: "welch" */
+  @Since("1.6.0")
+  def setTestMethod(method: String): this.type = {
+    this.testMethod = StreamingTestMethod.getTestMethodFromName(method)
+    this
+  }
+
+  /**
+   * Register a [[DStream]] of values for significance testing.
+   *
+   * @param data stream of (key,value) pairs where the key denotes group membership (true =
+   *             experiment, false = control) and the value is the numerical metric to test for
+   *             significance
+   * @return stream of significance testing results
+   */
+  @Since("1.6.0")
+  def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = {
+    val dataAfterPeacePeriod = dropPeacePeriod(data)
+    val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod)
+    val pairedSummaries = pairSummaries(summarizedData)
+
+    testMethod.doTest(pairedSummaries)
+  }
+
+  /** Drop all batches inside the peace period. */
+  private[stat] def dropPeacePeriod(
+      data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = {
+    data.transform { (rdd, time) =>
+      if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) {
+        rdd
+      } else {
+        data.context.sparkContext.parallelize(Seq())
+      }
+    }
+  }
+
+  /** Compute summary statistics over each key and the specified test window size. */
+  private[stat] def summarizeByKeyAndWindow(
+      data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = {
+    if (this.windowSize == 0) {
+      data.updateStateByKey[StatCounter](
+        (newValues: Seq[Double], oldSummary: Option[StatCounter]) => {
+          val newSummary = oldSummary.getOrElse(new StatCounter())
+          newSummary.merge(newValues)
+          Some(newSummary)
+        })
+    } else {
+      val windowDuration = data.slideDuration * this.windowSize
+      data
+        .groupByKeyAndWindow(windowDuration)
+        .mapValues { values =>
+          val summary = new StatCounter()
+          values.foreach(value => summary.merge(value))
+          summary
+        }
+    }
+  }
+
+  /**
+   * Transform a stream of summaries into pairs representing summary statistics for control group
+   * and experiment group up to this batch.
+   */
+  private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)])
+      : DStream[(StatCounter, StatCounter)] = {
+    summarizedData
+      .map[(Int, StatCounter)](x => (0, x._2))
+      .groupByKey()  // should be length two (control/experiment group)
+      .map(x => (x._2.head, x._2.last))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/aeef44a3/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala
new file mode 100644
index 0000000..a7eaed5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat.test
+
+import java.io.Serializable
+
+import scala.language.implicitConversions
+import scala.math.pow
+
+import com.twitter.chill.MeatLocker
+import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues
+import org.apache.commons.math3.stat.inference.TTest
+
+import org.apache.spark.Logging
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.StatCounter
+
+/**
+ * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests
+ * should extend [[StreamingTestMethod]] and introduce a new entry in
+ * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]]
+ */
+private[stat] sealed trait StreamingTestMethod extends Serializable {
+
+  val methodName: String
+  val nullHypothesis: String
+
+  protected type SummaryPairStream =
+    DStream[(StatCounter, StatCounter)]
+
+  /**
+   * Perform streaming 2-sample statistical significance testing.
+   *
+   * @param sampleSummaries stream pairs of summary statistics for the 2 samples
+   * @return stream of rest results
+   */
+  def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult]
+
+  /**
+   * Implicit adapter to convert between streaming summary statistics type and the type required by
+   * the t-testing libraries.
+   */
+  protected implicit def toApacheCommonsStats(
+      summaryStats: StatCounter): StatisticalSummaryValues = {
+    new StatisticalSummaryValues(
+      summaryStats.mean,
+      summaryStats.variance,
+      summaryStats.count,
+      summaryStats.max,
+      summaryStats.min,
+      summaryStats.mean * summaryStats.count
+    )
+  }
+}
+
+/**
+ * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean.
+ * This test does not assume equal variance between the two samples and does not assume equal
+ * sample size.
+ *
+ * @see http://en.wikipedia.org/wiki/Welch%27s_t_test
+ */
+private[stat] object WelchTTest extends StreamingTestMethod with Logging {
+
+  override final val methodName = "Welch's 2-sample t-test"
+  override final val nullHypothesis = "Both groups have same mean"
+
+  private final val tTester = MeatLocker(new TTest())
+
+  override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] =
+    data.map[StreamingTestResult]((test _).tupled)
+
+  private def test(
+      statsA: StatCounter,
+      statsB: StatCounter): StreamingTestResult = {
+    def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = {
+      val s1 = sample1.getVariance
+      val n1 = sample1.getN
+      val s2 = sample2.getVariance
+      val n2 = sample2.getN
+
+      val a = pow(s1, 2) / n1
+      val b = pow(s2, 2) / n2
+
+      pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1)))
+    }
+
+    new StreamingTestResult(
+      tTester.get.tTest(statsA, statsB),
+      welchDF(statsA, statsB),
+      tTester.get.t(statsA, statsB),
+      methodName,
+      nullHypothesis
+    )
+  }
+}
+
+/**
+ * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal
+ * mean. This test assumes equal variance between the two samples and does not assume equal sample
+ * size. For unequal variances, Welch's t-test should be used instead.
+ *
+ * @see http://en.wikipedia.org/wiki/Student%27s_t-test
+ */
+private[stat] object StudentTTest extends StreamingTestMethod with Logging {
+
+  override final val methodName = "Student's 2-sample t-test"
+  override final val nullHypothesis = "Both groups have same mean"
+
+  private final val tTester = MeatLocker(new TTest())
+
+  override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] =
+    data.map[StreamingTestResult]((test _).tupled)
+
+  private def test(
+      statsA: StatCounter,
+      statsB: StatCounter): StreamingTestResult = {
+    def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double =
+      sample1.getN + sample2.getN - 2
+
+    new StreamingTestResult(
+      tTester.get.homoscedasticTTest(statsA, statsB),
+      studentDF(statsA, statsB),
+      tTester.get.homoscedasticT(statsA, statsB),
+      methodName,
+      nullHypothesis
+    )
+  }
+}
+
+/**
+ * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between
+ * strings used in [[StreamingTest]] configuration and actual method implementation.
+ *
+ * Currently supported tests: `welch`, `student`.
+ */
+private[stat] object StreamingTestMethod {
+  // Note: after new `StreamingTestMethod`s are implemented, please update this map.
+  private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map(
+    "welch"->WelchTTest,
+    "student"->StudentTTest)
+
+  def getTestMethodFromName(method: String): StreamingTestMethod =
+    TEST_NAME_TO_OBJECT.get(method) match {
+      case Some(test) => test
+      case None =>
+        throw new IllegalArgumentException(
+          "Unrecognized method name. Supported streaming test methods: "
+            + TEST_NAME_TO_OBJECT.keys.mkString(", "))
+    }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/aeef44a3/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
index d01b370..b0916d3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
@@ -115,3 +115,25 @@ class KolmogorovSmirnovTestResult private[stat] (
     "Kolmogorov-Smirnov test summary:\n" + super.toString
   }
 }
+
+/**
+ * :: Experimental ::
+ * Object containing the test results for streaming testing.
+ */
+@Experimental
+@Since("1.6.0")
+private[stat] class StreamingTestResult @Since("1.6.0") (
+    @Since("1.6.0") override val pValue: Double,
+    @Since("1.6.0") override val degreesOfFreedom: Double,
+    @Since("1.6.0") override val statistic: Double,
+    @Since("1.6.0") val method: String,
+    @Since("1.6.0") override val nullHypothesis: String)
+  extends TestResult[Double] with Serializable {
+
+  override def toString: String = {
+    "Streaming test summary:\n" +
+      s"method: $method\n" +
+      super.toString
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/aeef44a3/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
new file mode 100644
index 0000000..d3e9ef4
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest}
+import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.StatCounter
+import org.apache.spark.util.random.XORShiftRandom
+
+class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
+
+  override def maxWaitTimeMillis : Int = 30000
+
+  test("accuracy for null hypothesis using welch t-test") {
+    // set parameters
+    val testMethod = "welch"
+    val numBatches = 2
+    val pointsPerBatch = 1000
+    val meanA = 0
+    val stdevA = 0.001
+    val meanB = 0
+    val stdevB = 0.001
+
+    val model = new StreamingTest()
+      .setWindowSize(0)
+      .setPeacePeriod(0)
+      .setTestMethod(testMethod)
+
+    val input = generateTestData(
+      numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
+
+    // setup and run the model
+    val ssc = setupStreams(
+      input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+    val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
+
+    assert(outputBatches.flatten.forall(res =>
+      res.pValue > 0.05 && res.method == WelchTTest.methodName))
+  }
+
+  test("accuracy for alternative hypothesis using welch t-test") {
+    // set parameters
+    val testMethod = "welch"
+    val numBatches = 2
+    val pointsPerBatch = 1000
+    val meanA = -10
+    val stdevA = 1
+    val meanB = 10
+    val stdevB = 1
+
+    val model = new StreamingTest()
+      .setWindowSize(0)
+      .setPeacePeriod(0)
+      .setTestMethod(testMethod)
+
+    val input = generateTestData(
+      numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
+
+    // setup and run the model
+    val ssc = setupStreams(
+      input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+    val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
+
+    assert(outputBatches.flatten.forall(res =>
+      res.pValue < 0.05 && res.method == WelchTTest.methodName))
+  }
+
+  test("accuracy for null hypothesis using student t-test") {
+    // set parameters
+    val testMethod = "student"
+    val numBatches = 2
+    val pointsPerBatch = 1000
+    val meanA = 0
+    val stdevA = 0.001
+    val meanB = 0
+    val stdevB = 0.001
+
+    val model = new StreamingTest()
+      .setWindowSize(0)
+      .setPeacePeriod(0)
+      .setTestMethod(testMethod)
+
+    val input = generateTestData(
+      numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
+
+    // setup and run the model
+    val ssc = setupStreams(
+      input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+    val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
+
+
+    assert(outputBatches.flatten.forall(res =>
+      res.pValue > 0.05 && res.method == StudentTTest.methodName))
+  }
+
+  test("accuracy for alternative hypothesis using student t-test") {
+    // set parameters
+    val testMethod = "student"
+    val numBatches = 2
+    val pointsPerBatch = 1000
+    val meanA = -10
+    val stdevA = 1
+    val meanB = 10
+    val stdevB = 1
+
+    val model = new StreamingTest()
+      .setWindowSize(0)
+      .setPeacePeriod(0)
+      .setTestMethod(testMethod)
+
+    val input = generateTestData(
+      numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
+
+    // setup and run the model
+    val ssc = setupStreams(
+      input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+    val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
+
+    assert(outputBatches.flatten.forall(res =>
+      res.pValue < 0.05 && res.method == StudentTTest.methodName))
+  }
+
+  test("batches within same test window are grouped") {
+    // set parameters
+    val testWindow = 3
+    val numBatches = 5
+    val pointsPerBatch = 100
+    val meanA = -10
+    val stdevA = 1
+    val meanB = 10
+    val stdevB = 1
+
+    val model = new StreamingTest()
+      .setWindowSize(testWindow)
+      .setPeacePeriod(0)
+
+    val input = generateTestData(
+      numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
+
+    // setup and run the model
+    val ssc = setupStreams(
+      input,
+      (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream))
+    val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches)
+    val outputCounts = outputBatches.flatten.map(_._2.count)
+
+    // number of batches seen so far does not exceed testWindow, expect counts to continue growing
+    for (i <- 0 until testWindow) {
+      assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2))
+    }
+
+    // number of batches seen exceeds testWindow, expect counts to be constant
+    assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2))
+  }
+
+
+  test("entries in peace period are dropped") {
+    // set parameters
+    val peacePeriod = 3
+    val numBatches = 7
+    val pointsPerBatch = 1000
+    val meanA = -10
+    val stdevA = 1
+    val meanB = 10
+    val stdevB = 1
+
+    val model = new StreamingTest()
+      .setWindowSize(0)
+      .setPeacePeriod(peacePeriod)
+
+    val input = generateTestData(
+      numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
+
+    // setup and run the model
+    val ssc = setupStreams(
+      input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream))
+    val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches)
+
+    assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch)
+  }
+
+  test("null hypothesis when only data from one group is present") {
+    // set parameters
+    val numBatches = 2
+    val pointsPerBatch = 1000
+    val meanA = 0
+    val stdevA = 0.001
+    val meanB = 0
+    val stdevB = 0.001
+
+    val model = new StreamingTest()
+      .setWindowSize(0)
+      .setPeacePeriod(0)
+
+    val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
+      .map(batch => batch.filter(_._1)) // only keep one test group
+
+    // setup and run the model
+    val ssc = setupStreams(
+      input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+    val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
+
+    assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001))
+  }
+
+  // Generate testing input with half of the entries in group A and half in group B
+  private def generateTestData(
+      numBatches: Int,
+      pointsPerBatch: Int,
+      meanA: Double,
+      stdevA: Double,
+      meanB: Double,
+      stdevB: Double,
+      seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = {
+    val rand = new XORShiftRandom(seed)
+    val numTrues = pointsPerBatch / 2
+    val data = (0 until numBatches).map { i =>
+      (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++
+        (pointsPerBatch / 2 until pointsPerBatch).map { idx =>
+          (false, meanB + stdevB * rand.nextGaussian())
+        }
+    }
+
+    data
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org