You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/15 17:36:27 UTC

incubator-joshua git commit: Convert to a DirectBuffer to transfer ngrams during probRule

Repository: incubator-joshua
Updated Branches:
  refs/heads/kellen-more_kenlm_opts [created] 9ea7eebf0


Convert to a DirectBuffer to transfer ngrams during probRule


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/9ea7eebf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/9ea7eebf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/9ea7eebf

Branch: refs/heads/kellen-more_kenlm_opts
Commit: 9ea7eebf0164d1676f633b441bd952eaa20b0760
Parents: 9c6ae40
Author: Kellen Sunderland <ke...@amazon.com>
Authored: Thu Sep 15 19:06:04 2016 +0200
Committer: Kellen Sunderland <ke...@amazon.com>
Committed: Thu Sep 15 19:35:46 2016 +0200

----------------------------------------------------------------------
 jni/kenlm_wrap.cc                               |  30 +++---
 .../org/apache/joshua/decoder/KenLMPool.java    |  10 +-
 .../org/apache/joshua/decoder/ff/lm/KenLM.java  | 108 +++++++++++--------
 3 files changed, 89 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/9ea7eebf/jni/kenlm_wrap.cc
----------------------------------------------------------------------
diff --git a/jni/kenlm_wrap.cc b/jni/kenlm_wrap.cc
index bd82fe4..0f3148a 100644
--- a/jni/kenlm_wrap.cc
+++ b/jni/kenlm_wrap.cc
@@ -84,7 +84,9 @@ typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
  */
 class Chart {
   public:
-    Chart() : lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
+    Chart(long* ngramBuffer) : 
+    ngramBuffer_(ngramBuffer),
+    lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
 
     StateIndex Intern(const lm::ngram::ChartState &state) {
       vec_.push_back(state);
@@ -99,6 +101,7 @@ class Chart {
     const lm::ngram::ChartState &InterpretState(StateIndex index) const {
       return vec_[index - 1];
     }
+    long* ngramBuffer_;
 
   private:
     StateVector vec_;
@@ -140,7 +143,7 @@ public:
 
   virtual bool IsKnownWordIndex(const lm::WordIndex& id) const = 0;
 
-  virtual float ProbRule(jlong *begin, jlong *end, lm::ngram::ChartState& state, const Chart &chart) const = 0;
+  virtual float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const = 0;
 
   virtual float ProbString(jint * const begin, jint * const end,
       jint start) const = 0;
@@ -197,7 +200,12 @@ public:
       return id != m_.GetVocabulary().NotFound();
   }
 
-  float ProbRule(jlong * const begin, jlong * const end, lm::ngram::ChartState& state, const Chart &chart) const {
+  float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const {
+
+    // By convention the first long in the ngramBuffer denots the size of the buffer
+    long* begin = chart.ngramBuffer_ + 1;
+    long* end = begin + *chart.ngramBuffer_;
+
     if (begin == end) return 0.0;
     lm::ngram::RuleScore<Model> ruleScore(m_, state);
 
@@ -351,8 +359,10 @@ JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroy(
 }
 
 JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
-    JNIEnv *env, jclass) {
-  return reinterpret_cast<long>(new Chart());
+    JNIEnv *env, jclass, jobject arr) {
+  jlong* ngramBuffer = (jlong*)env->GetDirectBufferAddress(arr);
+  Chart *newChart = new Chart(ngramBuffer);
+  return reinterpret_cast<long>(newChart);
 }
 
 JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroyPool(
@@ -449,20 +459,14 @@ union FloatConverter {
 };
 
 JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
-  JNIEnv *env, jclass, jlong pointer, jlong chartPtr, jlongArray arr) {
-
-  jint length = env->GetArrayLength(arr);
-  // GCC only.
-  jlong values[length];
-  env->GetLongArrayRegion(arr, 0, length, values);
+  JNIEnv *env, jclass, jlong pointer, jlong chartPtr) {
 
   // Compute the probability
   lm::ngram::ChartState outState;
   const VirtualBase *base = reinterpret_cast<const VirtualBase*>(pointer);
   Chart* chart = reinterpret_cast<Chart*>(chartPtr);
   FloatConverter prob;
-  prob.f = base->ProbRule(values, values + length, outState, *chart);
-
+  prob.f = base->ProbRule(outState, *chart);
   StateIndex index = chart->Intern(outState);
   return static_cast<uint64_t>(index) << 32 | static_cast<uint64_t>(prob.i);
 }

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/9ea7eebf/src/main/java/org/apache/joshua/decoder/KenLMPool.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/KenLMPool.java b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
index 378ac51..a1e709b 100644
--- a/src/main/java/org/apache/joshua/decoder/KenLMPool.java
+++ b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
@@ -2,6 +2,8 @@ package org.apache.joshua.decoder;
 
 import org.apache.joshua.decoder.ff.lm.KenLM;
 
+import java.nio.ByteBuffer;
+
 /**
  * Class to wrap a KenLM pool of states.  This class is not ThreadSafe.  It should be
  * used in a scoped context, and close must be called to release native resources.  It
@@ -15,11 +17,13 @@ public class KenLMPool implements AutoCloseable {
 
   private final long pool;
   private final KenLM languageModel;
+  private final ByteBuffer ngramBuffer;
   private boolean released = false;
 
-  public KenLMPool(long pool, KenLM languageModel) {
+  public KenLMPool(long pool, KenLM languageModel, ByteBuffer ngramBuffer) {
     this.pool = pool;
     this.languageModel = languageModel;
+    this.ngramBuffer = ngramBuffer;
   }
 
   public long getPool() {
@@ -39,4 +43,8 @@ public class KenLMPool implements AutoCloseable {
       languageModel.destroyLMPool(pool);
     }
   }
+
+  public ByteBuffer getNgramBuffer() {
+    return ngramBuffer;
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/9ea7eebf/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index 0646f68..e8a9f0f 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -25,6 +25,8 @@ import org.apache.joshua.util.FormatUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.nio.ByteBuffer;
+
 /**
  * JNI wrapper for KenLM. This version of KenLM supports two use cases, implemented by the separate
  * feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
@@ -37,8 +39,13 @@ import org.slf4j.LoggerFactory;
 
 public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
 
+  private static final int LONG_SIZE_IN_BYTES = Long.SIZE / 8;
+
   private static final Logger LOG = LoggerFactory.getLogger(KenLM.class);
 
+  // Maximum number of terminal and non-terminal symbols on a rule's target side
+  private static final int MAX_TARGET_LENGTH = 256;
+
   private final long pointer;
 
   // this is read from the config file, used to set maximum order
@@ -46,6 +53,24 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
   // inferred from model file (may be larger than ngramOrder)
   private final int N;
 
+  public KenLM(int order, String file_name) {
+    pointer = initializeSystemLibrary(file_name);
+    ngramOrder = order;
+    N = order(pointer);
+  }
+
+  /**
+   * Constructor if order is not known.
+   * Order will be inferred from the model.
+   *
+   * @param file_name string path to an input file
+   */
+  public KenLM(String file_name) {
+    pointer = initializeSystemLibrary(file_name);
+    N = order(pointer);
+    ngramOrder = N;
+  }
+
   private static native long construct(String file_name);
 
   private static native void destroy(long ptr);
@@ -62,33 +87,16 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
 
   private static native boolean isLmOov(long ptr, int word);
 
-  private static native long probRule(long ptr, long pool, long words[]);
+  private static native long probRule(long ptr, long pool);
 
   private static native float estimateRule(long ptr, long words[]);
 
   private static native float probString(long ptr, int words[], int start);
 
-  private static native long createPool();
+  private static native long createPool(ByteBuffer wordsBuffer);
 
   private static native void destroyPool(long pointer);
 
-  public KenLM(int order, String file_name) {
-    pointer = initializeSystemLibrary(file_name);
-    ngramOrder = order;
-    N = order(pointer);
-  }
-
-  /**
-   * Constructor if order is not known.
-   * Order will be inferred from the model.
-   * @param file_name string path to an input file
-   */
-  public KenLM(String file_name) {
-    pointer = initializeSystemLibrary(file_name);
-    N = order(pointer);
-    ngramOrder = N;
-  }
-
   private long initializeSystemLibrary(String file_name) {
     try {
       System.loadLibrary("ken");
@@ -99,15 +107,11 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
     }
   }
 
-  public static class KenLMLoadException extends RuntimeException {
-
-    public KenLMLoadException(UnsatisfiedLinkError e) {
-      super(e);
-    }
-  }
-
   public KenLMPool createLMPool() {
-    return new KenLMPool(createPool(), this);
+    ByteBuffer ngramBuffer = ByteBuffer.allocateDirect(MAX_TARGET_LENGTH * LONG_SIZE_IN_BYTES);
+    ngramBuffer.order(java.nio.ByteOrder.LITTLE_ENDIAN);
+    long pool = createPool(ngramBuffer);
+    return new KenLMPool(pool, this, ngramBuffer);
   }
 
   public void destroyLMPool(long pointer) {
@@ -134,6 +138,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
 
   /**
    * Query for n-gram probability using strings.
+   *
    * @param words a string array of words
    * @return float value denoting probability
    */
@@ -153,15 +158,21 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
    * needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
    * state and the LM probability incurred along this rule.
    *
-   * @param words array of words
+   * @param words       array of words
    * @param poolWrapper an object that wraps a pool reference returned from KenLM createPool
    * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
    * KenLM state and the LM probability incurred along this rule
    */
   public StateProbPair probRule(long[] words, KenLMPool poolWrapper) {
-    long packedResult = probRule(pointer, poolWrapper.getPool(), words);
+
+    poolWrapper.getNgramBuffer().putLong(0, words.length);
+    for (int i = 0; i < words.length; i++) {
+      poolWrapper.getNgramBuffer().putLong((i + 1) * LONG_SIZE_IN_BYTES, words[i]);
+    }
+
+    long packedResult = probRule(pointer, poolWrapper.getPool());
     int state = (int) (packedResult >> 32);
-    float probVal = Float.intBitsToFloat((int)packedResult);
+    float probVal = Float.intBitsToFloat((int) packedResult);
 
     return new StateProbPair(state, probVal);
   }
@@ -186,6 +197,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
 
   /**
    * The start symbol for a KenLM is the Vocabulary.START_SYM.
+   *
    * @return "&lt;s&gt;"
    */
   public String getStartSymbol() {
@@ -209,21 +221,6 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
     return isKnownWord(pointer, word);
   }
 
-
-  /**
-   * Inner class used to hold the results returned from KenLM with left-state minimization. Note
-   * that inner classes have to be static to be accessible from the JNI!
-   */
-  public static class StateProbPair {
-    public KenLMState state = null;
-    public float prob = 0.0f;
-
-    public StateProbPair(long state, float prob) {
-      this.state = new KenLMState(state);
-      this.prob = prob;
-    }
-  }
-
   @Override
   public int compareTo(KenLM other) {
     if (this == other)
@@ -252,4 +249,25 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
     return prob(ngram);
   }
 
+  public static class KenLMLoadException extends RuntimeException {
+
+    public KenLMLoadException(UnsatisfiedLinkError e) {
+      super(e);
+    }
+  }
+
+  /**
+   * Inner class used to hold the results returned from KenLM with left-state minimization. Note
+   * that inner classes have to be static to be accessible from the JNI!
+   */
+  public static class StateProbPair {
+    public KenLMState state = null;
+    public float prob = 0.0f;
+
+    public StateProbPair(long state, float prob) {
+      this.state = new KenLMState(state);
+      this.prob = prob;
+    }
+  }
+
 }