You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/09/25 11:51:44 UTC
svn commit: r1001180 [1/2] - in /mahout/trunk:
core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/
core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/
core/src/main/java/org/apache/mahout/classifier/sgd/
core/src/main/java/org/...
Author: srowen
Date: Sat Sep 25 09:51:42 2010
New Revision: 1001180
URL: http://svn.apache.org/viewvc?rev=1001180&view=rev
Log:
Checkstyle/PMD changes, mostly on new HMM code
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/VarLongWritable.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingContinuousValueEncoder.java
mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingStaticWordValueEncoder.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
mahout/trunk/eclipse/src/main/resources/findbugs-exclude.xml
mahout/trunk/etc/findbugs-exclude.xml
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/hadoop/DatasetSplit.java
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
mahout/trunk/examples/src/test/java/org/apache/mahout/examples/MahoutTestCase.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/engine/MersenneTwister.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/sampling/RandomSamplingAssistant.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/sampling/WeightedRandomSampler.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/matrix/DoubleMatrix2D.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/matrix/impl/AbstractMatrix1D.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/matrix/linalg/Algebra.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/matrix/linalg/LUDecompositionQuick.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/MahoutTestCase.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/nlp/collocations/llr/GramKey.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/vectors/common/PartialVectorMerger.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/MahoutTestCase.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java Sat Sep 25 09:51:42 2010
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
@@ -245,7 +246,7 @@ public final class RecommenderJob extend
FileSystem fs = FileSystem.get(tempDirPath.toUri(), partialMultiplyConf);
prePartialMultiplyPath1 = prePartialMultiplyPath1.makeQualified(fs);
prePartialMultiplyPath2 = prePartialMultiplyPath2.makeQualified(fs);
- SequenceFileInputFormat.setInputPaths(partialMultiply, prePartialMultiplyPath1, prePartialMultiplyPath2);
+ FileInputFormat.setInputPaths(partialMultiply, prePartialMultiplyPath1, prePartialMultiplyPath2);
partialMultiply.waitForCompletion(true);
}
@@ -280,7 +281,7 @@ public final class RecommenderJob extend
FileSystem fs = FileSystem.get(tempDirPath.toUri(), aggregateAndRecommendConf);
partialMultiplyPath = partialMultiplyPath.makeQualified(fs);
explicitFilterPath = explicitFilterPath.makeQualified(fs);
- SequenceFileInputFormat.setInputPaths(aggregateAndRecommend, partialMultiplyPath, explicitFilterPath);
+ FileInputFormat.setInputPaths(aggregateAndRecommend, partialMultiplyPath, explicitFilterPath);
}
setIOSort(aggregateAndRecommend);
aggregateAndRecommendConf.set(AggregateAndRecommendReducer.ITEMID_INDEX_PATH, itemIDIndexPath.toString());
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java Sat Sep 25 09:51:42 2010
@@ -24,8 +24,6 @@ import org.apache.mahout.math.Vector;
/**
* Class containing implementations of the three major HMM algorithms: forward,
* backward and Viterbi
- *
- * @author mheimel
*/
public final class HmmAlgorithms {
@@ -48,8 +46,7 @@ public final class HmmAlgorithms {
public static Matrix forwardAlgorithm(HmmModel model, int[] observations,
boolean scaled) {
- DenseMatrix alpha = new DenseMatrix(observations.length, model
- .getNrOfHiddenStates());
+ Matrix alpha = new DenseMatrix(observations.length, model.getNrOfHiddenStates());
forwardAlgorithm(alpha, model, observations, scaled);
@@ -74,9 +71,9 @@ public final class HmmAlgorithms {
if (scaled) { // compute log scaled alpha values
// Initialization
- for (int i = 0; i < model.getNrOfHiddenStates(); i++)
- alpha.setQuick(0, i, Math.log(ip.getQuick(i)
- * b.getQuick(i, observations[0])));
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ alpha.setQuick(0, i, Math.log(ip.getQuick(i) * b.getQuick(i, observations[0])));
+ }
// Induction
for (int t = 1; t < observations.length; t++) {
@@ -84,10 +81,10 @@ public final class HmmAlgorithms {
double sum = Double.NEGATIVE_INFINITY; // log(0)
for (int j = 0; j < model.getNrOfHiddenStates(); j++) {
double tmp = alpha.getQuick(t - 1, j) + Math.log(a.getQuick(j, i));
- if (tmp > Double.NEGATIVE_INFINITY) // make sure we
- // handle
- // log(0) correctly
+ if (tmp > Double.NEGATIVE_INFINITY) {
+ // make sure we handle log(0) correctly
sum = tmp + Math.log(1 + Math.exp(sum - tmp));
+ }
}
alpha.setQuick(t, i, sum + Math.log(b.getQuick(i, observations[t])));
}
@@ -95,8 +92,9 @@ public final class HmmAlgorithms {
} else {
// Initialization
- for (int i = 0; i < model.getNrOfHiddenStates(); i++)
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
alpha.setQuick(0, i, ip.getQuick(i) * b.getQuick(i, observations[0]));
+ }
// Induction
for (int t = 1; t < observations.length; t++) {
@@ -123,8 +121,7 @@ public final class HmmAlgorithms {
boolean scaled) {
// initialize the matrix
- DenseMatrix beta = new DenseMatrix(observations.length, model
- .getNrOfHiddenStates());
+ Matrix beta = new DenseMatrix(observations.length, model.getNrOfHiddenStates());
// compute the beta factors
backwardAlgorithm(beta, model, observations, scaled);
@@ -147,8 +144,9 @@ public final class HmmAlgorithms {
if (scaled) { // compute log-scaled factors
// initialization
- for (int i = 0; i < model.getNrOfHiddenStates(); i++)
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
beta.setQuick(observations.length - 1, i, 0);
+ }
// induction
for (int t = observations.length - 2; t >= 0; t--) {
@@ -157,23 +155,26 @@ public final class HmmAlgorithms {
for (int j = 0; j < model.getNrOfHiddenStates(); j++) {
double tmp = beta.getQuick(t + 1, j) + Math.log(a.getQuick(i, j))
+ Math.log(b.getQuick(j, observations[t + 1]));
- if (tmp > Double.NEGATIVE_INFINITY) // handle log(0)
+ if (tmp > Double.NEGATIVE_INFINITY) {
+ // handle log(0)
sum = tmp + Math.log(1 + Math.exp(sum - tmp));
+ }
}
beta.setQuick(t, i, sum);
}
}
} else {
// initialization
- for (int i = 0; i < model.getNrOfHiddenStates(); i++)
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
beta.setQuick(observations.length - 1, i, 1);
+ }
// induction
for (int t = observations.length - 2; t >= 0; t--) {
for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
double sum = 0;
- for (int j = 0; j < model.getNrOfHiddenStates(); j++)
- sum += beta.getQuick(t + 1, j) * a.getQuick(i, j)
- * b.getQuick(j, observations[t + 1]);
+ for (int j = 0; j < model.getNrOfHiddenStates(); j++) {
+ sum += beta.getQuick(t + 1, j) * a.getQuick(i, j) * b.getQuick(j, observations[t + 1]);
+ }
beta.setQuick(t, i, sum);
}
}
@@ -292,10 +293,11 @@ public final class HmmAlgorithms {
// find the most likely end state for initialization
double maxProb;
- if (scaled)
+ if (scaled) {
maxProb = Double.NEGATIVE_INFINITY;
- else
- maxProb = 0;
+ } else {
+ maxProb = 0.0;
+ }
for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
if (delta[observations.length - 1][i] > maxProb) {
maxProb = delta[observations.length - 1][i];
@@ -304,8 +306,9 @@ public final class HmmAlgorithms {
}
// now backtrack to find the most likely hidden sequence
- for (int t = observations.length - 2; t >= 0; t--)
+ for (int t = observations.length - 2; t >= 0; t--) {
sequence[t] = phi[t][sequence[t + 1]];
+ }
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java Sat Sep 25 09:51:42 2010
@@ -30,8 +30,6 @@ import org.apache.mahout.math.Vector;
* generated a given sequence of output states (model likelihood). 3) Compute
* the most likely hidden sequence for a given model and a given observed
* sequence (decoding).
- *
- * @author mheimel
*/
public final class HmmEvaluator {
@@ -67,10 +65,11 @@ public final class HmmEvaluator {
public static int[] predict(HmmModel model, int steps, long seed) {
// create the random number generator
Random rand;
- if (seed == 0)
+ if (seed == 0) {
rand = RandomUtils.getRandom();
- else
+ } else {
rand = RandomUtils.getRandom(seed);
+ }
// fetch the cumulative distributions
Vector cip = HmmUtils.getCumulativeInitialProbabilities(model);
Matrix ctm = HmmUtils.getCumulativeTransitionMatrix(model);
@@ -81,8 +80,9 @@ public final class HmmEvaluator {
int hiddenState = 0;
double randnr = rand.nextDouble();
- while (cip.get(hiddenState) < randnr)
+ while (cip.get(hiddenState) < randnr) {
hiddenState++;
+ }
// now draw steps output states according to the cumulative
// distributions
@@ -90,14 +90,16 @@ public final class HmmEvaluator {
// choose output state to given hidden state
randnr = rand.nextDouble();
int outputState = 0;
- while (com.get(hiddenState, outputState) < randnr)
+ while (com.get(hiddenState, outputState) < randnr) {
outputState++;
+ }
result[step] = outputState;
// choose the next hidden state
randnr = rand.nextDouble();
int nextHiddenState = 0;
- while (ctm.get(hiddenState, nextHiddenState) < randnr)
+ while (ctm.get(hiddenState, nextHiddenState) < randnr) {
nextHiddenState++;
+ }
hiddenState = nextHiddenState;
}
return result;
@@ -118,8 +120,7 @@ public final class HmmEvaluator {
*/
public static double modelLikelihood(HmmModel model, int[] outputSequence,
boolean scaled) {
- return modelLikelihood(HmmAlgorithms.forwardAlgorithm(model,
- outputSequence, scaled), scaled);
+ return modelLikelihood(HmmAlgorithms.forwardAlgorithm(model, outputSequence, scaled), scaled);
}
/**
@@ -163,13 +164,11 @@ public final class HmmEvaluator {
int firstOutput = outputSequence[0];
if (scaled) {
for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
- likelihood += pi.getQuick(i) * Math.exp(beta.getQuick(0, i))
- * e.getQuick(i, firstOutput);
+ likelihood += pi.getQuick(i) * Math.exp(beta.getQuick(0, i)) * e.getQuick(i, firstOutput);
}
} else {
for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
- likelihood += pi.getQuick(i) * beta.getQuick(0, i)
- * e.getQuick(i, firstOutput);
+ likelihood += pi.getQuick(i) * beta.getQuick(0, i) * e.getQuick(i, firstOutput);
}
}
return likelihood;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java Sat Sep 25 09:51:42 2010
@@ -44,8 +44,6 @@ import com.google.gson.reflect.TypeToken
/**
* Main class defining a Hidden Markov Model
- *
- * @author mheimel
*/
public class HmmModel implements JsonDeserializer<HmmModel>,
JsonSerializer<HmmModel>, Cloneable {
@@ -60,14 +58,16 @@ public class HmmModel implements JsonDes
/**
* Get a copy of this model
*/
+ @Override
public HmmModel clone() throws CloneNotSupportedException {
super.clone();
- HmmModel model = new HmmModel(transitionMatrix.clone(), emissionMatrix
- .clone(), initialProbabilities.clone());
- if (hiddenStateNames != null)
+ HmmModel model = new HmmModel(transitionMatrix.clone(), emissionMatrix.clone(), initialProbabilities.clone());
+ if (hiddenStateNames != null) {
model.hiddenStateNames = new TreeBidiMap(hiddenStateNames);
- if (outputStateNames != null)
+ }
+ if (outputStateNames != null) {
model.outputStateNames = new TreeBidiMap(outputStateNames);
+ }
return model;
}
@@ -142,10 +142,11 @@ public class HmmModel implements JsonDes
private void initRandomParameters(long seed) {
Random rand;
// initialize the random number generator
- if (seed == 0)
+ if (seed == 0) {
rand = RandomUtils.getRandom();
- else
+ } else {
rand = RandomUtils.getRandom(seed);
+ }
// initialize the initial Probabilities
double sum = 0; // used for normalization
for (int i = 0; i < nrOfHiddenStates; i++) {
@@ -165,8 +166,9 @@ public class HmmModel implements JsonDes
sum += values[j];
}
// normalize the random values to obtain probabilities
- for (int j = 0; j < nrOfHiddenStates; j++)
+ for (int j = 0; j < nrOfHiddenStates; j++) {
values[j] /= sum;
+ }
// set this row of the transition matrix
transitionMatrix.set(i, values);
}
@@ -180,8 +182,9 @@ public class HmmModel implements JsonDes
sum += values[j];
}
// normalize the random values to obtain probabilities
- for (int j = 0; j < nrOfOutputStates; j++)
+ for (int j = 0; j < nrOfOutputStates; j++) {
values[j] /= sum;
+ }
// set this row of the output matrix
emissionMatrix.set(i, values);
}
@@ -280,9 +283,8 @@ public class HmmModel implements JsonDes
*
* @return hidden state names.
*/
- @SuppressWarnings("unchecked")
public Map<String, Integer> getHiddenStateNames() {
- return hiddenStateNames;
+ return (Map<String, Integer>) hiddenStateNames;
}
/**
@@ -306,8 +308,9 @@ public class HmmModel implements JsonDes
* @param stateNames <String,Integer> Map that assigns each state name an integer ID
*/
public void registerHiddenStateNames(Map<String, Integer> stateNames) {
- if (stateNames != null)
+ if (stateNames != null) {
hiddenStateNames = new TreeBidiMap(stateNames);
+ }
}
/**
@@ -318,8 +321,9 @@ public class HmmModel implements JsonDes
* known or no hidden state names were specified
*/
public String getHiddenStateName(int id) {
- if (hiddenStateNames == null)
+ if (hiddenStateNames == null) {
return null;
+ }
return (String) hiddenStateNames.getKey(id);
}
@@ -331,8 +335,9 @@ public class HmmModel implements JsonDes
* known or no hidden state names were specified
*/
public int getHiddenStateID(String name) {
- if (hiddenStateNames == null)
+ if (hiddenStateNames == null) {
return -1;
+ }
Integer tmp = (Integer) hiddenStateNames.get(name);
return (tmp == null) ? -1 : tmp;
}
@@ -347,9 +352,8 @@ public class HmmModel implements JsonDes
*
* @return names of output states.
*/
- @SuppressWarnings("unchecked")
public Map<String, Integer> getOutputStateNames() {
- return outputStateNames;
+ return (Map<String, Integer>) outputStateNames;
}
/**
@@ -373,8 +377,9 @@ public class HmmModel implements JsonDes
* @param stateNames <String,Integer> Map that assigns each state name an integer ID
*/
public void registerOutputStateNames(Map<String, Integer> stateNames) {
- if (stateNames != null)
+ if (stateNames != null) {
outputStateNames = new TreeBidiMap(stateNames);
+ }
}
/**
@@ -385,8 +390,9 @@ public class HmmModel implements JsonDes
* known or no output state names were specified
*/
public String getOutputStateName(int id) {
- if (outputStateNames == null)
+ if (outputStateNames == null) {
return null;
+ }
return (String) outputStateNames.getKey(id);
}
@@ -398,8 +404,9 @@ public class HmmModel implements JsonDes
* known or no output state names were specified
*/
public int getOutputStateID(String name) {
- if (outputStateNames == null)
+ if (outputStateNames == null) {
return -1;
+ }
Integer tmp = (Integer) outputStateNames.get(name);
return (tmp == null) ? -1 : tmp;
}
@@ -435,9 +442,6 @@ public class HmmModel implements JsonDes
private static final String OUTNAMES = "HMMOutNames";
private static final String HIDDENNAMES = "HmmHiddenNames";
- /**
- * {@inheritDoc}
- */
@Override
public HmmModel deserialize(JsonElement json, Type type,
JsonDeserializationContext context) {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java Sat Sep 25 09:51:42 2010
@@ -29,8 +29,6 @@ import org.apache.mahout.math.Vector;
* Class containing several algorithms used to train a Hidden Markov Model. The
* three main algorithms are: supervised learning, unsupervised Viterbi and
* unsupervised Baum-Welch.
- *
- * @author mheimel
*/
public final class HmmTrainer {
@@ -60,10 +58,8 @@ public final class HmmTrainer {
pseudoCount = (pseudoCount == 0) ? Double.MIN_VALUE : pseudoCount;
// initialize the parameters
- DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates,
- nrOfHiddenStates);
- DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates,
- nrOfOutputStates);
+ DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
+ DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
// assign a small initial probability that is larger than zero, so
// unseen states will not get a zero probability
transitionMatrix.assign(pseudoCount);
@@ -81,18 +77,22 @@ public final class HmmTrainer {
for (int i = 0; i < nrOfHiddenStates; i++) {
// compute sum of probabilities for current row of transition matrix
double sum = 0;
- for (int j = 0; j < nrOfHiddenStates; j++)
+ for (int j = 0; j < nrOfHiddenStates; j++) {
sum += transitionMatrix.getQuick(i, j);
+ }
// normalize current row of transition matrix
- for (int j = 0; j < nrOfHiddenStates; j++)
+ for (int j = 0; j < nrOfHiddenStates; j++) {
transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+ }
// compute sum of probabilities for current row of emission matrix
sum = 0;
- for (int j = 0; j < nrOfOutputStates; j++)
+ for (int j = 0; j < nrOfOutputStates; j++) {
sum += emissionMatrix.getQuick(i, j);
+ }
// normalize current row of emission matrix
- for (int j = 0; j < nrOfOutputStates; j++)
+ for (int j = 0; j < nrOfOutputStates; j++) {
emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+ }
}
// return a new model using the parameter estimations
@@ -172,22 +172,27 @@ public final class HmmTrainer {
isum += initialProbabilities.getQuick(i);
// compute sum of probabilities for current row of transition matrix
double sum = 0;
- for (int j = 0; j < nrOfHiddenStates; j++)
+ for (int j = 0; j < nrOfHiddenStates; j++) {
sum += transitionMatrix.getQuick(i, j);
+ }
// normalize current row of transition matrix
- for (int j = 0; j < nrOfHiddenStates; j++)
+ for (int j = 0; j < nrOfHiddenStates; j++) {
transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+ }
// compute sum of probabilities for current row of emission matrix
sum = 0;
- for (int j = 0; j < nrOfOutputStates; j++)
+ for (int j = 0; j < nrOfOutputStates; j++) {
sum += emissionMatrix.getQuick(i, j);
+ }
// normalize current row of emission matrix
- for (int j = 0; j < nrOfOutputStates; j++)
+ for (int j = 0; j < nrOfOutputStates; j++) {
emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+ }
}
// normalize the initial probabilities
- for (int i = 0; i < nrOfHiddenStates; ++i)
+ for (int i = 0; i < nrOfHiddenStates; ++i) {
initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
+ }
// return a new model using the parameter estimates
return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
@@ -254,21 +259,26 @@ public final class HmmTrainer {
for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
double sum = 0;
// normalize the rows of the transition matrix
- for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k)
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
sum += transitionMatrix.getQuick(j, k);
- for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k)
+ }
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
transitionMatrix
.setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+ }
// normalize the rows of the emission matrix
sum = 0;
- for (int k = 0; k < iteration.getNrOfOutputStates(); ++k)
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
sum += emissionMatrix.getQuick(j, k);
- for (int k = 0; k < iteration.getNrOfOutputStates(); ++k)
+ }
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+ }
}
// check for convergence
- if (checkConvergence(lastIteration, iteration, epsilon))
+ if (checkConvergence(lastIteration, iteration, epsilon)) {
break;
+ }
// overwrite the last iterated model by the new iteration
lastIteration.assign(iteration);
}
@@ -303,8 +313,8 @@ public final class HmmTrainer {
// allocate space for baum-welch factors
int hiddenCount = initialModel.getNrOfHiddenStates();
int visibleCount = observedSequence.length;
- DenseMatrix alpha = new DenseMatrix(visibleCount, hiddenCount);
- DenseMatrix beta = new DenseMatrix(visibleCount, hiddenCount);
+ Matrix alpha = new DenseMatrix(visibleCount, hiddenCount);
+ Matrix beta = new DenseMatrix(visibleCount, hiddenCount);
// now run the baum Welch training iteration
for (int it = 0; it < maxIterations; ++it) {
@@ -328,17 +338,21 @@ public final class HmmTrainer {
for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
double sum = 0;
// normalize the rows of the transition matrix
- for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k)
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
sum += transitionMatrix.getQuick(j, k);
- for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k)
+ }
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
transitionMatrix
.setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+ }
// normalize the rows of the emission matrix
sum = 0;
- for (int k = 0; k < iteration.getNrOfOutputStates(); ++k)
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
sum += emissionMatrix.getQuick(j, k);
- for (int k = 0; k < iteration.getNrOfOutputStates(); ++k)
+ }
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+ }
// normalization parameter for initial probabilities
isum += initialProbabilities.getQuick(j);
}
@@ -348,8 +362,9 @@ public final class HmmTrainer {
/ isum);
}
// check for convergence
- if (checkConvergence(lastIteration, iteration, epsilon))
+ if (checkConvergence(lastIteration, iteration, epsilon)) {
break;
+ }
// overwrite the last iterated model by the new iteration
lastIteration.assign(iteration);
}
@@ -357,7 +372,7 @@ public final class HmmTrainer {
return iteration;
}
- private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, DenseMatrix alpha, DenseMatrix beta) {
+ private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
Vector initialProbabilities = iteration.getInitialProbabilities();
Matrix emissionMatrix = iteration.getEmissionMatrix();
Matrix transitionMatrix = iteration.getTransitionMatrix();
@@ -396,7 +411,7 @@ public final class HmmTrainer {
}
}
- private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, DenseMatrix alpha, DenseMatrix beta) {
+ private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
Vector initialProbabilities = iteration.getInitialProbabilities();
Matrix emissionMatrix = iteration.getEmissionMatrix();
Matrix transitionMatrix = iteration.getTransitionMatrix();
@@ -414,9 +429,10 @@ public final class HmmTrainer {
double temp = alpha.getQuick(t, i)
+ Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1]))
+ beta.getQuick(t + 1, j);
- if (temp > Double.NEGATIVE_INFINITY) // handle
- // 0-probabilities
+ if (temp > Double.NEGATIVE_INFINITY) {
+ // handle 0-probabilities
sum = temp + Math.log(1 + Math.exp(sum - temp));
+ }
}
transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
* Math.exp(sum - modelLikelihood));
@@ -430,9 +446,10 @@ public final class HmmTrainer {
// delta tensor
if (observedSequence[t] == j) {
double temp = alpha.getQuick(t, i) + beta.getQuick(t, i);
- if (temp > Double.NEGATIVE_INFINITY) // handle
- // 0-probabilities
+ if (temp > Double.NEGATIVE_INFINITY) {
+ // handle 0-probabilities
sum = temp + Math.log(1 + Math.exp(sum - temp));
+ }
}
}
emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood));
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java Sat Sep 25 09:51:42 2010
@@ -17,8 +17,10 @@
package org.apache.mahout.classifier.sequencelearning.hmm;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
+import java.util.List;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
@@ -30,8 +32,6 @@ import org.uncommons.maths.Maths;
/**
* A collection of utilities for handling HMMModel objects.
- *
- * @author mheimel
*/
public final class HmmUtils {
@@ -264,10 +264,11 @@ public final class HmmUtils {
for (int i = 0; i < sequence.size(); ++i) {
String nextState = seqIter.next();
int nextID;
- if (observed)
+ if (observed) {
nextID = model.getOutputStateID(nextState);
- else
+ } else {
nextID = model.getHiddenStateID(nextState);
+ }
// if the ID is -1, use the default value
encoded[i] = (nextID < 0) ? defaultValue : nextID;
}
@@ -283,18 +284,17 @@ public final class HmmUtils {
* @param observed If set, the sequence is encoded as a sequence of observed states,
* else it is encoded as sequence of hidden states
* @param defaultValue The default value in case a state is not known
- * @return java.util.Vector containing the decoded state names
+ * @return list containing the decoded state names
*/
- public static java.util.Vector<String> decodeStateSequence(HmmModel model,
- int[] sequence, boolean observed, String defaultValue) {
- java.util.Vector<String> decoded = new java.util.Vector<String>(
- sequence.length);
+ public static List<String> decodeStateSequence(HmmModel model, int[] sequence, boolean observed, String defaultValue) {
+ List<String> decoded = new ArrayList<String>(sequence.length);
for (int position : sequence) {
String nextState;
- if (observed)
+ if (observed) {
nextState = model.getOutputStateName(position);
- else
+ } else {
nextState = model.getHiddenStateName(position);
+ }
// if null was returned, use the default value
decoded.add(nextState == null ? defaultValue : nextState);
}
@@ -315,23 +315,28 @@ public final class HmmUtils {
for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
isum += ip.getQuick(i);
double sum = 0;
- for (int j = 0; j < model.getNrOfHiddenStates(); ++j)
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
sum += transition.getQuick(i, j);
+ }
if (sum != 1.0) {
- for (int j = 0; j < model.getNrOfHiddenStates(); ++j)
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
transition.setQuick(i, j, transition.getQuick(i, j) / sum);
+ }
}
sum = 0;
- for (int j = 0; j < model.getNrOfOutputStates(); ++j)
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
sum += emission.getQuick(i, j);
+ }
if (sum != 1.0) {
- for (int j = 0; j < model.getNrOfOutputStates(); ++j)
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
emission.setQuick(i, j, emission.getQuick(i, j) / sum);
+ }
}
}
if (isum != 1.0) {
- for (int i = 0; i < model.getNrOfHiddenStates(); ++i)
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
ip.setQuick(i, ip.getQuick(i) / isum);
+ }
}
}
@@ -358,24 +363,27 @@ public final class HmmUtils {
// now transfer the values
for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
double value = ip.getQuick(i);
- if (value > threshold)
+ if (value > threshold) {
sparseIp.setQuick(i, value);
+ }
for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
value = tr.getQuick(i, j);
- if (value > threshold)
+ if (value > threshold) {
sparseTr.setQuick(i, j, value);
+ }
}
for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
value = em.getQuick(i, j);
- if (value > threshold)
+ if (value > threshold) {
sparseEm.setQuick(i, j, value);
+ }
}
}
// create a new model
HmmModel sparseModel = new HmmModel(sparseTr, sparseEm, sparseIp);
// normalize the model
- HmmUtils.normalizeModel(sparseModel);
+ normalizeModel(sparseModel);
// register the names
sparseModel.registerHiddenStateNames(model.getHiddenStateNames());
sparseModel.registerOutputStateNames(model.getOutputStateNames());
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java Sat Sep 25 09:51:42 2010
@@ -29,6 +29,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
+import java.util.Queue;
import java.util.Set;
/**
@@ -52,7 +53,7 @@ import java.util.Set;
* but instead can be cleared between updates.
*/
public class ModelDissector {
- private Map<String,Vector> weightMap;
+ private final Map<String,Vector> weightMap;
public ModelDissector(int n) {
weightMap = Maps.newHashMap();
@@ -60,16 +61,18 @@ public class ModelDissector {
public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) {
features.assign(0);
- for (String feature : traceDictionary.keySet()) {
- if (!weightMap.containsKey(feature)) {
- for (Integer where : traceDictionary.get(feature)) {
+ for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) {
+ String key = entry.getKey();
+ Set<Integer> value = entry.getValue();
+ if (!weightMap.containsKey(key)) {
+ for (Integer where : value) {
features.set(where, 1);
}
Vector v = learner.classifyNoLink(features);
- weightMap.put(feature, v);
+ weightMap.put(key, v);
- for (Integer where : traceDictionary.get(feature)) {
+ for (Integer where : value) {
features.set(where, 0);
}
}
@@ -78,9 +81,9 @@ public class ModelDissector {
}
public List<Weight> summary(int n) {
- PriorityQueue<Weight> pq = new PriorityQueue<Weight>();
- for (String s : weightMap.keySet()) {
- pq.add(new Weight(s, weightMap.get(s)));
+ Queue<Weight> pq = new PriorityQueue<Weight>();
+ for (Map.Entry<String, Vector> entry : weightMap.entrySet()) {
+ pq.add(new Weight(entry.getKey(), entry.getValue()));
while (pq.size() > n) {
pq.poll();
}
@@ -91,9 +94,9 @@ public class ModelDissector {
}
public static class Weight implements Comparable<Weight> {
- private String feature;
- private double value;
- private int maxIndex;
+ private final String feature;
+ private final double value;
+ private final int maxIndex;
public Weight(String feature, Vector weights) {
this.feature = feature;
@@ -117,10 +120,10 @@ public class ModelDissector {
@Override
public int compareTo(Weight other) {
int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
- if (r != 0) {
- return r;
- } else {
+ if (r == 0) {
return feature.compareTo(other.feature);
+ } else {
+ return r;
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java Sat Sep 25 09:51:42 2010
@@ -84,8 +84,6 @@ public abstract class AbstractCluster im
private transient Vector s2;
- protected static final double SQRT2PI = Math.sqrt(2.0 * Math.PI);
-
/**
* @return the s0
*/
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Sat Sep 25 09:51:42 2010
@@ -90,7 +90,7 @@ public class KMeansDriver extends Abstra
HadoopUtil.overwriteOutput(output);
}
ClassLoader ccl = Thread.currentThread().getContextClassLoader();
- DistanceMeasure measure = ((Class<?>) ccl.loadClass(measureClass)).asSubclass(DistanceMeasure.class).newInstance();
+ DistanceMeasure measure = ccl.loadClass(measureClass).asSubclass(DistanceMeasure.class).newInstance();
if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) {
clusters = RandomSeedGenerator.buildRandom(input, clusters, Integer
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/VarLongWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/VarLongWritable.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/VarLongWritable.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/VarLongWritable.java Sat Sep 25 09:51:42 2010
@@ -44,7 +44,7 @@ public class VarLongWritable implements
@Override
public boolean equals(Object other) {
- return other instanceof VarLongWritable && ((VarLongWritable) other).value == value;
+ return other != null && VarLongWritable.class.equals(other.getClass()) && ((VarLongWritable) other).value == value;
}
@Override
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java Sat Sep 25 09:51:42 2010
@@ -24,7 +24,6 @@ package org.apache.mahout.math.stats;
* as a recommendation system.
*/
public interface OnlineAuc {
- @SuppressWarnings({"UnusedDeclaration"})
double addSample(int category, String groupKey, double score);
double addSample(int category, double score);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingContinuousValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingContinuousValueEncoder.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingContinuousValueEncoder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingContinuousValueEncoder.java Sat Sep 25 09:51:42 2010
@@ -21,7 +21,7 @@ import org.apache.mahout.math.map.OpenIn
public class CachingContinuousValueEncoder extends ContinuousValueEncoder {
- private int dataSize;
+ private final int dataSize;
private OpenIntIntHashMap[] caches;
public CachingContinuousValueEncoder(String name, int dataSize) {
@@ -49,7 +49,8 @@ public class CachingContinuousValueEncod
protected int hashForProbe(String originalForm, int dataSize, String name, int probe) {
if (dataSize != this.dataSize) {
- throw new IllegalArgumentException("dataSize argument [" + dataSize + "] does not match expected dataSize [" + this.dataSize + "]");
+ throw new IllegalArgumentException("dataSize argument ["
+ + dataSize + "] does not match expected dataSize [" + this.dataSize + "]");
}
if (caches[probe].containsKey(originalForm.hashCode())) {
return caches[probe].get(originalForm.hashCode());
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingStaticWordValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingStaticWordValueEncoder.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingStaticWordValueEncoder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/vectors/CachingStaticWordValueEncoder.java Sat Sep 25 09:51:42 2010
@@ -20,7 +20,7 @@ package org.apache.mahout.vectors;
import org.apache.mahout.math.map.OpenIntIntHashMap;
public class CachingStaticWordValueEncoder extends StaticWordValueEncoder {
- private int dataSize;
+ private final int dataSize;
private OpenIntIntHashMap[] caches;
// private TIntIntHashMap[] caches;
@@ -49,7 +49,8 @@ public class CachingStaticWordValueEncod
protected int hashForProbe(String originalForm, int dataSize, String name, int probe) {
if (dataSize != this.dataSize) {
- throw new IllegalArgumentException("dataSize argument [" + dataSize + "] does not match expected dataSize [" + this.dataSize + "]");
+ throw new IllegalArgumentException("dataSize argument ["
+ + dataSize + "] does not match expected dataSize [" + this.dataSize + "]");
}
if (caches[probe].containsKey(originalForm.hashCode())) {
return caches[probe].get(originalForm.hashCode());
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java Sat Sep 25 09:51:42 2010
@@ -17,8 +17,6 @@
package org.apache.mahout.classifier.sequencelearning.hmm;
-import junit.framework.Assert;
-
import org.apache.mahout.math.Matrix;
import org.junit.Test;
@@ -33,7 +31,7 @@ public class HMMAlgorithmsTest extends H
@Test
public void testForwardAlgorithm() {
// intialize the expected alpha values
- double alphaExpectedA[][] = {
+ double[][] alphaExpectedA = {
{0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04,
4.614927e-05},
{0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04,
@@ -43,21 +41,23 @@ public class HMMAlgorithmsTest extends H
{0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00,
2.428986e-05},};
// fetch the alpha matrix using the forward algorithm
- Matrix alpha = HmmAlgorithms.forwardAlgorithm(model, sequence, false);
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false);
// first do some basic checking
- Assert.assertNotNull(alpha);
- Assert.assertEquals(alpha.numCols(), 4);
- Assert.assertEquals(alpha.numRows(), 7);
+ assertNotNull(alpha);
+ assertEquals(4, alpha.numCols());
+ assertEquals(7, alpha.numRows());
// now compare the resulting matrices
- for (int i = 0; i < 4; ++i)
- for (int j = 0; j < 7; ++j)
- Assert.assertEquals(alphaExpectedA[i][j], alpha.get(j, i), 0.00001);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(alphaExpectedA[i][j], alpha.get(j, i), EPSILON);
+ }
+ }
}
@Test
public void testLogScaledForwardAlgorithm() {
// intialize the expected alpha values
- double alphaExpectedA[][] = {
+ double[][] alphaExpectedA = {
{0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04,
4.614927e-05},
{0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04,
@@ -67,16 +67,17 @@ public class HMMAlgorithmsTest extends H
{0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00,
2.428986e-05},};
// fetch the alpha matrix using the forward algorithm
- Matrix alpha = HmmAlgorithms.forwardAlgorithm(model, sequence, true);
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true);
// first do some basic checking
- Assert.assertNotNull(alpha);
- Assert.assertEquals(alpha.numCols(), 4);
- Assert.assertEquals(alpha.numRows(), 7);
+ assertNotNull(alpha);
+ assertEquals(4, alpha.numCols());
+ assertEquals(7, alpha.numRows());
// now compare the resulting matrices
- for (int i = 0; i < 4; ++i)
- for (int j = 0; j < 7; ++j)
- Assert.assertEquals(Math.log(alphaExpectedA[i][j]), alpha.get(j, i),
- 0.00001);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(Math.log(alphaExpectedA[i][j]), alpha.get(j, i), EPSILON);
+ }
+ }
}
/**
@@ -88,42 +89,45 @@ public class HMMAlgorithmsTest extends H
@Test
public void testBackwardAlgorithm() {
// intialize the expected beta values
- double betaExpectedA[][] = {
+ double[][] betaExpectedA = {
{0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1},
{0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1},
{0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1},
{0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}};
// fetch the beta matrix using the backward algorithm
- Matrix beta = HmmAlgorithms.backwardAlgorithm(model, sequence, false);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false);
// first do some basic checking
- Assert.assertNotNull(beta);
- Assert.assertEquals(beta.numCols(), 4);
- Assert.assertEquals(beta.numRows(), 7);
+ assertNotNull(beta);
+ assertEquals(4, beta.numCols());
+ assertEquals(7, beta.numRows());
// now compare the resulting matrices
- for (int i = 0; i < 4; ++i)
- for (int j = 0; j < 7; ++j)
- Assert.assertEquals(betaExpectedA[i][j], beta.get(j, i), 0.00001);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(betaExpectedA[i][j], beta.get(j, i), EPSILON);
+ }
+ }
}
@Test
public void testLogScaledBackwardAlgorithm() {
// intialize the expected beta values
- double betaExpectedA[][] = {
+ double[][] betaExpectedA = {
{0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1},
{0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1},
{0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1},
{0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}};
// fetch the beta matrix using the backward algorithm
- Matrix beta = HmmAlgorithms.backwardAlgorithm(model, sequence, true);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true);
// first do some basic checking
- Assert.assertNotNull(beta);
- Assert.assertEquals(beta.numCols(), 4);
- Assert.assertEquals(beta.numRows(), 7);
+ assertNotNull(beta);
+ assertEquals(4, beta.numCols());
+ assertEquals(7, beta.numRows());
// now compare the resulting matrices
- for (int i = 0; i < 4; ++i)
- for (int j = 0; j < 7; ++j)
- Assert.assertEquals(Math.log(betaExpectedA[i][j]), beta.get(j, i),
- 0.00001);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(Math.log(betaExpectedA[i][j]), beta.get(j, i), EPSILON);
+ }
+ }
}
@Test
@@ -131,13 +135,14 @@ public class HMMAlgorithmsTest extends H
// initialize the expected hidden sequence
int[] expected = {2, 0, 3, 3, 0, 0, 2};
// fetch the viterbi generated sequence
- int[] computed = HmmAlgorithms.viterbiAlgorithm(model, sequence, false);
+ int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), false);
// first make sure we return the correct size
- Assert.assertNotNull(computed);
- Assert.assertEquals(computed.length, sequence.length);
+ assertNotNull(computed);
+ assertEquals(computed.length, getSequence().length);
// now check the contents
- for (int i = 0; i < sequence.length; ++i)
- Assert.assertEquals(expected[i], computed[i]);
+ for (int i = 0; i < getSequence().length; ++i) {
+ assertEquals(expected[i], computed[i]);
+ }
}
@Test
@@ -145,13 +150,14 @@ public class HMMAlgorithmsTest extends H
// initialize the expected hidden sequence
int[] expected = {2, 0, 3, 3, 0, 0, 2};
// fetch the viterbi generated sequence
- int[] computed = HmmAlgorithms.viterbiAlgorithm(model, sequence, true);
+ int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), true);
// first make sure we return the correct size
- Assert.assertNotNull(computed);
- Assert.assertEquals(computed.length, sequence.length);
+ assertNotNull(computed);
+ assertEquals(computed.length, getSequence().length);
// now check the contents
- for (int i = 0; i < sequence.length; ++i)
- Assert.assertEquals(expected[i], computed[i]);
+ for (int i = 0; i < getSequence().length; ++i) {
+ assertEquals(expected[i], computed[i]);
+ }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java Sat Sep 25 09:51:42 2010
@@ -17,8 +17,6 @@
package org.apache.mahout.classifier.sequencelearning.hmm;
-import junit.framework.Assert;
-
import org.apache.mahout.math.Matrix;
import org.junit.Test;
@@ -32,15 +30,15 @@ public class HMMEvaluatorTest extends HM
@Test
public void testModelLikelihood() {
// compute alpha and beta values
- Matrix alpha = HmmAlgorithms.forwardAlgorithm(model, sequence, false);
- Matrix beta = HmmAlgorithms.backwardAlgorithm(model, sequence, false);
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false);
// now test whether forward == backward likelihood
double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
- double backwardLikelihood = HmmEvaluator.modelLikelihood(model, sequence,
+ double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
beta, false);
- Assert.assertEquals(forwardLikelihood, backwardLikelihood, 1e-6);
+ assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
// also make sure that the likelihood matches the expected one
- Assert.assertEquals(forwardLikelihood, 1.8425e-4, 1e-6);
+ assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
}
/**
@@ -51,15 +49,15 @@ public class HMMEvaluatorTest extends HM
@Test
public void testScaledModelLikelihood() {
// compute alpha and beta values
- Matrix alpha = HmmAlgorithms.forwardAlgorithm(model, sequence, true);
- Matrix beta = HmmAlgorithms.backwardAlgorithm(model, sequence, true);
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true);
// now test whether forward == backward likelihood
double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
- double backwardLikelihood = HmmEvaluator.modelLikelihood(model, sequence,
+ double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
beta, true);
- Assert.assertEquals(forwardLikelihood, backwardLikelihood, 1e-6);
+ assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
// also make sure that the likelihood matches the expected one
- Assert.assertEquals(forwardLikelihood, 1.8425e-4, 1e-6);
+ assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java Sat Sep 25 09:51:42 2010
@@ -17,8 +17,6 @@
package org.apache.mahout.classifier.sequencelearning.hmm;
-import junit.framework.Assert;
-
import org.junit.Test;
public class HMMModelTest extends HMMTestBase {
@@ -33,12 +31,12 @@ public class HMMModelTest extends HMMTes
@Test
public void testSerialization() {
- String serialized = model.toJson();
+ String serialized = getModel().toJson();
HmmModel model2 = HmmModel.fromJson(serialized);
String serialized2 = model2.toJson();
// since there are no equals methods for the underlying objects, we
// check identity via the serialization string
- Assert.assertEquals(serialized, serialized2);
+ assertEquals(serialized, serialized2);
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java Sat Sep 25 09:51:42 2010
@@ -6,8 +6,8 @@ import org.apache.mahout.math.DenseVecto
public class HMMTestBase extends MahoutTestCase {
- protected HmmModel model;
- protected int[] sequence = {1, 0, 2, 2, 0, 0, 1};
+ private HmmModel model;
+ private final int[] sequence = {1, 0, 2, 2, 0, 0, 1};
/**
* We initialize a new HMM model using the following parameters # hidden
@@ -27,16 +27,16 @@ public class HMMTestBase extends MahoutT
public void setUp() throws Exception {
super.setUp();
// intialize the hidden/output state names
- String hiddenNames[] = {"H0", "H1", "H2", "H3"};
- String outputNames[] = {"O0", "O1", "O2"};
+ String[] hiddenNames = {"H0", "H1", "H2", "H3"};
+ String[] outputNames = {"O0", "O1", "O2"};
// initialize the transition matrix
- double transitionP[][] = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+ double[][] transitionP = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
{0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
// initialize the emission matrix
- double emissionP[][] = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3},
+ double[][] emissionP = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3},
{0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}};
// initialize the initial probability vector
- double initialP[] = {0.2, 0.1, 0.4, 0.3};
+ double[] initialP = {0.2, 0.1, 0.4, 0.3};
// now generate the model
model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix(
emissionP), new DenseVector(initialP));
@@ -46,4 +46,11 @@ public class HMMTestBase extends MahoutT
HmmUtils.validate(model);
}
+ protected HmmModel getModel() {
+ return model;
+ }
+
+ protected int[] getSequence() {
+ return sequence;
+ }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java Sat Sep 25 09:51:42 2010
@@ -17,8 +17,6 @@
package org.apache.mahout.classifier.sequencelearning.hmm;
-import junit.framework.Assert;
-
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;
@@ -29,32 +27,31 @@ public class HMMTrainerTest extends HMMT
public void testViterbiTraining() {
// initialize the expected model parameters (from R)
// expected transition matrix
- double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125},
+ double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
{0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
{0.5, 0.1, 0.1, 0.3}};
// initialize the emission matrix
- double emissionE[][] = {{0.882353, 0.058824, 0.058824},
+ double[][] emissionE = {{0.882353, 0.058824, 0.058824},
{0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
{0.111111, 0.111111, 0.777778}};
// train the given network to the following output sequence
int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
- HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10,
- false);
+ HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, false);
// now check whether the model matches our expectations
Matrix emissionMatrix = trained.getEmissionMatrix();
Matrix transitionMatrix = trained.getTransitionMatrix();
for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
- for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
- Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
- 0.00001);
-
- for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
- Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
- 0.00001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], EPSILON);
+ }
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], EPSILON);
+ }
}
}
@@ -63,18 +60,18 @@ public class HMMTrainerTest extends HMMT
public void testScaledViterbiTraining() {
// initialize the expected model parameters (from R)
// expected transition matrix
- double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125},
+ double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
{0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
{0.5, 0.1, 0.1, 0.3}};
// initialize the emission matrix
- double emissionE[][] = {{0.882353, 0.058824, 0.058824},
+ double[][] emissionE = {{0.882353, 0.058824, 0.058824},
{0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
{0.111111, 0.111111, 0.777778}};
// train the given network to the following output sequence
int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
- HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10,
+ HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10,
true);
// now check whether the model matches our expectations
@@ -82,13 +79,15 @@ public class HMMTrainerTest extends HMMT
Matrix transitionMatrix = trained.getTransitionMatrix();
for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
- for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
- Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
- 0.00001);
-
- for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
- Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
- 0.00001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+ EPSILON);
+ }
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+ EPSILON);
+ }
}
}
@@ -106,7 +105,7 @@ public class HMMTrainerTest extends HMMT
double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
{0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
- HmmModel trained = HmmTrainer.trainBaumWelch(model, observed, 0.1, 10,
+ HmmModel trained = HmmTrainer.trainBaumWelch(getModel(), observed, 0.1, 10,
false);
Vector initialProbabilities = trained.getInitialProbabilities();
@@ -114,14 +113,16 @@ public class HMMTrainerTest extends HMMT
Matrix transitionMatrix = trained.getTransitionMatrix();
for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
- Assert.assertEquals(initialProbabilities.get(i), initialExpected[i],
+ assertEquals(initialProbabilities.get(i), initialExpected[i],
0.0001);
- for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
- Assert.assertEquals(transitionMatrix.getQuick(i, j),
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j),
transitionExpected[i][j], 0.0001);
- for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
- Assert.assertEquals(emissionMatrix.getQuick(i, j),
+ }
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j),
emissionExpected[i][j], 0.0001);
+ }
}
}
@@ -139,21 +140,23 @@ public class HMMTrainerTest extends HMMT
{0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
HmmModel trained = HmmTrainer
- .trainBaumWelch(model, observed, 0.1, 10, true);
+ .trainBaumWelch(getModel(), observed, 0.1, 10, true);
Vector initialProbabilities = trained.getInitialProbabilities();
Matrix emissionMatrix = trained.getEmissionMatrix();
Matrix transitionMatrix = trained.getTransitionMatrix();
for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
- Assert.assertEquals(initialProbabilities.get(i), initialExpected[i],
+ assertEquals(initialProbabilities.get(i), initialExpected[i],
0.0001);
- for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
- Assert.assertEquals(transitionMatrix.getQuick(i, j),
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j),
transitionExpected[i][j], 0.0001);
- for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
- Assert.assertEquals(emissionMatrix.getQuick(i, j),
+ }
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j),
emissionExpected[i][j], 0.0001);
+ }
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java Sat Sep 25 09:51:42 2010
@@ -18,8 +18,7 @@
package org.apache.mahout.classifier.sequencelearning.hmm;
import java.util.Arrays;
-
-import junit.framework.Assert;
+import java.util.List;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
@@ -29,48 +28,49 @@ import org.junit.Test;
public class HMMUtilsTest extends HMMTestBase {
- Matrix legal2_2;
- Matrix legal2_3;
- Matrix legal3_3;
- Vector legal2;
- Matrix illegal2_2;
+ private Matrix legal22;
+ private Matrix legal23;
+ private Matrix legal33;
+ private Vector legal2;
+ private Matrix illegal22;
+ @Override
public void setUp() throws Exception {
super.setUp();
- legal2_2 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
- legal2_3 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
+ legal22 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
+ legal23 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
{0.3, 0.3, 0.4}});
- legal3_3 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
+ legal33 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
{0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}});
legal2 = new DenseVector(new double[]{0.4, 0.6});
- illegal2_2 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
+ illegal22 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
}
@Test
public void testValidatorLegal() {
- HmmUtils.validate(new HmmModel(legal2_2, legal2_3, legal2));
+ HmmUtils.validate(new HmmModel(legal22, legal23, legal2));
}
@Test
public void testValidatorDimensionError() {
try {
- HmmUtils.validate(new HmmModel(legal3_3, legal2_3, legal2));
+ HmmUtils.validate(new HmmModel(legal33, legal23, legal2));
} catch (IllegalArgumentException e) {
// success
return;
}
- Assert.fail();
+ fail();
}
@Test
public void testValidatorIllegelMatrixError() {
try {
- HmmUtils.validate(new HmmModel(illegal2_2, legal2_3, legal2));
+ HmmUtils.validate(new HmmModel(illegal22, legal23, legal2));
} catch (IllegalArgumentException e) {
// success
return;
}
- Assert.fail();
+ fail();
}
@Test
@@ -78,18 +78,20 @@ public class HMMUtilsTest extends HMMTes
String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"};
String[] outputSequence = {"O1", "O2", "O4", "O0"};
// test encoding the hidden Sequence
- int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays
+ int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
.asList(hiddenSequence), false, -1);
- int[] outputSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays
+ int[] outputSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
.asList(outputSequence), true, -1);
// expected state sequences
int[] hiddenSequenceExp = {1, 2, 0, 3, -1};
int[] outputSequenceExp = {1, 2, -1, 0};
// compare
- for (int i = 0; i < hiddenSequenceEnc.length; ++i)
- Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
- for (int i = 0; i < outputSequenceEnc.length; ++i)
- Assert.assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
+ for (int i = 0; i < hiddenSequenceEnc.length; ++i) {
+ assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
+ }
+ for (int i = 0; i < outputSequenceEnc.length; ++i) {
+ assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
+ }
}
@Test
@@ -97,18 +99,20 @@ public class HMMUtilsTest extends HMMTes
int[] hiddenSequence = {1, 2, 0, 3, 10};
int[] outputSequence = {1, 2, 10, 0};
// test encoding the hidden Sequence
- java.util.Vector<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
- model, hiddenSequence, false, "unknown");
- java.util.Vector<String> outputSequenceDec = HmmUtils.decodeStateSequence(
- model, outputSequence, true, "unknown");
+ List<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
+ getModel(), hiddenSequence, false, "unknown");
+ List<String> outputSequenceDec = HmmUtils.decodeStateSequence(
+ getModel(), outputSequence, true, "unknown");
// expected state sequences
String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"};
String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"};
// compare
- for (int i = 0; i < hiddenSequenceExp.length; ++i)
- Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
- for (int i = 0; i < outputSequenceExp.length; ++i)
- Assert.assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
+ for (int i = 0; i < hiddenSequenceExp.length; ++i) {
+ assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
+ }
+ for (int i = 0; i < outputSequenceExp.length; ++i) {
+ assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
+ }
}
@Test
@@ -141,17 +145,18 @@ public class HMMUtilsTest extends HMMTes
Matrix sparse_tr = sparseModel.getTransitionMatrix();
Matrix sparse_em = sparseModel.getEmissionMatrix();
for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) {
- if (i == 2)
- Assert.assertEquals(1.0, sparse_ip.getQuick(i));
- else
- Assert.assertEquals(0.0, sparse_ip.getQuick(i));
+ if (i == 2) {
+ assertEquals(1.0, sparse_ip.getQuick(i), EPSILON);
+ } else {
+ assertEquals(0.0, sparse_ip.getQuick(i), EPSILON);
+ }
for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) {
if (i == j) {
- Assert.assertEquals(1.0, sparse_tr.getQuick(i, j));
- Assert.assertEquals(1.0, sparse_em.getQuick(i, j));
+ assertEquals(1.0, sparse_tr.getQuick(i, j), EPSILON);
+ assertEquals(1.0, sparse_em.getQuick(i, j), EPSILON);
} else {
- Assert.assertEquals(0.0, sparse_tr.getQuick(i, j));
- Assert.assertEquals(0.0, sparse_em.getQuick(i, j));
+ assertEquals(0.0, sparse_tr.getQuick(i, j), EPSILON);
+ assertEquals(0.0, sparse_em.getQuick(i, j), EPSILON);
}
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java Sat Sep 25 09:51:42 2010
@@ -29,7 +29,7 @@ import org.junit.Before;
public abstract class MahoutTestCase extends org.apache.mahout.math.MahoutTestCase {
/** "Close enough" value for floating-point comparisons. */
- public static final double EPSILON = 0.0000001;
+ public static final double EPSILON = 0.000001;
private Path testTempDirPath;
private FileSystem fs;
Modified: mahout/trunk/eclipse/src/main/resources/findbugs-exclude.xml
URL: http://svn.apache.org/viewvc/mahout/trunk/eclipse/src/main/resources/findbugs-exclude.xml?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/eclipse/src/main/resources/findbugs-exclude.xml (original)
+++ mahout/trunk/eclipse/src/main/resources/findbugs-exclude.xml Sat Sep 25 09:51:42 2010
@@ -1,6 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<FindBugsFilter>
<Match>
+ <Bug pattern="CN_IDIOM_NO_SUPER_CALL"/>
+ </Match>
+ <Match>
<Bug pattern="DLS_DEAD_LOCAL_STORE"/>
</Match>
<Match>
Modified: mahout/trunk/etc/findbugs-exclude.xml
URL: http://svn.apache.org/viewvc/mahout/trunk/etc/findbugs-exclude.xml?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/etc/findbugs-exclude.xml (original)
+++ mahout/trunk/etc/findbugs-exclude.xml Sat Sep 25 09:51:42 2010
@@ -1,6 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<FindBugsFilter>
<Match>
+ <Bug pattern="CN_IDIOM_NO_SUPER_CALL"/>
+ </Match>
+ <Match>
<Bug pattern="DLS_DEAD_LOCAL_STORE"/>
</Match>
<Match>
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java?rev=1001180&r1=1001179&r2=1001180&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java Sat Sep 25 09:51:42 2010
@@ -28,9 +28,9 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.mahout.math.Matrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* This class implements a sample program that uses a pre-tagged training data
@@ -42,11 +42,11 @@ import org.apache.mahout.math.Matrix;
* http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
* details regarding the data files can be found at
* http://flexcrfs.sourceforge.net/#Case_Study
- *
- * @author mheimel
*/
public final class PosTagger {
+ private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
+
/**
* No public constructors for utility classes.
*/
@@ -55,10 +55,6 @@ public final class PosTagger {
}
/**
- * Logger for this class.
- */
- private static final Log LOG = LogFactory.getLog(PosTagger.class);
- /**
* Model trained in the example.
*/
private static HmmModel taggingModel;
@@ -146,10 +142,12 @@ public final class PosTagger {
String[] tags = line.split(" ");
// when analyzing the training set, assign IDs
if (assignIDs) {
- if (!wordIDs.containsKey(tags[0]))
+ if (!wordIDs.containsKey(tags[0])) {
wordIDs.put(tags[0], nextWordId++);
- if (!tagIDs.containsKey(tags[1]))
+ }
+ if (!tagIDs.containsKey(tags[1])) {
tagIDs.put(tags[1], nextTagId++);
+ }
}
// determine the IDs
Integer wordID = wordIDs.get(tags[0]);
@@ -179,16 +177,14 @@ public final class PosTagger {
tagIDs = new HashMap<String, Integer>(44); // we expect 44 distinct tags
wordIDs = new HashMap<String, Integer>(19122); // we expect 19122
// distinct words
- LOG.info("Reading and parsing training data file from URL: " + trainingURL);
+ log.info("Reading and parsing training data file from URL: {}", trainingURL);
long start = System.currentTimeMillis();
readFromURL(trainingURL, true);
long end = System.currentTimeMillis();
- double duration = (end - start) / (double) 1000;
- LOG.info("Parsing done in " + duration + " seconds!");
- LOG.info("Read " + readLines + " lines containing "
- + hiddenSequences.size() + " sentences with a total of "
- + (nextWordId - 1) + " distinct words and " + (nextTagId - 1)
- + " distinct POS tags.");
+ double duration = (end - start) / 1000.0;
+ log.info("Parsing done in {} seconds!", duration);
+ log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.",
+ new Object[] {readLines, hiddenSequences.size(), (nextWordId - 1), (nextTagId - 1)});
start = System.currentTimeMillis();
taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
hiddenSequences, observedSequences, 0.05);
@@ -196,55 +192,53 @@ public final class PosTagger {
// since we assume a higher probability that a given unknown word is NNP
// than anything else
Matrix emissions = taggingModel.getEmissionMatrix();
- for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i)
- emissions.setQuick(i, 0, 0.1 / (double) taggingModel
- .getNrOfHiddenStates());
+ for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) {
+ emissions.setQuick(i, 0, 0.1 / (double) taggingModel.getNrOfHiddenStates());
+ }
int nnptag = tagIDs.get("NNP");
- emissions.setQuick(nnptag, 0, 1 / (double) taggingModel
- .getNrOfHiddenStates());
+ emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates());
// re-normalize the emission probabilities
HmmUtils.normalizeModel(taggingModel);
// now register the names
taggingModel.registerHiddenStateNames(tagIDs);
taggingModel.registerOutputStateNames(wordIDs);
end = System.currentTimeMillis();
- duration = (end - start) / (double) 1000;
- LOG.info("Trained HMM model sin " + duration + " seconds!");
+ duration = (end - start) / 1000.0;
+ log.info("Trained HMM models in {} seconds!", duration);
}
private static void testModel(String testingURL) throws IOException {
- LOG.info("Reading and parsing test data file from URL:" + testingURL);
+ log.info("Reading and parsing test data file from URL:" + testingURL);
long start = System.currentTimeMillis();
readFromURL(testingURL, false);
long end = System.currentTimeMillis();
- double duration = (end - start) / (double) 1000;
- LOG.info("Parsing done in " + duration + " seconds!");
- LOG.info("Read " + readLines + " lines containing "
- + hiddenSequences.size() + " sentences.");
+ double duration = (end - start) / 1000.0;
+ log.info("Parsing done in {} seconds!", duration);
+ log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size());
start = System.currentTimeMillis();
int errorCount = 0;
int totalCount = 0;
for (int i = 0; i < observedSequences.size(); ++i) {
// fetch the viterbi path as the POS tag for this observed sequence
- int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences
- .get(i), false);
+ int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false);
// compare with the expected
int[] posExpected = hiddenSequences.get(i);
for (int j = 0; j < posExpected.length; ++j) {
totalCount++;
- if (posEstimate[j] != posExpected[j])
+ if (posEstimate[j] != posExpected[j]) {
errorCount++;
+ }
}
}
end = System.currentTimeMillis();
- duration = (end - start) / (double) 1000;
- LOG.info("POS tagged test file in " + duration + " seconds!");
+ duration = (end - start) / 1000.0;
+ log.info("POS tagged test file in {} seconds!", duration);
double errorRate = (double) errorCount / (double) totalCount;
- LOG.info("Tagged the test file with an error rate of: " + errorRate);
+ log.info("Tagged the test file with an error rate of: {}", errorRate);
}
- private static java.util.Vector<String> tagSentence(String sentence) {
+ private static List<String> tagSentence(String sentence) {
// first, we need to isolate all punctuation characters, so that they
// can be recognized
sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
@@ -252,14 +246,11 @@ public final class PosTagger {
// now we tokenize the sentence
String[] tokens = sentence.split("[ ]+");
// now generate the observed sequence
- int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays
- .asList(tokens), true, 0);
+ int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0);
// POS tag this observedSequence
- int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence,
- false);
+ int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false);
// and now decode the tag names
- return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false,
- null);
+ return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null);
}
public static void main(String[] args) throws IOException {
@@ -269,10 +260,10 @@ public final class PosTagger {
// tag an exemplary sentence
String test = "McDonalds is a huge company with many employees .";
String[] testWords = test.split(" ");
- java.util.Vector<String> posTags;
- posTags = tagSentence(test);
- for (int i = 0; i < posTags.size(); ++i)
- LOG.info(testWords[i] + "[" + posTags.get(i) + "]");
+ List<String> posTags = tagSentence(test);
+ for (int i = 0; i < posTags.size(); ++i) {
+ log.info("{}[{}]", testWords[i], posTags.get(i));
+ }
}
}