You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/14 09:11:37 UTC

[15/43] incubator-joshua git commit: Probably won't compile but gets the idea across

Probably won't compile but gets the idea across


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/5e954752
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/5e954752
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/5e954752

Branch: refs/heads/7
Commit: 5e9547526ad4bc15f48e665608897def552cb9ab
Parents: c3e7a15
Author: Kenneth Heafield <gi...@kheafield.com>
Authored: Tue Sep 13 10:58:26 2016 +0200
Committer: Kenneth Heafield <gi...@kheafield.com>
Committed: Tue Sep 13 10:58:26 2016 +0200

----------------------------------------------------------------------
 jni/kenlm_wrap.cc | 142 ++++++++++++++++++++++---------------------------
 1 file changed, 63 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5e954752/jni/kenlm_wrap.cc
----------------------------------------------------------------------
diff --git a/jni/kenlm_wrap.cc b/jni/kenlm_wrap.cc
index 11d9c28..8f69e19 100644
--- a/jni/kenlm_wrap.cc
+++ b/jni/kenlm_wrap.cc
@@ -20,7 +20,6 @@
 #include "lm/left.hh"
 #include "lm/state.hh"
 #include "util/murmur_hash.hh"
-#include "util/pool.hh"
 
 #include <iostream>
 
@@ -30,7 +29,8 @@
 #include <pthread.h>
 
 // Grr.  Everybody's compiler is slightly different and I'm trying to not depend on boost.
-#include <unordered_map>
+#include <unordered_set>
+#include <vector>
 
 // Verify that jint and lm::ngram::WordIndex are the same size. If this breaks
 // for you, there's a need to revise probString.
