You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/06/22 20:56:19 UTC
[3/4] incubator-joshua git commit: ClassLMs: fixed a bug with
class-based lms not mapping to class ids when estimateCost(). Also refactored
the code a little bit to have StateMinimizingLanguageModels support classes
as well. Added some unit tests. The ex
ClassLMs: fixed a bug with class-based lms not mapping to class ids when estimateCost(). Also refactored the code a little bit to have StateMinimizingLanguageModels support classes as well. Added some unit tests. The existing regression test output was changed to the new output.
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/8fc7544e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/8fc7544e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/8fc7544e
Branch: refs/heads/master
Commit: 8fc7544eaaf35f71367b48778eaa1f22772ca390
Parents: 55e88d1
Author: Felix Hieber <fh...@amazon.com>
Authored: Mon Jun 20 11:21:03 2016 +0200
Committer: Felix Hieber <fh...@amazon.com>
Committed: Tue Jun 21 07:54:44 2016 +0200
----------------------------------------------------------------------
.../apache/joshua/decoder/ff/lm/ClassMap.java | 73 +
.../joshua/decoder/ff/lm/LanguageModelFF.java | 146 +-
.../ff/lm/StateMinimizingLanguageModel.java | 125 +-
.../class_lm/ClassBasedLanguageModelTest.java | 71 +
.../decoder/ff/lm/class_lm/ClassMapTest.java | 67 +
.../resources/bn-en/hiero/joshua-classlm.config | 4 +-
.../resources/bn-en/hiero/output-classlm.gold | 1565 +++---
src/test/resources/lm/class_lm/class.map | 5140 ++++++++++++++++++
.../resources/lm/class_lm/class_lm_9gram.gz | Bin 0 -> 12733137 bytes
9 files changed, 6355 insertions(+), 836 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java b/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java
new file mode 100644
index 0000000..c86d739
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm;
+
+import java.io.IOException;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.collect.ImmutableMap;
+
+public class ClassMap {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ClassMap.class);
+
+ private static final int OOV_ID = Vocabulary.getUnknownId();
+ private final ImmutableMap<Integer, Integer> mapping;
+
+ public ClassMap(String file_name) {
+ this.mapping = read(file_name);
+ LOG.info("{} entries read from class map", this.mapping.size());
+ }
+
+ public int getClassID(int wordID) {
+ return this.mapping.getOrDefault(wordID, OOV_ID);
+ }
+
+ public int size() {
+ return mapping.size();
+ }
+
+ /**
+ * Reads a class map from file_name
+ */
+ private static ImmutableMap<Integer, Integer> read(String file_name) {
+ final ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder();
+ int lineno = 0;
+ try {
+ for (String line : new LineReader(file_name, false)) {
+ lineno++;
+ String[] lineComp = line.trim().split("\\s+");
+ try {
+ builder.put(Vocabulary.id(lineComp[0]), Vocabulary.id(lineComp[1]));
+ } catch (java.lang.ArrayIndexOutOfBoundsException e) {
+ LOG.warn("bad vocab line #{} '{}'. skipping!", lineno, line);
+ LOG.warn(e.getMessage(), e);
+ }
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return builder.build();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java b/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
index 3fea410..9388ed7 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
@@ -21,12 +21,9 @@ package org.apache.joshua.decoder.ff.lm;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
-import com.google.common.primitives.Ints;
-
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.decoder.Support;
@@ -34,7 +31,6 @@ import org.apache.joshua.decoder.chart_parser.SourcePath;
import org.apache.joshua.decoder.ff.FeatureVector;
import org.apache.joshua.decoder.ff.StatefulFF;
import org.apache.joshua.decoder.ff.lm.berkeley_lm.LMGrammarBerkeley;
-import org.apache.joshua.decoder.ff.lm.KenLM;
import org.apache.joshua.decoder.ff.state_maintenance.DPState;
import org.apache.joshua.decoder.ff.state_maintenance.NgramDPState;
import org.apache.joshua.decoder.ff.tm.Rule;
@@ -44,6 +40,9 @@ import org.apache.joshua.util.FormatUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.primitives.Ints;
+
/**
* This class performs the following:
* <ol>
@@ -52,21 +51,21 @@ import org.slf4j.LoggerFactory;
* <li>Gets the LM state</li>
* <li>Gets the left-side LM state estimation score</li>
* </ol>
- *
+ *
* @author Matt Post post@cs.jhu.edu
* @author Juri Ganitkevitch juri@cs.jhu.edu
* @author Zhifei Li, zhifei.work@gmail.com
*/
public class LanguageModelFF extends StatefulFF {
- private static final Logger LOG = LoggerFactory.getLogger(LanguageModelFF.class);
+ static final Logger LOG = LoggerFactory.getLogger(LanguageModelFF.class);
public static int LM_INDEX = 0;
private int startSymbolId;
/**
* N-gram language model. We assume the language model is in ARPA format for equivalent state:
- *
+ *
* <ol>
* <li>We assume it is a backoff lm, and high-order ngram implies low-order ngram; absense of
* low-order ngram implies high-order ngram</li>
@@ -94,61 +93,20 @@ public class LanguageModelFF extends StatefulFF {
protected String path;
/* Whether this is a class-based LM */
- private boolean isClassLM;
+ protected boolean isClassLM;
private ClassMap classMap;
- protected class ClassMap {
-
- private final int OOV_id = Vocabulary.getUnknownId();
- private HashMap<Integer, Integer> classMap;
-
- public ClassMap(String file_name) throws IOException {
- this.classMap = new HashMap<Integer, Integer>();
- read(file_name);
- }
-
- public int getClassID(int wordID) {
- return this.classMap.getOrDefault(wordID, OOV_id);
- }
-
- /**
- * Reads a class map from file.
- *
- * @param file_name
- * @throws IOException
- */
- private void read(String file_name) throws IOException {
-
- int lineno = 0;
- for (String line: new org.apache.joshua.util.io.LineReader(file_name, false)) {
- lineno++;
- String[] lineComp = line.trim().split("\\s+");
- try {
- this.classMap.put(Vocabulary.id(lineComp[0]), Vocabulary.id(lineComp[1]));
- } catch (java.lang.ArrayIndexOutOfBoundsException e) {
- LOG.warn("bad vocab line #{} '{}'", lineno, line);
- LOG.warn(e.getMessage(), e);
- }
- }
- }
-
- }
-
public LanguageModelFF(FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, String.format("lm_%d", LanguageModelFF.LM_INDEX++), args, config);
this.type = parsedArgs.get("lm_type");
- this.ngramOrder = Integer.parseInt(parsedArgs.get("lm_order"));
+ this.ngramOrder = Integer.parseInt(parsedArgs.get("lm_order"));
this.path = parsedArgs.get("lm_file");
- if (parsedArgs.containsKey("class_map"))
- try {
- this.isClassLM = true;
- this.classMap = new ClassMap(parsedArgs.get("class_map"));
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
+ if (parsedArgs.containsKey("class_map")) {
+ this.isClassLM = true;
+ this.classMap = new ClassMap(parsedArgs.get("class_map"));
+ }
// The dense feature initialization hasn't happened yet, so we have to retrieve this as sparse
this.weight = weights.getSparse(name);
@@ -160,7 +118,7 @@ public class LanguageModelFF extends StatefulFF {
public ArrayList<String> reportDenseFeatures(int index) {
denseFeatureIndex = index;
- ArrayList<String> names = new ArrayList<String>();
+ final ArrayList<String> names = new ArrayList<String>(1);
names.add(name);
return names;
}
@@ -191,42 +149,52 @@ public class LanguageModelFF extends StatefulFF {
return this.languageModel;
}
+ public boolean isClassLM() {
+ return this.isClassLM;
+ }
+
public String logString() {
- if (languageModel != null)
- return String.format("%s, order %d (weight %.3f)", name, languageModel.getOrder(), weight);
- else
- return "WHOA";
+ return String.format("%s, order %d (weight %.3f), classLm=%s", name, languageModel.getOrder(), weight, isClassLM);
}
/**
- * Computes the features incurred along this edge. Note that these features are unweighted costs
- * of the feature; they are the feature cost, not the model cost, or the inner product of them.
+ * Computes the features incurred along this edge. Note that these features
+ * are unweighted costs of the feature; they are the feature cost, not the
+ * model cost, or the inner product of them.
*/
@Override
- public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
- Sentence sentence, Accumulator acc) {
-
- NgramDPState newState = null;
- if (rule != null) {
- if (config.source_annotations) {
- // Get source side annotations and project them to the target side
- newState = computeTransition(getTags(rule, i, j, sentence), tailNodes, acc);
- }
- else {
- if (this.isClassLM) {
- // Use a class language model
- // Return target side classes
- newState = computeTransition(getClasses(rule), tailNodes, acc);
- }
- else {
- // Default LM
- newState = computeTransition(rule.getEnglish(), tailNodes, acc);
- }
- }
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j,
+ SourcePath sourcePath, Sentence sentence, Accumulator acc) {
+ if (rule == null) {
+ return null;
}
- return newState;
+ int[] words;
+ if (config.source_annotations) {
+ // get source side annotations and project them to the target side
+ words = getTags(rule, i, j, sentence);
+ } else {
+ words = getRuleIds(rule);
+ }
+
+ return computeTransition(words, tailNodes, acc);
+
+ }
+
+ /**
+ * Retrieve ids from rule. These are either simply the rule ids on the target
+ * side, their corresponding class map ids, or the configured source-side
+ * annotation tags.
+ */
+ @VisibleForTesting
+ public int[] getRuleIds(final Rule rule) {
+ if (this.isClassLM) {
+ // map words to class ids
+ return getClasses(rule);
+ }
+ // Regular LM: use rule word ids
+ return rule.getEnglish();
}
/**
@@ -256,7 +224,7 @@ public class LanguageModelFF extends StatefulFF {
if (alignments[j] == i) {
String annotation = sentence.getAnnotation((int)alignments[i] + begin, "class");
if (annotation != null) {
- // System.err.println(String.format(" word %d source %d abs %d annotation %d/%s",
+ // System.err.println(String.format(" word %d source %d abs %d annotation %d/%s",
// i, alignments[i], alignments[i] + begin, annotation, Vocabulary.word(annotation)));
tokens[i] = Vocabulary.id(annotation);
break;
@@ -270,8 +238,8 @@ public class LanguageModelFF extends StatefulFF {
return tokens;
}
- /**
- * Sets the class map if this is a class LM
+ /**
+ * Sets the class map if this is a class LM
* @param fileName a string path to a file
* @throws IOException if there is an error reading the input file
*/
@@ -314,7 +282,7 @@ public class LanguageModelFF extends StatefulFF {
float estimate = 0.0f;
boolean considerIncompleteNgrams = true;
- int[] enWords = rule.getEnglish();
+ int[] enWords = getRuleIds(rule);
List<Integer> words = new ArrayList<Integer>();
boolean skipStart = (enWords[0] == startSymbolId);
@@ -366,7 +334,7 @@ public class LanguageModelFF extends StatefulFF {
* terminal words in the rule string are preceded by a nonterminal (c) we encounter adjacent
* nonterminals. In all of these situations, the corresponding boundary words of the node in the
* hypergraph represented by the nonterminal must be retrieved.
- *
+ *
* IMPORTANT: only complete n-grams are scored. This means that hypotheses with fewer words
* than the complete n-gram state remain *unscored*. This fact adds a lot of complication to the
* code, including the use of the computeFinal* family of functions, which correct this fact for
@@ -445,7 +413,7 @@ public class LanguageModelFF extends StatefulFF {
* This function differs from regular transitions because we incorporate the cost of incomplete
* left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
* requested when the object was created).
- *
+ *
* @param state the dynamic programming state
* @return the final transition probability (including incomplete n-grams)
*/
@@ -492,7 +460,7 @@ public class LanguageModelFF extends StatefulFF {
* This function is basically a wrapper for NGramLanguageModel::sentenceLogProbability(). It
* computes the probability of a phrase ("chunk"), using lower-order n-grams for the first n-1
* words.
- *
+ *
* @param words
* @param considerIncompleteNgrams
* @param skipStart
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
index 533365c..88dc647 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
@@ -18,7 +18,8 @@
*/
package org.apache.joshua.decoder.ff.lm;
-import java.util.ArrayList;
+import static org.apache.joshua.util.FormatUtils.isNonterminal;
+
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@@ -26,18 +27,16 @@ import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.decoder.chart_parser.SourcePath;
import org.apache.joshua.decoder.ff.FeatureVector;
-import org.apache.joshua.decoder.ff.lm.KenLM;
import org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair;
import org.apache.joshua.decoder.ff.state_maintenance.DPState;
import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
import org.apache.joshua.decoder.ff.tm.Rule;
import org.apache.joshua.decoder.hypergraph.HGNode;
import org.apache.joshua.decoder.segment_file.Sentence;
-import org.apache.joshua.util.FormatUtils;
/**
* Wrapper for KenLM LMs with left-state minimization. We inherit from the regular
- *
+ *
* @author Matt Post post@cs.jhu.edu
* @author Juri Ganitkevitch juri@cs.jhu.edu
*/
@@ -55,61 +54,37 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
throw new RuntimeException(msg);
}
}
-
- @Override
- public ArrayList<String> reportDenseFeatures(int index) {
- denseFeatureIndex = index;
-
- ArrayList<String> names = new ArrayList<String>();
- names.add(name);
- return names;
- }
/**
* Initializes the underlying language model.
*/
@Override
public void initializeLM() {
-
+
// Override type (only KenLM supports left-state minimization)
this.languageModel = new KenLM(ngramOrder, path);
Vocabulary.registerLanguageModel(this.languageModel);
Vocabulary.id(config.default_non_terminal);
-
+
}
-
+
/**
- * Estimates the cost of a rule. We override here since KenLM can do it more efficiently
- * than the default {@link LanguageModelFF} class.
- *
- * Most of this function implementation is redundant with compute().
+ * Estimates the cost of a rule. We override here since KenLM can do it more
+ * efficiently than the default {@link LanguageModelFF} class.
*/
@Override
public float estimateCost(Rule rule, Sentence sentence) {
-
- int[] ruleWords = rule.getEnglish();
-
- // The IDs we'll pass to KenLM
- long[] words = new long[ruleWords.length];
- for (int x = 0; x < ruleWords.length; x++) {
- int id = ruleWords[x];
+ int[] ruleWords = getRuleIds(rule);
- if (FormatUtils.isNonterminal(id)) {
- // For the estimate, we can just mark negative values
- words[x] = -1;
+ // map to ken lm ids
+ final long[] words = mapToKenLmIds(ruleWords, null, true);
- } else {
- // Terminal: just add it
- words[x] = id;
- }
- }
-
// Get the probability of applying the rule and the new state
return weight * ((KenLM) languageModel).estimateRule(words);
}
-
+
/**
* Computes the features incurred along this edge. Note that these features are unweighted costs
* of the feature; they are the feature cost, not the model cost, or the inner product of them.
@@ -118,39 +93,31 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
Sentence sentence, Accumulator acc) {
- int[] ruleWords = config.source_annotations
- ? getTags(rule, i, j, sentence)
- : rule.getEnglish();
-
- // The IDs we'll pass to KenLM
- long[] words = new long[ruleWords.length];
+ if (rule == null ) {
+ return null;
+ }
- for (int x = 0; x < ruleWords.length; x++) {
- int id = ruleWords[x];
+ int[] ruleWords;
+ if (config.source_annotations) {
+ // get source side annotations and project them to the target side
+ ruleWords = getTags(rule, i, j, sentence);
+ } else {
+ ruleWords = getRuleIds(rule);
+ }
- if (FormatUtils.isNonterminal(id)) {
- // Nonterminal: retrieve the KenLM long that records the state
- int index = -(id + 1);
- KenLMState state = (KenLMState) tailNodes.get(index).getDPState(stateIndex);
- words[x] = -state.getState();
+ // map to ken lm ids
+ final long[] words = mapToKenLmIds(ruleWords, tailNodes, false);
- } else {
- // Terminal: just add it
- words[x] = id;
- }
- }
-
- int sentID = sentence.id();
+ final int sentID = sentence.id();
// Since sentId is unique across threads, next operations are safe, but not atomic!
if (!poolMap.containsKey(sentID)) {
poolMap.put(sentID, KenLM.createPool());
}
// Get the probability of applying the rule and the new state
- StateProbPair pair = ((KenLM) languageModel).probRule(words, poolMap.get(sentID));
+ final StateProbPair pair = ((KenLM) languageModel).probRule(words, poolMap.get(sentID));
// Record the prob
-// acc.add(name, pair.prob);
acc.add(denseFeatureIndex, pair.prob);
// Return the state
@@ -158,10 +125,40 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
}
/**
+ * Maps given array of word/class ids to KenLM ids. For estimating cost and computing,
+ * state retrieval differs slightly.
+ */
+ private long[] mapToKenLmIds(int[] ids, List<HGNode> tailNodes, boolean isOnlyEstimate) {
+ // The IDs we will to KenLM
+ long[] kenIds = new long[ids.length];
+ for (int x = 0; x < ids.length; x++) {
+ int id = ids[x];
+
+ if (isNonterminal(id)) {
+
+ if (isOnlyEstimate) {
+ // For the estimate, we can just mark negative values
+ kenIds[x] = -1;
+ } else {
+ // Nonterminal: retrieve the KenLM long that records the state
+ int index = -(id + 1);
+ final KenLMState state = (KenLMState) tailNodes.get(index).getDPState(stateIndex);
+ kenIds[x] = -state.getState();
+ }
+
+ } else {
+ // Terminal: just add it
+ kenIds[x] = id;
+ }
+ }
+ return kenIds;
+ }
+
+ /**
* Destroys the pool created to allocate state for this sentence. Called from the
* {@link org.apache.joshua.decoder.Translation} class after outputting the sentence or k-best list. Hosting
* this map here in KenLMFF statically allows pools to be shared across KenLM instances.
- *
+ *
* @param sentId a key in the poolmap table to destroy
*/
public void destroyPool(int sentId) {
@@ -174,19 +171,13 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
* This function differs from regular transitions because we incorporate the cost of incomplete
* left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
* requested when the object was created).
- *
+ *
* KenLM already includes the prefix probabilities (of shorter n-grams on the left-hand side), so
* there's nothing that needs to be done.
*/
@Override
public DPState computeFinal(HGNode tailNode, int i, int j, SourcePath sourcePath, Sentence sentence,
Accumulator acc) {
-
- // KenLMState state = (KenLMState) tailNode.getDPState(getStateIndex());
-
- // This is unnecessary
- // acc.add(name, 0.0f);
-
// The state is the same since no rule was applied
return new KenLMState();
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java
new file mode 100644
index 0000000..7207d80
--- /dev/null
+++ b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm.class_lm;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.lm.LanguageModelFF;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+public class ClassBasedLanguageModelTest {
+
+ private static final float WEIGHT = 0.5f;
+
+ private LanguageModelFF ff;
+
+ @BeforeMethod
+ public void setUp() {
+ Decoder.resetGlobalState();
+
+ FeatureVector weights = new FeatureVector();
+ weights.set("lm_0", WEIGHT);
+ String[] args = { "-lm_type", "kenlm", "-lm_order", "9",
+ "-lm_file", "./src/test/resources/lm/class_lm/class_lm_9gram.gz",
+ "-class_map", "./src/test/resources/lm/class_lm/class.map" };
+
+ JoshuaConfiguration config = new JoshuaConfiguration();
+ ff = new LanguageModelFF(weights, args, config);
+ }
+
+ @AfterMethod
+ public void tearDown() {
+ Decoder.resetGlobalState();
+ }
+
+ @Test
+ public void givenLmDefinition_whenInitialized_thenInitializationIsCorrect() {
+ assertTrue(ff.isClassLM());
+ assertTrue(ff.isStateful());
+ }
+
+ @Test
+ public void givenRuleWithSingleWord_whenGetRuleId_thenIsMappedToClass() {
+ final int[] target = Vocabulary.addAll(new String[] { "professionalism" });
+ final Rule rule = new Rule(0, null, target, "", 0, 0);
+ assertEquals(Vocabulary.word(ff.getRuleIds(rule)[0]), "13");
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java
new file mode 100644
index 0000000..5d37a05
--- /dev/null
+++ b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm.class_lm;
+
+import static org.testng.Assert.assertEquals;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.ff.lm.ClassMap;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+
+public class ClassMapTest {
+
+ private static final int EXPECTED_CLASS_MAP_SIZE = 5140;
+
+ @BeforeMethod
+ public void setUp() {
+ Decoder.resetGlobalState();
+ }
+
+ @AfterMethod
+ public void tearDown() {
+ Decoder.resetGlobalState();
+ }
+
+ @Test
+ public void givenClassMapFile_whenClassMapRead_thenEntriesAreRead() {
+ // GIVEN
+ final String classMapFile = "./src/test/resources/lm/class_lm/class.map";
+
+ // WHEN
+ final ClassMap classMap = new ClassMap(classMapFile);
+
+ // THEN
+ assertEquals(classMap.size(), EXPECTED_CLASS_MAP_SIZE);
+ assertEquals(
+ Vocabulary.word(
+ classMap.getClassID(
+ Vocabulary.id("professionalism"))),
+ "13");
+ assertEquals(
+ Vocabulary.word(
+ classMap.getClassID(
+ Vocabulary.id("convenience"))),
+ "0");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/test/resources/bn-en/hiero/joshua-classlm.config
----------------------------------------------------------------------
diff --git a/src/test/resources/bn-en/hiero/joshua-classlm.config b/src/test/resources/bn-en/hiero/joshua-classlm.config
index 970b9b7..3be7392 100644
--- a/src/test/resources/bn-en/hiero/joshua-classlm.config
+++ b/src/test/resources/bn-en/hiero/joshua-classlm.config
@@ -1,7 +1,7 @@
-feature-function = LanguageModel -lm_type kenlm -lm_order 5 -minimizing false -lm_file lm.gz
+feature-function = StateMinimizingLanguageModel -lm_type kenlm -lm_order 5 -lm_file lm.gz
# Class LM feature
-feature-function = LanguageModel -lm_type kenlm -lm_order 9 -minimizing false -lm_file class_lm_9gram.gz -lm_class -class_map class.map
+feature-function = StateMinimizingLanguageModel -lm_type kenlm -lm_order 9 -lm_file class_lm_9gram.gz -class_map class.map
###### Old format for lms
# lm = kenlm 5 false false 100 lm.gz