You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/28 12:45:47 UTC
[3/6] incubator-joshua git commit: Remove uneeded modifications for
estimate in KenLM
Remove uneeded modifications for estimate in KenLM
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/d9c3d7ec
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/d9c3d7ec
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/d9c3d7ec
Branch: refs/heads/master
Commit: d9c3d7ecf069a6a0339b911b9defb8ce31ebb1f1
Parents: c8d8a65
Author: Kellen Sunderland <ke...@amazon.com>
Authored: Tue Sep 27 17:31:37 2016 +0200
Committer: Kellen Sunderland <ke...@amazon.com>
Committed: Tue Sep 27 18:16:44 2016 +0200
----------------------------------------------------------------------
jni/kenlm_wrap.cc | 30 ++++++------
.../org/apache/joshua/decoder/ff/lm/KenLM.java | 32 ++++++++-----
.../ff/lm/StateMinimizingLanguageModel.java | 49 +++++++++-----------
.../org/apache/joshua/system/KenLmTest.java | 16 +++----
4 files changed, 64 insertions(+), 63 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/jni/kenlm_wrap.cc
----------------------------------------------------------------------
diff --git a/jni/kenlm_wrap.cc b/jni/kenlm_wrap.cc
index 8947a61..445b57b 100644
--- a/jni/kenlm_wrap.cc
+++ b/jni/kenlm_wrap.cc
@@ -76,11 +76,10 @@ class EqualIndex : public std::binary_function<StateIndex, StateIndex, bool> {
typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
/**
- * A Chart bundles together a unordered_multimap that maps ChartState signatures to a single
- * object instantiated using a pool. This allows duplicate states to avoid allocating separate
- * state objects at multiple places throughout a sentence, and also allows state to be shared
- * across KenLMs for the same sentence. Multimap is used to avoid hash collisions which can
- * return incorrect results, and cause out-of-bounds lookups when multiple KenLMs are in use.
+ * A Chart bundles together a vector holding CharStates and an unordered_set of StateIndexes
+ * which provides a mapping between StateIndexes and the positions of ChartStates in the vector.
+ * This allows for duplicate states to avoid allocating separate state objects at multiple places
+ * throughout a sentence.
*/
class Chart {
public:
@@ -148,7 +147,7 @@ public:
virtual float ProbString(jint * const begin, jint * const end,
jint start) const = 0;
- virtual float EstimateRule(const Chart &chart) const = 0;
+ virtual float EstimateRule(jlong *begin, jlong *end) const = 0;
virtual uint8_t Order() const = 0;
@@ -202,7 +201,7 @@ public:
float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const {
- // By convention the first long in the ngramBuffer denots the size of the buffer
+ // By convention the first long in the ngramBuffer denotes the size of the buffer
long* begin = chart.ngramBuffer_ + 1;
long* end = begin + *chart.ngramBuffer_;
@@ -229,12 +228,7 @@ public:
return ruleScore.Finish();
}
- float EstimateRule(const Chart &chart) const {
-
- // By convention the first long in the ngramBuffer denotes the size of the buffer
- long* begin = chart.ngramBuffer_ + 1;
- long* end = begin + *chart.ngramBuffer_;
-
+ float EstimateRule(jlong * const begin, jlong * const end) const {
if (begin == end) return 0.0;
lm::ngram::ChartState nullState;
lm::ngram::RuleScore<Model> ruleScore(m_, nullState);
@@ -477,11 +471,15 @@ JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
}
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_estimateRule(
- JNIEnv *env, jclass, jlong pointer, jlong chartPtr) {
+ JNIEnv *env, jclass, jlong pointer, jlongArray arr) {
+ jint length = env->GetArrayLength(arr);
+ // GCC only.
+ jlong values[length];
+ env->GetLongArrayRegion(arr, 0, length, values);
// Compute the probability
- Chart* chart = reinterpret_cast<Chart*>(chartPtr);
- return reinterpret_cast<const VirtualBase*>(pointer)->EstimateRule(*chart);
+ return reinterpret_cast<const VirtualBase*>(pointer)->EstimateRule(values,
+ values + length);
}
} // extern
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index df16019..d138495 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -21,13 +21,15 @@ package org.apache.joshua.decoder.ff.lm;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.KenLMPool;
import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
-import org.apache.joshua.util.Constants;
import org.apache.joshua.util.FormatUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.ByteBuffer;
+import static java.nio.ByteOrder.LITTLE_ENDIAN;
+import static org.apache.joshua.util.Constants.LONG_SIZE_IN_BYTES;
+
/**
* JNI wrapper for KenLM. This version of KenLM supports two use cases, implemented by the separate
* feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
@@ -88,7 +90,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
private static native long probRule(long ptr, long pool);
- private static native float estimateRule(long ptr, long poolWrapper);
+ private static native float estimateRule(long ptr, long words[]);
private static native float probString(long ptr, int words[], int start);
@@ -107,9 +109,8 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
}
public KenLMPool createLMPool() {
- ByteBuffer ngramBuffer = ByteBuffer.allocateDirect(MAX_TARGET_LENGTH *
- Constants.LONG_SIZE_IN_BYTES);
- ngramBuffer.order(java.nio.ByteOrder.LITTLE_ENDIAN);
+ ByteBuffer ngramBuffer = ByteBuffer.allocateDirect(MAX_TARGET_LENGTH * LONG_SIZE_IN_BYTES);
+ ngramBuffer.order(LITTLE_ENDIAN);
long pool = createPool(ngramBuffer);
return new KenLMPool(pool, this, ngramBuffer);
}
@@ -158,11 +159,18 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
* needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
* state and the LM probability incurred along this rule.
*
+ * @param words array of words
* @param poolWrapper an object that wraps a pool reference returned from KenLM createPool
* @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
* KenLM state and the LM probability incurred along this rule
*/
- public StateProbPair probRule(KenLMPool poolWrapper) {
+ public StateProbPair probRule(long[] words, KenLMPool poolWrapper) {
+
+ poolWrapper.setBufferLength(words.length);
+ for (int i = 0; i < words.length; i++) {
+ poolWrapper.writeIdToBuffer(i, words[i]);
+ }
+
long packedResult = probRule(pointer, poolWrapper.getPool());
int state = (int) (packedResult >> 32);
float probVal = Float.intBitsToFloat((int) packedResult);
@@ -174,12 +182,13 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
* Public facing function that estimates the cost of a rule, which value is used for sorting
* rules during cube pruning.
*
+ * @param words array of words
* @return the estimated cost of the rule (the (partial) n-gram probabilities of all words in the rule)
*/
- public float estimateRule(KenLMPool poolWrapper) {
+ public float estimateRule(long[] words) {
float estimate = 0.0f;
try {
- estimate = estimateRule(pointer, poolWrapper.getPool());
+ estimate = estimateRule(pointer, words);
} catch (NoSuchMethodError e) {
throw new RuntimeException(e);
}
@@ -249,11 +258,12 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
}
/**
- * Inner class used to hold the results returned from KenLM with left-state minimization.
+ * Inner class used to hold the results returned from KenLM with left-state minimization. Note
+ * that inner classes have to be static to be accessible from the JNI!
*/
public static class StateProbPair {
- public KenLMState state = null;
- public float prob = 0.0f;
+ public final KenLMState state;
+ public final float prob;
public StateProbPair(long state, float prob) {
this.state = new KenLMState(state);
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
index 155522b..c3281d6 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
@@ -76,15 +76,13 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
int[] ruleWords = getRuleIds(rule);
- try(KenLMPool poolWrapper = ((KenLM)languageModel).createLMPool();) {
- // Write KenLM word ids to a shared ByteBuffer.
- writeKenLmIds(ruleWords, null, poolWrapper);
-
- // Get the probability of applying the rule and the new state
- float lmCost = weight * ((KenLM)languageModel).estimateRule(poolWrapper);
- float oovCost = oovWeight * ((withOovFeature) ? getOovs(ruleWords) : 0f);
- return lmCost + oovCost;
- }
+ // map to ken lm ids
+ final long[] words = mapToKenLmIds(ruleWords, null, true);
+
+ // Get the probability of applying the rule and the new state
+ float lmCost = weight * ((KenLM) languageModel).estimateRule(words);
+ float oovCost = oovWeight * ((withOovFeature) ? getOovs(ruleWords) : 0f);
+ return lmCost + oovCost;
}
private UUID languageModelPoolId = UUID.randomUUID();
@@ -103,7 +101,7 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
int[] ruleWords;
if (config.source_annotations) {
- // Get source side annotations and project them to the target side
+ // get source side annotations and project them to the target side
ruleWords = getTags(rule, i, j, sentence);
} else {
ruleWords = getRuleIds(rule);
@@ -114,16 +112,14 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
acc.add(oovDenseFeatureIndex, getOovs(ruleWords));
}
- KenLMPool statePool = sentence.getStateManager().getStatePool(languageModelPoolId,
- (KenLM)languageModel);
-
- // Write KenLM ngram ids to the shared direct buffer
- writeKenLmIds(ruleWords, tailNodes, statePool);
-
+ // map to ken lm ids
+ final long[] words = mapToKenLmIds(ruleWords, tailNodes, false);
+ KenLMPool statePool = sentence.getStateManager().getStatePool(languageModelPoolId, (KenLM)
+ languageModel);
// Get the probability of applying the rule and the new state
- final StateProbPair pair = ((KenLM)languageModel).probRule(statePool);
+ final StateProbPair pair = ((KenLM) languageModel).probRule(words, statePool);
// Record the prob
acc.add(denseFeatureIndex, pair.prob);
@@ -135,34 +131,31 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
/**
* Maps given array of word/class ids to KenLM ids. For estimating cost and computing,
* state retrieval differs slightly.
- *
- * When used for estimation tailNodes may be null.
*/
- private void writeKenLmIds(int[] ids, List<HGNode> tailNodes, KenLMPool poolWrapper) {
-
- poolWrapper.setBufferLength(ids.length);
-
+ private long[] mapToKenLmIds(int[] ids, List<HGNode> tailNodes, boolean isOnlyEstimate) {
// The IDs we will to KenLM
+ long[] kenIds = new long[ids.length];
for (int x = 0; x < ids.length; x++) {
int id = ids[x];
if (isNonterminal(id)) {
- if (tailNodes == null) {
- // For the estimation, we can just mark negative values
- poolWrapper.writeIdToBuffer(x, -1);
+ if (isOnlyEstimate) {
+ // For the estimate, we can just mark negative values
+ kenIds[x] = -1;
} else {
// Nonterminal: retrieve the KenLM long that records the state
int index = -(id + 1);
final KenLMState state = (KenLMState) tailNodes.get(index).getDPState(stateIndex);
- poolWrapper.writeIdToBuffer(x, -state.getState());
+ kenIds[x] = -state.getState();
}
} else {
// Terminal: just add it
- poolWrapper.writeIdToBuffer(x, id);
+ kenIds[x] = id;
}
}
+ return kenIds;
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d9c3d7ec/src/test/java/org/apache/joshua/system/KenLmTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/system/KenLmTest.java b/src/test/java/org/apache/joshua/system/KenLmTest.java
index 38d6fcc..2d129f1 100644
--- a/src/test/java/org/apache/joshua/system/KenLmTest.java
+++ b/src/test/java/org/apache/joshua/system/KenLmTest.java
@@ -90,23 +90,23 @@ public class KenLmTest {
registerLanguageModel(kenLm);
String sentence = "Wayne Gretzky";
String[] words = sentence.split("\\s+");
- Vocabulary.addAll(sentence);
+ int[] ids = Vocabulary.addAll(sentence);
+ long[] longIds = new long[ids.length];
+
+ for (int i = 0; i < words.length; i++) {
+ longIds[i] = Vocabulary.id(words[i]);
+ }
// WHEN
KenLM.StateProbPair result;
try (KenLMPool poolPointer = kenLm.createLMPool()) {
-
- poolPointer.setBufferLength(words.length);
- for(int i =0; i< words.length; i++) {
- poolPointer.writeIdToBuffer(i, Vocabulary.id(words[i]));
- }
- result = kenLm.probRule(poolPointer);
+ result = kenLm.probRule(longIds, poolPointer);
}
// THEN
assertThat(result, is(notNullValue()));
assertThat(result.state.getState(), is(1L));
- assertThat(result.prob, is(-3.7906885F));
+ assertThat(result.prob, is(-3.7906885f));
}
@Test