You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:39 UTC
[11/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
new file mode 100644
index 0000000..a1cd3e0
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
@@ -0,0 +1,488 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Collection;
+import java.util.Iterator;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Class containing several algorithms used to train a Hidden Markov Model. The
+ * three main algorithms are: supervised learning, unsupervised Viterbi and
+ * unsupervised Baum-Welch.
+ */
+public final class HmmTrainer {
+
+ /**
+ * No public constructor for utility classes.
+ */
+ private HmmTrainer() {
+ // nothing to do here really.
+ }
+
+ /**
+ * Create an supervised initial estimate of an HMM Model based on a sequence
+ * of observed and hidden states.
+ *
+ * @param nrOfHiddenStates The total number of hidden states
+ * @param nrOfOutputStates The total number of output states
+ * @param observedSequence Integer array containing the observed sequence
+ * @param hiddenSequence Integer array containing the hidden sequence
+ * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
+ * probabilities.
+ * @return An initial model using the estimated parameters
+ */
+ public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence,
+ int[] hiddenSequence, double pseudoCount) {
+ // make sure the pseudo count is not zero
+ pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+ // initialize the parameters
+ DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
+ DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
+ // assign a small initial probability that is larger than zero, so
+ // unseen states will not get a zero probability
+ transitionMatrix.assign(pseudoCount);
+ emissionMatrix.assign(pseudoCount);
+ // given no prior knowledge, we have to assume that all initial hidden
+ // states are equally likely
+ DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
+ initialProbabilities.assign(1.0 / nrOfHiddenStates);
+
+ // now loop over the sequences to count the number of transitions
+ countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+ hiddenSequence);
+
+ // make sure that probabilities are normalized
+ for (int i = 0; i < nrOfHiddenStates; i++) {
+ // compute sum of probabilities for current row of transition matrix
+ double sum = 0;
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ sum += transitionMatrix.getQuick(i, j);
+ }
+ // normalize current row of transition matrix
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+ }
+ // compute sum of probabilities for current row of emission matrix
+ sum = 0;
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ sum += emissionMatrix.getQuick(i, j);
+ }
+ // normalize current row of emission matrix
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+ }
+ }
+
+ // return a new model using the parameter estimations
+ return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+ }
+
+ /**
+ * Function that counts the number of state->state and state->output
+ * transitions for the given observed/hidden sequence.
+ *
+ * @param transitionMatrix transition matrix to use.
+ * @param emissionMatrix emission matrix to use for counting.
+ * @param observedSequence observation sequence to use.
+ * @param hiddenSequence sequence of hidden states to use.
+ */
+ private static void countTransitions(Matrix transitionMatrix,
+ Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) {
+ emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0],
+ emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1);
+ for (int i = 1; i < observedSequence.length; ++i) {
+ transitionMatrix
+ .setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix
+ .getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1);
+ emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i],
+ emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1);
+ }
+ }
+
+ /**
+ * Create an supervised initial estimate of an HMM Model based on a number of
+ * sequences of observed and hidden states.
+ *
+ * @param nrOfHiddenStates The total number of hidden states
+ * @param nrOfOutputStates The total number of output states
+ * @param hiddenSequences Collection of hidden sequences to use for training
+ * @param observedSequences Collection of observed sequences to use for training associated with hidden sequences.
+ * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
+ * probabilities.
+ * @return An initial model using the estimated parameters
+ */
+ public static HmmModel trainSupervisedSequence(int nrOfHiddenStates,
+ int nrOfOutputStates, Collection<int[]> hiddenSequences,
+ Collection<int[]> observedSequences, double pseudoCount) {
+
+ // make sure the pseudo count is not zero
+ pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+ // initialize parameters
+ DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates,
+ nrOfHiddenStates);
+ DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates,
+ nrOfOutputStates);
+ DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
+
+ // assign pseudo count to avoid zero probabilities
+ transitionMatrix.assign(pseudoCount);
+ emissionMatrix.assign(pseudoCount);
+ initialProbabilities.assign(pseudoCount);
+
+ // now loop over the sequences to count the number of transitions
+ Iterator<int[]> hiddenSequenceIt = hiddenSequences.iterator();
+ Iterator<int[]> observedSequenceIt = observedSequences.iterator();
+ while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) {
+ // fetch the current set of sequences
+ int[] hiddenSequence = hiddenSequenceIt.next();
+ int[] observedSequence = observedSequenceIt.next();
+ // increase the count for initial probabilities
+ initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities
+ .getQuick(hiddenSequence[0]) + 1);
+ countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+ hiddenSequence);
+ }
+
+ // make sure that probabilities are normalized
+ double isum = 0; // sum of initial probabilities
+ for (int i = 0; i < nrOfHiddenStates; i++) {
+ isum += initialProbabilities.getQuick(i);
+ // compute sum of probabilities for current row of transition matrix
+ double sum = 0;
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ sum += transitionMatrix.getQuick(i, j);
+ }
+ // normalize current row of transition matrix
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+ }
+ // compute sum of probabilities for current row of emission matrix
+ sum = 0;
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ sum += emissionMatrix.getQuick(i, j);
+ }
+ // normalize current row of emission matrix
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+ }
+ }
+ // normalize the initial probabilities
+ for (int i = 0; i < nrOfHiddenStates; ++i) {
+ initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
+ }
+
+ // return a new model using the parameter estimates
+ return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+ }
+
+ /**
+ * Iteratively train the parameters of the given initial model wrt to the
+ * observed sequence using Viterbi training.
+ *
+ * @param initialModel The initial model that gets iterated
+ * @param observedSequence The sequence of observed states
+ * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
+ * probabilities.
+ * @param epsilon Convergence criteria
+ * @param maxIterations The maximum number of training iterations
+ * @param scaled Use Log-scaled implementation, this is computationally more
+ * expensive but offers better numerical stability for large observed
+ * sequences
+ * @return The iterated model
+ */
+ public static HmmModel trainViterbi(HmmModel initialModel,
+ int[] observedSequence, double pseudoCount, double epsilon,
+ int maxIterations, boolean scaled) {
+
+ // make sure the pseudo count is not zero
+ pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+ // allocate space for iteration models
+ HmmModel lastIteration = initialModel.clone();
+ HmmModel iteration = initialModel.clone();
+
+ // allocate space for Viterbi path calculation
+ int[] viterbiPath = new int[observedSequence.length];
+ int[][] phi = new int[observedSequence.length - 1][initialModel
+ .getNrOfHiddenStates()];
+ double[][] delta = new double[observedSequence.length][initialModel
+ .getNrOfHiddenStates()];
+
+ // now run the Viterbi training iteration
+ for (int i = 0; i < maxIterations; ++i) {
+ // compute the Viterbi path
+ HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration,
+ observedSequence, scaled);
+ // Viterbi iteration uses the viterbi path to update
+ // the probabilities
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+
+ // first, assign the pseudo count
+ emissionMatrix.assign(pseudoCount);
+ transitionMatrix.assign(pseudoCount);
+
+ // now count the transitions
+ countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+ viterbiPath);
+
+ // and normalize the probabilities
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double sum = 0;
+ // normalize the rows of the transition matrix
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ sum += transitionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ transitionMatrix
+ .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+ }
+ // normalize the rows of the emission matrix
+ sum = 0;
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ sum += emissionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+ }
+ }
+ // check for convergence
+ if (checkConvergence(lastIteration, iteration, epsilon)) {
+ break;
+ }
+ // overwrite the last iterated model by the new iteration
+ lastIteration.assign(iteration);
+ }
+ // we are done :)
+ return iteration;
+ }
+
+ /**
+ * Iteratively train the parameters of the given initial model wrt the
+ * observed sequence using Baum-Welch training.
+ *
+ * @param initialModel The initial model that gets iterated
+ * @param observedSequence The sequence of observed states
+ * @param epsilon Convergence criteria
+ * @param maxIterations The maximum number of training iterations
+ * @param scaled Use log-scaled implementations of forward/backward algorithm. This
+ * is computationally more expensive, but offers better numerical
+ * stability for long output sequences.
+ * @return The iterated model
+ */
+ public static HmmModel trainBaumWelch(HmmModel initialModel,
+ int[] observedSequence, double epsilon, int maxIterations, boolean scaled) {
+ // allocate space for the iterations
+ HmmModel lastIteration = initialModel.clone();
+ HmmModel iteration = initialModel.clone();
+
+ // allocate space for baum-welch factors
+ int hiddenCount = initialModel.getNrOfHiddenStates();
+ int visibleCount = observedSequence.length;
+ Matrix alpha = new DenseMatrix(visibleCount, hiddenCount);
+ Matrix beta = new DenseMatrix(visibleCount, hiddenCount);
+
+ // now run the baum Welch training iteration
+ for (int it = 0; it < maxIterations; ++it) {
+ // fetch emission and transition matrix of current iteration
+ Vector initialProbabilities = iteration.getInitialProbabilities();
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+
+ // compute forward and backward factors
+ HmmAlgorithms.forwardAlgorithm(alpha, iteration, observedSequence, scaled);
+ HmmAlgorithms.backwardAlgorithm(beta, iteration, observedSequence, scaled);
+
+ if (scaled) {
+ logScaledBaumWelch(observedSequence, iteration, alpha, beta);
+ } else {
+ unscaledBaumWelch(observedSequence, iteration, alpha, beta);
+ }
+ // normalize transition/emission probabilities
+ // and normalize the probabilities
+ double isum = 0;
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double sum = 0;
+ // normalize the rows of the transition matrix
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ sum += transitionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ transitionMatrix
+ .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+ }
+ // normalize the rows of the emission matrix
+ sum = 0;
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ sum += emissionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+ }
+ // normalization parameter for initial probabilities
+ isum += initialProbabilities.getQuick(j);
+ }
+ // normalize initial probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ initialProbabilities.setQuick(i, initialProbabilities.getQuick(i)
+ / isum);
+ }
+ // check for convergence
+ if (checkConvergence(lastIteration, iteration, epsilon)) {
+ break;
+ }
+ // overwrite the last iterated model by the new iteration
+ lastIteration.assign(iteration);
+ }
+ // we are done :)
+ return iteration;
+ }
+
+ private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
+ Vector initialProbabilities = iteration.getInitialProbabilities();
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+ double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
+
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ initialProbabilities.setQuick(i, alpha.getQuick(0, i)
+ * beta.getQuick(0, i));
+ }
+
+ // recompute transition probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double temp = 0;
+ for (int t = 0; t < observedSequence.length - 1; ++t) {
+ temp += alpha.getQuick(t, i)
+ * emissionMatrix.getQuick(j, observedSequence[t + 1])
+ * beta.getQuick(t + 1, j);
+ }
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
+ * temp / modelLikelihood);
+ }
+ }
+ // recompute emission probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
+ double temp = 0;
+ for (int t = 0; t < observedSequence.length; ++t) {
+ // delta tensor
+ if (observedSequence[t] == j) {
+ temp += alpha.getQuick(t, i) * beta.getQuick(t, i);
+ }
+ }
+ emissionMatrix.setQuick(i, j, temp / modelLikelihood);
+ }
+ }
+ }
+
+ private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
+ Vector initialProbabilities = iteration.getInitialProbabilities();
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+ double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
+
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i)));
+ }
+
+ // recompute transition probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double sum = Double.NEGATIVE_INFINITY; // log(0)
+ for (int t = 0; t < observedSequence.length - 1; ++t) {
+ double temp = alpha.getQuick(t, i)
+ + Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1]))
+ + beta.getQuick(t + 1, j);
+ if (temp > Double.NEGATIVE_INFINITY) {
+ // handle 0-probabilities
+ sum = temp + Math.log1p(Math.exp(sum - temp));
+ }
+ }
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
+ * Math.exp(sum - modelLikelihood));
+ }
+ }
+ // recompute emission probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
+ double sum = Double.NEGATIVE_INFINITY; // log(0)
+ for (int t = 0; t < observedSequence.length; ++t) {
+ // delta tensor
+ if (observedSequence[t] == j) {
+ double temp = alpha.getQuick(t, i) + beta.getQuick(t, i);
+ if (temp > Double.NEGATIVE_INFINITY) {
+ // handle 0-probabilities
+ sum = temp + Math.log1p(Math.exp(sum - temp));
+ }
+ }
+ }
+ emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood));
+ }
+ }
+ }
+
+ /**
+ * Check convergence of two HMM models by computing a simple distance between
+ * emission / transition matrices
+ *
+ * @param oldModel Old HMM Model
+ * @param newModel New HMM Model
+ * @param epsilon Convergence Factor
+ * @return true if training converged to a stable state.
+ */
+ private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel,
+ double epsilon) {
+ // check convergence of transitionProbabilities
+ Matrix oldTransitionMatrix = oldModel.getTransitionMatrix();
+ Matrix newTransitionMatrix = newModel.getTransitionMatrix();
+ double diff = 0;
+ for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) {
+ double tmp = oldTransitionMatrix.getQuick(i, j)
+ - newTransitionMatrix.getQuick(i, j);
+ diff += tmp * tmp;
+ }
+ }
+ double norm = Math.sqrt(diff);
+ diff = 0;
+ // check convergence of emissionProbabilities
+ Matrix oldEmissionMatrix = oldModel.getEmissionMatrix();
+ Matrix newEmissionMatrix = newModel.getEmissionMatrix();
+ for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) {
+ for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) {
+
+ double tmp = oldEmissionMatrix.getQuick(i, j)
+ - newEmissionMatrix.getQuick(i, j);
+ diff += tmp * tmp;
+ }
+ }
+ norm += Math.sqrt(diff);
+ // iteration has converged :)
+ return norm < epsilon;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
new file mode 100644
index 0000000..e710816
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
@@ -0,0 +1,360 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * A collection of utilities for handling HMMModel objects.
+ */
+public final class HmmUtils {
+
+ /**
+ * No public constructor for utility classes.
+ */
+ private HmmUtils() {
+ // nothing to do here really.
+ }
+
+ /**
+ * Compute the cumulative transition probability matrix for the given HMM
+ * model. Matrix where each row i is the cumulative distribution of the
+ * transition probability distribution for hidden state i.
+ *
+ * @param model The HMM model for which the cumulative transition matrix should be
+ * computed
+ * @return The computed cumulative transition matrix.
+ */
+ public static Matrix getCumulativeTransitionMatrix(HmmModel model) {
+ // fetch the needed parameters from the model
+ int hiddenStates = model.getNrOfHiddenStates();
+ Matrix transitionMatrix = model.getTransitionMatrix();
+ // now compute the cumulative transition matrix
+ Matrix resultMatrix = new DenseMatrix(hiddenStates, hiddenStates);
+ for (int i = 0; i < hiddenStates; ++i) {
+ double sum = 0;
+ for (int j = 0; j < hiddenStates; ++j) {
+ sum += transitionMatrix.get(i, j);
+ resultMatrix.set(i, j, sum);
+ }
+ resultMatrix.set(i, hiddenStates - 1, 1.0);
+ // make sure the last
+ // state has always a
+ // cumulative
+ // probability of
+ // exactly 1.0
+ }
+ return resultMatrix;
+ }
+
+ /**
+ * Compute the cumulative output probability matrix for the given HMM model.
+ * Matrix where each row i is the cumulative distribution of the output
+ * probability distribution for hidden state i.
+ *
+ * @param model The HMM model for which the cumulative output matrix should be
+ * computed
+ * @return The computed cumulative output matrix.
+ */
+ public static Matrix getCumulativeOutputMatrix(HmmModel model) {
+ // fetch the needed parameters from the model
+ int hiddenStates = model.getNrOfHiddenStates();
+ int outputStates = model.getNrOfOutputStates();
+ Matrix outputMatrix = model.getEmissionMatrix();
+ // now compute the cumulative output matrix
+ Matrix resultMatrix = new DenseMatrix(hiddenStates, outputStates);
+ for (int i = 0; i < hiddenStates; ++i) {
+ double sum = 0;
+ for (int j = 0; j < outputStates; ++j) {
+ sum += outputMatrix.get(i, j);
+ resultMatrix.set(i, j, sum);
+ }
+ resultMatrix.set(i, outputStates - 1, 1.0);
+ // make sure the last
+ // output state has
+ // always a cumulative
+ // probability of 1.0
+ }
+ return resultMatrix;
+ }
+
+ /**
+ * Compute the cumulative distribution of the initial hidden state
+ * probabilities for the given HMM model.
+ *
+ * @param model The HMM model for which the cumulative initial state probabilities
+ * should be computed
+ * @return The computed cumulative initial state probability vector.
+ */
+ public static Vector getCumulativeInitialProbabilities(HmmModel model) {
+ // fetch the needed parameters from the model
+ int hiddenStates = model.getNrOfHiddenStates();
+ Vector initialProbabilities = model.getInitialProbabilities();
+ // now compute the cumulative output matrix
+ Vector resultVector = new DenseVector(initialProbabilities.size());
+ double sum = 0;
+ for (int i = 0; i < hiddenStates; ++i) {
+ sum += initialProbabilities.get(i);
+ resultVector.set(i, sum);
+ }
+ resultVector.set(hiddenStates - 1, 1.0); // make sure the last initial
+ // hidden state probability
+ // has always a cumulative
+ // probability of 1.0
+ return resultVector;
+ }
+
+ /**
+ * Validates an HMM model set
+ *
+ * @param model model to sanity check.
+ */
+ public static void validate(HmmModel model) {
+ if (model == null) {
+ return; // empty models are valid
+ }
+
+ /*
+ * The number of hidden states is positive.
+ */
+ Preconditions.checkArgument(model.getNrOfHiddenStates() > 0,
+ "Error: The number of hidden states has to be greater than 0");
+
+ /*
+ * The number of output states is positive.
+ */
+ Preconditions.checkArgument(model.getNrOfOutputStates() > 0,
+ "Error: The number of output states has to be greater than 0!");
+
+ /*
+ * The size of the vector of initial probabilities is equal to the number of
+ * the hidden states. Each initial probability is non-negative. The sum of
+ * initial probabilities is equal to 1.
+ */
+ Preconditions.checkArgument(model.getInitialProbabilities() != null
+ && model.getInitialProbabilities().size() == model.getNrOfHiddenStates(),
+ "Error: The vector of initial probabilities is not initialized!");
+
+ double sum = 0;
+ for (int i = 0; i < model.getInitialProbabilities().size(); i++) {
+ Preconditions.checkArgument(model.getInitialProbabilities().get(i) >= 0,
+ "Error: Initial probability of state %d is negative", i);
+ sum += model.getInitialProbabilities().get(i);
+ }
+ Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+ "Error: Initial probabilities do not add up to 1");
+ /*
+ * The row size of the output matrix is equal to the number of the hidden
+ * states. The column size is equal to the number of output states. Each
+ * probability of the matrix is non-negative. The sum of each row is equal
+ * to 1.
+ */
+ Preconditions.checkNotNull(model.getEmissionMatrix(), "Error: The output state matrix is not initialized!");
+ Preconditions.checkArgument(model.getEmissionMatrix().numRows() == model.getNrOfHiddenStates()
+ && model.getEmissionMatrix().numCols() == model.getNrOfOutputStates(),
+ "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfOutputStates");
+ for (int i = 0; i < model.getEmissionMatrix().numRows(); i++) {
+ sum = 0;
+ for (int j = 0; j < model.getEmissionMatrix().numCols(); j++) {
+ Preconditions.checkArgument(model.getEmissionMatrix().get(i, j) >= 0,
+ "The output state probability from hidden state " + i + " to output state " + j + " is negative");
+ sum += model.getEmissionMatrix().get(i, j);
+ }
+ Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+ "Error: The output state probabilities for hidden state %d don't add up to 1", i);
+ }
+
+ /*
+ * The size of both dimension of the transition matrix is equal to the
+ * number of the hidden states. Each probability of the matrix is
+ * non-negative. The sum of each row in transition matrix is equal to 1.
+ */
+ Preconditions.checkArgument(model.getTransitionMatrix() != null,
+ "Error: The hidden state matrix is not initialized!");
+ Preconditions.checkArgument(model.getTransitionMatrix().numRows() == model.getNrOfHiddenStates()
+ && model.getTransitionMatrix().numCols() == model.getNrOfHiddenStates(),
+ "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfHiddenStates");
+ for (int i = 0; i < model.getTransitionMatrix().numRows(); i++) {
+ sum = 0;
+ for (int j = 0; j < model.getTransitionMatrix().numCols(); j++) {
+ Preconditions.checkArgument(model.getTransitionMatrix().get(i, j) >= 0,
+ "Error: The transition probability from hidden state %d to hidden state %d is negative", i, j);
+ sum += model.getTransitionMatrix().get(i, j);
+ }
+ Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+ "Error: The transition probabilities for hidden state " + i + " don't add up to 1.");
+ }
+ }
+
+ /**
+ * Encodes a given collection of state names by the corresponding state IDs
+ * registered in a given model.
+ *
+ * @param model Model to provide the encoding for
+ * @param sequence Collection of state names
+ * @param observed If set, the sequence is encoded as a sequence of observed states,
+ * else it is encoded as sequence of hidden states
+ * @param defaultValue The default value in case a state is not known
+ * @return integer array containing the encoded state IDs
+ */
+ public static int[] encodeStateSequence(HmmModel model,
+ Collection<String> sequence, boolean observed, int defaultValue) {
+ int[] encoded = new int[sequence.size()];
+ Iterator<String> seqIter = sequence.iterator();
+ for (int i = 0; i < sequence.size(); ++i) {
+ String nextState = seqIter.next();
+ int nextID;
+ if (observed) {
+ nextID = model.getOutputStateID(nextState);
+ } else {
+ nextID = model.getHiddenStateID(nextState);
+ }
+ // if the ID is -1, use the default value
+ encoded[i] = nextID < 0 ? defaultValue : nextID;
+ }
+ return encoded;
+ }
+
+ /**
+ * Decodes a given collection of state IDs into the corresponding state names
+ * registered in a given model.
+ *
+ * @param model model to use for retrieving state names
+ * @param sequence int array of state IDs
+ * @param observed If set, the sequence is encoded as a sequence of observed states,
+ * else it is encoded as sequence of hidden states
+ * @param defaultValue The default value in case a state is not known
+ * @return list containing the decoded state names
+ */
+ public static List<String> decodeStateSequence(HmmModel model,
+ int[] sequence,
+ boolean observed,
+ String defaultValue) {
+ List<String> decoded = new ArrayList<>(sequence.length);
+ for (int position : sequence) {
+ String nextState;
+ if (observed) {
+ nextState = model.getOutputStateName(position);
+ } else {
+ nextState = model.getHiddenStateName(position);
+ }
+ // if null was returned, use the default value
+ decoded.add(nextState == null ? defaultValue : nextState);
+ }
+ return decoded;
+ }
+
+ /**
+ * Function used to normalize the probabilities of a given HMM model
+ *
+ * @param model model to normalize
+ */
+ public static void normalizeModel(HmmModel model) {
+ Vector ip = model.getInitialProbabilities();
+ Matrix emission = model.getEmissionMatrix();
+ Matrix transition = model.getTransitionMatrix();
+ // check normalization for all probabilities
+ double isum = 0;
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ isum += ip.getQuick(i);
+ double sum = 0;
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+ sum += transition.getQuick(i, j);
+ }
+ if (sum != 1.0) {
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+ transition.setQuick(i, j, transition.getQuick(i, j) / sum);
+ }
+ }
+ sum = 0;
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+ sum += emission.getQuick(i, j);
+ }
+ if (sum != 1.0) {
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+ emission.setQuick(i, j, emission.getQuick(i, j) / sum);
+ }
+ }
+ }
+ if (isum != 1.0) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ ip.setQuick(i, ip.getQuick(i) / isum);
+ }
+ }
+ }
+
+ /**
+ * Method to reduce the size of an HMMmodel by converting the models
+ * DenseMatrix/DenseVectors to sparse implementations and setting every value
+ * < threshold to 0
+ *
+ * @param model model to truncate
+ * @param threshold minimum value a model entry must have to be retained.
+ * @return Truncated model
+ */
+ public static HmmModel truncateModel(HmmModel model, double threshold) {
+ Vector ip = model.getInitialProbabilities();
+ Matrix em = model.getEmissionMatrix();
+ Matrix tr = model.getTransitionMatrix();
+ // allocate the sparse data structures
+ RandomAccessSparseVector sparseIp = new RandomAccessSparseVector(model
+ .getNrOfHiddenStates());
+ SparseMatrix sparseEm = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfOutputStates());
+ SparseMatrix sparseTr = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfHiddenStates());
+ // now transfer the values
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ double value = ip.getQuick(i);
+ if (value > threshold) {
+ sparseIp.setQuick(i, value);
+ }
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+ value = tr.getQuick(i, j);
+ if (value > threshold) {
+ sparseTr.setQuick(i, j, value);
+ }
+ }
+
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+ value = em.getQuick(i, j);
+ if (value > threshold) {
+ sparseEm.setQuick(i, j, value);
+ }
+ }
+ }
+ // create a new model
+ HmmModel sparseModel = new HmmModel(sparseTr, sparseEm, sparseIp);
+ // normalize the model
+ normalizeModel(sparseModel);
+ // register the names
+ sparseModel.registerHiddenStateNames(model.getHiddenStateNames());
+ sparseModel.registerOutputStateNames(model.getOutputStateNames());
+ // and return
+ return sparseModel;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
new file mode 100644
index 0000000..d0ae9c2
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Utils for serializing Writable parts of HmmModel (that means without hidden state names and so on)
+ */
+final class LossyHmmSerializer {
+
+ private LossyHmmSerializer() {
+ }
+
+ static void serialize(HmmModel model, DataOutput output) throws IOException {
+ MatrixWritable matrix = new MatrixWritable(model.getEmissionMatrix());
+ matrix.write(output);
+ matrix.set(model.getTransitionMatrix());
+ matrix.write(output);
+
+ VectorWritable vector = new VectorWritable(model.getInitialProbabilities());
+ vector.write(output);
+ }
+
+ static HmmModel deserialize(DataInput input) throws IOException {
+ MatrixWritable matrix = new MatrixWritable();
+ matrix.readFields(input);
+ Matrix emissionMatrix = matrix.get();
+
+ matrix.readFields(input);
+ Matrix transitionMatrix = matrix.get();
+
+ VectorWritable vector = new VectorWritable();
+ vector.readFields(input);
+ Vector initialProbabilities = vector.get();
+
+ return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
new file mode 100644
index 0000000..02baef1
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
@@ -0,0 +1,102 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataInputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.common.CommandLineUtil;
+
+/**
+ * Command-line tool for generating random sequences by given HMM
+ */
+public final class RandomSequenceGenerator {
+
+ private RandomSequenceGenerator() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ Option outputOption = optionBuilder.withLongName("output").
+ withDescription("Output file with sequence of observed states").
+ withShortName("o").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(false).create();
+
+ Option modelOption = optionBuilder.withLongName("model").
+ withDescription("Path to serialized HMM model").
+ withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(true).create();
+
+ Option lengthOption = optionBuilder.withLongName("length").
+ withDescription("Length of generated sequence").
+ withShortName("l").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Group optionGroup = new GroupBuilder().
+ withOption(outputOption).withOption(modelOption).withOption(lengthOption).
+ withName("Options").create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(optionGroup);
+ CommandLine commandLine = parser.parse(args);
+
+ String output = (String) commandLine.getValue(outputOption);
+
+ String modelPath = (String) commandLine.getValue(modelOption);
+
+ int length = Integer.parseInt((String) commandLine.getValue(lengthOption));
+
+ //reading serialized HMM
+ HmmModel model;
+ try (DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath))){
+ model = LossyHmmSerializer.deserialize(modelStream);
+ }
+
+ //generating observations
+ int[] observations = HmmEvaluator.predict(model, length, System.currentTimeMillis());
+
+ //writing output
+ try (PrintWriter writer =
+ new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true)){
+ for (int observation : observations) {
+ writer.print(observation);
+ writer.print(' ');
+ }
+ }
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(optionGroup);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
new file mode 100644
index 0000000..317237d
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataInputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Scanner;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+
+/**
+ * Command-line tool for Viterbi evaluating
+ */
+public final class ViterbiEvaluator {
+
+ private ViterbiEvaluator() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ Option inputOption = DefaultOptionCreator.inputOption().create();
+
+ Option outputOption = DefaultOptionCreator.outputOption().create();
+
+ Option modelOption = optionBuilder.withLongName("model").
+ withDescription("Path to serialized HMM model").
+ withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(true).create();
+
+ Option likelihoodOption = optionBuilder.withLongName("likelihood").
+ withDescription("Compute likelihood of observed sequence").
+ withShortName("l").withRequired(false).create();
+
+ Group optionGroup = new GroupBuilder().withOption(inputOption).
+ withOption(outputOption).withOption(modelOption).withOption(likelihoodOption).
+ withName("Options").create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(optionGroup);
+ CommandLine commandLine = parser.parse(args);
+
+ String input = (String) commandLine.getValue(inputOption);
+ String output = (String) commandLine.getValue(outputOption);
+
+ String modelPath = (String) commandLine.getValue(modelOption);
+
+ boolean computeLikelihood = commandLine.hasOption(likelihoodOption);
+
+ //reading serialized HMM
+ ;
+ HmmModel model;
+ try (DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath))) {
+ model = LossyHmmSerializer.deserialize(modelStream);
+ }
+
+ //reading observations
+ List<Integer> observations = new ArrayList<>();
+ try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) {
+ while (scanner.hasNextInt()) {
+ observations.add(scanner.nextInt());
+ }
+ }
+
+ int[] observationsArray = new int[observations.size()];
+ for (int i = 0; i < observations.size(); ++i) {
+ observationsArray[i] = observations.get(i);
+ }
+
+ //decoding
+ int[] hiddenStates = HmmEvaluator.decode(model, observationsArray, true);
+
+ //writing output
+ try (PrintWriter writer =
+ new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true)) {
+ for (int hiddenState : hiddenStates) {
+ writer.print(hiddenState);
+ writer.print(' ');
+ }
+ }
+
+ if (computeLikelihood) {
+ System.out.println("Likelihood: " + HmmEvaluator.modelLikelihood(model, observationsArray, true));
+ }
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(optionGroup);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
new file mode 100644
index 0000000..0b2c41b
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Generic definition of a 1 of n logistic regression classifier that returns probabilities in
+ * response to a feature vector. This classifier uses 1 of n-1 coding where the 0-th category
+ * is not stored explicitly.
+ * <p/>
+ * Provides the SGD based algorithm for learning a logistic regression, but omits all
+ * annealing of learning rates. Any extension of this abstract class must define the overall
+ * and per-term annealing for themselves.
+ */
+public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner {
+ // coefficients for the classification. This is a dense matrix
+ // that is (numCategories-1) x numFeatures
+ protected Matrix beta;
+
+ // number of categories we are classifying. This should the number of rows of beta plus one.
+ protected int numCategories;
+
+ protected int step;
+
+ // information about how long since coefficient rows were updated. This allows lazy regularization.
+ protected Vector updateSteps;
+
+ // information about how many updates we have had on a location. This allows per-term
+ // annealing a la confidence weighted learning.
+ protected Vector updateCounts;
+
+ // weight of the prior on beta
+ private double lambda = 1.0e-5;
+ protected PriorFunction prior;
+
+ // can we ignore any further regularization when doing classification?
+ private boolean sealed;
+
+ // by default we don't do any fancy training
+ private Gradient gradient = new DefaultGradient();
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param lambda New value of lambda, the weighting factor for the prior distribution.
+ * @return This, so other configurations can be chained.
+ */
+ public AbstractOnlineLogisticRegression lambda(double lambda) {
+ this.lambda = lambda;
+ return this;
+ }
+
+ /**
+ * Computes the inverse link function, by default the logistic link function.
+ *
+ * @param v The output of the linear combination in a GLM. Note that the value
+ * of v is disturbed.
+ * @return A version of v with the link function applied.
+ */
+ public static Vector link(Vector v) {
+ double max = v.maxValue();
+ if (max >= 40) {
+ // if max > 40, we subtract the large offset first
+ // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off
+ v.assign(Functions.minus(max)).assign(Functions.EXP);
+ return v.divide(v.norm(1));
+ } else {
+ v.assign(Functions.EXP);
+ return v.divide(1 + v.norm(1));
+ }
+ }
+
+ /**
+ * Computes the binomial logistic inverse link function.
+ *
+ * @param r The value to transform.
+ * @return The logit of r.
+ */
+ public static double link(double r) {
+ if (r < 0.0) {
+ double s = Math.exp(r);
+ return s / (1.0 + s);
+ } else {
+ double s = Math.exp(-r);
+ return 1.0 / (1.0 + s);
+ }
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ // apply pending regularization to whichever coefficients matter
+ regularize(instance);
+ return beta.times(instance);
+ }
+
+ public double classifyScalarNoLink(Vector instance) {
+ return beta.viewRow(0).dot(instance);
+ }
+
+ /**
+ * Returns n-1 probabilities, one for each category but the 0-th. The probability of the 0-th
+ * category is 1 - sum(this result).
+ *
+ * @param instance A vector of features to be classified.
+ * @return A vector of probabilities, one for each of the first n-1 categories.
+ */
+ @Override
+ public Vector classify(Vector instance) {
+ return link(classifyNoLink(instance));
+ }
+
+ /**
+ * Returns a single scalar probability in the case where we have two categories. Using this
+ * method avoids an extra vector allocation as opposed to calling classify() or an extra two
+ * vector allocations relative to classifyFull().
+ *
+ * @param instance The vector of features to be classified.
+ * @return The probability of the first of two categories.
+ * @throws IllegalArgumentException If the classifier doesn't have two categories.
+ */
+ @Override
+ public double classifyScalar(Vector instance) {
+ Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
+
+ // apply pending regularization to whichever coefficients matter
+ regularize(instance);
+
+ // result is a vector with one element so we can just use dot product
+ return link(classifyScalarNoLink(instance));
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ unseal();
+
+ double learningRate = currentLearningRate();
+
+ // push coefficients back to zero based on the prior
+ regularize(instance);
+
+ // update each row of coefficients according to result
+ Vector gradient = this.gradient.apply(groupKey, actual, instance, this);
+ for (int i = 0; i < numCategories - 1; i++) {
+ double gradientBase = gradient.get(i);
+
+ // then we apply the gradientBase to the resulting element.
+ for (Element updateLocation : instance.nonZeroes()) {
+ int j = updateLocation.index();
+
+ double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j);
+ beta.setQuick(i, j, newValue);
+ }
+ }
+
+ // remember that these elements got updated
+ for (Element element : instance.nonZeroes()) {
+ int j = element.index();
+ updateSteps.setQuick(j, getStep());
+ updateCounts.incrementQuick(j, 1);
+ }
+ nextStep();
+
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+ public void regularize(Vector instance) {
+ if (updateSteps == null || isSealed()) {
+ return;
+ }
+
+ // anneal learning rate
+ double learningRate = currentLearningRate();
+
+ // here we lazily apply the prior to make up for our neglect
+ for (int i = 0; i < numCategories - 1; i++) {
+ for (Element updateLocation : instance.nonZeroes()) {
+ int j = updateLocation.index();
+ double missingUpdates = getStep() - updateSteps.get(j);
+ if (missingUpdates > 0) {
+ double rate = getLambda() * learningRate * perTermLearningRate(j);
+ double newValue = prior.age(beta.get(i, j), missingUpdates, rate);
+ beta.set(i, j, newValue);
+ updateSteps.set(j, getStep());
+ }
+ }
+ }
+ }
+
+ // these two abstract methods are how extensions can modify the basic learning behavior of this object.
+
+ public abstract double perTermLearningRate(int j);
+
+ public abstract double currentLearningRate();
+
+ public void setPrior(PriorFunction prior) {
+ this.prior = prior;
+ }
+
+ public void setGradient(Gradient gradient) {
+ this.gradient = gradient;
+ }
+
+ public PriorFunction getPrior() {
+ return prior;
+ }
+
+ public Matrix getBeta() {
+ close();
+ return beta;
+ }
+
+ public void setBeta(int i, int j, double betaIJ) {
+ beta.set(i, j, betaIJ);
+ }
+
+ @Override
+ public int numCategories() {
+ return numCategories;
+ }
+
+ public int numFeatures() {
+ return beta.numCols();
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public int getStep() {
+ return step;
+ }
+
+ protected void nextStep() {
+ step++;
+ }
+
+ public boolean isSealed() {
+ return sealed;
+ }
+
+ protected void unseal() {
+ sealed = false;
+ }
+
+ private void regularizeAll() {
+ Vector all = new DenseVector(beta.numCols());
+ all.assign(1);
+ regularize(all);
+ }
+
+ @Override
+ public void close() {
+ if (!sealed) {
+ step++;
+ regularizeAll();
+ sealed = true;
+ }
+ }
+
+ public void copyFrom(AbstractOnlineLogisticRegression other) {
+ // number of categories we are classifying. This should the number of rows of beta plus one.
+ Preconditions.checkArgument(numCategories == other.numCategories,
+ "Can't copy unless number of target categories is the same");
+
+ beta.assign(other.beta);
+
+ step = other.step;
+
+ updateSteps.assign(other.updateSteps);
+ updateCounts.assign(other.updateCounts);
+ }
+
+ public boolean validModel() {
+ double k = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return Double.isNaN(v) || Double.isInfinite(v) ? 1 : 0;
+ }
+ });
+ return k < 1;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
new file mode 100644
index 0000000..24e5798
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
@@ -0,0 +1,586 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.EvolutionaryProcess;
+import org.apache.mahout.ep.Mapping;
+import org.apache.mahout.ep.Payload;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.ExecutionException;
+
+/**
+ * This is a meta-learner that maintains a pool of ordinary
+ * {@link org.apache.mahout.classifier.sgd.OnlineLogisticRegression} learners. Each
+ * member of the pool has different learning rates. Whichever of the learners in the pool falls
+ * behind in terms of average log-likelihood will be tossed out and replaced with variants of the
+ * survivors. This will let us automatically derive an annealing schedule that optimizes learning
+ * speed. Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might
+ * seem that it would to maintain multiple learners in memory. Doing this adaptation on-line as we
+ * learn also decreases the number of learning rate parameters required and replaces the normal
+ * hyper-parameter search.
+ * <p/>
+ * One wrinkle is that the pool of learners that we maintain is actually a pool of
+ * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which themselves contain several OnlineLogisticRegression
+ * objects. These pools allow estimation
+ * of performance on the fly even if we make many passes through the data. This does, however,
+ * increase the cost of training since if we are using 5-fold cross-validation, each vector is used
+ * 4 times for training and once for classification. If this becomes a problem, then we should
+ * probably use a 2-way unbalanced train/test split rather than full cross validation. With the
+ * current default settings, we have 100 learners running. This is better than the alternative of
+ * running hundreds of training passes to find good hyper-parameters because we only have to parse
+ * and feature-ize our inputs once. If you already have good hyper-parameters, then you might
+ * prefer to just run one CrossFoldLearner with those settings.
+ * <p/>
+ * The fitness used here is AUC. Another alternative would be to try log-likelihood, but it is much
+ * easier to get bogus values of log-likelihood than with AUC and the results seem to accord pretty
+ * well. It would be nice to allow the fitness function to be pluggable. This use of AUC means that
+ * AdaptiveLogisticRegression is mostly suited for binary target variables. This will be fixed
+ * before long by extending OnlineAuc to handle non-binary cases or by using a different fitness
+ * value in non-binary cases.
+ */
+public class AdaptiveLogisticRegression implements OnlineLearner, Writable {
+ public static final int DEFAULT_THREAD_COUNT = 20;
+ public static final int DEFAULT_POOL_SIZE = 20;
+ private static final int SURVIVORS = 2;
+
+ private int record;
+ private int cutoff = 1000;
+ private int minInterval = 1000;
+ private int maxInterval = 1000;
+ private int currentStep = 1000;
+ private int bufferSize = 1000;
+
+ private List<TrainingExample> buffer = new ArrayList<>();
+ private EvolutionaryProcess<Wrapper, CrossFoldLearner> ep;
+ private State<Wrapper, CrossFoldLearner> best;
+ private int threadCount = DEFAULT_THREAD_COUNT;
+ private int poolSize = DEFAULT_POOL_SIZE;
+ private State<Wrapper, CrossFoldLearner> seed;
+ private int numFeatures;
+
+ private boolean freezeSurvivors = true;
+
+ private static final Logger log = LoggerFactory.getLogger(AdaptiveLogisticRegression.class);
+
+ public AdaptiveLogisticRegression() {}
+
+ /**
+ * Uses {@link #DEFAULT_THREAD_COUNT} and {@link #DEFAULT_POOL_SIZE}
+ * @param numCategories The number of categories (labels) to train on
+ * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+ * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+ *
+ * @see #AdaptiveLogisticRegression(int, int, org.apache.mahout.classifier.sgd.PriorFunction, int, int)
+ */
+ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+ this(numCategories, numFeatures, prior, DEFAULT_THREAD_COUNT, DEFAULT_POOL_SIZE);
+ }
+
+ /**
+ *
+ * @param numCategories The number of categories (labels) to train on
+ * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+ * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+ * @param threadCount The number of threads to use for training
+ * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use.
+ */
+ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount,
+ int poolSize) {
+ this.numFeatures = numFeatures;
+ this.threadCount = threadCount;
+ this.poolSize = poolSize;
+ seed = new State<>(new double[2], 10);
+ Wrapper w = new Wrapper(numCategories, numFeatures, prior);
+ seed.setPayload(w);
+
+ Wrapper.setMappings(seed);
+ seed.setPayload(w);
+ setPoolSize(this.poolSize);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(record, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ record++;
+
+ buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
+ //don't train until we have enough examples
+ if (buffer.size() > bufferSize) {
+ trainWithBufferedExamples();
+ }
+ }
+
+ private void trainWithBufferedExamples() {
+ try {
+ this.best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+ @Override
+ public double apply(Payload<CrossFoldLearner> z, double[] params) {
+ Wrapper x = (Wrapper) z;
+ for (TrainingExample example : buffer) {
+ x.train(example);
+ }
+ if (x.getLearner().validModel()) {
+ if (x.getLearner().numCategories() == 2) {
+ return x.wrapped.auc();
+ } else {
+ return x.wrapped.logLikelihood();
+ }
+ } else {
+ return Double.NaN;
+ }
+ }
+ });
+ } catch (InterruptedException e) {
+ // ignore ... shouldn't happen
+ log.warn("Ignoring exception", e);
+ } catch (ExecutionException e) {
+ throw new IllegalStateException(e.getCause());
+ }
+ buffer.clear();
+
+ if (record > cutoff) {
+ cutoff = nextStep(record);
+
+ // evolve based on new fitness
+ ep.mutatePopulation(SURVIVORS);
+
+ if (freezeSurvivors) {
+ // now grossly hack the top survivors so they stick around. Set their
+ // mutation rates small and also hack their learning rate to be small
+ // as well.
+ for (State<Wrapper, CrossFoldLearner> state : ep.getPopulation().subList(0, SURVIVORS)) {
+ Wrapper.freeze(state);
+ }
+ }
+ }
+
+ }
+
+ public int nextStep(int recordNumber) {
+ int stepSize = stepSize(recordNumber, 2.6);
+ if (stepSize < minInterval) {
+ stepSize = minInterval;
+ }
+
+ if (stepSize > maxInterval) {
+ stepSize = maxInterval;
+ }
+
+ int newCutoff = stepSize * (recordNumber / stepSize + 1);
+ if (newCutoff < cutoff + currentStep) {
+ newCutoff = cutoff + currentStep;
+ } else {
+ this.currentStep = stepSize;
+ }
+ return newCutoff;
+ }
+
+ public static int stepSize(int recordNumber, double multiplier) {
+ int[] bumps = {1, 2, 5};
+ double log = Math.floor(multiplier * Math.log10(recordNumber));
+ int bump = bumps[(int) log % bumps.length];
+ int scale = (int) Math.pow(10, Math.floor(log / bumps.length));
+
+ return bump * scale;
+ }
+
+ @Override
+ public void close() {
+ trainWithBufferedExamples();
+ try {
+ ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+ @Override
+ public double apply(Payload<CrossFoldLearner> payload, double[] params) {
+ CrossFoldLearner learner = ((Wrapper) payload).getLearner();
+ learner.close();
+ return learner.logLikelihood();
+ }
+ });
+ } catch (InterruptedException e) {
+ log.warn("Ignoring exception", e);
+ } catch (ExecutionException e) {
+ throw new IllegalStateException(e);
+ } finally {
+ ep.close();
+ }
+ }
+
+ /**
+ * How often should the evolutionary optimization of learning parameters occur?
+ *
+ * @param interval Number of training examples to use in each epoch of optimization.
+ */
+ public void setInterval(int interval) {
+ setInterval(interval, interval);
+ }
+
+ /**
+ * Starts optimization using the shorter interval and progresses to the longer using the specified
+ * number of steps per decade. Note that values < 200 are not accepted. Values even that small
+ * are unlikely to be useful.
+ *
+ * @param minInterval The minimum epoch length for the evolutionary optimization
+ * @param maxInterval The maximum epoch length
+ */
+ public void setInterval(int minInterval, int maxInterval) {
+ this.minInterval = Math.max(200, minInterval);
+ this.maxInterval = Math.max(200, maxInterval);
+ this.cutoff = minInterval * (record / minInterval + 1);
+ this.currentStep = minInterval;
+ bufferSize = Math.min(minInterval, bufferSize);
+ }
+
+ public final void setPoolSize(int poolSize) {
+ this.poolSize = poolSize;
+ setupOptimizer(poolSize);
+ }
+
+ public void setThreadCount(int threadCount) {
+ this.threadCount = threadCount;
+ setupOptimizer(poolSize);
+ }
+
+ public void setAucEvaluator(OnlineAuc auc) {
+ seed.getPayload().setAucEvaluator(auc);
+ setupOptimizer(poolSize);
+ }
+
+ private void setupOptimizer(int poolSize) {
+ ep = new EvolutionaryProcess<>(threadCount, poolSize, seed);
+ }
+
+ /**
+ * Returns the size of the internal feature vector. Note that this is not the same as the number
+ * of distinct features, especially if feature hashing is being used.
+ *
+ * @return The internal feature vector size.
+ */
+ public int numFeatures() {
+ return numFeatures;
+ }
+
+ /**
+ * What is the AUC for the current best member of the population. If no member is best, usually
+ * because we haven't done any training yet, then the result is set to NaN.
+ *
+ * @return The AUC of the best member of the population or NaN if we can't figure that out.
+ */
+ public double auc() {
+ if (best == null) {
+ return Double.NaN;
+ } else {
+ Wrapper payload = best.getPayload();
+ return payload.getLearner().auc();
+ }
+ }
+
+ public State<Wrapper, CrossFoldLearner> getBest() {
+ return best;
+ }
+
+ public void setBest(State<Wrapper, CrossFoldLearner> best) {
+ this.best = best;
+ }
+
+ public int getRecord() {
+ return record;
+ }
+
+ public void setRecord(int record) {
+ this.record = record;
+ }
+
+ public int getMinInterval() {
+ return minInterval;
+ }
+
+ public int getMaxInterval() {
+ return maxInterval;
+ }
+
+ public int getNumCategories() {
+ return seed.getPayload().getLearner().numCategories();
+ }
+
+ public PriorFunction getPrior() {
+ return seed.getPayload().getLearner().getPrior();
+ }
+
+ public void setBuffer(List<TrainingExample> buffer) {
+ this.buffer = buffer;
+ }
+
+ public List<TrainingExample> getBuffer() {
+ return buffer;
+ }
+
+ public EvolutionaryProcess<Wrapper, CrossFoldLearner> getEp() {
+ return ep;
+ }
+
+ public void setEp(EvolutionaryProcess<Wrapper, CrossFoldLearner> ep) {
+ this.ep = ep;
+ }
+
+ public State<Wrapper, CrossFoldLearner> getSeed() {
+ return seed;
+ }
+
+ public void setSeed(State<Wrapper, CrossFoldLearner> seed) {
+ this.seed = seed;
+ }
+
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+
+ public void setAveragingWindow(int averagingWindow) {
+ seed.getPayload().getLearner().setWindowSize(averagingWindow);
+ setupOptimizer(poolSize);
+ }
+
+ public void setFreezeSurvivors(boolean freezeSurvivors) {
+ this.freezeSurvivors = freezeSurvivors;
+ }
+
+ /**
+ * Provides a shim between the EP optimization stuff and the CrossFoldLearner. The most important
+ * interface has to do with the parameters of the optimization. These are taken from the double[]
+ * params in the following order <ul> <li> regularization constant lambda <li> learningRate </ul>.
+ * All other parameters are set in such a way so as to defeat annealing to the extent possible.
+ * This lets the evolutionary algorithm handle the annealing.
+ * <p/>
+ * Note that per coefficient annealing is still done and no optimization of the per coefficient
+ * offset is done.
+ */
+ public static class Wrapper implements Payload<CrossFoldLearner> {
+ private CrossFoldLearner wrapped;
+
+ public Wrapper() {
+ }
+
+ public Wrapper(int numCategories, int numFeatures, PriorFunction prior) {
+ wrapped = new CrossFoldLearner(5, numCategories, numFeatures, prior);
+ }
+
+ @Override
+ public Wrapper copy() {
+ Wrapper r = new Wrapper();
+ r.wrapped = wrapped.copy();
+ return r;
+ }
+
+ @Override
+ public void update(double[] params) {
+ int i = 0;
+ wrapped.lambda(params[i++]);
+ wrapped.learningRate(params[i]);
+
+ wrapped.stepOffset(1);
+ wrapped.alpha(1);
+ wrapped.decayExponent(0);
+ }
+
+ public static void freeze(State<Wrapper, CrossFoldLearner> s) {
+ // radically decrease learning rate
+ double[] params = s.getParams();
+ params[1] -= 10;
+
+ // and cause evolution to hold (almost)
+ s.setOmni(s.getOmni() / 20);
+ double[] step = s.getStep();
+ for (int i = 0; i < step.length; i++) {
+ step[i] /= 20;
+ }
+ }
+
+ public static void setMappings(State<Wrapper, CrossFoldLearner> x) {
+ int i = 0;
+ // set the range for regularization (lambda)
+ x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1));
+ // set the range for learning rate (mu)
+ x.setMap(i, Mapping.logLimit(1.0e-8, 1));
+ }
+
+ public void train(TrainingExample example) {
+ wrapped.train(example.getKey(), example.getGroupKey(), example.getActual(), example.getInstance());
+ }
+
+ public CrossFoldLearner getLearner() {
+ return wrapped;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "auc=%.2f", wrapped.auc());
+ }
+
+ public void setAucEvaluator(OnlineAuc auc) {
+ wrapped.setAucEvaluator(auc);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ wrapped.write(out);
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ wrapped = new CrossFoldLearner();
+ wrapped.readFields(input);
+ }
+ }
+
+ public static class TrainingExample implements Writable {
+ private long key;
+ private String groupKey;
+ private int actual;
+ private Vector instance;
+
+ private TrainingExample() {
+ }
+
+ public TrainingExample(long key, String groupKey, int actual, Vector instance) {
+ this.key = key;
+ this.groupKey = groupKey;
+ this.actual = actual;
+ this.instance = instance;
+ }
+
+ public long getKey() {
+ return key;
+ }
+
+ public int getActual() {
+ return actual;
+ }
+
+ public Vector getInstance() {
+ return instance;
+ }
+
+ public String getGroupKey() {
+ return groupKey;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeLong(key);
+ if (groupKey != null) {
+ out.writeBoolean(true);
+ out.writeUTF(groupKey);
+ } else {
+ out.writeBoolean(false);
+ }
+ out.writeInt(actual);
+ VectorWritable.writeVector(out, instance, true);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ key = in.readLong();
+ if (in.readBoolean()) {
+ groupKey = in.readUTF();
+ }
+ actual = in.readInt();
+ instance = VectorWritable.readVector(in);
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(record);
+ out.writeInt(cutoff);
+ out.writeInt(minInterval);
+ out.writeInt(maxInterval);
+ out.writeInt(currentStep);
+ out.writeInt(bufferSize);
+
+ out.writeInt(buffer.size());
+ for (TrainingExample example : buffer) {
+ example.write(out);
+ }
+
+ ep.write(out);
+
+ best.write(out);
+
+ out.writeInt(threadCount);
+ out.writeInt(poolSize);
+ seed.write(out);
+ out.writeInt(numFeatures);
+
+ out.writeBoolean(freezeSurvivors);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ record = in.readInt();
+ cutoff = in.readInt();
+ minInterval = in.readInt();
+ maxInterval = in.readInt();
+ currentStep = in.readInt();
+ bufferSize = in.readInt();
+
+ int n = in.readInt();
+ buffer = new ArrayList<>();
+ for (int i = 0; i < n; i++) {
+ TrainingExample example = new TrainingExample();
+ example.readFields(in);
+ buffer.add(example);
+ }
+
+ ep = new EvolutionaryProcess<>();
+ ep.readFields(in);
+
+ best = new State<>();
+ best.readFields(in);
+
+ threadCount = in.readInt();
+ poolSize = in.readInt();
+ seed = new State<>();
+ seed.readFields(in);
+
+ numFeatures = in.readInt();
+ freezeSurvivors = in.readBoolean();
+ }
+}
+