You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/15 17:36:27 UTC
incubator-joshua git commit: Convert to a DirectBuffer to transfer
ngrams during probRule
Repository: incubator-joshua
Updated Branches:
refs/heads/kellen-more_kenlm_opts [created] 9ea7eebf0
Convert to a DirectBuffer to transfer ngrams during probRule
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/9ea7eebf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/9ea7eebf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/9ea7eebf
Branch: refs/heads/kellen-more_kenlm_opts
Commit: 9ea7eebf0164d1676f633b441bd952eaa20b0760
Parents: 9c6ae40
Author: Kellen Sunderland <ke...@amazon.com>
Authored: Thu Sep 15 19:06:04 2016 +0200
Committer: Kellen Sunderland <ke...@amazon.com>
Committed: Thu Sep 15 19:35:46 2016 +0200
----------------------------------------------------------------------
jni/kenlm_wrap.cc | 30 +++---
.../org/apache/joshua/decoder/KenLMPool.java | 10 +-
.../org/apache/joshua/decoder/ff/lm/KenLM.java | 108 +++++++++++--------
3 files changed, 89 insertions(+), 59 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/9ea7eebf/jni/kenlm_wrap.cc
----------------------------------------------------------------------
diff --git a/jni/kenlm_wrap.cc b/jni/kenlm_wrap.cc
index bd82fe4..0f3148a 100644
--- a/jni/kenlm_wrap.cc
+++ b/jni/kenlm_wrap.cc
@@ -84,7 +84,9 @@ typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
*/
class Chart {
public:
- Chart() : lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
+ Chart(long* ngramBuffer) :
+ ngramBuffer_(ngramBuffer),
+ lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
StateIndex Intern(const lm::ngram::ChartState &state) {
vec_.push_back(state);
@@ -99,6 +101,7 @@ class Chart {
const lm::ngram::ChartState &InterpretState(StateIndex index) const {
return vec_[index - 1];
}
+ long* ngramBuffer_;
private:
StateVector vec_;
@@ -140,7 +143,7 @@ public:
virtual bool IsKnownWordIndex(const lm::WordIndex& id) const = 0;
- virtual float ProbRule(jlong *begin, jlong *end, lm::ngram::ChartState& state, const Chart &chart) const = 0;
+ virtual float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const = 0;
virtual float ProbString(jint * const begin, jint * const end,
jint start) const = 0;
@@ -197,7 +200,12 @@ public:
return id != m_.GetVocabulary().NotFound();
}
- float ProbRule(jlong * const begin, jlong * const end, lm::ngram::ChartState& state, const Chart &chart) const {
+ float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const {
+
+ // By convention the first long in the ngramBuffer denots the size of the buffer
+ long* begin = chart.ngramBuffer_ + 1;
+ long* end = begin + *chart.ngramBuffer_;
+
if (begin == end) return 0.0;
lm::ngram::RuleScore<Model> ruleScore(m_, state);
@@ -351,8 +359,10 @@ JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroy(
}
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
- JNIEnv *env, jclass) {
- return reinterpret_cast<long>(new Chart());
+ JNIEnv *env, jclass, jobject arr) {
+ jlong* ngramBuffer = (jlong*)env->GetDirectBufferAddress(arr);
+ Chart *newChart = new Chart(ngramBuffer);
+ return reinterpret_cast<long>(newChart);
}
JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroyPool(
@@ -449,20 +459,14 @@ union FloatConverter {
};
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
- JNIEnv *env, jclass, jlong pointer, jlong chartPtr, jlongArray arr) {
-
- jint length = env->GetArrayLength(arr);
- // GCC only.
- jlong values[length];
- env->GetLongArrayRegion(arr, 0, length, values);
+ JNIEnv *env, jclass, jlong pointer, jlong chartPtr) {
// Compute the probability
lm::ngram::ChartState outState;
const VirtualBase *base = reinterpret_cast<const VirtualBase*>(pointer);
Chart* chart = reinterpret_cast<Chart*>(chartPtr);
FloatConverter prob;
- prob.f = base->ProbRule(values, values + length, outState, *chart);
-
+ prob.f = base->ProbRule(outState, *chart);
StateIndex index = chart->Intern(outState);
return static_cast<uint64_t>(index) << 32 | static_cast<uint64_t>(prob.i);
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/9ea7eebf/src/main/java/org/apache/joshua/decoder/KenLMPool.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/KenLMPool.java b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
index 378ac51..a1e709b 100644
--- a/src/main/java/org/apache/joshua/decoder/KenLMPool.java
+++ b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
@@ -2,6 +2,8 @@ package org.apache.joshua.decoder;
import org.apache.joshua.decoder.ff.lm.KenLM;
+import java.nio.ByteBuffer;
+
/**
* Class to wrap a KenLM pool of states. This class is not ThreadSafe. It should be
* used in a scoped context, and close must be called to release native resources. It
@@ -15,11 +17,13 @@ public class KenLMPool implements AutoCloseable {
private final long pool;
private final KenLM languageModel;
+ private final ByteBuffer ngramBuffer;
private boolean released = false;
- public KenLMPool(long pool, KenLM languageModel) {
+ public KenLMPool(long pool, KenLM languageModel, ByteBuffer ngramBuffer) {
this.pool = pool;
this.languageModel = languageModel;
+ this.ngramBuffer = ngramBuffer;
}
public long getPool() {
@@ -39,4 +43,8 @@ public class KenLMPool implements AutoCloseable {
languageModel.destroyLMPool(pool);
}
}
+
+ public ByteBuffer getNgramBuffer() {
+ return ngramBuffer;
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/9ea7eebf/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index 0646f68..e8a9f0f 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -25,6 +25,8 @@ import org.apache.joshua.util.FormatUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.nio.ByteBuffer;
+
/**
* JNI wrapper for KenLM. This version of KenLM supports two use cases, implemented by the separate
* feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
@@ -37,8 +39,13 @@ import org.slf4j.LoggerFactory;
public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
+ private static final int LONG_SIZE_IN_BYTES = Long.SIZE / 8;
+
private static final Logger LOG = LoggerFactory.getLogger(KenLM.class);
+ // Maximum number of terminal and non-terminal symbols on a rule's target side
+ private static final int MAX_TARGET_LENGTH = 256;
+
private final long pointer;
// this is read from the config file, used to set maximum order
@@ -46,6 +53,24 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
// inferred from model file (may be larger than ngramOrder)
private final int N;
+ public KenLM(int order, String file_name) {
+ pointer = initializeSystemLibrary(file_name);
+ ngramOrder = order;
+ N = order(pointer);
+ }
+
+ /**
+ * Constructor if order is not known.
+ * Order will be inferred from the model.
+ *
+ * @param file_name string path to an input file
+ */
+ public KenLM(String file_name) {
+ pointer = initializeSystemLibrary(file_name);
+ N = order(pointer);
+ ngramOrder = N;
+ }
+
private static native long construct(String file_name);
private static native void destroy(long ptr);
@@ -62,33 +87,16 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
private static native boolean isLmOov(long ptr, int word);
- private static native long probRule(long ptr, long pool, long words[]);
+ private static native long probRule(long ptr, long pool);
private static native float estimateRule(long ptr, long words[]);
private static native float probString(long ptr, int words[], int start);
- private static native long createPool();
+ private static native long createPool(ByteBuffer wordsBuffer);
private static native void destroyPool(long pointer);
- public KenLM(int order, String file_name) {
- pointer = initializeSystemLibrary(file_name);
- ngramOrder = order;
- N = order(pointer);
- }
-
- /**
- * Constructor if order is not known.
- * Order will be inferred from the model.
- * @param file_name string path to an input file
- */
- public KenLM(String file_name) {
- pointer = initializeSystemLibrary(file_name);
- N = order(pointer);
- ngramOrder = N;
- }
-
private long initializeSystemLibrary(String file_name) {
try {
System.loadLibrary("ken");
@@ -99,15 +107,11 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
}
}
- public static class KenLMLoadException extends RuntimeException {
-
- public KenLMLoadException(UnsatisfiedLinkError e) {
- super(e);
- }
- }
-
public KenLMPool createLMPool() {
- return new KenLMPool(createPool(), this);
+ ByteBuffer ngramBuffer = ByteBuffer.allocateDirect(MAX_TARGET_LENGTH * LONG_SIZE_IN_BYTES);
+ ngramBuffer.order(java.nio.ByteOrder.LITTLE_ENDIAN);
+ long pool = createPool(ngramBuffer);
+ return new KenLMPool(pool, this, ngramBuffer);
}
public void destroyLMPool(long pointer) {
@@ -134,6 +138,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
/**
* Query for n-gram probability using strings.
+ *
* @param words a string array of words
* @return float value denoting probability
*/
@@ -153,15 +158,21 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
* needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
* state and the LM probability incurred along this rule.
*
- * @param words array of words
+ * @param words array of words
* @param poolWrapper an object that wraps a pool reference returned from KenLM createPool
* @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
* KenLM state and the LM probability incurred along this rule
*/
public StateProbPair probRule(long[] words, KenLMPool poolWrapper) {
- long packedResult = probRule(pointer, poolWrapper.getPool(), words);
+
+ poolWrapper.getNgramBuffer().putLong(0, words.length);
+ for (int i = 0; i < words.length; i++) {
+ poolWrapper.getNgramBuffer().putLong((i + 1) * LONG_SIZE_IN_BYTES, words[i]);
+ }
+
+ long packedResult = probRule(pointer, poolWrapper.getPool());
int state = (int) (packedResult >> 32);
- float probVal = Float.intBitsToFloat((int)packedResult);
+ float probVal = Float.intBitsToFloat((int) packedResult);
return new StateProbPair(state, probVal);
}
@@ -186,6 +197,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
/**
* The start symbol for a KenLM is the Vocabulary.START_SYM.
+ *
* @return "<s>"
*/
public String getStartSymbol() {
@@ -209,21 +221,6 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
return isKnownWord(pointer, word);
}
-
- /**
- * Inner class used to hold the results returned from KenLM with left-state minimization. Note
- * that inner classes have to be static to be accessible from the JNI!
- */
- public static class StateProbPair {
- public KenLMState state = null;
- public float prob = 0.0f;
-
- public StateProbPair(long state, float prob) {
- this.state = new KenLMState(state);
- this.prob = prob;
- }
- }
-
@Override
public int compareTo(KenLM other) {
if (this == other)
@@ -252,4 +249,25 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
return prob(ngram);
}
+ public static class KenLMLoadException extends RuntimeException {
+
+ public KenLMLoadException(UnsatisfiedLinkError e) {
+ super(e);
+ }
+ }
+
+ /**
+ * Inner class used to hold the results returned from KenLM with left-state minimization. Note
+ * that inner classes have to be static to be accessible from the JNI!
+ */
+ public static class StateProbPair {
+ public KenLMState state = null;
+ public float prob = 0.0f;
+
+ public StateProbPair(long state, float prob) {
+ this.state = new KenLMState(state);
+ this.prob = prob;
+ }
+ }
+
}