You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:39 UTC

[11/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
new file mode 100644
index 0000000..a1cd3e0
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
@@ -0,0 +1,488 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Collection;
+import java.util.Iterator;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Class containing several algorithms used to train a Hidden Markov Model. The
+ * three main algorithms are: supervised learning, unsupervised Viterbi and
+ * unsupervised Baum-Welch.
+ */
+public final class HmmTrainer {
+
+  /**
+   * No public constructor for utility classes.
+   */
+  private HmmTrainer() {
+    // nothing to do here really.
+  }
+
+  /**
+   * Create an supervised initial estimate of an HMM Model based on a sequence
+   * of observed and hidden states.
+   *
+   * @param nrOfHiddenStates The total number of hidden states
+   * @param nrOfOutputStates The total number of output states
+   * @param observedSequence Integer array containing the observed sequence
+   * @param hiddenSequence   Integer array containing the hidden sequence
+   * @param pseudoCount      Value that is assigned to non-occurring transitions to avoid zero
+   *                         probabilities.
+   * @return An initial model using the estimated parameters
+   */
+  public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence,
+      int[] hiddenSequence, double pseudoCount) {
+    // make sure the pseudo count is not zero
+    pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+    // initialize the parameters
+    DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
+    DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
+    // assign a small initial probability that is larger than zero, so
+    // unseen states will not get a zero probability
+    transitionMatrix.assign(pseudoCount);
+    emissionMatrix.assign(pseudoCount);
+    // given no prior knowledge, we have to assume that all initial hidden
+    // states are equally likely
+    DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
+    initialProbabilities.assign(1.0 / nrOfHiddenStates);
+
+    // now loop over the sequences to count the number of transitions
+    countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+        hiddenSequence);
+
+    // make sure that probabilities are normalized
+    for (int i = 0; i < nrOfHiddenStates; i++) {
+      // compute sum of probabilities for current row of transition matrix
+      double sum = 0;
+      for (int j = 0; j < nrOfHiddenStates; j++) {
+        sum += transitionMatrix.getQuick(i, j);
+      }
+      // normalize current row of transition matrix
+      for (int j = 0; j < nrOfHiddenStates; j++) {
+        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+      }
+      // compute sum of probabilities for current row of emission matrix
+      sum = 0;
+      for (int j = 0; j < nrOfOutputStates; j++) {
+        sum += emissionMatrix.getQuick(i, j);
+      }
+      // normalize current row of emission matrix
+      for (int j = 0; j < nrOfOutputStates; j++) {
+        emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+      }
+    }
+
+    // return a new model using the parameter estimations
+    return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+  }
+
+  /**
+   * Function that counts the number of state->state and state->output
+   * transitions for the given observed/hidden sequence.
+   *
+   * @param transitionMatrix transition matrix to use.
+   * @param emissionMatrix emission matrix to use for counting.
+   * @param observedSequence observation sequence to use.
+   * @param hiddenSequence sequence of hidden states to use.
+   */
+  private static void countTransitions(Matrix transitionMatrix,
+                                       Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) {
+    emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0],
+        emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1);
+    for (int i = 1; i < observedSequence.length; ++i) {
+      transitionMatrix
+          .setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix
+              .getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1);
+      emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i],
+          emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1);
+    }
+  }
+
+  /**
+   * Create an supervised initial estimate of an HMM Model based on a number of
+   * sequences of observed and hidden states.
+   *
+   * @param nrOfHiddenStates The total number of hidden states
+   * @param nrOfOutputStates The total number of output states
+   * @param hiddenSequences Collection of hidden sequences to use for training
+   * @param observedSequences Collection of observed sequences to use for training associated with hidden sequences.
+   * @param pseudoCount      Value that is assigned to non-occurring transitions to avoid zero
+   *                         probabilities.
+   * @return An initial model using the estimated parameters
+   */
+  public static HmmModel trainSupervisedSequence(int nrOfHiddenStates,
+                                                 int nrOfOutputStates, Collection<int[]> hiddenSequences,
+                                                 Collection<int[]> observedSequences, double pseudoCount) {
+
+    // make sure the pseudo count is not zero
+    pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+    // initialize parameters
+    DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates,
+        nrOfHiddenStates);
+    DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates,
+        nrOfOutputStates);
+    DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
+
+    // assign pseudo count to avoid zero probabilities
+    transitionMatrix.assign(pseudoCount);
+    emissionMatrix.assign(pseudoCount);
+    initialProbabilities.assign(pseudoCount);
+
+    // now loop over the sequences to count the number of transitions
+    Iterator<int[]> hiddenSequenceIt = hiddenSequences.iterator();
+    Iterator<int[]> observedSequenceIt = observedSequences.iterator();
+    while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) {
+      // fetch the current set of sequences
+      int[] hiddenSequence = hiddenSequenceIt.next();
+      int[] observedSequence = observedSequenceIt.next();
+      // increase the count for initial probabilities
+      initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities
+          .getQuick(hiddenSequence[0]) + 1);
+      countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+          hiddenSequence);
+    }
+
+    // make sure that probabilities are normalized
+    double isum = 0; // sum of initial probabilities
+    for (int i = 0; i < nrOfHiddenStates; i++) {
+      isum += initialProbabilities.getQuick(i);
+      // compute sum of probabilities for current row of transition matrix
+      double sum = 0;
+      for (int j = 0; j < nrOfHiddenStates; j++) {
+        sum += transitionMatrix.getQuick(i, j);
+      }
+      // normalize current row of transition matrix
+      for (int j = 0; j < nrOfHiddenStates; j++) {
+        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+      }
+      // compute sum of probabilities for current row of emission matrix
+      sum = 0;
+      for (int j = 0; j < nrOfOutputStates; j++) {
+        sum += emissionMatrix.getQuick(i, j);
+      }
+      // normalize current row of emission matrix
+      for (int j = 0; j < nrOfOutputStates; j++) {
+        emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+      }
+    }
+    // normalize the initial probabilities
+    for (int i = 0; i < nrOfHiddenStates; ++i) {
+      initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
+    }
+
+    // return a new model using the parameter estimates
+    return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+  }
+
+  /**
+   * Iteratively train the parameters of the given initial model wrt to the
+   * observed sequence using Viterbi training.
+   *
+   * @param initialModel     The initial model that gets iterated
+   * @param observedSequence The sequence of observed states
+   * @param pseudoCount      Value that is assigned to non-occurring transitions to avoid zero
+   *                         probabilities.
+   * @param epsilon          Convergence criteria
+   * @param maxIterations    The maximum number of training iterations
+   * @param scaled           Use Log-scaled implementation, this is computationally more
+   *                         expensive but offers better numerical stability for large observed
+   *                         sequences
+   * @return The iterated model
+   */
+  public static HmmModel trainViterbi(HmmModel initialModel,
+                                      int[] observedSequence, double pseudoCount, double epsilon,
+                                      int maxIterations, boolean scaled) {
+
+    // make sure the pseudo count is not zero
+    pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+    // allocate space for iteration models
+    HmmModel lastIteration = initialModel.clone();
+    HmmModel iteration = initialModel.clone();
+
+    // allocate space for Viterbi path calculation
+    int[] viterbiPath = new int[observedSequence.length];
+    int[][] phi = new int[observedSequence.length - 1][initialModel
+        .getNrOfHiddenStates()];
+    double[][] delta = new double[observedSequence.length][initialModel
+        .getNrOfHiddenStates()];
+
+    // now run the Viterbi training iteration
+    for (int i = 0; i < maxIterations; ++i) {
+      // compute the Viterbi path
+      HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration,
+          observedSequence, scaled);
+      // Viterbi iteration uses the viterbi path to update
+      // the probabilities
+      Matrix emissionMatrix = iteration.getEmissionMatrix();
+      Matrix transitionMatrix = iteration.getTransitionMatrix();
+
+      // first, assign the pseudo count
+      emissionMatrix.assign(pseudoCount);
+      transitionMatrix.assign(pseudoCount);
+
+      // now count the transitions
+      countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+          viterbiPath);
+
+      // and normalize the probabilities
+      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+        double sum = 0;
+        // normalize the rows of the transition matrix
+        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+          sum += transitionMatrix.getQuick(j, k);
+        }
+        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+          transitionMatrix
+              .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+        }
+        // normalize the rows of the emission matrix
+        sum = 0;
+        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+          sum += emissionMatrix.getQuick(j, k);
+        }
+        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+          emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+        }
+      }
+      // check for convergence
+      if (checkConvergence(lastIteration, iteration, epsilon)) {
+        break;
+      }
+      // overwrite the last iterated model by the new iteration
+      lastIteration.assign(iteration);
+    }
+    // we are done :)
+    return iteration;
+  }
+
+  /**
+   * Iteratively train the parameters of the given initial model wrt the
+   * observed sequence using Baum-Welch training.
+   *
+   * @param initialModel     The initial model that gets iterated
+   * @param observedSequence The sequence of observed states
+   * @param epsilon          Convergence criteria
+   * @param maxIterations    The maximum number of training iterations
+   * @param scaled           Use log-scaled implementations of forward/backward algorithm. This
+   *                         is computationally more expensive, but offers better numerical
+   *                         stability for long output sequences.
+   * @return The iterated model
+   */
+  public static HmmModel trainBaumWelch(HmmModel initialModel,
+                                        int[] observedSequence, double epsilon, int maxIterations, boolean scaled) {
+    // allocate space for the iterations
+    HmmModel lastIteration = initialModel.clone();
+    HmmModel iteration = initialModel.clone();
+
+    // allocate space for baum-welch factors
+    int hiddenCount = initialModel.getNrOfHiddenStates();
+    int visibleCount = observedSequence.length;
+    Matrix alpha = new DenseMatrix(visibleCount, hiddenCount);
+    Matrix beta = new DenseMatrix(visibleCount, hiddenCount);
+
+    // now run the baum Welch training iteration
+    for (int it = 0; it < maxIterations; ++it) {
+      // fetch emission and transition matrix of current iteration
+      Vector initialProbabilities = iteration.getInitialProbabilities();
+      Matrix emissionMatrix = iteration.getEmissionMatrix();
+      Matrix transitionMatrix = iteration.getTransitionMatrix();
+
+      // compute forward and backward factors
+      HmmAlgorithms.forwardAlgorithm(alpha, iteration, observedSequence, scaled);
+      HmmAlgorithms.backwardAlgorithm(beta, iteration, observedSequence, scaled);
+
+      if (scaled) {
+        logScaledBaumWelch(observedSequence, iteration, alpha, beta);
+      } else {
+        unscaledBaumWelch(observedSequence, iteration, alpha, beta);
+      }
+      // normalize transition/emission probabilities
+      // and normalize the probabilities
+      double isum = 0;
+      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+        double sum = 0;
+        // normalize the rows of the transition matrix
+        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+          sum += transitionMatrix.getQuick(j, k);
+        }
+        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+          transitionMatrix
+              .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+        }
+        // normalize the rows of the emission matrix
+        sum = 0;
+        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+          sum += emissionMatrix.getQuick(j, k);
+        }
+        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+          emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+        }
+        // normalization parameter for initial probabilities
+        isum += initialProbabilities.getQuick(j);
+      }
+      // normalize initial probabilities
+      for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+        initialProbabilities.setQuick(i, initialProbabilities.getQuick(i)
+            / isum);
+      }
+      // check for convergence
+      if (checkConvergence(lastIteration, iteration, epsilon)) {
+        break;
+      }
+      // overwrite the last iterated model by the new iteration
+      lastIteration.assign(iteration);
+    }
+    // we are done :)
+    return iteration;
+  }
+
+  private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
+    Vector initialProbabilities = iteration.getInitialProbabilities();
+    Matrix emissionMatrix = iteration.getEmissionMatrix();
+    Matrix transitionMatrix = iteration.getTransitionMatrix();
+    double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
+
+    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+      initialProbabilities.setQuick(i, alpha.getQuick(0, i)
+          * beta.getQuick(0, i));
+    }
+
+    // recompute transition probabilities
+    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+        double temp = 0;
+        for (int t = 0; t < observedSequence.length - 1; ++t) {
+          temp += alpha.getQuick(t, i)
+              * emissionMatrix.getQuick(j, observedSequence[t + 1])
+              * beta.getQuick(t + 1, j);
+        }
+        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
+            * temp / modelLikelihood);
+      }
+    }
+    // recompute emission probabilities
+    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
+        double temp = 0;
+        for (int t = 0; t < observedSequence.length; ++t) {
+          // delta tensor
+          if (observedSequence[t] == j) {
+            temp += alpha.getQuick(t, i) * beta.getQuick(t, i);
+          }
+        }
+        emissionMatrix.setQuick(i, j, temp / modelLikelihood);
+      }
+    }
+  }
+
+  private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
+    Vector initialProbabilities = iteration.getInitialProbabilities();
+    Matrix emissionMatrix = iteration.getEmissionMatrix();
+    Matrix transitionMatrix = iteration.getTransitionMatrix();
+    double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
+
+    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+      initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i)));
+    }
+
+    // recompute transition probabilities
+    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+        double sum = Double.NEGATIVE_INFINITY; // log(0)
+        for (int t = 0; t < observedSequence.length - 1; ++t) {
+          double temp = alpha.getQuick(t, i)
+              + Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1]))
+              + beta.getQuick(t + 1, j);
+          if (temp > Double.NEGATIVE_INFINITY) {
+            // handle 0-probabilities
+            sum = temp + Math.log1p(Math.exp(sum - temp));
+          }
+        }
+        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
+            * Math.exp(sum - modelLikelihood));
+      }
+    }
+    // recompute emission probabilities
+    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
+        double sum = Double.NEGATIVE_INFINITY; // log(0)
+        for (int t = 0; t < observedSequence.length; ++t) {
+          // delta tensor
+          if (observedSequence[t] == j) {
+            double temp = alpha.getQuick(t, i) + beta.getQuick(t, i);
+            if (temp > Double.NEGATIVE_INFINITY) {
+              // handle 0-probabilities
+              sum = temp + Math.log1p(Math.exp(sum - temp));
+            }
+          }
+        }
+        emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood));
+      }
+    }
+  }
+
+  /**
+   * Check convergence of two HMM models by computing a simple distance between
+   * emission / transition matrices
+   *
+   * @param oldModel Old HMM Model
+   * @param newModel New HMM Model
+   * @param epsilon  Convergence Factor
+   * @return true if training converged to a stable state.
+   */
+  private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel,
+                                          double epsilon) {
+    // check convergence of transitionProbabilities
+    Matrix oldTransitionMatrix = oldModel.getTransitionMatrix();
+    Matrix newTransitionMatrix = newModel.getTransitionMatrix();
+    double diff = 0;
+    for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) {
+        double tmp = oldTransitionMatrix.getQuick(i, j)
+            - newTransitionMatrix.getQuick(i, j);
+        diff += tmp * tmp;
+      }
+    }
+    double norm = Math.sqrt(diff);
+    diff = 0;
+    // check convergence of emissionProbabilities
+    Matrix oldEmissionMatrix = oldModel.getEmissionMatrix();
+    Matrix newEmissionMatrix = newModel.getEmissionMatrix();
+    for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) {
+      for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) {
+
+        double tmp = oldEmissionMatrix.getQuick(i, j)
+            - newEmissionMatrix.getQuick(i, j);
+        diff += tmp * tmp;
+      }
+    }
+    norm += Math.sqrt(diff);
+    // iteration has converged :)
+    return norm < epsilon;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
new file mode 100644
index 0000000..e710816
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
@@ -0,0 +1,360 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * A collection of utilities for handling HMMModel objects.
+ */
+public final class HmmUtils {
+
+  /**
+   * No public constructor for utility classes.
+   */
+  private HmmUtils() {
+    // nothing to do here really.
+  }
+
+  /**
+   * Compute the cumulative transition probability matrix for the given HMM
+   * model. Matrix where each row i is the cumulative distribution of the
+   * transition probability distribution for hidden state i.
+   *
+   * @param model The HMM model for which the cumulative transition matrix should be
+   *              computed
+   * @return The computed cumulative transition matrix.
+   */
+  public static Matrix getCumulativeTransitionMatrix(HmmModel model) {
+    // fetch the needed parameters from the model
+    int hiddenStates = model.getNrOfHiddenStates();
+    Matrix transitionMatrix = model.getTransitionMatrix();
+    // now compute the cumulative transition matrix
+    Matrix resultMatrix = new DenseMatrix(hiddenStates, hiddenStates);
+    for (int i = 0; i < hiddenStates; ++i) {
+      double sum = 0;
+      for (int j = 0; j < hiddenStates; ++j) {
+        sum += transitionMatrix.get(i, j);
+        resultMatrix.set(i, j, sum);
+      }
+      resultMatrix.set(i, hiddenStates - 1, 1.0);
+      // make sure the last
+      // state has always a
+      // cumulative
+      // probability of
+      // exactly 1.0
+    }
+    return resultMatrix;
+  }
+
+  /**
+   * Compute the cumulative output probability matrix for the given HMM model.
+   * Matrix where each row i is the cumulative distribution of the output
+   * probability distribution for hidden state i.
+   *
+   * @param model The HMM model for which the cumulative output matrix should be
+   *              computed
+   * @return The computed cumulative output matrix.
+   */
+  public static Matrix getCumulativeOutputMatrix(HmmModel model) {
+    // fetch the needed parameters from the model
+    int hiddenStates = model.getNrOfHiddenStates();
+    int outputStates = model.getNrOfOutputStates();
+    Matrix outputMatrix = model.getEmissionMatrix();
+    // now compute the cumulative output matrix
+    Matrix resultMatrix = new DenseMatrix(hiddenStates, outputStates);
+    for (int i = 0; i < hiddenStates; ++i) {
+      double sum = 0;
+      for (int j = 0; j < outputStates; ++j) {
+        sum += outputMatrix.get(i, j);
+        resultMatrix.set(i, j, sum);
+      }
+      resultMatrix.set(i, outputStates - 1, 1.0);
+      // make sure the last
+      // output state has
+      // always a cumulative
+      // probability of 1.0
+    }
+    return resultMatrix;
+  }
+
+  /**
+   * Compute the cumulative distribution of the initial hidden state
+   * probabilities for the given HMM model.
+   *
+   * @param model The HMM model for which the cumulative initial state probabilities
+   *              should be computed
+   * @return The computed cumulative initial state probability vector.
+   */
+  public static Vector getCumulativeInitialProbabilities(HmmModel model) {
+    // fetch the needed parameters from the model
+    int hiddenStates = model.getNrOfHiddenStates();
+    Vector initialProbabilities = model.getInitialProbabilities();
+    // now compute the cumulative output matrix
+    Vector resultVector = new DenseVector(initialProbabilities.size());
+    double sum = 0;
+    for (int i = 0; i < hiddenStates; ++i) {
+      sum += initialProbabilities.get(i);
+      resultVector.set(i, sum);
+    }
+    resultVector.set(hiddenStates - 1, 1.0); // make sure the last initial
+    // hidden state probability
+    // has always a cumulative
+    // probability of 1.0
+    return resultVector;
+  }
+
+  /**
+   * Validates an HMM model set
+   *
+   * @param model model to sanity check.
+   */
+  public static void validate(HmmModel model) {
+    if (model == null) {
+      return; // empty models are valid
+    }
+
+    /*
+     * The number of hidden states is positive.
+     */
+    Preconditions.checkArgument(model.getNrOfHiddenStates() > 0,
+      "Error: The number of hidden states has to be greater than 0");
+    
+    /*
+     * The number of output states is positive.
+     */
+    Preconditions.checkArgument(model.getNrOfOutputStates() > 0,
+      "Error: The number of output states has to be greater than 0!");
+
+    /*
+     * The size of the vector of initial probabilities is equal to the number of
+     * the hidden states. Each initial probability is non-negative. The sum of
+     * initial probabilities is equal to 1.
+     */
+    Preconditions.checkArgument(model.getInitialProbabilities() != null
+      && model.getInitialProbabilities().size() == model.getNrOfHiddenStates(),
+      "Error: The vector of initial probabilities is not initialized!");
+    
+    double sum = 0;
+    for (int i = 0; i < model.getInitialProbabilities().size(); i++) {
+      Preconditions.checkArgument(model.getInitialProbabilities().get(i) >= 0,
+        "Error: Initial probability of state %d is negative", i);
+      sum += model.getInitialProbabilities().get(i);
+    }
+    Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+                                "Error: Initial probabilities do not add up to 1");
+    /*
+     * The row size of the output matrix is equal to the number of the hidden
+     * states. The column size is equal to the number of output states. Each
+     * probability of the matrix is non-negative. The sum of each row is equal
+     * to 1.
+     */
+    Preconditions.checkNotNull(model.getEmissionMatrix(), "Error: The output state matrix is not initialized!");
+    Preconditions.checkArgument(model.getEmissionMatrix().numRows() == model.getNrOfHiddenStates()
+        && model.getEmissionMatrix().numCols() == model.getNrOfOutputStates(),
+        "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfOutputStates");
+    for (int i = 0; i < model.getEmissionMatrix().numRows(); i++) {
+      sum = 0;
+      for (int j = 0; j < model.getEmissionMatrix().numCols(); j++) {
+        Preconditions.checkArgument(model.getEmissionMatrix().get(i, j) >= 0,
+            "The output state probability from hidden state " + i + " to output state " + j + " is negative");
+        sum += model.getEmissionMatrix().get(i, j);
+      }
+      Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+        "Error: The output state probabilities for hidden state %d don't add up to 1", i);
+    }
+
+    /*
+     * The size of both dimension of the transition matrix is equal to the
+     * number of the hidden states. Each probability of the matrix is
+     * non-negative. The sum of each row in transition matrix is equal to 1.
+     */
+    Preconditions.checkArgument(model.getTransitionMatrix() != null,
+      "Error: The hidden state matrix is not initialized!");
+    Preconditions.checkArgument(model.getTransitionMatrix().numRows() == model.getNrOfHiddenStates()
+      && model.getTransitionMatrix().numCols() == model.getNrOfHiddenStates(),
+      "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfHiddenStates");
+    for (int i = 0; i < model.getTransitionMatrix().numRows(); i++) {
+      sum = 0;
+      for (int j = 0; j < model.getTransitionMatrix().numCols(); j++) {
+        Preconditions.checkArgument(model.getTransitionMatrix().get(i, j) >= 0,
+          "Error: The transition probability from hidden state %d to hidden state %d is negative", i, j);
+        sum += model.getTransitionMatrix().get(i, j);
+      }
+      Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+        "Error: The transition probabilities for hidden state " + i + " don't add up to 1.");
+    }
+  }
+
+  /**
+   * Encodes a given collection of state names by the corresponding state IDs
+   * registered in a given model.
+   *
+   * @param model        Model to provide the encoding for
+   * @param sequence     Collection of state names
+   * @param observed     If set, the sequence is encoded as a sequence of observed states,
+   *                     else it is encoded as sequence of hidden states
+   * @param defaultValue The default value in case a state is not known
+   * @return integer array containing the encoded state IDs
+   */
+  public static int[] encodeStateSequence(HmmModel model,
+                                          Collection<String> sequence, boolean observed, int defaultValue) {
+    int[] encoded = new int[sequence.size()];
+    Iterator<String> seqIter = sequence.iterator();
+    for (int i = 0; i < sequence.size(); ++i) {
+      String nextState = seqIter.next();
+      int nextID;
+      if (observed) {
+        nextID = model.getOutputStateID(nextState);
+      } else {
+        nextID = model.getHiddenStateID(nextState);
+      }
+      // if the ID is -1, use the default value
+      encoded[i] = nextID < 0 ? defaultValue : nextID;
+    }
+    return encoded;
+  }
+
+  /**
+   * Decodes a given collection of state IDs into the corresponding state names
+   * registered in a given model.
+   *
+   * @param model        model to use for retrieving state names
+   * @param sequence     int array of state IDs
+   * @param observed     If set, the sequence is encoded as a sequence of observed states,
+   *                     else it is encoded as sequence of hidden states
+   * @param defaultValue The default value in case a state is not known
+   * @return list containing the decoded state names
+   */
+  public static List<String> decodeStateSequence(HmmModel model,
+                                                 int[] sequence,
+                                                 boolean observed,
+                                                 String defaultValue) {
+    List<String> decoded = new ArrayList<>(sequence.length);
+    for (int position : sequence) {
+      String nextState;
+      if (observed) {
+        nextState = model.getOutputStateName(position);
+      } else {
+        nextState = model.getHiddenStateName(position);
+      }
+      // if null was returned, use the default value
+      decoded.add(nextState == null ? defaultValue : nextState);
+    }
+    return decoded;
+  }
+
+  /**
+   * Function used to normalize the probabilities of a given HMM model
+   *
+   * @param model model to normalize
+   */
+  public static void normalizeModel(HmmModel model) {
+    Vector ip = model.getInitialProbabilities();
+    Matrix emission = model.getEmissionMatrix();
+    Matrix transition = model.getTransitionMatrix();
+    // check normalization for all probabilities
+    double isum = 0;
+    for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+      isum += ip.getQuick(i);
+      double sum = 0;
+      for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+        sum += transition.getQuick(i, j);
+      }
+      if (sum != 1.0) {
+        for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+          transition.setQuick(i, j, transition.getQuick(i, j) / sum);
+        }
+      }
+      sum = 0;
+      for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+        sum += emission.getQuick(i, j);
+      }
+      if (sum != 1.0) {
+        for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+          emission.setQuick(i, j, emission.getQuick(i, j) / sum);
+        }
+      }
+    }
+    if (isum != 1.0) {
+      for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+        ip.setQuick(i, ip.getQuick(i) / isum);
+      }
+    }
+  }
+
+  /**
+   * Method to reduce the size of an HMMmodel by converting the models
+   * DenseMatrix/DenseVectors to sparse implementations and setting every value
+   * < threshold to 0
+   *
+   * @param model model to truncate
+   * @param threshold minimum value a model entry must have to be retained.
+   * @return Truncated model
+   */
+  public static HmmModel truncateModel(HmmModel model, double threshold) {
+    Vector ip = model.getInitialProbabilities();
+    Matrix em = model.getEmissionMatrix();
+    Matrix tr = model.getTransitionMatrix();
+    // allocate the sparse data structures
+    RandomAccessSparseVector sparseIp = new RandomAccessSparseVector(model
+        .getNrOfHiddenStates());
+    SparseMatrix sparseEm = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfOutputStates());
+    SparseMatrix sparseTr = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfHiddenStates());
+    // now transfer the values
+    for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+      double value = ip.getQuick(i);
+      if (value > threshold) {
+        sparseIp.setQuick(i, value);
+      }
+      for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+        value = tr.getQuick(i, j);
+        if (value > threshold) {
+          sparseTr.setQuick(i, j, value);
+        }
+      }
+
+      for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+        value = em.getQuick(i, j);
+        if (value > threshold) {
+          sparseEm.setQuick(i, j, value);
+        }
+      }
+    }
+    // create a new model
+    HmmModel sparseModel = new HmmModel(sparseTr, sparseEm, sparseIp);
+    // normalize the model
+    normalizeModel(sparseModel);
+    // register the names
+    sparseModel.registerHiddenStateNames(model.getHiddenStateNames());
+    sparseModel.registerOutputStateNames(model.getOutputStateNames());
+    // and return
+    return sparseModel;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
new file mode 100644
index 0000000..d0ae9c2
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Utils for serializing Writable parts of HmmModel (that means without hidden state names and so on)
+ */
+final class LossyHmmSerializer {
+
+  private LossyHmmSerializer() {
+  }
+
+  static void serialize(HmmModel model, DataOutput output) throws IOException {
+    MatrixWritable matrix = new MatrixWritable(model.getEmissionMatrix());
+    matrix.write(output);
+    matrix.set(model.getTransitionMatrix());
+    matrix.write(output);
+
+    VectorWritable vector = new VectorWritable(model.getInitialProbabilities());
+    vector.write(output);
+  }
+
+  static HmmModel deserialize(DataInput input) throws IOException {
+    MatrixWritable matrix = new MatrixWritable();
+    matrix.readFields(input);
+    Matrix emissionMatrix = matrix.get();
+
+    matrix.readFields(input);
+    Matrix transitionMatrix = matrix.get();
+
+    VectorWritable vector = new VectorWritable();
+    vector.readFields(input);
+    Vector initialProbabilities = vector.get();
+
+    return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
new file mode 100644
index 0000000..02baef1
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
@@ -0,0 +1,102 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataInputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.common.CommandLineUtil;
+
+/**
+ * Command-line tool for generating random sequences by given HMM
+ */
+public final class RandomSequenceGenerator {
+
+  private RandomSequenceGenerator() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+    Option outputOption = optionBuilder.withLongName("output").
+      withDescription("Output file with sequence of observed states").
+      withShortName("o").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("path").create()).withRequired(false).create();
+
+    Option modelOption = optionBuilder.withLongName("model").
+      withDescription("Path to serialized HMM model").
+      withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("path").create()).withRequired(true).create();
+
+    Option lengthOption = optionBuilder.withLongName("length").
+      withDescription("Length of generated sequence").
+      withShortName("l").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("number").create()).withRequired(true).create();
+
+    Group optionGroup = new GroupBuilder().
+      withOption(outputOption).withOption(modelOption).withOption(lengthOption).
+      withName("Options").create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(optionGroup);
+      CommandLine commandLine = parser.parse(args);
+
+      String output = (String) commandLine.getValue(outputOption);
+
+      String modelPath = (String) commandLine.getValue(modelOption);
+
+      int length = Integer.parseInt((String) commandLine.getValue(lengthOption));
+
+      //reading serialized HMM
+      HmmModel model;
+      try (DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath))){
+        model = LossyHmmSerializer.deserialize(modelStream);
+      }
+
+      //generating observations
+      int[] observations = HmmEvaluator.predict(model, length, System.currentTimeMillis());
+
+      //writing output
+      try (PrintWriter writer =
+               new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true)){
+        for (int observation : observations) {
+          writer.print(observation);
+          writer.print(' ');
+        }
+      }
+    } catch (OptionException e) {
+      CommandLineUtil.printHelp(optionGroup);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
new file mode 100644
index 0000000..317237d
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataInputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Scanner;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+
+/**
+ * Command-line tool for Viterbi evaluating
+ */
+public final class ViterbiEvaluator {
+
+  private ViterbiEvaluator() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+    Option inputOption = DefaultOptionCreator.inputOption().create();
+
+    Option outputOption = DefaultOptionCreator.outputOption().create();
+
+    Option modelOption = optionBuilder.withLongName("model").
+      withDescription("Path to serialized HMM model").
+      withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("path").create()).withRequired(true).create();
+
+    Option likelihoodOption = optionBuilder.withLongName("likelihood").
+      withDescription("Compute likelihood of observed sequence").
+      withShortName("l").withRequired(false).create();
+
+    Group optionGroup = new GroupBuilder().withOption(inputOption).
+      withOption(outputOption).withOption(modelOption).withOption(likelihoodOption).
+      withName("Options").create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(optionGroup);
+      CommandLine commandLine = parser.parse(args);
+
+      String input = (String) commandLine.getValue(inputOption);
+      String output = (String) commandLine.getValue(outputOption);
+
+      String modelPath = (String) commandLine.getValue(modelOption);
+
+      boolean computeLikelihood = commandLine.hasOption(likelihoodOption);
+
+      //reading serialized HMM
+      ;
+      HmmModel model;
+      try (DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath))) {
+        model = LossyHmmSerializer.deserialize(modelStream);
+      }
+
+      //reading observations
+      List<Integer> observations = new ArrayList<>();
+      try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) {
+        while (scanner.hasNextInt()) {
+          observations.add(scanner.nextInt());
+        }
+      }
+
+      int[] observationsArray = new int[observations.size()];
+      for (int i = 0; i < observations.size(); ++i) {
+        observationsArray[i] = observations.get(i);
+      }
+
+      //decoding
+      int[] hiddenStates = HmmEvaluator.decode(model, observationsArray, true);
+
+      //writing output
+      try (PrintWriter writer =
+               new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true)) {
+        for (int hiddenState : hiddenStates) {
+          writer.print(hiddenState);
+          writer.print(' ');
+        }
+      }
+
+      if (computeLikelihood) {
+        System.out.println("Likelihood: " + HmmEvaluator.modelLikelihood(model, observationsArray, true));
+      }
+    } catch (OptionException e) {
+      CommandLineUtil.printHelp(optionGroup);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
new file mode 100644
index 0000000..0b2c41b
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Generic definition of a 1 of n logistic regression classifier that returns probabilities in
+ * response to a feature vector.  This classifier uses 1 of n-1 coding where the 0-th category
+ * is not stored explicitly.
+ * <p/>
+ * Provides the SGD based algorithm for learning a logistic regression, but omits all
+ * annealing of learning rates.  Any extension of this abstract class must define the overall
+ * and per-term annealing for themselves.
+ */
+public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner {
+  // coefficients for the classification.  This is a dense matrix
+  // that is (numCategories-1) x numFeatures
+  protected Matrix beta;
+
+  // number of categories we are classifying.  This should the number of rows of beta plus one.
+  protected int numCategories;
+
+  protected int step;
+
+  // information about how long since coefficient rows were updated.  This allows lazy regularization.
+  protected Vector updateSteps;
+
+  // information about how many updates we have had on a location.  This allows per-term
+  // annealing a la confidence weighted learning.
+  protected Vector updateCounts;
+
+  // weight of the prior on beta
+  private double lambda = 1.0e-5;
+  protected PriorFunction prior;
+
+  // can we ignore any further regularization when doing classification?
+  private boolean sealed;
+
+  // by default we don't do any fancy training
+  private Gradient gradient = new DefaultGradient();
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param lambda New value of lambda, the weighting factor for the prior distribution.
+   * @return This, so other configurations can be chained.
+   */
+  public AbstractOnlineLogisticRegression lambda(double lambda) {
+    this.lambda = lambda;
+    return this;
+  }
+
+  /**
+   * Computes the inverse link function, by default the logistic link function.
+   *
+   * @param v The output of the linear combination in a GLM.  Note that the value
+   *          of v is disturbed.
+   * @return A version of v with the link function applied.
+   */
+  public static Vector link(Vector v) {
+    double max = v.maxValue();
+    if (max >= 40) {
+      // if max > 40, we subtract the large offset first
+      // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off
+      v.assign(Functions.minus(max)).assign(Functions.EXP);
+      return v.divide(v.norm(1));
+    } else {
+      v.assign(Functions.EXP);
+      return v.divide(1 + v.norm(1));
+    }
+  }
+
+  /**
+   * Computes the binomial logistic inverse link function.
+   *
+   * @param r The value to transform.
+   * @return The logit of r.
+   */
+  public static double link(double r) {
+    if (r < 0.0) {
+      double s = Math.exp(r);
+      return s / (1.0 + s);
+    } else {
+      double s = Math.exp(-r);
+      return 1.0 / (1.0 + s);
+    }
+  }
+
+  @Override
+  public Vector classifyNoLink(Vector instance) {
+    // apply pending regularization to whichever coefficients matter
+    regularize(instance);
+    return beta.times(instance);
+  }
+
+  public double classifyScalarNoLink(Vector instance) {
+    return beta.viewRow(0).dot(instance);
+  }
+
+  /**
+   * Returns n-1 probabilities, one for each category but the 0-th.  The probability of the 0-th
+   * category is 1 - sum(this result).
+   *
+   * @param instance A vector of features to be classified.
+   * @return A vector of probabilities, one for each of the first n-1 categories.
+   */
+  @Override
+  public Vector classify(Vector instance) {
+    return link(classifyNoLink(instance));
+  }
+
+  /**
+   * Returns a single scalar probability in the case where we have two categories.  Using this
+   * method avoids an extra vector allocation as opposed to calling classify() or an extra two
+   * vector allocations relative to classifyFull().
+   *
+   * @param instance The vector of features to be classified.
+   * @return The probability of the first of two categories.
+   * @throws IllegalArgumentException If the classifier doesn't have two categories.
+   */
+  @Override
+  public double classifyScalar(Vector instance) {
+    Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
+
+    // apply pending regularization to whichever coefficients matter
+    regularize(instance);
+
+    // result is a vector with one element so we can just use dot product
+    return link(classifyScalarNoLink(instance));
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+    unseal();
+
+    double learningRate = currentLearningRate();
+
+    // push coefficients back to zero based on the prior
+    regularize(instance);
+
+    // update each row of coefficients according to result
+    Vector gradient = this.gradient.apply(groupKey, actual, instance, this);
+    for (int i = 0; i < numCategories - 1; i++) {
+      double gradientBase = gradient.get(i);
+
+      // then we apply the gradientBase to the resulting element.
+      for (Element updateLocation : instance.nonZeroes()) {
+        int j = updateLocation.index();
+
+        double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j);
+        beta.setQuick(i, j, newValue);
+      }
+    }
+
+    // remember that these elements got updated
+    for (Element element : instance.nonZeroes()) {
+      int j = element.index();
+      updateSteps.setQuick(j, getStep());
+      updateCounts.incrementQuick(j, 1);
+    }
+    nextStep();
+
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(0, null, actual, instance);
+  }
+
+  public void regularize(Vector instance) {
+    if (updateSteps == null || isSealed()) {
+      return;
+    }
+
+    // anneal learning rate
+    double learningRate = currentLearningRate();
+
+    // here we lazily apply the prior to make up for our neglect
+    for (int i = 0; i < numCategories - 1; i++) {
+      for (Element updateLocation : instance.nonZeroes()) {
+        int j = updateLocation.index();
+        double missingUpdates = getStep() - updateSteps.get(j);
+        if (missingUpdates > 0) {
+          double rate = getLambda() * learningRate * perTermLearningRate(j);
+          double newValue = prior.age(beta.get(i, j), missingUpdates, rate);
+          beta.set(i, j, newValue);
+          updateSteps.set(j, getStep());
+        }
+      }
+    }
+  }
+
+  // these two abstract methods are how extensions can modify the basic learning behavior of this object.
+
+  public abstract double perTermLearningRate(int j);
+
+  public abstract double currentLearningRate();
+
+  public void setPrior(PriorFunction prior) {
+    this.prior = prior;
+  }
+
+  public void setGradient(Gradient gradient) {
+    this.gradient = gradient;
+  }
+
+  public PriorFunction getPrior() {
+    return prior;
+  }
+
+  public Matrix getBeta() {
+    close();
+    return beta;
+  }
+
+  public void setBeta(int i, int j, double betaIJ) {
+    beta.set(i, j, betaIJ);
+  }
+
+  @Override
+  public int numCategories() {
+    return numCategories;
+  }
+
+  public int numFeatures() {
+    return beta.numCols();
+  }
+
+  public double getLambda() {
+    return lambda;
+  }
+
+  public int getStep() {
+    return step;
+  }
+
+  protected void nextStep() {
+    step++;
+  }
+
+  public boolean isSealed() {
+    return sealed;
+  }
+
+  protected void unseal() {
+    sealed = false;
+  }
+
+  private void regularizeAll() {
+    Vector all = new DenseVector(beta.numCols());
+    all.assign(1);
+    regularize(all);
+  }
+
+  @Override
+  public void close() {
+    if (!sealed) {
+      step++;
+      regularizeAll();
+      sealed = true;
+    }
+  }
+
+  public void copyFrom(AbstractOnlineLogisticRegression other) {
+    // number of categories we are classifying.  This should the number of rows of beta plus one.
+    Preconditions.checkArgument(numCategories == other.numCategories,
+            "Can't copy unless number of target categories is the same");
+
+    beta.assign(other.beta);
+
+    step = other.step;
+
+    updateSteps.assign(other.updateSteps);
+    updateCounts.assign(other.updateCounts);
+  }
+
+  public boolean validModel() {
+    double k = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+      @Override
+      public double apply(double v) {
+        return Double.isNaN(v) || Double.isInfinite(v) ? 1 : 0;
+      }
+    });
+    return k < 1;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
new file mode 100644
index 0000000..24e5798
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
@@ -0,0 +1,586 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.EvolutionaryProcess;
+import org.apache.mahout.ep.Mapping;
+import org.apache.mahout.ep.Payload;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.ExecutionException;
+
+/**
+ * This is a meta-learner that maintains a pool of ordinary
+ * {@link org.apache.mahout.classifier.sgd.OnlineLogisticRegression} learners. Each
+ * member of the pool has different learning rates.  Whichever of the learners in the pool falls
+ * behind in terms of average log-likelihood will be tossed out and replaced with variants of the
+ * survivors.  This will let us automatically derive an annealing schedule that optimizes learning
+ * speed.  Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might
+ * seem that it would to maintain multiple learners in memory.  Doing this adaptation on-line as we
+ * learn also decreases the number of learning rate parameters required and replaces the normal
+ * hyper-parameter search.
+ * <p/>
+ * One wrinkle is that the pool of learners that we maintain is actually a pool of
+ * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which themselves contain several OnlineLogisticRegression
+ * objects.  These pools allow estimation
+ * of performance on the fly even if we make many passes through the data.  This does, however,
+ * increase the cost of training since if we are using 5-fold cross-validation, each vector is used
+ * 4 times for training and once for classification.  If this becomes a problem, then we should
+ * probably use a 2-way unbalanced train/test split rather than full cross validation.  With the
+ * current default settings, we have 100 learners running.  This is better than the alternative of
+ * running hundreds of training passes to find good hyper-parameters because we only have to parse
+ * and feature-ize our inputs once.  If you already have good hyper-parameters, then you might
+ * prefer to just run one CrossFoldLearner with those settings.
+ * <p/>
+ * The fitness used here is AUC.  Another alternative would be to try log-likelihood, but it is much
+ * easier to get bogus values of log-likelihood than with AUC and the results seem to accord pretty
+ * well.  It would be nice to allow the fitness function to be pluggable. This use of AUC means that
+ * AdaptiveLogisticRegression is mostly suited for binary target variables. This will be fixed
+ * before long by extending OnlineAuc to handle non-binary cases or by using a different fitness
+ * value in non-binary cases.
+ */
+public class AdaptiveLogisticRegression implements OnlineLearner, Writable {
+  public static final int DEFAULT_THREAD_COUNT = 20;
+  public static final int DEFAULT_POOL_SIZE = 20;
+  private static final int SURVIVORS = 2;
+
+  private int record;
+  private int cutoff = 1000;
+  private int minInterval = 1000;
+  private int maxInterval = 1000;
+  private int currentStep = 1000;
+  private int bufferSize = 1000;
+
+  private List<TrainingExample> buffer = new ArrayList<>();
+  private EvolutionaryProcess<Wrapper, CrossFoldLearner> ep;
+  private State<Wrapper, CrossFoldLearner> best;
+  private int threadCount = DEFAULT_THREAD_COUNT;
+  private int poolSize = DEFAULT_POOL_SIZE;
+  private State<Wrapper, CrossFoldLearner> seed;
+  private int numFeatures;
+
+  private boolean freezeSurvivors = true;
+
+  private static final Logger log = LoggerFactory.getLogger(AdaptiveLogisticRegression.class);
+
+  public AdaptiveLogisticRegression() {}
+
+  /**
+   * Uses {@link #DEFAULT_THREAD_COUNT} and {@link #DEFAULT_POOL_SIZE}
+   * @param numCategories The number of categories (labels) to train on
+   * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+   * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+   *
+   * @see #AdaptiveLogisticRegression(int, int, org.apache.mahout.classifier.sgd.PriorFunction, int, int)
+   */
+  public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+    this(numCategories, numFeatures, prior, DEFAULT_THREAD_COUNT, DEFAULT_POOL_SIZE);
+  }
+
+  /**
+   *
+   * @param numCategories The number of categories (labels) to train on
+   * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+   * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+   * @param threadCount The number of threads to use for training
+   * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use.
+   */
+  public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount,
+      int poolSize) {
+    this.numFeatures = numFeatures;
+    this.threadCount = threadCount;
+    this.poolSize = poolSize;
+    seed = new State<>(new double[2], 10);
+    Wrapper w = new Wrapper(numCategories, numFeatures, prior);
+    seed.setPayload(w);
+
+    Wrapper.setMappings(seed);
+    seed.setPayload(w);
+    setPoolSize(this.poolSize);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(record, null, actual, instance);
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+    record++;
+
+    buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
+    //don't train until we have enough examples
+    if (buffer.size() > bufferSize) {
+      trainWithBufferedExamples();
+    }
+  }
+
+  private void trainWithBufferedExamples() {
+    try {
+      this.best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+        @Override
+        public double apply(Payload<CrossFoldLearner> z, double[] params) {
+          Wrapper x = (Wrapper) z;
+          for (TrainingExample example : buffer) {
+            x.train(example);
+          }
+          if (x.getLearner().validModel()) {
+            if (x.getLearner().numCategories() == 2) {
+              return x.wrapped.auc();
+            } else {
+              return x.wrapped.logLikelihood();
+            }
+          } else {
+            return Double.NaN;
+          }
+        }
+      });
+    } catch (InterruptedException e) {
+      // ignore ... shouldn't happen
+      log.warn("Ignoring exception", e);
+    } catch (ExecutionException e) {
+      throw new IllegalStateException(e.getCause());
+    }
+    buffer.clear();
+
+    if (record > cutoff) {
+      cutoff = nextStep(record);
+
+      // evolve based on new fitness
+      ep.mutatePopulation(SURVIVORS);
+
+      if (freezeSurvivors) {
+        // now grossly hack the top survivors so they stick around.  Set their
+        // mutation rates small and also hack their learning rate to be small
+        // as well.
+        for (State<Wrapper, CrossFoldLearner> state : ep.getPopulation().subList(0, SURVIVORS)) {
+          Wrapper.freeze(state);
+        }
+      }
+    }
+
+  }
+
+  public int nextStep(int recordNumber) {
+    int stepSize = stepSize(recordNumber, 2.6);
+    if (stepSize < minInterval) {
+      stepSize = minInterval;
+    }
+
+    if (stepSize > maxInterval) {
+      stepSize = maxInterval;
+    }
+
+    int newCutoff = stepSize * (recordNumber / stepSize + 1);
+    if (newCutoff < cutoff + currentStep) {
+      newCutoff = cutoff + currentStep;
+    } else {
+      this.currentStep = stepSize;
+    }
+    return newCutoff;
+  }
+
+  public static int stepSize(int recordNumber, double multiplier) {
+    int[] bumps = {1, 2, 5};
+    double log = Math.floor(multiplier * Math.log10(recordNumber));
+    int bump = bumps[(int) log % bumps.length];
+    int scale = (int) Math.pow(10, Math.floor(log / bumps.length));
+
+    return bump * scale;
+  }
+
+  @Override
+  public void close() {
+    trainWithBufferedExamples();
+    try {
+      ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+        @Override
+        public double apply(Payload<CrossFoldLearner> payload, double[] params) {
+          CrossFoldLearner learner = ((Wrapper) payload).getLearner();
+          learner.close();
+          return learner.logLikelihood();
+        }
+      });
+    } catch (InterruptedException e) {
+      log.warn("Ignoring exception", e);
+    } catch (ExecutionException e) {
+      throw new IllegalStateException(e);
+    } finally {
+      ep.close();
+    }
+  }
+
+  /**
+   * How often should the evolutionary optimization of learning parameters occur?
+   *
+   * @param interval Number of training examples to use in each epoch of optimization.
+   */
+  public void setInterval(int interval) {
+    setInterval(interval, interval);
+  }
+
+  /**
+   * Starts optimization using the shorter interval and progresses to the longer using the specified
+   * number of steps per decade.  Note that values < 200 are not accepted.  Values even that small
+   * are unlikely to be useful.
+   *
+   * @param minInterval The minimum epoch length for the evolutionary optimization
+   * @param maxInterval The maximum epoch length
+   */
+  public void setInterval(int minInterval, int maxInterval) {
+    this.minInterval = Math.max(200, minInterval);
+    this.maxInterval = Math.max(200, maxInterval);
+    this.cutoff = minInterval * (record / minInterval + 1);
+    this.currentStep = minInterval;
+    bufferSize = Math.min(minInterval, bufferSize);
+  }
+
+  public final void setPoolSize(int poolSize) {
+    this.poolSize = poolSize;
+    setupOptimizer(poolSize);
+  }
+
+  public void setThreadCount(int threadCount) {
+    this.threadCount = threadCount;
+    setupOptimizer(poolSize);
+  }
+
+  public void setAucEvaluator(OnlineAuc auc) {
+    seed.getPayload().setAucEvaluator(auc);
+    setupOptimizer(poolSize);
+  }
+
+  private void setupOptimizer(int poolSize) {
+    ep = new EvolutionaryProcess<>(threadCount, poolSize, seed);
+  }
+
+  /**
+   * Returns the size of the internal feature vector.  Note that this is not the same as the number
+   * of distinct features, especially if feature hashing is being used.
+   *
+   * @return The internal feature vector size.
+   */
+  public int numFeatures() {
+    return numFeatures;
+  }
+
+  /**
+   * What is the AUC for the current best member of the population.  If no member is best, usually
+   * because we haven't done any training yet, then the result is set to NaN.
+   *
+   * @return The AUC of the best member of the population or NaN if we can't figure that out.
+   */
+  public double auc() {
+    if (best == null) {
+      return Double.NaN;
+    } else {
+      Wrapper payload = best.getPayload();
+      return payload.getLearner().auc();
+    }
+  }
+
+  public State<Wrapper, CrossFoldLearner> getBest() {
+    return best;
+  }
+
+  public void setBest(State<Wrapper, CrossFoldLearner> best) {
+    this.best = best;
+  }
+
+  public int getRecord() {
+    return record;
+  }
+
+  public void setRecord(int record) {
+    this.record = record;
+  }
+
+  public int getMinInterval() {
+    return minInterval;
+  }
+
+  public int getMaxInterval() {
+    return maxInterval;
+  }
+
+  public int getNumCategories() {
+    return seed.getPayload().getLearner().numCategories();
+  }
+
+  public PriorFunction getPrior() {
+    return seed.getPayload().getLearner().getPrior();
+  }
+
+  public void setBuffer(List<TrainingExample> buffer) {
+    this.buffer = buffer;
+  }
+
+  public List<TrainingExample> getBuffer() {
+    return buffer;
+  }
+
+  public EvolutionaryProcess<Wrapper, CrossFoldLearner> getEp() {
+    return ep;
+  }
+
+  public void setEp(EvolutionaryProcess<Wrapper, CrossFoldLearner> ep) {
+    this.ep = ep;
+  }
+
+  public State<Wrapper, CrossFoldLearner> getSeed() {
+    return seed;
+  }
+
+  public void setSeed(State<Wrapper, CrossFoldLearner> seed) {
+    this.seed = seed;
+  }
+
+  public int getNumFeatures() {
+    return numFeatures;
+  }
+
+  public void setAveragingWindow(int averagingWindow) {
+    seed.getPayload().getLearner().setWindowSize(averagingWindow);
+    setupOptimizer(poolSize);
+  }
+
+  public void setFreezeSurvivors(boolean freezeSurvivors) {
+    this.freezeSurvivors = freezeSurvivors;
+  }
+
+  /**
+   * Provides a shim between the EP optimization stuff and the CrossFoldLearner.  The most important
+   * interface has to do with the parameters of the optimization.  These are taken from the double[]
+   * params in the following order <ul> <li> regularization constant lambda <li> learningRate </ul>.
+   * All other parameters are set in such a way so as to defeat annealing to the extent possible.
+   * This lets the evolutionary algorithm handle the annealing.
+   * <p/>
+   * Note that per coefficient annealing is still done and no optimization of the per coefficient
+   * offset is done.
+   */
+  public static class Wrapper implements Payload<CrossFoldLearner> {
+    private CrossFoldLearner wrapped;
+
+    public Wrapper() {
+    }
+
+    public Wrapper(int numCategories, int numFeatures, PriorFunction prior) {
+      wrapped = new CrossFoldLearner(5, numCategories, numFeatures, prior);
+    }
+
+    @Override
+    public Wrapper copy() {
+      Wrapper r = new Wrapper();
+      r.wrapped = wrapped.copy();
+      return r;
+    }
+
+    @Override
+    public void update(double[] params) {
+      int i = 0;
+      wrapped.lambda(params[i++]);
+      wrapped.learningRate(params[i]);
+
+      wrapped.stepOffset(1);
+      wrapped.alpha(1);
+      wrapped.decayExponent(0);
+    }
+
+    public static void freeze(State<Wrapper, CrossFoldLearner> s) {
+      // radically decrease learning rate
+      double[] params = s.getParams();
+      params[1] -= 10;
+
+      // and cause evolution to hold (almost)
+      s.setOmni(s.getOmni() / 20);
+      double[] step = s.getStep();
+      for (int i = 0; i < step.length; i++) {
+        step[i] /= 20;
+      }
+    }
+
+    public static void setMappings(State<Wrapper, CrossFoldLearner> x) {
+      int i = 0;
+      // set the range for regularization (lambda)
+      x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1));
+      // set the range for learning rate (mu)
+      x.setMap(i, Mapping.logLimit(1.0e-8, 1));
+    }
+
+    public void train(TrainingExample example) {
+      wrapped.train(example.getKey(), example.getGroupKey(), example.getActual(), example.getInstance());
+    }
+
+    public CrossFoldLearner getLearner() {
+      return wrapped;
+    }
+
+    @Override
+    public String toString() {
+      return String.format(Locale.ENGLISH, "auc=%.2f", wrapped.auc());
+    }
+
+    public void setAucEvaluator(OnlineAuc auc) {
+      wrapped.setAucEvaluator(auc);
+    }
+
+    @Override
+    public void write(DataOutput out) throws IOException {
+      wrapped.write(out);
+    }
+
+    @Override
+    public void readFields(DataInput input) throws IOException {
+      wrapped = new CrossFoldLearner();
+      wrapped.readFields(input);
+    }
+  }
+
+  public static class TrainingExample implements Writable {
+    private long key;
+    private String groupKey;
+    private int actual;
+    private Vector instance;
+
+    private TrainingExample() {
+    }
+
+    public TrainingExample(long key, String groupKey, int actual, Vector instance) {
+      this.key = key;
+      this.groupKey = groupKey;
+      this.actual = actual;
+      this.instance = instance;
+    }
+
+    public long getKey() {
+      return key;
+    }
+
+    public int getActual() {
+      return actual;
+    }
+
+    public Vector getInstance() {
+      return instance;
+    }
+
+    public String getGroupKey() {
+      return groupKey;
+    }
+
+    @Override
+    public void write(DataOutput out) throws IOException {
+      out.writeLong(key);
+      if (groupKey != null) {
+        out.writeBoolean(true);
+        out.writeUTF(groupKey);
+      } else {
+        out.writeBoolean(false);
+      }
+      out.writeInt(actual);
+      VectorWritable.writeVector(out, instance, true);
+    }
+
+    @Override
+    public void readFields(DataInput in) throws IOException {
+      key = in.readLong();
+      if (in.readBoolean()) {
+        groupKey = in.readUTF();
+      }
+      actual = in.readInt();
+      instance = VectorWritable.readVector(in);
+    }
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeInt(record);
+    out.writeInt(cutoff);
+    out.writeInt(minInterval);
+    out.writeInt(maxInterval);
+    out.writeInt(currentStep);
+    out.writeInt(bufferSize);
+
+    out.writeInt(buffer.size());
+    for (TrainingExample example : buffer) {
+      example.write(out);
+    }
+
+    ep.write(out);
+
+    best.write(out);
+
+    out.writeInt(threadCount);
+    out.writeInt(poolSize);
+    seed.write(out);
+    out.writeInt(numFeatures);
+
+    out.writeBoolean(freezeSurvivors);
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    record = in.readInt();
+    cutoff = in.readInt();
+    minInterval = in.readInt();
+    maxInterval = in.readInt();
+    currentStep = in.readInt();
+    bufferSize = in.readInt();
+
+    int n = in.readInt();
+    buffer = new ArrayList<>();
+    for (int i = 0; i < n; i++) {
+      TrainingExample example = new TrainingExample();
+      example.readFields(in);
+      buffer.add(example);
+    }
+
+    ep = new EvolutionaryProcess<>();
+    ep.readFields(in);
+
+    best = new State<>();
+    best.readFields(in);
+
+    threadCount = in.readInt();
+    poolSize = in.readInt();
+    seed = new State<>();
+    seed.readFields(in);
+
+    numFeatures = in.readInt();
+    freezeSurvivors = in.readBoolean();
+  }
+}
+