You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/06/22 20:56:19 UTC

[3/4] incubator-joshua git commit: ClassLMs: fixed a bug with class-based lms not mapping to class ids when estimateCost(). Also refactored the code a little bit to have StateMinimizingLanguageModels support classes as well. Added some unit tests. The ex

ClassLMs: fixed a bug with class-based lms not mapping to class ids when estimateCost(). Also refactored the code a little bit to have StateMinimizingLanguageModels support classes as well. Added some unit tests. The existing regression test output was changed to the new output.


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/8fc7544e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/8fc7544e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/8fc7544e

Branch: refs/heads/master
Commit: 8fc7544eaaf35f71367b48778eaa1f22772ca390
Parents: 55e88d1
Author: Felix Hieber <fh...@amazon.com>
Authored: Mon Jun 20 11:21:03 2016 +0200
Committer: Felix Hieber <fh...@amazon.com>
Committed: Tue Jun 21 07:54:44 2016 +0200

----------------------------------------------------------------------
 .../apache/joshua/decoder/ff/lm/ClassMap.java   |   73 +
 .../joshua/decoder/ff/lm/LanguageModelFF.java   |  146 +-
 .../ff/lm/StateMinimizingLanguageModel.java     |  125 +-
 .../class_lm/ClassBasedLanguageModelTest.java   |   71 +
 .../decoder/ff/lm/class_lm/ClassMapTest.java    |   67 +
 .../resources/bn-en/hiero/joshua-classlm.config |    4 +-
 .../resources/bn-en/hiero/output-classlm.gold   | 1565 +++---
 src/test/resources/lm/class_lm/class.map        | 5140 ++++++++++++++++++
 .../resources/lm/class_lm/class_lm_9gram.gz     |  Bin 0 -> 12733137 bytes
 9 files changed, 6355 insertions(+), 836 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java b/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java
