You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/13 15:48:53 UTC

[4/7] incubator-joshua git commit: Manage pool of states on a per LM, per sentence basis

Manage pool of states on a per LM, per sentence basis


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/0252942d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/0252942d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/0252942d

Branch: refs/heads/master
Commit: 0252942dafc1679f2c5d6b8d6da7cd6884ca40c3
Parents: 4e07bb6
Author: Kellen Sunderland <ke...@amazon.com>
Authored: Tue Sep 13 13:58:05 2016 +0200
Committer: Kellen Sunderland <ke...@amazon.com>
Committed: Tue Sep 13 15:59:46 2016 +0200

----------------------------------------------------------------------
 .../org/apache/joshua/decoder/KenLMPool.java    | 42 ++++++++++++++++++++
 .../decoder/LanguageModelStateManager.java      | 29 ++++++++++++++
 .../org/apache/joshua/decoder/Translation.java  | 17 +-------
 .../org/apache/joshua/decoder/ff/lm/KenLM.java  | 25 +++++-------
 .../ff/lm/StateMinimizingLanguageModel.java     | 30 ++++----------
 .../joshua/decoder/segment_file/Sentence.java   | 11 +++++
 .../org/apache/joshua/system/KenLmTest.java     | 16 ++++----
 7 files changed, 109 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/KenLMPool.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/KenLMPool.java b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
