You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/28 12:45:47 UTC

[3/6] incubator-joshua git commit: Remove uneeded modifications for estimate in KenLM

Remove uneeded modifications for estimate in KenLM


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/d9c3d7ec
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/d9c3d7ec
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/d9c3d7ec

Branch: refs/heads/master
Commit: d9c3d7ecf069a6a0339b911b9defb8ce31ebb1f1
Parents: c8d8a65
Author: Kellen Sunderland <ke...@amazon.com>
Authored: Tue Sep 27 17:31:37 2016 +0200
Committer: Kellen Sunderland <ke...@amazon.com>
Committed: Tue Sep 27 18:16:44 2016 +0200

----------------------------------------------------------------------
 jni/kenlm_wrap.cc                               | 30 ++++++------
 .../org/apache/joshua/decoder/ff/lm/KenLM.java  | 32 ++++++++-----
 .../ff/lm/StateMinimizingLanguageModel.java     | 49 +++++++++-----------
 .../org/apache/joshua/system/KenLmTest.java     | 16 +++----
 4 files changed, 64 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/jni/kenlm_wrap.cc
----------------------------------------------------------------------
diff --git a/jni/kenlm_wrap.cc b/jni/kenlm_wrap.cc
index 8947a61..445b57b 100644
--- a/jni/kenlm_wrap.cc
+++ b/jni/kenlm_wrap.cc
@@ -76,11 +76,10 @@ class EqualIndex : public std::binary_function<StateIndex, StateIndex, bool> {
 typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
 
 /**
- * A Chart bundles together a unordered_multimap that maps ChartState signatures to a single
- * object instantiated using a pool. This allows duplicate states to avoid allocating separate
- * state objects at multiple places throughout a sentence, and also allows state to be shared
- * across KenLMs for the same sentence.  Multimap is used to avoid hash collisions which can
- * return incorrect results, and cause out-of-bounds lookups when multiple KenLMs are in use.
+ * A Chart bundles together a vector holding CharStates and an unordered_set of StateIndexes
+ * which provides a mapping between StateIndexes and the positions of ChartStates in the vector.
+ * This allows for duplicate states to avoid allocating separate state objects at multiple places
+ * throughout a sentence.
  */
 class Chart {
   public:
@@ -148,7 +147,7 @@ public:
   virtual float ProbString(jint * const begin, jint * const end,
       jint start) const = 0;
 
-  virtual float EstimateRule(const Chart &chart) const = 0;
+  virtual float EstimateRule(jlong *begin, jlong *end) const = 0;
 
   virtual uint8_t Order() const = 0;
 
@@ -202,7 +201,7 @@ public:
 
   float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const {
 
-    // By convention the first long in the ngramBuffer denots the size of the buffer
+    // By convention the first long in the ngramBuffer denotes the size of the buffer
     long* begin = chart.ngramBuffer_ + 1;
     long* end = begin + *chart.ngramBuffer_;
 
@@ -229,12 +228,7 @@ public:
     return ruleScore.Finish();
   }
 
-  float EstimateRule(const Chart &chart) const {
-
-    // By convention the first long in the ngramBuffer denotes the size of the buffer
-    long* begin = chart.ngramBuffer_ + 1;
-    long* end = begin + *chart.ngramBuffer_;
-
+  float EstimateRule(jlong * const begin, jlong * const end) const {
     if (begin == end) return 0.0;
     lm::ngram::ChartState nullState;
     lm::ngram::RuleScore<Model> ruleScore(m_, nullState);
@@ -477,11 +471,15 @@ JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
 }
 
 JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_estimateRule(
-  JNIEnv *env, jclass, jlong pointer, jlong chartPtr) {
+  JNIEnv *env, jclass, jlong pointer, jlongArray arr) {
+  jint length = env->GetArrayLength(arr);
+  // GCC only.
+  jlong values[length];
+  env->GetLongArrayRegion(arr, 0, length, values);
 
   // Compute the probability
-  Chart* chart = reinterpret_cast<Chart*>(chartPtr);
-  return reinterpret_cast<const VirtualBase*>(pointer)->EstimateRule(*chart);
+  return reinterpret_cast<const VirtualBase*>(pointer)->EstimateRule(values,
+      values + length);
 }
 
 } // extern

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index df16019..d138495 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -21,13 +21,15 @@ package org.apache.joshua.decoder.ff.lm;
 import org.apache.joshua.corpus.Vocabulary;
 import org.apache.joshua.decoder.KenLMPool;
 import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
-import org.apache.joshua.util.Constants;
 import org.apache.joshua.util.FormatUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.nio.ByteBuffer;
 
+import static java.nio.ByteOrder.LITTLE_ENDIAN;
+import static org.apache.joshua.util.Constants.LONG_SIZE_IN_BYTES;
+
 /**
  * JNI wrapper for KenLM. This version of KenLM supports two use cases, implemented by the separate
  * feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
@@ -88,7 +90,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
 
   private static native long probRule(long ptr, long pool);
 
-  private static native float estimateRule(long ptr, long poolWrapper);
+  private static native float estimateRule(long ptr, long words[]);
 
   private static native float probString(long ptr, int words[], int start);
 
@@ -107,9 +109,8 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
   }
 
   public KenLMPool createLMPool() {
-    ByteBuffer ngramBuffer = ByteBuffer.allocateDirect(MAX_TARGET_LENGTH *
-            Constants.LONG_SIZE_IN_BYTES);
-    ngramBuffer.order(java.nio.ByteOrder.LITTLE_ENDIAN);
+    ByteBuffer ngramBuffer = ByteBuffer.allocateDirect(MAX_TARGET_LENGTH * LONG_SIZE_IN_BYTES);
+    ngramBuffer.order(LITTLE_ENDIAN);
     long pool = createPool(ngramBuffer);
     return new KenLMPool(pool, this, ngramBuffer);
   }
@@ -158,11 +159,18 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
    * needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
    * state and the LM probability incurred along this rule.
    *
+   * @param words       array of words
    * @param poolWrapper an object that wraps a pool reference returned from KenLM createPool
    * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
    * KenLM state and the LM probability incurred along this rule
    */
-  public StateProbPair probRule(KenLMPool poolWrapper) {
+  public StateProbPair probRule(long[] words, KenLMPool poolWrapper) {
+
+    poolWrapper.setBufferLength(words.length);
+    for (int i = 0; i < words.length; i++) {
+      poolWrapper.writeIdToBuffer(i, words[i]);
+    }
+
     long packedResult = probRule(pointer, poolWrapper.getPool());
     int state = (int) (packedResult >> 32);
     float probVal = Float.intBitsToFloat((int) packedResult);
@@ -174,12 +182,13 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
    * Public facing function that estimates the cost of a rule, which value is used for sorting
    * rules during cube pruning.
    *
+   * @param words array of words
    * @return the estimated cost of the rule (the (partial) n-gram probabilities of all words in the rule)
    */
-  public float estimateRule(KenLMPool poolWrapper) {
+  public float estimateRule(long[] words) {
     float estimate = 0.0f;
     try {
-      estimate = estimateRule(pointer, poolWrapper.getPool());
+      estimate = estimateRule(pointer, words);
     } catch (NoSuchMethodError e) {
       throw new RuntimeException(e);
     }
@@ -249,11 +258,12 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
   }
 
   /**
-   * Inner class used to hold the results returned from KenLM with left-state minimization.
+   * Inner class used to hold the results returned from KenLM with left-state minimization. Note
+   * that inner classes have to be static to be accessible from the JNI!
    */
   public static class StateProbPair {
-    public KenLMState state = null;
-    public float prob = 0.0f;
+    public final KenLMState state;
+    public final float prob;
 
     public StateProbPair(long state, float prob) {
       this.state = new KenLMState(state);

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
index 155522b..c3281d6 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
@@ -76,15 +76,13 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
 
     int[] ruleWords = getRuleIds(rule);
 
-    try(KenLMPool poolWrapper = ((KenLM)languageModel).createLMPool();) {
-      // Write KenLM word ids to a shared ByteBuffer.
-      writeKenLmIds(ruleWords, null, poolWrapper);
-
-      // Get the probability of applying the rule and the new state
-      float lmCost = weight * ((KenLM)languageModel).estimateRule(poolWrapper);
-      float oovCost = oovWeight * ((withOovFeature) ? getOovs(ruleWords) : 0f);
-      return lmCost + oovCost;
-    }
+    // map to ken lm ids
+    final long[] words = mapToKenLmIds(ruleWords, null, true);
+    
+    // Get the probability of applying the rule and the new state
+    float lmCost = weight * ((KenLM) languageModel).estimateRule(words);
+    float oovCost = oovWeight * ((withOovFeature) ? getOovs(ruleWords) : 0f);
+    return lmCost + oovCost;
   }
 
   private UUID languageModelPoolId = UUID.randomUUID();
@@ -103,7 +101,7 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
 
     int[] ruleWords;
     if (config.source_annotations) {
-      // Get source side annotations and project them to the target side
+      // get source side annotations and project them to the target side
       ruleWords = getTags(rule, i, j, sentence);
     } else {
       ruleWords = getRuleIds(rule);
@@ -114,16 +112,14 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
       acc.add(oovDenseFeatureIndex, getOovs(ruleWords));
     }
 
-    KenLMPool statePool = sentence.getStateManager().getStatePool(languageModelPoolId,
-            (KenLM)languageModel);
-
-     // Write KenLM ngram ids to the shared direct buffer
-    writeKenLmIds(ruleWords, tailNodes, statePool);
-
+     // map to ken lm ids
+    final long[] words = mapToKenLmIds(ruleWords, tailNodes, false);
 
+    KenLMPool statePool = sentence.getStateManager().getStatePool(languageModelPoolId, (KenLM)
+            languageModel);
 
     // Get the probability of applying the rule and the new state
-    final StateProbPair pair = ((KenLM)languageModel).probRule(statePool);
+    final StateProbPair pair = ((KenLM) languageModel).probRule(words, statePool);
 
     // Record the prob
     acc.add(denseFeatureIndex, pair.prob);
@@ -135,34 +131,31 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
   /**
    * Maps given array of word/class ids to KenLM ids. For estimating cost and computing,
    * state retrieval differs slightly.
-   *
-   * When used for estimation tailNodes may be null.
    */
-  private void writeKenLmIds(int[] ids, List<HGNode> tailNodes, KenLMPool poolWrapper) {
-
-    poolWrapper.setBufferLength(ids.length);
-
+  private long[] mapToKenLmIds(int[] ids, List<HGNode> tailNodes, boolean isOnlyEstimate) {
     // The IDs we will to KenLM
+    long[] kenIds = new long[ids.length];
     for (int x = 0; x < ids.length; x++) {
       int id = ids[x];
 
       if (isNonterminal(id)) {
 
-        if (tailNodes == null) {
-          // For the estimation, we can just mark negative values
-          poolWrapper.writeIdToBuffer(x, -1);
+        if (isOnlyEstimate) {
+          // For the estimate, we can just mark negative values
+          kenIds[x] = -1;
         } else {
           // Nonterminal: retrieve the KenLM long that records the state
           int index = -(id + 1);
           final KenLMState state = (KenLMState) tailNodes.get(index).getDPState(stateIndex);
-          poolWrapper.writeIdToBuffer(x, -state.getState());
+          kenIds[x] = -state.getState();
         }
 
       } else {
         // Terminal: just add it
-        poolWrapper.writeIdToBuffer(x, id);
+        kenIds[x] = id;
       }
     }
+    return kenIds;
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/src/test/java/org/apache/joshua/system/KenLmTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/system/KenLmTest.java b/src/test/java/org/apache/joshua/system/KenLmTest.java
index 38d6fcc..2d129f1 100644
--- a/src/test/java/org/apache/joshua/system/KenLmTest.java
+++ b/src/test/java/org/apache/joshua/system/KenLmTest.java
@@ -90,23 +90,23 @@ public class KenLmTest {
     registerLanguageModel(kenLm);
     String sentence = "Wayne Gretzky";
     String[] words = sentence.split("\\s+");
-    Vocabulary.addAll(sentence);
+    int[] ids = Vocabulary.addAll(sentence);
+    long[] longIds = new long[ids.length];
+
+    for (int i = 0; i < words.length; i++) {
+      longIds[i] = Vocabulary.id(words[i]);
+    }
 
     // WHEN
     KenLM.StateProbPair result;
     try (KenLMPool poolPointer = kenLm.createLMPool()) {
-
-      poolPointer.setBufferLength(words.length);
-      for(int i =0; i< words.length; i++) {
-        poolPointer.writeIdToBuffer(i, Vocabulary.id(words[i]));
-      }
-      result = kenLm.probRule(poolPointer);
+      result = kenLm.probRule(longIds, poolPointer);
     }
 
     // THEN
     assertThat(result, is(notNullValue()));
     assertThat(result.state.getState(), is(1L));
-    assertThat(result.prob, is(-3.7906885F));
+    assertThat(result.prob, is(-3.7906885f));
   }
 
   @Test