You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by is...@apache.org on 2010/09/24 13:17:14 UTC
svn commit: r1000807 [2/2] - in /mahout/trunk: core/
core/src/main/java/org/apache/mahout/classifier/sequencelearning/
core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/
core/src/test/java/org/apache/mahout/classifier/sequencelearnin...
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java Fri Sep 24 11:17:13 2010
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import junit.framework.Assert;
+
+import org.junit.Test;
+
+public class HMMModelTest extends HMMTestBase {
+
+ @Test
+ public void testRandomModelGeneration() {
+ // make sure we generate a valid random model
+ HmmModel model = new HmmModel(10, 20);
+ // check whether the model is valid
+ HmmUtils.validate(model);
+ }
+
+ @Test
+ public void testSerialization() {
+ String serialized = model.toJson();
+ HmmModel model2 = HmmModel.fromJson(serialized);
+ String serialized2 = model2.toJson();
+ // since there are no equals methods for the underlying objects, we
+ // check identity via the serialization string
+ Assert.assertEquals(serialized, serialized2);
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java Fri Sep 24 11:17:13 2010
@@ -0,0 +1,49 @@
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+
+public class HMMTestBase extends MahoutTestCase {
+
+ protected HmmModel model;
+ protected int[] sequence = {1, 0, 2, 2, 0, 0, 1};
+
+ /**
+ * We initialize a new HMM model using the following parameters # hidden
+ * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") #
+ * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1
+ * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2
+ * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial
+ * probabilities H0 0.2
+ * <p/>
+ * H1 0.1 H2 0.4 H3 0.3
+ * <p/>
+ * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0"
+ * "O1"
+ */
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ // intialize the hidden/output state names
+ String hiddenNames[] = {"H0", "H1", "H2", "H3"};
+ String outputNames[] = {"O0", "O1", "O2"};
+ // initialize the transition matrix
+ double transitionP[][] = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+ {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+ // initialize the emission matrix
+ double emissionP[][] = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3},
+ {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}};
+ // initialize the initial probability vector
+ double initialP[] = {0.2, 0.1, 0.4, 0.3};
+ // now generate the model
+ model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix(
+ emissionP), new DenseVector(initialP));
+ model.registerHiddenStateNames(hiddenNames);
+ model.registerOutputStateNames(outputNames);
+ // make sure the model is valid :)
+ HmmUtils.validate(model);
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java Fri Sep 24 11:17:13 2010
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import junit.framework.Assert;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMTrainerTest extends HMMTestBase {
+
+ @Test
+ public void testViterbiTraining() {
+ // initialize the expected model parameters (from R)
+ // expected transition matrix
+ double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125},
+ {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+ {0.5, 0.1, 0.1, 0.3}};
+ // initialize the emission matrix
+ double emissionE[][] = {{0.882353, 0.058824, 0.058824},
+ {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+ {0.111111, 0.111111, 0.777778}};
+
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10,
+ false);
+
+ // now check whether the model matches our expectations
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+ Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+ 0.00001);
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+ Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+ 0.00001);
+ }
+
+ }
+
+ @Test
+ public void testScaledViterbiTraining() {
+ // initialize the expected model parameters (from R)
+ // expected transition matrix
+ double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125},
+ {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+ {0.5, 0.1, 0.1, 0.3}};
+ // initialize the emission matrix
+ double emissionE[][] = {{0.882353, 0.058824, 0.058824},
+ {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+ {0.111111, 0.111111, 0.777778}};
+
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10,
+ true);
+
+ // now check whether the model matches our expectations
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+ Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+ 0.00001);
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+ Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+ 0.00001);
+ }
+
+ }
+
+ @Test
+ public void testBaumWelchTraining() {
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ // expected values from Matlab HMM package / R HMM package
+ double[] initialExpected = {0, 0, 1.0, 0};
+ double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+ {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+ {0.0024, 0.6657, 0, 0.3319}};
+ double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+ {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+ HmmModel trained = HmmTrainer.trainBaumWelch(model, observed, 0.1, 10,
+ false);
+
+ Vector initialProbabilities = trained.getInitialProbabilities();
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ Assert.assertEquals(initialProbabilities.get(i), initialExpected[i],
+ 0.0001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+ Assert.assertEquals(transitionMatrix.getQuick(i, j),
+ transitionExpected[i][j], 0.0001);
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+ Assert.assertEquals(emissionMatrix.getQuick(i, j),
+ emissionExpected[i][j], 0.0001);
+ }
+ }
+
+ @Test
+ public void testScaledBaumWelchTraining() {
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ // expected values from Matlab HMM package / R HMM package
+ double[] initialExpected = {0, 0, 1.0, 0};
+ double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+ {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+ {0.0024, 0.6657, 0, 0.3319}};
+ double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+ {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+ HmmModel trained = HmmTrainer
+ .trainBaumWelch(model, observed, 0.1, 10, true);
+
+ Vector initialProbabilities = trained.getInitialProbabilities();
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ Assert.assertEquals(initialProbabilities.get(i), initialExpected[i],
+ 0.0001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+ Assert.assertEquals(transitionMatrix.getQuick(i, j),
+ transitionExpected[i][j], 0.0001);
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+ Assert.assertEquals(emissionMatrix.getQuick(i, j),
+ emissionExpected[i][j], 0.0001);
+ }
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java Fri Sep 24 11:17:13 2010
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Arrays;
+
+import junit.framework.Assert;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMUtilsTest extends HMMTestBase {
+
+ Matrix legal2_2;
+ Matrix legal2_3;
+ Matrix legal3_3;
+ Vector legal2;
+ Matrix illegal2_2;
+
+ public void setUp() throws Exception {
+ super.setUp();
+ legal2_2 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
+ legal2_3 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
+ {0.3, 0.3, 0.4}});
+ legal3_3 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
+ {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}});
+ legal2 = new DenseVector(new double[]{0.4, 0.6});
+ illegal2_2 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
+ }
+
+ @Test
+ public void testValidatorLegal() {
+ HmmUtils.validate(new HmmModel(legal2_2, legal2_3, legal2));
+ }
+
+ @Test
+ public void testValidatorDimensionError() {
+ try {
+ HmmUtils.validate(new HmmModel(legal3_3, legal2_3, legal2));
+ } catch (IllegalArgumentException e) {
+ // success
+ return;
+ }
+ Assert.fail();
+ }
+
+ @Test
+ public void testValidatorIllegelMatrixError() {
+ try {
+ HmmUtils.validate(new HmmModel(illegal2_2, legal2_3, legal2));
+ } catch (IllegalArgumentException e) {
+ // success
+ return;
+ }
+ Assert.fail();
+ }
+
+ @Test
+ public void testEncodeStateSequence() {
+ String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"};
+ String[] outputSequence = {"O1", "O2", "O4", "O0"};
+ // test encoding the hidden Sequence
+ int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays
+ .asList(hiddenSequence), false, -1);
+ int[] outputSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays
+ .asList(outputSequence), true, -1);
+ // expected state sequences
+ int[] hiddenSequenceExp = {1, 2, 0, 3, -1};
+ int[] outputSequenceExp = {1, 2, -1, 0};
+ // compare
+ for (int i = 0; i < hiddenSequenceEnc.length; ++i)
+ Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
+ for (int i = 0; i < outputSequenceEnc.length; ++i)
+ Assert.assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
+ }
+
+ @Test
+ public void testDecodeStateSequence() {
+ int[] hiddenSequence = {1, 2, 0, 3, 10};
+ int[] outputSequence = {1, 2, 10, 0};
+ // test encoding the hidden Sequence
+ java.util.Vector<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
+ model, hiddenSequence, false, "unknown");
+ java.util.Vector<String> outputSequenceDec = HmmUtils.decodeStateSequence(
+ model, outputSequence, true, "unknown");
+ // expected state sequences
+ String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"};
+ String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"};
+ // compare
+ for (int i = 0; i < hiddenSequenceExp.length; ++i)
+ Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
+ for (int i = 0; i < outputSequenceExp.length; ++i)
+ Assert.assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
+ }
+
+ @Test
+ public void testNormalizeModel() {
+ DenseVector ip = new DenseVector(new double[]{10, 20});
+ DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}});
+ DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}});
+ HmmModel model = new HmmModel(tr, em, ip);
+ HmmUtils.normalizeModel(model);
+ // the model should be valid now
+ HmmUtils.validate(model);
+ }
+
+ @Test
+ public void testTruncateModel() {
+ DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998});
+ DenseMatrix tr = new DenseMatrix(new double[][]{
+ {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+ {0.0001, 0.0001, 0.9998}});
+ DenseMatrix em = new DenseMatrix(new double[][]{
+ {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+ {0.0001, 0.0001, 0.9998}});
+ HmmModel model = new HmmModel(tr, em, ip);
+ // now truncate the model
+ HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01);
+ // first make sure this is a valid model
+ HmmUtils.validate(sparseModel);
+ // now check whether the values are as expected
+ Vector sparse_ip = sparseModel.getInitialProbabilities();
+ Matrix sparse_tr = sparseModel.getTransitionMatrix();
+ Matrix sparse_em = sparseModel.getEmissionMatrix();
+ for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) {
+ if (i == 2)
+ Assert.assertEquals(1.0, sparse_ip.getQuick(i));
+ else
+ Assert.assertEquals(0.0, sparse_ip.getQuick(i));
+ for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) {
+ if (i == j) {
+ Assert.assertEquals(1.0, sparse_tr.getQuick(i, j));
+ Assert.assertEquals(1.0, sparse_em.getQuick(i, j));
+ } else {
+ Assert.assertEquals(0.0, sparse_tr.getQuick(i, j));
+ Assert.assertEquals(0.0, sparse_em.getQuick(i, j));
+ }
+ }
+ }
+ }
+
+}
Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java Fri Sep 24 11:17:13 2010
@@ -0,0 +1,278 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.net.URL;
+import java.net.URLConnection;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.mahout.math.Matrix;
+
+/**
+ * This class implements a sample program that uses a pre-tagged training data
+ * set to train an HMM model as a POS tagger. The training data is automatically
+ * downloaded from the following URL:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then
+ * trains an HMM Model using supervised learning and tests the model on the
+ * following test data set:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
+ * details regarding the data files can be found at
+ * http://flexcrfs.sourceforge.net/#Case_Study
+ *
+ * @author mheimel
+ */
+public final class PosTagger {
+
+ /**
+ * No public constructors for utility classes.
+ */
+ private PosTagger() {
+ // nothing to do here really.
+ }
+
+ /**
+ * Logger for this class.
+ */
+ private static final Log LOG = LogFactory.getLog(PosTagger.class);
+ /**
+ * Model trained in the example.
+ */
+ private static HmmModel taggingModel;
+
+ /**
+ * Map for storing the IDs for the POS tags (hidden states)
+ */
+ private static Map<String, Integer> tagIDs;
+
+ /**
+ * Counter for the next assigned POS tag ID The value of 0 is reserved for
+ * "unknown POS tag"
+ */
+ private static int nextTagId;
+
+ /**
+ * Map for storing the IDs for observed words (observed states)
+ */
+ private static Map<String, Integer> wordIDs;
+
+ /**
+ * Counter for the next assigned word ID The value of 0 is reserved for
+ * "unknown word"
+ */
+ private static int nextWordId = 1; // 0 is reserved for "unknown word"
+
+ /**
+ * Used for storing a list of POS tags of read sentences.
+ */
+ private static List<int[]> hiddenSequences;
+
+ /**
+ * Used for storing a list of word tags of read sentences.
+ */
+ private static List<int[]> observedSequences;
+
+ /**
+ * number of read lines
+ */
+ private static int readLines;
+
+ /**
+ * Given an URL, this function fetches the data file, parses it, assigns POS
+ * Tag/word IDs and fills the hiddenSequences/observedSequences lists with
+ * data from those files. The data is expected to be in the following format
+ * (one word per line): word pos-tag np-tag sentences are closed with the .
+ * pos tag
+ *
+ * @param url Where the data file is stored
+ * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for
+ * training data, not needed for test data)
+ * @throws IOException in case data file cannot be read.
+ */
+ private static void readFromURL(String url, boolean assignIDs) throws IOException {
+ URLConnection connection = (new URL(url)).openConnection();
+ BufferedReader input = new BufferedReader(new InputStreamReader(connection.getInputStream()));
+ // initialize the data structure
+ hiddenSequences = new LinkedList<int[]>();
+ observedSequences = new LinkedList<int[]>();
+ readLines = 0;
+
+ // now read line by line of the input file
+ String line;
+ List<Integer> observedSequence = new LinkedList<Integer>();
+ List<Integer> hiddenSequence = new LinkedList<Integer>();
+ while ((line = input.readLine()) != null) {
+ if (line.isEmpty()) {
+ // new sentence starts
+ int[] observedSequenceArray = new int[observedSequence.size()];
+ int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+ for (int i = 0; i < observedSequence.size(); ++i) {
+ observedSequenceArray[i] = observedSequence.get(i);
+ hiddenSequenceArray[i] = hiddenSequence.get(i);
+ }
+ // now register those arrays
+ hiddenSequences.add(hiddenSequenceArray);
+ observedSequences.add(observedSequenceArray);
+ // and reset the linked lists
+ observedSequence.clear();
+ hiddenSequence.clear();
+ continue;
+ }
+ readLines++;
+ // we expect the format [word] [POS tag] [NP tag]
+ String[] tags = line.split(" ");
+ // when analyzing the training set, assign IDs
+ if (assignIDs) {
+ if (!wordIDs.containsKey(tags[0]))
+ wordIDs.put(tags[0], nextWordId++);
+ if (!tagIDs.containsKey(tags[1]))
+ tagIDs.put(tags[1], nextTagId++);
+ }
+ // determine the IDs
+ Integer wordID = wordIDs.get(tags[0]);
+ Integer tagID = tagIDs.get(tags[1]);
+ // handle unknown values
+ wordID = (wordID == null) ? 0 : wordID;
+ tagID = (tagID == null) ? 0 : tagID;
+ // now construct the current sequence
+ observedSequence.add(wordID);
+ hiddenSequence.add(tagID);
+ }
+ // if there is still something in the pipe, register it
+ if (!observedSequence.isEmpty()) {
+ int[] observedSequenceArray = new int[observedSequence.size()];
+ int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+ for (int i = 0; i < observedSequence.size(); ++i) {
+ observedSequenceArray[i] = observedSequence.get(i);
+ hiddenSequenceArray[i] = hiddenSequence.get(i);
+ }
+ // now register those arrays
+ hiddenSequences.add(hiddenSequenceArray);
+ observedSequences.add(observedSequenceArray);
+ }
+ }
+
+ private static void trainModel(String trainingURL) throws IOException {
+ tagIDs = new HashMap<String, Integer>(44); // we expect 44 distinct tags
+ wordIDs = new HashMap<String, Integer>(19122); // we expect 19122
+ // distinct words
+ LOG.info("Reading and parsing training data file from URL: " + trainingURL);
+ long start = System.currentTimeMillis();
+ readFromURL(trainingURL, true);
+ long end = System.currentTimeMillis();
+ double duration = (end - start) / (double) 1000;
+ LOG.info("Parsing done in " + duration + " seconds!");
+ LOG.info("Read " + readLines + " lines containing "
+ + hiddenSequences.size() + " sentences with a total of "
+ + (nextWordId - 1) + " distinct words and " + (nextTagId - 1)
+ + " distinct POS tags.");
+ start = System.currentTimeMillis();
+ taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
+ hiddenSequences, observedSequences, 0.05);
+ // we have to adjust the model a bit,
+ // since we assume a higher probability that a given unknown word is NNP
+ // than anything else
+ Matrix emissions = taggingModel.getEmissionMatrix();
+ for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i)
+ emissions.setQuick(i, 0, 0.1 / (double) taggingModel
+ .getNrOfHiddenStates());
+ int nnptag = tagIDs.get("NNP");
+ emissions.setQuick(nnptag, 0, 1 / (double) taggingModel
+ .getNrOfHiddenStates());
+ // re-normalize the emission probabilities
+ HmmUtils.normalizeModel(taggingModel);
+ // now register the names
+ taggingModel.registerHiddenStateNames(tagIDs);
+ taggingModel.registerOutputStateNames(wordIDs);
+ end = System.currentTimeMillis();
+ duration = (end - start) / (double) 1000;
+ LOG.info("Trained HMM model sin " + duration + " seconds!");
+ }
+
+ private static void testModel(String testingURL) throws IOException {
+ LOG.info("Reading and parsing test data file from URL:" + testingURL);
+ long start = System.currentTimeMillis();
+ readFromURL(testingURL, false);
+ long end = System.currentTimeMillis();
+ double duration = (end - start) / (double) 1000;
+ LOG.info("Parsing done in " + duration + " seconds!");
+ LOG.info("Read " + readLines + " lines containing "
+ + hiddenSequences.size() + " sentences.");
+
+ start = System.currentTimeMillis();
+ int errorCount = 0;
+ int totalCount = 0;
+ for (int i = 0; i < observedSequences.size(); ++i) {
+ // fetch the viterbi path as the POS tag for this observed sequence
+ int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences
+ .get(i), false);
+ // compare with the expected
+ int[] posExpected = hiddenSequences.get(i);
+ for (int j = 0; j < posExpected.length; ++j) {
+ totalCount++;
+ if (posEstimate[j] != posExpected[j])
+ errorCount++;
+ }
+ }
+ end = System.currentTimeMillis();
+ duration = (end - start) / (double) 1000;
+ LOG.info("POS tagged test file in " + duration + " seconds!");
+ double errorRate = (double) errorCount / (double) totalCount;
+ LOG.info("Tagged the test file with an error rate of: " + errorRate);
+ }
+
+ private static java.util.Vector<String> tagSentence(String sentence) {
+ // first, we need to isolate all punctuation characters, so that they
+ // can be recognized
+ sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
+ sentence = sentence.replaceAll("''", " '' ");
+ // now we tokenize the sentence
+ String[] tokens = sentence.split("[ ]+");
+ // now generate the observed sequence
+ int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays
+ .asList(tokens), true, 0);
+ // POS tag this observedSequence
+ int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence,
+ false);
+ // and now decode the tag names
+ return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false,
+ null);
+ }
+
+ public static void main(String[] args) throws IOException {
+ // generate the model from URL
+ trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
+ testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
+ // tag an exemplary sentence
+ String test = "McDonalds is a huge company with many employees .";
+ String[] testWords = test.split(" ");
+ java.util.Vector<String> posTags;
+ posTags = tagSentence(test);
+ for (int i = 0; i < posTags.size(); ++i)
+ LOG.info(testWords[i] + "[" + posTags.get(i) + "]");
+ }
+
+}