You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/05/27 00:34:30 UTC
[24/32] incubator-joshua git commit: Merge branch 'master' into
JOSHUA-252 (compiling, but not tested)
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/decoder/ff/tm/hash_based/MemoryBasedBatchGrammar.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/ff/tm/hash_based/MemoryBasedBatchGrammar.java
index 4099797,0000000..3a37e08
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/ff/tm/hash_based/MemoryBasedBatchGrammar.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/tm/hash_based/MemoryBasedBatchGrammar.java
@@@ -1,321 -1,0 +1,317 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.tm.hash_based;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.JoshuaConfiguration.OOVItem;
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.tm.AbstractGrammar;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.GrammarReader;
+import org.apache.joshua.decoder.ff.tm.Trie;
+import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
- import org.apache.joshua.decoder.ff.tm.format.PhraseFormatReader;
- import org.apache.joshua.decoder.ff.tm.format.SamtFormatReader;
++import org.apache.joshua.decoder.ff.tm.format.MosesFormatReader;
+import org.apache.joshua.util.FormatUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class implements a memory-based bilingual BatchGrammar.
+ * <p>
+ * The rules are stored in a trie. Each trie node has: (1) RuleBin: a list of rules matching the
+ * french sides so far (2) A HashMap of next-layer trie nodes, the next french word used as the key
+ * in HashMap
+ *
+ * @author Zhifei Li zhifei.work@gmail.com
+ * @author Matt Post post@cs.jhu.edu
+ */
+public class MemoryBasedBatchGrammar extends AbstractGrammar {
+
+ private static final Logger LOG = LoggerFactory.getLogger(MemoryBasedBatchGrammar.class);
+
+ // ===============================================================
+ // Instance Fields
+ // ===============================================================
+
+ /* The number of rules read. */
+ private int qtyRulesRead = 0;
+
+ /* The number of distinct source sides. */
+ private int qtyRuleBins = 0;
+
+ private int numDenseFeatures = 0;
+
+ /* The trie root. */
+ private MemoryBasedTrie root = null;
+
+ /* The file containing the grammar. */
+ private String grammarFile;
+
+ private GrammarReader<Rule> modelReader;
+
+ /* Whether the grammar's rules contain regular expressions. */
+ private boolean isRegexpGrammar = false;
+
+ // ===============================================================
+ // Static Fields
+ // ===============================================================
+
+ // ===============================================================
+ // Constructors
+ // ===============================================================
+
+ public MemoryBasedBatchGrammar(JoshuaConfiguration joshuaConfiguration) {
+ super(joshuaConfiguration);
+ this.root = new MemoryBasedTrie();
+ this.joshuaConfiguration = joshuaConfiguration;
+ }
+
+ public MemoryBasedBatchGrammar(String owner, JoshuaConfiguration joshuaConfiguration) {
+ this(joshuaConfiguration);
+ this.owner = Vocabulary.id(owner);
+ }
+
+ public MemoryBasedBatchGrammar(GrammarReader<Rule> gr, JoshuaConfiguration joshuaConfiguration) {
+ // this.defaultOwner = Vocabulary.id(defaultOwner);
+ // this.defaultLHS = Vocabulary.id(defaultLHSSymbol);
+ this(joshuaConfiguration);
+ modelReader = gr;
+ }
+
+ public MemoryBasedBatchGrammar(String formatKeyword, String grammarFile, String owner,
+ String defaultLHSSymbol, int spanLimit, JoshuaConfiguration joshuaConfiguration)
+ throws IOException {
+
+ this(joshuaConfiguration);
+ this.owner = Vocabulary.id(owner);
+ Vocabulary.id(defaultLHSSymbol);
+ this.spanLimit = spanLimit;
+ this.grammarFile = grammarFile;
+ this.setRegexpGrammar(formatKeyword.equals("regexp"));
+
+ // ==== loading grammar
+ this.modelReader = createReader(formatKeyword, grammarFile);
+ if (modelReader != null) {
- modelReader.initialize();
+ for (Rule rule : modelReader)
+ if (rule != null) {
+ addRule(rule);
+ }
+ } else {
+ LOG.info("Couldn't create a GrammarReader for file {} with format {}",
+ grammarFile, formatKeyword);
+ }
+
+ this.printGrammar();
+ }
+
- protected GrammarReader<Rule> createReader(String format, String grammarFile) {
++ protected GrammarReader<Rule> createReader(String format, String grammarFile) throws IOException {
+
+ if (grammarFile != null) {
+ if ("hiero".equals(format) || "thrax".equals(format) || "regexp".equals(format)) {
+ return new HieroFormatReader(grammarFile);
- } else if ("samt".equals(format)) {
- return new SamtFormatReader(grammarFile);
- } else if ("phrase".equals(format) || "moses".equals(format)) {
- return new PhraseFormatReader(grammarFile, format.equals("moses"));
++ } else if ("moses".equals(format)) {
++ return new MosesFormatReader(grammarFile);
+ } else {
+ throw new RuntimeException(String.format("* FATAL: unknown grammar format '%s'", format));
+ }
+ }
+ return null;
+ }
+
+ // ===============================================================
+ // Methods
+ // ===============================================================
+
+ public void setSpanLimit(int spanLimit) {
+ this.spanLimit = spanLimit;
+ }
+
+ @Override
+ public int getNumRules() {
+ return this.qtyRulesRead;
+ }
+
+ @Override
+ public Rule constructManualRule(int lhs, int[] sourceWords, int[] targetWords,
+ float[] denseScores, int arity) {
+ return null;
+ }
+
+ /**
+ * if the span covered by the chart bin is greater than the limit, then return false
+ */
+ public boolean hasRuleForSpan(int i, int j, int pathLength) {
+ if (this.spanLimit == -1) { // mono-glue grammar
+ return (i == 0);
+ } else {
+ // System.err.println(String.format("%s HASRULEFORSPAN(%d,%d,%d)/%d = %s",
+ // Vocabulary.word(this.owner), i, j, pathLength, spanLimit, pathLength <= this.spanLimit));
+ return (pathLength <= this.spanLimit);
+ }
+ }
+
+ public Trie getTrieRoot() {
+ return this.root;
+ }
+
+ /**
+ * Adds a rule to the grammar.
+ */
+ public void addRule(Rule rule) {
+
+ // TODO: Why two increments?
+ this.qtyRulesRead++;
+
+ // if (owner == -1) {
+ // System.err.println("* FATAL: MemoryBasedBatchGrammar::addRule(): owner not set for grammar");
+ // System.exit(1);
+ // }
+ rule.setOwner(owner);
+
+ if (numDenseFeatures == 0)
+ numDenseFeatures = rule.getFeatureVector().getDenseFeatures().size();
+
+ // === identify the position, and insert the trie nodes as necessary
+ MemoryBasedTrie pos = root;
+ int[] french = rule.getFrench();
+
+ maxSourcePhraseLength = Math.max(maxSourcePhraseLength, french.length);
+
+ for (int k = 0; k < french.length; k++) {
+ int curSymID = french[k];
+
+ /*
+ * Note that the nonTerminal symbol in the french is not cleaned (i.e., will be sth like
+ * [X,1]), but the symbol in the Trie has to be cleaned, so that the match does not care about
+ * the markup (i.e., [X,1] or [X,2] means the same thing, that is X) if
+ * (Vocabulary.nt(french[k])) { curSymID = modelReader.cleanNonTerminal(french[k]); if
+ * (logger.isLoggable(Level.FINEST)) logger.finest("Amended to: " + curSymID); }
+ */
+
+ MemoryBasedTrie nextLayer = (MemoryBasedTrie) pos.match(curSymID);
+ if (null == nextLayer) {
+ nextLayer = new MemoryBasedTrie();
+ if (pos.hasExtensions() == false) {
+ pos.childrenTbl = new HashMap<Integer, MemoryBasedTrie>();
+ }
+ pos.childrenTbl.put(curSymID, nextLayer);
+ }
+ pos = nextLayer;
+ }
+
+ // === add the rule into the trie node
+ if (!pos.hasRules()) {
+ pos.ruleBin = new MemoryBasedRuleBin(rule.getArity(), rule.getFrench());
+ this.qtyRuleBins++;
+ }
+ pos.ruleBin.addRule(rule);
+ }
+
+ protected void printGrammar() {
+ LOG.info("MemoryBasedBatchGrammar: Read {} rules with {} distinct source sides from '{}'",
+ this.qtyRulesRead, this.qtyRuleBins, grammarFile);
+ }
+
+ /**
+ * This returns true if the grammar contains rules that are regular expressions, possibly matching
+ * many different inputs.
+ *
+ * @return true if the grammar's rules may contain regular expressions.
+ */
+ @Override
+ public boolean isRegexpGrammar() {
+ return this.isRegexpGrammar;
+ }
+
+ public void setRegexpGrammar(boolean value) {
+ this.isRegexpGrammar = value;
+ }
+
+ /***
+ * Takes an input word and creates an OOV rule in the current grammar for that word.
+ *
+ * @param sourceWord integer representation of word
+ * @param featureFunctions {@link java.util.List} of {@link org.apache.joshua.decoder.ff.FeatureFunction}'s
+ */
+ @Override
+ public void addOOVRules(int sourceWord, List<FeatureFunction> featureFunctions) {
+
+ // TODO: _OOV shouldn't be outright added, since the word might not be OOV for the LM (but now
+ // almost
+ // certainly is)
+ final int targetWord = this.joshuaConfiguration.mark_oovs ? Vocabulary.id(Vocabulary
+ .word(sourceWord) + "_OOV") : sourceWord;
+
+ int[] sourceWords = { sourceWord };
+ int[] targetWords = { targetWord };
+ final String oovAlignment = "0-0";
+
+ if (this.joshuaConfiguration.oovList != null && this.joshuaConfiguration.oovList.size() != 0) {
+ for (OOVItem item : this.joshuaConfiguration.oovList) {
+ Rule oovRule = new Rule(Vocabulary.id(item.label), sourceWords, targetWords, "", 0,
+ oovAlignment);
+ addRule(oovRule);
+ oovRule.estimateRuleCost(featureFunctions);
+ }
+ } else {
+ int nt_i = Vocabulary.id(this.joshuaConfiguration.default_non_terminal);
+ Rule oovRule = new Rule(nt_i, sourceWords, targetWords, "", 0, oovAlignment);
+ addRule(oovRule);
+ oovRule.estimateRuleCost(featureFunctions);
+ }
+ }
+
+ /**
+ * Adds a default set of glue rules.
+ *
+ * @param featureFunctions an {@link java.util.ArrayList} of {@link org.apache.joshua.decoder.ff.FeatureFunction}'s
+ */
+ public void addGlueRules(ArrayList<FeatureFunction> featureFunctions) {
+ HieroFormatReader reader = new HieroFormatReader();
+
+ String goalNT = FormatUtils.cleanNonTerminal(joshuaConfiguration.goal_symbol);
+ String defaultNT = FormatUtils.cleanNonTerminal(joshuaConfiguration.default_non_terminal);
+
+ String[] ruleStrings = new String[] {
+ String.format("[%s] ||| %s ||| %s ||| 0", goalNT, Vocabulary.START_SYM,
+ Vocabulary.START_SYM),
+ String.format("[%s] ||| [%s,1] [%s,2] ||| [%s,1] [%s,2] ||| -1", goalNT, goalNT, defaultNT,
+ goalNT, defaultNT),
+ String.format("[%s] ||| [%s,1] %s ||| [%s,1] %s ||| 0", goalNT, goalNT,
+ Vocabulary.STOP_SYM, goalNT, Vocabulary.STOP_SYM) };
+
+ for (String ruleString : ruleStrings) {
+ Rule rule = reader.parseLine(ruleString);
+ addRule(rule);
+ rule.estimateRuleCost(featureFunctions);
+ }
+ }
+
+ @Override
+ public int getNumDenseFeatures() {
+ return numDenseFeatures;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
index c6dbadc,0000000..632644f
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
@@@ -1,1057 -1,0 +1,1080 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.tm.packed;
+
+/***
+ * This package implements Joshua's packed grammar structure, which enables the efficient loading
+ * and accessing of grammars. It is described in the paper:
+ *
+ * @article{ganitkevitch2012joshua,
+ * Author = {Ganitkevitch, J. and Cao, Y. and Weese, J. and Post, M. and Callison-Burch, C.},
+ * Journal = {Proceedings of WMT12},
+ * Title = {Joshua 4.0: Packing, PRO, and paraphrases},
+ * Year = {2012}}
+ *
+ * The packed grammar works by compiling out the grammar tries into a compact format that is loaded
+ * and parsed directly from Java arrays. A fundamental problem is that Java arrays are indexed
+ * by ints and not longs, meaning the maximum size of the packed grammar is about 2 GB. This forces
+ * the use of packed grammar slices, which together constitute the grammar. The figure in the
+ * paper above shows what each slice looks like.
+ *
+ * The division across slices is done in a depth-first manner. Consider the entire grammar organized
+ * into a single source-side trie. The splits across tries are done by grouping the root-level
+ * outgoing trie arcs --- and the entire trie beneath them --- across slices.
+ *
+ * This presents a problem: if the subtree rooted beneath a single top-level arc is too big for a
+ * slice, the grammar can't be packed. This happens with very large Hiero grammars, for example,
+ * where there are a *lot* of rules that start with [X].
+ *
+ * A solution being worked on is to split that symbol and pack them into separate grammars with a
+ * shared vocabulary, and then rely on Joshua's ability to query multiple grammars for rules to
+ * solve this problem. This is not currently implemented but could be done directly in the
+ * Grammar Packer.
+ *
+ * *UPDATE 10/2015*
+ * The introduction of a SliceAggregatingTrie together with sorting the grammar by the full source string
+ * (not just by the first source word) allows distributing rules with the same first source word
+ * across multiple slices.
+ * @author fhieber
+ */
+
+import static java.util.Collections.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.BufferUnderflowException;
+import java.nio.ByteBuffer;
+import java.nio.IntBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.channels.FileChannel.MapMode;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.security.DigestInputStream;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.tm.AbstractGrammar;
+import org.apache.joshua.decoder.ff.tm.BasicRuleCollection;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.RuleCollection;
+import org.apache.joshua.decoder.ff.tm.Trie;
+import org.apache.joshua.decoder.ff.tm.hash_based.ExtensionIterator;
++import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.encoding.EncoderConfiguration;
+import org.apache.joshua.util.encoding.FloatEncoder;
+import org.apache.joshua.util.io.LineReader;
+
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class PackedGrammar extends AbstractGrammar {
+
+ private static final Logger LOG = LoggerFactory.getLogger(PackedGrammar.class);
+ public static final String VOCABULARY_FILENAME = "vocabulary";
+
+ private EncoderConfiguration encoding;
+ private PackedRoot root;
+ private ArrayList<PackedSlice> slices;
+
+ private final File vocabFile; // store path to vocabulary file
+
+ // The grammar specification keyword (e.g., "thrax" or "moses")
+ private String type;
+
++ // The version number of the earliest supported grammar packer
++ public static final int SUPPORTED_VERSION = 3;
++
+ // A rule cache for commonly used tries to avoid excess object allocations
+ // Testing shows there's up to ~95% hit rate when cache size is 5000 Trie nodes.
+ private final Cache<Trie, List<Rule>> cached_rules;
+
++ private String grammarDir;
++
+ public PackedGrammar(String grammar_dir, int span_limit, String owner, String type,
+ JoshuaConfiguration joshuaConfiguration) throws IOException {
+ super(joshuaConfiguration);
++
++ this.grammarDir = grammar_dir;
+ this.spanLimit = span_limit;
- this.type = type;
+
+ // Read the vocabulary.
+ vocabFile = new File(grammar_dir + File.separator + VOCABULARY_FILENAME);
+ LOG.info("Reading vocabulary: {}", vocabFile);
+ if (!Vocabulary.read(vocabFile)) {
+ throw new RuntimeException("mismatches or collisions while reading on-disk vocabulary");
+ }
+
+ // Read the config
+ String configFile = grammar_dir + File.separator + "config";
+ if (new File(configFile).exists()) {
+ LOG.info("Reading packed config: {}", configFile);
+ readConfig(configFile);
+ }
+
+ // Read the quantizer setup.
+ LOG.info("Reading encoder configuration: {}{}encoding", grammar_dir, File.separator);
+ encoding = new EncoderConfiguration();
+ encoding.load(grammar_dir + File.separator + "encoding");
+
+ // Set phrase owner.
+ this.owner = Vocabulary.id(owner);
+
+ final List<String> listing = Arrays.asList(new File(grammar_dir).list());
+ sort(listing); // File.list() has arbitrary sort order
+ slices = new ArrayList<PackedSlice>();
+ for (String prefix : listing) {
+ if (prefix.startsWith("slice_") && prefix.endsWith(".source"))
+ slices.add(new PackedSlice(grammar_dir + File.separator + prefix.substring(0, 11)));
+ }
+
+ long count = 0;
+ for (PackedSlice s : slices)
+ count += s.estimated.length;
+ root = new PackedRoot(slices);
+ cached_rules = CacheBuilder.newBuilder().maximumSize(joshuaConfiguration.cachedRuleSize).build();
+
+ LOG.info("Loaded {} rules", count);
+ }
+
+ @Override
+ public Trie getTrieRoot() {
+ return root;
+ }
+
+ @Override
+ public boolean hasRuleForSpan(int startIndex, int endIndex, int pathLength) {
+ return (spanLimit == -1 || pathLength <= spanLimit);
+ }
+
+ @Override
+ public int getNumRules() {
+ int num_rules = 0;
+ for (PackedSlice ps : slices)
+ num_rules += ps.featureSize;
+ return num_rules;
+ }
+
+ @Override
+ public int getNumDenseFeatures() {
+ return encoding.getNumDenseFeatures();
+ }
+
+ public Rule constructManualRule(int lhs, int[] src, int[] tgt, float[] scores, int arity) {
+ return null;
+ }
+
+ /**
+ * Computes the MD5 checksum of the vocabulary file.
+ * Can be used for comparing vocabularies across multiple packedGrammars.
+ * @return the computed checksum
+ */
+ public String computeVocabularyChecksum() {
+ MessageDigest md;
+ try {
+ md = MessageDigest.getInstance("MD5");
+ } catch (NoSuchAlgorithmException e) {
+ throw new RuntimeException("Unknown checksum algorithm");
+ }
+ byte[] buffer = new byte[1024];
+ try (final InputStream is = Files.newInputStream(Paths.get(vocabFile.toString()));
+ DigestInputStream dis = new DigestInputStream(is, md)) {
+ while (dis.read(buffer) != -1) {}
+ } catch (IOException e) {
+ throw new RuntimeException("Can not find vocabulary file. This should not happen.");
+ }
+ byte[] digest = md.digest();
+ // convert the byte to hex format
+ StringBuffer sb = new StringBuffer("");
+ for (int i = 0; i < digest.length; i++) {
+ sb.append(Integer.toString((digest[i] & 0xff) + 0x100, 16).substring(1));
+ }
+ return sb.toString();
+ }
+
+ /**
+ * PackedRoot represents the root of the packed grammar trie.
+ * Tries for different source-side firstwords are organized in
+ * packedSlices on disk. A packedSlice can contain multiple trie
+ * roots (i.e. multiple source-side firstwords).
+ * The PackedRoot builds a lookup table, mapping from
+ * source-side firstwords to the addresses in the packedSlices
+ * that represent the subtrie for a particular firstword.
+ * If the GrammarPacker has to distribute rules for a
+ * source-side firstword across multiple slices, a
+ * SliceAggregatingTrie node is created that aggregates those
+ * tries to hide
+ * this additional complexity from the grammar interface
+ * This feature allows packing of grammars where the list of rules
+ * for a single source-side firstword would exceed the maximum array
+ * size of Java (2gb).
+ */
+ public final class PackedRoot implements Trie {
+
+ private final HashMap<Integer, Trie> lookup;
+
+ public PackedRoot(final List<PackedSlice> slices) {
+ final Map<Integer, List<Trie>> childTries = collectChildTries(slices);
+ lookup = buildLookupTable(childTries);
+ }
+
+ /**
+ * Determines whether trie nodes for source first-words are spread over
+ * multiple packedSlices by counting their occurrences.
+ * @param slices
+ * @return A mapping from first word ids to a list of trie nodes.
+ */
+ private Map<Integer, List<Trie>> collectChildTries(final List<PackedSlice> slices) {
+ final Map<Integer, List<Trie>> childTries = new HashMap<>();
+ for (PackedSlice packedSlice : slices) {
+
+ // number of tries stored in this packedSlice
+ final int num_children = packedSlice.source[0];
+ for (int i = 0; i < num_children; i++) {
+ final int id = packedSlice.source[2 * i + 1];
+
+ /* aggregate tries with same root id
+ * obtain a Trie node, already at the correct address in the packedSlice.
+ * In other words, the lookup index already points to the correct trie node in the packedSlice.
+ * packedRoot.match() thus can directly return the result of lookup.get(id);
+ */
+ if (!childTries.containsKey(id)) {
+ childTries.put(id, new ArrayList<Trie>(1));
+ }
+ final Trie trie = packedSlice.root().match(id);
+ childTries.get(id).add(trie);
+ }
+ }
+ return childTries;
+ }
+
+ /**
+ * Build a lookup table for children tries.
+ * If the list contains only a single child node, a regular trie node
+ * is inserted into the table; otherwise a SliceAggregatingTrie node is
+ * created that hides this partitioning into multiple packedSlices
+ * upstream.
+ */
+ private HashMap<Integer,Trie> buildLookupTable(final Map<Integer, List<Trie>> childTries) {
+ HashMap<Integer,Trie> lookup = new HashMap<>(childTries.size());
+ for (int id : childTries.keySet()) {
+ final List<Trie> tries = childTries.get(id);
+ if (tries.size() == 1) {
+ lookup.put(id, tries.get(0));
+ } else {
+ lookup.put(id, new SliceAggregatingTrie(tries));
+ }
+ }
+ return lookup;
+ }
+
+ @Override
+ public Trie match(int word_id) {
+ return lookup.get(word_id);
+ }
+
+ @Override
+ public boolean hasExtensions() {
+ return !lookup.isEmpty();
+ }
+
+ @Override
+ public HashMap<Integer, ? extends Trie> getChildren() {
+ return lookup;
+ }
+
+ @Override
+ public ArrayList<? extends Trie> getExtensions() {
+ return new ArrayList<>(lookup.values());
+ }
+
+ @Override
+ public boolean hasRules() {
+ return false;
+ }
+
+ @Override
+ public RuleCollection getRuleCollection() {
+ return new BasicRuleCollection(0, new int[0]);
+ }
+
+ @Override
+ public Iterator<Integer> getTerminalExtensionIterator() {
+ return new ExtensionIterator(lookup, true);
+ }
+
+ @Override
+ public Iterator<Integer> getNonterminalExtensionIterator() {
+ return new ExtensionIterator(lookup, false);
+ }
+ }
+
+ public final class PackedSlice {
+ private final String name;
+
+ private final int[] source;
+ private final IntBuffer target;
+ private final ByteBuffer features;
+ private final ByteBuffer alignments;
+
+ private final int[] targetLookup;
+ private int featureSize;
+ private float[] estimated;
+ private float[] precomputable;
+
+ private final static int BUFFER_HEADER_POSITION = 8;
+
+ /**
+ * Provides a cache of packedTrie nodes to be used in getTrie.
+ */
+ private HashMap<Integer, PackedTrie> tries;
+
+ public PackedSlice(String prefix) throws IOException {
+ name = prefix;
+
+ File source_file = new File(prefix + ".source");
+ File target_file = new File(prefix + ".target");
+ File target_lookup_file = new File(prefix + ".target.lookup");
+ File feature_file = new File(prefix + ".features");
+ File alignment_file = new File(prefix + ".alignments");
+
+ source = fullyLoadFileToArray(source_file);
+ // First int specifies the size of this file, load from 1st int on
+ targetLookup = fullyLoadFileToArray(target_lookup_file, 1);
+
+ target = associateMemoryMappedFile(target_file).asIntBuffer();
+ features = associateMemoryMappedFile(feature_file);
+ initializeFeatureStructures();
+
+ if (alignment_file.exists()) {
+ alignments = associateMemoryMappedFile(alignment_file);
+ } else {
+ alignments = null;
+ }
+
+ tries = new HashMap<Integer, PackedTrie>();
+ }
+
+ /**
+ * Helper function to help create all the structures which describe features
+ * in the Slice. Only called during object construction.
+ */
+ private void initializeFeatureStructures() {
+ int num_blocks = features.getInt(0);
+ estimated = new float[num_blocks];
+ precomputable = new float[num_blocks];
+ Arrays.fill(estimated, Float.NEGATIVE_INFINITY);
+ Arrays.fill(precomputable, Float.NEGATIVE_INFINITY);
+ featureSize = features.getInt(4);
+ }
+
+ private int getIntFromByteBuffer(int position, ByteBuffer buffer) {
+ return buffer.getInt(BUFFER_HEADER_POSITION + (4 * position));
+ }
+
+ private int[] fullyLoadFileToArray(File file) throws IOException {
+ return fullyLoadFileToArray(file, 0);
+ }
+
+ /**
+ * This function will use a bulk loading method to fully populate a target
+ * array from file.
+ *
+ * @param file
+ * File that will be read from disk.
+ * @param startIndex
+ * an offset into the read file.
+ * @return an int array of size length(file) - offset containing ints in the
+ * file.
+ * @throws IOException
+ */
+ private int[] fullyLoadFileToArray(File file, int startIndex) throws IOException {
+ IntBuffer buffer = associateMemoryMappedFile(file).asIntBuffer();
+ int size = (int) (file.length() - (4 * startIndex))/4;
+ int[] result = new int[size];
+ buffer.position(startIndex);
+ buffer.get(result, 0, size);
+ return result;
+ }
+
+ private ByteBuffer associateMemoryMappedFile(File file) throws IOException {
+ try(FileInputStream fileInputStream = new FileInputStream(file)) {
+ FileChannel fileChannel = fileInputStream.getChannel();
+ int size = (int) fileChannel.size();
+ MappedByteBuffer result = fileChannel.map(MapMode.READ_ONLY, 0, size);
+ return result;
+ }
+ }
+
+ private final int[] getTarget(int pointer) {
+ // Figure out level.
+ int tgt_length = 1;
+ while (tgt_length < (targetLookup.length + 1) && targetLookup[tgt_length] <= pointer)
+ tgt_length++;
+ int[] tgt = new int[tgt_length];
+ int index = 0;
+ int parent;
+ do {
+ parent = target.get(pointer);
+ if (parent != -1)
+ tgt[index++] = target.get(pointer + 1);
+ pointer = parent;
+ } while (pointer != -1);
+ return tgt;
+ }
+
+ private synchronized PackedTrie getTrie(final int node_address) {
+ PackedTrie t = tries.get(node_address);
+ if (t == null) {
+ t = new PackedTrie(node_address);
+ tries.put(node_address, t);
+ }
+ return t;
+ }
+
+ private synchronized PackedTrie getTrie(int node_address, int[] parent_src, int parent_arity,
+ int symbol) {
+ PackedTrie t = tries.get(node_address);
+ if (t == null) {
+ t = new PackedTrie(node_address, parent_src, parent_arity, symbol);
+ tries.put(node_address, t);
+ }
+ return t;
+ }
+
+ /**
+ * Returns the FeatureVector associated with a rule (represented as a block ID).
+ * These features are in the form "feature1=value feature2=value...". By default, unlabeled
+ * features are named using the pattern.
+ * @param block_id
+ * @return feature vector
+ */
+
+ private final FeatureVector loadFeatureVector(int block_id) {
+ int featurePosition = getIntFromByteBuffer(block_id, features);
+ final int numFeatures = encoding.readId(features, featurePosition);
+
+ featurePosition += EncoderConfiguration.ID_SIZE;
+ final FeatureVector featureVector = new FeatureVector();
+ FloatEncoder encoder;
+ String featureName;
+
+ for (int i = 0; i < numFeatures; i++) {
+ final int innerId = encoding.readId(features, featurePosition);
+ final int outerId = encoding.outerId(innerId);
+ encoder = encoding.encoder(innerId);
+ // TODO (fhieber): why on earth are dense feature ids (ints) encoded in the vocabulary?
+ featureName = Vocabulary.word(outerId);
+ final float value = encoder.read(features, featurePosition);
+ try {
+ int index = Integer.parseInt(featureName);
+ featureVector.increment(index, -value);
+ } catch (NumberFormatException e) {
+ featureVector.increment(featureName, value);
+ }
+ featurePosition += EncoderConfiguration.ID_SIZE + encoder.size();
+ }
+
+ return featureVector;
+ }
+
+ /**
+ * We need to synchronize this method as there is a many to one ratio between
+ * PackedRule/PhrasePair and this class (PackedSlice). This means during concurrent first
+ * getAlignments calls to PackedRule objects they could alter each other's positions within the
+ * buffer before calling read on the buffer.
+ */
+ private synchronized final byte[] getAlignmentArray(int block_id) {
+ if (alignments == null)
+ throw new RuntimeException("No alignments available.");
+ int alignment_position = getIntFromByteBuffer(block_id, alignments);
+ int num_points = (int) alignments.get(alignment_position);
+ byte[] alignment = new byte[num_points * 2];
+
+ alignments.position(alignment_position + 1);
+ try {
+ alignments.get(alignment, 0, num_points * 2);
+ } catch (BufferUnderflowException bue) {
+ LOG.warn("Had an exception when accessing alignment mapped byte buffer");
+ LOG.warn("Attempting to access alignments at position: {}", alignment_position + 1);
+ LOG.warn("And to read this many bytes: {}", num_points * 2);
+ LOG.warn("Buffer capacity is : {}", alignments.capacity());
+ LOG.warn("Buffer position is : {}", alignments.position());
+ LOG.warn("Buffer limit is : {}", alignments.limit());
+ throw bue;
+ }
+ return alignment;
+ }
+
+ private final PackedTrie root() {
+ return getTrie(0);
+ }
+
+ public String toString() {
+ return name;
+ }
+
+ /**
+ * A trie node within the grammar slice. Identified by its position within the source array,
+ * and, as a supplement, the source string leading from the trie root to the node.
+ *
+ * @author jg
+ *
+ */
+ public class PackedTrie implements Trie, RuleCollection {
+
+ private final int position;
+
+ private boolean sorted = false;
+
+ private int[] src;
+ private int arity;
+
+ private PackedTrie(int position) {
+ this.position = position;
+ src = new int[0];
+ arity = 0;
+ }
+
+ private PackedTrie(int position, int[] parent_src, int parent_arity, int symbol) {
+ this.position = position;
+ src = new int[parent_src.length + 1];
+ System.arraycopy(parent_src, 0, src, 0, parent_src.length);
+ src[src.length - 1] = symbol;
+ arity = parent_arity;
- if (Vocabulary.nt(symbol))
++ if (FormatUtils.isNonterminal(symbol))
+ arity++;
+ }
+
+ @Override
+ public final Trie match(int token_id) {
+ int num_children = source[position];
+ if (num_children == 0)
+ return null;
+ if (num_children == 1 && token_id == source[position + 1])
+ return getTrie(source[position + 2], src, arity, token_id);
+ int top = 0;
+ int bottom = num_children - 1;
+ while (true) {
+ int candidate = (top + bottom) / 2;
+ int candidate_position = position + 1 + 2 * candidate;
+ int read_token = source[candidate_position];
+ if (read_token == token_id) {
+ return getTrie(source[candidate_position + 1], src, arity, token_id);
+ } else if (top == bottom) {
+ return null;
+ } else if (read_token > token_id) {
+ top = candidate + 1;
+ } else {
+ bottom = candidate - 1;
+ }
+ if (bottom < top)
+ return null;
+ }
+ }
+
+ @Override
+ public HashMap<Integer, ? extends Trie> getChildren() {
+ HashMap<Integer, Trie> children = new HashMap<Integer, Trie>();
+ int num_children = source[position];
+ for (int i = 0; i < num_children; i++) {
+ int symbol = source[position + 1 + 2 * i];
+ int address = source[position + 2 + 2 * i];
+ children.put(symbol, getTrie(address, src, arity, symbol));
+ }
+ return children;
+ }
+
+ @Override
+ public boolean hasExtensions() {
+ return (source[position] != 0);
+ }
+
+ @Override
+ public ArrayList<? extends Trie> getExtensions() {
+ int num_children = source[position];
+ ArrayList<PackedTrie> tries = new ArrayList<PackedTrie>(num_children);
+
+ for (int i = 0; i < num_children; i++) {
+ int symbol = source[position + 1 + 2 * i];
+ int address = source[position + 2 + 2 * i];
+ tries.add(getTrie(address, src, arity, symbol));
+ }
+
+ return tries;
+ }
+
+ @Override
+ public boolean hasRules() {
+ int num_children = source[position];
+ return (source[position + 1 + 2 * num_children] != 0);
+ }
+
+ @Override
+ public RuleCollection getRuleCollection() {
+ return this;
+ }
+
+ @Override
+ public List<Rule> getRules() {
+ List<Rule> rules = cached_rules.getIfPresent(this);
+ if (rules != null) {
+ return rules;
+ }
+
+ int num_children = source[position];
+ int rule_position = position + 2 * (num_children + 1);
+ int num_rules = source[rule_position - 1];
+
+ rules = new ArrayList<Rule>(num_rules);
+ for (int i = 0; i < num_rules; i++) {
- if (type.equals("moses") || type.equals("phrase"))
- rules.add(new PackedPhrasePair(rule_position + 3 * i));
- else
- rules.add(new PackedRule(rule_position + 3 * i));
++ rules.add(new PackedRule(rule_position + 3 * i));
+ }
+
+ cached_rules.put(this, rules);
+ return rules;
+ }
+
+ /**
+ * We determine if the Trie is sorted by checking if the estimated cost of the first rule in
+ * the trie has been set.
+ */
+ @Override
+ public boolean isSorted() {
+ return sorted;
+ }
+
+ private synchronized void sortRules(List<FeatureFunction> models) {
+ int num_children = source[position];
+ int rule_position = position + 2 * (num_children + 1);
+ int num_rules = source[rule_position - 1];
+ if (num_rules == 0) {
+ this.sorted = true;
+ return;
+ }
+ Integer[] rules = new Integer[num_rules];
+
+ int target_address;
+ int block_id;
+ for (int i = 0; i < num_rules; ++i) {
+ target_address = source[rule_position + 1 + 3 * i];
+ rules[i] = rule_position + 2 + 3 * i;
+ block_id = source[rules[i]];
+
+ Rule rule = new Rule(source[rule_position + 3 * i], src,
+ getTarget(target_address), loadFeatureVector(block_id), arity, owner);
+ estimated[block_id] = rule.estimateRuleCost(models);
+ precomputable[block_id] = rule.getPrecomputableCost();
+ }
+
+ Arrays.sort(rules, new Comparator<Integer>() {
+ public int compare(Integer a, Integer b) {
+ float a_cost = estimated[source[a]];
+ float b_cost = estimated[source[b]];
+ if (a_cost == b_cost)
+ return 0;
+ return (a_cost > b_cost ? -1 : 1);
+ }
+ });
+
+ int[] sorted = new int[3 * num_rules];
+ int j = 0;
+ for (int i = 0; i < rules.length; i++) {
+ int address = rules[i];
+ sorted[j++] = source[address - 2];
+ sorted[j++] = source[address - 1];
+ sorted[j++] = source[address];
+ }
+ for (int i = 0; i < sorted.length; i++)
+ source[rule_position + i] = sorted[i];
+
+ // Replace rules in cache with their sorted values on next getRules()
+ cached_rules.invalidate(this);
+ this.sorted = true;
+ }
+
+ @Override
+ public List<Rule> getSortedRules(List<FeatureFunction> featureFunctions) {
+ if (!isSorted())
+ sortRules(featureFunctions);
+ return getRules();
+ }
+
+ @Override
+ public int[] getSourceSide() {
+ return src;
+ }
+
+ @Override
+ public int getArity() {
+ return arity;
+ }
+
+ @Override
+ public Iterator<Integer> getTerminalExtensionIterator() {
+ return new PackedChildIterator(position, true);
+ }
+
+ @Override
+ public Iterator<Integer> getNonterminalExtensionIterator() {
+ return new PackedChildIterator(position, false);
+ }
+
+ public final class PackedChildIterator implements Iterator<Integer> {
+
+ private int current;
+ private boolean terminal;
+ private boolean done;
+ private int last;
+
+ PackedChildIterator(int position, boolean terminal) {
+ this.terminal = terminal;
+ int num_children = source[position];
+ done = (num_children == 0);
+ if (!done) {
+ current = (terminal ? position + 1 : position - 1 + 2 * num_children);
+ last = (terminal ? position - 1 + 2 * num_children : position + 1);
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (done)
+ return false;
+ int next = (terminal ? current + 2 : current - 2);
+ if (next == last)
+ return false;
+ return (terminal ? source[next] > 0 : source[next] < 0);
+ }
+
+ @Override
+ public Integer next() {
+ if (done)
+ throw new RuntimeException("No more symbols!");
+ int symbol = source[current];
+ if (current == last)
+ done = true;
+ if (!done) {
+ current = (terminal ? current + 2 : current - 2);
+ done = (terminal ? source[current] < 0 : source[current] > 0);
+ }
+ return symbol;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ /**
+ * A packed phrase pair represents a rule of the form of a phrase pair, packed with the
+ * grammar-packer.pl script, which simply adds a nonterminal [X] to the left-hand side of
+ * all phrase pairs (and converts the Moses features). The packer then packs these. We have
+ * to then put a nonterminal on the source and target sides to treat the phrase pairs like
+ * left-branching rules, which is how Joshua deals with phrase decoding.
+ *
+ * @author Matt Post post@cs.jhu.edu
+ *
+ */
+ public final class PackedPhrasePair extends PackedRule {
+
+ private final Supplier<int[]> englishSupplier;
+ private final Supplier<byte[]> alignmentSupplier;
+
+ public PackedPhrasePair(int address) {
+ super(address);
+ englishSupplier = initializeEnglishSupplier();
+ alignmentSupplier = initializeAlignmentSupplier();
+ }
+
+ @Override
+ public int getArity() {
+ return PackedTrie.this.getArity() + 1;
+ }
+
+ /**
+ * Initialize a number of suppliers which get evaluated when their respective getters
+ * are called.
+ * Inner lambda functions are guaranteed to only be called once, because of this underlying
+ * structures are accessed in a threadsafe way.
+ * Guava's implementation makes sure only one read of a volatile variable occurs per get.
+ * This means this implementation should be as thread-safe and performant as possible.
+ */
+
+ private Supplier<int[]> initializeEnglishSupplier(){
+ Supplier<int[]> result = Suppliers.memoize(() ->{
+ int[] phrase = getTarget(source[address + 1]);
+ int[] tgt = new int[phrase.length + 1];
+ tgt[0] = -1;
+ for (int i = 0; i < phrase.length; i++)
+ tgt[i+1] = phrase[i];
+ return tgt;
+ });
+ return result;
+ }
+
+ private Supplier<byte[]> initializeAlignmentSupplier(){
+ Supplier<byte[]> result = Suppliers.memoize(() ->{
+ byte[] raw_alignment = getAlignmentArray(source[address + 2]);
+ byte[] points = new byte[raw_alignment.length + 2];
+ points[0] = points[1] = 0;
+ for (int i = 0; i < raw_alignment.length; i++)
+ points[i + 2] = (byte) (raw_alignment[i] + 1);
+ return points;
+ });
+ return result;
+ }
+
+ /**
+ * Take the English phrase of the underlying rule and prepend an [X].
+ *
+ * @return the augmented phrase
+ */
+ @Override
+ public int[] getEnglish() {
+ return this.englishSupplier.get();
+ }
+
+ /**
+ * Take the French phrase of the underlying rule and prepend an [X].
+ *
+ * @return the augmented French phrase
+ */
+ @Override
+ public int[] getFrench() {
+ int phrase[] = new int[src.length + 1];
+ int ntid = Vocabulary.id(PackedGrammar.this.joshuaConfiguration.default_non_terminal);
+ phrase[0] = ntid;
+ System.arraycopy(src, 0, phrase, 1, src.length);
+ return phrase;
+ }
+
+ /**
+ * Similarly the alignment array needs to be shifted over by one.
+ *
+ * @return the byte[] alignment
+ */
+ @Override
+ public byte[] getAlignment() {
+ // if no alignments in grammar do not fail
+ if (alignments == null) {
+ return null;
+ }
+
+ return this.alignmentSupplier.get();
+ }
+ }
+
+ public class PackedRule extends Rule {
+ protected final int address;
+ private final Supplier<int[]> englishSupplier;
+ private final Supplier<FeatureVector> featureVectorSupplier;
+ private final Supplier<byte[]> alignmentsSupplier;
+
+ public PackedRule(int address) {
+ this.address = address;
+ this.englishSupplier = intializeEnglishSupplier();
+ this.featureVectorSupplier = initializeFeatureVectorSupplier();
+ this.alignmentsSupplier = initializeAlignmentsSupplier();
+ }
+
+ private Supplier<int[]> intializeEnglishSupplier(){
+ Supplier<int[]> result = Suppliers.memoize(() ->{
+ return getTarget(source[address + 1]);
+ });
+ return result;
+ }
+
+ private Supplier<FeatureVector> initializeFeatureVectorSupplier(){
+ Supplier<FeatureVector> result = Suppliers.memoize(() ->{
+ return loadFeatureVector(source[address + 2]);
+ });
+ return result;
+ }
+
+ private Supplier<byte[]> initializeAlignmentsSupplier(){
+ Supplier<byte[]> result = Suppliers.memoize(()->{
+ // if no alignments in grammar do not fail
+ if (alignments == null){
+ return null;
+ }
+ return getAlignmentArray(source[address + 2]);
+ });
+ return result;
+ }
+
+ @Override
+ public void setArity(int arity) {
+ }
+
+ @Override
+ public int getArity() {
+ return PackedTrie.this.getArity();
+ }
+
+ @Override
+ public void setOwner(int ow) {
+ }
+
+ @Override
+ public int getOwner() {
+ return owner;
+ }
+
+ @Override
+ public void setLHS(int lhs) {
+ }
+
+ @Override
+ public int getLHS() {
+ return source[address];
+ }
+
+ @Override
+ public void setEnglish(int[] eng) {
+ }
+
+ @Override
+ public int[] getEnglish() {
+ return this.englishSupplier.get();
+ }
+
+ @Override
+ public void setFrench(int[] french) {
+ }
+
+ @Override
+ public int[] getFrench() {
+ return src;
+ }
+
+ @Override
+ public FeatureVector getFeatureVector() {
+ return this.featureVectorSupplier.get();
+ }
+
+ @Override
+ public byte[] getAlignment() {
+ return this.alignmentsSupplier.get();
+ }
+
+ @Override
+ public String getAlignmentString() {
+ throw new RuntimeException("AlignmentString not implemented for PackedRule!");
+ }
+
+ @Override
+ public float getEstimatedCost() {
+ return estimated[source[address + 2]];
+ }
+
+// @Override
+// public void setPrecomputableCost(float cost) {
+// precomputable[source[address + 2]] = cost;
+// }
+
+ @Override
+ public float getPrecomputableCost() {
+ return precomputable[source[address + 2]];
+ }
+
+ @Override
+ public float estimateRuleCost(List<FeatureFunction> models) {
+ return estimated[source[address + 2]];
+ }
+
+ @Override
+ public String toString() {
+ StringBuffer sb = new StringBuffer();
+ sb.append(Vocabulary.word(this.getLHS()));
+ sb.append(" ||| ");
+ sb.append(getFrenchWords());
+ sb.append(" ||| ");
+ sb.append(getEnglishWords());
+ sb.append(" |||");
+ sb.append(" " + getFeatureVector());
+ sb.append(String.format(" ||| %.3f", getEstimatedCost()));
+ return sb.toString();
+ }
+ }
+ }
+ }
+
+ @Override
+ public boolean isRegexpGrammar() {
+ return false;
+ }
+
+ @Override
+ public void addOOVRules(int word, List<FeatureFunction> featureFunctions) {
+ throw new RuntimeException("PackedGrammar.addOOVRules(): I can't add OOV rules");
+ }
+
+ @Override
+ public void addRule(Rule rule) {
+ throw new RuntimeException("PackedGrammar.addRule(): I can't add rules");
+ }
+
++ /**
++ * Read the config file
++ *
++ * TODO: this should be rewritten using typeconfig.
++ *
++ * @param config
++ * @throws IOException
++ */
+ private void readConfig(String config) throws IOException {
++ int version = 0;
++
+ for (String line: new LineReader(config)) {
+ String[] tokens = line.split(" = ");
+ if (tokens[0].equals("max-source-len"))
+ this.maxSourcePhraseLength = Integer.parseInt(tokens[1]);
++ else if (tokens[0].equals("version")) {
++ version = Integer.parseInt(tokens[1]);
++ }
++ }
++
++ if (version != 3) {
++ String message = String.format("The grammar at %s was packed with packer version %d, but the earliest supported version is %d",
++ this.grammarDir, version, SUPPORTED_VERSION);
++ throw new RuntimeException(message);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/decoder/hypergraph/GrammarBuilderWalkerFunction.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/hypergraph/GrammarBuilderWalkerFunction.java
index 7908d28,0000000..a6edddd
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/hypergraph/GrammarBuilderWalkerFunction.java
+++ b/src/main/java/org/apache/joshua/decoder/hypergraph/GrammarBuilderWalkerFunction.java
@@@ -1,180 -1,0 +1,180 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.hypergraph;
+
+import java.io.PrintStream;
+import java.util.HashSet;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.ff.tm.Grammar;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
+import org.apache.joshua.decoder.ff.tm.hash_based.MemoryBasedBatchGrammar;
++import org.apache.joshua.util.FormatUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This walker function builds up a new context-free grammar by visiting each node in a hypergraph.
+ * For a quick overview, see Chris Dyer's 2010 NAACL paper
+ * "Two monlingual parses are better than one (synchronous parse)".
+ * <p>
+ * From a functional-programming point of view, this walker really wants to calculate a fold over
+ * the entire hypergraph: the initial value is an empty grammar, and as we visit each node, we add
+ * more rules to the grammar. After we have traversed the whole hypergraph, the resulting grammar
+ * will contain all rules needed for synchronous parsing.
+ * <p>
+ * These rules look just like the rules already present in the hypergraph, except that each
+ * non-terminal symbol is annotated with the span of its node.
+ */
+public class GrammarBuilderWalkerFunction implements WalkerFunction {
+
+ private static final Logger LOG = LoggerFactory.getLogger(GrammarBuilderWalkerFunction.class);
+
+ private MemoryBasedBatchGrammar grammar;
+ private static HieroFormatReader reader = new HieroFormatReader();
+ private PrintStream outStream;
+ private int goalSymbol;
+ private HashSet<Rule> rules;
+
+ public GrammarBuilderWalkerFunction(String goal,JoshuaConfiguration joshuaConfiguration) {
+ grammar = new MemoryBasedBatchGrammar(reader,joshuaConfiguration);
+ grammar.setSpanLimit(1000);
+ outStream = null;
+ goalSymbol = Vocabulary.id(goal);
+ rules = new HashSet<Rule>();
+ }
+
+ public GrammarBuilderWalkerFunction(String goal, PrintStream out,JoshuaConfiguration joshuaConfiguration) {
+ this(goal,joshuaConfiguration);
+ outStream = out;
+ }
+
+ public void apply(HGNode node, int index) {
+ // System.err.printf("VISITING NODE: %s\n", getLabelWithSpan(node));
+ for (HyperEdge e : node.hyperedges) {
+ Rule r = getRuleWithSpans(e, node);
+ if (r != null && !rules.contains(r)) {
+ if (outStream != null) outStream.println(r);
+ grammar.addRule(r);
+ rules.add(r);
+ }
+ }
+ }
+
+ private static int getLabelWithSpan(HGNode node) {
+ return Vocabulary.id(getLabelWithSpanAsString(node));
+ }
+
+ private static String getLabelWithSpanAsString(HGNode node) {
+ String label = Vocabulary.word(node.lhs);
- String cleanLabel = HieroFormatReader.cleanNonTerminal(label);
- String unBracketedCleanLabel = cleanLabel.substring(1, cleanLabel.length() - 1);
++ String unBracketedCleanLabel = label.substring(1, label.length() - 1);
+ return String.format("[%d-%s-%d]", node.i, unBracketedCleanLabel, node.j);
+ }
+
+ private boolean nodeHasGoalSymbol(HGNode node) {
+ return node.lhs == goalSymbol;
+ }
+
+ private Rule getRuleWithSpans(HyperEdge edge, HGNode head) {
+ Rule edgeRule = edge.getRule();
+ int headLabel = getLabelWithSpan(head);
+ // System.err.printf("Head label: %s\n", headLabel);
+ // if (edge.getAntNodes() != null) {
+ // for (HGNode n : edge.getAntNodes())
+ // System.err.printf("> %s\n", getLabelWithSpan(n));
+ // }
+ int[] source = getNewSource(nodeHasGoalSymbol(head), edge);
+ // if this would be unary abstract, getNewSource will be null
+ if (source == null) return null;
+ int[] target = getNewTargetFromSource(source);
+ Rule result =
+ new Rule(headLabel, source, target, edgeRule.getFeatureString(), edgeRule.getArity());
+ // System.err.printf("new rule is %s\n", result);
+ return result;
+ }
+
+ private static int[] getNewSource(boolean isGlue, HyperEdge edge) {
+ Rule rule = edge.getRule();
+ int[] english = rule.getEnglish();
+ // if this is a unary abstract rule, just return null
+ // TODO: except glue rules!
+ if (english.length == 1 && english[0] < 0 && !isGlue) return null;
+ int[] result = new int[english.length];
+ for (int i = 0; i < english.length; i++) {
+ int curr = english[i];
- if (!Vocabulary.nt(curr)) {
- // If it's a terminal symbol, we just copy it into the new rule.
++ if (! FormatUtils.isNonterminal(curr)) {
++ // If it's a terminal symbol, we just copy it into the new rule.
+ result[i] = curr;
+ } else {
- // If it's a nonterminal, its value is -N, where N is the index
- // of the nonterminal on the source side.
- //
- // That is, if we would call a nonterminal "[X,2]", the value of
- // curr at this point is -2. And the tail node that it points at
- // is #1 (since getTailNodes() is 0-indexed).
++ // If it's a nonterminal, its value is -N, where N is the index
++ // of the nonterminal on the source side.
++ //
++ // That is, if we would call a nonterminal "[X,2]", the value of
++ // curr at this point is -2. And the tail node that it points at
++ // is #1 (since getTailNodes() is 0-indexed).
+ int index = -curr - 1;
+ result[i] = getLabelWithSpan(edge.getTailNodes().get(index));
+ }
+ }
+ // System.err.printf("source: %s\n", result);
+ return result;
+ }
+
+ private static int[] getNewTargetFromSource(int[] source) {
+ int[] result = new int[source.length];
- int currNT = -1; // value to stick into NT slots
++ int currNT = -1; // value to stick into NT slots
+ for (int i = 0; i < source.length; i++) {
+ result[i] = source[i];
- if (Vocabulary.nt(result[i])) {
++ if (FormatUtils.isNonterminal(result[i])) {
+ result[i] = currNT;
- currNT--;
++ currNT--;
+ }
+ }
+ // System.err.printf("target: %s\n", result);
+ return result;
+ }
+
+ private static HGNode getGoalSymbolNode(HGNode root) {
+ if (root.hyperedges == null || root.hyperedges.size() == 0) {
+ LOG.error("getGoalSymbolNode: root node has no hyperedges");
+ return null;
+ }
+ return root.hyperedges.get(0).getTailNodes().get(0);
+ }
+
+
+ public static int goalSymbol(HyperGraph hg) {
+ if (hg.goalNode == null) {
+ LOG.error("goalSymbol: goalNode of hypergraph is null");
+ return -1;
+ }
+ HGNode symbolNode = getGoalSymbolNode(hg.goalNode);
+ if (symbolNode == null) return -1;
+ // System.err.printf("goalSymbol: %s\n", result);
+ // System.err.printf("symbol node LHS is %d\n", symbolNode.lhs);
+ // System.err.printf("i = %d, j = %d\n", symbolNode.i, symbolNode.j);
+ return getLabelWithSpan(symbolNode);
+ }
+
+ public Grammar getGrammar() {
+ return grammar;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/decoder/hypergraph/HyperEdge.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/hypergraph/HyperEdge.java
index d7bcc4d,0000000..b188650
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/hypergraph/HyperEdge.java
+++ b/src/main/java/org/apache/joshua/decoder/hypergraph/HyperEdge.java
@@@ -1,108 -1,0 +1,101 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.hypergraph;
+
+import java.util.List;
+
+import org.apache.joshua.decoder.chart_parser.SourcePath;
+import org.apache.joshua.decoder.ff.tm.Rule;
+
+/**
+ * this class implement Hyperedge
+ *
+ * @author Zhifei Li, zhifei.work@gmail.com
+ * @author Matt Post post@cs.jhu.edu
+ */
+
+public class HyperEdge {
+
+ /**
+ * the 1-best logP of all possible derivations: best logP of ant hgnodes + transitionlogP
+ **/
+ private float bestDerivationScore = Float.NEGATIVE_INFINITY;
+
+ /**
+ * this remembers the stateless + non_stateless logP assocated with the rule (excluding the
+ * best-logP from ant nodes)
+ * */
- private Float transitionScore = null;
++ private float transitionScore;
+
+ private Rule rule;
+
+ private SourcePath srcPath = null;
+
+ /**
+ * If antNodes is null, then this edge corresponds to a rule with zero arity. Aslo, the nodes
+ * appear in the list as per the index of the Foreign side non-terminal
+ * */
+ private List<HGNode> tailNodes = null;
+
+ public HyperEdge(Rule rule, float bestDerivationScore, float transitionScore,
+ List<HGNode> tailNodes, SourcePath srcPath) {
+ this.bestDerivationScore = bestDerivationScore;
+ this.transitionScore = transitionScore;
+ this.rule = rule;
+ this.tailNodes = tailNodes;
+ this.srcPath = srcPath;
+ }
+
+ public Rule getRule() {
+ return rule;
+ }
+
+ public float getBestDerivationScore() {
+ return bestDerivationScore;
+ }
+
+ public SourcePath getSourcePath() {
+ return srcPath;
+ }
+
+ public List<HGNode> getTailNodes() {
+ return tailNodes;
+ }
+
+ public float getTransitionLogP(boolean forceCompute) {
- StringBuilder sb = new StringBuilder();
- if (forceCompute || transitionScore == null) {
++ if (forceCompute) {
+ float res = bestDerivationScore;
- sb.append(String.format("Best derivation = %.5f", res));
+ if (tailNodes != null) for (HGNode tailNode : tailNodes) {
+ res += tailNode.bestHyperedge.bestDerivationScore;
- sb.append(String.format(", tail = %.5f", tailNode.bestHyperedge.bestDerivationScore));
+ }
+ transitionScore = res;
+ }
- // System.err.println("HYPEREDGE SCORE = " + sb.toString());
+ return transitionScore;
+ }
+
+ public void setTransitionLogP(float transitionLogP) {
+ this.transitionScore = transitionLogP;
+ }
+
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.rule);
- // if (getTailNodes() != null) for (HGNode tailNode : getTailNodes()) {
- // sb.append(" tail=" + tailNode);
- // }
+ return sb.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/decoder/hypergraph/OutputStringExtractor.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/hypergraph/OutputStringExtractor.java
index 4366e21,0000000..341edbd
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/hypergraph/OutputStringExtractor.java
+++ b/src/main/java/org/apache/joshua/decoder/hypergraph/OutputStringExtractor.java
@@@ -1,195 -1,0 +1,196 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.hypergraph;
+
+import static java.lang.Math.min;
+import static org.apache.joshua.corpus.Vocabulary.getWords;
- import static org.apache.joshua.corpus.Vocabulary.nt;
++import static org.apache.joshua.util.FormatUtils.isNonterminal;
+
+import java.util.Stack;
+
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationState;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationVisitor;
++import org.apache.joshua.util.FormatUtils;
+
+public class OutputStringExtractor implements WalkerFunction, DerivationVisitor {
+
+ public OutputStringExtractor(final boolean extractSource) {
+ this.extractSource = extractSource;
+ }
+
+ private Stack<OutputString> outputStringStack = new Stack<>();
+ private final boolean extractSource;
+
+ @Override
+ public void apply(HGNode node, int nodeIndex) {
+ apply(node.bestHyperedge.getRule(), nodeIndex);
+ }
+
+ /**
+ * Visiting a node during k-best extraction is the same as
+ * apply() for Viterbi extraction but using the edge from
+ * the Derivation state.
+ */
+ @Override
+ public void before(final DerivationState state, int level, int tailNodeIndex) {
+ apply(state.edge.getRule(), tailNodeIndex);
+ }
+
+ private void apply(Rule rule, int nodeIndex) {
+ if (rule != null) {
+ final int[] words = extractSource ? rule.getFrench() : rule.getEnglish();
+ merge(new OutputString(words, rule.getArity(), nodeIndex));
+ }
+ }
+
+ /** Nothing to do */
+ @Override
+ public void after(DerivationState state, int level, int tailNodeIndex) {}
+
+ private static int getSourceNonTerminalPosition(final int[] words, int nonTerminalIndex) {
+ int nonTerminalsSeen = 0;
+ for (int i = 0; i < words.length; i++) {
- if (nt(words[i])) {
++ if (FormatUtils.isNonterminal(words[i])) {
+ nonTerminalsSeen++;
+ if (nonTerminalsSeen == nonTerminalIndex) {
+ return i;
+ }
+ }
+ }
+ throw new RuntimeException(
+ String.format(
+ "Can not find %s-th non terminal in source ids: %s. This should not happen!",
+ nonTerminalIndex,
+ arrayToString(words)));
+ }
+
+ /**
+ * Returns the position of the nonTerminalIndex-th nonTerminal words.
+ * Non-terminals on target sides of rules are indexed by
+ * their order on the source side, e.g. '-1', '-2',
+ * Thus, if index==0 we return the index of '-1'.
+ * For index==1, we return index of '-2'
+ */
+ private static int getTargetNonTerminalPosition(int[] words, int nonTerminalIndex) {
+ for (int pos = 0; pos < words.length; pos++) {
- if (nt(words[pos]) && -(words[pos] + 1) == nonTerminalIndex) {
++ if (FormatUtils.isNonterminal(words[pos]) && -(words[pos] + 1) == nonTerminalIndex) {
+ return pos;
+ }
+ }
+ throw new RuntimeException(
+ String.format(
+ "Can not find %s-th non terminal in target ids: %s. This should not happen!",
+ nonTerminalIndex,
+ arrayToString(words)));
+ }
+
+ private static String arrayToString(int[] ids) {
+ StringBuilder sb = new StringBuilder();
+ for (int i : ids) {
+ sb.append(i + " ");
+ }
+ return sb.toString().trim();
+ }
+
+ private void substituteNonTerminal(
+ final OutputString parentState,
+ final OutputString childState) {
+ int mergePosition;
+ if (extractSource) {
+ /* correct nonTerminal is given by the tailNodePosition of the childState (zero-index, thus +1) and
+ * current parentState's arity. If the parentState has already filled one of two available slots,
+ * we need to use the remaining one, even if childState refers to the second slot.
+ */
+ mergePosition = getSourceNonTerminalPosition(
+ parentState.words, min(childState.tailNodePosition + 1, parentState.arity));
+ } else {
+ mergePosition = getTargetNonTerminalPosition(
+ parentState.words, childState.tailNodePosition);
+ }
+ parentState.substituteNonTerminalAtPosition(childState.words, mergePosition);
+ }
+
+ private void merge(final OutputString state) {
+ if (!outputStringStack.isEmpty()
+ && state.arity == 0) {
+ if (outputStringStack.peek().arity == 0) {
+ throw new IllegalStateException("Parent OutputString has arity of 0. Cannot merge.");
+ }
+ final OutputString parent = outputStringStack.pop();
+ substituteNonTerminal(parent, state);
+ merge(parent);
+ } else {
+ outputStringStack.add(state);
+ }
+ }
+
+ @Override
+ public String toString() {
+ if (outputStringStack.isEmpty()) {
+ return "";
+ }
+
+ if (outputStringStack.size() != 1) {
+ throw new IllegalStateException(
+ String.format(
+ "Stack should contain only a single (last) element, but was size %d", outputStringStack.size()));
+ }
+ return getWords(outputStringStack.pop().words);
+ }
+
+ /** Stores necessary information to obtain an output string on source or target side */
+ private class OutputString {
+
+ private int[] words;
+ private int arity;
+ private final int tailNodePosition;
+
+ private OutputString(int[] words, int arity, int tailNodePosition) {
+ this.words = words;
+ this.arity = arity;
+ this.tailNodePosition = tailNodePosition;
+ }
+
+ /**
+ * Merges child words into this at the correct
+ * non terminal position of this.
+ * The correct position is determined by the tailNodePosition
+ * of child and the arity of this.
+ */
+ private void substituteNonTerminalAtPosition(final int[] words, final int position) {
- assert(nt(this.words[position]));
++ assert(FormatUtils.isNonterminal(this.words[position]));
+ final int[] result = new int[words.length + this.words.length - 1];
+ int resultIndex = 0;
+ for (int i = 0; i < position; i++) {
+ result[resultIndex++] = this.words[i];
+ }
+ for (int i = 0; i < words.length; i++) {
+ result[resultIndex++] = words[i];
+ }
+ for (int i = position + 1; i < this.words.length; i++) {
+ result[resultIndex++] = this.words[i];
+ }
+ // update words and reduce arity of this OutputString
+ this.words = result;
+ arity--;
+ }
+ }
+
- }
++}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/abb8c518/src/main/java/org/apache/joshua/decoder/phrase/PhraseTable.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/phrase/PhraseTable.java
index 733e1e1,0000000..df3bd99
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/phrase/PhraseTable.java
+++ b/src/main/java/org/apache/joshua/decoder/phrase/PhraseTable.java
@@@ -1,202 -1,0 +1,197 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.phrase;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.tm.Grammar;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.RuleCollection;
+import org.apache.joshua.decoder.ff.tm.Trie;
+import org.apache.joshua.decoder.ff.tm.hash_based.MemoryBasedBatchGrammar;
+import org.apache.joshua.decoder.ff.tm.packed.PackedGrammar;
+
+/**
+ * Represents a phrase table, and is implemented as a wrapper around either a {@link PackedGrammar}
+ * or a {@link MemoryBasedBatchGrammar}.
+ *
+ * TODO: this should all be implemented as a two-level trie (source trie and target trie).
+ */
+public class PhraseTable implements Grammar {
+
+ private JoshuaConfiguration config;
+ private Grammar backend;
+
+ /**
+ * Chain to the super with a number of defaults. For example, we only use a single nonterminal,
+ * and there is no span limit.
+ *
+ * @param grammarFile file path parent directory
+ * @param owner used to set phrase owners
+ * @param type the grammar specification keyword (e.g., "thrax" or "moses")
+ * @param config a populated {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ * @throws IOException if there is an error reading the grammar file
+ */
+ public PhraseTable(String grammarFile, String owner, String type, JoshuaConfiguration config)
+ throws IOException {
+ this.config = config;
+ int spanLimit = 0;
+
+ if (grammarFile != null && new File(grammarFile).isDirectory()) {
+ this.backend = new PackedGrammar(grammarFile, spanLimit, owner, type, config);
+ if (this.backend.getMaxSourcePhraseLength() == -1) {
+ String msg = "FATAL: Using a packed grammar for a phrase table backend requires that you "
+ + "packed the grammar with Joshua 6.0.2 or greater";
+ throw new RuntimeException(msg);
+ }
+
+ } else {
+ this.backend = new MemoryBasedBatchGrammar(type, grammarFile, owner, "[X]", spanLimit, config);
+ }
+ }
+
+ public PhraseTable(String owner, JoshuaConfiguration config) {
+ this.config = config;
+
+ this.backend = new MemoryBasedBatchGrammar(owner, config);
+ }
+
+ /**
- * Returns the longest source phrase read. For {@link MemoryBasedBatchGrammar}s, we subtract 1
- * since the grammar includes the nonterminal. For {@link PackedGrammar}s, the value was either
- * in the packed config file (Joshua 6.0.2+) or was passed in via the TM config line.
++ * Returns the longest source phrase read. Because phrases have a dummy nonterminal prepended to
++ * them, we need to subtract 1.
+ *
+ * @return the longest source phrase read.
+ */
+ @Override
+ public int getMaxSourcePhraseLength() {
- if (backend instanceof MemoryBasedBatchGrammar)
- return this.backend.getMaxSourcePhraseLength() - 1;
- else
- return this.backend.getMaxSourcePhraseLength();
++ return this.backend.getMaxSourcePhraseLength() - 1;
+ }
+
+ /**
+ * Collect the set of target-side phrases associated with a source phrase.
+ *
+ * @param sourceWords the sequence of source words
+ * @return the rules
+ */
+ public RuleCollection getPhrases(int[] sourceWords) {
+ if (sourceWords.length != 0) {
+ Trie pointer = getTrieRoot();
- if (! (backend instanceof PackedGrammar))
- pointer = pointer.match(Vocabulary.id("[X]"));
++ pointer = pointer.match(Vocabulary.id("[X]"));
+ int i = 0;
+ while (pointer != null && i < sourceWords.length)
+ pointer = pointer.match(sourceWords[i++]);
+
+ if (pointer != null && pointer.hasRules()) {
+ return pointer.getRuleCollection();
+ }
+ }
+
+ return null;
+ }
+
+ /**
+ * Adds a rule to the grammar. Only supported when the backend is a MemoryBasedBatchGrammar.
+ *
+ * @param rule the rule to add
+ */
+ public void addRule(Rule rule) {
+ ((MemoryBasedBatchGrammar)backend).addRule(rule);
+ }
+
+ @Override
+ public void addOOVRules(int sourceWord, List<FeatureFunction> featureFunctions) {
+ // TODO: _OOV shouldn't be outright added, since the word might not be OOV for the LM (but now almost
+ // certainly is)
+ int targetWord = config.mark_oovs
+ ? Vocabulary.id(Vocabulary.word(sourceWord) + "_OOV")
+ : sourceWord;
+
+ int nt_i = Vocabulary.id("[X]");
+ Rule oovRule = new Rule(nt_i, new int[] { nt_i, sourceWord },
+ new int[] { -1, targetWord }, "", 1, null);
+ addRule(oovRule);
+ oovRule.estimateRuleCost(featureFunctions);
+
+// String ruleString = String.format("[X] ||| [X,1] %s ||| [X,1] %s",
+// Vocabulary.word(sourceWord), Vocabulary.word(targetWord));
+// BilingualRule oovRule = new HieroFormatReader().parseLine(ruleString);
+// oovRule.setOwner(Vocabulary.id("oov"));
+// addRule(oovRule);
+// oovRule.estimateRuleCost(featureFunctions);
+ }
+
+ @Override
+ public Trie getTrieRoot() {
+ return backend.getTrieRoot();
+ }
+
+ @Override
+ public void sortGrammar(List<FeatureFunction> models) {
+ backend.sortGrammar(models);
+ }
+
+ @Override
+ public boolean isSorted() {
+ return backend.isSorted();
+ }
+
+ /**
+ * This should never be called.
+ */
+ @Override
+ public boolean hasRuleForSpan(int startIndex, int endIndex, int pathLength) {
+ return true;
+ }
+
+ @Override
+ public int getNumRules() {
+ return backend.getNumRules();
+ }
+
+ @Override
+ public Rule constructManualRule(int lhs, int[] sourceWords, int[] targetWords, float[] scores,
+ int arity) {
+ return backend.constructManualRule(lhs, sourceWords, targetWords, scores, arity);
+ }
+
+ @Override
+ public void writeGrammarOnDisk(String file) {
+ backend.writeGrammarOnDisk(file);
+ }
+
+ @Override
+ public boolean isRegexpGrammar() {
+ return backend.isRegexpGrammar();
+ }
+
+ @Override
+ public int getOwner() {
+ return backend.getOwner();
+ }
+
+ @Override
+ public int getNumDenseFeatures() {
+ return backend.getNumDenseFeatures();
+ }
+}