You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ap...@apache.org on 2014/12/19 23:00:40 UTC

[2/2] mahout git commit: MAHOUT-1493 Port Naive Bayes to Scala DSL (apalumbo) closes apache/mahout#32

MAHOUT-1493 Port Naive Bayes to Scala DSL (apalumbo) closes apache/mahout#32


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/31053431
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/31053431
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/31053431

Branch: refs/heads/master
Commit: 310534319ae8df4cd95c3ff044afb6afdaee2605
Parents: ae1808b
Author: Andrew Palumbo <ap...@outlook.com>
Authored: Fri Dec 19 16:58:05 2014 -0500
Committer: Andrew Palumbo <ap...@outlook.com>
Committed: Fri Dec 19 16:58:05 2014 -0500

----------------------------------------------------------------------
 CHANGELOG                                       |   2 +
 bin/mahout                                      |  17 +
 examples/bin/classify-20newsgroups.sh           |  83 ++--
 .../apache/mahout/h2obindings/H2OHelper.java    |   5 +-
 .../classifier/naivebayes/NBClassifier.scala    | 119 +++++
 .../mahout/classifier/naivebayes/NBModel.scala  | 207 ++++++++
 .../classifier/naivebayes/NaiveBayes.scala      | 415 ++++++++++++++++
 .../classifier/stats/ClassifierStats.scala      | 467 +++++++++++++++++++
 .../classifier/stats/ConfusionMatrix.scala      | 460 ++++++++++++++++++
 .../classifier/naivebayes/NBTestBase.scala      | 171 +++++++
 .../stats/ClassifierStatsTestBase.scala         | 257 ++++++++++
 .../mahout/classifier/ConfusionMatrixTest.java  |   2 +-
 .../sparkbindings/shell/MahoutSparkILoop.scala  |   6 +
 .../classifier/naivebayes/SparkNaiveBayes.scala |  99 ++++
 .../apache/mahout/drivers/TestNBDriver.scala    | 131 ++++++
 .../apache/mahout/drivers/TrainNBDriver.scala   | 115 +++++
 .../mahout/sparkbindings/drm/package.scala      |   2 -
 .../naivebayes/NBSparkTestSuite.scala           |  86 ++++
 .../stats/ClassifierStatsSparkTestSuite.scala   |  26 ++
 19 files changed, 2636 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/CHANGELOG
----------------------------------------------------------------------
diff --git a/CHANGELOG b/CHANGELOG
index ca9b71c..e4f4cae 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 1.0 - unreleased
 
+  MAHOUT-1493: MAHOUT-1493 Port Naive Bayes to Scala DSL (apalumbo) 
+
   MAHOUT-1611: Preconditions.checkArgument in org.apache.mahout.utils.ConcatenateVectorsJob (Haishou Ma via smarthi)
 
   MAHOUT-1615: SparkEngine drmFromHDFS returning the same Key for all Key,Vec Pairs for Text-Keyed SequenceFiles (Anand Avati, dlyubimov, apalumbo)

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/bin/mahout
----------------------------------------------------------------------
diff --git a/bin/mahout b/bin/mahout
index c22118b..c51c239 100755
--- a/bin/mahout
+++ b/bin/mahout
@@ -92,6 +92,14 @@ if [ "$1" == "spark-rowsimilarity" ]; then
   SPARK=1
 fi
 
+if [ "$1" == "spark-trainnb" ]; then
+  SPARK=1
+fi
+
+if [ "$1" == "spark-testnb" ]; then
+  SPARK=1
+fi
+
 if [ "$MAHOUT_CORE" != "" ]; then
   IS_CORE=1
 fi
@@ -262,6 +270,15 @@ case "$1" in
     shift
     "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.RowSimilarityDriver" "$@"
     ;;
+  (spark-trainnb)
+    shift
+    "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.TrainNBDriver" "$@"
+    ;;
+  (spark-testnb)
+    shift
+    "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "org.apache.mahout.drivers.TestNBDriver" "$@"
+    ;;
+
   (h2o-node)
     shift
     "$JAVA" $JAVA_HEAP_MAX -classpath "$CLASSPATH" "water.H2O" -md5skip "$@" -name mah2out

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/examples/bin/classify-20newsgroups.sh
----------------------------------------------------------------------
diff --git a/examples/bin/classify-20newsgroups.sh b/examples/bin/classify-20newsgroups.sh
index 562c0ba..80eb403 100755
--- a/examples/bin/classify-20newsgroups.sh
+++ b/examples/bin/classify-20newsgroups.sh
@@ -42,7 +42,7 @@ if [ "$HADOOP_HOME" != "" ] && [ "$MAHOUT_LOCAL" == "" ] ; then
 fi
 
 WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( cnaivebayes naivebayes sgd clean)
+algorithm=( cnaivebayes-MapReduce naivebayes-MapReduce cnaivebayes-Spark naivebayes-Spark sgd-MapReduce clean)
 if [ -n "$1" ]; then
   choice=$1
 else
@@ -50,7 +50,9 @@ else
   echo "1. ${algorithm[0]}"
   echo "2. ${algorithm[1]}"
   echo "3. ${algorithm[2]}"
-  echo "4. ${algorithm[3]} -- cleans up the work area in $WORK_DIR"
+  echo "4. ${algorithm[3]}"
+  echo "5. ${algorithm[4]}"
+  echo "6. ${algorithm[5]}-- cleans up the work area in $WORK_DIR"
   read -p "Enter your choice : " choice
 fi
 
@@ -79,10 +81,10 @@ cd ../..
 
 set -e
 
-if [ "x$alg" == "xnaivebayes"  -o  "x$alg" == "xcnaivebayes" ]; then
+if  ( [ "x$alg" == "xnaivebayes-MapReduce" ] ||  [ "x$alg" == "xcnaivebayes-MapReduce" ] || [ "x$alg" == "xnaivebayes-Spark"  ] || [ "x$alg" == "xcnaivebayes-Spark" ] ); then
   c=""
 
-  if [ "x$alg" == "xcnaivebayes" ]; then
+  if [ "x$alg" == "xcnaivebayes-MapReduce" -o "x$alg" == "xnaivebayes-Spark" ]; then
     c=" -c"
   fi
 
@@ -96,6 +98,7 @@ if [ "x$alg" == "xnaivebayes"  -o  "x$alg" == "xcnaivebayes" ]; then
     echo "Copying 20newsgroups data to HDFS"
     set +e
     $HADOOP dfs -rmr ${WORK_DIR}/20news-all
+    $HADOOP dfs -rmr ${WORK_DIR}/spark-model
     set -e
     $HADOOP dfs -put ${WORK_DIR}/20news-all ${WORK_DIR}/20news-all
   fi
@@ -117,30 +120,54 @@ if [ "x$alg" == "xnaivebayes"  -o  "x$alg" == "xcnaivebayes" ]; then
     --testOutput ${WORK_DIR}/20news-test-vectors  \
     --randomSelectionPct 40 --overwrite --sequenceFiles -xm sequential
 
-  echo "Training Naive Bayes model"
-  ./bin/mahout trainnb \
-    -i ${WORK_DIR}/20news-train-vectors -el \
-    -o ${WORK_DIR}/model \
-    -li ${WORK_DIR}/labelindex \
-    -ow $c
-
-  echo "Self testing on training set"
-
-  ./bin/mahout testnb \
-    -i ${WORK_DIR}/20news-train-vectors\
-    -m ${WORK_DIR}/model \
-    -l ${WORK_DIR}/labelindex \
-    -ow -o ${WORK_DIR}/20news-testing $c
-
-  echo "Testing on holdout set"
-
-  ./bin/mahout testnb \
-    -i ${WORK_DIR}/20news-test-vectors\
-    -m ${WORK_DIR}/model \
-    -l ${WORK_DIR}/labelindex \
-    -ow -o ${WORK_DIR}/20news-testing $c
-
-elif [ "x$alg" == "xsgd" ]; then
+    if [ "x$alg" == "xnaivebayes-MapReduce"  -o  "x$alg" == "xcnaivebayes-MapReduce" ]; then
+
+      echo "Training Naive Bayes model"
+      ./bin/mahout trainnb \
+        -i ${WORK_DIR}/20news-train-vectors -el \
+        -o ${WORK_DIR}/model \
+        -li ${WORK_DIR}/labelindex \
+        -ow $c
+
+      echo "Self testing on training set"
+
+      ./bin/mahout testnb \
+        -i ${WORK_DIR}/20news-train-vectors\
+        -m ${WORK_DIR}/model \
+        -l ${WORK_DIR}/labelindex \
+        -ow -o ${WORK_DIR}/20news-testing $c
+
+      echo "Testing on holdout set"
+
+      ./bin/mahout testnb \
+        -i ${WORK_DIR}/20news-test-vectors\
+        -m ${WORK_DIR}/model \
+        -l ${WORK_DIR}/labelindex \
+        -ow -o ${WORK_DIR}/20news-testing $c
+
+    elif [ "x$alg" == "xnaivebayes-Spark" -o "x$alg" == "xcnaivebayes-Spark" ]; then
+       set +e
+           $HADOOP dfs -rmr ${WORK_DIR}/spark-model
+       set -e
+
+      echo "Training Naive Bayes model"
+      ./bin/mahout spark-trainnb \
+        -i ${WORK_DIR}/20news-train-vectors \
+        -o ${WORK_DIR}/spark-model $c
+
+      echo "Self testing on training set"
+      ./bin/mahout spark-testnb \
+        -i ${WORK_DIR}/20news-train-vectors\
+        -o ${WORK_DIR}\
+        -m ${WORK_DIR}/spark-model $c
+
+      echo "Testing on holdout set"
+      ./bin/mahout spark-testnb \
+        -i ${WORK_DIR}/20news-test-vectors\
+        -o ${WORK_DIR}\
+        -m ${WORK_DIR}/spark-model $c
+    fi
+elif [ "x$alg" == "xsgd-MapReduce" ]; then
   if [ ! -e "/tmp/news-group.model" ]; then
     echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
     ./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/h2o/src/main/java/org/apache/mahout/h2obindings/H2OHelper.java