new file mode 100644
index 0000000..c86d739
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/ClassMap.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm;
+
+import java.io.IOException;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.collect.ImmutableMap;
+
+public class ClassMap {
+
+  private static final Logger LOG = LoggerFactory.getLogger(ClassMap.class);
+
+  private static final int OOV_ID = Vocabulary.getUnknownId();
+  private final ImmutableMap<Integer, Integer> mapping;
+
+  public ClassMap(String file_name) {
+    this.mapping = read(file_name);
+    LOG.info("{} entries read from class map", this.mapping.size());
+  }
+
+  public int getClassID(int wordID) {
+    return this.mapping.getOrDefault(wordID, OOV_ID);
+  }
+
+  public int size() {
+    return mapping.size();
+  }
+
+  /**
+   * Reads a class map from file_name
+   */
+  private static ImmutableMap<Integer, Integer> read(String file_name) {
+    final ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder();
+    int lineno = 0;
+    try {
+      for (String line : new LineReader(file_name, false)) {
+        lineno++;
+        String[] lineComp = line.trim().split("\\s+");
+        try {
+          builder.put(Vocabulary.id(lineComp[0]), Vocabulary.id(lineComp[1]));
+        } catch (java.lang.ArrayIndexOutOfBoundsException e) {
+          LOG.warn("bad vocab line #{} '{}'. skipping!", lineno, line);
+          LOG.warn(e.getMessage(), e);
+        }
+      }
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+    return builder.build();
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java b/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
index 3fea410..9388ed7 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
@@ -21,12 +21,9 @@ package org.apache.joshua.decoder.ff.lm;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 
-import com.google.common.primitives.Ints;
-
 import org.apache.joshua.corpus.Vocabulary;
 import org.apache.joshua.decoder.JoshuaConfiguration;
 import org.apache.joshua.decoder.Support;
@@ -34,7 +31,6 @@ import org.apache.joshua.decoder.chart_parser.SourcePath;
 import org.apache.joshua.decoder.ff.FeatureVector;
 import org.apache.joshua.decoder.ff.StatefulFF;
 import org.apache.joshua.decoder.ff.lm.berkeley_lm.LMGrammarBerkeley;
-import org.apache.joshua.decoder.ff.lm.KenLM;
 import org.apache.joshua.decoder.ff.state_maintenance.DPState;
 import org.apache.joshua.decoder.ff.state_maintenance.NgramDPState;
 import org.apache.joshua.decoder.ff.tm.Rule;
@@ -44,6 +40,9 @@ import org.apache.joshua.util.FormatUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.primitives.Ints;
+
 /**
  * This class performs the following:
  * <ol>
@@ -52,21 +51,21 @@ import org.slf4j.LoggerFactory;
  * <li>Gets the LM state</li>
  * <li>Gets the left-side LM state estimation score</li>
  * </ol>
- * 
+ *
  * @author Matt Post post@cs.jhu.edu
  * @author Juri Ganitkevitch juri@cs.jhu.edu
  * @author Zhifei Li, zhifei.work@gmail.com
  */
 public class LanguageModelFF extends StatefulFF {
 
-  private static final Logger LOG = LoggerFactory.getLogger(LanguageModelFF.class);
+  static final Logger LOG = LoggerFactory.getLogger(LanguageModelFF.class);
 
   public static int LM_INDEX = 0;
   private int startSymbolId;
 
   /**
    * N-gram language model. We assume the language model is in ARPA format for equivalent state:
-   * 
+   *
    * <ol>
    * <li>We assume it is a backoff lm, and high-order ngram implies low-order ngram; absense of
    * low-order ngram implies high-order ngram</li>
@@ -94,61 +93,20 @@ public class LanguageModelFF extends StatefulFF {
   protected String path;
 
   /* Whether this is a class-based LM */
-  private boolean isClassLM;
+  protected boolean isClassLM;
   private ClassMap classMap;
 
-  protected class ClassMap {
-
-    private final int OOV_id = Vocabulary.getUnknownId();
-    private HashMap<Integer, Integer> classMap;
-
-    public ClassMap(String file_name) throws IOException {
-      this.classMap = new HashMap<Integer, Integer>();
-      read(file_name);
-    }
-
-    public int getClassID(int wordID) {
-      return this.classMap.getOrDefault(wordID, OOV_id);
-    }
-
-    /**
-     * Reads a class map from file.
-     * 
-     * @param file_name
-     * @throws IOException
-     */
-    private void read(String file_name) throws IOException {
-
-      int lineno = 0;
-      for (String line: new org.apache.joshua.util.io.LineReader(file_name, false)) {
-        lineno++;
-        String[] lineComp = line.trim().split("\\s+");
-        try {
-          this.classMap.put(Vocabulary.id(lineComp[0]), Vocabulary.id(lineComp[1]));
-        } catch (java.lang.ArrayIndexOutOfBoundsException e) {
-          LOG.warn("bad vocab line #{} '{}'", lineno, line);
-          LOG.warn(e.getMessage(), e);
-        }
-      }
-    }
-
-  }
-
   public LanguageModelFF(FeatureVector weights, String[] args, JoshuaConfiguration config) {
     super(weights, String.format("lm_%d", LanguageModelFF.LM_INDEX++), args, config);
 
     this.type = parsedArgs.get("lm_type");
-    this.ngramOrder = Integer.parseInt(parsedArgs.get("lm_order")); 
+    this.ngramOrder = Integer.parseInt(parsedArgs.get("lm_order"));
     this.path = parsedArgs.get("lm_file");
 
-    if (parsedArgs.containsKey("class_map"))
-      try {
-        this.isClassLM = true;
-        this.classMap = new ClassMap(parsedArgs.get("class_map"));
-      } catch (IOException e) {
-        // TODO Auto-generated catch block
-        e.printStackTrace();
-      }
+    if (parsedArgs.containsKey("class_map")) {
+      this.isClassLM = true;
+      this.classMap = new ClassMap(parsedArgs.get("class_map"));
+    }
 
     // The dense feature initialization hasn't happened yet, so we have to retrieve this as sparse
     this.weight = weights.getSparse(name);
@@ -160,7 +118,7 @@ public class LanguageModelFF extends StatefulFF {
   public ArrayList<String> reportDenseFeatures(int index) {
     denseFeatureIndex = index;
 
-    ArrayList<String> names = new ArrayList<String>();
+    final ArrayList<String> names = new ArrayList<String>(1);
     names.add(name);
     return names;
   }
@@ -191,42 +149,52 @@ public class LanguageModelFF extends StatefulFF {
     return this.languageModel;
   }
 
+  public boolean isClassLM() {
+  	return this.isClassLM;
+  }
+
   public String logString() {
-    if (languageModel != null)
-      return String.format("%s, order %d (weight %.3f)", name, languageModel.getOrder(), weight);
-    else
-      return "WHOA";
+    return String.format("%s, order %d (weight %.3f), classLm=%s", name, languageModel.getOrder(), weight, isClassLM);
   }
 
   /**
-   * Computes the features incurred along this edge. Note that these features are unweighted costs
-   * of the feature; they are the feature cost, not the model cost, or the inner product of them.
+   * Computes the features incurred along this edge. Note that these features
+   * are unweighted costs of the feature; they are the feature cost, not the
+   * model cost, or the inner product of them.
    */
   @Override
-  public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
-      Sentence sentence, Accumulator acc) {
-
-    NgramDPState newState = null;
-    if (rule != null) {
-      if (config.source_annotations) {
-        // Get source side annotations and project them to the target side
-        newState = computeTransition(getTags(rule, i, j, sentence), tailNodes, acc);
-      }
-      else {
-        if (this.isClassLM) {
-          // Use a class language model
-          // Return target side classes
-          newState = computeTransition(getClasses(rule), tailNodes, acc);
-        }
-        else {
-          // Default LM 
-          newState = computeTransition(rule.getEnglish(), tailNodes, acc);
-        }
-      }
+  public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j,
+    SourcePath sourcePath, Sentence sentence, Accumulator acc) {
 
+    if (rule == null) {
+      return null;
     }
 
-    return newState;
+    int[] words;
+    if (config.source_annotations) {
+      // get source side annotations and project them to the target side
+      words = getTags(rule, i, j, sentence);
+    } else {
+      words = getRuleIds(rule);
+    }
+
+    return computeTransition(words, tailNodes, acc);
+
+	}
+
+  /**
+   * Retrieve ids from rule. These are either simply the rule ids on the target
+   * side, their corresponding class map ids, or the configured source-side
+   * annotation tags.
+   */
+  @VisibleForTesting
+  public int[] getRuleIds(final Rule rule) {
+    if (this.isClassLM) {
+      // map words to class ids
+      return getClasses(rule);
+    }
+    // Regular LM: use rule word ids
+    return rule.getEnglish();
   }
 
   /**
@@ -256,7 +224,7 @@ public class LanguageModelFF extends StatefulFF {
             if (alignments[j] == i) {
               String annotation = sentence.getAnnotation((int)alignments[i] + begin, "class");
               if (annotation != null) {
-                //                System.err.println(String.format("  word %d source %d abs %d annotation %d/%s", 
+                //                System.err.println(String.format("  word %d source %d abs %d annotation %d/%s",
                 //                    i, alignments[i], alignments[i] + begin, annotation, Vocabulary.word(annotation)));
                 tokens[i] = Vocabulary.id(annotation);
                 break;
@@ -270,8 +238,8 @@ public class LanguageModelFF extends StatefulFF {
     return tokens;
   }
 
-  /** 
-   * Sets the class map if this is a class LM 
+  /**
+   * Sets the class map if this is a class LM
    * @param fileName a string path to a file
    * @throws IOException if there is an error reading the input file
    */
@@ -314,7 +282,7 @@ public class LanguageModelFF extends StatefulFF {
     float estimate = 0.0f;
     boolean considerIncompleteNgrams = true;
 
-    int[] enWords = rule.getEnglish();
+    int[] enWords = getRuleIds(rule);
 
     List<Integer> words = new ArrayList<Integer>();
     boolean skipStart = (enWords[0] == startSymbolId);
@@ -366,7 +334,7 @@ public class LanguageModelFF extends StatefulFF {
    * terminal words in the rule string are preceded by a nonterminal (c) we encounter adjacent
    * nonterminals. In all of these situations, the corresponding boundary words of the node in the
    * hypergraph represented by the nonterminal must be retrieved.
-   * 
+   *
    * IMPORTANT: only complete n-grams are scored. This means that hypotheses with fewer words
    * than the complete n-gram state remain *unscored*. This fact adds a lot of complication to the
    * code, including the use of the computeFinal* family of functions, which correct this fact for
@@ -445,7 +413,7 @@ public class LanguageModelFF extends StatefulFF {
    * This function differs from regular transitions because we incorporate the cost of incomplete
    * left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
    * requested when the object was created).
-   * 
+   *
    * @param state the dynamic programming state
    * @return the final transition probability (including incomplete n-grams)
    */
@@ -492,7 +460,7 @@ public class LanguageModelFF extends StatefulFF {
    * This function is basically a wrapper for NGramLanguageModel::sentenceLogProbability(). It
    * computes the probability of a phrase ("chunk"), using lower-order n-grams for the first n-1
    * words.
-   * 
+   *
    * @param words
    * @param considerIncompleteNgrams
    * @param skipStart

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
index 533365c..88dc647 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
@@ -18,7 +18,8 @@
  */
 package org.apache.joshua.decoder.ff.lm;
 
-import java.util.ArrayList;
+import static org.apache.joshua.util.FormatUtils.isNonterminal;
+
 import java.util.List;
 import java.util.concurrent.ConcurrentHashMap;
 
@@ -26,18 +27,16 @@ import org.apache.joshua.corpus.Vocabulary;
 import org.apache.joshua.decoder.JoshuaConfiguration;
 import org.apache.joshua.decoder.chart_parser.SourcePath;
 import org.apache.joshua.decoder.ff.FeatureVector;
-import org.apache.joshua.decoder.ff.lm.KenLM;
 import org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair;
 import org.apache.joshua.decoder.ff.state_maintenance.DPState;
 import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
 import org.apache.joshua.decoder.ff.tm.Rule;
 import org.apache.joshua.decoder.hypergraph.HGNode;
 import org.apache.joshua.decoder.segment_file.Sentence;
-import org.apache.joshua.util.FormatUtils;
 
 /**
  * Wrapper for KenLM LMs with left-state minimization. We inherit from the regular
- * 
+ *
  * @author Matt Post post@cs.jhu.edu
  * @author Juri Ganitkevitch juri@cs.jhu.edu
  */
@@ -55,61 +54,37 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
       throw new RuntimeException(msg);
     }
   }
-  
-  @Override
-  public ArrayList<String> reportDenseFeatures(int index) {
-    denseFeatureIndex = index;
-    
-    ArrayList<String> names = new ArrayList<String>();
-    names.add(name);
-    return names;
-  }
 
   /**
    * Initializes the underlying language model.
    */
   @Override
   public void initializeLM() {
-    
+
     // Override type (only KenLM supports left-state minimization)
     this.languageModel = new KenLM(ngramOrder, path);
 
     Vocabulary.registerLanguageModel(this.languageModel);
     Vocabulary.id(config.default_non_terminal);
-    
+
   }
-  
+
   /**
-   * Estimates the cost of a rule. We override here since KenLM can do it more efficiently
-   * than the default {@link LanguageModelFF} class.
-   *    
-   * Most of this function implementation is redundant with compute().
+   * Estimates the cost of a rule. We override here since KenLM can do it more
+   * efficiently than the default {@link LanguageModelFF} class.
    */
   @Override
   public float estimateCost(Rule rule, Sentence sentence) {
-    
-    int[] ruleWords = rule.getEnglish();
-
-    // The IDs we'll pass to KenLM
-    long[] words = new long[ruleWords.length];
 
-    for (int x = 0; x < ruleWords.length; x++) {
-      int id = ruleWords[x];
+    int[] ruleWords = getRuleIds(rule);
 
-      if (FormatUtils.isNonterminal(id)) {
-        // For the estimate, we can just mark negative values
-        words[x] = -1;
+    // map to ken lm ids
+    final long[] words = mapToKenLmIds(ruleWords, null, true);
 
-      } else {
-        // Terminal: just add it
-        words[x] = id;
-      }
-    }
-    
     // Get the probability of applying the rule and the new state
     return weight * ((KenLM) languageModel).estimateRule(words);
   }
-  
+
   /**
    * Computes the features incurred along this edge. Note that these features are unweighted costs
    * of the feature; they are the feature cost, not the model cost, or the inner product of them.
@@ -118,39 +93,31 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
   public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
       Sentence sentence, Accumulator acc) {
 
-    int[] ruleWords = config.source_annotations 
-        ? getTags(rule, i, j, sentence)
-        : rule.getEnglish();
-
-    // The IDs we'll pass to KenLM
-    long[] words = new long[ruleWords.length];
+    if (rule == null ) {
+      return null;
+    }
 
-    for (int x = 0; x < ruleWords.length; x++) {
-      int id = ruleWords[x];
+    int[] ruleWords;
+    if (config.source_annotations) {
+      // get source side annotations and project them to the target side
+      ruleWords = getTags(rule, i, j, sentence);
+    } else {
+      ruleWords = getRuleIds(rule);
+    }
 
-      if (FormatUtils.isNonterminal(id)) {
-        // Nonterminal: retrieve the KenLM long that records the state
-        int index = -(id + 1);
-        KenLMState state = (KenLMState) tailNodes.get(index).getDPState(stateIndex);
-        words[x] = -state.getState();
+     // map to ken lm ids
+    final long[] words = mapToKenLmIds(ruleWords, tailNodes, false);
 
-      } else {
-        // Terminal: just add it
-        words[x] = id;
-      }
-    }
-    
-    int sentID = sentence.id();
+    final int sentID = sentence.id();
     // Since sentId is unique across threads, next operations are safe, but not atomic!
     if (!poolMap.containsKey(sentID)) {
       poolMap.put(sentID, KenLM.createPool());
     }
 
     // Get the probability of applying the rule and the new state
-    StateProbPair pair = ((KenLM) languageModel).probRule(words, poolMap.get(sentID));
+    final StateProbPair pair = ((KenLM) languageModel).probRule(words, poolMap.get(sentID));
 
     // Record the prob
-//    acc.add(name, pair.prob);
     acc.add(denseFeatureIndex, pair.prob);
 
     // Return the state
@@ -158,10 +125,40 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
   }
 
   /**
+   * Maps given array of word/class ids to KenLM ids. For estimating cost and computing,
+   * state retrieval differs slightly.
+   */
+  private long[] mapToKenLmIds(int[] ids, List<HGNode> tailNodes, boolean isOnlyEstimate) {
+    // The IDs we will to KenLM
+    long[] kenIds = new long[ids.length];
+    for (int x = 0; x < ids.length; x++) {
+      int id = ids[x];
+
+      if (isNonterminal(id)) {
+
+        if (isOnlyEstimate) {
+          // For the estimate, we can just mark negative values
+          kenIds[x] = -1;
+        } else {
+          // Nonterminal: retrieve the KenLM long that records the state
+          int index = -(id + 1);
+          final KenLMState state = (KenLMState) tailNodes.get(index).getDPState(stateIndex);
+          kenIds[x] = -state.getState();
+        }
+
+      } else {
+        // Terminal: just add it
+        kenIds[x] = id;
+      }
+    }
+    return kenIds;
+  }
+
+  /**
    * Destroys the pool created to allocate state for this sentence. Called from the
    * {@link org.apache.joshua.decoder.Translation} class after outputting the sentence or k-best list. Hosting
    * this map here in KenLMFF statically allows pools to be shared across KenLM instances.
-   * 
+   *
    * @param sentId a key in the poolmap table to destroy
    */
   public void destroyPool(int sentId) {
@@ -174,19 +171,13 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
    * This function differs from regular transitions because we incorporate the cost of incomplete
    * left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
    * requested when the object was created).
-   * 
+   *
    * KenLM already includes the prefix probabilities (of shorter n-grams on the left-hand side), so
    * there's nothing that needs to be done.
    */
   @Override
   public DPState computeFinal(HGNode tailNode, int i, int j, SourcePath sourcePath, Sentence sentence,
       Accumulator acc) {
-
-    // KenLMState state = (KenLMState) tailNode.getDPState(getStateIndex());
-
-    // This is unnecessary
-    // acc.add(name, 0.0f);
-
     // The state is the same since no rule was applied
     return new KenLMState();
   }

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java
new file mode 100644
index 0000000..7207d80
--- /dev/null
+++ b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassBasedLanguageModelTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm.class_lm;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.lm.LanguageModelFF;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+public class ClassBasedLanguageModelTest {
+
+  private static final float WEIGHT = 0.5f;
+
+  private LanguageModelFF ff;
+
+  @BeforeMethod
+  public void setUp() {
+    Decoder.resetGlobalState();
+
+    FeatureVector weights = new FeatureVector();
+    weights.set("lm_0", WEIGHT);
+    String[] args = { "-lm_type", "kenlm", "-lm_order", "9",
+      "-lm_file", "./src/test/resources/lm/class_lm/class_lm_9gram.gz",
+      "-class_map", "./src/test/resources/lm/class_lm/class.map" };
+
+    JoshuaConfiguration config = new JoshuaConfiguration();
+    ff = new LanguageModelFF(weights, args, config);
+  }
+
+  @AfterMethod
+  public void tearDown() {
+    Decoder.resetGlobalState();
+  }
+
+  @Test
+  public void givenLmDefinition_whenInitialized_thenInitializationIsCorrect() {
+    assertTrue(ff.isClassLM());
+    assertTrue(ff.isStateful());
+  }
+
+  @Test
+  public void givenRuleWithSingleWord_whenGetRuleId_thenIsMappedToClass() {
+    final int[] target = Vocabulary.addAll(new String[] { "professionalism" });
+    final Rule rule = new Rule(0, null, target, "", 0, 0);
+    assertEquals(Vocabulary.word(ff.getRuleIds(rule)[0]), "13");
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java
new file mode 100644
index 0000000..5d37a05
--- /dev/null
+++ b/src/test/java/org/apache/joshua/decoder/ff/lm/class_lm/ClassMapTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm.class_lm;
+
+import static org.testng.Assert.assertEquals;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.ff.lm.ClassMap;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+
+public class ClassMapTest {
+
+  private static final int EXPECTED_CLASS_MAP_SIZE = 5140;
+
+  @BeforeMethod
+  public void setUp() {
+    Decoder.resetGlobalState();
+  }
+
+  @AfterMethod
+  public void tearDown() {
+    Decoder.resetGlobalState();
+  }
+
+  @Test
+  public void givenClassMapFile_whenClassMapRead_thenEntriesAreRead() {
+    // GIVEN
+    final String classMapFile = "./src/test/resources/lm/class_lm/class.map";
+
+    // WHEN
+    final ClassMap classMap = new ClassMap(classMapFile);
+
+    // THEN
+    assertEquals(classMap.size(), EXPECTED_CLASS_MAP_SIZE);
+    assertEquals(
+      Vocabulary.word(
+        classMap.getClassID(
+          Vocabulary.id("professionalism"))),
+      "13");
+    assertEquals(
+      Vocabulary.word(
+        classMap.getClassID(
+          Vocabulary.id("convenience"))),
+      "0");
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8fc7544e/src/test/resources/bn-en/hiero/joshua-classlm.config
----------------------------------------------------------------------
diff --git a/src/test/resources/bn-en/hiero/joshua-classlm.config b/src/test/resources/bn-en/hiero/joshua-classlm.config
index 970b9b7..3be7392 100644
--- a/src/test/resources/bn-en/hiero/joshua-classlm.config
+++ b/src/test/resources/bn-en/hiero/joshua-classlm.config
@@ -1,7 +1,7 @@
-feature-function = LanguageModel -lm_type kenlm -lm_order 5 -minimizing false -lm_file lm.gz
+feature-function = StateMinimizingLanguageModel -lm_type kenlm -lm_order 5 -lm_file lm.gz
 
 # Class LM feature
-feature-function = LanguageModel -lm_type kenlm -lm_order 9 -minimizing false -lm_file class_lm_9gram.gz -lm_class -class_map class.map
+feature-function = StateMinimizingLanguageModel -lm_type kenlm -lm_order 9 -lm_file class_lm_9gram.gz -class_map class.map
 
 ###### Old format for lms
 # lm = kenlm 5 false false 100 lm.gz