new file mode 100644
index 0000000..378ac51
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/KenLMPool.java
@@ -0,0 +1,42 @@
+package org.apache.joshua.decoder;
+
+import org.apache.joshua.decoder.ff.lm.KenLM;
+
+/**
+ * Class to wrap a KenLM pool of states.  This class is not ThreadSafe.  It should be
+ * used in a scoped context, and close must be called to release native resources.  It
+ * does implement a custom finalizer that will release these resources if needed, but
+ * this should not be relied on.
+ *
+ * @author Kellen Sunderland
+ */
+
+public class KenLMPool implements AutoCloseable {
+
+  private final long pool;
+  private final KenLM languageModel;
+  private boolean released = false;
+
+  public KenLMPool(long pool, KenLM languageModel) {
+    this.pool = pool;
+    this.languageModel = languageModel;
+  }
+
+  public long getPool() {
+    return pool;
+  }
+
+  @Override
+  protected void finalize() throws Throwable {
+    close();
+    super.finalize();
+  }
+
+  @Override
+  public void close() {
+    if (!released) {
+      released = true;
+      languageModel.destroyLMPool(pool);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java b/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java
new file mode 100644
index 0000000..6a3c4b3
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/LanguageModelStateManager.java
@@ -0,0 +1,29 @@
+package org.apache.joshua.decoder;
+
+import org.apache.joshua.decoder.ff.lm.KenLM;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * @author Kellen Sunderland
+ */
+public class LanguageModelStateManager {
+
+  private Map<UUID, KenLMPool> languageModelPoolMapping = new HashMap<>();
+
+  public KenLMPool getStatePool(UUID languageModelId, KenLM languageModel) {
+    KenLMPool statePool = languageModelPoolMapping.get(languageModelId);
+    if (statePool == null) {
+      statePool = languageModel.createLMPool();
+      languageModelPoolMapping.put(languageModelId, statePool);
+    }
+    return statePool;
+  }
+
+  public void clearStatePool() {
+    languageModelPoolMapping.values().forEach(KenLMPool::close);
+    languageModelPoolMapping.clear();
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/Translation.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/Translation.java b/src/main/java/org/apache/joshua/decoder/Translation.java
index ade9b22..ff2aed0 100644
--- a/src/main/java/org/apache/joshua/decoder/Translation.java
+++ b/src/main/java/org/apache/joshua/decoder/Translation.java
@@ -182,8 +182,8 @@ public class Translation {
 
     }
 
-    // remove state from StateMinimizingLanguageModel instances in features.
-    destroyKenLMStates(featureFunctions);
+    // Force any StateMinimizingLanguageModel pool mappings to be cleaned
+    source.getStateManager().clearStatePool();
 
   }
 
@@ -224,17 +224,4 @@ public class Translation {
     }
     return structuredTranslations;
   }
-
-  /**
-   * KenLM hack. If using KenLMFF, we need to tell KenLM to delete the pool used to create chart
-   * objects for this sentence.
-   */
-  private void destroyKenLMStates(final List<FeatureFunction> featureFunctions) {
-    for (FeatureFunction feature : featureFunctions) {
-      if (feature instanceof StateMinimizingLanguageModel) {
-        ((StateMinimizingLanguageModel) feature).destroyPool(getSourceSentence().id());
-        break;
-      }
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index b0a1117..0646f68 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -19,6 +19,7 @@
 package org.apache.joshua.decoder.ff.lm;
 
 import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.KenLMPool;
 import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
 import org.apache.joshua.util.FormatUtils;
 import org.slf4j.Logger;
@@ -105,8 +106,8 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
     }
   }
 
-  public long createLMPool() {
-    return createPool();
+  public KenLMPool createLMPool() {
+    return new KenLMPool(createPool(), this);
   }
 
   public void destroyLMPool(long pointer) {
@@ -153,24 +154,16 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
    * state and the LM probability incurred along this rule.
    *
    * @param words array of words
-   * @param poolPointer todo
+   * @param poolWrapper an object that wraps a pool reference returned from KenLM createPool
    * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
    * KenLM state and the LM probability incurred along this rule
    */
-  public StateProbPair probRule(long[] words, long poolPointer) {
+  public StateProbPair probRule(long[] words, KenLMPool poolWrapper) {
+    long packedResult = probRule(pointer, poolWrapper.getPool(), words);
+    int state = (int) (packedResult >> 32);
+    float probVal = Float.intBitsToFloat((int)packedResult);
 
-    StateProbPair pair = null;
-    try {
-      long packedResult = probRule(pointer, poolPointer, words);
-      int state = (int) (packedResult >> 32);
-      float probVal = Float.intBitsToFloat((int)packedResult);
-      pair = new StateProbPair(state, probVal);
-    } catch (NoSuchMethodError e) {
-      e.printStackTrace();
-      System.exit(1);
-    }
-
-    return pair;
+    return new StateProbPair(state, probVal);
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
index 4bec379..2219ce8 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/StateMinimizingLanguageModel.java
@@ -21,10 +21,11 @@ package org.apache.joshua.decoder.ff.lm;
 import static org.apache.joshua.util.FormatUtils.isNonterminal;
 
 import java.util.List;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.UUID;
 
 import org.apache.joshua.corpus.Vocabulary;
 import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.KenLMPool;
 import org.apache.joshua.decoder.chart_parser.SourcePath;
 import org.apache.joshua.decoder.ff.FeatureVector;
 import org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair;
@@ -42,9 +43,6 @@ import org.apache.joshua.decoder.segment_file.Sentence;
  */
 public class StateMinimizingLanguageModel extends LanguageModelFF {
 
-  // maps from sentence numbers to KenLM-side pools used to allocate state
-  private static final ConcurrentHashMap<Integer, Long> poolMap = new ConcurrentHashMap<>();
-
   public StateMinimizingLanguageModel(FeatureVector weights, String[] args, JoshuaConfiguration config) {
     super(weights, args, config);
     this.type = "kenlm";
@@ -87,6 +85,8 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
     return lmCost + oovCost;
   }
 
+  private UUID languageModelPoolId = UUID.randomUUID();
+
   /**
    * Computes the features incurred along this edge. Note that these features are unweighted costs
    * of the feature; they are the feature cost, not the model cost, or the inner product of them.
@@ -115,14 +115,11 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
      // map to ken lm ids
     final long[] words = mapToKenLmIds(ruleWords, tailNodes, false);
 
-    final int sentID = sentence.id();
-    // Since sentId is unique across threads, next operations are safe, but not atomic!
-    if (!poolMap.containsKey(sentID)) {
-      poolMap.put(sentID, ((KenLM) languageModel).createLMPool());
-    }
+    KenLMPool statePool = sentence.getStateManager().getStatePool(languageModelPoolId, (KenLM)
+            languageModel);
 
     // Get the probability of applying the rule and the new state
-    final StateProbPair pair = ((KenLM) languageModel).probRule(words, poolMap.get(sentID));
+    final StateProbPair pair = ((KenLM) languageModel).probRule(words, statePool);
 
     // Record the prob
     acc.add(denseFeatureIndex, pair.prob);
@@ -162,19 +159,6 @@ public class StateMinimizingLanguageModel extends LanguageModelFF {
   }
 
   /**
-   * Destroys the pool created to allocate state for this sentence. Called from the
-   * {@link org.apache.joshua.decoder.Translation} class after outputting the sentence or k-best list. Hosting
-   * this map here in KenLMFF statically allows pools to be shared across KenLM instances.
-   *
-   * @param sentId a key in the poolmap table to destroy
-   */
-  public void destroyPool(int sentId) {
-    if (poolMap.containsKey(sentId))
-      ((KenLM) languageModel).destroyLMPool(poolMap.get(sentId));
-    poolMap.remove(sentId);
-  }
-
-  /**
    * This function differs from regular transitions because we incorporate the cost of incomplete
    * left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
    * requested when the object was created).

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java b/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
index 7127870..f84c41a 100644
--- a/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
+++ b/src/main/java/org/apache/joshua/decoder/segment_file/Sentence.java
@@ -21,16 +21,21 @@ package org.apache.joshua.decoder.segment_file;
 import static org.apache.joshua.util.FormatUtils.addSentenceMarkers;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 import java.util.StringTokenizer;
+import java.util.UUID;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
 import org.apache.joshua.corpus.Vocabulary;
 import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.KenLMPool;
+import org.apache.joshua.decoder.LanguageModelStateManager;
 import org.apache.joshua.decoder.ff.tm.Grammar;
 import org.apache.joshua.lattice.Arc;
 import org.apache.joshua.lattice.Lattice;
@@ -77,6 +82,8 @@ public class Sentence {
   
   public JoshuaConfiguration config = null;
 
+  private LanguageModelStateManager stateManager = new LanguageModelStateManager();
+
   /**
    * Constructor. Receives a string representing the input sentence. This string may be a
    * string-encoded lattice or a plain text string for decoding.
@@ -447,4 +454,8 @@ public class Sentence {
   public Node<Token> getNode(int i) {
     return getLattice().getNode(i);
   }
+
+  public LanguageModelStateManager getStateManager() {
+    return stateManager;
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0252942d/src/test/java/org/apache/joshua/system/KenLmTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/system/KenLmTest.java b/src/test/java/org/apache/joshua/system/KenLmTest.java
index aa396d2..003b5d9 100644
--- a/src/test/java/org/apache/joshua/system/KenLmTest.java
+++ b/src/test/java/org/apache/joshua/system/KenLmTest.java
@@ -19,6 +19,7 @@
 package org.apache.joshua.system;
 
 import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.KenLMPool;
 import org.apache.joshua.decoder.ff.lm.KenLM;
 import org.apache.joshua.util.io.KenLmTestUtil;
 import org.testng.annotations.AfterMethod;
@@ -29,8 +30,7 @@ import static org.apache.joshua.corpus.Vocabulary.registerLanguageModel;
 import static org.apache.joshua.corpus.Vocabulary.unregisterLanguageModels;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.Is.is;
-import static org.mockito.Matchers.isNotNull;
-import static org.mockito.Matchers.notNull;
+import static org.hamcrest.core.IsNull.notNullValue;
 import static org.testng.Assert.assertTrue;
 import static org.testng.AssertJUnit.assertEquals;
 import static org.testng.AssertJUnit.assertFalse;
@@ -84,7 +84,7 @@ public class KenLmTest {
   }
 
   @Test
-  public void givenKenLm_whenQueryingForNgramProbability2_thenIdAndStringMethodsReturnTheSame() {
+  public void givenKenLm_whenQueryingWithState_thenStateAndProbReturned() {
     // GIVEN
     KenLmTestUtil.Guard(() -> kenLm = new KenLM(LANGUAGE_MODEL_PATH));
 
@@ -94,16 +94,18 @@ public class KenLmTest {
     int[] ids = Vocabulary.addAll(sentence);
     long[] longIds = new long[ids.length];
 
-    for(int i = 0; i< words.length; i++) {
+    for (int i = 0; i < words.length; i++) {
       longIds[i] = ids[i];
     }
 
     // WHEN
-    long poolPointer = kenLm.createLMPool();
-    KenLM.StateProbPair result = kenLm.probRule(longIds, poolPointer);
-    kenLm.destroyLMPool(poolPointer);
+    KenLM.StateProbPair result;
+    try (KenLMPool poolPointer = kenLm.createLMPool()) {
+      result = kenLm.probRule(longIds, poolPointer);
+    }
 
     // THEN
+    assertThat(result, is(notNullValue()));
     assertThat(result.state.getState(), is(0L));
     assertThat(result.prob, is(-3.7906885f));
   }