You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by le...@apache.org on 2016/05/16 06:26:32 UTC
[16/66] [partial] incubator-joshua git commit: JOSHUA-252 Make it
possible to use Maven to build Joshua
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/FragmentLMFF.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/FragmentLMFF.java b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/FragmentLMFF.java
new file mode 100644
index 0000000..0375dc0
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/FragmentLMFF.java
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.fragmentlm;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Stack;
+
+import joshua.decoder.JoshuaConfiguration;
+import joshua.decoder.chart_parser.SourcePath;
+import joshua.decoder.ff.FeatureVector;
+import joshua.decoder.ff.StatefulFF;
+import joshua.decoder.ff.state_maintenance.DPState;
+import joshua.decoder.ff.tm.Rule;
+import joshua.decoder.ff.tm.format.HieroFormatReader;
+import joshua.decoder.hypergraph.HGNode;
+import joshua.decoder.hypergraph.HyperEdge;
+import joshua.decoder.segment_file.Sentence;
+
+/**
+ * Feature function that reads in a list of language model fragments and matches them against the
+ * hypergraph. This allows for language model fragment "glue" features, which fire when LM fragments
+ * (supplied as input) are assembled. These LM fragments are presumably useful in ensuring
+ * grammaticality and can be independent of the translation model fragments.
+ *
+ * Usage: in the Joshua Configuration file, put
+ *
+ * feature-function = FragmentLM -lm LM_FRAGMENTS_FILE -map RULE_FRAGMENTS_MAP_FILE
+ *
+ * LM_FRAGMENTS_FILE is a pointer to a file containing a list of fragments that it should look for.
+ * The format of the file is one fragment per line in PTB format, e.g.:
+ *
+ * (S NP (VP (VBD said) SBAR) (. .))
+ *
+ * RULE_FRAGMENTS_MAP_FILE points to a file that maps fragments to the flattened SCFG rule format
+ * that Joshua uses. This mapping is necessary because Joshua's rules have been flattened, meaning
+ * that their internal structure has been removed, yet this structure is needed for matching LM
+ * fragments. The format of the file is
+ *
+ * FRAGMENT ||| RULE-TARGET-SIDE
+ *
+ * for example,
+ *
+ * (S (NP (DT the) (NN man)) VP .) ||| the man [VP,1] [.,2] (SBAR (IN that) (S (NP (PRP he)) (VP
+ * (VBD was) (VB done)))) ||| that he was done (VP (VBD said) SBAR) ||| said SBAR
+ *
+ * @author Matt Post <po...@cs.jhu.edu>
+ */
+public class FragmentLMFF extends StatefulFF {
+
+ /*
+ * When building a fragment from a rule rooted in the hypergraph, this parameter determines how
+ * deep we'll go. Smaller values mean less hypergraph traversal but may also limit the LM
+ * fragments that can be fired.
+ */
+ private int BUILD_DEPTH = 1;
+
+ /*
+ * The maximum depth of a fragment, defined as the longest path from the fragment root to any of
+ * its leaves.
+ */
+ private int MAX_DEPTH = 0;
+
+ /*
+ * This is the minimum depth for lexicalized LM fragments. This allows you to easily exclude small
+ * depth-one fragments that may be overfit to the training data. A depth of 1 (the default) does
+ * not exclude any fragments.
+ */
+ private int MIN_LEX_DEPTH = 1;
+
+ /*
+ * Set to true to activate meta-features.
+ */
+ private boolean OPTS_DEPTH = false;
+
+ /*
+ * This contains a list of the language model fragments, indexed by LHS.
+ */
+ private HashMap<String, ArrayList<Tree>> lmFragments = null;
+
+ private int numFragments = 0;
+
+ /* The location of the file containing the language model fragments */
+ private String fragmentLMFile = "";
+
+ /**
+ * @param weights
+ * @param name
+ * @param stateComputer
+ */
+ public FragmentLMFF(FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "FragmentLMFF", args, config);
+
+ lmFragments = new HashMap<String, ArrayList<Tree>>();
+
+ fragmentLMFile = parsedArgs.get("lm");
+ BUILD_DEPTH = Integer.parseInt(parsedArgs.get("build-depth"));
+ MAX_DEPTH = Integer.parseInt(parsedArgs.get("max-depth"));
+ MIN_LEX_DEPTH = Integer.parseInt(parsedArgs.get("min-lex-depth"));
+
+ /* Read in the language model fragments */
+ try {
+ Collection<Tree> trees = PennTreebankReader.readTrees(fragmentLMFile);
+ for (Tree fragment : trees) {
+ addLMFragment(fragment);
+
+ // System.err.println(String.format("Read fragment: %s",
+ // lmFragments.get(lmFragments.size()-1)));
+ }
+ } catch (IOException e) {
+ System.err.println(String.format("* WARNING: couldn't read fragment LM file '%s'",
+ fragmentLMFile));
+ System.exit(1);
+ }
+ System.err.println(String.format("FragmentLMFF: Read %d LM fragments from '%s'", numFragments,
+ fragmentLMFile));
+ }
+
+ /**
+ * Add the provided fragment to the language model, subject to some filtering.
+ *
+ * @param fragment
+ */
+ public void addLMFragment(Tree fragment) {
+ if (lmFragments == null)
+ return;
+
+ int fragmentDepth = fragment.getDepth();
+
+ if (MAX_DEPTH != 0 && fragmentDepth > MAX_DEPTH) {
+ System.err.println(String.format(" Skipping fragment %s (depth %d > %d)", fragment,
+ fragmentDepth, MAX_DEPTH));
+ return;
+ }
+
+ if (MIN_LEX_DEPTH > 1 && fragment.isLexicalized() && fragmentDepth < MIN_LEX_DEPTH) {
+ System.err.println(String.format(" Skipping fragment %s (lex depth %d < %d)", fragment,
+ fragmentDepth, MIN_LEX_DEPTH));
+ return;
+ }
+
+ if (lmFragments.get(fragment.getRule()) == null)
+ lmFragments.put(fragment.getRule(), new ArrayList<Tree>());
+ lmFragments.get(fragment.getRule()).add(fragment);
+ numFragments++;
+ }
+
+ /**
+ * This function computes the features that fire when the current rule is applied. The features
+ * that fire are any LM fragments that match the fragment associated with the current rule. LM
+ * fragments may recurse over the tail nodes, following 1-best backpointers until the fragment
+ * either matches or fails.
+ */
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ /*
+ * Get the fragment associated with the target side of this rule.
+ *
+ * This could be done more efficiently. For example, just build the tree fragment once and then
+ * pattern match against it. This would circumvent having to build the tree possibly once every
+ * time you try to apply a rule.
+ */
+ Tree baseTree = Tree.buildTree(rule, tailNodes, BUILD_DEPTH);
+
+ Stack<Tree> nodeStack = new Stack<Tree>();
+ nodeStack.add(baseTree);
+ while (!nodeStack.empty()) {
+ Tree tree = nodeStack.pop();
+ if (tree == null)
+ continue;
+
+ if (lmFragments.get(tree.getRule()) != null) {
+ for (Tree fragment : lmFragments.get(tree.getRule())) {
+// System.err.println(String.format("Does\n %s match\n %s??\n -> %s", fragment, tree,
+// match(fragment, tree)));
+
+ if (fragment.getLabel() == tree.getLabel() && match(fragment, tree)) {
+// System.err.println(String.format(" FIRING: matched %s against %s", fragment, tree));
+ acc.add(fragment.escapedString(), 1);
+ if (OPTS_DEPTH)
+ if (fragment.isLexicalized())
+ acc.add(String.format("FragmentFF_lexdepth%d", fragment.getDepth()), 1);
+ else
+ acc.add(String.format("FragmentFF_depth%d", fragment.getDepth()), 1);
+ }
+ }
+ }
+
+ // We also need to try matching rules against internal nodes of the fragment corresponding to
+ // this
+ // rule
+ if (tree.getChildren() != null)
+ for (Tree childNode : tree.getChildren()) {
+ if (!childNode.isBoundary())
+ nodeStack.add(childNode);
+ }
+ }
+
+ return new FragmentState(baseTree);
+ }
+
+ /**
+ * Matches the fragment against the (possibly partially-built) tree. Assumption
+ *
+ * @param fragment the language model fragment
+ * @param tree the tree to match against (expanded from the hypergraph)
+ * @return
+ */
+ private boolean match(Tree fragment, Tree tree) {
+ // System.err.println(String.format("MATCH(%s,%s)", fragment, tree));
+
+ /* Make sure the root labels match. */
+ if (fragment.getLabel() != tree.getLabel()) {
+ return false;
+ }
+
+ /* Same number of kids? */
+ List<Tree> fkids = fragment.getChildren();
+ if (fkids.size() > 0) {
+ List<Tree> tkids = tree.getChildren();
+ if (fkids.size() != tkids.size()) {
+ return false;
+ }
+
+ /* Do the kids match on all labels? */
+ for (int i = 0; i < fkids.size(); i++)
+ if (fkids.get(i).getLabel() != tkids.get(i).getLabel())
+ return false;
+
+ /* Recursive match. */
+ for (int i = 0; i < fkids.size(); i++) {
+ if (!match(fkids.get(i), tkids.get(i)))
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ @Override
+ public DPState computeFinal(HGNode tailNodes, int i, int j, SourcePath sourcePath, Sentence sentence,
+ Accumulator acc) {
+ // TODO Auto-generated method stub
+ return null;
+ }
+
+ @Override
+ public float estimateFutureCost(Rule rule, DPState state, Sentence sentence) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public float estimateCost(Rule rule, Sentence sentence) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ public static void main(String[] args) {
+ /* Add an LM fragment, then create a dummy multi-level hypergraph to match the fragment against. */
+ // FragmentLMFF fragmentLMFF = new FragmentLMFF(new FeatureVector(), (StateComputer) null, "");
+ FragmentLMFF fragmentLMFF = new FragmentLMFF(new FeatureVector(),
+ new String[] {"-lm", "test/fragments.txt", "-map", "test/mapping.txt"}, null);
+
+ Tree fragment = Tree.fromString("(S NP (VP (VBD \"said\") SBAR) (. \".\"))");
+
+ Rule ruleS = new HieroFormatReader()
+ .parseLine("[S] ||| the man [VP,1] [.,2] ||| the man [VP,1] [.,2] ||| 0");
+ Rule ruleVP = new HieroFormatReader()
+ .parseLine("[VP] ||| said [SBAR,1] ||| said [SBAR,1] ||| 0");
+ Rule ruleSBAR = new HieroFormatReader()
+ .parseLine("[SBAR] ||| that he was done ||| that he was done ||| 0");
+ Rule rulePERIOD = new HieroFormatReader().parseLine("[.] ||| . ||| . ||| 0");
+
+ ruleS.setOwner(0);
+ ruleVP.setOwner(0);
+ ruleSBAR.setOwner(0);
+ rulePERIOD.setOwner(0);
+
+ HyperEdge edgeSBAR = new HyperEdge(ruleSBAR, 0.0f, 0.0f, null, (SourcePath) null);
+
+ HGNode nodeSBAR = new HGNode(3, 7, ruleSBAR.getLHS(), null, edgeSBAR, 0.0f);
+ ArrayList<HGNode> tailNodesVP = new ArrayList<HGNode>();
+ Collections.addAll(tailNodesVP, nodeSBAR);
+ HyperEdge edgeVP = new HyperEdge(ruleVP, 0.0f, 0.0f, tailNodesVP, (SourcePath) null);
+ HGNode nodeVP = new HGNode(2, 7, ruleVP.getLHS(), null, edgeVP, 0.0f);
+
+ HyperEdge edgePERIOD = new HyperEdge(rulePERIOD, 0.0f, 0.0f, null, (SourcePath) null);
+ HGNode nodePERIOD = new HGNode(7, 8, rulePERIOD.getLHS(), null, edgePERIOD, 0.0f);
+
+ ArrayList<HGNode> tailNodes = new ArrayList<HGNode>();
+ Collections.addAll(tailNodes, nodeVP, nodePERIOD);
+
+ Tree tree = Tree.buildTree(ruleS, tailNodes, 1);
+ boolean matched = fragmentLMFF.match(fragment, tree);
+ System.err.println(String.format("Does\n %s match\n %s??\n -> %s", fragment, tree, matched));
+ }
+
+ /**
+ * Maintains a state pointer used by KenLM to implement left-state minimization.
+ *
+ * @author Matt Post <po...@cs.jhu.edu>
+ * @author Juri Ganitkevitch <ju...@cs.jhu.edu>
+ */
+ public class FragmentState extends DPState {
+
+ private Tree tree = null;
+
+ public FragmentState(Tree tree) {
+ this.tree = tree;
+ }
+
+ /**
+ * Every tree is unique.
+ *
+ * Some savings could be had here if we grouped together items with the same string.
+ */
+ @Override
+ public int hashCode() {
+ return tree.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return (other instanceof FragmentState && this == other);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("[FragmentState %s]", tree);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/PennTreebankReader.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/PennTreebankReader.java b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/PennTreebankReader.java
new file mode 100644
index 0000000..6ab52e1
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/PennTreebankReader.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.fragmentlm;
+
+import java.util.*;
+import java.io.*;
+import java.nio.charset.Charset;
+import java.nio.charset.UnsupportedCharsetException;
+
+/**
+ * @author Dan Klein
+ */
+public class PennTreebankReader {
+
+ static class TreeCollection extends AbstractCollection<Tree> {
+
+ List<File> files;
+ Charset charset;
+
+ static class TreeIteratorIterator implements Iterator<Iterator<Tree>> {
+ Iterator<File> fileIterator;
+ Iterator<Tree> nextTreeIterator;
+ Charset charset;
+
+ public boolean hasNext() {
+ return nextTreeIterator != null;
+ }
+
+ public Iterator<Tree> next() {
+ Iterator<Tree> currentTreeIterator = nextTreeIterator;
+ advance();
+ return currentTreeIterator;
+ }
+
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ private void advance() {
+ nextTreeIterator = null;
+ while (nextTreeIterator == null && fileIterator.hasNext()) {
+ File file = fileIterator.next();
+ // System.out.println(file);
+ try {
+ nextTreeIterator = new Trees.PennTreeReader(new BufferedReader(new InputStreamReader(
+ new FileInputStream(file), this.charset)));
+ } catch (FileNotFoundException e) {
+ } catch (UnsupportedCharsetException e) {
+ throw new Error("Unsupported charset in file " + file.getPath());
+ }
+ }
+ }
+
+ TreeIteratorIterator(List<File> files, Charset charset) {
+ this.fileIterator = files.iterator();
+ this.charset = charset;
+ advance();
+ }
+ }
+
+ public Iterator<Tree> iterator() {
+ return new ConcatenationIterator<Tree>(new TreeIteratorIterator(files, this.charset));
+ }
+
+ public int size() {
+ int size = 0;
+ Iterator<Tree> i = iterator();
+ while (i.hasNext()) {
+ size++;
+ i.next();
+ }
+ return size;
+ }
+
+ @SuppressWarnings("unused")
+ private List<File> getFilesUnder(String path, FileFilter fileFilter) {
+ File root = new File(path);
+ List<File> files = new ArrayList<File>();
+ addFilesUnder(root, files, fileFilter);
+ return files;
+ }
+
+ private void addFilesUnder(File root, List<File> files, FileFilter fileFilter) {
+ if (!fileFilter.accept(root))
+ return;
+ if (root.isFile()) {
+ files.add(root);
+ return;
+ }
+ if (root.isDirectory()) {
+ File[] children = root.listFiles();
+ for (int i = 0; i < children.length; i++) {
+ File child = children[i];
+ addFilesUnder(child, files, fileFilter);
+ }
+ }
+ }
+
+ public TreeCollection(String file) throws FileNotFoundException, IOException {
+ this.files = new ArrayList<File>();
+ this.files.add(new File(file));
+ this.charset = Charset.defaultCharset();
+ }
+ }
+
+ public static Collection<Tree> readTrees(String path) throws FileNotFoundException, IOException {
+ return new TreeCollection(path);
+ }
+
+ public static void main(String[] args) {
+/* Collection<Tree> trees = readTrees(args[0], Charset.defaultCharset());
+ for (Tree tree : trees) {
+ tree = (new Trees.StandardTreeNormalizer()).transformTree(tree);
+ System.out.println(Trees.PennTreeRenderer.render(tree));
+ }
+ */
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
new file mode 100644
index 0000000..b52ccce
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
@@ -0,0 +1,776 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.fragmentlm;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.io.StringReader;
+import java.util.*;
+
+import joshua.corpus.Vocabulary;
+import joshua.decoder.ff.fragmentlm.Trees.PennTreeReader;
+import joshua.decoder.ff.tm.Rule;
+import joshua.decoder.hypergraph.HGNode;
+import joshua.decoder.hypergraph.HyperEdge;
+import joshua.decoder.hypergraph.KBestExtractor.DerivationState;
+import joshua.util.io.LineReader;
+
+/**
+ * Represent phrase-structure trees, with each node consisting of a label and a list of children.
+ * Borrowed from the Berkeley Parser, and extended to allow the representation of tree fragments in
+ * addition to complete trees (the BP requires terminals to be immediately governed by a
+ * preterminal). To distinguish terminals from nonterminals in fragments, the former must be
+ * enclosed in double-quotes when read in.
+ *
+ * @author Dan Klein
+ * @author Matt Post <po...@cs.jhu.edu>
+ */
+public class Tree implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ protected int label;
+
+ /* Marks a frontier node as a terminal (as opposed to a nonterminal). */
+ boolean isTerminal = false;
+
+ /*
+ * Marks the root and frontier nodes of a fragment. Useful for denoting fragment derivations in
+ * larger trees.
+ */
+ boolean isBoundary = false;
+
+ /* A list of the node's children. */
+ List<Tree> children;
+
+ /* The maximum distance from the root to any of the frontier nodes. */
+ int depth = -1;
+
+ /* The number of lexicalized items among the tree's frontier. */
+ private int numLexicalItems = -1;
+
+ /*
+ * This maps the flat right-hand sides of Joshua rules to the tree fragments they were derived
+ * from. It is used to lookup the fragment that language model fragments should be match against.
+ * For example, if the target (English) side of your rule is
+ *
+ * [NP,1] said [SBAR,2]
+ *
+ * we will retrieve the unflattened fragment
+ *
+ * (S NP (VP (VBD said) SBAR))
+ *
+ * which presumably was the fronter fragment used to derive the translation rule. With this in
+ * hand, we can iterate through our store of language model fragments to match them against this,
+ * following tail nodes if necessary.
+ */
+ public static HashMap<String, String> rulesToFragmentStrings = new HashMap<String, String>();
+
+ public Tree(String label, List<Tree> children) {
+ setLabel(label);
+ this.children = children;
+ }
+
+ public Tree(String label) {
+ setLabel(label);
+ this.children = Collections.emptyList();
+ }
+
+ public Tree(int label2, ArrayList<Tree> newChildren) {
+ this.label = label2;
+ this.children = newChildren;
+ }
+
+ public void setChildren(List<Tree> c) {
+ this.children = c;
+ }
+
+ public List<Tree> getChildren() {
+ return children;
+ }
+
+ public int getLabel() {
+ return label;
+ }
+
+ /**
+ * Computes the depth-one rule rooted at this node. If the node has no children, null is returned.
+ *
+ * @return
+ */
+ public String getRule() {
+ if (isLeaf()) {
+ return null;
+ }
+ StringBuilder ruleString = new StringBuilder("(" + Vocabulary.word(getLabel()));
+ for (Tree child : getChildren()) {
+ ruleString.append(" ").append(Vocabulary.word(child.getLabel()));
+ }
+ return ruleString.toString();
+ }
+
+ /*
+ * Boundary nodes are used externally to mark merge points between different fragments. This is
+ * separate from the internal ( (substitution point) denotation.
+ */
+ public boolean isBoundary() {
+ return isBoundary;
+ }
+
+ public void setBoundary(boolean b) {
+ this.isBoundary = b;
+ }
+
+ public boolean isTerminal() {
+ return isTerminal;
+ }
+
+ public boolean isLeaf() {
+ return getChildren().isEmpty();
+ }
+
+ public boolean isPreTerminal() {
+ return getChildren().size() == 1 && getChildren().get(0).isLeaf();
+ }
+
+ public List<Tree> getNonterminalYield() {
+ List<Tree> yield = new ArrayList<Tree>();
+ appendNonterminalYield(this, yield);
+ return yield;
+ }
+
+ public List<Tree> getYield() {
+ List<Tree> yield = new ArrayList<Tree>();
+ appendYield(this, yield);
+ return yield;
+ }
+
+ public List<Tree> getTerminals() {
+ List<Tree> yield = new ArrayList<Tree>();
+ appendTerminals(this, yield);
+ return yield;
+ }
+
+ private static void appendTerminals(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendTerminals(child, yield);
+ }
+ }
+
+ /**
+ * Clone the structure of the tree.
+ *
+ * @return a cloned tree
+ */
+ public Tree shallowClone() {
+ ArrayList<Tree> newChildren = new ArrayList<Tree>(children.size());
+ for (Tree child : children) {
+ newChildren.add(child.shallowClone());
+ }
+
+ Tree newTree = new Tree(label, newChildren);
+ newTree.setIsTerminal(isTerminal());
+ newTree.setBoundary(isBoundary());
+ return newTree;
+ }
+
+ private void setIsTerminal(boolean terminal) {
+ isTerminal = terminal;
+ }
+
+ private static void appendNonterminalYield(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf() && !tree.isTerminal()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendNonterminalYield(child, yield);
+ }
+ }
+
+ private static void appendYield(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendYield(child, yield);
+ }
+ }
+
+ public List<Tree> getPreTerminalYield() {
+ List<Tree> yield = new ArrayList<Tree>();
+ appendPreTerminalYield(this, yield);
+ return yield;
+ }
+
+ private static void appendPreTerminalYield(Tree tree, List<Tree> yield) {
+ if (tree.isPreTerminal()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendPreTerminalYield(child, yield);
+ }
+ }
+
+ /**
+ * A tree is lexicalized if it has terminal nodes among the leaves of its frontier. For normal
+ * trees this is always true since they bottom out in terminals, but for fragments, this may or
+ * may not be true.
+ */
+ public boolean isLexicalized() {
+ if (this.numLexicalItems < 0) {
+ if (isTerminal())
+ this.numLexicalItems = 1;
+ else {
+ this.numLexicalItems = 0;
+ for (Tree child : children)
+ if (child.isLexicalized())
+ this.numLexicalItems += 1;
+ }
+ }
+
+ return (this.numLexicalItems > 0);
+ }
+
+ /**
+ * The depth of a tree is the maximum distance from the root to any of the frontier nodes.
+ *
+ * @return the tree depth
+ */
+ public int getDepth() {
+ if (this.depth >= 0)
+ return this.depth;
+
+ if (isLeaf()) {
+ this.depth = 0;
+ } else {
+ int maxDepth = 0;
+ for (Tree child : children) {
+ int depth = child.getDepth();
+ if (depth > maxDepth)
+ maxDepth = depth;
+ }
+ this.depth = maxDepth + 1;
+ }
+ return this.depth;
+ }
+
+ public List<Tree> getAtDepth(int depth) {
+ List<Tree> yield = new ArrayList<Tree>();
+ appendAtDepth(depth, this, yield);
+ return yield;
+ }
+
+ private static void appendAtDepth(int depth, Tree tree, List<Tree> yield) {
+ if (depth < 0)
+ return;
+ if (depth == 0) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendAtDepth(depth - 1, child, yield);
+ }
+ }
+
+ public void setLabel(String label) {
+ if (label.length() >= 3 && label.startsWith("\"") && label.endsWith("\"")) {
+ this.isTerminal = true;
+ label = label.substring(1, label.length() - 1);
+ }
+
+ this.label = Vocabulary.id(label);
+ }
+
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ toStringBuilder(sb);
+ return sb.toString();
+ }
+
+ /**
+ * Removes the quotes around terminals. Note that the resulting tree could not be read back
+ * in by this class, since unquoted leaves are interpreted as nonterminals.
+ *
+ * @return
+ */
+ public String unquotedString() {
+ return toString().replaceAll("\"", "");
+ }
+
+ public String escapedString() {
+ return toString().replaceAll(" ", "_");
+ }
+
+ public void toStringBuilder(StringBuilder sb) {
+ if (!isLeaf())
+ sb.append('(');
+
+ if (isTerminal())
+ sb.append(String.format("\"%s\"", Vocabulary.word(getLabel())));
+ else
+ sb.append(Vocabulary.word(getLabel()));
+
+ if (!isLeaf()) {
+ for (Tree child : getChildren()) {
+ sb.append(' ');
+ child.toStringBuilder(sb);
+ }
+ sb.append(')');
+ }
+ }
+
+ /**
+ * Get the set of all subtrees inside the tree by returning a tree rooted at each node. These are
+ * <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
+ *
+ * @return the <code>Set</code> of all subtrees in the tree.
+ */
+ public Set<Tree> subTrees() {
+ return (Set<Tree>) subTrees(new HashSet<Tree>());
+ }
+
+ /**
+ * Get the list of all subtrees inside the tree by returning a tree rooted at each node. These are
+ * <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
+ *
+ * @return the <code>List</code> of all subtrees in the tree.
+ */
+ public List<Tree> subTreeList() {
+ return (List<Tree>) subTrees(new ArrayList<Tree>());
+ }
+
+ /**
+ * Add the set of all subtrees inside a tree (including the tree itself) to the given
+ * <code>Collection</code>.
+ *
+ * @param n A collection of nodes to which the subtrees will be added
+ * @return The collection parameter with the subtrees added
+ */
+ public Collection<Tree> subTrees(Collection<Tree> n) {
+ n.add(this);
+ List<Tree> kids = getChildren();
+ for (Tree kid : kids) {
+ kid.subTrees(n);
+ }
+ return n;
+ }
+
+ /**
+ * Returns an iterator over the nodes of the tree. This method implements the
+ * <code>iterator()</code> method required by the <code>Collections</code> interface. It does a
+ * preorder (children after node) traversal of the tree. (A possible extension to the class at
+ * some point would be to allow different traversal orderings via variant iterators.)
+ *
+ * @return An interator over the nodes of the tree
+ */
+ public TreeIterator iterator() {
+ return new TreeIterator();
+ }
+
+ private class TreeIterator implements Iterator<Tree> {
+
+ private List<Tree> treeStack;
+
+ private TreeIterator() {
+ treeStack = new ArrayList<Tree>();
+ treeStack.add(Tree.this);
+ }
+
+ public boolean hasNext() {
+ return (!treeStack.isEmpty());
+ }
+
+ public Tree next() {
+ int lastIndex = treeStack.size() - 1;
+ Tree tr = treeStack.remove(lastIndex);
+ List<Tree> kids = tr.getChildren();
+ // so that we can efficiently use one List, we reverse them
+ for (int i = kids.size() - 1; i >= 0; i--) {
+ treeStack.add(kids.get(i));
+ }
+ return tr;
+ }
+
+ /**
+ * Not supported
+ */
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ }
+
+ public boolean hasUnaryChain() {
+ return hasUnaryChainHelper(this, false);
+ }
+
+ private boolean hasUnaryChainHelper(Tree tree, boolean unaryAbove) {
+ boolean result = false;
+ if (tree.getChildren().size() == 1) {
+ if (unaryAbove)
+ return true;
+ else if (tree.getChildren().get(0).isPreTerminal())
+ return false;
+ else
+ return hasUnaryChainHelper(tree.getChildren().get(0), true);
+ } else {
+ for (Tree child : tree.getChildren()) {
+ if (!child.isPreTerminal())
+ result = result || hasUnaryChainHelper(child, false);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Inserts the SOS (and EOS) symbols into a parse tree, attaching them as a left (right) sibling
+ * to the leftmost (rightmost) pre-terminal in the tree. This facilitates using trees as language
+ * models. The arguments have to be passed in to preserve Java generics, even though this is only
+ * ever used with String versions.
+ *
+ * @param sos presumably "<s>"
+ * @param eos presumably "</s>"
+ */
+ public void insertSentenceMarkers(String sos, String eos) {
+ insertSentenceMarker(sos, 0);
+ insertSentenceMarker(eos, -1);
+ }
+
+ public void insertSentenceMarkers() {
+ insertSentenceMarker("<s>", 0);
+ insertSentenceMarker("</s>", -1);
+ }
+
+ /**
+ *
+ * @param symbol
+ * @param pos
+ */
+ private void insertSentenceMarker(String symbol, int pos) {
+
+ if (isLeaf() || isPreTerminal())
+ return;
+
+ List<Tree> children = getChildren();
+ int index = (pos == -1) ? children.size() - 1 : pos;
+ if (children.get(index).isPreTerminal()) {
+ if (pos == -1)
+ children.add(new Tree(symbol));
+ else
+ children.add(pos, new Tree(symbol));
+ } else {
+ children.get(index).insertSentenceMarker(symbol, pos);
+ }
+ }
+
+ /**
+ * This is a convenience function for producing a fragment from its string representation.
+ */
+ public static Tree fromString(String ptbStr) {
+ PennTreeReader reader = new PennTreeReader(new StringReader(ptbStr));
+ Tree fragment = reader.next();
+ return fragment;
+ }
+
+ public static Tree getFragmentFromYield(String yield) {
+ String fragmentString = rulesToFragmentStrings.get(yield);
+ if (fragmentString != null)
+ return fromString(fragmentString);
+
+ return null;
+ }
+
+ public static void readMapping(String fragmentMappingFile) {
+ /* Read in the rule / fragments mapping */
+ try {
+ LineReader reader = new LineReader(fragmentMappingFile);
+ for (String line : reader) {
+ String[] fields = line.split("\\s+\\|{3}\\s+");
+ if (fields.length != 2 || !fields[0].startsWith("(")) {
+ System.err.println(String.format("* WARNING: malformed line %d: %s", reader.lineno(),
+ line));
+ continue;
+ }
+
+ rulesToFragmentStrings.put(fields[1].trim(), fields[0].trim()); // buildFragment(fields[0]));
+ }
+ } catch (IOException e) {
+ System.err.println(String.format("* WARNING: couldn't read fragment mapping file '%s'",
+ fragmentMappingFile));
+ System.exit(1);
+ }
+ System.err.println(String.format("FragmentLMFF: Read %d mappings from '%s'",
+ rulesToFragmentStrings.size(), fragmentMappingFile));
+ }
+
+ /**
+ * Builds a tree from the kth-best derivation state. This is done by initializing the tree with
+ * the internal fragment corresponding to the rule; this will be the top of the tree. We then
+ * recursively visit the derivation state objects, following the route through the hypergraph
+ * defined by them.
+ *
+ * This function is like the other buildTree() function, but that one simply follows the best
+ * incoming hyperedge for each node.
+ *
+ * @param rule
+ * @param tailNodes
+ * @param derivation - should not be null
+ * @param maxDepth
+ * @return
+ */
+ public static Tree buildTree(Rule rule, DerivationState[] derivationStates, int maxDepth) {
+ Tree tree = getFragmentFromYield(rule.getEnglishWords());
+
+ if (tree == null) {
+ return null;
+ }
+
+ tree = tree.shallowClone();
+
+ System.err.println(String.format("buildTree(%s)", tree));
+ for (int i = 0; i < derivationStates.length; i++) {
+ System.err.println(String.format(" -> %d: %s", i, derivationStates[i]));
+ }
+
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ /* The English side of a rule is a sequence of integers. Nonnegative integers are word
+ * indices in the Vocabulary, while negative indices are used to nonterminals. These negative
+ * indices are a *permutation* of the source side nonterminals, which contain the actual
+ * nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
+ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * the incoming DerivationState items, which are ordered by the source side.
+ */
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] englishInts = rule.getEnglish();
+ for (int i = 0; i < englishInts.length; i++)
+ if (englishInts[i] < 0)
+ tailIndices.add(-(englishInts[i] + 1));
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the heads of the derivation states. Since we don't know which of the tree's
+ * frontier items are terminals and which are nonterminals, we walk through the tail nodes,
+ * and then match the label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < derivationStates.length; i++) {
+
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ HyperEdge nextEdge = derivationStates[i].edge;
+ if (nextEdge != null) {
+ DerivationState[] nextStates = null;
+ if (nextEdge.getTailNodes() != null && nextEdge.getTailNodes().size() > 0) {
+ nextStates = new DerivationState[nextEdge.getTailNodes().size()];
+ for (int j = 0; j < nextStates.length; j++)
+ nextStates[j] = derivationStates[i].getChildDerivationState(nextEdge, j);
+ }
+ Tree childTree = buildTree(nextEdge.getRule(), nextStates, maxDepth - 1);
+
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ } else {
+ frontierTree.children = tree.children;
+ }
+ }
+
+ return tree;
+ }
+
+ /**
+ * Builds a tree from the kth-best derivation state. This is done by initializing the tree with
+ * the internal fragment corresponding to the rule; this will be the top of the tree. We then
+ * recursively visit the derivation state objects, following the route through the hypergraph
+ * defined by them.
+ *
+ * This function is like the other buildTree() function, but that one simply follows the best
+ * incoming hyperedge for each node.
+ *
+ * @param rule
+ * @param tailNodes
+ * @param derivation
+ * @param maxDepth
+ * @return
+ */
+ public static Tree buildTree(DerivationState derivationState, int maxDepth) {
+ Rule rule = derivationState.edge.getRule();
+
+ Tree tree = getFragmentFromYield(rule.getEnglishWords());
+
+ if (tree == null) {
+ return null;
+ }
+
+ tree = tree.shallowClone();
+
+ System.err.println(String.format("buildTree(%s)", tree));
+
+ if (rule.getArity() > 0 && maxDepth > 0) {
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ /* The English side of a rule is a sequence of integers. Nonnegative integers are word
+ * indices in the Vocabulary, while negative indices are used to nonterminals. These negative
+ * indices are a *permutation* of the source side nonterminals, which contain the actual
+ * nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
+ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * the incoming DerivationState items, which are ordered by the source side.
+ */
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] englishInts = rule.getEnglish();
+ for (int i = 0; i < englishInts.length; i++)
+ if (englishInts[i] < 0)
+ tailIndices.add(-(englishInts[i] + 1));
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the heads of the derivation states. Since we don't know which of the tree's
+ * frontier items are terminals and which are nonterminals, we walk through the tail nodes,
+ * and then match the label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < rule.getArity(); i++) {
+
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ DerivationState childState = derivationState.getChildDerivationState(derivationState.edge, i);
+ Tree childTree = buildTree(childState, maxDepth - 1);
+
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ }
+ }
+
+ return tree;
+ }
+
+ /**
+ * Takes a rule and its tail pointers and recursively constructs a tree (up to maxDepth).
+ *
+ * This could be implemented by using the other buildTree() function and using the 1-best
+ * DerivationState.
+ *
+ * @param rule
+ * @param tailNodes
+ * @return
+ */
+ public static Tree buildTree(Rule rule, List<HGNode> tailNodes, int maxDepth) {
+ Tree tree = getFragmentFromYield(rule.getEnglishWords());
+
+ if (tree == null) {
+ tree = new Tree(String.format("(%s %s)", Vocabulary.word(rule.getLHS()), rule.getEnglishWords()));
+ // System.err.println("COULDN'T FIND " + rule.getEnglishWords());
+ // System.err.println("RULE " + rule);
+ // for (Entry<String, Tree> pair: rulesToFragments.entrySet())
+ // System.err.println(" FOUND " + pair.getKey());
+
+// return null;
+ } else {
+ tree = tree.shallowClone();
+ }
+
+ if (tree != null && tailNodes != null && tailNodes.size() > 0 && maxDepth > 0) {
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] englishInts = rule.getEnglish();
+ for (int i = 0; i < englishInts.length; i++)
+ if (englishInts[i] < 0)
+ tailIndices.add(-1 * englishInts[i] - 1);
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the tail nodes. Since we don't know which of the tree's frontier items are
+ * terminals and which are nonterminals, we walk through the tail nodes, and then match the
+ * label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < tailNodes.size(); i++) {
+
+ // String lhs = tailNodes.get(i).getLHS().replaceAll("[\\[\\]]", "");
+ // System.err.println(String.format(" %d: %s", i, lhs));
+ try {
+ Tree frontierTree = frontier.get(tailIndices.get(i).intValue());
+ frontierTree.setBoundary(true);
+
+ HyperEdge edge = tailNodes.get(i).bestHyperedge;
+ if (edge != null) {
+ Tree childTree = buildTree(edge.getRule(), edge.getTailNodes(), maxDepth - 1);
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ } else {
+ frontierTree.children = tree.children;
+ }
+ } catch (IndexOutOfBoundsException e) {
+ System.err.println(String.format("ERROR at index %d", i));
+ System.err.println(String.format("RULE: %s TREE: %s", rule.getEnglishWords(), tree));
+ System.err.println(" FRONTIER:");
+ for (Tree kid : frontier)
+ System.err.println(" " + kid);
+ e.printStackTrace();
+ System.exit(1);
+ }
+ }
+ }
+
+ return tree;
+ }
+
+ public static void main(String[] args) {
+ LineReader reader = new LineReader(System.in);
+
+ for (String line : reader) {
+ try {
+ Tree tree = Tree.fromString(line);
+ tree.insertSentenceMarkers();
+ System.out.println(tree);
+ } catch (Exception e) {
+ System.out.println("");
+ }
+ }
+
+ /*
+ * Tree fragment = Tree
+ * .fromString("(TOP (S (NP (DT the) (NN boy)) (VP (VBD ate) (NP (DT the) (NN food)))))");
+ * fragment.insertSentenceMarkers("<s>", "</s>");
+ *
+ * System.out.println(fragment);
+ *
+ * ArrayList<Tree> trees = new ArrayList<Tree>(); trees.add(Tree.fromString("(NN \"mat\")"));
+ * trees.add(Tree.fromString("(S (NP DT NN) VP)"));
+ * trees.add(Tree.fromString("(S (NP (DT \"the\") NN) VP)"));
+ * trees.add(Tree.fromString("(S (NP (DT the) NN) VP)"));
+ *
+ * for (Tree tree : trees) { System.out.println(String.format("TREE %s DEPTH %d LEX? %s", tree,
+ * tree.getDepth(), tree.isLexicalized())); }
+ */
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Trees.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Trees.java b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Trees.java
new file mode 100644
index 0000000..94a0f44
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Trees.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.fragmentlm;
+
+import java.io.IOException;
+import java.io.PushbackReader;
+import java.io.Reader;
+import java.io.StringReader;
+import java.util.*;
+
+import joshua.corpus.Vocabulary;
+
+/**
+ * Tools for displaying, reading, and modifying trees. Borrowed from the Berkeley Parser.
+ *
+ * @author Dan Klein
+ */
+public class Trees {
+
+ public static class PennTreeReader implements Iterator<Tree> {
+ public static String ROOT_LABEL = "ROOT";
+
+ PushbackReader in;
+ Tree nextTree;
+
+ public boolean hasNext() {
+ return (nextTree != null);
+ }
+
+ public Tree next() {
+ if (!hasNext())
+ throw new NoSuchElementException();
+ Tree tree = nextTree;
+ nextTree = readRootTree();
+ // System.out.println(nextTree);
+ return tree;
+ }
+
+ private Tree readRootTree() {
+ try {
+ readWhiteSpace();
+ if (!isLeftParen(peek()))
+ return null;
+ return readTree(true);
+ } catch (IOException e) {
+ throw new RuntimeException("Error reading tree.");
+ }
+ }
+
+ private Tree readTree(boolean isRoot) throws IOException {
+ if (!isLeftParen(peek())) {
+ return readLeaf();
+ } else {
+ readLeftParen();
+ String label = readLabel();
+ if (label.length() == 0 && isRoot)
+ label = ROOT_LABEL;
+ List<Tree> children = readChildren();
+ readRightParen();
+ return new Tree(label, children);
+ }
+ }
+
+ private String readLabel() throws IOException {
+ readWhiteSpace();
+ return readText();
+ }
+
+ private String readText() throws IOException {
+ StringBuilder sb = new StringBuilder();
+ int ch = in.read();
+ while (!isWhiteSpace(ch) && !isLeftParen(ch) && !isRightParen(ch)) {
+ sb.append((char) ch);
+ ch = in.read();
+ }
+ in.unread(ch);
+ // System.out.println("Read text: ["+sb+"]");
+ return sb.toString().intern();
+ }
+
+ private List<Tree> readChildren() throws IOException {
+ readWhiteSpace();
+ // if (!isLeftParen(peek()))
+ // return Collections.singletonList(readLeaf());
+ return readChildList();
+ }
+
+ private int peek() throws IOException {
+ int ch = in.read();
+ in.unread(ch);
+ return ch;
+ }
+
+ private Tree readLeaf() throws IOException {
+ String label = readText();
+ return new Tree(label);
+ }
+
+ private List<Tree> readChildList() throws IOException {
+ List<Tree> children = new ArrayList<Tree>();
+ readWhiteSpace();
+ while (!isRightParen(peek())) {
+ children.add(readTree(false));
+ readWhiteSpace();
+ }
+ return children;
+ }
+
+ private void readLeftParen() throws IOException {
+ // System.out.println("Read left.");
+ readWhiteSpace();
+ int ch = in.read();
+ if (!isLeftParen(ch))
+ throw new RuntimeException("Format error reading tree. (leftParen)");
+ }
+
+ private void readRightParen() throws IOException {
+ // System.out.println("Read right.");
+ readWhiteSpace();
+ int ch = in.read();
+
+ if (!isRightParen(ch)) {
+ System.out.println((char) ch);
+ throw new RuntimeException("Format error reading tree. (rightParen)");
+ }
+ }
+
+ private void readWhiteSpace() throws IOException {
+ int ch = in.read();
+ while (isWhiteSpace(ch)) {
+ ch = in.read();
+ }
+ in.unread(ch);
+ }
+
+ private boolean isWhiteSpace(int ch) {
+ return (ch == ' ' || ch == '\t' || ch == '\f' || ch == '\r' || ch == '\n');
+ }
+
+ private boolean isLeftParen(int ch) {
+ return ch == '(';
+ }
+
+ private boolean isRightParen(int ch) {
+ return ch == ')';
+ }
+
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ public PennTreeReader(Reader in) {
+ this.in = new PushbackReader(in);
+ nextTree = readRootTree();
+ // System.out.println(nextTree);
+ }
+ }
+
+ /**
+ * Renderer for pretty-printing trees according to the Penn Treebank indenting guidelines
+ * (mutliline). Adapted from code originally written by Dan Klein and modified by Chris Manning.
+ */
+ public static class PennTreeRenderer {
+
+ /**
+ * Print the tree as done in Penn Treebank merged files. The formatting should be exactly the
+ * same, but we don't print the trailing whitespace found in Penn Treebank trees. The basic
+ * deviation from a bracketed indented tree is to in general collapse the printing of adjacent
+ * preterminals onto one line of tags and words. Additional complexities are that conjunctions
+ * (tag CC) are not collapsed in this way, and that the unlabeled outer brackets are collapsed
+ * onto the same line as the next bracket down.
+ */
+ public static String render(Tree tree) {
+ StringBuilder sb = new StringBuilder();
+ renderTree(tree, 0, false, false, false, true, sb);
+ sb.append('\n');
+ return sb.toString();
+ }
+
+ /**
+ * Display a node, implementing Penn Treebank style layout
+ */
+ private static void renderTree(Tree tree, int indent, boolean parentLabelNull,
+ boolean firstSibling, boolean leftSiblingPreTerminal, boolean topLevel, StringBuilder sb) {
+ // the condition for staying on the same line in Penn Treebank
+ boolean suppressIndent = (parentLabelNull || (firstSibling && tree.isPreTerminal()) || (leftSiblingPreTerminal
+ && tree.isPreTerminal()));
+ if (suppressIndent) {
+ sb.append(' ');
+ } else {
+ if (!topLevel) {
+ sb.append('\n');
+ }
+ for (int i = 0; i < indent; i++) {
+ sb.append(" ");
+ }
+ }
+ if (tree.isLeaf() || tree.isPreTerminal()) {
+ renderFlat(tree, sb);
+ return;
+ }
+ sb.append('(');
+ sb.append(tree.getLabel());
+ renderChildren(tree.getChildren(), indent + 1, false, sb);
+ sb.append(')');
+ }
+
+ private static void renderFlat(Tree tree, StringBuilder sb) {
+ if (tree.isLeaf()) {
+ sb.append(Vocabulary.word(tree.getLabel()));
+ return;
+ }
+ sb.append('(');
+ sb.append(Vocabulary.word(tree.getLabel()));
+ sb.append(' ');
+ sb.append(Vocabulary.word(tree.getChildren().get(0).getLabel()));
+ sb.append(')');
+ }
+
+ private static void renderChildren(List<Tree> children, int indent,
+ boolean parentLabelNull, StringBuilder sb) {
+ boolean firstSibling = true;
+ boolean leftSibIsPreTerm = true; // counts as true at beginning
+ for (Tree child : children) {
+ renderTree(child, indent, parentLabelNull, firstSibling, leftSibIsPreTerm, false, sb);
+ leftSibIsPreTerm = child.isPreTerminal();
+ firstSibling = false;
+ }
+ }
+ }
+
+ public static void main(String[] args) {
+ String ptbTreeString = "((S (NP (DT the) (JJ quick) (JJ brown) (NN fox)) (VP (VBD jumped) (PP (IN over) (NP (DT the) (JJ lazy) (NN dog)))) (. .)))";
+
+ if (args.length > 0) {
+ String tree = "";
+ for (String str : args) {
+ tree += " " + str;
+ }
+ ptbTreeString = tree.substring(1);
+ }
+
+ PennTreeReader reader = new PennTreeReader(new StringReader(ptbTreeString));
+
+ Tree tree = reader.next();
+ System.out.println(PennTreeRenderer.render(tree));
+ System.out.println(tree);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java
new file mode 100644
index 0000000..20f29f1
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.lm;
+
+import java.util.Arrays;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import joshua.corpus.Vocabulary;
+
+/**
+ * This class provides a default implementation for the Equivalent LM State optimization (namely,
+ * don't back off anywhere). It also provides some default implementations for more general
+ * functions on the interface to fall back to more specific ones (e.g. from ArrayList<Integer> to
+ * int[]) and a default implementation for sentenceLogProbability which enumerates the n-grams and
+ * calls calls ngramLogProbability for each of them.
+ *
+ * @author Zhifei Li, <zh...@gmail.com>
+ * @author wren ng thornton <wr...@users.sourceforge.net>
+ */
+public abstract class DefaultNGramLanguageModel implements NGramLanguageModel {
+
+ /** Logger for this class. */
+ private static final Logger logger = Logger.getLogger(DefaultNGramLanguageModel.class.getName());
+
+ protected final int ngramOrder;
+
+ protected float ceiling_cost = -100;
+
+ // ===============================================================
+ // Constructors
+ // ===============================================================
+ public DefaultNGramLanguageModel(int order, float ceiling_cost) {
+ this.ngramOrder = order;
+ this.ceiling_cost = ceiling_cost;
+ }
+
+ public DefaultNGramLanguageModel(int order) {
+ this.ngramOrder = order;
+ }
+
+
+ // ===============================================================
+ // Attributes
+ // ===============================================================
+ @Override
+ public final int getOrder() {
+ return this.ngramOrder;
+ }
+
+
+ // ===============================================================
+ // NGramLanguageModel Methods
+ // ===============================================================
+
+ @Override
+ public boolean registerWord(String token, int id) {
+ // No private LM ID mapping, do nothing
+ return false;
+ }
+
+ @Override
+ public float sentenceLogProbability(int[] sentence, int order, int startIndex) {
+ if (sentence == null) return 0.0f;
+ int sentenceLength = sentence.length;
+ if (sentenceLength <= 0) return 0.0f;
+
+ float probability = 0.0f;
+ // partial ngrams at the beginning
+ for (int j = startIndex; j < order && j <= sentenceLength; j++) {
+ // TODO: startIndex dependents on the order, e.g., this.ngramOrder-1 (in srilm, for 3-gram lm,
+ // start_index=2. othercase, need to check)
+ int[] ngram = Arrays.copyOfRange(sentence, 0, j);
+ double logProb = ngramLogProbability(ngram, order);
+ if (logger.isLoggable(Level.FINE)) {
+ String words = Vocabulary.getWords(ngram);
+ logger.fine("\tlogp ( " + words + " ) = " + logProb);
+ }
+ probability += logProb;
+ }
+
+ // regular-order ngrams
+ for (int i = 0; i <= sentenceLength - order; i++) {
+ int[] ngram = Arrays.copyOfRange(sentence, i, i + order);
+ double logProb = ngramLogProbability(ngram, order);
+ if (logger.isLoggable(Level.FINE)) {
+ String words = Vocabulary.getWords(ngram);
+ logger.fine("\tlogp ( " + words + " ) = " + logProb);
+ }
+ probability += logProb;
+ }
+
+ return probability;
+ }
+
+ @Override
+ public float ngramLogProbability(int[] ngram) {
+ return this.ngramLogProbability(ngram, this.ngramOrder);
+ }
+
+ protected abstract float ngramLogProbability_helper(int[] ngram, int order);
+
+ @Override
+ public float ngramLogProbability(int[] ngram, int order) {
+ if (ngram.length > order) {
+ throw new RuntimeException("ngram length is greather than the max order");
+ }
+ // if (ngram.length==1 && "we".equals(Vocabulary.getWord(ngram[0]))) {
+ // System.err.println("Something weird is about to happen");
+ // }
+
+ int historySize = ngram.length - 1;
+ if (historySize >= order || historySize < 0) {
+ // BUG: use logger or exception. Don't zero default
+ throw new RuntimeException("Error: history size is " + historySize);
+ // return 0;
+ }
+ float probability = ngramLogProbability_helper(ngram, order);
+ if (probability < ceiling_cost) {
+ probability = ceiling_cost;
+ }
+ return probability;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
new file mode 100644
index 0000000..329b631
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.lm;
+
+import joshua.corpus.Vocabulary;
+import joshua.decoder.ff.lm.NGramLanguageModel;
+import joshua.decoder.ff.state_maintenance.KenLMState;
+
+/**
+ * JNI wrapper for KenLM. This version of KenLM supports two use cases, implemented by the separate
+ * feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
+ * lm/left.hh, returning a state pointer representing the KenLM state, while LangaugeModelFF handles
+ * state by itself and just passes in the ngrams for scoring.
+ *
+ * @author Kenneth Heafield
+ * @author Matt Post <po...@cs.jhu.edu>
+ */
+
+public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
+
+ static {
+ try {
+ System.loadLibrary("ken");
+ } catch (UnsatisfiedLinkError e) {
+ System.err.println("* FATAL: Can't find libken.so (libken.dylib on OS X) in $JOSHUA/lib");
+ System.err.println("* This probably means that the KenLM library didn't compile.");
+ System.err.println("* Make sure that BOOST_ROOT is set to the root of your boost");
+ System.err.println("* installation (it's not /opt/local/, the default), change to");
+ System.err.println("* $JOSHUA, and type 'ant kenlm'. If problems persist, see the");
+ System.err.println("* website (joshua-decoder.org).");
+ System.exit(1);
+ }
+ }
+
+ private final long pointer;
+
+ // this is read from the config file, used to set maximum order
+ private final int ngramOrder;
+ // inferred from model file (may be larger than ngramOrder)
+ private final int N;
+ // whether left-state minimization was requested
+ private boolean minimizing;
+
+ private final static native long construct(String file_name);
+
+ private final static native void destroy(long ptr);
+
+ private final static native int order(long ptr);
+
+ private final static native boolean registerWord(long ptr, String word, int id);
+
+ private final static native float prob(long ptr, int words[]);
+
+ private final static native float probForString(long ptr, String[] words);
+
+ private final static native boolean isKnownWord(long ptr, String word);
+
+ private final static native StateProbPair probRule(long ptr, long pool, long words[]);
+
+ private final static native float estimateRule(long ptr, long words[]);
+
+ private final static native float probString(long ptr, int words[], int start);
+
+ public final static native long createPool();
+ public final static native void destroyPool(long pointer);
+
+ public KenLM(int order, String file_name) {
+ ngramOrder = order;
+
+ pointer = construct(file_name);
+ N = order(pointer);
+ }
+
+ /**
+ * Constructor if order is not known.
+ * Order will be inferred from the model.
+ */
+ public KenLM(String file_name) {
+ pointer = construct(file_name);
+ N = order(pointer);
+ ngramOrder = N;
+ }
+
+ public void destroy() {
+ destroy(pointer);
+ }
+
+ public int getOrder() {
+ return ngramOrder;
+ }
+
+ public boolean registerWord(String word, int id) {
+ return registerWord(pointer, word, id);
+ }
+
+ public float prob(int[] words) {
+ return prob(pointer, words);
+ }
+
+ /**
+ * Query for n-gram probability using strings.
+ */
+ public float prob(String[] words) {
+ return probForString(pointer, words);
+ }
+
+ // Apparently Zhifei starts some array indices at 1. Change to 0-indexing.
+ public float probString(int words[], int start) {
+ return probString(pointer, words, start - 1);
+ }
+
+ /**
+ * This function is the bridge to the interface in kenlm/lm/left.hh, which has KenLM score the
+ * whole rule. It takes a list of words and states retrieved from tail nodes (nonterminals in the
+ * rule). Nonterminals have a negative value so KenLM can distinguish them. The sentence number is
+ * needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
+ * state and the LM probability incurred along this rule.
+ *
+ * @param words
+ * @param sentId
+ * @return
+ */
+ public StateProbPair probRule(long[] words, long poolPointer) {
+
+ StateProbPair pair = null;
+ try {
+ pair = probRule(pointer, poolPointer, words);
+ } catch (NoSuchMethodError e) {
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ return pair;
+ }
+
+ /**
+ * Public facing function that estimates the cost of a rule, which value is used for sorting
+ * rules during cube pruning.
+ *
+ * @param words
+ * @return the estimated cost of the rule (the (partial) n-gram probabilities of all words in the rule)
+ */
+ public float estimateRule(long[] words) {
+ float estimate = 0.0f;
+ try {
+ estimate = estimateRule(pointer, words);
+ } catch (NoSuchMethodError e) {
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ return estimate;
+ }
+
+ /**
+ * The start symbol for a KenLM is the Vocabulary.START_SYM.
+ */
+ public String getStartSymbol() {
+ return Vocabulary.START_SYM;
+ }
+
+ public boolean isKnownWord(String word) {
+ return isKnownWord(pointer, word);
+ }
+
+
+ /**
+ * Inner class used to hold the results returned from KenLM with left-state minimization. Note
+ * that inner classes have to be static to be accessible from the JNI!
+ */
+ public static class StateProbPair {
+ public KenLMState state = null;
+ public float prob = 0.0f;
+
+ public StateProbPair(long state, float prob) {
+ this.state = new KenLMState(state);
+ this.prob = prob;
+ }
+ }
+
+ @Override
+ public int compareTo(KenLM other) {
+ if (this == other)
+ return 0;
+ else
+ return -1;
+ }
+
+ /**
+ * These functions are used if KenLM is invoked under LanguageModelFF instead of KenLMFF.
+ */
+ @Override
+ public float sentenceLogProbability(int[] sentence, int order, int startIndex) {
+ return probString(sentence, startIndex);
+ }
+
+ @Override
+ public float ngramLogProbability(int[] ngram, int order) {
+ if (order != N && order != ngram.length)
+ throw new RuntimeException("Lower order not supported.");
+ return prob(ngram);
+ }
+
+ @Override
+ public float ngramLogProbability(int[] ngram) {
+ return prob(ngram);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java b/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
new file mode 100644
index 0000000..a002de7
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/LanguageModelFF.java
@@ -0,0 +1,520 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.lm;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+
+import com.google.common.primitives.Ints;
+
+import joshua.corpus.Vocabulary;
+import joshua.decoder.JoshuaConfiguration;
+import joshua.decoder.Support;
+import joshua.decoder.chart_parser.SourcePath;
+import joshua.decoder.ff.FeatureVector;
+import joshua.decoder.ff.StatefulFF;
+import joshua.decoder.ff.lm.berkeley_lm.LMGrammarBerkeley;
+import joshua.decoder.ff.lm.KenLM;
+import joshua.decoder.ff.state_maintenance.DPState;
+import joshua.decoder.ff.state_maintenance.NgramDPState;
+import joshua.decoder.ff.tm.Rule;
+import joshua.decoder.hypergraph.HGNode;
+import joshua.decoder.segment_file.Sentence;
+
+/**
+ * This class performs the following:
+ * <ol>
+ * <li>Gets the additional LM score due to combinations of small items into larger ones by using
+ * rules
+ * <li>Gets the LM state
+ * <li>Gets the left-side LM state estimation score
+ * </ol>
+ *
+ * @author Matt Post <po...@cs.jhu.edu>
+ * @author Juri Ganitkevitch <ju...@cs.jhu.edu>
+ * @author Zhifei Li, <zh...@gmail.com>
+ */
+public class LanguageModelFF extends StatefulFF {
+
+ public static int LM_INDEX = 0;
+ private int startSymbolId;
+
+ /**
+ * N-gram language model. We assume the language model is in ARPA format for equivalent state:
+ *
+ * <ol>
+ * <li>We assume it is a backoff lm, and high-order ngram implies low-order ngram; absense of
+ * low-order ngram implies high-order ngram</li>
+ * <li>For a ngram, existence of backoffweight => existence a probability Two ways of dealing with
+ * low counts:
+ * <ul>
+ * <li>SRILM: don't multiply zeros in for unknown words</li>
+ * <li>Pharaoh: cap at a minimum score exp(-10), including unknown words</li>
+ * </ul>
+ * </li>
+ */
+ protected NGramLanguageModel languageModel;
+
+ /**
+ * We always use this order of ngram, though the LMGrammar may provide higher order probability.
+ */
+ protected final int ngramOrder;
+
+ /*
+ * We cache the weight of the feature since there is only one.
+ */
+ protected float weight;
+ protected String type;
+ protected String path;
+
+ /* Whether this is a class-based LM */
+ private boolean isClassLM;
+ private ClassMap classMap;
+
+ protected class ClassMap {
+
+ private final int OOV_id = Vocabulary.getUnknownId();
+ private HashMap<Integer, Integer> classMap;
+
+ public ClassMap(String file_name) throws IOException {
+ this.classMap = new HashMap<Integer, Integer>();
+ read(file_name);
+ }
+
+ public int getClassID(int wordID) {
+ return this.classMap.getOrDefault(wordID, OOV_id);
+ }
+
+ /**
+ * Reads a class map from file.
+ *
+ * @param file_name
+ * @throws IOException
+ */
+ private void read(String file_name) throws IOException {
+
+ int lineno = 0;
+ for (String line: new joshua.util.io.LineReader(file_name, false)) {
+ lineno++;
+ String[] lineComp = line.trim().split("\\s+");
+ try {
+ this.classMap.put(Vocabulary.id(lineComp[0]), Vocabulary.id(lineComp[1]));
+ } catch (java.lang.ArrayIndexOutOfBoundsException e) {
+ System.err.println(String.format("* WARNING: bad vocab line #%d '%s'", lineno, line));
+ }
+ }
+ }
+
+ }
+
+ public LanguageModelFF(FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, String.format("lm_%d", LanguageModelFF.LM_INDEX++), args, config);
+
+ this.type = parsedArgs.get("lm_type");
+ this.ngramOrder = Integer.parseInt(parsedArgs.get("lm_order"));
+ this.path = parsedArgs.get("lm_file");
+
+ if (parsedArgs.containsKey("class_map"))
+ try {
+ this.isClassLM = true;
+ this.classMap = new ClassMap(parsedArgs.get("class_map"));
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ // The dense feature initialization hasn't happened yet, so we have to retrieve this as sparse
+ this.weight = weights.getSparse(name);
+
+ initializeLM();
+ }
+
+ @Override
+ public ArrayList<String> reportDenseFeatures(int index) {
+ denseFeatureIndex = index;
+
+ ArrayList<String> names = new ArrayList<String>();
+ names.add(name);
+ return names;
+ }
+
+ /**
+ * Initializes the underlying language model.
+ *
+ * @param config
+ * @param type
+ * @param path
+ */
+ protected void initializeLM() {
+ if (type.equals("kenlm")) {
+ this.languageModel = new KenLM(ngramOrder, path);
+
+ } else if (type.equals("berkeleylm")) {
+ this.languageModel = new LMGrammarBerkeley(ngramOrder, path);
+
+ } else {
+ System.err.println(String.format("* FATAL: Invalid backend lm_type '%s' for LanguageModel", type));
+ System.err.println(String.format("* Permissible values for 'lm_type' are 'kenlm' and 'berkeleylm'"));
+ System.exit(-1);
+ }
+
+ Vocabulary.registerLanguageModel(this.languageModel);
+ Vocabulary.id(config.default_non_terminal);
+
+ startSymbolId = Vocabulary.id(Vocabulary.START_SYM);
+ }
+
+ public NGramLanguageModel getLM() {
+ return this.languageModel;
+ }
+
+ public String logString() {
+ if (languageModel != null)
+ return String.format("%s, order %d (weight %.3f)", name, languageModel.getOrder(), weight);
+ else
+ return "WHOA";
+ }
+
+ /**
+ * Computes the features incurred along this edge. Note that these features are unweighted costs
+ * of the feature; they are the feature cost, not the model cost, or the inner product of them.
+ */
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ NgramDPState newState = null;
+ if (rule != null) {
+ if (config.source_annotations) {
+ // Get source side annotations and project them to the target side
+ newState = computeTransition(getTags(rule, i, j, sentence), tailNodes, acc);
+ }
+ else {
+ if (this.isClassLM) {
+ // Use a class language model
+ // Return target side classes
+ newState = computeTransition(getClasses(rule), tailNodes, acc);
+ }
+ else {
+ // Default LM
+ newState = computeTransition(rule.getEnglish(), tailNodes, acc);
+ }
+ }
+
+ }
+
+ return newState;
+ }
+
+ /**
+ * Input sentences can be tagged with information specific to the language model. This looks for
+ * such annotations by following a word's alignments back to the source words, checking for
+ * annotations, and replacing the surface word if such annotations are found.
+ *
+ */
+ protected int[] getTags(Rule rule, int begin, int end, Sentence sentence) {
+ /* Very important to make a copy here, so the original rule is not modified */
+ int[] tokens = Arrays.copyOf(rule.getEnglish(), rule.getEnglish().length);
+ byte[] alignments = rule.getAlignment();
+
+// System.err.println(String.format("getTags() %s", rule.getRuleString()));
+
+ /* For each target-side token, project it to each of its source-language alignments. If any of those
+ * are annotated, take the first annotation and quit.
+ */
+ if (alignments != null) {
+ for (int i = 0; i < tokens.length; i++) {
+ if (tokens[i] > 0) { // skip nonterminals
+ for (int j = 0; j < alignments.length; j += 2) {
+ if (alignments[j] == i) {
+ String annotation = sentence.getAnnotation((int)alignments[i] + begin, "class");
+ if (annotation != null) {
+// System.err.println(String.format(" word %d source %d abs %d annotation %d/%s",
+// i, alignments[i], alignments[i] + begin, annotation, Vocabulary.word(annotation)));
+ tokens[i] = Vocabulary.id(annotation);
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return tokens;
+ }
+
+ /**
+ * Sets the class map if this is a class LM
+ * @param classMap
+ * @throws IOException
+ */
+ public void setClassMap(String fileName) throws IOException {
+ this.classMap = new ClassMap(fileName);
+ }
+
+
+ /**
+ * Replace each word in a rule with the target side classes.
+ */
+ protected int[] getClasses(Rule rule) {
+ if (this.classMap == null) {
+ System.err.println("The class map is not set. Cannot use the class LM ");
+ System.exit(2);
+ }
+ /* Very important to make a copy here, so the original rule is not modified */
+ int[] tokens = Arrays.copyOf(rule.getEnglish(), rule.getEnglish().length);
+ for (int i = 0; i < tokens.length; i++) {
+ if (tokens[i] > 0 ) {
+ tokens[i] = this.classMap.getClassID(tokens[i]);
+ }
+ }
+ return tokens;
+ }
+
+ @Override
+ public DPState computeFinal(HGNode tailNode, int i, int j, SourcePath sourcePath, Sentence sentence,
+ Accumulator acc) {
+ return computeFinalTransition((NgramDPState) tailNode.getDPState(stateIndex), acc);
+ }
+
+ /**
+ * This function computes all the complete n-grams found in the rule, as well as the incomplete
+ * n-grams on the left-hand side.
+ */
+ @Override
+ public float estimateCost(Rule rule, Sentence sentence) {
+
+ float estimate = 0.0f;
+ boolean considerIncompleteNgrams = true;
+
+ int[] enWords = rule.getEnglish();
+
+ List<Integer> words = new ArrayList<Integer>();
+ boolean skipStart = (enWords[0] == startSymbolId);
+
+ /*
+ * Move through the words, accumulating language model costs each time we have an n-gram (n >=
+ * 2), and resetting the series of words when we hit a nonterminal.
+ */
+ for (int c = 0; c < enWords.length; c++) {
+ int currentWord = enWords[c];
+ if (Vocabulary.nt(currentWord)) {
+ estimate += scoreChunkLogP(words, considerIncompleteNgrams, skipStart);
+ words.clear();
+ skipStart = false;
+ } else {
+ words.add(currentWord);
+ }
+ }
+ estimate += scoreChunkLogP(words, considerIncompleteNgrams, skipStart);
+
+ return weight * estimate;
+ }
+
+ /**
+ * Estimates the future cost of a rule. For the language model feature, this is the sum of the
+ * costs of the leftmost k-grams, k = [1..n-1].
+ */
+ @Override
+ public float estimateFutureCost(Rule rule, DPState currentState, Sentence sentence) {
+ NgramDPState state = (NgramDPState) currentState;
+
+ float estimate = 0.0f;
+ int[] leftContext = state.getLeftLMStateWords();
+
+ if (null != leftContext) {
+ boolean skipStart = true;
+ if (leftContext[0] != startSymbolId) {
+ skipStart = false;
+ }
+ estimate += scoreChunkLogP(leftContext, true, skipStart);
+ }
+ return weight * estimate;
+ }
+
+ /**
+ * Compute the cost of a rule application. The cost of applying a rule is computed by determining
+ * the n-gram costs for all n-grams created by this rule application, and summing them. N-grams
+ * are created when (a) terminal words in the rule string are followed by a nonterminal (b)
+ * terminal words in the rule string are preceded by a nonterminal (c) we encounter adjacent
+ * nonterminals. In all of these situations, the corresponding boundary words of the node in the
+ * hypergraph represented by the nonterminal must be retrieved.
+ *
+ * IMPORTANT: only complete n-grams are scored. This means that hypotheses with fewer words
+ * than the complete n-gram state remain *unscored*. This fact adds a lot of complication to the
+ * code, including the use of the computeFinal* family of functions, which correct this fact for
+ * sentences that are too short on the final transition.
+ */
+ private NgramDPState computeTransition(int[] enWords, List<HGNode> tailNodes, Accumulator acc) {
+
+ int[] current = new int[this.ngramOrder];
+ int[] shadow = new int[this.ngramOrder];
+ int ccount = 0;
+ float transitionLogP = 0.0f;
+ int[] left_context = null;
+
+ for (int c = 0; c < enWords.length; c++) {
+ int curID = enWords[c];
+
+ if (Vocabulary.nt(curID)) {
+ int index = -(curID + 1);
+
+ NgramDPState state = (NgramDPState) tailNodes.get(index).getDPState(stateIndex);
+ int[] left = state.getLeftLMStateWords();
+ int[] right = state.getRightLMStateWords();
+
+ // Left context.
+ for (int i = 0; i < left.length; i++) {
+ current[ccount++] = left[i];
+
+ if (left_context == null && ccount == this.ngramOrder - 1)
+ left_context = Arrays.copyOf(current, ccount);
+
+ if (ccount == this.ngramOrder) {
+ // Compute the current word probability, and remove it.
+ float prob = this.languageModel.ngramLogProbability(current, this.ngramOrder);
+// System.err.println(String.format("-> prob(%s) = %f", Vocabulary.getWords(current), prob));
+ transitionLogP += prob;
+ System.arraycopy(current, 1, shadow, 0, this.ngramOrder - 1);
+ int[] tmp = current;
+ current = shadow;
+ shadow = tmp;
+ --ccount;
+ }
+ }
+ System.arraycopy(right, 0, current, ccount - right.length, right.length);
+ } else { // terminal words
+ current[ccount++] = curID;
+
+ if (left_context == null && ccount == this.ngramOrder - 1)
+ left_context = Arrays.copyOf(current, ccount);
+
+ if (ccount == this.ngramOrder) {
+ // Compute the current word probability, and remove it.s
+ float prob = this.languageModel.ngramLogProbability(current, this.ngramOrder);
+// System.err.println(String.format("-> prob(%s) = %f", Vocabulary.getWords(current), prob));
+ transitionLogP += prob;
+ System.arraycopy(current, 1, shadow, 0, this.ngramOrder - 1);
+ int[] tmp = current;
+ current = shadow;
+ shadow = tmp;
+ --ccount;
+ }
+ }
+ }
+// acc.add(name, transitionLogP);
+ acc.add(denseFeatureIndex, transitionLogP);
+
+ if (left_context != null) {
+ return new NgramDPState(left_context, Arrays.copyOfRange(current, ccount - this.ngramOrder
+ + 1, ccount));
+ } else {
+ int[] context = Arrays.copyOf(current, ccount);
+ return new NgramDPState(context, context);
+ }
+ }
+
+ /**
+ * This function differs from regular transitions because we incorporate the cost of incomplete
+ * left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
+ * requested when the object was created).
+ *
+ * @param state the dynamic programming state
+ * @return the final transition probability (including incomplete n-grams)
+ */
+ private NgramDPState computeFinalTransition(NgramDPState state, Accumulator acc) {
+
+// System.err.println(String.format("LanguageModel::computeFinalTransition()"));
+
+ float res = 0.0f;
+ LinkedList<Integer> currentNgram = new LinkedList<Integer>();
+ int[] leftContext = state.getLeftLMStateWords();
+ int[] rightContext = state.getRightLMStateWords();
+
+ for (int i = 0; i < leftContext.length; i++) {
+ int t = leftContext[i];
+ currentNgram.add(t);
+
+ if (currentNgram.size() >= 2) { // start from bigram
+ float prob = this.languageModel.ngramLogProbability(Support.toArray(currentNgram),
+ currentNgram.size());
+ res += prob;
+ }
+ if (currentNgram.size() == this.ngramOrder)
+ currentNgram.removeFirst();
+ }
+
+ // Tell the accumulator
+// acc.add(name, res);
+ acc.add(denseFeatureIndex, res);
+
+ // State is the same
+ return new NgramDPState(leftContext, rightContext);
+ }
+
+
+ /**
+ * Compatibility method for {@link #scoreChunkLogP(int[], boolean, boolean)}
+ */
+ private float scoreChunkLogP(List<Integer> words, boolean considerIncompleteNgrams,
+ boolean skipStart) {
+ return scoreChunkLogP(Ints.toArray(words), considerIncompleteNgrams, skipStart);
+ }
+
+ /**
+ * This function is basically a wrapper for NGramLanguageModel::sentenceLogProbability(). It
+ * computes the probability of a phrase ("chunk"), using lower-order n-grams for the first n-1
+ * words.
+ *
+ * @param words
+ * @param considerIncompleteNgrams
+ * @param skipStart
+ * @return the phrase log probability
+ */
+ private float scoreChunkLogP(int[] words, boolean considerIncompleteNgrams,
+ boolean skipStart) {
+
+ float score = 0.0f;
+ if (words.length > 0) {
+ int startIndex;
+ if (!considerIncompleteNgrams) {
+ startIndex = this.ngramOrder;
+ } else if (skipStart) {
+ startIndex = 2;
+ } else {
+ startIndex = 1;
+ }
+ score = this.languageModel.sentenceLogProbability(words, this.ngramOrder, startIndex);
+ }
+
+ return score;
+ }
+
+ /**
+ * Public method to set LM_INDEX back to 0.
+ * Required if multiple instances of the JoshuaDecoder live in the same JVM.
+ */
+ public static void resetLmIndex() {
+ LM_INDEX = 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/8cdbc4b8/src/main/java/org/apache/joshua/decoder/ff/lm/NGramLanguageModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/NGramLanguageModel.java b/src/main/java/org/apache/joshua/decoder/ff/lm/NGramLanguageModel.java
new file mode 100644
index 0000000..15da650
--- /dev/null
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/NGramLanguageModel.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package joshua.decoder.ff.lm;
+
+/**
+ * An interface for new language models to implement. An object of this type is passed to
+ * LanguageModelFF, which will handle all the dynamic programming and state maintenance.
+ *
+ * @author wren ng thornton <wr...@users.sourceforge.net>
+ * @author Zhifei Li, <zh...@gmail.com>
+ * @author Matt Post <po...@cs.jhu.edu>
+ * @author Juri Ganitkevitch <ju...@cs.jhu.edu>
+ */
+public interface NGramLanguageModel {
+
+ // ===============================================================
+ // Attributes
+ // ===============================================================
+ int getOrder();
+
+ // ===============================================================
+ // Methods
+ // ===============================================================
+
+ /**
+ * Language models may have their own private vocabulary mapping strings to integers; for example,
+ * if they make use of a compile format (as KenLM and BerkeleyLM do). This mapping is likely
+ * different from the global mapping containing in joshua.corpus.Vocabulary, which is used to
+ * convert the input string and grammars. This function is used to tell the language model what
+ * the global mapping is, so that the language model can convert it into its own private mapping.
+ *
+ * @param word
+ * @param id
+ * @return Whether any collisions were detected.
+ */
+ boolean registerWord(String token, int id);
+
+ /**
+ * @param sentence the sentence to be scored
+ * @param order the order of N-grams for the LM
+ * @param startIndex the index of first event-word we want to get its probability; if we want to
+ * get the prob for the whole sentence, then startIndex should be 1
+ * @return the LogP of the whole sentence
+ */
+ float sentenceLogProbability(int[] sentence, int order, int startIndex);
+
+ /**
+ * Compute the probability of a single word given its context.
+ *
+ * @param ngram
+ * @param order
+ * @return
+ */
+ float ngramLogProbability(int[] ngram, int order);
+
+ float ngramLogProbability(int[] ngram);
+}