You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/13 15:48:52 UTC

[3/7] incubator-joshua git commit: Adapted Java side of JNI interface to get state and prob from packed long

Adapted Java side of JNI interface to get state and prob from packed long


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/4e07bb66
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/4e07bb66
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/4e07bb66

Branch: refs/heads/master
Commit: 4e07bb66d28e55357ee6b19b3c60a76a31d8dd75
Parents: 929760a
Author: Kellen Sunderland <ke...@amazon.com>
Authored: Tue Sep 13 12:39:41 2016 +0200
Committer: Kellen Sunderland <ke...@amazon.com>
Committed: Tue Sep 13 12:39:41 2016 +0200

----------------------------------------------------------------------
 jni/kenlm_wrap.cc                               |  9 +++---
 .../org/apache/joshua/decoder/ff/lm/KenLM.java  |  7 +++--
 .../org/apache/joshua/system/KenLmTest.java     | 29 ++++++++++++++++++++
 3 files changed, 39 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/4e07bb66/jni/kenlm_wrap.cc
----------------------------------------------------------------------
diff --git a/jni/kenlm_wrap.cc b/jni/kenlm_wrap.cc
index 8f69e19..bbe6e7c 100644
--- a/jni/kenlm_wrap.cc
+++ b/jni/kenlm_wrap.cc
@@ -95,7 +95,7 @@ class Chart {
       return *ins.first;
     }
 
-    const ChartState &InterpretState(StateIndex index) const {
+    const lm::ngram::ChartState &InterpretState(StateIndex index) const {
       return vec_[index];
     }
 
@@ -201,7 +201,7 @@ public:
     lm::ngram::RuleScore<Model> ruleScore(m_, state);
 
     if (*begin < 0) {
-      ruleScore.BeginNonTerminal(chart.Interpet(-*begin));
+      ruleScore.BeginNonTerminal(chart.InterpretState(-*begin));
     } else {
       const lm::WordIndex word = map_[*begin];
       if (word == m_.GetVocabulary().BeginSentence()) {
@@ -213,7 +213,7 @@ public:
     for (jlong* i = begin + 1; i != end; i++) {
       long word = *i;
       if (word < 0)
-        ruleScore.NonTerminal(chart.Interpret(-word));
+        ruleScore.NonTerminal(chart.InterpretState(-word));
       else
         ruleScore.Terminal(map_[word]);
     }
@@ -449,6 +449,7 @@ union FloatConverter {
 
 JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
   JNIEnv *env, jclass, jlong pointer, jlong chartPtr, jlongArray arr) {
+
   jint length = env->GetArrayLength(arr);
   // GCC only.
   jlong values[length];
@@ -458,7 +459,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
   lm::ngram::ChartState outState;
   const VirtualBase *base = reinterpret_cast<const VirtualBase*>(pointer);
   Chart* chart = reinterpret_cast<Chart*>(chartPtr);
-  FloatConvert prob;
+  FloatConverter prob;
   prob.f = base->ProbRule(values, values + length, outState, *chart);
 
   StateIndex index = chart->Intern(outState);

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/4e07bb66/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index 044c85f..b0a1117 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -61,7 +61,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
 
   private static native boolean isLmOov(long ptr, int word);
 
-  private static native StateProbPair probRule(long ptr, long pool, long words[]);
+  private static native long probRule(long ptr, long pool, long words[]);
 
   private static native float estimateRule(long ptr, long words[]);
 
@@ -161,7 +161,10 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
 
     StateProbPair pair = null;
     try {
-      pair = probRule(pointer, poolPointer, words);
+      long packedResult = probRule(pointer, poolPointer, words);
+      int state = (int) (packedResult >> 32);
+      float probVal = Float.intBitsToFloat((int)packedResult);
+      pair = new StateProbPair(state, probVal);
     } catch (NoSuchMethodError e) {
       e.printStackTrace();
       System.exit(1);

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/4e07bb66/src/test/java/org/apache/joshua/system/KenLmTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/system/KenLmTest.java b/src/test/java/org/apache/joshua/system/KenLmTest.java
index 74baef3..aa396d2 100644
--- a/src/test/java/org/apache/joshua/system/KenLmTest.java
+++ b/src/test/java/org/apache/joshua/system/KenLmTest.java
@@ -27,6 +27,10 @@ import org.testng.annotations.Test;
 
 import static org.apache.joshua.corpus.Vocabulary.registerLanguageModel;
 import static org.apache.joshua.corpus.Vocabulary.unregisterLanguageModels;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.Is.is;
+import static org.mockito.Matchers.isNotNull;
+import static org.mockito.Matchers.notNull;
 import static org.testng.Assert.assertTrue;
 import static org.testng.AssertJUnit.assertEquals;
 import static org.testng.AssertJUnit.assertFalse;
@@ -80,6 +84,31 @@ public class KenLmTest {
   }
 
   @Test
+  public void givenKenLm_whenQueryingForNgramProbability2_thenIdAndStringMethodsReturnTheSame() {
+    // GIVEN
+    KenLmTestUtil.Guard(() -> kenLm = new KenLM(LANGUAGE_MODEL_PATH));
+
+    registerLanguageModel(kenLm);
+    String sentence = "Wayne Gretzky";
+    String[] words = sentence.split("\\s+");
+    int[] ids = Vocabulary.addAll(sentence);
+    long[] longIds = new long[ids.length];
+
+    for(int i = 0; i< words.length; i++) {
+      longIds[i] = ids[i];
+    }
+
+    // WHEN
+    long poolPointer = kenLm.createLMPool();
+    KenLM.StateProbPair result = kenLm.probRule(longIds, poolPointer);
+    kenLm.destroyLMPool(poolPointer);
+
+    // THEN
+    assertThat(result.state.getState(), is(0L));
+    assertThat(result.prob, is(-3.7906885f));
+  }
+
+  @Test
   public void givenKenLm_whenIsKnownWord_thenReturnValuesAreCorrect() {
     KenLmTestUtil.Guard(() -> kenLm = new KenLM(LANGUAGE_MODEL_PATH));
     assertTrue(kenLm.isKnownWord("Wayne"));