You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/06/01 02:51:52 UTC
[57/94] [abbrv] incubator-joshua git commit: Merge branch 'master'
into JOSHUA-252 (compiling, but not tested)
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java
index 0660e8a,0000000..575515a
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java
+++ b/src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java
@@@ -1,796 -1,0 +1,797 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.oracle;
+
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiString;
+import static org.apache.joshua.util.FormatUtils.removeSentenceMarkers;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.Support;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.hypergraph.HyperEdge;
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor;
+import org.apache.joshua.util.FileUtility;
+import org.apache.joshua.util.io.LineReader;
++import org.apache.joshua.util.FormatUtils;
+
+/**
+ * approximated BLEU (1) do not consider clipping effect (2) in the dynamic programming, do not
+ * maintain different states for different hyp length (3) brief penalty is calculated based on the
+ * avg ref length (4) using sentence-level BLEU, instead of doc-level BLEU
+ *
+ * @author Zhifei Li, zhifei.work@gmail.com (Johns Hopkins University)
+ */
+public class OracleExtractionHG extends SplitHg {
+ static String BACKOFF_LEFT_LM_STATE_SYM = "<lzfbo>";
+ public int BACKOFF_LEFT_LM_STATE_SYM_ID;// used for equivelant state
+
+ static String NULL_LEFT_LM_STATE_SYM = "<lzflnull>";
+ public int NULL_LEFT_LM_STATE_SYM_ID;// used for equivelant state
+
+ static String NULL_RIGHT_LM_STATE_SYM = "<lzfrnull>";
+ public int NULL_RIGHT_LM_STATE_SYM_ID;// used for equivelant state
+
+ // int[] ref_sentence;//reference string (not tree)
+ protected int src_sent_len = 0;
+ protected int ref_sent_len = 0;
+ protected int g_lm_order = 4; // only used for decide whether to get the LM state by this class or
+ // not in compute_state
+ static protected boolean do_local_ngram_clip = false;
+ static protected boolean maitain_length_state = false;
+ static protected int g_bleu_order = 4;
+
+ static boolean using_left_equiv_state = true;
+ static boolean using_right_equiv_state = true;
+
+ // TODO Add generics to hash tables in this class
+ HashMap<String, Boolean> tbl_suffix = new HashMap<String, Boolean>();
+ HashMap<String, Boolean> tbl_prefix = new HashMap<String, Boolean>();
+ static PrefixGrammar grammar_prefix = new PrefixGrammar();// TODO
+ static PrefixGrammar grammar_suffix = new PrefixGrammar();// TODO
+
+ // key: item; value: best_deduction, best_bleu, best_len, # of n-gram match where n is in [1,4]
+ protected HashMap<String, Integer> tbl_ref_ngrams = new HashMap<String, Integer>();
+
+ static boolean always_maintain_seperate_lm_state = true; // if true: the virtual item maintain its
+ // own lm state regardless whether
+ // lm_order>=g_bleu_order
+
+ int lm_feat_id = 0; // the baseline LM feature id
+
+ /**
+ * Constructs a new object capable of extracting a tree from a hypergraph that most closely
+ * matches a provided oracle sentence.
+ * <p>
+ * It seems that the symbol table here should only need to represent monolingual terminals, plus
+ * nonterminals.
+ *
+ * @param lm_feat_id_ a language model feature identifier
+ */
+ public OracleExtractionHG(int lm_feat_id_) {
+ this.lm_feat_id = lm_feat_id_;
+ this.BACKOFF_LEFT_LM_STATE_SYM_ID = Vocabulary.id(BACKOFF_LEFT_LM_STATE_SYM);
+ this.NULL_LEFT_LM_STATE_SYM_ID = Vocabulary.id(NULL_RIGHT_LM_STATE_SYM);
+ this.NULL_RIGHT_LM_STATE_SYM_ID = Vocabulary.id(NULL_RIGHT_LM_STATE_SYM);
+ }
+
+ /*
+ * for 919 sent, time_on_reading: 148797 time_on_orc_extract: 580286
+ */
+ @SuppressWarnings({ "unused" })
+ public static void main(String[] args) throws IOException {
+ JoshuaConfiguration joshuaConfiguration = new JoshuaConfiguration();
+ /*
+ * String f_hypergraphs="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.items"; String
+ * f_rule_tbl="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.rules"; String
+ * f_ref_files="C:\\Users\\zli\\Documents\\mt03.ref.txt.1"; String f_orc_out
+ * ="C:\\Users\\zli\\Documents\\mt03.orc.txt";
+ */
+ if (6 != args.length) {
+ System.out
+ .println("Usage: java Decoder f_hypergraphs f_rule_tbl f_ref_files f_orc_out lm_order orc_extract_nbest");
+ System.out.println("num of args is " + args.length);
+ for (int i = 0; i < args.length; i++) {
+ System.out.println("arg is: " + args[i]);
+ }
+ System.exit(1);
+ }
+ // String f_hypergraphs = args[0].trim();
+ // String f_rule_tbl = args[1].trim();
+ String f_ref_files = args[2].trim();
+ String f_orc_out = args[3].trim();
+ int lm_order = Integer.parseInt(args[4].trim());
+ boolean orc_extract_nbest = Boolean.valueOf(args[5].trim()); // oracle extraction from nbest or hg
+
+ int baseline_lm_feat_id = 0;
+
+ KBestExtractor kbest_extractor = null;
+ int topN = 300;// TODO
+ joshuaConfiguration.use_unique_nbest = true;
+ joshuaConfiguration.include_align_index = false;
+ boolean do_ngram_clip_nbest = true; // TODO
+ if (orc_extract_nbest) {
+ System.out.println("oracle extraction from nbest list");
+
+ kbest_extractor = new KBestExtractor(null, null, Decoder.weights, false, joshuaConfiguration);
+ }
+
+ BufferedWriter orc_out = FileUtility.getWriteFileStream(f_orc_out);
+
+ long start_time0 = System.currentTimeMillis();
+ long time_on_reading = 0;
+ long time_on_orc_extract = 0;
+ // DiskHyperGraph dhg_read = new DiskHyperGraph(baseline_lm_feat_id, true, null);
+
+ // dhg_read.initRead(f_hypergraphs, f_rule_tbl, null);
+
+ OracleExtractionHG orc_extractor = new OracleExtractionHG(baseline_lm_feat_id);
+ long start_time = System.currentTimeMillis();
+ int sent_id = 0;
+ for (String ref_sent: new LineReader(f_ref_files)) {
+ System.out.println("############Process sentence " + sent_id);
+ start_time = System.currentTimeMillis();
+ sent_id++;
+ // if(sent_id>10)break;
+
+ // HyperGraph hg = dhg_read.readHyperGraph();
+ HyperGraph hg = null;
+ if (hg == null)
+ continue;
+
+ // System.out.println("read disk hyp: " + (System.currentTimeMillis()-start_time));
+ time_on_reading += System.currentTimeMillis() - start_time;
+ start_time = System.currentTimeMillis();
+
+ String orc_sent = null;
+ double orc_bleu = 0;
+ if (orc_extract_nbest) {
+ Object[] res = orc_extractor.oracle_extract_nbest(kbest_extractor, hg, topN,
+ do_ngram_clip_nbest, ref_sent);
+ orc_sent = (String) res[0];
+ orc_bleu = (Double) res[1];
+ } else {
+ HyperGraph hg_oracle = orc_extractor.oracle_extract_hg(hg, hg.sentLen(), lm_order, ref_sent);
+ orc_sent = removeSentenceMarkers(getViterbiString(hg_oracle));
+ orc_bleu = orc_extractor.get_best_goal_cost(hg, orc_extractor.g_tbl_split_virtual_items);
+
+ time_on_orc_extract += System.currentTimeMillis() - start_time;
+ System.out.println("num_virtual_items: " + orc_extractor.g_num_virtual_items
+ + " num_virtual_dts: " + orc_extractor.g_num_virtual_deductions);
+ // System.out.println("oracle extract: " + (System.currentTimeMillis()-start_time));
+ }
+
+ orc_out.write(orc_sent + "\n");
+ System.out.println("orc bleu is " + orc_bleu);
+ }
+ orc_out.close();
+
+ System.out.println("time_on_reading: " + time_on_reading);
+ System.out.println("time_on_orc_extract: " + time_on_orc_extract);
+ System.out.println("total running time: " + (System.currentTimeMillis() - start_time0));
+ }
+
+ // find the oracle hypothesis in the nbest list
+ public Object[] oracle_extract_nbest(KBestExtractor kbest_extractor, HyperGraph hg, int n,
+ boolean do_ngram_clip, String ref_sent) {
+ if (hg.goalNode == null)
+ return null;
+ kbest_extractor.resetState();
+ int next_n = 0;
+ double orc_bleu = -1;
+ String orc_sent = null;
+ while (true) {
+ String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n);// ?????????
+ if (hyp_sent == null || next_n > n)
+ break;
+ double t_bleu = compute_sentence_bleu(ref_sent, hyp_sent, do_ngram_clip, 4);
+ if (t_bleu > orc_bleu) {
+ orc_bleu = t_bleu;
+ orc_sent = hyp_sent;
+ }
+ }
+ System.out.println("Oracle sent: " + orc_sent);
+ System.out.println("Oracle bleu: " + orc_bleu);
+ Object[] res = new Object[2];
+ res[0] = orc_sent;
+ res[1] = orc_bleu;
+ return res;
+ }
+
+ public HyperGraph oracle_extract_hg(HyperGraph hg, int src_sent_len_in, int lm_order,
+ String ref_sent_str) {
+ int[] ref_sent = Vocabulary.addAll(ref_sent_str);
+ g_lm_order = lm_order;
+ src_sent_len = src_sent_len_in;
+ ref_sent_len = ref_sent.length;
+
+ tbl_ref_ngrams.clear();
+ get_ngrams(tbl_ref_ngrams, g_bleu_order, ref_sent, false);
+ if (using_left_equiv_state || using_right_equiv_state) {
+ tbl_prefix.clear();
+ tbl_suffix.clear();
+ setup_prefix_suffix_tbl(ref_sent, g_bleu_order, tbl_prefix, tbl_suffix);
+ setup_prefix_suffix_grammar(ref_sent, g_bleu_order, grammar_prefix, grammar_suffix);// TODO
+ }
+ split_hg(hg);
+
+ // System.out.println("best bleu is " + get_best_goal_cost( hg, g_tbl_split_virtual_items));
+ return get_1best_tree_hg(hg, g_tbl_split_virtual_items);
+ }
+
+ /*
+ * This procedure does (1) identify all possible match (2) add a new deduction for each matches
+ */
+ protected void process_one_combination_axiom(HGNode parent_item,
+ HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt) {
+ if (null == cur_dt.getRule()) {
+ throw new RuntimeException("error null rule in axiom");
+ }
+ double avg_ref_len = (parent_item.j - parent_item.i >= src_sent_len) ? ref_sent_len
+ : (parent_item.j - parent_item.i) * ref_sent_len * 1.0 / src_sent_len;// avg len?
+ double bleu_score[] = new double[1];
+ DPStateOracle dps = compute_state(parent_item, cur_dt, null, tbl_ref_ngrams,
+ do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix);
+ VirtualDeduction t_dt = new VirtualDeduction(cur_dt, null, -bleu_score[0]);// cost: -best_bleu
+ g_num_virtual_deductions++;
+ add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true);
+ }
+
+ /*
+ * This procedure does (1) create a new deduction (based on cur_dt and ant_virtual_item) (2) find
+ * whether an Item can contain this deduction (based on virtual_item_sigs which is a hashmap
+ * specific to a parent_item) (2.1) if yes, add the deduction, (2.2) otherwise (2.2.1) create a
+ * new item (2.2.2) and add the item into virtual_item_sigs
+ */
+ protected void process_one_combination_nonaxiom(HGNode parent_item,
+ HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt,
+ ArrayList<VirtualItem> l_ant_virtual_item) {
+ if (null == l_ant_virtual_item) {
+ throw new RuntimeException("wrong call in process_one_combination_nonaxiom");
+ }
+ double avg_ref_len = (parent_item.j - parent_item.i >= src_sent_len) ? ref_sent_len
+ : (parent_item.j - parent_item.i) * ref_sent_len * 1.0 / src_sent_len;// avg len?
+ double bleu_score[] = new double[1];
+ DPStateOracle dps = compute_state(parent_item, cur_dt, l_ant_virtual_item, tbl_ref_ngrams,
+ do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix);
+ VirtualDeduction t_dt = new VirtualDeduction(cur_dt, l_ant_virtual_item, -bleu_score[0]);// cost:
+ // -best_bleu
+ g_num_virtual_deductions++;
+ add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true);
+ }
+
+ // DPState maintain all the state information at an item that is required during dynamic
+ // programming
+ protected static class DPStateOracle extends DPState {
+ int best_len; // this may not be used in the signature
+ int[] ngram_matches;
+ int[] left_lm_state;
+ int[] right_lm_state;
+
+ public DPStateOracle(int blen, int[] matches, int[] left, int[] right) {
+ best_len = blen;
+ ngram_matches = matches;
+ left_lm_state = left;
+ right_lm_state = right;
+ }
+
+ protected String get_signature() {
+ StringBuffer res = new StringBuffer();
+ if (maitain_length_state) {
+ res.append(best_len);
+ res.append(' ');
+ }
+ if (null != left_lm_state) { // goal-item have null state
+ for (int i = 0; i < left_lm_state.length; i++) {
+ res.append(left_lm_state[i]);
+ res.append(' ');
+ }
+ }
+ res.append("lzf ");
+
+ if (null != right_lm_state) { // goal-item have null state
+ for (int i = 0; i < right_lm_state.length; i++) {
+ res.append(right_lm_state[i]);
+ res.append(' ');
+ }
+ }
+ // if(left_lm_state==null || right_lm_state==null)System.out.println("sig is: " +
+ // res.toString());
+ return res.toString();
+ }
+
+ protected void print() {
+ StringBuffer res = new StringBuffer();
+ res.append("DPstate: best_len: ");
+ res.append(best_len);
+ for (int i = 0; i < ngram_matches.length; i++) {
+ res.append("; ngram: ");
+ res.append(ngram_matches[i]);
+ }
+ System.out.println(res.toString());
+ }
+ }
+
+ // ########################## commmon funcions #####################
+ // based on tbl_oracle_states, tbl_ref_ngrams, and dt, get the state
+ // get the new state: STATE_BEST_DEDUCT STATE_BEST_BLEU STATE_BEST_LEN NGRAM_MATCH_COUNTS
+ protected DPStateOracle compute_state(HGNode parent_item, HyperEdge dt,
+ ArrayList<VirtualItem> l_ant_virtual_item, HashMap<String, Integer> tbl_ref_ngrams,
+ boolean do_local_ngram_clip, int lm_order, double ref_len, double[] bleu_score,
+ HashMap<String, Boolean> tbl_suffix, HashMap<String, Boolean> tbl_prefix) {
+ // ##### deductions under "goal item" does not have rule
+ if (null == dt.getRule()) {
+ if (l_ant_virtual_item.size() != 1) {
+ throw new RuntimeException("error deduction under goal item have more than one item");
+ }
+ bleu_score[0] = -l_ant_virtual_item.get(0).best_virtual_deduction.best_cost;
+ return new DPStateOracle(0, null, null, null); // no DPState at all
+ }
+
+ // ################## deductions *not* under "goal item"
+ HashMap<String, Integer> new_ngram_counts = new HashMap<String, Integer>();// new ngrams created
+ // due to the
+ // combination
+ HashMap<String, Integer> old_ngram_counts = new HashMap<String, Integer>();// the ngram that has
+ // already been
+ // computed
+ int total_hyp_len = 0;
+ int[] num_ngram_match = new int[g_bleu_order];
+ int[] en_words = dt.getRule().getEnglish();
+
+ // ####calulate new and old ngram counts, and len
+
+ ArrayList<Integer> words = new ArrayList<Integer>();
+
+ // used for compute left- and right- lm state
+ ArrayList<Integer> left_state_sequence = null;
+ // used for compute left- and right- lm state
+ ArrayList<Integer> right_state_sequence = null;
+
+ int correct_lm_order = lm_order;
+ if (always_maintain_seperate_lm_state || lm_order < g_bleu_order) {
+ left_state_sequence = new ArrayList<Integer>();
+ right_state_sequence = new ArrayList<Integer>();
+ correct_lm_order = g_bleu_order; // if lm_order is smaller than g_bleu_order, we will get the
+ // lm state by ourself
+ }
+
+ // #### get left_state_sequence, right_state_sequence, total_hyp_len, num_ngram_match
+ for (int c = 0; c < en_words.length; c++) {
+ int c_id = en_words[c];
- if (Vocabulary.nt(c_id)) {
++ if (FormatUtils.isNonterminal(c_id)) {
+ int index = -(c_id + 1);
+ DPStateOracle ant_state = (DPStateOracle) l_ant_virtual_item.get(index).dp_state;
+ total_hyp_len += ant_state.best_len;
+ for (int t = 0; t < g_bleu_order; t++) {
+ num_ngram_match[t] += ant_state.ngram_matches[t];
+ }
+ int[] l_context = ant_state.left_lm_state;
+ int[] r_context = ant_state.right_lm_state;
+ for (int t : l_context) { // always have l_context
+ words.add(t);
+ if (null != left_state_sequence && left_state_sequence.size() < g_bleu_order - 1) {
+ left_state_sequence.add(t);
+ }
+ }
+ get_ngrams(old_ngram_counts, g_bleu_order, l_context, true);
+ if (r_context.length >= correct_lm_order - 1) { // the right and left are NOT overlapping
+ get_ngrams(new_ngram_counts, g_bleu_order, words, true);
+ get_ngrams(old_ngram_counts, g_bleu_order, r_context, true);
+ words.clear();// start a new chunk
+ if (null != right_state_sequence) {
+ right_state_sequence.clear();
+ }
+ for (int t : r_context) {
+ words.add(t);
+ }
+ }
+ if (null != right_state_sequence) {
+ for (int t : r_context) {
+ right_state_sequence.add(t);
+ }
+ }
+ } else {
+ words.add(c_id);
+ total_hyp_len += 1;
+ if (null != left_state_sequence && left_state_sequence.size() < g_bleu_order - 1) {
+ left_state_sequence.add(c_id);
+ }
+ if (null != right_state_sequence) {
+ right_state_sequence.add(c_id);
+ }
+ }
+ }
+ get_ngrams(new_ngram_counts, g_bleu_order, words, true);
+
+ // ####now deduct ngram counts
+ for (String ngram : new_ngram_counts.keySet()) {
+ if (tbl_ref_ngrams.containsKey(ngram)) {
+ int final_count = (Integer) new_ngram_counts.get(ngram);
+ if (old_ngram_counts.containsKey(ngram)) {
+ final_count -= (Integer) old_ngram_counts.get(ngram);
+ // BUG: Whoa, is that an actual hard-coded ID in there? :)
+ if (final_count < 0) {
+ throw new RuntimeException("negative count for ngram: " + Vocabulary.word(11844)
+ + "; new: " + new_ngram_counts.get(ngram) + "; old: " + old_ngram_counts.get(ngram));
+ }
+ }
+ if (final_count > 0) { // TODO: not correct/global ngram clip
+ if (do_local_ngram_clip) {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += Support.findMin(final_count,
+ (Integer) tbl_ref_ngrams.get(ngram));
+ } else {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += final_count; // do not do any cliping
+ }
+ }
+ }
+ }
+
+ // ####now calculate the BLEU score and state
+ int[] left_lm_state = null;
+ int[] right_lm_state = null;
+ left_lm_state = get_left_equiv_state(left_state_sequence, tbl_suffix);
+ right_lm_state = get_right_equiv_state(right_state_sequence, tbl_prefix);
+
+ // debug
+ // System.out.println("lm_order is " + lm_order);
+ // compare_two_int_arrays(left_lm_state,
+ // (int[])parent_item.tbl_states.get(Symbol.LM_L_STATE_SYM_ID));
+ // compare_two_int_arrays(right_lm_state,
+ // (int[])parent_item.tbl_states.get(Symbol.LM_R_STATE_SYM_ID));
+ // end
+
+ bleu_score[0] = compute_bleu(total_hyp_len, ref_len, num_ngram_match, g_bleu_order);
+ // System.out.println("blue score is " + bleu_score[0]);
+ return new DPStateOracle(total_hyp_len, num_ngram_match, left_lm_state, right_lm_state);
+ }
+
+ private int[] get_left_equiv_state(ArrayList<Integer> left_state_sequence,
+ HashMap<String, Boolean> tbl_suffix) {
+ int l_size = (left_state_sequence.size() < g_bleu_order - 1) ? left_state_sequence.size()
+ : (g_bleu_order - 1);
+ int[] left_lm_state = new int[l_size];
+ if (!using_left_equiv_state || l_size < g_bleu_order - 1) { // regular
+ for (int i = 0; i < l_size; i++) {
+ left_lm_state[i] = left_state_sequence.get(i);
+ }
+ } else {
+ for (int i = l_size - 1; i >= 0; i--) { // right to left
+ if (is_a_suffix_in_tbl(left_state_sequence, 0, i, tbl_suffix)) {
+ // if(is_a_suffix_in_grammar(left_state_sequence, 0, i, grammar_suffix)){
+ for (int j = i; j >= 0; j--) {
+ left_lm_state[j] = left_state_sequence.get(j);
+ }
+ break;
+ } else {
+ left_lm_state[i] = this.NULL_LEFT_LM_STATE_SYM_ID;
+ }
+ }
+ // System.out.println("origi left:" + Symbol.get_string(left_state_sequence) + "; equiv left:"
+ // + Symbol.get_string(left_lm_state));
+ }
+ return left_lm_state;
+ }
+
+ private boolean is_a_suffix_in_tbl(ArrayList<Integer> left_state_sequence, int start_pos,
+ int end_pos, HashMap<String, Boolean> tbl_suffix) {
+ if ((Integer) left_state_sequence.get(end_pos) == this.NULL_LEFT_LM_STATE_SYM_ID) {
+ return false;
+ }
+ StringBuffer suffix = new StringBuffer();
+ for (int i = end_pos; i >= start_pos; i--) { // right-most first
+ suffix.append(left_state_sequence.get(i));
+ if (i > start_pos)
+ suffix.append(' ');
+ }
+ return (Boolean) tbl_suffix.containsKey(suffix.toString());
+ }
+
+ private int[] get_right_equiv_state(ArrayList<Integer> right_state_sequence,
+ HashMap<String, Boolean> tbl_prefix) {
+ int r_size = (right_state_sequence.size() < g_bleu_order - 1) ? right_state_sequence.size()
+ : (g_bleu_order - 1);
+ int[] right_lm_state = new int[r_size];
+ if (!using_right_equiv_state || r_size < g_bleu_order - 1) { // regular
+ for (int i = 0; i < r_size; i++) {
+ right_lm_state[i] = (Integer) right_state_sequence.get(right_state_sequence.size() - r_size
+ + i);
+ }
+ } else {
+ for (int i = 0; i < r_size; i++) { // left to right
+ if (is_a_prefix_in_tbl(right_state_sequence, right_state_sequence.size() - r_size + i,
+ right_state_sequence.size() - 1, tbl_prefix)) {
+ // if(is_a_prefix_in_grammar(right_state_sequence, right_state_sequence.size()-r_size+i,
+ // right_state_sequence.size()-1, grammar_prefix)){
+ for (int j = i; j < r_size; j++) {
+ right_lm_state[j] = (Integer) right_state_sequence.get(right_state_sequence.size()
+ - r_size + j);
+ }
+ break;
+ } else {
+ right_lm_state[i] = this.NULL_RIGHT_LM_STATE_SYM_ID;
+ }
+ }
+ // System.out.println("origi right:" + Symbol.get_string(right_state_sequence)+
+ // "; equiv right:" + Symbol.get_string(right_lm_state));
+ }
+ return right_lm_state;
+ }
+
+ private boolean is_a_prefix_in_tbl(ArrayList<Integer> right_state_sequence, int start_pos,
+ int end_pos, HashMap<String, Boolean> tbl_prefix) {
+ if (right_state_sequence.get(start_pos) == this.NULL_RIGHT_LM_STATE_SYM_ID) {
+ return false;
+ }
+ StringBuffer prefix = new StringBuffer();
+ for (int i = start_pos; i <= end_pos; i++) {
+ prefix.append(right_state_sequence.get(i));
+ if (i < end_pos)
+ prefix.append(' ');
+ }
+ return (Boolean) tbl_prefix.containsKey(prefix.toString());
+ }
+
+ public static void compare_two_int_arrays(int[] a, int[] b) {
+ if (a.length != b.length) {
+ throw new RuntimeException("two arrays do not have same size");
+ }
+ for (int i = 0; i < a.length; i++) {
+ if (a[i] != b[i]) {
+ throw new RuntimeException("elements in two arrays are not same");
+ }
+ }
+ }
+
+ // sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order]))
+ public static double compute_bleu(int hyp_len, double ref_len, int[] num_ngram_match,
+ int bleu_order) {
+ if (hyp_len <= 0 || ref_len <= 0) {
+ throw new RuntimeException("ref or hyp is zero len");
+ }
+ double res = 0;
+ double wt = 1.0 / bleu_order;
+ double prec = 0;
+ double smooth_factor = 1.0;
+ for (int t = 0; t < bleu_order && t < hyp_len; t++) {
+ if (num_ngram_match[t] > 0) {
+ prec += wt * Math.log(num_ngram_match[t] * 1.0 / (hyp_len - t));
+ } else {
+ smooth_factor *= 0.5;// TODO
+ prec += wt * Math.log(smooth_factor / (hyp_len - t));
+ }
+ }
+ double bp = (hyp_len >= ref_len) ? 1.0 : Math.exp(1 - ref_len / hyp_len);
+ res = bp * Math.exp(prec);
+ // System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec)
+ // + "; bp: " + bp + "; bleu: " + res);
+ return res;
+ }
+
+ // accumulate ngram counts into tbl
+ public void get_ngrams(HashMap<String, Integer> tbl, int order, int[] wrds,
+ boolean ignore_null_equiv_symbol) {
+ for (int i = 0; i < wrds.length; i++) {
+ for (int j = 0; j < order && j + i < wrds.length; j++) { // ngram: [i,i+j]
+ boolean contain_null = false;
+ StringBuffer ngram = new StringBuffer();
+ for (int k = i; k <= i + j; k++) {
+ if (wrds[k] == this.NULL_LEFT_LM_STATE_SYM_ID
+ || wrds[k] == this.NULL_RIGHT_LM_STATE_SYM_ID) {
+ contain_null = true;
+ if (ignore_null_equiv_symbol)
+ break;
+ }
+ ngram.append(wrds[k]);
+ if (k < i + j)
+ ngram.append(' ');
+ }
+ if (ignore_null_equiv_symbol && contain_null)
+ continue; // skip this ngram
+ String ngram_str = ngram.toString();
+ if (tbl.containsKey(ngram_str)) {
+ tbl.put(ngram_str, (Integer) tbl.get(ngram_str) + 1);
+ } else {
+ tbl.put(ngram_str, 1);
+ }
+ }
+ }
+ }
+
+ /**
+ * accumulate ngram counts into tbl.
+ * @param tbl a {@link java.util.HashMap} which is used to store ngram counts
+ * @param order todo
+ * @param wrds an {@link java.util.ArrayList} containing {@link java.lang.Integer} word representations
+ * @param ignore_null_equiv_symbol set to true to skip some nGrams
+ */
+ public void get_ngrams(HashMap<String, Integer> tbl, int order, ArrayList<Integer> wrds,
+ boolean ignore_null_equiv_symbol) {
+ for (int i = 0; i < wrds.size(); i++) {
+ // ngram: [i,i+j]
+ for (int j = 0; j < order && j + i < wrds.size(); j++) {
+ boolean contain_null = false;
+ StringBuffer ngram = new StringBuffer();
+ for (int k = i; k <= i + j; k++) {
+ int t_wrd = (Integer) wrds.get(k);
+ if (t_wrd == this.NULL_LEFT_LM_STATE_SYM_ID || t_wrd == this.NULL_RIGHT_LM_STATE_SYM_ID) {
+ contain_null = true;
+ if (ignore_null_equiv_symbol)
+ break;
+ }
+ ngram.append(t_wrd);
+ if (k < i + j)
+ ngram.append(' ');
+ }
+ // skip this ngram
+ if (ignore_null_equiv_symbol && contain_null)
+ continue;
+
+ String ngram_str = ngram.toString();
+ if (tbl.containsKey(ngram_str)) {
+ tbl.put(ngram_str, (Integer) tbl.get(ngram_str) + 1);
+ } else {
+ tbl.put(ngram_str, 1);
+ }
+ }
+ }
+ }
+
+ // do_ngram_clip: consider global n-gram clip
+ public double compute_sentence_bleu(String ref_sent, String hyp_sent, boolean do_ngram_clip,
+ int bleu_order) {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ int[] numeric_ref_sent = Vocabulary.addAll(ref_sent);
+ int[] numeric_hyp_sent = Vocabulary.addAll(hyp_sent);
+ return compute_sentence_bleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order);
+ }
+
+ public double compute_sentence_bleu(int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip,
+ int bleu_order) {
+ double res_bleu = 0;
+ int order = 4;
+ HashMap<String, Integer> ref_ngram_tbl = new HashMap<String, Integer>();
+ get_ngrams(ref_ngram_tbl, order, ref_sent, false);
+ HashMap<String, Integer> hyp_ngram_tbl = new HashMap<String, Integer>();
+ get_ngrams(hyp_ngram_tbl, order, hyp_sent, false);
+
+ int[] num_ngram_match = new int[order];
+ for (String ngram : hyp_ngram_tbl.keySet()) {
+ if (ref_ngram_tbl.containsKey(ngram)) {
+ if (do_ngram_clip) {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += Support.findMin(
+ (Integer) ref_ngram_tbl.get(ngram), (Integer) hyp_ngram_tbl.get(ngram)); // ngram clip
+ } else {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += (Integer) hyp_ngram_tbl.get(ngram);// without
+ // ngram
+ // count
+ // clipping
+ }
+ }
+ }
+ res_bleu = compute_bleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order);
+ // System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length +
+ // "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+
+ // " " + num_ngram_match[2] + " " +num_ngram_match[3]);
+
+ return res_bleu;
+ }
+
+ // #### equivalent lm stuff ############
+ public static void setup_prefix_suffix_tbl(int[] wrds, int order,
+ HashMap<String, Boolean> prefix_tbl, HashMap<String, Boolean> suffix_tbl) {
+ for (int i = 0; i < wrds.length; i++) {
+ for (int j = 0; j < order && j + i < wrds.length; j++) { // ngram: [i,i+j]
+ StringBuffer ngram = new StringBuffer();
+ // ### prefix
+ for (int k = i; k < i + j; k++) { // all ngrams [i,i+j-1]
+ ngram.append(wrds[k]);
+ prefix_tbl.put(ngram.toString(), true);
+ ngram.append(' ');
+ }
+ // ### suffix: right-most wrd first
+ ngram = new StringBuffer();
+ for (int k = i + j; k > i; k--) { // all ngrams [i+1,i+j]: reverse order
+ ngram.append(wrds[k]);
+ suffix_tbl.put(ngram.toString(), true);// stored in reverse order
+ ngram.append(' ');
+ }
+ }
+ }
+ }
+
+ // #### equivalent lm stuff ############
+ public static void setup_prefix_suffix_grammar(int[] wrds, int order, PrefixGrammar prefix_gr,
+ PrefixGrammar suffix_gr) {
+ for (int i = 0; i < wrds.length; i++) {
+ for (int j = 0; j < order && j + i < wrds.length; j++) { // ngram: [i,i+j]
+ // ### prefix
+ prefix_gr.add_ngram(wrds, i, i + j - 1);// ngram: [i,i+j-1]
+
+ // ### suffix: right-most wrd first
+ int[] reverse_wrds = new int[j];
+ for (int k = i + j, t = 0; k > i; k--) { // all ngrams [i+1,i+j]: reverse order
+ reverse_wrds[t++] = wrds[k];
+ }
+ suffix_gr.add_ngram(reverse_wrds, 0, j - 1);
+ }
+ }
+ }
+
+ /*
+ * a backoff node is a hashtable, it may include: (1) probabilititis for next words (2) pointers
+ * to a next-layer backoff node (hashtable) (3) backoff weight for this node (4) suffix/prefix
+ * flag to indicate that there is ngrams start from this suffix
+ */
+ private static class PrefixGrammar {
+
+ private static class PrefixGrammarNode extends HashMap<Integer, PrefixGrammarNode> {
+ private static final long serialVersionUID = 1L;
+ };
+
+ PrefixGrammarNode root = new PrefixGrammarNode();
+
+ // add prefix information
+ public void add_ngram(int[] wrds, int start_pos, int end_pos) {
+ // ######### identify the position, and insert the trinodes if necessary
+ PrefixGrammarNode pos = root;
+ for (int k = start_pos; k <= end_pos; k++) {
+ int cur_sym_id = wrds[k];
+ PrefixGrammarNode next_layer = pos.get(cur_sym_id);
+
+ if (null != next_layer) {
+ pos = next_layer;
+ } else {
+ // next layer node
+ PrefixGrammarNode tmp = new PrefixGrammarNode();
+ pos.put(cur_sym_id, tmp);
+ pos = tmp;
+ }
+ }
+ }
+
+ @SuppressWarnings("unused")
+ public boolean contain_ngram(ArrayList<Integer> wrds, int start_pos, int end_pos) {
+ if (end_pos < start_pos)
+ return false;
+ PrefixGrammarNode pos = root;
+ for (int k = start_pos; k <= end_pos; k++) {
+ int cur_sym_id = wrds.get(k);
+ PrefixGrammarNode next_layer = pos.get(cur_sym_id);
+ if (next_layer != null) {
+ pos = next_layer;
+ } else {
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/tools/GrammarPacker.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/tools/GrammarPacker.java
index 8bdeb3b,0000000..3b38c29
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/tools/GrammarPacker.java
+++ b/src/main/java/org/apache/joshua/tools/GrammarPacker.java
@@@ -1,982 -1,0 +1,958 @@@
- /*
++/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.tools;
+
+import static org.apache.joshua.decoder.ff.tm.packed.PackedGrammar.VOCABULARY_FILENAME;
+
+import java.io.BufferedOutputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
- import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.TreeMap;
+
+import org.apache.joshua.corpus.Vocabulary;
++import org.apache.joshua.decoder.ff.tm.Rule;
++import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
++import org.apache.joshua.decoder.ff.tm.format.MosesFormatReader;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.encoding.EncoderConfiguration;
+import org.apache.joshua.util.encoding.FeatureTypeAnalyzer;
+import org.apache.joshua.util.encoding.IntEncoder;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
++
+public class GrammarPacker {
+
+ private static final Logger LOG = LoggerFactory.getLogger(GrammarPacker.class);
+
++ /**
++ * The packed grammar version number. Increment this any time you add new features, and update
++ * the documentation.
++ *
++ * Version history:
++ *
++ * - 3 (May 2016). This was the first version that was marked. It removed the special phrase-
++ * table packing that packed phrases without the [X,1] on the source and target sides, which
++ * then required special handling in the decoder to use for phrase-based decoding.
++ *
++ *
++ */
++ public static final int VERSION = 3;
++
+ // Size limit for slice in bytes.
+ private static int DATA_SIZE_LIMIT = (int) (Integer.MAX_VALUE * 0.8);
+ // Estimated average number of feature entries for one rule.
+ private static int DATA_SIZE_ESTIMATE = 20;
+
+ private static final String SOURCE_WORDS_SEPARATOR = " ||| ";
+
+ // Output directory name.
+ private String output;
+
+ // Input grammar to be packed.
+ private String grammar;
+
+ public String getGrammar() {
+ return grammar;
+ }
-
++
+ public String getOutputDirectory() {
+ return output;
+ }
+
+ // Approximate maximum size of a slice in number of rules
+ private int approximateMaximumSliceSize;
+
+ private boolean labeled;
+
+ private boolean packAlignments;
+ private boolean grammarAlignments;
+ private String alignments;
+
+ private FeatureTypeAnalyzer types;
+ private EncoderConfiguration encoderConfig;
+
+ private String dump;
+
+ private int max_source_len;
+
+ public GrammarPacker(String grammar_filename, String config_filename, String output_filename,
+ String alignments_filename, String featuredump_filename, boolean grammar_alignments,
+ int approximateMaximumSliceSize)
+ throws IOException {
+ this.labeled = true;
+ this.grammar = grammar_filename;
+ this.output = output_filename;
+ this.dump = featuredump_filename;
+ this.grammarAlignments = grammar_alignments;
+ this.approximateMaximumSliceSize = approximateMaximumSliceSize;
+ this.max_source_len = 0;
+
+ // TODO: Always open encoder config? This is debatable.
+ this.types = new FeatureTypeAnalyzer(true);
+
+ this.alignments = alignments_filename;
+ packAlignments = grammarAlignments || (alignments != null);
+ if (!packAlignments) {
+ LOG.info("No alignments file or grammar specified, skipping.");
+ } else if (alignments != null && !new File(alignments_filename).exists()) {
+ throw new RuntimeException("Alignments file does not exist: " + alignments);
+ }
+
+ if (config_filename != null) {
+ readConfig(config_filename);
+ types.readConfig(config_filename);
+ } else {
+ LOG.info("No config specified. Attempting auto-detection of feature types.");
+ }
+ LOG.info("Approximate maximum slice size (in # of rules) set to {}", approximateMaximumSliceSize);
+
+ File working_dir = new File(output);
+ working_dir.mkdir();
+ if (!working_dir.exists()) {
+ throw new RuntimeException("Failed creating output directory.");
+ }
+ }
+
+ private void readConfig(String config_filename) throws IOException {
+ LineReader reader = new LineReader(config_filename);
+ while (reader.hasNext()) {
+ // Clean up line, chop comments off and skip if the result is empty.
+ String line = reader.next().trim();
+ if (line.indexOf('#') != -1)
+ line = line.substring(0, line.indexOf('#'));
+ if (line.isEmpty())
+ continue;
+ String[] fields = line.split("[\\s]+");
+
+ if (fields.length < 2) {
+ throw new RuntimeException("Incomplete line in config.");
+ }
+ if ("slice_size".equals(fields[0])) {
+ // Number of records to concurrently load into memory for sorting.
+ approximateMaximumSliceSize = Integer.parseInt(fields[1]);
+ }
+ }
+ reader.close();
+ }
+
+ /**
+ * Executes the packing.
+ *
+ * @throws IOException if there is an error reading the grammar
+ */
+ public void pack() throws IOException {
+ LOG.info("Beginning exploration pass.");
- LineReader grammar_reader = null;
- LineReader alignment_reader = null;
+
+ // Explore pass. Learn vocabulary and feature value histograms.
+ LOG.info("Exploring: {}", grammar);
- grammar_reader = new LineReader(grammar);
- explore(grammar_reader);
++
++ HieroFormatReader grammarReader = getGrammarReader();
++ explore(grammarReader);
+
+ LOG.info("Exploration pass complete. Freezing vocabulary and finalizing encoders.");
+ if (dump != null) {
+ PrintWriter dump_writer = new PrintWriter(dump);
+ dump_writer.println(types.toString());
+ dump_writer.close();
+ }
+
+ types.inferTypes(this.labeled);
+ LOG.info("Type inference complete.");
+
+ LOG.info("Finalizing encoding.");
+
+ LOG.info("Writing encoding.");
+ types.write(output + File.separator + "encoding");
+
+ writeVocabulary();
+
+ String configFile = output + File.separator + "config";
+ LOG.info("Writing config to '{}'", configFile);
+ // Write config options
+ FileWriter config = new FileWriter(configFile);
- config.write(String.format("max-source-len = {}\n", max_source_len));
++ config.write(String.format("version = %d\n", VERSION));
++ config.write(String.format("max-source-len = %d\n", max_source_len));
+ config.close();
-
++
+ // Read previously written encoder configuration to match up to changed
+ // vocabulary id's.
+ LOG.info("Reading encoding.");
+ encoderConfig = new EncoderConfiguration();
+ encoderConfig.load(output + File.separator + "encoding");
+
+ LOG.info("Beginning packing pass.");
+ // Actual binarization pass. Slice and pack source, target and data.
- grammar_reader = new LineReader(grammar);
-
++ grammarReader = getGrammarReader();
++ LineReader alignment_reader = null;
+ if (packAlignments && !grammarAlignments)
+ alignment_reader = new LineReader(alignments);
- binarize(grammar_reader, alignment_reader);
++ binarize(grammarReader, alignment_reader);
+ LOG.info("Packing complete.");
+
+ LOG.info("Packed grammar in: {}", output);
+ LOG.info("Done.");
+ }
+
- private void explore(LineReader grammar) {
- int counter = 0;
++ /**
++ * Returns a reader that turns whatever file format is found into Hiero grammar rules.
++ *
++ * @param grammarFile
++ * @return
++ * @throws IOException
++ */
++ private HieroFormatReader getGrammarReader() throws IOException {
++ LineReader reader = new LineReader(grammar);
++ String line = reader.next();
++ if (line.startsWith("[")) {
++ return new HieroFormatReader(grammar);
++ } else {
++ return new MosesFormatReader(grammar);
++ }
++ }
++
++ /**
++ * This first pass over the grammar
++ * @param reader
++ */
++ private void explore(HieroFormatReader reader) {
++
+ // We always assume a labeled grammar. Unlabeled features are assumed to be dense and to always
+ // appear in the same order. They are assigned numeric names in order of appearance.
+ this.types.setLabeled(true);
+
- while (grammar.hasNext()) {
- String line = grammar.next().trim();
- counter++;
- ArrayList<String> fields = new ArrayList<String>(Arrays.asList(line.split("\\s\\|{3}\\s")));
-
- String lhs = null;
- if (line.startsWith("[")) {
- // hierarchical model
- if (fields.size() < 4) {
- LOG.warn("Incomplete grammar line at line {}: '{}'", counter, line);
- continue;
- }
- lhs = fields.remove(0);
- } else {
- // phrase-based model
- if (fields.size() < 3) {
- LOG.warn("Incomplete phrase line at line {}", counter);
- LOG.warn(line);
- continue;
- }
- lhs = "[X]";
- }
++ for (Rule rule: reader) {
+
- String[] source = fields.get(0).split("\\s");
- String[] target = fields.get(1).split("\\s");
- String[] features = fields.get(2).split("\\s");
-
- max_source_len = Math.max(max_source_len, source.length);
-
- Vocabulary.id(lhs);
- try {
- /* Add symbols to vocabulary.
- * NOTE: In case of nonterminals, we add both stripped versions ("[X]")
- * and "[X,1]" to the vocabulary.
- */
- for (String source_word : source) {
- Vocabulary.id(source_word);
- if (FormatUtils.isNonterminal(source_word)) {
- Vocabulary.id(FormatUtils.stripNonTerminalIndex(source_word));
- }
- }
- for (String target_word : target) {
- Vocabulary.id(target_word);
- if (FormatUtils.isNonterminal(target_word)) {
- Vocabulary.id(FormatUtils.stripNonTerminalIndex(target_word));
- }
- }
- } catch (java.lang.StringIndexOutOfBoundsException e) {
- LOG.warn("* Skipping bad grammar line '{}'", line);
- continue;
- }
++ max_source_len = Math.max(max_source_len, rule.getFrench().length);
++
++ /* Add symbols to vocabulary.
++ * NOTE: In case of nonterminals, we add both stripped versions ("[X]")
++ * and "[X,1]" to the vocabulary.
++ *
++ * TODO: MJP May 2016: Is it necessary to add [X,1]?
++ */
+
+ // Add feature names to vocabulary and pass the value through the
+ // appropriate encoder.
+ int feature_counter = 0;
++ String[] features = rule.getFeatureString().split("\\s+");
+ for (int f = 0; f < features.length; ++f) {
+ if (features[f].contains("=")) {
+ String[] fe = features[f].split("=");
+ if (fe[0].equals("Alignment"))
+ continue;
+ types.observe(Vocabulary.id(fe[0]), Float.parseFloat(fe[1]));
+ } else {
+ types.observe(Vocabulary.id(String.valueOf(feature_counter++)),
+ Float.parseFloat(features[f]));
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns a String encoding the first two source words.
+ * If there is only one source word, use empty string for the second.
+ */
+ private String getFirstTwoSourceWords(final String[] source_words) {
+ return source_words[0] + SOURCE_WORDS_SEPARATOR + ((source_words.length > 1) ? source_words[1] : "");
+ }
+
- private void binarize(LineReader grammar_reader, LineReader alignment_reader) throws IOException {
++ private void binarize(HieroFormatReader grammarReader, LineReader alignment_reader) throws IOException {
+ int counter = 0;
+ int slice_counter = 0;
+ int num_slices = 0;
+
+ boolean ready_to_flush = false;
+ // to determine when flushing is possible
+ String prev_first_two_source_words = null;
+
+ PackingTrie<SourceValue> source_trie = new PackingTrie<SourceValue>();
+ PackingTrie<TargetValue> target_trie = new PackingTrie<TargetValue>();
+ FeatureBuffer feature_buffer = new FeatureBuffer();
+
+ AlignmentBuffer alignment_buffer = null;
+ if (packAlignments)
+ alignment_buffer = new AlignmentBuffer();
+
+ TreeMap<Integer, Float> features = new TreeMap<Integer, Float>();
- while (grammar_reader.hasNext()) {
- String grammar_line = grammar_reader.next().trim();
++ for (Rule rule: grammarReader) {
+ counter++;
+ slice_counter++;
+
- ArrayList<String> fields = new ArrayList<String>(Arrays.asList(grammar_line.split("\\s\\|{3}\\s")));
- String lhs_word;
- String[] source_words;
- String[] target_words;
- String[] feature_entries;
- if (grammar_line.startsWith("[")) {
- if (fields.size() < 4)
- continue;
-
- lhs_word = fields.remove(0);
- source_words = fields.get(0).split("\\s");
- target_words = fields.get(1).split("\\s");
- feature_entries = fields.get(2).split("\\s");
-
- } else {
- if (fields.size() < 3)
- continue;
-
- lhs_word = "[X]";
- String tmp = "[X,1] " + fields.get(0);
- source_words = tmp.split("\\s");
- tmp = "[X,1] " + fields.get(1);
- target_words = tmp.split("\\s");
- feature_entries = fields.get(2).split("\\s");
- }
++ String lhs_word = Vocabulary.word(rule.getLHS());
++ String[] source_words = rule.getFrenchWords().split("\\s+");
++ String[] target_words = rule.getEnglishWords().split("\\s+");
++ String[] feature_entries = rule.getFeatureString().split("\\s+");
+
+ // Reached slice limit size, indicate that we're closing up.
+ if (!ready_to_flush
+ && (slice_counter > approximateMaximumSliceSize
+ || feature_buffer.overflowing()
+ || (packAlignments && alignment_buffer.overflowing()))) {
+ ready_to_flush = true;
+ // store the first two source words when slice size limit was reached
+ prev_first_two_source_words = getFirstTwoSourceWords(source_words);
+ }
+ // ready to flush
+ if (ready_to_flush) {
+ final String first_two_source_words = getFirstTwoSourceWords(source_words);
+ // the grammar can only be partitioned at the level of first two source word changes.
+ // Thus, we can only flush if the current first two source words differ from the ones
+ // when the slice size limit was reached.
+ if (!first_two_source_words.equals(prev_first_two_source_words)) {
+ LOG.warn("ready to flush and first two words have changed ({} vs. {})",
+ prev_first_two_source_words, first_two_source_words);
+ LOG.info("flushing {} rules to slice.", slice_counter);
+ flush(source_trie, target_trie, feature_buffer, alignment_buffer, num_slices);
+ source_trie.clear();
+ target_trie.clear();
+ feature_buffer.clear();
+ if (packAlignments)
+ alignment_buffer.clear();
+
+ num_slices++;
+ slice_counter = 0;
+ ready_to_flush = false;
+ }
+ }
+
+ int alignment_index = -1;
+ // If present, process alignments.
+ if (packAlignments) {
+ String alignment_line;
+ if (grammarAlignments) {
- alignment_line = fields.get(3);
++ alignment_line = rule.getAlignmentString();
+ } else {
+ if (!alignment_reader.hasNext()) {
+ LOG.error("No more alignments starting in line {}", counter);
+ throw new RuntimeException("No more alignments starting in line " + counter);
+ }
+ alignment_line = alignment_reader.next().trim();
+ }
+ String[] alignment_entries = alignment_line.split("\\s");
+ byte[] alignments = new byte[alignment_entries.length * 2];
+ if (alignment_entries.length != 0) {
+ for (int i = 0; i < alignment_entries.length; i++) {
+ String[] parts = alignment_entries[i].split("-");
+ alignments[2 * i] = Byte.parseByte(parts[0]);
+ alignments[2 * i + 1] = Byte.parseByte(parts[1]);
+ }
+ }
+ alignment_index = alignment_buffer.add(alignments);
+ }
+
+ // Process features.
+ // Implicitly sort via TreeMap, write to data buffer, remember position
+ // to pass on to the source trie node.
+ features.clear();
+ int feature_count = 0;
+ for (int f = 0; f < feature_entries.length; ++f) {
+ String feature_entry = feature_entries[f];
+ int feature_id;
- float feature_value;
++ float feature_value;
+ if (feature_entry.contains("=")) {
+ String[] parts = feature_entry.split("=");
+ if (parts[0].equals("Alignment"))
+ continue;
+ feature_id = Vocabulary.id(parts[0]);
+ feature_value = Float.parseFloat(parts[1]);
+ } else {
+ feature_id = Vocabulary.id(String.valueOf(feature_count++));
+ feature_value = Float.parseFloat(feature_entry);
+ }
+ if (feature_value != 0)
+ features.put(encoderConfig.innerId(feature_id), feature_value);
+ }
+ int features_index = feature_buffer.add(features);
+
+ // Sanity check on the data block index.
+ if (packAlignments && features_index != alignment_index) {
+ LOG.error("Block index mismatch between features ({}) and alignments ({}).",
+ features_index, alignment_index);
+ throw new RuntimeException("Data block index mismatch.");
+ }
+
+ // Process source side.
+ SourceValue sv = new SourceValue(Vocabulary.id(lhs_word), features_index);
+ int[] source = new int[source_words.length];
+ for (int i = 0; i < source_words.length; i++) {
+ if (FormatUtils.isNonterminal(source_words[i]))
+ source[i] = Vocabulary.id(FormatUtils.stripNonTerminalIndex(source_words[i]));
+ else
+ source[i] = Vocabulary.id(source_words[i]);
+ }
+ source_trie.add(source, sv);
+
+ // Process target side.
+ TargetValue tv = new TargetValue(sv);
+ int[] target = new int[target_words.length];
+ for (int i = 0; i < target_words.length; i++) {
+ if (FormatUtils.isNonterminal(target_words[i])) {
+ target[target_words.length - (i + 1)] = -FormatUtils.getNonterminalIndex(target_words[i]);
+ } else {
+ target[target_words.length - (i + 1)] = Vocabulary.id(target_words[i]);
+ }
+ }
+ target_trie.add(target, tv);
+ }
+ // flush last slice and clear buffers
+ flush(source_trie, target_trie, feature_buffer, alignment_buffer, num_slices);
+ }
+
+ /**
+ * Serializes the source, target and feature data structures into interlinked binary files. Target
+ * is written first, into a skeletal (node don't carry any data) upward-pointing trie, updating
+ * the linking source trie nodes with the position once it is known. Source and feature data are
+ * written simultaneously. The source structure is written into a downward-pointing trie and
+ * stores the rule's lhs as well as links to the target and feature stream. The feature stream is
+ * prompted to write out a block
+ *
+ * @param source_trie
+ * @param target_trie
+ * @param feature_buffer
+ * @param id
+ * @throws IOException
+ */
+ private void flush(PackingTrie<SourceValue> source_trie,
+ PackingTrie<TargetValue> target_trie, FeatureBuffer feature_buffer,
+ AlignmentBuffer alignment_buffer, int id) throws IOException {
+ // Make a slice object for this piece of the grammar.
+ PackingFileTuple slice = new PackingFileTuple("slice_" + String.format("%05d", id));
+ // Pull out the streams for source, target and data output.
+ DataOutputStream source_stream = slice.getSourceOutput();
+ DataOutputStream target_stream = slice.getTargetOutput();
+ DataOutputStream target_lookup_stream = slice.getTargetLookupOutput();
+ DataOutputStream feature_stream = slice.getFeatureOutput();
+ DataOutputStream alignment_stream = slice.getAlignmentOutput();
+
+ Queue<PackingTrie<TargetValue>> target_queue;
+ Queue<PackingTrie<SourceValue>> source_queue;
+
+ // The number of bytes both written into the source stream and
+ // buffered in the source queue.
+ int source_position;
+ // The number of bytes written into the target stream.
+ int target_position;
+
+ // Add trie root into queue, set target position to 0 and set cumulated
+ // size to size of trie root.
+ target_queue = new LinkedList<PackingTrie<TargetValue>>();
+ target_queue.add(target_trie);
+ target_position = 0;
+
+ // Target lookup table for trie levels.
+ int current_level_size = 1;
+ int next_level_size = 0;
+ ArrayList<Integer> target_lookup = new ArrayList<Integer>();
+
+ // Packing loop for upwards-pointing target trie.
+ while (!target_queue.isEmpty()) {
+ // Pop top of queue.
+ PackingTrie<TargetValue> node = target_queue.poll();
+ // Register that this is where we're writing the node to.
+ node.address = target_position;
+ // Tell source nodes that we're writing to this position in the file.
+ for (TargetValue tv : node.values)
+ tv.parent.target = node.address;
+ // Write link to parent.
+ if (node.parent != null)
+ target_stream.writeInt(node.parent.address);
+ else
+ target_stream.writeInt(-1);
+ target_stream.writeInt(node.symbol);
+ // Enqueue children.
+ for (int k : node.children.descendingKeySet()) {
+ PackingTrie<TargetValue> child = node.children.get(k);
+ target_queue.add(child);
+ }
+ target_position += node.size(false, true);
+ next_level_size += node.children.descendingKeySet().size();
+
+ current_level_size--;
+ if (current_level_size == 0) {
+ target_lookup.add(target_position);
+ current_level_size = next_level_size;
+ next_level_size = 0;
+ }
+ }
+ target_lookup_stream.writeInt(target_lookup.size());
+ for (int i : target_lookup)
+ target_lookup_stream.writeInt(i);
+ target_lookup_stream.close();
+
+ // Setting up for source and data writing.
+ source_queue = new LinkedList<PackingTrie<SourceValue>>();
+ source_queue.add(source_trie);
+ source_position = source_trie.size(true, false);
+ source_trie.address = target_position;
+
+ // Ready data buffers for writing.
+ feature_buffer.initialize();
+ if (packAlignments)
+ alignment_buffer.initialize();
+
+ // Packing loop for downwards-pointing source trie.
+ while (!source_queue.isEmpty()) {
+ // Pop top of queue.
+ PackingTrie<SourceValue> node = source_queue.poll();
+ // Write number of children.
+ source_stream.writeInt(node.children.size());
+ // Write links to children.
+ for (int k : node.children.descendingKeySet()) {
+ PackingTrie<SourceValue> child = node.children.get(k);
+ // Enqueue child.
+ source_queue.add(child);
+ // Child's address will be at the current end of the queue.
+ child.address = source_position;
+ // Advance cumulated size by child's size.
+ source_position += child.size(true, false);
+ // Write the link.
+ source_stream.writeInt(k);
+ source_stream.writeInt(child.address);
+ }
+ // Write number of data items.
+ source_stream.writeInt(node.values.size());
+ // Write lhs and links to target and data.
+ for (SourceValue sv : node.values) {
+ int feature_block_index = feature_buffer.write(sv.data);
+ if (packAlignments) {
+ int alignment_block_index = alignment_buffer.write(sv.data);
+ if (alignment_block_index != feature_block_index) {
+ LOG.error("Block index mismatch.");
+ throw new RuntimeException("Block index mismatch: alignment (" + alignment_block_index
+ + ") and features (" + feature_block_index + ") don't match.");
+ }
+ }
+ source_stream.writeInt(sv.lhs);
+ source_stream.writeInt(sv.target);
+ source_stream.writeInt(feature_block_index);
+ }
+ }
+ // Flush the data stream.
+ feature_buffer.flush(feature_stream);
+ if (packAlignments)
+ alignment_buffer.flush(alignment_stream);
+
+ target_stream.close();
+ source_stream.close();
+ feature_stream.close();
+ if (packAlignments)
+ alignment_stream.close();
+ }
+
+ public void writeVocabulary() throws IOException {
+ final String vocabularyFilename = output + File.separator + VOCABULARY_FILENAME;
+ LOG.info("Writing vocabulary to {}", vocabularyFilename);
+ Vocabulary.write(vocabularyFilename);
+ }
+
+ /**
+ * Integer-labeled, doubly-linked trie with some provisions for packing.
+ *
+ * @author Juri Ganitkevitch
+ *
+ * @param <D> The trie's value type.
+ */
+ class PackingTrie<D extends PackingTrieValue> {
+ int symbol;
+ PackingTrie<D> parent;
+
+ TreeMap<Integer, PackingTrie<D>> children;
+ List<D> values;
+
+ int address;
+
+ PackingTrie() {
+ address = -1;
+
+ symbol = 0;
+ parent = null;
+
+ children = new TreeMap<Integer, PackingTrie<D>>();
+ values = new ArrayList<D>();
+ }
+
+ PackingTrie(PackingTrie<D> parent, int symbol) {
+ this();
+ this.parent = parent;
+ this.symbol = symbol;
+ }
+
+ void add(int[] path, D value) {
+ add(path, 0, value);
+ }
+
+ private void add(int[] path, int index, D value) {
+ if (index == path.length)
+ this.values.add(value);
+ else {
+ PackingTrie<D> child = children.get(path[index]);
+ if (child == null) {
+ child = new PackingTrie<D>(this, path[index]);
+ children.put(path[index], child);
+ }
+ child.add(path, index + 1, value);
+ }
+ }
+
+ /**
+ * Calculate the size (in ints) of a packed trie node. Distinguishes downwards pointing (parent
+ * points to children) from upwards pointing (children point to parent) tries, as well as
+ * skeletal (no data, just the labeled links) and non-skeletal (nodes have a data block)
+ * packing.
+ *
+ * @param downwards Are we packing into a downwards-pointing trie?
+ * @param skeletal Are we packing into a skeletal trie?
+ *
+ * @return Number of bytes the trie node would occupy.
+ */
+ int size(boolean downwards, boolean skeletal) {
+ int size = 0;
+ if (downwards) {
+ // Number of children and links to children.
+ size = 1 + 2 * children.size();
+ } else {
+ // Link to parent.
+ size += 2;
+ }
+ // Non-skeletal packing: number of data items.
+ if (!skeletal)
+ size += 1;
+ // Non-skeletal packing: write size taken up by data items.
+ if (!skeletal && !values.isEmpty())
+ size += values.size() * values.get(0).size();
+
+ return size;
+ }
+
+ void clear() {
+ children.clear();
+ values.clear();
+ }
+ }
+
+ interface PackingTrieValue {
+ int size();
+ }
+
+ class SourceValue implements PackingTrieValue {
+ int lhs;
+ int data;
+ int target;
+
+ public SourceValue() {
+ }
+
+ SourceValue(int lhs, int data) {
+ this.lhs = lhs;
+ this.data = data;
+ }
+
+ void setTarget(int target) {
+ this.target = target;
+ }
+
+ public int size() {
+ return 3;
+ }
+ }
+
+ class TargetValue implements PackingTrieValue {
+ SourceValue parent;
+
+ TargetValue(SourceValue parent) {
+ this.parent = parent;
+ }
+
+ public int size() {
+ return 0;
+ }
+ }
+
+ abstract class PackingBuffer<T> {
+ private byte[] backing;
+ protected ByteBuffer buffer;
+
+ protected ArrayList<Integer> memoryLookup;
+ protected int totalSize;
+ protected ArrayList<Integer> onDiskOrder;
+
+ PackingBuffer() throws IOException {
+ allocate();
+ memoryLookup = new ArrayList<Integer>();
+ onDiskOrder = new ArrayList<Integer>();
+ totalSize = 0;
+ }
+
+ abstract int add(T item);
+
+ // Allocate a reasonably-sized buffer for the feature data.
+ private void allocate() {
+ backing = new byte[approximateMaximumSliceSize * DATA_SIZE_ESTIMATE];
+ buffer = ByteBuffer.wrap(backing);
+ }
+
+ // Reallocate the backing array and buffer, copies data over.
+ protected void reallocate() {
+ if (backing.length == Integer.MAX_VALUE)
+ return;
+ long attempted_length = backing.length * 2l;
+ int new_length;
+ // Detect overflow.
+ if (attempted_length >= Integer.MAX_VALUE)
+ new_length = Integer.MAX_VALUE;
+ else
+ new_length = (int) attempted_length;
+ byte[] new_backing = new byte[new_length];
+ System.arraycopy(backing, 0, new_backing, 0, backing.length);
+ int old_position = buffer.position();
+ ByteBuffer new_buffer = ByteBuffer.wrap(new_backing);
+ new_buffer.position(old_position);
+ buffer = new_buffer;
+ backing = new_backing;
+ }
+
+ /**
+ * Prepare the data buffer for disk writing.
+ */
+ void initialize() {
+ onDiskOrder.clear();
+ }
+
+ /**
+ * Enqueue a data block for later writing.
+ *
+ * @param block_index The index of the data block to add to writing queue.
+ * @return The to-be-written block's output index.
+ */
+ int write(int block_index) {
+ onDiskOrder.add(block_index);
+ return onDiskOrder.size() - 1;
+ }
+
+ /**
+ * Performs the actual writing to disk in the order specified by calls to write() since the last
+ * call to initialize().
+ *
+ * @param out
+ * @throws IOException
+ */
+ void flush(DataOutputStream out) throws IOException {
+ writeHeader(out);
+ int size;
+ int block_address;
+ for (int block_index : onDiskOrder) {
+ block_address = memoryLookup.get(block_index);
+ size = blockSize(block_index);
+ out.write(backing, block_address, size);
+ }
+ }
+
+ void clear() {
+ buffer.clear();
+ memoryLookup.clear();
+ onDiskOrder.clear();
+ }
+
+ boolean overflowing() {
+ return (buffer.position() >= DATA_SIZE_LIMIT);
+ }
+
+ private void writeHeader(DataOutputStream out) throws IOException {
+ if (out.size() == 0) {
+ out.writeInt(onDiskOrder.size());
+ out.writeInt(totalSize);
+ int disk_position = headerSize();
+ for (int block_index : onDiskOrder) {
+ out.writeInt(disk_position);
+ disk_position += blockSize(block_index);
+ }
+ } else {
+ throw new RuntimeException("Got a used stream for header writing.");
+ }
+ }
+
+ private int headerSize() {
+ // One integer for each data block, plus number of blocks and total size.
+ return 4 * (onDiskOrder.size() + 2);
+ }
+
+ private int blockSize(int block_index) {
+ int block_address = memoryLookup.get(block_index);
+ return (block_index < memoryLookup.size() - 1 ? memoryLookup.get(block_index + 1) : totalSize)
+ - block_address;
+ }
+ }
+
+ class FeatureBuffer extends PackingBuffer<TreeMap<Integer, Float>> {
+
+ private IntEncoder idEncoder;
+
+ FeatureBuffer() throws IOException {
+ super();
+ idEncoder = types.getIdEncoder();
+ LOG.info("Encoding feature ids in: {}", idEncoder.getKey());
+ }
+
+ /**
+ * Add a block of features to the buffer.
+ *
+ * @param features TreeMap with the features for one rule.
+ * @return The index of the resulting data block.
+ */
+ int add(TreeMap<Integer, Float> features) {
+ int data_position = buffer.position();
+
+ // Over-estimate how much room this addition will need: for each
+ // feature (ID_SIZE for label, "upper bound" of 4 for the value), plus ID_SIZE for
+ // the number of features. If this won't fit, reallocate the buffer.
+ int size_estimate = (4 + EncoderConfiguration.ID_SIZE) * features.size()
+ + EncoderConfiguration.ID_SIZE;
+ if (buffer.capacity() - buffer.position() <= size_estimate)
+ reallocate();
+
+ // Write features to buffer.
+ idEncoder.write(buffer, features.size());
+ for (Integer k : features.descendingKeySet()) {
+ float v = features.get(k);
+ // Sparse features.
+ if (v != 0.0) {
+ idEncoder.write(buffer, k);
+ encoderConfig.encoder(k).write(buffer, v);
+ }
+ }
+ // Store position the block was written to.
+ memoryLookup.add(data_position);
+ // Update total size (in bytes).
+ totalSize = buffer.position();
+
+ // Return block index.
+ return memoryLookup.size() - 1;
+ }
+ }
+
+ class AlignmentBuffer extends PackingBuffer<byte[]> {
+
+ AlignmentBuffer() throws IOException {
+ super();
+ }
+
+ /**
+ * Add a rule alignments to the buffer.
+ *
+ * @param alignments a byte array with the alignment points for one rule.
+ * @return The index of the resulting data block.
+ */
+ int add(byte[] alignments) {
+ int data_position = buffer.position();
+ int size_estimate = alignments.length + 1;
+ if (buffer.capacity() - buffer.position() <= size_estimate)
+ reallocate();
+
+ // Write alignment points to buffer.
+ buffer.put((byte) (alignments.length / 2));
+ buffer.put(alignments);
+
+ // Store position the block was written to.
+ memoryLookup.add(data_position);
+ // Update total size (in bytes).
+ totalSize = buffer.position();
+ // Return block index.
+ return memoryLookup.size() - 1;
+ }
+ }
+
+ class PackingFileTuple implements Comparable<PackingFileTuple> {
+ private File sourceFile;
+ private File targetLookupFile;
+ private File targetFile;
+
+ private File featureFile;
+ private File alignmentFile;
+
+ PackingFileTuple(String prefix) {
+ sourceFile = new File(output + File.separator + prefix + ".source");
+ targetFile = new File(output + File.separator + prefix + ".target");
+ targetLookupFile = new File(output + File.separator + prefix + ".target.lookup");
+ featureFile = new File(output + File.separator + prefix + ".features");
+
+ alignmentFile = null;
+ if (packAlignments)
+ alignmentFile = new File(output + File.separator + prefix + ".alignments");
+
+ LOG.info("Allocated slice: {}", sourceFile.getAbsolutePath());
+ }
+
+ DataOutputStream getSourceOutput() throws IOException {
+ return getOutput(sourceFile);
+ }
+
+ DataOutputStream getTargetOutput() throws IOException {
+ return getOutput(targetFile);
+ }
+
+ DataOutputStream getTargetLookupOutput() throws IOException {
+ return getOutput(targetLookupFile);
+ }
+
+ DataOutputStream getFeatureOutput() throws IOException {
+ return getOutput(featureFile);
+ }
+
+ DataOutputStream getAlignmentOutput() throws IOException {
+ if (alignmentFile != null)
+ return getOutput(alignmentFile);
+ return null;
+ }
+
+ private DataOutputStream getOutput(File file) throws IOException {
+ if (file.createNewFile()) {
+ return new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
+ } else {
+ throw new RuntimeException("File doesn't exist: " + file.getName());
+ }
+ }
+
+ long getSize() {
+ return sourceFile.length() + targetFile.length() + featureFile.length();
+ }
+
+ @Override
+ public int compareTo(PackingFileTuple o) {
+ if (getSize() > o.getSize()) {
+ return -1;
+ } else if (getSize() < o.getSize()) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/util/FormatUtils.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/util/FormatUtils.java
index eb59480,0000000..6ab58eb
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/util/FormatUtils.java
+++ b/src/main/java/org/apache/joshua/util/FormatUtils.java
@@@ -1,232 -1,0 +1,242 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util;
+
+import java.io.PrintStream;
+import java.io.UnsupportedEncodingException;
+import java.util.regex.Pattern;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Utility class for format issues.
+ *
+ * @author Juri Ganitkevitch
+ * @author Lane Schwartz
+ */
+public class FormatUtils {
+
+ private static final Logger LOG = LoggerFactory.getLogger(FormatUtils.class);
+
+ private static final String INDEX_SEPARATOR = ",";
+
+ /**
+ * Determines whether the string is a nonterminal by checking that the first character is [
+ * and the last character is ].
+ *
+ * @param token input string
+ * @return true if it's a nonterminal symbol, false otherwise
+ */
+ public static boolean isNonterminal(String token) {
+ return (token.length() >=3 && token.charAt(0) == '[') && (token.charAt(token.length() - 1) == ']');
+ }
++
++ /**
++ * Determines whether the ID represents a nonterminal. This is a trivial check, since nonterminal
++ * IDs are simply negative ones.
++ */
++ public static boolean isNonterminal(int id) {
++ return id < 0;
++ }
+
+ /**
+ * Nonterminals are stored in the vocabulary in square brackets. This removes them when you
+ * just want the raw nonterminal word.
+ * Supports indexed and non-indexed nonTerminals:
+ * [GOAL] -> GOAL
+ * [X,1] -> [X]
+ *
+ * @param nt the nonterminal, e.g., "[GOAL]"
+ * @return the cleaned nonterminal, e.g., "GOAL"
+ */
+ public static String cleanNonTerminal(String nt) {
+ if (isNonterminal(nt)) {
+ if (isIndexedNonTerminal(nt)) {
+ // strip ",.*]"
+ return nt.substring(1, nt.indexOf(INDEX_SEPARATOR));
+ }
+ // strip "]"
+ return nt.substring(1, nt.length() - 1);
+ }
+ return nt;
+ }
+
+ private static boolean isIndexedNonTerminal(String nt) {
+ return nt.contains(INDEX_SEPARATOR);
+ }
+
+ /**
+ * Removes the index from a nonTerminal: [X,1] -> [X].
+ * @param nt an input non-terminal string
+ * @return the stripped non terminal string
+ */
+ public static String stripNonTerminalIndex(String nt) {
- return markup(cleanNonTerminal(nt));
++ return ensureNonTerminalBrackets(cleanNonTerminal(nt));
+ }
+
++ /**
++ * Nonterminals on source and target sides are represented as [X,1], where 1 is an integer
++ * that links the two sides. This function extracts the index, e.g.,
++ *
++ * getNonterminalIndex("[X,7]") -> 7
++ *
++ * @param the nonterminal index
++ * @return
++ */
+ public static int getNonterminalIndex(String nt) {
+ return Integer.parseInt(nt.substring(nt.indexOf(INDEX_SEPARATOR) + 1, nt.length() - 1));
+ }
+
+ /**
+ * Ensures that a string looks like what the system considers a nonterminal to be.
+ *
+ * @param nt the nonterminal string
+ * @return the nonterminal string surrounded in square brackets (if not already)
+ */
- public static String markup(String nt) {
++ public static String ensureNonTerminalBrackets(String nt) {
+ if (isNonterminal(nt))
+ return nt;
+ else
+ return "[" + nt + "]";
+ }
-
- public static String markup(String nt, int index) {
- if (isNonterminal(nt)) {
- return markup(cleanNonTerminal(nt), index);
- }
- return "[" + nt + INDEX_SEPARATOR + index + "]";
- }
+
+ public static String escapeSpecialSymbols(String s) {
+ return s.replaceAll("\\[", "-lsb-")
+ .replaceAll("\\]", "-rsb-")
+ .replaceAll("\\|", "-pipe-");
+ }
+
+ public static String unescapeSpecialSymbols(String s) {
+ return s.replaceAll("-lsb-", "[")
+ .replaceAll("-rsb-", "]")
+ .replaceAll("-pipe-", "|");
+ }
+
+ /**
+ * wrap sentence with sentence start/stop markers
+ * as defined by Vocabulary; separated by a single whitespace.
+ * @param s an input sentence
+ * @return the wrapped sentence
+ */
+ public static String addSentenceMarkers(String s) {
+ return Vocabulary.START_SYM + " " + s + " " + Vocabulary.STOP_SYM;
+ }
+
+ /**
+ * strip sentence markers (and whitespaces) from string
+ * @param s the sentence to strip of markers (and whitespaces)
+ * @return the stripped string
+ */
+ public static String removeSentenceMarkers(String s) {
+ return s.replaceAll("<s> ", "").replace(" </s>", "");
+ }
+
+ /**
+ * Returns true if the String parameter represents a valid number.
+ * <p>
+ * The body of this method is taken from the Javadoc documentation for the Java Double class.
+ *
+ * @param string an input string
+ * @see java.lang.Double
+ * @return <code>true</code> if the string represents a valid number, <code>false</code> otherwise
+ */
+ public static boolean isNumber(String string) {
+ final String Digits = "(\\p{Digit}+)";
+ final String HexDigits = "(\\p{XDigit}+)";
+ // an exponent is 'e' or 'E' followed by an optionally
+ // signed decimal integer.
+ final String Exp = "[eE][+-]?" + Digits;
+ final String fpRegex = ("[\\x00-\\x20]*" + // Optional leading "whitespace"
+ "[+-]?(" + // Optional sign character
+ "NaN|" + // "NaN" string
+ "Infinity|" + // "Infinity" string
+
+ // A decimal floating-point string representing a finite positive
+ // number without a leading sign has at most five basic pieces:
+ // Digits . Digits ExponentPart FloatTypeSuffix
+ //
+ // Since this method allows integer-only strings as input
+ // in addition to strings of floating-point literals, the
+ // two sub-patterns below are simplifications of the grammar
+ // productions from the Java Language Specification, 2nd
+ // edition, section 3.10.2.
+
+ // Digits ._opt Digits_opt ExponentPart_opt FloatTypeSuffix_opt
+ "(((" + Digits + "(\\.)?(" + Digits + "?)(" + Exp + ")?)|" +
+
+ // . Digits ExponentPart_opt FloatTypeSuffix_opt
+ "(\\.(" + Digits + ")(" + Exp + ")?)|" +
+
+ // Hexadecimal strings
+ "((" +
+ // 0[xX] HexDigits ._opt BinaryExponent FloatTypeSuffix_opt
+ "(0[xX]" + HexDigits + "(\\.)?)|" +
+
+ // 0[xX] HexDigits_opt . HexDigits BinaryExponent FloatTypeSuffix_opt
+ "(0[xX]" + HexDigits + "?(\\.)" + HexDigits + ")" +
+
+ ")[pP][+-]?" + Digits + "))" + "[fFdD]?))" + "[\\x00-\\x20]*");// Optional
+ // trailing
+ // "whitespace"
+
+ return Pattern.matches(fpRegex, string);
+ }
+
+ /**
+ * Set System.out and System.err to use the UTF8 character encoding.
+ *
+ * @return <code>true</code> if both System.out and System.err were successfully set to use UTF8,
+ * <code>false</code> otherwise.
+ */
+ public static boolean useUTF8() {
+
+ try {
+ System.setOut(new PrintStream(System.out, true, "UTF8"));
+ System.setErr(new PrintStream(System.err, true, "UTF8"));
+ return true;
+ } catch (UnsupportedEncodingException e1) {
+ LOG.warn("UTF8 is not a valid encoding; using system default encoding for System.out and System.err.");
+ return false;
+ } catch (SecurityException e2) {
+ LOG.warn("Security manager is configured to disallow changes to System.out or System.err; using system default encoding.");
+ return false;
+ }
+ }
+
+ /**
+ * Determines if a string contains ALL CAPS
+ *
+ * @param token an input token
+ * @return true if the string is all in uppercase, false otherwise
+ */
+ public static boolean ISALLUPPERCASE(String token) {
+ for (int i = 0; i < token.length(); i++)
+ if (! Character.isUpperCase(token.charAt(i)))
+ return false;
+ return true;
+ }
+
+ public static String capitalize(String word) {
+ if (word == null || word.length() == 0)
+ return word;
+ return word.substring(0, 1).toUpperCase() + word.substring(1);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/test/java/org/apache/joshua/corpus/VocabularyTest.java
----------------------------------------------------------------------
diff --cc src/test/java/org/apache/joshua/corpus/VocabularyTest.java
index fc41a1e,0000000..a282ba3
mode 100644,000000..100644
--- a/src/test/java/org/apache/joshua/corpus/VocabularyTest.java
+++ b/src/test/java/org/apache/joshua/corpus/VocabularyTest.java
@@@ -1,133 -1,0 +1,135 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import static org.junit.Assert.*;
+
+import java.io.File;
+import java.io.IOException;
++
++import org.apache.joshua.util.FormatUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+public class VocabularyTest {
+ private static final String WORD1 = "word1";
+ private static final String WORD2 = "word2";
+ private static final String NON_TERMINAL = "[X]";
+ private static final String GOAL = "[GOAL]";
+
+ @Before
+ public void init() {
+ Vocabulary.clear();
+ }
+
+ @After
+ public void deinit() {
+ Vocabulary.clear();
+ }
+
+ @Test
+ public void givenVocabulary_whenEmpty_thenOnlyContainsUnknownWord() {
+ assertTrue(Vocabulary.hasId(Vocabulary.UNKNOWN_ID));
+ assertFalse(Vocabulary.hasId(1));
+ assertFalse(Vocabulary.hasId(-1));
+ assertEquals(Vocabulary.UNKNOWN_WORD, Vocabulary.word(Vocabulary.UNKNOWN_ID));
+ assertEquals(1, Vocabulary.size());
+ }
+
+ @Test
+ public void givenVocabulary_whenNewWord_thenMappingIsAdded() {
+ final int FIRST_WORD_ID = 1;
+ assertFalse(Vocabulary.hasId(FIRST_WORD_ID));
+ assertEquals(FIRST_WORD_ID, Vocabulary.id(WORD1));
+ //should return same id after second call:
+ assertEquals(FIRST_WORD_ID, Vocabulary.id(WORD1));
+ assertTrue(Vocabulary.hasId(FIRST_WORD_ID));
+ assertEquals(WORD1, Vocabulary.word(FIRST_WORD_ID));
+ assertEquals(2, Vocabulary.size());
+ }
+
+ @Test
+ public void givenVocabulary_whenCheckingStringInBracketsOrNegativeNumber_thenIsNonTerminal() {
+ //non-terminals
- assertTrue(Vocabulary.nt(NON_TERMINAL));
++ assertTrue(FormatUtils.isNonterminal(NON_TERMINAL));
+ //terminals
- assertFalse(Vocabulary.nt(WORD1));
- assertFalse(Vocabulary.nt("[]"));
- assertFalse(Vocabulary.nt("["));
- assertFalse(Vocabulary.nt("]"));
- assertFalse(Vocabulary.nt(""));
++ assertFalse(FormatUtils.isNonterminal(WORD1));
++ assertFalse(FormatUtils.isNonterminal("[]"));
++ assertFalse(FormatUtils.isNonterminal("["));
++ assertFalse(FormatUtils.isNonterminal("]"));
++ assertFalse(FormatUtils.isNonterminal(""));
+
+ //negative numbers indicate non-terminals
- assertTrue(Vocabulary.nt(-1));
- assertTrue(Vocabulary.nt(-5));
++ assertTrue(FormatUtils.isNonterminal(-1));
++ assertTrue(FormatUtils.isNonterminal(-5));
+
+ //positive numbers indicate terminals:
- assertFalse(Vocabulary.nt(0));
- assertFalse(Vocabulary.nt(5));
++ assertFalse(FormatUtils.isNonterminal(0));
++ assertFalse(FormatUtils.isNonterminal(5));
+
+
+ }
+
+ @Test
+ public void givenVocabulary_whenNonTerminal_thenReturnsStrictlyPositiveNonTerminalIndices() {
+ final int FIRST_NON_TERMINAL_INDEX = 1;
+ assertTrue(Vocabulary.id(NON_TERMINAL) < 0);
+ assertTrue(Vocabulary.hasId(FIRST_NON_TERMINAL_INDEX));
+ assertTrue(Vocabulary.hasId(-FIRST_NON_TERMINAL_INDEX));
+
+ assertTrue(Vocabulary.id("") > 0);
+ assertTrue(Vocabulary.id(WORD1) > 0);
+
+ final int SECOND_NON_TERMINAL_INDEX = 4;
+ assertTrue(Vocabulary.id(GOAL) < 0);
+ assertTrue(Vocabulary.hasId(SECOND_NON_TERMINAL_INDEX));
+ assertTrue(Vocabulary.hasId(-SECOND_NON_TERMINAL_INDEX));
+
+ assertTrue(Vocabulary.id(WORD2) > 0);
+ }
+
+ @Rule
+ public TemporaryFolder folder = new TemporaryFolder();
+
+ @Test
+ public void givenVocabulary_whenWritenAndReading_thenVocabularyStaysTheSame() throws IOException {
+ File vocabFile = folder.newFile();
+
+ int id1 = Vocabulary.id(WORD1);
+ int id2 = Vocabulary.id(NON_TERMINAL);
+ int id3 = Vocabulary.id(WORD2);
+
+ Vocabulary.write(vocabFile.getAbsolutePath());
+
+ Vocabulary.clear();
+
+ Vocabulary.read(vocabFile);
+
+ assertEquals(4, Vocabulary.size()); //unknown word + 3 other words
+ assertTrue(Vocabulary.hasId(id1));
+ assertTrue(Vocabulary.hasId(id2));
+ assertTrue(Vocabulary.hasId(id3));
+ assertEquals(id1, Vocabulary.id(WORD1));
+ assertEquals(id2, Vocabulary.id(NON_TERMINAL));
+ assertEquals(id3, Vocabulary.id(WORD2));
+ }
+}