You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/13 15:48:53 UTC
[4/7] incubator-joshua git commit: Manage pool of states on a per LM,
per sentence basis
Manage pool of states on a per LM, per sentence basis
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/0252942d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/0252942d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/0252942d
Branch: refs/heads/master
Commit: 0252942dafc1679f2c5d6b8d6da7cd6884ca40c3
Parents: 4e07bb6
Author: Kellen Sunderland <ke...@amazon.com>
Authored: Tue Sep 13 13:58:05 2016 +0200
Committer: Kellen Sunderland <ke...@amazon.com>
Committed: Tue Sep 13 15:59:46 2016 +0200
----------------------------------------------------------------------
.../org/apache/joshua/decoder/KenLMPool.java | 42 ++++++++++++++++++++
.../decoder/LanguageModelStateManager.java | 29 ++++++++++++++
.../org/apache/joshua/decoder/Translation.java | 17 +-------
.../org/apache/joshua/decoder/ff/lm/KenLM.java | 25 +++++-------
.../ff/lm/StateMinimizingLanguageModel.java | 30 ++++----------
.../joshua/decoder/segment_file/Sentence.java | 11 +++++
.../org/apache/joshua/system/KenLmTest.java | 16 ++++----
7 files changed, 109 insertions(+), 61 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/KenLMPool.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/KenLMPool.java b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
new file mode 100644
index 0000000..378ac51
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
@@ -0,0 +1,42 @@
+package org.apache.joshua.decoder;
+
+import org.apache.joshua.decoder.ff.lm.KenLM;
+
+/**
+ * Class to wrap a KenLM pool of states. This class is not ThreadSafe. It should be
+ * used in a scoped context, and close must be called to release native resources. It
+ * does implement a custom finalizer that will release these resources if needed, but
+ * this should not be relied on.
+ *
+ * @author Kellen Sunderland
+ */
+
+public class KenLMPool implements AutoCloseable {
+
+ private final long pool;
+ private final KenLM languageModel;
+ private boolean released = false;
+
+ public KenLMPool(long pool, KenLM languageModel) {
+ this.pool = pool;
+ this.languageModel = languageModel;
+ }
+
+ public long getPool() {
+ return pool;
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ close();
+ super.finalize();
+ }
+
+ @Override
+ public void close() {
+ if (!released) {
+ released = true;
+ languageModel.destroyLMPool(pool);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java b/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java
new file mode 100644
index 0000000..6a3c4b3
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java
@@ -0,0 +1,29 @@
+package org.apache.joshua.decoder;
+
+import org.apache.joshua.decoder.ff.lm.KenLM;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * @author Kellen Sunderland
+ */
+public class LanguageModelStateManager {
+
+ private Map<UUID, KenLMPool> languageModelPoolMapping = new HashMap<>();
+
+ public KenLMPool getStatePool(UUID languageModelId, KenLM languageModel) {
+ KenLMPool statePool = languageModelPoolMapping.get(languageModelId);
+ if (statePool == null) {
+ statePool = languageModel.createLMPool();
+ languageModelPoolMapping.put(languageModelId, statePool);
+ }
+ return statePool;
+ }
+
+ public void clearStatePool() {
+ languageModelPoolMapping.values().forEach(KenLMPool::close);
+ languageModelPoolMapping.clear();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/Translation.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/Translation.java b/src/main/java/org/apache/joshua/decoder/Translation.java
index ade9b22..ff2aed0 100644
--- a/src/main/java/org/apache/joshua/decoder/Translation.java
+++ b/src/main/java/org/apache/joshua/decoder/Translation.java
@@ -182,8 +182,8 @@ public class Translation {
}
- // remove state from StateMinimizingLanguageModel instances in features.
- destroyKenLMStates(featureFunctions);
+ // Force any StateMinimizingLanguageModel pool mappings to be cleaned
+ source.getStateManager().clearStatePool();
}
@@ -224,17 +224,4 @@ public class Translation {
}
return structuredTranslations;
}
-
- /**
- * KenLM hack. If using KenLMFF, we need to tell KenLM to delete the pool used to create chart
- * objects for this sentence.
- */
- private void destroyKenLMStates(final List<FeatureFunction> featureFunctions) {
- for (FeatureFunction feature : featureFunctions) {
- if (feature instanceof StateMinimizingLanguageModel) {
- ((StateMinimizingLanguageModel) feature).destroyPool(getSourceSentence().id());
- break;
- }
- }
- }
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index b0a1117..0646f68 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -19,6 +19,7 @@
package org.apache.joshua.decoder.ff.lm;
import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.KenLMPool;
import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
import org.apache.joshua.util.FormatUtils;
import org.slf4j.Logger;
@@ -105,8 +106,8 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
}
}
- public long createLMPool() {
- return createPool();
+ public KenLMPool createLMPool() {
+ return new KenLMPool(createPool(), this);
}
public void destroyLMPool(long pointer) {
@@ -153,24 +154,16 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
* state and the LM probability incurred along this rule.
*
* @param words array of words
- * @param poolPointer todo
+ * @param poolWrapper an object that wraps a pool reference returned from KenLM createPool
* @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
* KenLM state and the LM probability incurred along this rule
*/
- public StateProbPair probRule(long[] words, long poolPointer) {
+ public StateProbPair probRule(long[] words, KenLMPool poolWrapper) {
+ long packedResult = probRule(pointer, poolWrapper.getPool(), words);
+ int state = (int) (packedResult >> 32);
+ float probVal = Float.intBitsToFloat((int)packedResult);
- StateProbPair pair = null;
- try {
- long packedResult = probRule(pointer, poolPointer, words);
- int state = (int) (packedResult >> 32);
- float probVal = Float.intBitsToFloat((int)packedResult);
- pair = new StateProbPair(state, probVal);
- } catch (NoSuchMethodError e) {
- e.printStackTrace();
- System.exit(1);
- }
-
- return pair;
+ return new StateProbPair(state, probVal);
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
index 4bec379..2219ce8 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
@@ -21,10 +21,11 @@ package org.apache.joshua.decoder.ff.lm;
import static org.apache.joshua.util.FormatUtils.isNonterminal;
import java.util.List;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.UUID;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.KenLMPool;
import org.apache.joshua.decoder.chart_parser.SourcePath;
import org.apache.joshua.decoder.ff.FeatureVector;
import org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair;
@@ -42,9 +43,6 @@ import org.apache.joshua.decoder.segment_file.Sentence;
*/
public class StateMinimizingLanguageModel extends LanguageModelFF {
- // maps from sentence numbers to KenLM-side pools used to allocate state
- private static final ConcurrentHashMap<Integer, Long> poolMap = new ConcurrentHashMap<>();
-
public StateMinimizingLanguageModel(FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, args, config);
this.type = "kenlm";
@@ -87,6 +85,8 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
return lmCost + oovCost;
}
+ private UUID languageModelPoolId = UUID.randomUUID();
+
/**
* Computes the features incurred along this edge. Note that these features are unweighted costs
* of the feature; they are the feature cost, not the model cost, or the inner product of them.
@@ -115,14 +115,11 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
// map to ken lm ids
final long[] words = mapToKenLmIds(ruleWords, tailNodes, false);
- final int sentID = sentence.id();
- // Since sentId is unique across threads, next operations are safe, but not atomic!
- if (!poolMap.containsKey(sentID)) {
- poolMap.put(sentID, ((KenLM) languageModel).createLMPool());
- }
+ KenLMPool statePool = sentence.getStateManager().getStatePool(languageModelPoolId, (KenLM)
+ languageModel);
// Get the probability of applying the rule and the new state
- final StateProbPair pair = ((KenLM) languageModel).probRule(words, poolMap.get(sentID));
+ final StateProbPair pair = ((KenLM) languageModel).probRule(words, statePool);
// Record the prob
acc.add(denseFeatureIndex, pair.prob);
@@ -162,19 +159,6 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
}
/**
- * Destroys the pool created to allocate state for this sentence. Called from the
- * {@link org.apache.joshua.decoder.Translation} class after outputting the sentence or k-best list. Hosting
- * this map here in KenLMFF statically allows pools to be shared across KenLM instances.
- *
- * @param sentId a key in the poolmap table to destroy
- */
- public void destroyPool(int sentId) {
- if (poolMap.containsKey(sentId))
- ((KenLM) languageModel).destroyLMPool(poolMap.get(sentId));
- poolMap.remove(sentId);
- }
-
- /**
* This function differs from regular transitions because we incorporate the cost of incomplete
* left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
* requested when the object was created).
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java b/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
index 7127870..f84c41a 100644
--- a/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
+++ b/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
@@ -21,16 +21,21 @@ package org.apache.joshua.decoder.segment_file;
import static org.apache.joshua.util.FormatUtils.addSentenceMarkers;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.StringTokenizer;
+import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.KenLMPool;
+import org.apache.joshua.decoder.LanguageModelStateManager;
import org.apache.joshua.decoder.ff.tm.Grammar;
import org.apache.joshua.lattice.Arc;
import org.apache.joshua.lattice.Lattice;
@@ -77,6 +82,8 @@ public class Sentence {
public JoshuaConfiguration config = null;
+ private LanguageModelStateManager stateManager = new LanguageModelStateManager();
+
/**
* Constructor. Receives a string representing the input sentence. This string may be a
* string-encoded lattice or a plain text string for decoding.
@@ -447,4 +454,8 @@ public class Sentence {
public Node<Token> getNode(int i) {
return getLattice().getNode(i);
}
+
+ public LanguageModelStateManager getStateManager() {
+ return stateManager;
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/test/java/org/apache/joshua/system/KenLmTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/system/KenLmTest.java b/src/test/java/org/apache/joshua/system/KenLmTest.java
index aa396d2..003b5d9 100644
--- a/src/test/java/org/apache/joshua/system/KenLmTest.java
+++ b/src/test/java/org/apache/joshua/system/KenLmTest.java
@@ -19,6 +19,7 @@
package org.apache.joshua.system;
import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.KenLMPool;
import org.apache.joshua.decoder.ff.lm.KenLM;
import org.apache.joshua.util.io.KenLmTestUtil;
import org.testng.annotations.AfterMethod;
@@ -29,8 +30,7 @@ import static org.apache.joshua.corpus.Vocabulary.registerLanguageModel;
import static org.apache.joshua.corpus.Vocabulary.unregisterLanguageModels;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.Is.is;
-import static org.mockito.Matchers.isNotNull;
-import static org.mockito.Matchers.notNull;
+import static org.hamcrest.core.IsNull.notNullValue;
import static org.testng.Assert.assertTrue;
import static org.testng.AssertJUnit.assertEquals;
import static org.testng.AssertJUnit.assertFalse;
@@ -84,7 +84,7 @@ public class KenLmTest {
}
@Test
- public void givenKenLm_whenQueryingForNgramProbability2_thenIdAndStringMethodsReturnTheSame() {
+ public void givenKenLm_whenQueryingWithState_thenStateAndProbReturned() {
// GIVEN
KenLmTestUtil.Guard(() -> kenLm = new KenLM(LANGUAGE_MODEL_PATH));
@@ -94,16 +94,18 @@ public class KenLmTest {
int[] ids = Vocabulary.addAll(sentence);
long[] longIds = new long[ids.length];
- for(int i = 0; i< words.length; i++) {
+ for (int i = 0; i < words.length; i++) {
longIds[i] = ids[i];
}
// WHEN
- long poolPointer = kenLm.createLMPool();
- KenLM.StateProbPair result = kenLm.probRule(longIds, poolPointer);
- kenLm.destroyLMPool(poolPointer);
+ KenLM.StateProbPair result;
+ try (KenLMPool poolPointer = kenLm.createLMPool()) {
+ result = kenLm.probRule(longIds, poolPointer);
+ }
// THEN
+ assertThat(result, is(notNullValue()));
assertThat(result.state.getState(), is(0L));
assertThat(result.prob, is(-3.7906885f));
}