You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/09/14 09:11:37 UTC
[15/43] incubator-joshua git commit: Probably won't compile but gets
the idea across
Probably won't compile but gets the idea across
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/5e954752
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/5e954752
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/5e954752
Branch: refs/heads/7
Commit: 5e9547526ad4bc15f48e665608897def552cb9ab
Parents: c3e7a15
Author: Kenneth Heafield <gi...@kheafield.com>
Authored: Tue Sep 13 10:58:26 2016 +0200
Committer: Kenneth Heafield <gi...@kheafield.com>
Committed: Tue Sep 13 10:58:26 2016 +0200
----------------------------------------------------------------------
jni/kenlm_wrap.cc | 142 ++++++++++++++++++++++---------------------------
1 file changed, 63 insertions(+), 79 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5e954752/jni/kenlm_wrap.cc
----------------------------------------------------------------------
diff --git a/jni/kenlm_wrap.cc b/jni/kenlm_wrap.cc
index 11d9c28..8f69e19 100644
--- a/jni/kenlm_wrap.cc
+++ b/jni/kenlm_wrap.cc
@@ -20,7 +20,6 @@
#include "lm/left.hh"
#include "lm/state.hh"
#include "util/murmur_hash.hh"
-#include "util/pool.hh"
#include <iostream>
@@ -30,7 +29,8 @@
#include <pthread.h>
// Grr. Everybody's compiler is slightly different and I'm trying to not depend on boost.
-#include <unordered_map>
+#include <unordered_set>
+#include <vector>
// Verify that jint and lm::ngram::WordIndex are the same size. If this breaks
// for you, there's a need to revise probString.
@@ -45,7 +45,35 @@ template<> struct StaticCheck<true> {
typedef StaticCheck<sizeof(jint) == sizeof(lm::WordIndex)>::StaticAssertionPassed FloatSize;
-typedef std::unordered_multimap<uint64_t, lm::ngram::ChartState*> PoolHash;
+// Could be uint64_t if you wanted to have 33-bit support.
+typedef uint32_t StateIndex;
+typedef std::vector<lm::ngram::ChartState> StateVector;
+
+class HashIndex : public std::unary_function<StateIndex, uint64_t> {
+ public:
+ explicit HashIndex(const StateVector &vec) : vec_(vec) {}
+
+ uint64_t operator()(StateIndex index) const {
+ return hash_value(vec_[index]);
+ }
+
+ private:
+ const StateVector &vec_;
+};
+
+class EqualIndex : public std::binary_function<StateIndex, StateIndex, bool> {
+ public:
+ explicit EqualIndex(const StateVector &vec) : vec_(vec) {}
+
+ bool operator()(StateIndex first, StateIndex second) const {
+ return vec_[first] == vec_[second];
+ }
+
+ private:
+ const StateVector &vec_;
+};
+
+typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
/**
* A Chart bundles together a unordered_multimap that maps ChartState signatures to a single
@@ -54,46 +82,26 @@ typedef std::unordered_multimap<uint64_t, lm::ngram::ChartState*> PoolHash;
* across KenLMs for the same sentence. Multimap is used to avoid hash collisions which can
* return incorrect results, and cause out-of-bounds lookups when multiple KenLMs are in use.
*/
-struct Chart {
- // A cache for allocated chart objects
- PoolHash* poolHash;
- // Pool used to allocate new ones
- util::Pool* pool;
-
- Chart() {
- poolHash = new PoolHash();
- pool = new util::Pool();
- }
-
- ~Chart() {
- delete poolHash;
- pool->FreeAll();
- delete pool;
- }
-
- lm::ngram::ChartState* put(const lm::ngram::ChartState& state) {
- lm::ngram::ChartState* state_ptr = nullptr;
- uint64_t hashValue = lm::ngram::hash_value(state);
- auto state_it = poolHash->find(hashValue);
-
- // Try to retrieve a matching ChartState pointer from our Pool
- while(state_it != poolHash->end()) {
- if (state == *(state_it->second)) {
- state_ptr = state_it->second;
- break;
+class Chart {
+ public:
+ Chart() : lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
+
+ StateIndex Intern(const lm::ngram::ChartState &state) {
+ vec_.push_back(state);
+ std::pair<Lookup::iterator, bool> ins(lookup_.insert(vec_.size() - 1));
+ if (!ins.second) {
+ vec_.pop_back();
}
- state_it++;
+ return *ins.first;
}
- // Unable to find this ChartState in our pool, allocate new space for it
- if (!state_ptr) {
- state_ptr = (lm::ngram::ChartState *) pool->Allocate(sizeof(lm::ngram::ChartState));
- *state_ptr = state;
- (*poolHash).insert({hashValue, state_ptr});
+ const ChartState &InterpretState(StateIndex index) const {
+ return vec_[index];
}
- return state_ptr;
- }
+ private:
+ StateVector vec_;
+ Lookup lookup_;
};
// Vocab ids above what the vocabulary knows about are unknown and should
@@ -131,7 +139,7 @@ public:
virtual bool IsKnownWordIndex(const lm::WordIndex& id) const = 0;
- virtual float ProbRule(jlong *begin, jlong *end, lm::ngram::ChartState& state) const = 0;
+ virtual float ProbRule(jlong *begin, jlong *end, lm::ngram::ChartState& state, const Chart &chart) const = 0;
virtual float ProbString(jint * const begin, jint * const end,
jint start) const = 0;
@@ -142,22 +150,9 @@ public:
virtual bool RegisterWord(const StringPiece& word, const int joshua_id) = 0;
- void RememberReturnMethod(jclass chart_pair, jmethodID chart_pair_init) {
- chart_pair_ = chart_pair;
- chart_pair_init_ = chart_pair_init;
- }
-
- jclass ChartPair() const { return chart_pair_; }
- jmethodID ChartPairInit() const { return chart_pair_init_; }
-
protected:
VirtualBase() {
}
-
-private:
- // Hack: these are remembered so we can avoid looking them up every time.
- jclass chart_pair_;
- jmethodID chart_pair_init_;
};
template<class Model> class VirtualImpl: public VirtualBase {
@@ -201,12 +196,12 @@ public:
return id != m_.GetVocabulary().NotFound();
}
- float ProbRule(jlong * const begin, jlong * const end, lm::ngram::ChartState& state) const {
+ float ProbRule(jlong * const begin, jlong * const end, lm::ngram::ChartState& state, const Chart &chart) const {
if (begin == end) return 0.0;
lm::ngram::RuleScore<Model> ruleScore(m_, state);
if (*begin < 0) {
- ruleScore.BeginNonTerminal(*reinterpret_cast<const lm::ngram::ChartState*>(-*begin));
+ ruleScore.BeginNonTerminal(chart.Interpet(-*begin));
} else {
const lm::WordIndex word = map_[*begin];
if (word == m_.GetVocabulary().BeginSentence()) {
@@ -218,7 +213,7 @@ public:
for (jlong* i = begin + 1; i != end; i++) {
long word = *i;
if (word < 0)
- ruleScore.NonTerminal(*reinterpret_cast<const lm::ngram::ChartState*>(-word));
+ ruleScore.NonTerminal(chart.Interpret(-word));
else
ruleScore.Terminal(map_[word]);
}
@@ -341,18 +336,6 @@ JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_construct(
VirtualBase *ret;
try {
ret = ConstructModel(str);
-
- // Get a class reference for the type pair that char
- jclass local_chart_pair = env->FindClass("org/apache/joshua/decoder/ff/lm/KenLM$StateProbPair");
- UTIL_THROW_IF(!local_chart_pair, util::Exception, "Failed to find org/apache/joshua/decoder/ff/lm/KenLM$StateProbPair");
- jclass chart_pair = (jclass)env->NewGlobalRef(local_chart_pair);
- env->DeleteLocalRef(local_chart_pair);
-
- // Get the Method ID of the constructor which takes an int
- jmethodID chart_pair_init = env->GetMethodID(chart_pair, "<init>", "(JF)V");
- UTIL_THROW_IF(!chart_pair_init, util::Exception, "Failed to find init method");
-
- ret->RememberReturnMethod(chart_pair, chart_pair_init);
} catch (std::exception &e) {
std::cerr << e.what() << std::endl;
abort();
@@ -363,20 +346,17 @@ JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_construct(
JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroy(
JNIEnv *env, jclass, jlong pointer) {
- VirtualBase *base = reinterpret_cast<VirtualBase*>(pointer);
- env->DeleteGlobalRef(base->ChartPair());
- delete base;
+ delete reinterpret_cast<VirtualBase*>(pointer);
}
-JNIEXPORT long JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
+JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
JNIEnv *env, jclass) {
return reinterpret_cast<long>(new Chart());
}
JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroyPool(
JNIEnv *env, jclass, jlong pointer) {
- Chart* chart = reinterpret_cast<Chart*>(pointer);
- delete chart;
+ delete reinterpret_cast<Chart*>(pointer);
}
JNIEXPORT jint JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_order(
@@ -462,7 +442,12 @@ JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probString(
values + length, start);
}
-JNIEXPORT jobject JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
+union FloatConverter {
+ float f;
+ uint32_t i;
+};
+
+JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
JNIEnv *env, jclass, jlong pointer, jlong chartPtr, jlongArray arr) {
jint length = env->GetArrayLength(arr);
// GCC only.
@@ -472,13 +457,12 @@ JNIEXPORT jobject JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
// Compute the probability
lm::ngram::ChartState outState;
const VirtualBase *base = reinterpret_cast<const VirtualBase*>(pointer);
- float prob = base->ProbRule(values, values + length, outState);
-
Chart* chart = reinterpret_cast<Chart*>(chartPtr);
- lm::ngram::ChartState* outStatePtr = chart->put(outState);
+ FloatConvert prob;
+ prob.f = base->ProbRule(values, values + length, outState, *chart);
- // Call back constructor to allocate a new instance, with an int argument
- return env->NewObject(base->ChartPair(), base->ChartPairInit(), (long)outStatePtr, prob);
+ StateIndex index = chart->Intern(outState);
+ return static_cast<uint64_t>(index) << 32 | static_cast<uint64_t>(prob.i);
}
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_estimateRule(