@@ -45,7 +45,35 @@ template<> struct StaticCheck<true> {
 
 typedef StaticCheck<sizeof(jint) == sizeof(lm::WordIndex)>::StaticAssertionPassed FloatSize;
 
-typedef std::unordered_multimap<uint64_t, lm::ngram::ChartState*> PoolHash;
+// Could be uint64_t if you wanted to have 33-bit support.
+typedef uint32_t StateIndex;
+typedef std::vector<lm::ngram::ChartState> StateVector;
+
+class HashIndex : public std::unary_function<StateIndex, uint64_t> {
+  public:
+    explicit HashIndex(const StateVector &vec) : vec_(vec) {}
+
+    uint64_t operator()(StateIndex index) const {
+      return hash_value(vec_[index]);
+    }
+
+  private:
+    const StateVector &vec_;
+};
+
+class EqualIndex : public std::binary_function<StateIndex, StateIndex, bool> {
+  public:
+    explicit EqualIndex(const StateVector &vec) : vec_(vec) {}
+
+    bool operator()(StateIndex first, StateIndex second) const {
+      return vec_[first] == vec_[second];
+    }
+
+  private:
+    const StateVector &vec_;
+};
+
+typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
 
 /**
  * A Chart bundles together a unordered_multimap that maps ChartState signatures to a single
@@ -54,46 +82,26 @@ typedef std::unordered_multimap<uint64_t, lm::ngram::ChartState*> PoolHash;
  * across KenLMs for the same sentence.  Multimap is used to avoid hash collisions which can
  * return incorrect results, and cause out-of-bounds lookups when multiple KenLMs are in use.
  */
-struct Chart {
-  // A cache for allocated chart objects
-  PoolHash* poolHash;
-  // Pool used to allocate new ones
-  util::Pool* pool;
-
-  Chart() {
-    poolHash = new PoolHash();
-    pool = new util::Pool();
-  }
-
-  ~Chart() {
-    delete poolHash;
-    pool->FreeAll();
-    delete pool;
-  }
-
-  lm::ngram::ChartState* put(const lm::ngram::ChartState& state) {
-    lm::ngram::ChartState* state_ptr = nullptr;
-    uint64_t hashValue = lm::ngram::hash_value(state);
-    auto state_it = poolHash->find(hashValue);
-
-    // Try to retrieve a matching ChartState pointer from our Pool
-    while(state_it != poolHash->end()) {
-      if (state == *(state_it->second)) {
-        state_ptr = state_it->second;
-        break;
+class Chart {
+  public:
+    Chart() : lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
+
+    StateIndex Intern(const lm::ngram::ChartState &state) {
+      vec_.push_back(state);
+      std::pair<Lookup::iterator, bool> ins(lookup_.insert(vec_.size() - 1));
+      if (!ins.second) {
+        vec_.pop_back();
       }
-      state_it++;
+      return *ins.first;
     }
 
-    // Unable to find this ChartState in our pool, allocate new space for it
-    if (!state_ptr) {
-      state_ptr = (lm::ngram::ChartState *) pool->Allocate(sizeof(lm::ngram::ChartState));
-      *state_ptr = state;
-      (*poolHash).insert({hashValue, state_ptr});
+    const ChartState &InterpretState(StateIndex index) const {
+      return vec_[index];
     }
 
-    return state_ptr;
-  }
+  private:
+    StateVector vec_;
+    Lookup lookup_;
 };
 
 // Vocab ids above what the vocabulary knows about are unknown and should
@@ -131,7 +139,7 @@ public:
 
   virtual bool IsKnownWordIndex(const lm::WordIndex& id) const = 0;
 
-  virtual float ProbRule(jlong *begin, jlong *end, lm::ngram::ChartState& state) const = 0;
+  virtual float ProbRule(jlong *begin, jlong *end, lm::ngram::ChartState& state, const Chart &chart) const = 0;
 
   virtual float ProbString(jint * const begin, jint * const end,
       jint start) const = 0;
@@ -142,22 +150,9 @@ public:
 
   virtual bool RegisterWord(const StringPiece& word, const int joshua_id) = 0;
 
-  void RememberReturnMethod(jclass chart_pair, jmethodID chart_pair_init) {
-    chart_pair_ = chart_pair;
-    chart_pair_init_ = chart_pair_init;
-  }
-
-  jclass ChartPair() const { return chart_pair_; }
-  jmethodID ChartPairInit() const { return chart_pair_init_; }
-
 protected:
   VirtualBase() {
   }
-
-private:
-  // Hack: these are remembered so we can avoid looking them up every time.
-  jclass chart_pair_;
-  jmethodID chart_pair_init_;
 };
 
 template<class Model> class VirtualImpl: public VirtualBase {
@@ -201,12 +196,12 @@ public:
       return id != m_.GetVocabulary().NotFound();
   }
 
-  float ProbRule(jlong * const begin, jlong * const end, lm::ngram::ChartState& state) const {
+  float ProbRule(jlong * const begin, jlong * const end, lm::ngram::ChartState& state, const Chart &chart) const {
     if (begin == end) return 0.0;
     lm::ngram::RuleScore<Model> ruleScore(m_, state);
 
     if (*begin < 0) {
-      ruleScore.BeginNonTerminal(*reinterpret_cast<const lm::ngram::ChartState*>(-*begin));
+      ruleScore.BeginNonTerminal(chart.Interpet(-*begin));
     } else {
       const lm::WordIndex word = map_[*begin];
       if (word == m_.GetVocabulary().BeginSentence()) {
@@ -218,7 +213,7 @@ public:
     for (jlong* i = begin + 1; i != end; i++) {
       long word = *i;
       if (word < 0)
-        ruleScore.NonTerminal(*reinterpret_cast<const lm::ngram::ChartState*>(-word));
+        ruleScore.NonTerminal(chart.Interpret(-word));
       else
         ruleScore.Terminal(map_[word]);
     }
@@ -341,18 +336,6 @@ JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_construct(
   VirtualBase *ret;
   try {
     ret = ConstructModel(str);
-
-    // Get a class reference for the type pair that char
-    jclass local_chart_pair = env->FindClass("org/apache/joshua/decoder/ff/lm/KenLM$StateProbPair");
-    UTIL_THROW_IF(!local_chart_pair, util::Exception, "Failed to find org/apache/joshua/decoder/ff/lm/KenLM$StateProbPair");
-    jclass chart_pair = (jclass)env->NewGlobalRef(local_chart_pair);
-    env->DeleteLocalRef(local_chart_pair);
-
-    // Get the Method ID of the constructor which takes an int
-    jmethodID chart_pair_init = env->GetMethodID(chart_pair, "<init>", "(JF)V");
-    UTIL_THROW_IF(!chart_pair_init, util::Exception, "Failed to find init method");
-
-    ret->RememberReturnMethod(chart_pair, chart_pair_init);
   } catch (std::exception &e) {
     std::cerr << e.what() << std::endl;
     abort();
@@ -363,20 +346,17 @@ JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_construct(
 
 JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroy(
     JNIEnv *env, jclass, jlong pointer) {
-  VirtualBase *base = reinterpret_cast<VirtualBase*>(pointer);
-  env->DeleteGlobalRef(base->ChartPair());
-  delete base;
+  delete reinterpret_cast<VirtualBase*>(pointer);
 }
 
-JNIEXPORT long JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
+JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
     JNIEnv *env, jclass) {
   return reinterpret_cast<long>(new Chart());
 }
 
 JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroyPool(
     JNIEnv *env, jclass, jlong pointer) {
-  Chart* chart = reinterpret_cast<Chart*>(pointer);
-  delete chart;
+  delete reinterpret_cast<Chart*>(pointer);
 }
 
 JNIEXPORT jint JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_order(
@@ -462,7 +442,12 @@ JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probString(
       values + length, start);
 }
 
-JNIEXPORT jobject JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
+union FloatConverter {
+  float f;
+  uint32_t i;
+};
+
+JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
   JNIEnv *env, jclass, jlong pointer, jlong chartPtr, jlongArray arr) {
   jint length = env->GetArrayLength(arr);
   // GCC only.
@@ -472,13 +457,12 @@ JNIEXPORT jobject JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
   // Compute the probability
   lm::ngram::ChartState outState;
   const VirtualBase *base = reinterpret_cast<const VirtualBase*>(pointer);
-  float prob = base->ProbRule(values, values + length, outState);
-
   Chart* chart = reinterpret_cast<Chart*>(chartPtr);
-  lm::ngram::ChartState* outStatePtr = chart->put(outState);
+  FloatConvert prob;
+  prob.f = base->ProbRule(values, values + length, outState, *chart);
 
-  // Call back constructor to allocate a new instance, with an int argument
-  return env->NewObject(base->ChartPair(), base->ChartPairInit(), (long)outStatePtr, prob);
+  StateIndex index = chart->Intern(outState);
+  return static_cast<uint64_t>(index) << 32 | static_cast<uint64_t>(prob.i);
 }
 
 JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_estimateRule(