You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/08/30 21:04:59 UTC
[14/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
index e88f00a,0000000..5c75188
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
@@@ -1,239 -1,0 +1,238 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
- import static java.util.Arrays.asList;
+import static org.apache.joshua.decoder.StructuredTranslationFactory.fromViterbiDerivation;
+import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiFeatures;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiString;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiWordAlignments;
+import static org.apache.joshua.util.FormatUtils.removeSentenceMarkers;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.io.StringWriter;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.lm.StateMinimizingLanguageModel;
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor;
+import org.apache.joshua.decoder.io.DeNormalize;
+import org.apache.joshua.decoder.segment_file.Sentence;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class represents translated input objects (sentences or lattices). It is aware of the source
+ * sentence and id and contains the decoded hypergraph. Translation objects are returned by
+ * DecoderTask instances to the InputHandler, where they are assembled in order for output.
- *
++ *
+ * @author Matt Post post@cs.jhu.edu
+ * @author Felix Hieber fhieber@amazon.com
+ */
+
+public class Translation {
+ private static final Logger LOG = LoggerFactory.getLogger(Translation.class);
+ private final Sentence source;
+
+ /**
+ * This stores the output of the translation so we don't have to hold onto the hypergraph while we
+ * wait for the outputs to be assembled.
+ */
+ private String output = null;
+
+ /**
+ * Stores the list of StructuredTranslations.
+ * If joshuaConfig.topN == 0, will only contain the Viterbi translation.
+ * Else it will use KBestExtractor to populate this list.
+ */
+ private List<StructuredTranslation> structuredTranslations = null;
-
- public Translation(Sentence source, HyperGraph hypergraph,
++
++ public Translation(Sentence source, HyperGraph hypergraph,
+ List<FeatureFunction> featureFunctions, JoshuaConfiguration joshuaConfiguration) {
+ this.source = source;
-
++
+ /**
+ * Structured output from Joshua provides a way to programmatically access translation results
+ * from downstream applications, instead of writing results as strings to an output buffer.
+ */
+ if (joshuaConfiguration.use_structured_output) {
-
++
+ if (joshuaConfiguration.topN == 0) {
+ /*
+ * Obtain Viterbi StructuredTranslation
+ */
+ StructuredTranslation translation = fromViterbiDerivation(source, hypergraph, featureFunctions);
+ this.output = translation.getTranslationString();
+ structuredTranslations = Collections.singletonList(translation);
-
++
+ } else {
+ /*
+ * Get K-Best list of StructuredTranslations
+ */
+ final KBestExtractor kBestExtractor = new KBestExtractor(source, featureFunctions, Decoder.weights, false, joshuaConfiguration);
+ structuredTranslations = kBestExtractor.KbestExtractOnHG(hypergraph, joshuaConfiguration.topN);
+ if (structuredTranslations.isEmpty()) {
+ structuredTranslations = Collections
+ .singletonList(StructuredTranslationFactory.fromEmptyOutput(source));
+ this.output = "";
+ } else {
+ this.output = structuredTranslations.get(0).getTranslationString();
+ }
+ // TODO: We omit the BLEU rescoring for now since it is not clear whether it works at all and what the desired output is below.
+ }
+
+ } else {
+
+ StringWriter sw = new StringWriter();
+ BufferedWriter out = new BufferedWriter(sw);
+
+ try {
-
++
+ if (hypergraph != null) {
-
++
+ long startTime = System.currentTimeMillis();
+
+ if (joshuaConfiguration.topN == 0) {
+
+ /* construct Viterbi output */
+ final String best = getViterbiString(hypergraph);
+
+ LOG.info("Translation {}: {} {}", source.id(), hypergraph.goalNode.getScore(), best);
+
+ /*
+ * Setting topN to 0 turns off k-best extraction, in which case we need to parse through
+ * the output-string, with the understanding that we can only substitute variables for the
+ * output string, sentence number, and model score.
+ */
+ String translation = joshuaConfiguration.outputFormat
+ .replace("%s", removeSentenceMarkers(best))
+ .replace("%S", DeNormalize.processSingleLine(best))
+ .replace("%c", String.format("%.3f", hypergraph.goalNode.getScore()))
+ .replace("%i", String.format("%d", source.id()));
+
+ if (joshuaConfiguration.outputFormat.contains("%a")) {
+ translation = translation.replace("%a", getViterbiWordAlignments(hypergraph));
+ }
+
+ if (joshuaConfiguration.outputFormat.contains("%f")) {
+ final FeatureVector features = getViterbiFeatures(hypergraph, featureFunctions, source);
+ translation = translation.replace("%f", features.textFormat());
+ }
+
+ out.write(translation);
+ out.newLine();
+
+ } else {
+
+ final KBestExtractor kBestExtractor = new KBestExtractor(
+ source, featureFunctions, Decoder.weights, false, joshuaConfiguration);
+ kBestExtractor.lazyKBestExtractOnHG(hypergraph, joshuaConfiguration.topN, out);
+
+ if (joshuaConfiguration.rescoreForest) {
+ final int bleuFeatureHash = hashFeature("BLEU");
+ Decoder.weights.add(bleuFeatureHash, joshuaConfiguration.rescoreForestWeight);
+ kBestExtractor.lazyKBestExtractOnHG(hypergraph, joshuaConfiguration.topN, out);
+
+ Decoder.weights.add(bleuFeatureHash, -joshuaConfiguration.rescoreForestWeight);
+ kBestExtractor.lazyKBestExtractOnHG(hypergraph, joshuaConfiguration.topN, out);
+ }
+ }
+
- float seconds = (float) (System.currentTimeMillis() - startTime) / 1000.0f;
++ float seconds = (System.currentTimeMillis() - startTime) / 1000.0f;
+ LOG.info("Input {}: {}-best extraction took {} seconds", id(),
+ joshuaConfiguration.topN, seconds);
+
+ } else {
-
++
+ // Failed translations and blank lines get empty formatted outputs
+ out.write(getFailedTranslationOutput(source, joshuaConfiguration));
+ out.newLine();
-
++
+ }
+
+ out.flush();
-
++
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ this.output = sw.toString();
+
+ }
-
++
+ // remove state from StateMinimizingLanguageModel instances in features.
+ destroyKenLMStates(featureFunctions);
+
+ }
+
+ public Sentence getSourceSentence() {
+ return this.source;
+ }
+
+ public int id() {
+ return source.id();
+ }
+
+ @Override
+ public String toString() {
+ return output;
+ }
-
++
+ private String getFailedTranslationOutput(final Sentence source, final JoshuaConfiguration joshuaConfiguration) {
+ return joshuaConfiguration.outputFormat
+ .replace("%s", source.source())
+ .replace("%e", "")
+ .replace("%S", "")
+ .replace("%t", "()")
+ .replace("%i", Integer.toString(source.id()))
+ .replace("%f", "")
+ .replace("%c", "0.000");
+ }
-
++
+ /**
+ * Returns the StructuredTranslations
+ * if JoshuaConfiguration.use_structured_output == True.
+ * @throws RuntimeException if JoshuaConfiguration.use_structured_output == False.
+ * @return List of StructuredTranslations.
+ */
+ public List<StructuredTranslation> getStructuredTranslations() {
+ if (structuredTranslations == null) {
+ throw new RuntimeException(
+ "No StructuredTranslation objects created. You should set JoshuaConfigration.use_structured_output = true");
+ }
+ return structuredTranslations;
+ }
-
++
+ /**
+ * KenLM hack. If using KenLMFF, we need to tell KenLM to delete the pool used to create chart
+ * objects for this sentence.
+ */
+ private void destroyKenLMStates(final List<FeatureFunction> featureFunctions) {
+ for (FeatureFunction feature : featureFunctions) {
+ if (feature instanceof StateMinimizingLanguageModel) {
+ ((StateMinimizingLanguageModel) feature).destroyPool(getSourceSentence().id());
+ break;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
index 8b5c81a,0000000..0e5139a
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
@@@ -1,474 -1,0 +1,438 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.chart_parser;
+
+import java.util.ArrayList;
- import java.util.HashMap;
+import java.util.List;
- import java.util.Map;
+
- import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.tm.Grammar;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.RuleCollection;
+import org.apache.joshua.decoder.ff.tm.Trie;
+import org.apache.joshua.decoder.segment_file.Token;
+import org.apache.joshua.lattice.Arc;
+import org.apache.joshua.lattice.Lattice;
+import org.apache.joshua.lattice.Node;
+import org.apache.joshua.util.ChartSpan;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The DotChart handles Earley-style implicit binarization of translation rules.
- *
++ *
+ * The {@link DotNode} object represents the (possibly partial) application of a synchronous rule.
+ * The implicit binarization is maintained with a pointer to the {@link Trie} node in the grammar,
+ * for easy retrieval of the next symbol to be matched. At every span (i,j) of the input sentence,
+ * every incomplete DotNode is examined to see whether it (a) needs a terminal and matches against
+ * the final terminal of the span or (b) needs a nonterminal and matches against a completed
+ * nonterminal in the main chart at some split point (k,j).
- *
++ *
+ * Once a rule is completed, it is entered into the {@link DotChart}. {@link DotCell} objects are
+ * used to group completed DotNodes over a span.
- *
++ *
+ * There is a separate DotChart for every grammar.
- *
++ *
+ * @author Zhifei Li, <zh...@gmail.com>
+ * @author Matt Post <po...@cs.jhu.edu>
+ * @author Kristy Hollingshead Seitz
+ */
+class DotChart {
+
+ // ===============================================================
+ // Static fields
+ // ===============================================================
+
+ private static final Logger LOG = LoggerFactory.getLogger(DotChart.class);
+
+
+ // ===============================================================
+ // Package-protected instance fields
+ // ===============================================================
+ /**
+ * Two-dimensional chart of cells. Some cells might be null. This could definitely be represented
+ * more efficiently, since only the upper half of this triangle is every used.
+ */
+ private final ChartSpan<DotCell> dotcells;
+
+ public DotCell getDotCell(int i, int j) {
+ return dotcells.get(i, j);
+ }
+
+ // ===============================================================
+ // Private instance fields (maybe could be protected instead)
+ // ===============================================================
+
+ /**
+ * CKY+ style parse chart in which completed span entries are stored.
+ */
+ private final Chart dotChart;
+
+ /**
+ * Translation grammar which contains the translation rules.
+ */
+ private final Grammar pGrammar;
+
+ /* Length of input sentence. */
+ private final int sentLen;
+
+ /* Represents the input sentence being translated. */
+ private final Lattice<Token> input;
+
+ // ===============================================================
+ // Constructors
+ // ===============================================================
+
+ // TODO: Maybe this should be a non-static inner class of Chart. That would give us implicit
+ // access to all the arguments of this constructor. Though we would need to take an argument, i,
+ // to know which Chart.this.grammars[i] to use.
+
+ /**
+ * Constructs a new dot chart from a specified input lattice, a translation grammar, and a parse
+ * chart.
- *
++ *
+ * @param input A lattice which represents an input sentence.
+ * @param grammar A translation grammar.
+ * @param chart A CKY+ style chart in which completed span entries are stored.
+ */
+ public DotChart(Lattice<Token> input, Grammar grammar, Chart chart) {
+
+ this.dotChart = chart;
+ this.pGrammar = grammar;
+ this.input = input;
+ this.sentLen = input.size();
+ this.dotcells = new ChartSpan<>(sentLen, null);
+
+ seed();
+ }
+
+ /**
+ * Add initial dot items: dot-items pointer to the root of the grammar trie.
+ */
+ void seed() {
+ for (int j = 0; j <= sentLen - 1; j++) {
+ if (pGrammar.hasRuleForSpan(j, j, input.distance(j, j))) {
+ if (null == pGrammar.getTrieRoot()) {
+ throw new RuntimeException("trie root is null");
+ }
+ addDotItem(pGrammar.getTrieRoot(), j, j, null, null, new SourcePath());
+ }
+ }
+ }
+
+ /**
+ * This function computes all possible expansions of all rules over the provided span (i,j). By
+ * expansions, we mean the moving of the dot forward (from left to right) over a nonterminal or
+ * terminal symbol on the rule's source side.
- *
++ *
+ * There are two kinds of expansions:
- *
++ *
+ * <ol>
+ * <li>Expansion over a nonterminal symbol. For this kind of expansion, a rule has a dot
+ * immediately prior to a source-side nonterminal. The main Chart is consulted to see whether
+ * there exists a completed nonterminal with the same label. If so, the dot is advanced.
- *
++ *
+ * Discovering nonterminal expansions is a matter of enumerating all split points k such that i <
+ * k and k < j. The nonterminal symbol must exist in the main Chart over (k,j).
- *
++ *
+ * <li>Expansion over a terminal symbol. In this case, expansion is a simple matter of determing
+ * whether the input symbol at position j (the end of the span) matches the next symbol in the
+ * rule. This is equivalent to choosing a split point k = j - 1 and looking for terminal symbols
+ * over (k,j). Note that phrases in the input rule are handled one-by-one as we consider longer
+ * spans.
+ * </ol>
+ */
+ void expandDotCell(int i, int j) {
+ if (LOG.isDebugEnabled())
+ LOG.debug("Expanding dot cell ({}, {})", i, j);
+
+ /*
+ * (1) If the dot is just to the left of a non-terminal variable, we look for theorems or axioms
+ * in the Chart that may apply and extend the dot position. We look for existing axioms over all
+ * spans (k,j), i < k < j.
+ */
+ for (int k = i + 1; k < j; k++) {
+ extendDotItemsWithProvedItems(i, k, j, false);
+ }
+
+ /*
+ * (2) If the the dot-item is looking for a source-side terminal symbol, we simply match against
+ * the input sentence and advance the dot.
+ */
+ Node<Token> node = input.getNode(j - 1);
+ for (Arc<Token> arc : node.getOutgoingArcs()) {
+
+ int last_word = arc.getLabel().getWord();
+ int arc_len = arc.getHead().getNumber() - arc.getTail().getNumber();
+
+ // int last_word=foreign_sent[j-1]; // input.getNode(j-1).getNumber(); //
+
+ if (null != dotcells.get(i, j - 1)) {
+ // dotitem in dot_bins[i][k]: looking for an item in the right to the dot
+
+
+ for (DotNode dotNode : dotcells.get(i, j - 1).getDotNodes()) {
+
+ // String arcWord = Vocabulary.word(last_word);
+ // Assert.assertFalse(arcWord.endsWith("]"));
+ // Assert.assertFalse(arcWord.startsWith("["));
+ // logger.info("DotChart.expandDotCell: " + arcWord);
+
+ // List<Trie> child_tnodes = ruleMatcher.produceMatchingChildTNodesTerminalevel(dotNode,
+ // last_word);
+
- List<Trie> child_tnodes = null;
-
+ Trie child_node = dotNode.trieNode.match(last_word);
+ if (null != child_node) {
+ addDotItem(child_node, i, j - 1 + arc_len, dotNode.antSuperNodes, null,
+ dotNode.srcPath.extend(arc));
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * note: (i,j) is a non-terminal, this cannot be a cn-side terminal, which have been handled in
+ * case2 of dotchart.expand_cell add dotitems that start with the complete super-items in
+ * cell(i,j)
+ */
+ void startDotItems(int i, int j) {
+ extendDotItemsWithProvedItems(i, i, j, true);
+ }
+
+ // ===============================================================
+ // Private methods
+ // ===============================================================
+
+ /**
+ * Attempt to combine an item in the dot chart with an item in the main chart to create a new item
+ * in the dot chart. The DotChart item is a {@link DotNode} begun at position i with the dot
+ * currently at position k, that is, a partially-applied rule.
- *
++ *
+ * In other words, this method looks for (proved) theorems or axioms in the completed chart that
+ * may apply and extend the dot position.
- *
++ *
+ * @param i Start index of a dot chart item
+ * @param k End index of a dot chart item; start index of a completed chart item
+ * @param j End index of a completed chart item
+ * @param skipUnary if true, don't extend unary rules
+ */
+ private void extendDotItemsWithProvedItems(int i, int k, int j, boolean skipUnary) {
+ if (this.dotcells.get(i, k) == null || this.dotChart.getCell(k, j) == null) {
+ return;
+ }
+
+ // complete super-items (items over the same span with different LHSs)
+ List<SuperNode> superNodes = new ArrayList<>(this.dotChart.getCell(k, j).getSortedSuperItems().values());
+
+ /* For every partially complete item over (i,k) */
+ for (DotNode dotNode : dotcells.get(i, k).dotNodes) {
+ /* For every completed nonterminal in the main chart */
+ for (SuperNode superNode : superNodes) {
+
+ // String arcWord = Vocabulary.word(superNode.lhs);
+ // logger.info("DotChart.extendDotItemsWithProvedItems: " + arcWord);
+ // Assert.assertTrue(arcWord.endsWith("]"));
+ // Assert.assertTrue(arcWord.startsWith("["));
+
+ /*
+ * Regular Expression matching allows for a regular-expression style rules in the grammar,
+ * which allows for a very primitive treatment of morphology. This is an advanced,
+ * undocumented feature that introduces a complexity, in that the next "word" in the grammar
+ * rule might match more than one outgoing arc in the grammar trie.
+ */
+ Trie child_node = dotNode.getTrieNode().match(superNode.lhs);
+ if (child_node != null) {
+ if ((!skipUnary) || (child_node.hasExtensions())) {
+ addDotItem(child_node, i, j, dotNode.getAntSuperNodes(), superNode, dotNode
+ .getSourcePath().extendNonTerminal());
+ }
+ }
+ }
+ }
+ }
+
- /*
- * We introduced the ability to have regular expressions in rules for matching against terminals.
- * For example, you could have the rule
- *
- * <pre> [X] ||| l?s herman?s ||| siblings </pre>
- *
- * When this is enabled for a grammar, we need to test against *all* (positive) outgoing arcs of
- * the grammar trie node to see if any of them match, and then return the whole set. This is quite
- * expensive, which is why you should only enable regular expressions for small grammars.
- */
-
- private ArrayList<Trie> matchAll(DotNode dotNode, int wordID) {
- ArrayList<Trie> trieList = new ArrayList<>();
- HashMap<Integer, ? extends Trie> childrenTbl = dotNode.trieNode.getChildren();
-
- if (childrenTbl != null && wordID >= 0) {
- // get all the extensions, map to string, check for *, build regexp
- for (Map.Entry<Integer, ? extends Trie> entry : childrenTbl.entrySet()) {
- Integer arcID = entry.getKey();
- if (arcID == wordID) {
- trieList.add(entry.getValue());
- } else {
- String arcWord = Vocabulary.word(arcID);
- if (Vocabulary.word(wordID).matches(arcWord)) {
- trieList.add(entry.getValue());
- }
- }
- }
- }
- return trieList;
- }
-
-
+ /**
+ * Creates a {@link DotNode} and adds it into the {@link DotChart} at the correct place. These
- * are (possibly incomplete) rule applications.
- *
++ * are (possibly incomplete) rule applications.
++ *
+ * @param tnode the trie node pointing to the location ("dot") in the grammar trie
+ * @param i
+ * @param j
+ * @param antSuperNodesIn the supernodes representing the rule's tail nodes
+ * @param curSuperNode the lefthand side of the rule being created
+ * @param srcPath the path taken through the input lattice
+ */
+ private void addDotItem(Trie tnode, int i, int j, ArrayList<SuperNode> antSuperNodesIn,
+ SuperNode curSuperNode, SourcePath srcPath) {
+ ArrayList<SuperNode> antSuperNodes = new ArrayList<>();
+ if (antSuperNodesIn != null) {
+ antSuperNodes.addAll(antSuperNodesIn);
+ }
+ if (curSuperNode != null) {
+ antSuperNodes.add(curSuperNode);
+ }
+
+ DotNode item = new DotNode(i, j, tnode, antSuperNodes, srcPath);
+ if (dotcells.get(i, j) == null) {
+ dotcells.set(i, j, new DotCell());
+ }
+ dotcells.get(i, j).addDotNode(item);
+ dotChart.nDotitemAdded++;
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Add a dotitem in cell ({}, {}), n_dotitem={}, {}", i, j,
+ dotChart.nDotitemAdded, srcPath);
+
+ RuleCollection rules = tnode.getRuleCollection();
+ if (rules != null) {
+ for (Rule r : rules.getRules()) {
+ // System.out.println("rule: "+r.toString());
+ LOG.debug("{}", r);
+ }
+ }
+ }
+ }
+
+ // ===============================================================
+ // Package-protected classes
+ // ===============================================================
+
+ /**
+ * A DotCell groups together DotNodes that have been applied over a particular span. A DotNode, in
+ * turn, is a partially-applied grammar rule, represented as a pointer into the grammar trie
+ * structure.
+ */
+ static class DotCell {
+
+ // Package-protected fields
+ private final List<DotNode> dotNodes = new ArrayList<>();
+
+ public List<DotNode> getDotNodes() {
+ return dotNodes;
+ }
+
+ private void addDotNode(DotNode dt) {
+ /*
+ * if(l_dot_items==null) l_dot_items= new ArrayList<DotItem>();
+ */
+ dotNodes.add(dt);
+ }
+ }
+
+ /**
+ * A DotNode represents the partial application of a rule rooted to a particular span (i,j). It
+ * maintains a pointer to the trie node in the grammar for efficient mapping.
+ */
+ static class DotNode {
+
+ private final int i;
+ private final int j;
+ private Trie trieNode = null;
-
++
+ /* A list of grounded (over a span) nonterminals that have been crossed in traversing the rule */
+ private ArrayList<SuperNode> antSuperNodes = null;
-
++
+ /* The source lattice cost of applying the rule */
+ private final SourcePath srcPath;
+
+ @Override
+ public String toString() {
+ int size = 0;
+ if (trieNode != null && trieNode.getRuleCollection() != null)
+ size = trieNode.getRuleCollection().getRules().size();
+ return String.format("DOTNODE i=%d j=%d #rules=%d #tails=%d", i, j, size, antSuperNodes.size());
+ }
-
++
+ /**
+ * Initialize a dot node with the span, grammar trie node, list of supernode tail pointers, and
+ * the lattice sourcepath.
- *
++ *
+ * @param i
+ * @param j
+ * @param trieNode
+ * @param antSuperNodes
+ * @param srcPath
+ */
+ public DotNode(int i, int j, Trie trieNode, ArrayList<SuperNode> antSuperNodes, SourcePath srcPath) {
+ this.i = i;
+ this.j = j;
+ this.trieNode = trieNode;
+ this.antSuperNodes = antSuperNodes;
+ this.srcPath = srcPath;
+ }
+
++ @Override
+ public boolean equals(Object obj) {
+ if (obj == null)
+ return false;
+ if (!this.getClass().equals(obj.getClass()))
+ return false;
+ DotNode state = (DotNode) obj;
+
+ /*
+ * Technically, we should be comparing the span inforamtion as well, but that would require us
+ * to store it, increasing memory requirements, and we should be able to guarantee that we
+ * won't be comparing DotNodes across spans.
+ */
+ // if (this.i != state.i || this.j != state.j)
+ // return false;
+
+ return this.trieNode == state.trieNode;
+
+ }
+
+ /**
+ * Technically the hash should include the span (i,j), but since DotNodes are grouped by span,
+ * this isn't necessary, and we gain something by not having to store the span.
+ */
++ @Override
+ public int hashCode() {
+ return this.trieNode.hashCode();
+ }
+
+ // convenience function
+ public boolean hasRules() {
+ return getTrieNode().getRuleCollection() != null && getTrieNode().getRuleCollection().getRules().size() != 0;
+ }
-
++
+ public RuleCollection getRuleCollection() {
+ return getTrieNode().getRuleCollection();
+ }
+
+ public Trie getTrieNode() {
+ return trieNode;
+ }
+
+ public SourcePath getSourcePath() {
+ return srcPath;
+ }
+
+ public ArrayList<SuperNode> getAntSuperNodes() {
+ return antSuperNodes;
+ }
+
+ public int begin() {
+ return i;
+ }
-
++
+ public int end() {
+ return j;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
index 9338b0d,0000000..d9b894c
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
@@@ -1,215 -1,0 +1,214 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff;
+
+import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.chart_parser.SourcePath;
+import org.apache.joshua.decoder.ff.state_maintenance.DPState;
+import org.apache.joshua.decoder.ff.state_maintenance.NgramDPState;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.segment_file.Sentence;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.io.LineReader;
+
+/***
+ * The RuleBigram feature is an indicator feature that counts target word bigrams that are created when
+ * a rule is applied. It accepts three parameters:
+ *
+ * -vocab /path/to/vocab
+ *
+ * The path to a vocabulary, where each line is of the format ID WORD COUNT.
+ *
+ * -threshold N
+ *
+ * Mask to UNK all words whose COUNT is less than N.
+ *
+ * -top-n N
+ *
+ * Only use the top N words.
+ */
+
+public class TargetBigram extends StatefulFF {
+
+ private HashSet<String> vocab = null;
+ private int maxTerms = 1000000;
+ private int threshold = 0;
+
+ public TargetBigram(FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "TargetBigram", args, config);
+
+ if (parsedArgs.containsKey("threshold"))
+ threshold = Integer.parseInt(parsedArgs.get("threshold"));
+
+ if (parsedArgs.containsKey("top-n"))
+ maxTerms = Integer.parseInt(parsedArgs.get("top-n"));
+
+ if (parsedArgs.containsKey("vocab")) {
+ loadVocab(parsedArgs.get("vocab"));
+ }
+ }
+
+ /**
+ * Load vocabulary items passing the 'threshold' and 'top-n' filters.
+ *
+ * @param filename
+ */
+ private void loadVocab(String filename) {
+ this.vocab = new HashSet<>();
+ this.vocab.add("<s>");
+ this.vocab.add("</s>");
- try {
- LineReader lineReader = new LineReader(filename);
++ try(LineReader lineReader = new LineReader(filename);) {
+ for (String line: lineReader) {
+ if (lineReader.lineno() > maxTerms)
+ break;
+
+ String[] tokens = line.split("\\s+");
+ String word = tokens[1];
+ int count = Integer.parseInt(tokens[2]);
+
+ if (count >= threshold)
+ vocab.add(word);
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(String.format(
+ "* FATAL: couldn't load TargetBigram vocabulary '%s'", filename), e);
+ }
+ }
+
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int spanStart, int spanEnd,
+ SourcePath sourcePath, Sentence sentence, Accumulator acc) {
+
+ int[] enWords = rule.getTarget();
+
+ int left = -1;
+ int right = -1;
+
+ List<String> currentNgram = new LinkedList<>();
+ for (int curID : enWords) {
+ if (FormatUtils.isNonterminal(curID)) {
+ int index = -(curID + 1);
+ NgramDPState state = (NgramDPState) tailNodes.get(index).getDPState(stateIndex);
+ int[] leftContext = state.getLeftLMStateWords();
+ int[] rightContext = state.getRightLMStateWords();
+
+ // Left context.
+ for (int token : leftContext) {
+ currentNgram.add(getWord(token));
+ if (left == -1)
+ left = token;
+ right = token;
+ if (currentNgram.size() == 2) {
+ String ngram = join(currentNgram);
+ acc.add(hashFeature(String.format("%s_%s", name, ngram)), 1);
+ // System.err.println(String.format("ADDING %s_%s", name, ngram));
+ currentNgram.remove(0);
+ }
+ }
+ // Replace right context.
+ int tSize = currentNgram.size();
+ for (int i = 0; i < rightContext.length; i++)
+ currentNgram.set(tSize - rightContext.length + i, getWord(rightContext[i]));
+
+ } else { // terminal words
+ currentNgram.add(getWord(curID));
+ if (left == -1)
+ left = curID;
+ right = curID;
+ if (currentNgram.size() == 2) {
+ String ngram = join(currentNgram);
+ acc.add(hashFeature(String.format("%s_%s", name, ngram)), 1);
+ // System.err.println(String.format("ADDING %s_%s", name, ngram));
+ currentNgram.remove(0);
+ }
+ }
+ }
+
+ // System.err.println(String.format("RULE %s -> state %s", rule.getRuleString(), state));
+ return new NgramDPState(new int[] { left }, new int[] { right });
+ }
+
+ /**
+ * Returns the word after comparing against the private vocabulary (if set).
+ *
+ * @param curID
+ * @return the word
+ */
+ private String getWord(int curID) {
+ String word = Vocabulary.word(curID);
+
+ if (vocab != null && ! vocab.contains(word)) {
+ return "UNK";
+ }
+
+ return word;
+ }
+
+ /**
+ * We don't compute a future cost.
+ */
+ @Override
+ public float estimateFutureCost(Rule rule, DPState state, Sentence sentence) {
+ return 0.0f;
+ }
+
+ /**
+ * There is nothing to be done here, since <s> and </s> are included in rules that are part
+ * of the grammar. We simply return the DP state of the tail node.
+ */
+ @Override
+ public DPState computeFinal(HGNode tailNode, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ return tailNode.getDPState(stateIndex);
+ }
+
+ /**
- * TargetBigram features are only computed across hyperedges, so there is nothing to be done here.
++ * TargetBigram features are only computed across hyperedges, so there is nothing to be done here.
+ */
+ @Override
+ public float estimateCost(Rule rule, Sentence sentence) {
+ return 0.0f;
+ }
+
+ /**
+ * Join a list with the _ character. I am sure this is in a library somewhere.
+ *
+ * @param list a list of strings
+ * @return the joined String
+ */
+ private String join(List<String> list) {
+ StringBuilder sb = new StringBuilder();
+ for (String item : list) {
+ sb.append(item).append("_");
+ }
+
+ return sb.substring(0, sb.length() - 1);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
index 036c4bc,0000000..f822fe4
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
@@@ -1,777 -1,0 +1,786 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.fragmentlm;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.io.StringReader;
- import java.util.*;
++import java.util.ArrayList;
++import java.util.Collection;
++import java.util.Collections;
++import java.util.HashMap;
++import java.util.HashSet;
++import java.util.Iterator;
++import java.util.List;
++import java.util.Set;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.fragmentlm.Trees.PennTreeReader;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.hypergraph.HyperEdge;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationState;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Represent phrase-structure trees, with each node consisting of a label and a list of children.
+ * Borrowed from the Berkeley Parser, and extended to allow the representation of tree fragments in
+ * addition to complete trees (the BP requires terminals to be immediately governed by a
+ * preterminal). To distinguish terminals from nonterminals in fragments, the former must be
+ * enclosed in double-quotes when read in.
- *
++ *
+ * @author Dan Klein
+ * @author Matt Post post@cs.jhu.edu
+ */
+public class Tree implements Serializable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(Tree.class);
+ private static final long serialVersionUID = 1L;
+
+ protected int label;
+
+ /* Marks a frontier node as a terminal (as opposed to a nonterminal). */
+ boolean isTerminal = false;
+
+ /*
+ * Marks the root and frontier nodes of a fragment. Useful for denoting fragment derivations in
+ * larger trees.
+ */
+ boolean isBoundary = false;
+
+ /* A list of the node's children. */
+ List<Tree> children;
+
+ /* The maximum distance from the root to any of the frontier nodes. */
+ int depth = -1;
+
+ /* The number of lexicalized items among the tree's frontier. */
+ private int numLexicalItems = -1;
+
+ /*
+ * This maps the flat right-hand sides of Joshua rules to the tree fragments they were derived
+ * from. It is used to lookup the fragment that language model fragments should be match against.
+ * For example, if the target (English) side of your rule is
- *
++ *
+ * [NP,1] said [SBAR,2]
- *
++ *
+ * we will retrieve the unflattened fragment
- *
++ *
+ * (S NP (VP (VBD said) SBAR))
- *
++ *
+ * which presumably was the fronter fragment used to derive the translation rule. With this in
+ * hand, we can iterate through our store of language model fragments to match them against this,
+ * following tail nodes if necessary.
+ */
+ public static final HashMap<String, String> rulesToFragmentStrings = new HashMap<>();
+
+ public Tree(String label, List<Tree> children) {
+ setLabel(label);
+ this.children = children;
+ }
+
+ public Tree(String label) {
+ setLabel(label);
+ this.children = Collections.emptyList();
+ }
+
+ public Tree(int label2, ArrayList<Tree> newChildren) {
+ this.label = label2;
+ this.children = newChildren;
+ }
+
+ public void setChildren(List<Tree> c) {
+ this.children = c;
+ }
+
+ public List<Tree> getChildren() {
+ return children;
+ }
+
+ public int getLabel() {
+ return label;
+ }
+
+ /**
+ * Computes the depth-one rule rooted at this node. If the node has no children, null is returned.
- *
++ *
+ * @return string representation of the rule
+ */
+ public String getRule() {
+ if (isLeaf()) {
+ return null;
+ }
+ StringBuilder ruleString = new StringBuilder("(" + Vocabulary.word(getLabel()));
+ for (Tree child : getChildren()) {
+ ruleString.append(" ").append(Vocabulary.word(child.getLabel()));
+ }
+ return ruleString.toString();
+ }
+
+ /*
+ * Boundary nodes are used externally to mark merge points between different fragments. This is
+ * separate from the internal ( (substitution point) denotation.
+ */
+ public boolean isBoundary() {
+ return isBoundary;
+ }
+
+ public void setBoundary(boolean b) {
+ this.isBoundary = b;
+ }
+
+ public boolean isTerminal() {
+ return isTerminal;
+ }
+
+ public boolean isLeaf() {
+ return getChildren().isEmpty();
+ }
+
+ public boolean isPreTerminal() {
+ return getChildren().size() == 1 && getChildren().get(0).isLeaf();
+ }
+
+ public List<Tree> getNonterminalYield() {
+ List<Tree> yield = new ArrayList<>();
+ appendNonterminalYield(this, yield);
+ return yield;
+ }
+
+ public List<Tree> getYield() {
+ List<Tree> yield = new ArrayList<>();
+ appendYield(this, yield);
+ return yield;
+ }
+
+ public List<Tree> getTerminals() {
+ List<Tree> yield = new ArrayList<>();
+ appendTerminals(this, yield);
+ return yield;
+ }
+
+ private static void appendTerminals(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendTerminals(child, yield);
+ }
+ }
+
+ /**
+ * Clone the structure of the tree.
- *
++ *
+ * @return a cloned tree
+ */
+ public Tree shallowClone() {
+ ArrayList<Tree> newChildren = new ArrayList<>(children.size());
+ for (Tree child : children) {
+ newChildren.add(child.shallowClone());
+ }
+
+ Tree newTree = new Tree(label, newChildren);
+ newTree.setIsTerminal(isTerminal());
+ newTree.setBoundary(isBoundary());
+ return newTree;
+ }
+
+ private void setIsTerminal(boolean terminal) {
+ isTerminal = terminal;
+ }
+
+ private static void appendNonterminalYield(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf() && !tree.isTerminal()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendNonterminalYield(child, yield);
+ }
+ }
+
+ private static void appendYield(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendYield(child, yield);
+ }
+ }
+
+ public List<Tree> getPreTerminalYield() {
+ List<Tree> yield = new ArrayList<>();
+ appendPreTerminalYield(this, yield);
+ return yield;
+ }
+
+ private static void appendPreTerminalYield(Tree tree, List<Tree> yield) {
+ if (tree.isPreTerminal()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendPreTerminalYield(child, yield);
+ }
+ }
+
+ /**
+ * A tree is lexicalized if it has terminal nodes among the leaves of its frontier. For normal
+ * trees this is always true since they bottom out in terminals, but for fragments, this may or
+ * may not be true.
- *
++ *
+ * @return true if the tree is lexicalized
+ */
+ public boolean isLexicalized() {
+ if (this.numLexicalItems < 0) {
+ if (isTerminal())
+ this.numLexicalItems = 1;
+ else {
+ this.numLexicalItems = 0;
+ children.stream().filter(child -> child.isLexicalized())
+ .forEach(child -> this.numLexicalItems += 1);
+ }
+ }
+
+ return (this.numLexicalItems > 0);
+ }
+
+ /**
+ * The depth of a tree is the maximum distance from the root to any of the frontier nodes.
- *
++ *
+ * @return the tree depth
+ */
+ public int getDepth() {
+ if (this.depth >= 0)
+ return this.depth;
+
+ if (isLeaf()) {
+ this.depth = 0;
+ } else {
+ int maxDepth = 0;
+ for (Tree child : children) {
+ int depth = child.getDepth();
+ if (depth > maxDepth)
+ maxDepth = depth;
+ }
+ this.depth = maxDepth + 1;
+ }
+ return this.depth;
+ }
+
+ public List<Tree> getAtDepth(int depth) {
+ List<Tree> yield = new ArrayList<>();
+ appendAtDepth(depth, this, yield);
+ return yield;
+ }
+
+ private static void appendAtDepth(int depth, Tree tree, List<Tree> yield) {
+ if (depth < 0)
+ return;
+ if (depth == 0) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendAtDepth(depth - 1, child, yield);
+ }
+ }
+
+ public void setLabel(String label) {
+ if (label.length() >= 3 && label.startsWith("\"") && label.endsWith("\"")) {
+ this.isTerminal = true;
+ label = label.substring(1, label.length() - 1);
+ }
+
+ this.label = Vocabulary.id(label);
+ }
+
++ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ toStringBuilder(sb);
+ return sb.toString();
+ }
+
+ /**
+ * Removes the quotes around terminals. Note that the resulting tree could not be read back
+ * in by this class, since unquoted leaves are interpreted as nonterminals.
- *
++ *
+ * @return unquoted string
+ */
+ public String unquotedString() {
+ return toString().replaceAll("\"", "");
+ }
-
++
+ public String escapedString() {
+ return toString().replaceAll(" ", "_");
+ }
+
+ public void toStringBuilder(StringBuilder sb) {
+ if (!isLeaf())
+ sb.append('(');
+
+ if (isTerminal())
+ sb.append(String.format("\"%s\"", Vocabulary.word(getLabel())));
+ else
+ sb.append(Vocabulary.word(getLabel()));
+
+ if (!isLeaf()) {
+ for (Tree child : getChildren()) {
+ sb.append(' ');
+ child.toStringBuilder(sb);
+ }
+ sb.append(')');
+ }
+ }
+
+ /**
+ * Get the set of all subtrees inside the tree by returning a tree rooted at each node. These are
+ * <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
- *
++ *
+ * @return the <code>Set</code> of all subtrees in the tree.
+ */
+ public Set<Tree> subTrees() {
+ return (Set<Tree>) subTrees(new HashSet<>());
+ }
+
+ /**
+ * Get the list of all subtrees inside the tree by returning a tree rooted at each node. These are
+ * <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
- *
++ *
+ * @return the <code>List</code> of all subtrees in the tree.
+ */
+ public List<Tree> subTreeList() {
+ return (List<Tree>) subTrees(new ArrayList<>());
+ }
+
+ /**
+ * Add the set of all subtrees inside a tree (including the tree itself) to the given
+ * <code>Collection</code>.
- *
++ *
+ * @param n A collection of nodes to which the subtrees will be added
+ * @return The collection parameter with the subtrees added
+ */
+ public Collection<Tree> subTrees(Collection<Tree> n) {
+ n.add(this);
+ List<Tree> kids = getChildren();
+ for (Tree kid : kids) {
+ kid.subTrees(n);
+ }
+ return n;
+ }
+
+ /**
+ * Returns an iterator over the nodes of the tree. This method implements the
+ * <code>iterator()</code> method required by the <code>Collections</code> interface. It does a
+ * preorder (children after node) traversal of the tree. (A possible extension to the class at
+ * some point would be to allow different traversal orderings via variant iterators.)
- *
++ *
+ * @return An interator over the nodes of the tree
+ */
+ public TreeIterator iterator() {
+ return new TreeIterator();
+ }
+
+ private class TreeIterator implements Iterator<Tree> {
+
+ private final List<Tree> treeStack;
+
+ private TreeIterator() {
+ treeStack = new ArrayList<>();
+ treeStack.add(Tree.this);
+ }
+
++ @Override
+ public boolean hasNext() {
+ return (!treeStack.isEmpty());
+ }
+
++ @Override
+ public Tree next() {
+ int lastIndex = treeStack.size() - 1;
+ Tree tr = treeStack.remove(lastIndex);
+ List<Tree> kids = tr.getChildren();
+ // so that we can efficiently use one List, we reverse them
+ for (int i = kids.size() - 1; i >= 0; i--) {
+ treeStack.add(kids.get(i));
+ }
+ return tr;
+ }
+
+ /**
+ * Not supported
+ */
++ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ }
+
+ public boolean hasUnaryChain() {
+ return hasUnaryChainHelper(this, false);
+ }
+
+ private boolean hasUnaryChainHelper(Tree tree, boolean unaryAbove) {
+ boolean result = false;
+ if (tree.getChildren().size() == 1) {
+ if (unaryAbove)
+ return true;
+ else if (tree.getChildren().get(0).isPreTerminal())
+ return false;
+ else
+ return hasUnaryChainHelper(tree.getChildren().get(0), true);
+ } else {
+ for (Tree child : tree.getChildren()) {
+ if (!child.isPreTerminal())
+ result = result || hasUnaryChainHelper(child, false);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Inserts the SOS (and EOS) symbols into a parse tree, attaching them as a left (right) sibling
+ * to the leftmost (rightmost) pre-terminal in the tree. This facilitates using trees as language
+ * models. The arguments have to be passed in to preserve Java generics, even though this is only
+ * ever used with String versions.
- *
++ *
+ * @param sos presumably "<s>"
+ * @param eos presumably "</s>"
+ */
+ public void insertSentenceMarkers(String sos, String eos) {
+ insertSentenceMarker(sos, 0);
+ insertSentenceMarker(eos, -1);
+ }
+
+ public void insertSentenceMarkers() {
+ insertSentenceMarker("<s>", 0);
+ insertSentenceMarker("</s>", -1);
+ }
+
+ /**
- *
++ *
+ * @param symbol the marker to insert
+ * @param pos the position at which to insert
+ */
+ private void insertSentenceMarker(String symbol, int pos) {
+
+ if (isLeaf() || isPreTerminal())
+ return;
+
+ List<Tree> children = getChildren();
+ int index = (pos == -1) ? children.size() - 1 : pos;
+ if (children.get(index).isPreTerminal()) {
+ if (pos == -1)
+ children.add(new Tree(symbol));
+ else
+ children.add(pos, new Tree(symbol));
+ } else {
+ children.get(index).insertSentenceMarker(symbol, pos);
+ }
+ }
+
+ /**
+ * This is a convenience function for producing a fragment from its string representation.
- *
++ *
+ * @param ptbStr input string from which to produce a fragment
+ * @return the fragment
+ */
+ public static Tree fromString(String ptbStr) {
+ PennTreeReader reader = new PennTreeReader(new StringReader(ptbStr));
+ return reader.next();
+ }
+
+ public static Tree getFragmentFromYield(String yield) {
+ String fragmentString = rulesToFragmentStrings.get(yield);
+ if (fragmentString != null)
+ return fromString(fragmentString);
+
+ return null;
+ }
+
+ public static void readMapping(String fragmentMappingFile) {
+ /* Read in the rule / fragments mapping */
- try {
- LineReader reader = new LineReader(fragmentMappingFile);
++ try (LineReader reader = new LineReader(fragmentMappingFile);) {
+ for (String line : reader) {
+ String[] fields = line.split("\\s+\\|{3}\\s+");
+ if (fields.length != 2 || !fields[0].startsWith("(")) {
+ LOG.warn("malformed line {}: {}", reader.lineno(), line);
+ continue;
+ }
+
+ rulesToFragmentStrings.put(fields[1].trim(), fields[0].trim()); // buildFragment(fields[0]));
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(String.format("* WARNING: couldn't read fragment mapping file '%s'",
+ fragmentMappingFile), e);
+ }
+ LOG.info("FragmentLMFF: Read {} mappings from '{}'", rulesToFragmentStrings.size(),
+ fragmentMappingFile);
+ }
+
+ /**
+ * Builds a tree from the kth-best derivation state. This is done by initializing the tree with
+ * the internal fragment corresponding to the rule; this will be the top of the tree. We then
+ * recursively visit the derivation state objects, following the route through the hypergraph
+ * defined by them.
- *
++ *
+ * This function is like Tree#buildTree(DerivationState, int),
+ * but that one simply follows the best incoming hyperedge for each node.
- *
++ *
+ * @param rule for which corresponding internal fragment can be used to initialize the tree
+ * @param derivationStates array of state objects
+ * @param maxDepth of route through the hypergraph
- * @return the Tree
++ * @return the Tree
+ */
+ public static Tree buildTree(Rule rule, DerivationState[] derivationStates, int maxDepth) {
+ Tree tree = getFragmentFromYield(rule.getTargetWords());
+
+ if (tree == null) {
+ return null;
+ }
+
+ tree = tree.shallowClone();
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("buildTree({})", tree);
+ for (int i = 0; i < derivationStates.length; i++) {
+ LOG.debug(" -> {}: {}", i, derivationStates[i]);
+ }
+ }
+
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ /* The English side of a rule is a sequence of integers. Nonnegative integers are word
+ * indices in the Vocabulary, while negative indices are used to nonterminals. These negative
+ * indices are a *permutation* of the source side nonterminals, which contain the actual
+ * nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
- * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
++ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * the incoming DerivationState items, which are ordered by the source side.
+ */
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] englishInts = rule.getTarget();
+ for (int i = 0; i < englishInts.length; i++)
+ if (englishInts[i] < 0)
+ tailIndices.add(-(englishInts[i] + 1));
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the heads of the derivation states. Since we don't know which of the tree's
+ * frontier items are terminals and which are nonterminals, we walk through the tail nodes,
+ * and then match the label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < derivationStates.length; i++) {
+
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ HyperEdge nextEdge = derivationStates[i].edge;
+ if (nextEdge != null) {
+ DerivationState[] nextStates = null;
+ if (nextEdge.getTailNodes() != null && nextEdge.getTailNodes().size() > 0) {
+ nextStates = new DerivationState[nextEdge.getTailNodes().size()];
+ for (int j = 0; j < nextStates.length; j++)
+ nextStates[j] = derivationStates[i].getChildDerivationState(nextEdge, j);
+ }
+ Tree childTree = buildTree(nextEdge.getRule(), nextStates, maxDepth - 1);
+
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ } else {
+ frontierTree.children = tree.children;
+ }
+ }
-
++
+ return tree;
+ }
-
++
+ /**
+ * <p>Builds a tree from the kth-best derivation state. This is done by initializing the tree with
+ * the internal fragment corresponding to the rule; this will be the top of the tree. We then
+ * recursively visit the derivation state objects, following the route through the hypergraph
+ * defined by them.</p>
- *
++ *
+ * @param derivationState array of state objects
+ * @param maxDepth of route through the hypergraph
+ * @return the Tree
+ */
+ public static Tree buildTree(DerivationState derivationState, int maxDepth) {
+ Rule rule = derivationState.edge.getRule();
-
+ Tree tree = getFragmentFromYield(rule.getTargetWords());
+
+ if (tree == null) {
+ return null;
+ }
+
+ tree = tree.shallowClone();
-
++
+ LOG.debug("buildTree({})", tree);
+
+ if (rule.getArity() > 0 && maxDepth > 0) {
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ /* The English side of a rule is a sequence of integers. Nonnegative integers are word
+ * indices in the Vocabulary, while negative indices are used to nonterminals. These negative
+ * indices are a *permutation* of the source side nonterminals, which contain the actual
+ * nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
- * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
++ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * the incoming DerivationState items, which are ordered by the source side.
+ */
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] targetInts = rule.getTarget();
+ for (int i = 0; i < targetInts.length; i++)
+ if (targetInts[i] < 0)
+ tailIndices.add(-(targetInts[i] + 1));
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the heads of the derivation states. Since we don't know which of the tree's
+ * frontier items are terminals and which are nonterminals, we walk through the tail nodes,
+ * and then match the label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < rule.getArity(); i++) {
+
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ DerivationState childState = derivationState.getChildDerivationState(derivationState.edge, i);
+ Tree childTree = buildTree(childState, maxDepth - 1);
+
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ }
+ }
-
++
+ return tree;
+ }
+
+ /**
+ * Takes a rule and its tail pointers and recursively constructs a tree (up to maxDepth).
- *
++ *
+ * This could be implemented by using the other buildTree() function and using the 1-best
+ * DerivationState.
- *
++ *
+ * @param rule {@link org.apache.joshua.decoder.ff.tm.Rule} to be used whilst building the tree
+ * @param tailNodes {@link java.util.List} of {@link org.apache.joshua.decoder.hypergraph.HGNode}'s
+ * @param maxDepth to go in the tree
+ * @return shallow clone of the Tree object
+ */
+ public static Tree buildTree(Rule rule, List<HGNode> tailNodes, int maxDepth) {
+ Tree tree = getFragmentFromYield(rule.getTargetWords());
+
+ if (tree == null) {
+ tree = new Tree(String.format("(%s %s)", Vocabulary.word(rule.getLHS()), rule.getTargetWords()));
+ // System.err.println("COULDN'T FIND " + rule.getEnglishWords());
+ // System.err.println("RULE " + rule);
+ // for (Entry<String, Tree> pair: rulesToFragments.entrySet())
+ // System.err.println(" FOUND " + pair.getKey());
+
+// return null;
+ } else {
+ tree = tree.shallowClone();
+ }
+
+ if (tree != null && tailNodes != null && tailNodes.size() > 0 && maxDepth > 0) {
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] targetInts = rule.getTarget();
+ for (int i = 0; i < targetInts.length; i++)
+ if (targetInts[i] < 0)
+ tailIndices.add(-1 * targetInts[i] - 1);
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the tail nodes. Since we don't know which of the tree's frontier items are
+ * terminals and which are nonterminals, we walk through the tail nodes, and then match the
+ * label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < tailNodes.size(); i++) {
+
+ // String lhs = tailNodes.get(i).getLHS().replaceAll("[\\[\\]]", "");
+ // System.err.println(String.format(" %d: %s", i, lhs));
+ try {
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ HyperEdge edge = tailNodes.get(i).bestHyperedge;
+ if (edge != null) {
+ Tree childTree = buildTree(edge.getRule(), edge.getTailNodes(), maxDepth - 1);
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ } else {
+ frontierTree.children = tree.children;
+ }
+ } catch (IndexOutOfBoundsException e) {
+ LOG.error("ERROR at index {}", i);
+ LOG.error("RULE: {} TREE: {}", rule.getTargetWords(), tree);
+ LOG.error(" FRONTIER:");
+ for (Tree kid : frontier) {
+ LOG.error(" {}", kid);
+ }
+ throw new RuntimeException(String.format("ERROR at index %d", i), e);
+ }
+ }
+ }
+
+ return tree;
+ }
+
- public static void main(String[] args) {
- LineReader reader = new LineReader(System.in);
-
- for (String line : reader) {
- try {
- Tree tree = Tree.fromString(line);
- tree.insertSentenceMarkers();
- System.out.println(tree);
- } catch (Exception e) {
- System.out.println("");
++ public static void main(String[] args) throws IOException {
++ try (LineReader reader = new LineReader(System.in);) {
++ for (String line : reader) {
++ try {
++ Tree tree = Tree.fromString(line);
++ tree.insertSentenceMarkers();
++ System.out.println(tree);
++ } catch (Exception e) {
++ System.out.println("");
++ }
+ }
+ }
+
+ /*
+ * Tree fragment = Tree
+ * .fromString("(TOP (S (NP (DT the) (NN boy)) (VP (VBD ate) (NP (DT the) (NN food)))))");
+ * fragment.insertSentenceMarkers("<s>", "</s>");
- *
++ *
+ * System.out.println(fragment);
- *
++ *
+ * ArrayList<Tree> trees = new ArrayList<Tree>(); trees.add(Tree.fromString("(NN \"mat\")"));
+ * trees.add(Tree.fromString("(S (NP DT NN) VP)"));
+ * trees.add(Tree.fromString("(S (NP (DT \"the\") NN) VP)"));
+ * trees.add(Tree.fromString("(S (NP (DT the) NN) VP)"));
- *
++ *
+ * for (Tree tree : trees) { System.out.println(String.format("TREE %s DEPTH %d LEX? %s", tree,
+ * tree.getDepth(), tree.isLexicalized())); }
+ */
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index 93d54ed,0000000..044c85f
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@@ -1,257 -1,0 +1,259 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
+import org.apache.joshua.util.FormatUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * JNI wrapper for KenLM. This version of KenLM supports two use cases, implemented by the separate
+ * feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
+ * lm/left.hh, returning a state pointer representing the KenLM state, while LangaugeModelFF handles
+ * state by itself and just passes in the ngrams for scoring.
- *
++ *
+ * @author Kenneth Heafield
+ * @author Matt Post post@cs.jhu.edu
+ */
+
+public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(KenLM.class);
+
+ private final long pointer;
+
+ // this is read from the config file, used to set maximum order
+ private final int ngramOrder;
+ // inferred from model file (may be larger than ngramOrder)
+ private final int N;
+
+ private static native long construct(String file_name);
+
+ private static native void destroy(long ptr);
+
+ private static native int order(long ptr);
+
+ private static native boolean registerWord(long ptr, String word, int id);
+
+ private static native float prob(long ptr, int words[]);
+
+ private static native float probForString(long ptr, String[] words);
+
+ private static native boolean isKnownWord(long ptr, String word);
-
++
+ private static native boolean isLmOov(long ptr, int word);
+
+ private static native StateProbPair probRule(long ptr, long pool, long words[]);
-
++
+ private static native float estimateRule(long ptr, long words[]);
+
+ private static native float probString(long ptr, int words[], int start);
+
+ private static native long createPool();
+
+ private static native void destroyPool(long pointer);
+
+ public KenLM(int order, String file_name) {
+ pointer = initializeSystemLibrary(file_name);
+ ngramOrder = order;
+ N = order(pointer);
+ }
+
+ /**
+ * Constructor if order is not known.
+ * Order will be inferred from the model.
+ * @param file_name string path to an input file
+ */
+ public KenLM(String file_name) {
+ pointer = initializeSystemLibrary(file_name);
+ N = order(pointer);
+ ngramOrder = N;
+ }
+
+ private long initializeSystemLibrary(String file_name) {
+ try {
+ System.loadLibrary("ken");
+ return construct(file_name);
+ } catch (UnsatisfiedLinkError e) {
+ LOG.error("Can't find libken.so (libken.dylib on OS X) on the Java library path.");
+ throw new KenLMLoadException(e);
+ }
+ }
+
- public class KenLMLoadException extends RuntimeException {
++ public static class KenLMLoadException extends RuntimeException {
+
+ public KenLMLoadException(UnsatisfiedLinkError e) {
+ super(e);
+ }
+ }
+
+ public long createLMPool() {
+ return createPool();
+ }
+
+ public void destroyLMPool(long pointer) {
+ destroyPool(pointer);
+ }
+
+ public void destroy() {
+ destroy(pointer);
+ }
+
++ @Override
+ public int getOrder() {
+ return ngramOrder;
+ }
+
++ @Override
+ public boolean registerWord(String word, int id) {
+ return registerWord(pointer, word, id);
+ }
+
+ public float prob(int[] words) {
+ return prob(pointer, words);
+ }
+
+ /**
+ * Query for n-gram probability using strings.
+ * @param words a string array of words
+ * @return float value denoting probability
+ */
+ public float prob(String[] words) {
+ return probForString(pointer, words);
+ }
+
+ // Apparently Zhifei starts some array indices at 1. Change to 0-indexing.
+ public float probString(int words[], int start) {
+ return probString(pointer, words, start - 1);
+ }
+
+ /**
+ * This function is the bridge to the interface in kenlm/lm/left.hh, which has KenLM score the
+ * whole rule. It takes an array of words and states retrieved from tail nodes (nonterminals in the
+ * rule). Nonterminals have a negative value so KenLM can distinguish them. The sentence number is
+ * needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
+ * state and the LM probability incurred along this rule.
- *
++ *
+ * @param words array of words
+ * @param poolPointer todo
- * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
++ * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
+ * KenLM state and the LM probability incurred along this rule
+ */
+ public StateProbPair probRule(long[] words, long poolPointer) {
+
+ StateProbPair pair = null;
+ try {
+ pair = probRule(pointer, poolPointer, words);
+ } catch (NoSuchMethodError e) {
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ return pair;
+ }
+
+ /**
+ * Public facing function that estimates the cost of a rule, which value is used for sorting
+ * rules during cube pruning.
- *
++ *
+ * @param words array of words
+ * @return the estimated cost of the rule (the (partial) n-gram probabilities of all words in the rule)
+ */
+ public float estimateRule(long[] words) {
+ float estimate = 0.0f;
+ try {
+ estimate = estimateRule(pointer, words);
+ } catch (NoSuchMethodError e) {
+ throw new RuntimeException(e);
+ }
-
++
+ return estimate;
+ }
+
+ /**
+ * The start symbol for a KenLM is the Vocabulary.START_SYM.
+ * @return "<s>"
+ */
+ public String getStartSymbol() {
+ return Vocabulary.START_SYM;
+ }
-
++
+ /**
+ * Returns whether the given Vocabulary ID is unknown to the
+ * KenLM vocabulary. This can be used for a LanguageModel_OOV features
+ * and does not need to convert to an intermediate string.
+ */
+ @Override
+ public boolean isOov(int wordId) {
+ if (FormatUtils.isNonterminal(wordId)) {
+ throw new IllegalArgumentException("Should not query for nonterminals!");
+ }
+ return isLmOov(pointer, wordId);
+ }
+
+ public boolean isKnownWord(String word) {
+ return isKnownWord(pointer, word);
+ }
+
+
+ /**
+ * Inner class used to hold the results returned from KenLM with left-state minimization. Note
+ * that inner classes have to be static to be accessible from the JNI!
+ */
+ public static class StateProbPair {
+ public KenLMState state = null;
+ public float prob = 0.0f;
+
+ public StateProbPair(long state, float prob) {
+ this.state = new KenLMState(state);
+ this.prob = prob;
+ }
+ }
+
+ @Override
+ public int compareTo(KenLM other) {
+ if (this == other)
+ return 0;
+ else
+ return -1;
+ }
+
+ /**
+ * These functions are used if KenLM is invoked under LanguageModelFF instead of KenLMFF.
+ */
+ @Override
+ public float sentenceLogProbability(int[] sentence, int order, int startIndex) {
+ return probString(sentence, startIndex);
+ }
+
+ @Override
+ public float ngramLogProbability(int[] ngram, int order) {
+ if (order != N && order != ngram.length)
+ throw new RuntimeException("Lower order not supported.");
+ return prob(ngram);
+ }
+
+ @Override
+ public float ngramLogProbability(int[] ngram) {
+ return prob(ngram);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
index 9bfccb0,0000000..0615077
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
@@@ -1,334 -1,0 +1,284 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm.buildin_lm;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.Scanner;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.lm.AbstractLM;
+import org.apache.joshua.decoder.ff.lm.ArpaFile;
+import org.apache.joshua.decoder.ff.lm.ArpaNgram;
+import org.apache.joshua.util.Bits;
+import org.apache.joshua.util.Regex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Relatively memory-compact language model
+ * stored as a reversed-word-order trie.
+ * <p>
+ * The trie itself represents language model context.
+ * <p>
- * Conceptually, each node in the trie stores a map
++ * Conceptually, each node in the trie stores a map
+ * from conditioning word to log probability.
+ * <p>
- * Additionally, each node in the trie stores
++ * Additionally, each node in the trie stores
+ * the backoff weight for that context.
- *
++ *
+ * @author Lane Schwartz
+ * @see <a href="http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html">SRILM ngram-discount documentation</a>
+ */
+public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel {
+
+ private static final Logger LOG = LoggerFactory.getLogger(TrieLM.class);
+
+ /**
+ * Node ID for the root node.
+ */
+ private static final int ROOT_NODE_ID = 0;
+
+
- /**
- * Maps from (node id, word id for child) --> node id of child.
++ /**
++ * Maps from (node id, word id for child) --> node id of child.
+ */
+ private final Map<Long,Integer> children;
+
+ /**
- * Maps from (node id, word id for lookup word) -->
- * log prob of lookup word given context
- *
++ * Maps from (node id, word id for lookup word) -->
++ * log prob of lookup word given context
++ *
+ * (the context is defined by where you are in the tree).
+ */
+ private final Map<Long,Float> logProbs;
+
+ /**
- * Maps from (node id) -->
- * backoff weight for that context
- *
++ * Maps from (node id) -->
++ * backoff weight for that context
++ *
+ * (the context is defined by where you are in the tree).
+ */
+ private final Map<Integer,Float> backoffs;
+
+ public TrieLM(Vocabulary vocab, String file) throws FileNotFoundException {
+ this(new ArpaFile(file,vocab));
+ }
+
+ /**
+ * Constructs a language model object from the specified ARPA file.
- *
++ *
+ * @param arpaFile input ARPA file
+ * @throws FileNotFoundException if the input file cannot be located
+ */
+ public TrieLM(ArpaFile arpaFile) throws FileNotFoundException {
+ super(Vocabulary.size(), arpaFile.getOrder());
+
+ int ngramCounts = arpaFile.size();
+ LOG.debug("ARPA file contains {} n-grams", ngramCounts);
+
+ this.children = new HashMap<>(ngramCounts);
+ this.logProbs = new HashMap<>(ngramCounts);
+ this.backoffs = new HashMap<>(ngramCounts);
+
+ int nodeCounter = 0;
+
+ int lineNumber = 0;
+ for (ArpaNgram ngram : arpaFile) {
+ lineNumber += 1;
+ if (lineNumber % 100000 == 0){
+ LOG.info("Line: {}", lineNumber);
+ }
+
+ LOG.debug("{}-gram: ({} | {})", ngram.order(), ngram.getWord(),
+ Arrays.toString(ngram.getContext()));
+ int word = ngram.getWord();
+
+ int[] context = ngram.getContext();
+
+ {
+ // Find where the log prob should be stored
+ int contextNodeID = ROOT_NODE_ID;
+ {
+ for (int i=context.length-1; i>=0; i--) {
+ long key = Bits.encodeAsLong(contextNodeID, context[i]);
+ int childID;
+ if (children.containsKey(key)) {
+ childID = children.get(key);
+ } else {
+ childID = ++nodeCounter;
+ LOG.debug("children.put({}:{}, {})", contextNodeID, context[i], childID);
+ children.put(key, childID);
+ }
+ contextNodeID = childID;
+ }
+ }
+
+ // Store the log prob for this n-gram at this node in the trie
+ {
+ long key = Bits.encodeAsLong(contextNodeID, word);
+ float logProb = ngram.getValue();
+ LOG.debug("logProbs.put({}:{}, {}", contextNodeID, word, logProb);
+ this.logProbs.put(key, logProb);
+ }
+ }
+
+ {
+ // Find where the backoff should be stored
+ int backoffNodeID = ROOT_NODE_ID;
- {
++ {
+ long backoffNodeKey = Bits.encodeAsLong(backoffNodeID, word);
+ int wordChildID;
+ if (children.containsKey(backoffNodeKey)) {
+ wordChildID = children.get(backoffNodeKey);
+ } else {
+ wordChildID = ++nodeCounter;
+ LOG.debug("children.put({}: {}, {})", backoffNodeID, word, wordChildID);
+ children.put(backoffNodeKey, wordChildID);
+ }
+ backoffNodeID = wordChildID;
+
+ for (int i=context.length-1; i>=0; i--) {
+ long key = Bits.encodeAsLong(backoffNodeID, context[i]);
+ int childID;
+ if (children.containsKey(key)) {
+ childID = children.get(key);
+ } else {
+ childID = ++nodeCounter;
+ LOG.debug("children.put({}:{}, {})", backoffNodeID, context[i], childID);
+ children.put(key, childID);
+ }
+ backoffNodeID = childID;
+ }
+ }
+
+ // Store the backoff for this n-gram at this node in the trie
+ {
+ float backoff = ngram.getBackoff();
+ LOG.debug("backoffs.put({}:{}, {})", backoffNodeID, word, backoff);
+ this.backoffs.put(backoffNodeID, backoff);
+ }
+ }
+
+ }
+ }
+
+
+ @Override
- protected double logProbabilityOfBackoffState_helper(
- int[] ngram, int order, int qtyAdditionalBackoffWeight
- ) {
++ protected double logProbabilityOfBackoffState_helper(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
+ throw new UnsupportedOperationException("probabilityOfBackoffState_helper undefined for TrieLM");
+ }
+
+ @Override
+ protected float ngramLogProbability_helper(int[] ngram, int order) {
-
- // float logProb = (float) -JoshuaConfiguration.lm_ceiling_cost;//Float.NEGATIVE_INFINITY; // log(0.0f)
- float backoff = 0.0f; // log(1.0f)
-
- int i = ngram.length - 1;
- int word = ngram[i];
- i -= 1;
-
- int nodeID = ROOT_NODE_ID;
-
- while (true) {
-
- {
- long key = Bits.encodeAsLong(nodeID, word);
- if (logProbs.containsKey(key)) {
- // logProb = logProbs.get(key);
- backoff = 0.0f; // log(0.0f)
- }
- }
-
- if (i < 0) {
- break;
- }
-
- {
- long key = Bits.encodeAsLong(nodeID, ngram[i]);
-
- if (children.containsKey(key)) {
- nodeID = children.get(key);
-
- backoff += backoffs.get(nodeID);
-
- i -= 1;
-
- } else {
- break;
- }
- }
-
- }
-
- // double result = logProb + backoff;
- // if (result < -JoshuaConfiguration.lm_ceiling_cost) {
- // result = -JoshuaConfiguration.lm_ceiling_cost;
- // }
- //
- // return result;
- return (Float) null;
++ throw new UnsupportedOperationException();
+ }
+
+ public Map<Long,Integer> getChildren() {
+ return this.children;
+ }
+
+ public static void main(String[] args) throws IOException {
+
+ LOG.info("Constructing ARPA file");
+ ArpaFile arpaFile = new ArpaFile(args[0]);
+
+ LOG.info("Getting symbol table");
+ Vocabulary vocab = arpaFile.getVocab();
+
+ LOG.info("Constructing TrieLM");
+ TrieLM lm = new TrieLM(arpaFile);
+
+ int n = Integer.valueOf(args[2]);
+ LOG.info("N-gram order will be {}", n);
+
- Scanner scanner = new Scanner(new File(args[1]));
++ try (Scanner scanner = new Scanner(new File(args[1]));) {
++ LinkedList<String> wordList = new LinkedList<>();
++ LinkedList<String> window = new LinkedList<>();
+
- LinkedList<String> wordList = new LinkedList<>();
- LinkedList<String> window = new LinkedList<>();
++ LOG.info("Starting to scan {}", args[1]);
++ while (scanner.hasNext()) {
+
- LOG.info("Starting to scan {}", args[1]);
- while (scanner.hasNext()) {
++ LOG.info("Getting next line...");
++ String line = scanner.nextLine();
++ LOG.info("Line: {}", line);
+
- LOG.info("Getting next line...");
- String line = scanner.nextLine();
- LOG.info("Line: {}", line);
++ String[] words = Regex.spaces.split(line);
++ wordList.clear();
+
- String[] words = Regex.spaces.split(line);
- wordList.clear();
++ wordList.add("<s>");
++ Collections.addAll(wordList, words);
++ wordList.add("</s>");
+
- wordList.add("<s>");
- Collections.addAll(wordList, words);
- wordList.add("</s>");
-
- ArrayList<Integer> sentence = new ArrayList<>();
- // int[] ids = new int[wordList.size()];
- for (String aWordList : wordList) {
- sentence.add(Vocabulary.id(aWordList));
- // ids[i] = ;
- }
++ ArrayList<Integer> sentence = new ArrayList<>();
++ // int[] ids = new int[wordList.size()];
++ for (String aWordList : wordList) {
++ sentence.add(Vocabulary.id(aWordList));
++ // ids[i] = ;
++ }
+
++ while (!wordList.isEmpty()) {
++ window.clear();
+
++ {
++ int i = 0;
++ for (String word : wordList) {
++ if (i >= n)
++ break;
++ window.add(word);
++ i++;
++ }
++ wordList.remove();
++ }
+
- while (! wordList.isEmpty()) {
- window.clear();
++ {
++ int i = 0;
++ int[] wordIDs = new int[window.size()];
++ for (String word : window) {
++ wordIDs[i] = Vocabulary.id(word);
++ i++;
++ }
+
- {
- int i=0;
- for (String word : wordList) {
- if (i>=n) break;
- window.add(word);
- i++;
++ LOG.info("logProb {} = {}", window, lm.ngramLogProbability(wordIDs, n));
+ }
- wordList.remove();
+ }
+
- {
- int i=0;
- int[] wordIDs = new int[window.size()];
- for (String word : window) {
- wordIDs[i] = Vocabulary.id(word);
- i++;
- }
++ double logProb = lm.sentenceLogProbability(sentence, n, 2);// .ngramLogProbability(ids,
++ // n);
++ double prob = Math.exp(logProb);
+
- LOG.info("logProb {} = {}", window, lm.ngramLogProbability(wordIDs, n));
- }
++ LOG.info("Total logProb = {}", logProb);
++ LOG.info("Total prob = {}", prob);
+ }
-
- double logProb = lm.sentenceLogProbability(sentence, n, 2);//.ngramLogProbability(ids, n);
- double prob = Math.exp(logProb);
-
- LOG.info("Total logProb = {}", logProb);
- LOG.info("Total prob = {}", prob);
+ }
-
+ }
+
+ @Override
+ public boolean isOov(int id) {
+ throw new RuntimeException("Not implemented!");
+ }
+
+}