----------------------------------------------------------------------
diff --git a/h2o/src/main/java/org/apache/mahout/h2obindings/H2OHelper.java b/h2o/src/main/java/org/apache/mahout/h2obindings/H2OHelper.java
index 294ec7e..d89015d 100644
--- a/h2o/src/main/java/org/apache/mahout/h2obindings/H2OHelper.java
+++ b/h2o/src/main/java/org/apache/mahout/h2obindings/H2OHelper.java
@@ -319,7 +319,6 @@ public class H2OHelper {
     for (int c = 0; c < m.columnSize(); c++) {
       writers[c].close(closer);
     }
-
     // If string labeled matrix, create aux Vec
     Map<String,Integer> map = m.getRowLabelBindings();
     if (map != null) {
@@ -327,8 +326,8 @@ public class H2OHelper {
       labels = frame.anyVec().makeZero();
       Vec.Writer writer = labels.open();
       Map<Integer,String> rmap = reverseMap(map);
-
-      for (long r = 0; r < m.rowSize(); r++) {
+      for (int r = 0; r < m.rowSize(); r++) {
+        // TODO: fix bug here... Exception is being thrown when setting Strings
         writer.set(r, rmap.get(r));
       }
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala
new file mode 100644
index 0000000..5de0733
--- /dev/null
+++ b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala
@@ -0,0 +1,119 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements.  See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+package org.apache.mahout.classifier.naivebayes
+
+import org.apache.mahout.math.Vector
+import scala.collection.JavaConversions._
+
+/**
+ * Abstract Classifier base for Complentary and Standard Classifiers
+ * @param nbModel a trained NBModel
+ */
+abstract class AbstractNBClassifier(nbModel: NBModel) extends java.io.Serializable {
+
+  // Trained Naive Bayes Model
+  val model = nbModel
+
+  /** scoring method for standard and complementary classifiers */
+  protected def getScoreForLabelFeature(label: Int, feature: Int): Double
+
+  /** getter for model */
+  protected def getModel: NBModel= {
+     model
+  }
+
+  /**
+   * Compute the score for a Vector of weighted TF-IDF featured
+   * @param label Label to be scored
+   * @param instance Vector of weights to be calculate score
+   * @return score for this Label
+   */
+  protected def getScoreForLabelInstance(label: Int, instance: Vector): Double = {
+    var result: Double = 0.0
+    for (e <- instance.nonZeroes) {
+      result += e.get * getScoreForLabelFeature(label, e.index)
+    }
+    result
+  }
+
+  /** number of categories the model has been trained on */
+  def numCategories: Int = {
+     model.numLabels
+  }
+
+  /**
+   * get a scoring vector for a vector of TF of TF-IDF weights
+   * @param instance vector of TF of TF-IDF weights to be classified
+   * @return a vector of scores.
+   */
+  def classifyFull(instance: Vector): Vector = {
+    classifyFull(model.createScoringVector, instance)
+  }
+
+  /** helper method for classifyFull(Vector) */
+  def classifyFull(r: Vector, instance: Vector): Vector = {
+    var label: Int = 0
+    for (label <- 0 until model.numLabels) {
+        r.setQuick(label, getScoreForLabelInstance(label, instance))
+      }
+    r
+  }
+}
+
+/**
+ * Standard Multinomial Naive Bayes Classifier
+ * @param nbModel a trained NBModel
+ */
+class StandardNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable{
+  override def getScoreForLabelFeature(label: Int, feature: Int): Double = {
+    val model: NBModel = getModel
+    StandardNBClassifier.computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI, model.numFeatures)
+  }
+}
+
+/** helper object for StandardNBClassifier */
+object StandardNBClassifier extends java.io.Serializable {
+  /** Compute Standard Multinomial Naive Bayes Weights See Rennie et. al. Section 2.1 */
+  def computeWeight(featureLabelWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = {
+    val numerator: Double = featureLabelWeight + alphaI
+    val denominator: Double = labelWeight + alphaI * numFeatures
+    return Math.log(numerator / denominator)
+  }
+}
+
+/**
+ * Complementary Naive Bayes Classifier
+ * @param nbModel a trained NBModel
+ */
+class ComplementaryNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable {
+  override def getScoreForLabelFeature(label: Int, feature: Int): Double = {
+    val model: NBModel = getModel
+    val weight: Double = ComplementaryNBClassifier.computeWeight(model.featureWeight(feature), model.weight(label, feature), model.totalWeightSum, model.labelWeight(label), model.alphaI, model.numFeatures)
+    return weight / model.thetaNormalizer(label)
+  }
+}
+
+/** helper object for ComplementaryNBClassifier */
+object ComplementaryNBClassifier extends java.io.Serializable {
+
+  /** Compute Complementary weights See Rennie et. al. Section 3.1 */
+  def computeWeight(featureWeight: Double, featureLabelWeight: Double, totalWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = {
+    val numerator: Double = featureWeight - featureLabelWeight + alphaI
+    val denominator: Double = totalWeight - labelWeight + alphaI * numFeatures
+    return -Math.log(numerator / denominator)
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala
new file mode 100644
index 0000000..4d19144
--- /dev/null
+++ b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.naivebayes
+
+import org.apache.mahout.math._
+
+import org.apache.mahout.math.{drm, scalabindings}
+
+import scalabindings._
+import scalabindings.RLikeOps._
+import drm.RLikeDrmOps._
+import drm._
+import scala.collection.JavaConverters._
+import scala.language.asInstanceOf
+import scala.collection._
+import JavaConversions._
+
+/**
+ *
+ * @param weightsPerLabelAndFeature Aggregated matrix of weights of labels x features
+ * @param weightsPerFeature Vector of summation of all feature weights.
+ * @param weightsPerLabel Vector of summation of all label weights.
+ * @param perlabelThetaNormalizer Vector of weight normalizers per label (used only for complemtary models)
+ * @param labelIndex HashMap of labels and their corresponding row in the weightMatrix
+ * @param alphaI Laplace smoothing factor.
+ * @param isComplementary Whether or not this is a complementary model.
+ */
+class NBModel(val weightsPerLabelAndFeature: Matrix = null,
+              val weightsPerFeature: Vector = null,
+              val weightsPerLabel: Vector = null,
+              val perlabelThetaNormalizer: Vector = null,
+              val labelIndex: Map[String, Integer] = null,
+              val alphaI: Float = .0f,
+              val isComplementary: Boolean= false)  extends java.io.Serializable {
+
+
+  val numFeatures: Double = weightsPerFeature.getNumNondefaultElements
+  val totalWeightSum: Double = weightsPerLabel.zSum
+  val alphaVector: Vector = null
+
+  validate()
+
+  // todo: Maybe it is a good idea to move the dfsWrite and dfsRead out
+  // todo: of the model and into a helper
+
+  // TODO: weightsPerLabelAndFeature, a sparse (numFeatures x numLabels) matrix should fit
+  // TODO: upfront in memory and should not require a DRM decide if we want this to scale out.
+
+
+  /** getter for summed label weights.  Used by legacy classifier */
+  def labelWeight(label: Int): Double = {
+     weightsPerLabel.getQuick(label)
+  }
+
+  /** getter for weight normalizers.  Used by legacy classifier */
+  def thetaNormalizer(label: Int): Double = {
+    perlabelThetaNormalizer.get(label)
+  }
+
+  /** getter for summed feature weights.  Used by legacy classifier */
+  def featureWeight(feature: Int): Double = {
+    weightsPerFeature.getQuick(feature)
+  }
+
+  /** getter for individual aggregated weights.  Used by legacy classifier */
+  def weight(label: Int, feature: Int): Double = {
+    weightsPerLabelAndFeature.getQuick(label, feature)
+  }
+
+  /** getter for a single empty vector of weights */
+  def createScoringVector: Vector = {
+     weightsPerLabel.like
+  }
+
+  /** getter for a the number of labels to consider */
+  def numLabels: Int = {
+     weightsPerLabel.size
+  }
+
+  /**
+   * Write a trained model to the filesystem as a series of DRMs
+   * @param pathToModel Directory to which the model will be written
+   */
+  def dfsWrite(pathToModel: String)(implicit ctx: DistributedContext): Unit = {
+    //todo:  write out as smaller partitions or possibly use reader and writers to
+    //todo:  write something other than a DRM for label Index, is Complementary, alphaI.
+    drmParallelize(weightsPerLabelAndFeature).dfsWrite(pathToModel + "/weightsPerLabelAndFeatureDrm.drm")
+    drmParallelize(sparse(weightsPerFeature)).dfsWrite(pathToModel + "/weightsPerFeatureDrm.drm")
+    drmParallelize(sparse(weightsPerLabel)).dfsWrite(pathToModel + "/weightsPerLabelDrm.drm")
+    drmParallelize(sparse(perlabelThetaNormalizer)).dfsWrite(pathToModel + "/perlabelThetaNormalizerDrm.drm")
+    drmParallelize(sparse(svec((0,alphaI)::Nil))).dfsWrite(pathToModel + "/alphaIDrm.drm")
+
+    // isComplementry is true if isComplementaryDrm(0,0) == 1 else false
+    val isComplementaryDrm = sparse(0 to 1, 0 to 1)
+    if(isComplementary){
+      isComplementaryDrm(0,0) = 1.0
+    } else {
+      isComplementaryDrm(0,0) = 0.0
+    }
+    drmParallelize(isComplementaryDrm).dfsWrite(pathToModel + "/isComplementaryDrm.drm")
+
+    // write the label index as a String-Keyed DRM.
+    val labelIndexDummyDrm = weightsPerLabelAndFeature.like()
+    labelIndexDummyDrm.setRowLabelBindings(labelIndex)
+    // get a reverse map of [Integer, String] and set the value of firsr column of the drm
+    // to the corresponding row number for it's Label (the rows may not be read back in the same order)
+    val revMap = labelIndex.map(x => x._2 -> x._1)
+    for(i <- 0 until labelIndexDummyDrm.numRows() ){
+      labelIndexDummyDrm.set(labelIndex(revMap(i)), 0, i.toDouble)
+    }
+
+    drmParallelizeWithRowLabels(labelIndexDummyDrm).dfsWrite(pathToModel + "/labelIndex.drm")
+  }
+
+  /** Model Validation */
+  def validate() {
+    assert(alphaI > 0, "alphaI has to be greater than 0!")
+    assert(numFeatures > 0, "the vocab count has to be greater than 0!")
+    assert(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!")
+    assert(weightsPerLabel != null, "the number of labels has to be defined!")
+    assert(weightsPerLabel.getNumNondefaultElements > 0, "the number of labels has to be greater than 0!")
+    assert(weightsPerFeature != null, "the feature sums have to be defined")
+    assert(weightsPerFeature.getNumNondefaultElements > 0, "the feature sums have to be greater than 0!")
+    if (isComplementary) {
+      assert(perlabelThetaNormalizer != null, "the theta normalizers have to be defined")
+      assert(perlabelThetaNormalizer.getNumNondefaultElements > 0, "the number of theta normalizers has to be greater than 0!")
+      assert(Math.signum(perlabelThetaNormalizer.minValue) == Math.signum(perlabelThetaNormalizer.maxValue), "Theta normalizers do not all have the same sign")
+      assert(perlabelThetaNormalizer.getNumNonZeroElements == perlabelThetaNormalizer.size, "Weight normalizers can not have zero value.")
+    }
+    assert(labelIndex.size == weightsPerLabel.getNumNondefaultElements, "label index must have entries for all labels")
+  }
+}
+
+object NBModel extends java.io.Serializable {
+  /**
+   * Read a trained model in from from the filesystem.
+   * @param pathToModel directory from which to read individual model components
+   * @return a valid NBModel
+   */
+  def dfsRead(pathToModel: String)(implicit ctx: DistributedContext): NBModel = {
+    //todo:  Takes forever to read we need a more practical method of writing models. Readers/Writers?
+
+    val weightsPerFeatureDrm = drmDfsRead(pathToModel + "/weightsPerFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY)
+    val weightsPerFeature = weightsPerFeatureDrm.collect(0, ::)
+    weightsPerFeatureDrm.uncache()
+
+    val weightsPerLabelDrm = drmDfsRead(pathToModel + "/weightsPerLabelDrm.drm").checkpoint(CacheHint.MEMORY_ONLY)
+    val weightsPerLabel = weightsPerLabelDrm.collect(0, ::)
+    weightsPerLabelDrm.uncache()
+
+    val alphaIDrm = drmDfsRead(pathToModel + "/alphaIDrm.drm").checkpoint(CacheHint.MEMORY_ONLY)
+    val alphaI: Float = alphaIDrm.collect(0, 0).toFloat
+    alphaIDrm.uncache()
+
+    // isComplementry is true if isComplementaryDrm(0,0) == 1 else false
+    val isComplementaryDrm = drmDfsRead(pathToModel + "/isComplementaryDrm.drm").checkpoint(CacheHint.MEMORY_ONLY)
+    val isComplementary = isComplementaryDrm.collect(0, 0).toInt == 1
+    isComplementaryDrm.uncache()
+
+    var perLabelThetaNormalizer= weightsPerFeature.like()
+    if (isComplementary) {
+      val perLabelThetaNormalizerDrm = drm.drmDfsRead(pathToModel + "/perlabelThetaNormalizerDrm.drm")
+                                             .checkpoint(CacheHint.MEMORY_ONLY)
+      perLabelThetaNormalizer = perLabelThetaNormalizerDrm.collect(0, ::)
+    }
+
+    val dummyLabelDrm= drmDfsRead(pathToModel + "/labelIndex.drm")
+                         .checkpoint(CacheHint.MEMORY_ONLY)
+    val labelIndexMap:java.util.Map[String, Integer] = dummyLabelDrm.getRowLabelBindings
+    dummyLabelDrm.uncache()
+
+    // map the labels to the corresponding row numbers of weightsPerFeatureDrm (values in dummyLabelDrm)
+    val scalaLabelIndexMap: mutable.Map[String, Integer] =
+      labelIndexMap.map(x => x._1 -> dummyLabelDrm.get(labelIndexMap(x._1), 0)
+        .toInt
+        .asInstanceOf[Integer])
+
+    val weightsPerLabelAndFeatureDrm = drmDfsRead(pathToModel + "/weightsPerLabelAndFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY)
+    val weightsPerLabelAndFeature = weightsPerLabelAndFeatureDrm.collect
+    weightsPerLabelAndFeatureDrm.uncache()
+
+    // model validation is triggered automatically by constructor
+    val model: NBModel = new NBModel(weightsPerLabelAndFeature,
+      weightsPerFeature,
+      weightsPerLabel,
+      perLabelThetaNormalizer,
+      scalaLabelIndexMap,
+      alphaI,
+      isComplementary)
+
+    model
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala
new file mode 100644
index 0000000..bff6d48
--- /dev/null
+++ b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes
+
+import org.apache.mahout.classifier.stats.{ResultAnalyzer, ClassifierResult}
+import org.apache.mahout.math._
+import scalabindings._
+import scalabindings.RLikeOps._
+import drm.RLikeDrmOps._
+import drm._
+import scala.reflect.ClassTag
+import scala.language.asInstanceOf
+import collection._
+import scala.collection.JavaConversions._
+
+/**
+ * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor
+ * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf
+ */
+trait NaiveBayes extends java.io.Serializable{
+
+  /** default value for the Laplacian smoothing parameter */
+  def defaultAlphaI = 1.0f
+
+  // function to extract categories from string keys
+  type CategoryParser = String => String
+
+  /** Default: seqdirectory/seq2Sparse Categories are Stored in Drm Keys as: /Category/document_id */
+  def seq2SparseCategoryParser: CategoryParser = x => x.split("/")(1)
+
+
+  /**
+   * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor
+   * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf
+   *
+   * @param observationsPerLabel a DrmLike[Int] matrix containing term frequency counts for each label.
+   * @param trainComplementary whether or not to train a complementary Naive Bayes model
+   * @param alphaI Laplace smoothing parameter
+   * @return trained naive bayes model
+   */
+  def train(observationsPerLabel: DrmLike[Int],
+            labelIndex: Map[String, Integer],
+            trainComplementary: Boolean = true,
+            alphaI: Float = defaultAlphaI): NBModel = {
+
+    // Summation of all weights per feature
+    val weightsPerFeature = observationsPerLabel.colSums
+
+    // Distributed summation of all weights per label
+    val weightsPerLabel = observationsPerLabel.rowSums
+
+    // Collect a matrix to pass to the NaiveBayesModel
+    val inCoreTFIDF = observationsPerLabel.collect
+
+    // perLabelThetaNormalizer Vector is expected by NaiveBayesModel. We can pass a null value
+    // or Vector of zeroes in the case of a standard NB model.
+    var thetaNormalizer = weightsPerFeature.like()
+
+    // Instantiate a trainer and retrieve the perLabelThetaNormalizer Vector from it in the case of
+    // a complementary NB model
+    if (trainComplementary) {
+      val thetaTrainer = new ComplementaryNBThetaTrainer(weightsPerFeature,
+                                                         weightsPerLabel,
+                                                         alphaI)
+      // local training of the theta normalization
+      for (labelIndex <- 0 until inCoreTFIDF.nrow) {
+        thetaTrainer.train(labelIndex, inCoreTFIDF(labelIndex, ::))
+      }
+      thetaNormalizer = thetaTrainer.retrievePerLabelThetaNormalizer
+    }
+
+    new NBModel(inCoreTFIDF,
+                weightsPerFeature,
+                weightsPerLabel,
+                thetaNormalizer,
+                labelIndex,
+                alphaI,
+                trainComplementary)
+  }
+
+  /**
+   * Extract label Keys from raw TF or TF-IDF Matrix generated by seqdirectory/seq2sparse
+   * and aggregate TF or TF-IDF values by their label
+   * Override this method in engine specific modules to optimize
+   *
+   * @param stringKeyedObservations DrmLike matrix; Output from seq2sparse
+   *   in form K = eg./Category/document_title
+   *           V = TF or TF-IDF values per term
+   * @param cParser a String => String function used to extract categories from
+   *   Keys of the stringKeyedObservations DRM. The default
+   *   CategoryParser will extract "Category" from: '/Category/document_id'
+   * @return  (labelIndexMap,aggregatedByLabelObservationDrm)
+   *   labelIndexMap is a HashMap [String, Integer] K = label row index
+   *                                                V = label
+   *   aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated
+   *   TF or TF-IDF counts per label
+   */
+  def extractLabelsAndAggregateObservations[K: ClassTag](stringKeyedObservations: DrmLike[K],
+                                                         cParser: CategoryParser = seq2SparseCategoryParser)
+                                                        (implicit ctx: DistributedContext):
+                                                        (mutable.HashMap[String, Integer], DrmLike[Int])= {
+
+    stringKeyedObservations.checkpoint()
+
+    val numDocs=stringKeyedObservations.nrow
+    val numFeatures=stringKeyedObservations.ncol
+
+    // Extract categories from labels assigned by seq2sparse
+    // Categories are Stored in Drm Keys as eg.: /Category/document_id
+
+    // Get a new DRM with a single column so that we don't have to collect the
+    // DRM into memory upfront.
+    val strippedObeservations= stringKeyedObservations.mapBlock(ncol=1){
+      case(keys, block) =>
+        val blockB = block.like(keys.size, 1)
+        keys -> blockB
+    }
+
+    // Extract the row label bindings (the String keys) from the slim Drm
+    // strip the document_id from the row keys keeping only the category.
+    // Sort the bindings alphabetically into a Vector
+    val labelVectorByRowIndex = strippedObeservations
+                                  .getRowLabelBindings
+                                  .map(x => x._2 -> cParser(x._1))
+                                  .toVector.sortWith(_._1 < _._1)
+
+    //TODO: add a .toIntKeyed(...) method to DrmLike?
+
+    // Copy stringKeyedObservations to an Int-Keyed Drm so that we can compute transpose
+    // Copy the Collected Matrices up front for now until we hav a distributed way of converting
+    val inCoreStringKeyedObservations = stringKeyedObservations.collect
+    val inCoreIntKeyedObservations = new SparseMatrix(
+                             stringKeyedObservations.nrow.toInt,
+                             stringKeyedObservations.ncol)
+    for (i <- 0 until inCoreStringKeyedObservations.nrow.toInt) {
+      inCoreIntKeyedObservations(i, ::) = inCoreStringKeyedObservations(i, ::)
+    }
+
+    val intKeyedObservations= drmParallelize(inCoreIntKeyedObservations)
+
+    stringKeyedObservations.uncache()
+
+    var labelIndex = 0
+    val labelIndexMap = new mutable.HashMap[String, Integer]
+    val encodedLabelByRowIndexVector = new DenseVector(labelVectorByRowIndex.size)
+    
+    // Encode Categories as an Integer (Double) so we can broadcast as a vector
+    // where each element is an Int-encoded category whose index corresponds
+    // to its row in the Drm
+    for (i <- 0 until labelVectorByRowIndex.size) {
+      if (!(labelIndexMap.contains(labelVectorByRowIndex(i)._2))) {
+        encodedLabelByRowIndexVector(i) = labelIndex.toDouble
+        labelIndexMap.put(labelVectorByRowIndex(i)._2, labelIndex)
+        labelIndex += 1
+      }
+      // don't like this casting but need to use a java.lang.Integer when setting rowLabelBindings
+      encodedLabelByRowIndexVector(i) = labelIndexMap
+                                          .getOrElse(labelVectorByRowIndex(i)._2, -1)
+                                          .asInstanceOf[Int].toDouble
+    }
+
+    // "Combiner": Map and aggregate by Category. Do this by broadcasting the encoded
+    // category vector and mapping a transposed IntKeyed Drm out so that all categories
+    // will be present on all nodes as columns and can be referenced by
+    // BCastEncodedCategoryByRowVector.  Iteratively sum all categories.
+    val nLabels = labelIndex
+
+    val bcastEncodedCategoryByRowVector = drmBroadcast(encodedLabelByRowIndexVector)
+
+    val aggregetedObservationByLabelDrm = intKeyedObservations.t.mapBlock(ncol = nLabels) {
+      case (keys, blockA) =>
+        val blockB = blockA.like(keys.size, nLabels)
+        var label : Int = 0
+        for (i <- 0 until keys.size) {
+          blockA(i, ::).nonZeroes().foreach { elem =>
+            label = bcastEncodedCategoryByRowVector.get(elem.index).toInt
+            blockB(i, label) = blockB(i, label) + blockA(i, elem.index)
+          }
+        }
+        keys -> blockB
+    }.t
+
+    (labelIndexMap, aggregetedObservationByLabelDrm)
+  }
+
+  /**
+   * Test a trained model with a labeled dataset
+   * @param model a trained NBModel
+   * @param testSet a labeled testing set
+   * @param testComplementary test using a complementary or a standard NB classifier
+   * @param cParser a String => String function used to extract categories from
+   *   Keys of the testing set DRM. The default
+   *   CategoryParser will extract "Category" from: '/Category/document_id'
+   * @tparam K implicitly determined Key type of test set DRM: String
+   * @return a result analyzer with confusion matrix and accuracy statistics
+   */
+  def test[K: ClassTag](model: NBModel,
+                        testSet: DrmLike[K],
+                        testComplementary: Boolean = false,
+                        cParser: CategoryParser = seq2SparseCategoryParser)
+                        (implicit ctx: DistributedContext): ResultAnalyzer = {
+
+    val labelMap = model.labelIndex
+
+    val numLabels = model.numLabels
+
+    testSet.checkpoint()
+
+    val numTestInstances = testSet.nrow.toInt
+
+    // instantiate the correct type of classifier
+    val classifier = testComplementary match {
+      case true => new ComplementaryNBClassifier(model) with Serializable
+      case _ => new StandardNBClassifier(model) with Serializable
+    }
+    
+    if (testComplementary) {
+      assert(testComplementary == model.isComplementary,
+        "Complementary Label Assignment requires Complementary Training")
+    }
+
+    /**  need to change the model around so that we can broadcast it?            */
+    /*   for now just classifying each sequentially.                             */
+    /*
+    val bcastWeightMatrix = drmBroadcast(model.weightsPerLabelAndFeature)
+    val bcastFeatureWeights = drmBroadcast(model.weightsPerFeature)
+    val bcastLabelWeights = drmBroadcast(model.weightsPerLabel)
+    val bcastWeightNormalizers = drmBroadcast(model.perlabelThetaNormalizer)
+    val bcastLabelIndex = labelMap
+    val alphaI = model.alphaI
+    val bcastIsComplementary = model.isComplementary
+
+    val scoredTestSet = testSet.mapBlock(ncol = numLabels){
+      case (keys, block)=>
+        val closureModel = new NBModel(bcastWeightMatrix,
+                                       bcastFeatureWeights,
+                                       bcastLabelWeights,
+                                       bcastWeightNormalizers,
+                                       bcastLabelIndex,
+                                       alphaI,
+                                       bcastIsComplementary)
+        val classifier = closureModel match {
+          case xx if model.isComplementary => new ComplementaryNBClassifier(closureModel)
+          case _ => new StandardNBClassifier(closureModel)
+        }
+        val numInstances = keys.size
+        val blockB= block.like(numInstances, numLabels)
+        for(i <- 0 until numInstances){
+          blockB(i, ::) := classifier.classifyFull(block(i, ::) )
+        }
+        keys -> blockB
+    }
+
+    // may want to strip this down if we think that numDocuments x numLabels wont fit into memory
+    val testSetLabelMap = scoredTestSet.getRowLabelBindings
+
+    // collect so that we can slice rows.
+    val inCoreScoredTestSet = scoredTestSet.collect
+
+    testSet.uncache()
+    */
+
+
+    /** Sequentially: */
+
+    // Since we cant broadcast the model as is do it sequentially up front for now
+    val inCoreTestSet = testSet.collect
+
+    // get the labels of the test set and extract the keys
+    val testSetLabelMap = testSet.getRowLabelBindings //.map(x => cParser(x._1) -> x._2)
+
+    // empty Matrix in which we'll set the classification scores
+    val inCoreScoredTestSet = testSet.like(numTestInstances, numLabels)
+
+    testSet.uncache()
+    
+    for (i <- 0 until numTestInstances) {
+      inCoreScoredTestSet(i, ::) := classifier.classifyFull(inCoreTestSet(i, ::))
+    }
+
+    // todo: reverse the labelMaps in training and through the model?
+
+    // reverse the label map and extract the labels
+    val reverseTestSetLabelMap = testSetLabelMap.map(x => x._2 -> cParser(x._1))
+
+    val reverseLabelMap = labelMap.map(x => x._2 -> x._1)
+
+    val analyzer = new ResultAnalyzer(labelMap.keys.toList.sorted, "DEFAULT")
+
+    // need to do this with out collecting
+    // val inCoreScoredTestSet = scoredTestSet.collect
+    for (i <- 0 until numTestInstances) {
+      val (bestIdx, bestScore) = argmax(inCoreScoredTestSet(i,::))
+      val classifierResult = new ClassifierResult(reverseLabelMap(bestIdx), bestScore)
+      analyzer.addInstance(reverseTestSetLabelMap(i), classifierResult)
+    }
+
+    analyzer
+  }
+
+  /**
+   * argmax with values as well
+   * returns a tuple of index of the max score and the score itself.
+   * @param v Vector of of scores
+   * @return  (bestIndex, bestScore)
+   */
+  def argmax(v: Vector): (Int, Double) = {
+    var bestIdx: Int = Integer.MIN_VALUE
+    var bestScore: Double = Integer.MIN_VALUE.asInstanceOf[Int].toDouble
+    for(i <- 0 until v.size) {
+      if(v(i) > bestScore){
+        bestScore = v(i)
+        bestIdx = i
+      }
+    }
+    (bestIdx, bestScore)
+  }
+
+}
+
+object NaiveBayes extends NaiveBayes with java.io.Serializable
+
+/**
+ * Trainer for the weight normalization vector used by Transform Weight Normalized Complement
+ * Naive Bayes.  See: Rennie et.al.: Tackling the poor assumptions of Naive Bayes Text classifiers,
+ * ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf Sec. 3.2.
+ *
+ * @param weightsPerFeature a Vector of summed TF or TF-IDF weights for each word in dictionary.
+ * @param weightsPerLabel a Vector of summed TF or TF-IDF weights for each label.
+ * @param alphaI Laplace smoothing factor. Defaut value of 1.
+ */
+class ComplementaryNBThetaTrainer(private val weightsPerFeature: Vector,
+                                  private val weightsPerLabel: Vector,
+                                  private val alphaI: Double = 1.0) {
+                                   
+   private val perLabelThetaNormalizer: Vector = weightsPerLabel.like()
+   private val totalWeightSum: Double = weightsPerLabel.zSum
+   private var numFeatures: Double = weightsPerFeature.getNumNondefaultElements
+
+   assert(weightsPerFeature != null, "weightsPerFeature vector can not be null")
+   assert(weightsPerLabel != null, "weightsPerLabel vector can not be null")
+
+  /**
+   * Train the weight normalization vector for each label
+   * @param label
+   * @param featurePerLabelWeight
+   */
+  def train(label: Int, featurePerLabelWeight: Vector) {
+    val currentLabelWeight = labelWeight(label)
+    // sum weights for each label including those with zero word counts
+    for (i <- 0 until featurePerLabelWeight.size) {
+      val currentFeaturePerLabelWeight = featurePerLabelWeight(i)
+      updatePerLabelThetaNormalizer(label,
+        ComplementaryNBClassifier.computeWeight(featureWeight(i),
+                                                currentFeaturePerLabelWeight,
+                                                totalWeightSum,
+                                                currentLabelWeight,
+                                                alphaI,
+                                                numFeatures)
+                                   )
+    }
+  }
+
+  /**
+   * getter for summed TF or TF-IDF weights by label
+   * @param label index of label
+   * @return sum of word TF or TF-IDF weights for label
+   */
+  def labelWeight(label: Int): Double = {
+    weightsPerLabel(label)
+  }
+
+  /**
+   * getter for summed TF or TF-IDF weights by word.
+   * @param feature index of word.
+   * @return sum of TF or TF-IDF weights for word.
+   */
+  def featureWeight(feature: Int): Double = {
+    weightsPerFeature(feature)
+  }
+
+  /**
+   * add the magnitude of the current weight to the current
+   * label's corresponding Vector element.
+   * @param label index of label to update.
+   * @param weight weight to add.
+   */
+  def updatePerLabelThetaNormalizer(label: Int, weight: Double) {
+    perLabelThetaNormalizer(label) = perLabelThetaNormalizer(label) + Math.abs(weight)
+  }
+
+  /**
+   * Getter for the weight normalizer vector as indexed by label
+   * @return a copy of the weight normalizer vector.
+   */
+  def retrievePerLabelThetaNormalizer: Vector = {
+    perLabelThetaNormalizer.cloned
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala
new file mode 100644
index 0000000..8f1413a
--- /dev/null
+++ b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala
@@ -0,0 +1,467 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements.  See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package org.apache.mahout.classifier.stats
+
+import java.text.{DecimalFormat, NumberFormat}
+import java.util
+import org.apache.mahout.math.stats.OnlineSummarizer
+
+
+/**
+ * Result of a document classification. The label and the associated score (usually probabilty)
+ */
+class ClassifierResult (private var label: String = null,
+                        private var score: Double = 0.0,
+                        private var logLikelihood: Double = Integer.MAX_VALUE.toDouble) {
+
+  def getLogLikelihood: Double = logLikelihood
+
+  def setLogLikelihood(llh: Double) {
+    logLikelihood = llh
+  }
+
+  def getLabel: String = label
+
+  def getScore: Double = score
+
+  def setLabel(lbl: String) {
+    label = lbl
+  }
+
+  def setScore(sc: Double) {
+    score = sc
+  }
+
+  override def toString: String = {
+     "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}'
+  }
+
+}
+
+/**
+ * ResultAnalyzer captures the classification statistics and displays in a tabular manner
+ * @param labelSet Set of labels to be considered in classification
+ * @param defaultLabel  the default label for an unknown classification
+ */
+class ResultAnalyzer(private val labelSet: util.Collection[String], defaultLabel: String) {
+
+  val confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel)
+  val summarizer = new OnlineSummarizer
+
+  private var hasLL: Boolean = false
+  private var correctlyClassified: Int = 0
+  private var incorrectlyClassified: Int = 0
+
+
+  def getConfusionMatrix: ConfusionMatrix = confusionMatrix
+
+  /**
+   *
+   * @param correctLabel
+   * The correct label
+   * @param classifiedResult
+   * The classified result
+   * @return whether the instance was correct or not
+   */
+  def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Boolean = {
+    val result: Boolean = correctLabel == classifiedResult.getLabel
+    if (result) {
+      correctlyClassified += 1
+    }
+    else {
+      incorrectlyClassified += 1
+    }
+    confusionMatrix.addInstance(correctLabel, classifiedResult)
+    if (classifiedResult.getLogLikelihood != Integer.MAX_VALUE.toDouble) {
+      summarizer.add(classifiedResult.getLogLikelihood)
+      hasLL = true
+    }
+
+    result
+  }
+
+  /** Dump the resulting statistics to a string */
+  override def toString: String = {
+    val returnString: StringBuilder = new StringBuilder
+    returnString.append('\n')
+    returnString.append("=======================================================\n")
+    returnString.append("Summary\n")
+    returnString.append("-------------------------------------------------------\n")
+    val totalClassified: Int = correctlyClassified + incorrectlyClassified
+    val percentageCorrect: Double = 100.asInstanceOf[Double] * correctlyClassified / totalClassified
+    val percentageIncorrect: Double = 100.asInstanceOf[Double] * incorrectlyClassified / totalClassified
+    val decimalFormatter: NumberFormat = new DecimalFormat("0.####")
+    returnString.append("Correctly Classified Instances")
+                .append(": ")
+                .append(Integer.toString(correctlyClassified))
+                .append('\t')
+                .append(decimalFormatter.format(percentageCorrect))
+                .append("%\n")
+    returnString.append("Incorrectly Classified Instances")
+                .append(": ")
+                .append(Integer.toString(incorrectlyClassified))
+                .append('\t')
+                .append(decimalFormatter.format(percentageIncorrect))
+                .append("%\n")
+    returnString.append("Total Classified Instances")
+                .append(": ")
+                .append(Integer.toString(totalClassified))
+                .append('\n')
+    returnString.append('\n')
+    returnString.append(confusionMatrix)
+    returnString.append("=======================================================\n")
+    returnString.append("Statistics\n")
+    returnString.append("-------------------------------------------------------\n")
+    val normStats: RunningAverageAndStdDev = confusionMatrix.getNormalizedStats
+    returnString.append("Kappa: \t")
+                .append(decimalFormatter.format(confusionMatrix.getKappa))
+                .append('\n')
+    returnString.append("Accuracy: \t")
+                .append(decimalFormatter.format(confusionMatrix.getAccuracy))
+                .append("%\n")
+    returnString.append("Reliability: \t")
+                .append(decimalFormatter.format(normStats.getAverage * 100.00000001))
+                .append("%\n")
+    returnString.append("Reliability (std dev): \t")
+                .append(decimalFormatter.format(normStats.getStandardDeviation))
+                .append('\n')
+    returnString.append("Weighted precision: \t")
+                .append(decimalFormatter.format(confusionMatrix.getWeightedPrecision))
+                .append('\n')
+    returnString.append("Weighted recall: \t")
+                .append(decimalFormatter.format(confusionMatrix.getWeightedRecall))
+                .append('\n')
+    returnString.append("Weighted F1 score: \t")
+                .append(decimalFormatter.format(confusionMatrix.getWeightedF1score))
+                .append('\n')
+    if (hasLL) {
+      returnString.append("Log-likelihood: \t")
+                  .append("mean      :  \t")
+                  .append(decimalFormatter.format(summarizer.getMean))
+                  .append('\n')
+      returnString.append("25%-ile   :  \t")
+                  .append(decimalFormatter.format(summarizer.getQuartile(1)))
+                  .append('\n')
+      returnString.append("75%-ile   :  \t")
+                  .append(decimalFormatter.format(summarizer.getQuartile(3)))
+                  .append('\n')
+    }
+
+    returnString.toString()
+  }
+
+
+}
+
+/**
+ *
+ * Interface for classes that can keep track of a running average of a series of numbers. One can add to or
+ * remove from the series, as well as update a datum in the series. The class does not actually keep track of
+ * the series of values, just its running average, so it doesn't even matter if you remove/change a value that
+ * wasn't added.
+ *
+ * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverage.java
+ */
+trait RunningAverage {
+
+  /**
+   * @param datum
+   * new item to add to the running average
+   * @throws IllegalArgumentException
+   * if datum is { @link Double#NaN}
+   */
+  def addDatum(datum: Double)
+
+  /**
+   * @param datum
+   * item to remove to the running average
+   * @throws IllegalArgumentException
+   * if datum is { @link Double#NaN}
+   * @throws IllegalStateException
+   * if count is 0
+   */
+  def removeDatum(datum: Double)
+
+  /**
+   * @param delta
+   * amount by which to change a datum in the running average
+   * @throws IllegalArgumentException
+   * if delta is { @link Double#NaN}
+   * @throws IllegalStateException
+   * if count is 0
+   */
+  def changeDatum(delta: Double)
+
+  def getCount: Int
+
+  def getAverage: Double
+
+  /**
+   * @return a (possibly immutable) object whose average is the negative of this object's
+   */
+  def inverse: RunningAverage
+}
+
+/**
+ *
+ * Extends {@link RunningAverage} by adding standard deviation too.
+ *
+ * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev.java
+ */
+trait RunningAverageAndStdDev extends RunningAverage {
+
+  /** @return standard deviation of data */
+  def getStandardDeviation: Double
+
+  /**
+   * @return a (possibly immutable) object whose average is the negative of this object's
+   */
+  def inverse: RunningAverageAndStdDev
+}
+
+
+class InvertedRunningAverage(private val delegate: RunningAverage) extends RunningAverage {
+
+  override def addDatum(datum: Double) {
+    throw new UnsupportedOperationException
+  }
+
+  override def removeDatum(datum: Double) {
+    throw new UnsupportedOperationException
+  }
+
+  override def changeDatum(delta: Double) {
+    throw new UnsupportedOperationException
+  }
+
+  override def getCount: Int = {
+     delegate.getCount
+  }
+
+  override def getAverage: Double = {
+     -delegate.getAverage
+  }
+
+  override def inverse: RunningAverage = {
+     delegate
+  }
+}
+
+
+/**
+ *
+ * A simple class that can keep track of a running average of a series of numbers. One can add to or remove
+ * from the series, as well as update a datum in the series. The class does not actually keep track of the
+ * series of values, just its running average, so it doesn't even matter if you remove/change a value that
+ * wasn't added.
+ *
+ * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverage.java
+ */
+class FullRunningAverage(private var count: Int = 0,
+                         private var average: Double = Double.NaN ) extends RunningAverage {
+
+  /**
+   * @param datum
+   * new item to add to the running average
+   */
+  override def addDatum(datum: Double) {
+    count += 1
+    if (count == 1) {
+      average = datum
+    }
+    else {
+      average = average * (count - 1) / count + datum / count
+    }
+  }
+
+  /**
+   * @param datum
+   * item to remove from the running average
+   * @throws IllegalStateException
+   * if count is 0
+   */
+  override def removeDatum(datum: Double) {
+    if (count == 0) {
+      throw new IllegalStateException
+    }
+    count -= 1
+    if (count == 0) {
+      average = Double.NaN
+    }
+    else {
+      average = average * (count + 1) / count - datum / count
+    }
+  }
+
+  /**
+   * @param delta
+   * amount by which to change a datum in the running average
+   * @throws IllegalStateException
+   * if count is 0
+   */
+  override def changeDatum(delta: Double) {
+    if (count == 0) {
+      throw new IllegalStateException
+    }
+    average += delta / count
+  }
+
+  override def getCount: Int = {
+    count
+  }
+
+  override def getAverage: Double = {
+    average
+  }
+
+  override def inverse: RunningAverage = {
+    new InvertedRunningAverage(this)
+  }
+
+  override def toString: String = {
+    String.valueOf(average)
+  }
+}
+
+
+/**
+ *
+ * Extends {@link FullRunningAverage} to add a running standard deviation computation.
+ * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html
+ *
+ * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev.java
+ */
+class FullRunningAverageAndStdDev(private var count: Int = 0,
+                                  private var average: Double = 0.0,
+                                  private var mk: Double = 0.0,
+                                  private var sk: Double = 0.0) extends FullRunningAverage with RunningAverageAndStdDev {
+
+  var stdDev: Double = 0.0
+
+  recomputeStdDev
+
+  def getMk: Double = {
+     mk
+  }
+
+  def getSk: Double = {
+    sk
+  }
+
+  override def getStandardDeviation: Double = {
+    stdDev
+  }
+
+  override def addDatum(datum: Double) {
+    super.addDatum(datum)
+    val count: Int = getCount
+    if (count == 1) {
+      mk = datum
+      sk = 0.0
+    }
+    else {
+      val oldmk: Double = mk
+      val diff: Double = datum - oldmk
+      mk += diff / count
+      sk += diff * (datum - mk)
+    }
+    recomputeStdDev
+  }
+
+  override def removeDatum(datum: Double) {
+    val oldCount: Int = getCount
+    super.removeDatum(datum)
+    val oldmk: Double = mk
+    mk = (oldCount * oldmk - datum) / (oldCount - 1)
+    sk -= (datum - mk) * (datum - oldmk)
+    recomputeStdDev
+  }
+
+  /**
+   * @throws UnsupportedOperationException
+   */
+  override def changeDatum(delta: Double) {
+    throw new UnsupportedOperationException
+  }
+
+  private def recomputeStdDev {
+    val count: Int = getCount
+    stdDev = if (count > 1) Math.sqrt(sk / (count - 1)) else Double.NaN
+  }
+
+  override def inverse: RunningAverageAndStdDev = {
+     new InvertedRunningAverageAndStdDev(this)
+  }
+
+  override def toString: String = {
+     String.valueOf(String.valueOf(getAverage) + ',' + stdDev)
+  }
+
+}
+
+
+/**
+ *
+ * @param delegate RunningAverageAndStdDev instance
+ *
+ * Ported from org.apache.mahout.cf.taste.impl.common.InvertedRunningAverageAndStdDev.java
+ */
+class InvertedRunningAverageAndStdDev(private val delegate: RunningAverageAndStdDev) extends RunningAverageAndStdDev {
+
+  /**
+   * @throws UnsupportedOperationException
+   */
+  override def addDatum(datum: Double) {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * @throws UnsupportedOperationException
+   */
+
+  override def removeDatum(datum: Double) {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * @throws UnsupportedOperationException
+   */
+  override def changeDatum(delta: Double) {
+    throw new UnsupportedOperationException
+  }
+
+  override def getCount: Int = {
+     delegate.getCount
+  }
+
+  override def getAverage: Double = {
+     -delegate.getAverage
+  }
+
+  override def getStandardDeviation: Double = {
+     delegate.getStandardDeviation
+  }
+
+  override def inverse: RunningAverageAndStdDev = {
+     delegate
+  }
+}
+
+
+
+

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala
new file mode 100644
index 0000000..328d27b
--- /dev/null
+++ b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala
@@ -0,0 +1,460 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements.  See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package org.apache.mahout.classifier.stats
+
+import java.util
+import org.apache.commons.math3.stat.descriptive.moment.Mean // This is brought in by mahout-math
+import org.apache.mahout.math.{DenseMatrix, Matrix}
+import scala.collection.mutable
+import scala.collection.JavaConversions._
+
+/**
+ *
+ * Ported from org.apache.mahout.classifier.ConfusionMatrix.java
+ *
+ * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
+ *
+ * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default.
+ *
+ * See http://en.wikipedia.org/wiki/Confusion_matrix for background
+ *
+ *
+ * @param labels The labels to consider for classification
+ * @param defaultLabel default unknown label
+ */
+class ConfusionMatrix(private var labels: util.Collection[String] = null,
+                      private var defaultLabel: String = "unknown")  {
+  /**
+   * Matrix Constructor
+   * @param m a DenseMatrix with RowLabelBindings
+   */
+//   def this(m: Matrix) {
+//     this()
+//     confusionMatrix = Array.ofDim[Int](m.numRows, m.numRows)
+//     setMatrix(m)
+//   }
+
+   // val LOG: Logger = LoggerFactory.getLogger(classOf[ConfusionMatrix])
+
+  var confusionMatrix = Array.ofDim[Int](labels.size + 1, labels.size + 1)
+
+  val labelMap = new mutable.HashMap[String,Integer]()
+
+  var samples: Int = 0
+
+  var i: Integer = 0
+  for (label <- labels) {
+    labelMap.put(label, i)
+    i+=1
+  }
+  labelMap.put(defaultLabel, i)
+
+
+  def getConfusionMatrix: Array[Array[Int]] = confusionMatrix
+
+  def getLabels = labelMap.keys.toList
+
+  def numLabels: Int = labelMap.size
+
+  def getAccuracy(label: String): Double = {
+    val labelId: Int = labelMap(label)
+    var labelTotal: Int = 0
+    var correct: Int = 0
+    for (i <- 0 until numLabels) {
+      labelTotal += confusionMatrix(labelId)(i)
+      if (i == labelId) {
+        correct += confusionMatrix(labelId)(i)
+      }
+    }
+
+    100.0 * correct / labelTotal
+  }
+
+  def getAccuracy: Double = {
+    var total: Int = 0
+    var correct: Int = 0
+    for (i <- 0 until numLabels) {
+      for (j <- 0 until numLabels) {
+        total += confusionMatrix(i)(j)
+        if (i == j) {
+          correct += confusionMatrix(i)(j)
+        }
+      }
+    }
+
+    100.0 * correct / total
+  }
+
+  /** Sum of true positives and false negatives */
+  private def getActualNumberOfTestExamplesForClass(label: String): Int = {
+    val labelId: Int = labelMap(label)
+    var sum: Int = 0
+    for (i <- 0 until numLabels) {
+      sum += confusionMatrix(labelId)(i)
+    }
+    sum
+  }
+
+  def getPrecision(label: String): Double = {
+    val labelId: Int = labelMap(label)
+    val truePositives: Int = confusionMatrix(labelId)(labelId)
+    var falsePositives: Int = 0
+
+    for (i <- 0 until numLabels) {
+      if (i != labelId) {
+        falsePositives += confusionMatrix(i)(labelId)
+      }
+    }
+
+    if (truePositives + falsePositives == 0) {
+      0
+    } else {
+      (truePositives.asInstanceOf[Double]) / (truePositives + falsePositives)
+    }
+  }
+
+
+  def getWeightedPrecision: Double = {
+    val precisions: Array[Double] = new Array[Double](numLabels)
+    val weights: Array[Double] = new Array[Double](numLabels)
+    var index: Int = 0
+    for (label <- labelMap.keys) {
+      precisions(index) = getPrecision(label)
+      weights(index) = getActualNumberOfTestExamplesForClass(label)
+      index += 1
+    }
+    new Mean().evaluate(precisions, weights)
+  }
+
+  def getRecall(label: String): Double = {
+    val labelId: Int = labelMap(label)
+    val truePositives: Int = confusionMatrix(labelId)(labelId)
+    var falseNegatives: Int = 0
+    for (i <- 0 until numLabels) {
+      if (i != labelId) {
+        falseNegatives += confusionMatrix(labelId)(i)
+      }
+    }
+
+    if (truePositives + falseNegatives == 0) {
+      0
+    } else {
+      (truePositives.asInstanceOf[Double]) / (truePositives + falseNegatives)
+    }
+  }
+
+  def getWeightedRecall: Double = {
+    val recalls: Array[Double] = new Array[Double](numLabels)
+    val weights: Array[Double] = new Array[Double](numLabels)
+    var index: Int = 0
+    for (label <- labelMap.keys) {
+      recalls(index) = getRecall(label)
+      weights(index) = getActualNumberOfTestExamplesForClass(label)
+      index += 1
+    }
+    new Mean().evaluate(recalls, weights)
+  }
+
+  def getF1score(label: String): Double = {
+    val precision: Double = getPrecision(label)
+    val recall: Double = getRecall(label)
+    if (precision + recall == 0) {
+      0
+    } else {
+      2 * precision * recall / (precision + recall)
+    }
+  }
+
+  def getWeightedF1score: Double = {
+    val f1Scores: Array[Double] = new Array[Double](numLabels)
+    val weights: Array[Double] = new Array[Double](numLabels)
+    var index: Int = 0
+    for (label <- labelMap.keys) {
+      f1Scores(index) = getF1score(label)
+      weights(index) = getActualNumberOfTestExamplesForClass(label)
+      index += 1
+    }
+    new Mean().evaluate(f1Scores, weights)
+  }
+
+  def getReliability: Double = {
+    var count: Int = 0
+    var accuracy: Double = 0
+    for (label <- labelMap.keys) {
+      if (!(label == defaultLabel)) {
+        accuracy += getAccuracy(label)
+      }
+      count += 1
+    }
+    accuracy / count
+  }
+
+  /**
+   * Accuracy v.s. randomly classifying all samples.
+   * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
+   * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales.
+   * Educational And Psychological Measurement 20:37-46.
+   *
+   * Formula and variable names from:
+   * http://www.yale.edu/ceo/OEFS/Accuracy.pdf
+   *
+   * @return double
+   */
+  def getKappa: Double = {
+    var a: Double = 0.0
+    var b: Double = 0.0
+    for (i <- 0 until confusionMatrix.length) {
+      a += confusionMatrix(i)(i)
+      var br: Int = 0
+      for (j <- 0 until confusionMatrix.length) {
+        br += confusionMatrix(i)(j)
+      }
+      var bc: Int = 0
+      //TODO: verify this as an iterator
+      for (vec <- confusionMatrix) {
+        bc += vec(i)
+      }
+      b += br * bc
+    }
+    (samples * a - b) / (samples * samples - b)
+  }
+
+  def getCorrect(label: String): Int = {
+    val labelId: Int = labelMap(label)
+    confusionMatrix(labelId)(labelId)
+  }
+
+  def getTotal(label: String): Int = {
+    val labelId: Int = labelMap(label)
+    var labelTotal: Int = 0
+    for (i <- 0 until numLabels) {
+      labelTotal += confusionMatrix(labelId)(i)
+    }
+    labelTotal
+  }
+
+  /**
+   * Standard deviation of normalized producer accuracy
+   * Not a standard score
+   * @return double
+   */
+  def getNormalizedStats: RunningAverageAndStdDev = {
+    val summer = new FullRunningAverageAndStdDev()
+    for (d <- 0 until  confusionMatrix.length) {
+      var total: Double = 0.0
+      for (j <- 0 until  confusionMatrix.length) {
+        total += confusionMatrix(d)(j)
+      }
+      summer.addDatum(confusionMatrix(d)(d) / (total + 0.000001))
+    }
+    summer
+  }
+
+  def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Unit = {
+    samples += 1
+    incrementCount(correctLabel, classifiedResult.getLabel)
+  }
+
+  def addInstance(correctLabel: String, classifiedLabel: String): Unit = {
+    samples += 1
+    incrementCount(correctLabel, classifiedLabel)
+  }
+
+  def getCount(correctLabel: String, classifiedLabel: String): Int = {
+    if (!labelMap.containsKey(correctLabel)) {
+    //  LOG.warn("Label {} did not appear in the training examples", correctLabel)
+      return 0
+    }
+    assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel)
+    val correctId: Int = labelMap(correctLabel)
+    val classifiedId: Int = labelMap(classifiedLabel)
+    confusionMatrix(correctId)(classifiedId)
+  }
+
+  def putCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = {
+    if (!labelMap.containsKey(correctLabel)) {
+    //  LOG.warn("Label {} did not appear in the training examples", correctLabel)
+      return
+    }
+    assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel)
+    val correctId: Int = labelMap(correctLabel)
+    val classifiedId: Int = labelMap(classifiedLabel)
+    if (confusionMatrix(correctId)(classifiedId) == 0.0 && count != 0) {
+      samples += 1
+    }
+    confusionMatrix(correctId)(classifiedId) = count
+  }
+
+  def incrementCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = {
+    putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel))
+  }
+
+  def incrementCount(correctLabel: String, classifiedLabel: String): Unit = {
+    incrementCount(correctLabel, classifiedLabel, 1)
+  }
+
+  def getDefaultLabel: String = {
+    defaultLabel
+  }
+
+  def merge(b: ConfusionMatrix): ConfusionMatrix = {
+    assert(labelMap.size == b.getLabels.size, "The label sizes do not match")
+    for (correctLabel <- this.labelMap.keys) {
+      for (classifiedLabel <- this.labelMap.keys) {
+        incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel))
+      }
+    }
+    this
+  }
+
+  def getMatrix: Matrix = {
+    val length: Int = confusionMatrix.length
+    val m: Matrix = new DenseMatrix(length, length)
+
+    val labels: java.util.HashMap[String, Integer] = new java.util.HashMap()
+
+    for (r <- 0 until length) {
+      for (c <- 0 until length) {
+        m.set(r, c, confusionMatrix(r)(c))
+      }
+    }
+
+    for (entry <- labelMap.entrySet) {
+      labels.put(entry.getKey, entry.getValue)
+    }
+    m.setRowLabelBindings(labels)
+    m.setColumnLabelBindings(labels)
+
+    m
+  }
+
+  def setMatrix(m: Matrix) : Unit = {
+    val length: Int = confusionMatrix.length
+    if (m.numRows != m.numCols) {
+      throw new IllegalArgumentException("ConfusionMatrix: matrix(" + m.numRows + ',' + m.numCols + ") must be square")
+    }
+
+    for (r <- 0 until length) {
+      for (c <- 0 until length) {
+        confusionMatrix(r)(c) = Math.round(m.get(r, c)).toInt
+      }
+    }
+
+    var labels = m.getRowLabelBindings
+    if (labels == null) {
+      labels = m.getColumnLabelBindings
+    }
+
+    if (labels != null) {
+      val sorted: Array[String] = sortLabels(labels)
+      verifyLabels(length, sorted)
+      labelMap.clear
+      for (i <- 0 until length) {
+        labelMap.put(sorted(i), i)
+      }
+    }
+  }
+
+  def verifyLabels(length: Int, sorted: Array[String]): Unit = {
+    assert(sorted.length == length, "One label, one row")
+    for (i <- 0 until length) {
+      if (sorted(i) == null) {
+        assert(false, "One label, one row")
+      }
+    }
+  }
+
+  def sortLabels(labels: java.util.Map[String, Integer]): Array[String] = {
+    val sorted: Array[String] = new Array[String](labels.size)
+    for (entry <- labels.entrySet) {
+      sorted(entry.getValue) = entry.getKey
+    }
+
+    sorted
+  }
+
+  /**
+   * This is overloaded. toString() is not a formatted report you print for a manager :)
+   * Assume that if there are no default assignments, the default feature was not used
+   */
+  override def toString: String = {
+
+    val returnString: StringBuilder = new StringBuilder(200)
+
+    returnString.append("=======================================================").append('\n')
+    returnString.append("Confusion Matrix\n")
+    returnString.append("-------------------------------------------------------").append('\n')
+
+    val unclassified: Int = getTotal(defaultLabel)
+
+    for (entry <- this.labelMap.entrySet) {
+      if (!((entry.getKey == defaultLabel) && unclassified == 0)) {
+        returnString.append(getSmallLabel(entry.getValue) + "     ").append('\t')
+      }
+    }
+
+    returnString.append("<--Classified as").append('\n')
+
+    for (entry <- this.labelMap.entrySet) {
+      if (!((entry.getKey == defaultLabel) && unclassified == 0)) {
+        val correctLabel: String = entry.getKey
+        var labelTotal: Int = 0
+
+        for (classifiedLabel <- this.labelMap.keySet) {
+          if (!((classifiedLabel == defaultLabel) && unclassified == 0)) {
+            returnString.append(Integer.toString(getCount(correctLabel, classifiedLabel)) + "     ")
+                        .append('\t')
+            labelTotal += getCount(correctLabel, classifiedLabel)
+          }
+        }
+        returnString.append(" |  ").append(String.valueOf(labelTotal) + "      ")
+                    .append('\t')
+                    .append(getSmallLabel(entry.getValue) + "     ")
+                    .append(" = ")
+                    .append(correctLabel)
+                    .append('\n')
+      }
+    }
+
+    if (unclassified > 0) {
+      returnString.append("Default Category: ")
+                  .append(defaultLabel)
+                  .append(": ")
+                  .append(unclassified)
+                  .append('\n')
+    }
+    returnString.append('\n')
+
+    returnString.toString()
+  }
+
+
+  def getSmallLabel(i: Int): String = {
+    var value: Int = i
+    val returnString: StringBuilder = new StringBuilder
+    do {
+      val n: Int = value % 26
+      returnString.insert(0, ('a' + n).asInstanceOf[Char])
+      value /= 26
+    } while (value > 0)
+
+    returnString.toString()
+  }
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/test/scala/org/apache/mahout/classifier/naivebayes/NBTestBase.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/classifier/naivebayes/NBTestBase.scala b/math-scala/src/test/scala/org/apache/mahout/classifier/naivebayes/NBTestBase.scala
new file mode 100644
index 0000000..67b1c08
--- /dev/null
+++ b/math-scala/src/test/scala/org/apache/mahout/classifier/naivebayes/NBTestBase.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes
+
+import org.apache.mahout.math._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.test.DistributedMahoutSuite
+import org.apache.mahout.test.MahoutSuite
+import org.scalatest.{FunSuite, Matchers}
+import collection._
+import JavaConversions._
+import collection.JavaConversions
+
+trait NBTestBase extends DistributedMahoutSuite with Matchers { this:FunSuite =>
+
+  val epsilon = 1E-6
+
+  test("Simple Standard NB Model") {
+
+    // test from simulated sparse TF-IDF data
+    val inCoreTFIDF = sparse(
+      (0, 0.7) ::(1, 0.1) ::(2, 0.1) ::(3, 0.3) :: Nil,
+      (0, 0.4) ::(1, 0.4) ::(2, 0.1) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.0) ::(2, 0.8) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.1) ::(2, 0.1) ::(3, 0.7) :: Nil
+    )
+
+    val TFIDFDrm = drm.drmParallelize(m = inCoreTFIDF, numPartitions = 2)
+
+    val labelIndex = new java.util.HashMap[String,Integer]()
+    labelIndex.put("Cat1", 3)
+    labelIndex.put("Cat2", 2)
+    labelIndex.put("Cat3", 1)
+    labelIndex.put("Cat4", 0)
+
+    // train a Standard NB Model
+    val model = NaiveBayes.train(TFIDFDrm, labelIndex, false)
+
+    // validate the model- will throw an exception if model is invalid
+    model.validate()
+
+    // check the labelWeights
+    model.labelWeight(0) - 1.2 should be < epsilon
+    model.labelWeight(1) - 1.0 should be < epsilon
+    model.labelWeight(2) - 1.0 should be < epsilon
+    model.labelWeight(3) - 1.0 should be < epsilon
+
+    // check the Feature weights
+    model.featureWeight(0) - 1.3 should be < epsilon
+    model.featureWeight(1) - 0.6 should be < epsilon
+    model.featureWeight(2) - 1.1 should be < epsilon
+    model.featureWeight(3) - 1.2 should be < epsilon
+  }
+
+  test("NB Aggregator") {
+
+    val rowBindings = new java.util.HashMap[String,Integer]()
+    rowBindings.put("/Cat1/doc_a/", 0)
+    rowBindings.put("/Cat2/doc_b/", 1)
+    rowBindings.put("/Cat1/doc_c/", 2)
+    rowBindings.put("/Cat2/doc_d/", 3)
+    rowBindings.put("/Cat1/doc_e/", 4)
+
+
+    val matrixSetup = sparse(
+      (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil,
+      (0, 0.0) ::(1, 0.1) ::(2, 0.0) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil,
+      (0, 0.0) ::(1, 0.1) ::(2, 0.0) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil
+    )
+
+
+    matrixSetup.setRowLabelBindings(rowBindings)
+
+    val TFIDFDrm = drm.drmParallelizeWithRowLabels(m = matrixSetup, numPartitions = 2)
+
+    val (labelIndex, aggregatedTFIDFDrm) = NaiveBayes.extractLabelsAndAggregateObservations(TFIDFDrm)
+
+    labelIndex.size should be (2)
+
+    val cat1=labelIndex("Cat1")
+    val cat2=labelIndex("Cat2")
+
+    cat1 should be (0)
+    cat2 should be (1)
+
+    val aggregatedTFIDFInCore = aggregatedTFIDFDrm.collect
+    aggregatedTFIDFInCore.numCols should be (4)
+    aggregatedTFIDFInCore.numRows should be (2)
+
+    aggregatedTFIDFInCore.get(cat1, 0) - 0.3 should be < epsilon
+    aggregatedTFIDFInCore.get(cat1, 1) - 0.0 should be < epsilon
+    aggregatedTFIDFInCore.get(cat1, 2) - 0.3 should be < epsilon
+    aggregatedTFIDFInCore.get(cat1, 3) - 0.0 should be < epsilon
+    aggregatedTFIDFInCore.get(cat2, 0) - 0.0 should be < epsilon
+    aggregatedTFIDFInCore.get(cat2, 1) - 0.2 should be < epsilon
+    aggregatedTFIDFInCore.get(cat2, 2) - 0.0 should be < epsilon
+    aggregatedTFIDFInCore.get(cat2, 3) - 0.2 should be < epsilon
+
+  }
+
+  test("Model DFS Serialization") {
+
+    // test from simulated sparse TF-IDF data
+    val inCoreTFIDF = sparse(
+      (0, 0.7) ::(1, 0.1) ::(2, 0.1) ::(3, 0.3) :: Nil,
+      (0, 0.4) ::(1, 0.4) ::(2, 0.1) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.0) ::(2, 0.8) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.1) ::(2, 0.1) ::(3, 0.7) :: Nil
+    )
+
+    val labelIndex = new java.util.HashMap[String,Integer]()
+    labelIndex.put("Cat1", 0)
+    labelIndex.put("Cat2", 1)
+    labelIndex.put("Cat3", 2)
+    labelIndex.put("Cat4", 3)
+
+    val TFIDFDrm = drm.drmParallelize(m = inCoreTFIDF, numPartitions = 2)
+
+    // train a Standard NB Model- no label index here
+    val model = NaiveBayes.train(TFIDFDrm, labelIndex, false)
+
+    // validate the model- will throw an exception if model is invalid
+    model.validate()
+
+    // save the model
+    model.dfsWrite(TmpDir)
+
+    // reload a new model which should be equal to the original
+    // this will automatically trigger a validate() call
+    val materializedModel= NBModel.dfsRead(TmpDir)
+
+
+    // check the labelWeights
+    model.labelWeight(0) - materializedModel.labelWeight(0) should be < epsilon //1.2
+    model.labelWeight(1) - materializedModel.labelWeight(1) should be < epsilon //1.0
+    model.labelWeight(2) - materializedModel.labelWeight(2) should be < epsilon //1.0
+    model.labelWeight(3) - materializedModel.labelWeight(3) should be < epsilon //1.0
+
+    // check the Feature weights
+    model.featureWeight(0) - materializedModel.featureWeight(0) should be < epsilon //1.3
+    model.featureWeight(1) - materializedModel.featureWeight(1) should be < epsilon //0.6
+    model.featureWeight(2) - materializedModel.featureWeight(2) should be < epsilon //1.1
+    model.featureWeight(3) - materializedModel.featureWeight(3) should be < epsilon //1.2
+
+    // check to se if the new model is complementary
+    materializedModel.isComplementary should be (model.isComplementary)
+
+    // check the label indexMaps
+    for(elem <- model.labelIndex){
+      model.labelIndex(elem._1) == materializedModel.labelIndex(elem._1) should be (true)
+    }
+  }
+
+}