You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/05/31 19:39:23 UTC
[3/5] incubator-joshua git commit: Merge branch 'sparse' of
https://github.com/fhieber/incubator-joshua into JOSHUA-PR21
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5c0d5388/src/main/java/org/apache/joshua/decoder/ff/RuleFF.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/ff/RuleFF.java
index bc6d67b,0000000..20f91ee
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/ff/RuleFF.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/RuleFF.java
@@@ -1,100 -1,0 +1,135 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff;
+
++import static com.google.common.cache.CacheBuilder.newBuilder;
++
+import java.util.List;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.chart_parser.SourcePath;
+import org.apache.joshua.decoder.ff.state_maintenance.DPState;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.segment_file.Sentence;
+
++import com.google.common.cache.Cache;
++
+/**
- * This feature just counts rules that are used. You can restrict it with a number of flags:
- *
- * -owner OWNER
- * Only count rules owned by OWNER
- * -target|-source
- * Only count the target or source side (plus the LHS)
- *
- * TODO: add an option to separately provide a list of rule counts, restrict to counts above a threshold.
++ * This feature fires for rule ids.
++ * Firing can be restricted to rules from a certain owner, and rule ids
++ * can be generated from source side and/or target side.
+ */
+public class RuleFF extends StatelessFF {
+
+ private enum Sides { SOURCE, TARGET, BOTH };
+
- private int owner = 0;
- private Sides sides = Sides.BOTH;
++ private static final String NAME = "RuleFF";
++ // value to fire for features
++ private static final int VALUE = 1;
++ // whether this feature is restricted to a certain grammar/owner
++ private final boolean ownerRestriction;
++ // the grammar/owner this feature is restricted to fire
++ private final int owner;
++ // what part of the rule should be extracted;
++ private final Sides sides;
++ // Strings separating words and rule sides
++ private static final String SEPARATOR = "~";
++ private static final String SIDES_SEPARATOR = "->";
++
++ private final Cache<Rule, String> featureCache;
+
+ public RuleFF(FeatureVector weights, String[] args, JoshuaConfiguration config) {
- super(weights, "RuleFF", args, config);
++ super(weights, NAME, args, config);
++
++ ownerRestriction = (parsedArgs.containsKey("owner")) ? true : false;
++ owner = ownerRestriction ? Vocabulary.id(parsedArgs.get("owner")) : 0;
+
- owner = Vocabulary.id(parsedArgs.get("owner"));
- if (parsedArgs.containsKey("source"))
- sides = Sides.SOURCE;
- else if (parsedArgs.containsKey("target"))
- sides = Sides.TARGET;
++ if (parsedArgs.containsKey("sides")) {
++ final String sideValue = parsedArgs.get("sides");
++ if (sideValue.equalsIgnoreCase("source")) {
++ sides = Sides.SOURCE;
++ } else if (sideValue.equalsIgnoreCase("target")) {
++ sides = Sides.TARGET;
++ } else if (sideValue.equalsIgnoreCase("both")){
++ sides = Sides.BOTH;
++ } else {
++ throw new RuntimeException("Unknown side value.");
++ }
++ } else {
++ sides = Sides.BOTH;
++ }
++
++ // initialize cache
++ if (parsedArgs.containsKey("cacheSize")) {
++ featureCache = newBuilder().maximumSize(Integer.parseInt(parsedArgs.get("cacheSize"))).build();
++ } else {
++ featureCache = newBuilder().maximumSize(config.cachedRuleSize).build();
++ }
+ }
+
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
-
- if (owner > 0 && rule.getOwner() == owner) {
- String ruleString = getRuleString(rule);
- acc.add(ruleString, 1);
++
++ if (ownerRestriction && rule.getOwner() != owner) {
++ return null;
+ }
+
++ String featureName = featureCache.getIfPresent(rule);
++ if (featureName == null) {
++ featureName = getRuleString(rule);
++ featureCache.put(rule, featureName);
++ }
++ acc.add(featureName, VALUE);
++
+ return null;
+ }
-
- private String getRuleString(Rule rule) {
- String ruleString = "";
- switch(sides) {
- case BOTH:
- ruleString = String.format("%s %s %s", Vocabulary.word(rule.getLHS()), rule.getFrenchWords(),
- rule.getEnglishWords());
- break;
-
- case SOURCE:
- ruleString = String.format("%s %s", Vocabulary.word(rule.getLHS()), rule.getFrenchWords());
- break;
-
- case TARGET:
- ruleString = String.format("%s %s", Vocabulary.word(rule.getLHS()), rule.getEnglishWords());
- break;
++
++ /**
++ * Obtains the feature id for the given rule.
++ * @param rule
++ * @return String representing the feature name.s
++ */
++ private String getRuleString(final Rule rule) {
++ final StringBuilder sb = new StringBuilder(Vocabulary.word(rule.getLHS()))
++ .append(SIDES_SEPARATOR);
++ if (sides == Sides.SOURCE || sides == Sides.BOTH) {
++ sb.append(Vocabulary.getWords(rule.getFrench(), SEPARATOR));
++ }
++ sb.append(SIDES_SEPARATOR);
++ if (sides == Sides.TARGET || sides == Sides.BOTH) {
++ sb.append(Vocabulary.getWords(rule.getEnglish(), SEPARATOR));
+ }
- return ruleString.replaceAll("[ =]", "~");
++ return sb.toString();
+ }
+
+ @Override
+ public double estimateLogP(Rule rule, int sentID) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public double getWeight() {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5c0d5388/src/main/java/org/apache/joshua/decoder/ff/RuleLength.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/ff/RuleLength.java
index 59b1c20,0000000..02c520b
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/ff/RuleLength.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/RuleLength.java
@@@ -1,51 -1,0 +1,52 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff;
+
+import java.util.List;
+
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.chart_parser.SourcePath;
+import org.apache.joshua.decoder.ff.state_maintenance.DPState;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.segment_file.Sentence;
+
+/*
+ * This feature computes three feature templates: a feature indicating the length of the rule's
+ * source side, its target side, and a feature that pairs them.
+ */
+public abstract class RuleLength extends StatelessFF {
++
++ private static final int VALUE = 1;
+
+ public RuleLength(FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "RuleLength", args, config);
+ }
+
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
- int sourceLen = rule.getFrench().length;
- int targetLen = rule.getEnglish().length;
- acc.add(String.format("%s_sourceLength%d", name, sourceLen), 1);
- acc.add(String.format("%s_targetLength%d", name, targetLen), 1);
- acc.add(String.format("%s_pairLength%d-%d", name, sourceLen, targetLen), 1);
-
++ int sourceLength = rule.getFrench().length;
++ int targetLength = rule.getEnglish().length;
++ acc.add(name + "_source" + sourceLength, VALUE);
++ acc.add(name + "_target" + sourceLength, VALUE);
++ acc.add(name + "_sourceTarget" + sourceLength + "-" + targetLength, VALUE);
+ return null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5c0d5388/src/main/java/org/apache/joshua/decoder/ff/RuleShape.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/ff/RuleShape.java
index a514021,0000000..6333701
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/ff/RuleShape.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/RuleShape.java
@@@ -1,85 -1,0 +1,112 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff;
+
+import java.util.List;
+
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.chart_parser.SourcePath;
+import org.apache.joshua.decoder.ff.state_maintenance.DPState;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.segment_file.Sentence;
++import org.apache.joshua.util.FormatUtils;
++import org.apache.joshua.corpus.Vocabulary;
+
+/*
+ * Implements the RuleShape feature for source, target, and paired source+target sides.
+ */
+public class RuleShape extends StatelessFF {
+
+ public RuleShape(FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "RuleShape", args, config);
+ }
+
- private int gettype(int id) {
- if (id < 0)
- return -1;
- return 1;
++ private enum WordType {
++ N("N"), T("x"), P("+");
++ private final String string;
++ private boolean repeats;
++
++ private WordType(final String string) {
++ this.string = string;
++ this.repeats = false;
++ }
++
++ private void setRepeats() {
++ repeats = true;
++ }
++
++ @Override
++ public String toString() {
++ if (repeats) {
++ return this.string + "+";
++ }
++ return this.string;
++ }
++ }
++
++ private WordType getWordType(int id) {
++ if (FormatUtils.isNonterminal(id)) {
++ return WordType.N;
++ } else {
++ return WordType.T;
++ }
+ }
+
- private String pattern(int[] ids) {
- StringBuilder pattern = new StringBuilder();
- int curtype = gettype(ids[0]);
- int curcount = 1;
++ /**
++ * Returns a String describing the rule pattern.
++ */
++ private String getRulePattern(int[] ids) {
++ final StringBuilder pattern = new StringBuilder();
++ WordType currentType = getWordType(ids[0]);
+ for (int i = 1; i < ids.length; i++) {
- if (gettype(ids[i]) != curtype) {
- pattern.append(String.format("%s%s_", curtype < 0 ? "N" : "x", curcount > 1 ? "+" : ""));
- curtype = gettype(ids[i]);
- curcount = 1;
++ if (getWordType(ids[i]) != currentType) {
++ pattern.append(currentType.toString());
++ currentType = getWordType(ids[i]);
+ } else {
- curcount++;
++ currentType.setRepeats();
+ }
+ }
- pattern.append(String.format("%s%s_", curtype < 0 ? "N" : "x", curcount > 1 ? "+" : ""));
++ pattern.append(currentType.toString());
+ return pattern.toString();
+ }
+
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i_, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
- String sourceShape = pattern(rule.getFrench());
- String targetShape = pattern(rule.getEnglish());
- acc.add(String.format("%s_source_%s", name, sourceShape), 1);
- acc.add(String.format("%s_target_%s", name, targetShape), 1);
- acc.add(String.format("%s_both_%s__%s", name, sourceShape, targetShape), 1);
-
++ final String sourceShape = getRulePattern(rule.getFrench());
++ final String targetShape = getRulePattern(rule.getEnglish());
++ acc.add(name + "_source_" + sourceShape, 1);
++ acc.add(name + "_target_" + sourceShape, 1);
++ acc.add(name + "_sourceTarget_" + sourceShape + "_" + targetShape, 1);
+ return null;
+ }
+
+ @Override
+ public double estimateLogP(Rule rule, int sentID) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public double getWeight() {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5c0d5388/src/main/java/org/apache/joshua/decoder/ff/WordPenalty.java
----------------------------------------------------------------------
diff --cc src/main/java/org/apache/joshua/decoder/ff/WordPenalty.java
index 62c889f,0000000..e1f74c2
mode 100644,000000..100644
--- a/src/main/java/org/apache/joshua/decoder/ff/WordPenalty.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/WordPenalty.java
@@@ -1,90 -1,0 +1,92 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.ff.state_maintenance.DPState;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.chart_parser.SourcePath;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.phrase.Hypothesis;
+import org.apache.joshua.decoder.segment_file.Sentence;
+
+/**
+ *
+ * @author Zhifei Li zhifei.work@gmail.com
+ * @author Matt Post post@cs.jhu.edu
+ */
+public final class WordPenalty extends StatelessFF {
+
+ private float OMEGA = -(float) Math.log10(Math.E); // -0.435
++ private final boolean isCky;
+
+ public WordPenalty(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "WordPenalty", args, config);
+
+ if (parsedArgs.containsKey("value"))
+ OMEGA = Float.parseFloat(parsedArgs.get("value"));
++
++ isCky = config.search_algorithm.equals("cky");
+ }
+
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ if (rule != null) {
+ // TODO: this is an inefficient way to do this. Find a better way to not apply this rule
+ // to start and stop glue rules when phrase-based decoding.
- if (config.search_algorithm.equals("cky")
- || (rule != Hypothesis.BEGIN_RULE && rule != Hypothesis.END_RULE))
- // acc.add(name, OMEGA * (rule.getEnglish().length - rule.getArity()));
++ if (isCky || (rule != Hypothesis.BEGIN_RULE && rule != Hypothesis.END_RULE)) {
+ acc.add(denseFeatureIndex, OMEGA * (rule.getEnglish().length - rule.getArity()));
++ }
+ }
+
+ return null;
+ }
+
+ @Override
+ public ArrayList<String> reportDenseFeatures(int index) {
+ denseFeatureIndex = index;
- ArrayList<String> names = new ArrayList<String>();
++ ArrayList<String> names = new ArrayList<>(1);
+ names.add(name);
+ return names;
+ }
+
+ @Override
+ public float estimateCost(Rule rule, Sentence sentence) {
+ if (rule != null)
+ return weights.getDense(denseFeatureIndex) * OMEGA * (rule.getEnglish().length - rule.getArity());
+ return 0.0f;
+ }
+
+ @Override
+ public double estimateLogP(Rule rule, int sentID) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public double getWeight() {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5c0d5388/src/test/java/org/apache/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeleyTest.java
----------------------------------------------------------------------
diff --cc src/test/java/org/apache/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeleyTest.java
index df73136,0000000..00a6a36
mode 100644,000000..100644
--- a/src/test/java/org/apache/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeleyTest.java
+++ b/src/test/java/org/apache/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeleyTest.java
@@@ -1,79 -1,0 +1,79 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm.berkeley_lm;
+
+import static org.junit.Assert.assertEquals;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.Translation;
+import org.apache.joshua.decoder.segment_file.Sentence;
+
+/**
+ * Replacement for test/lm/berkeley/test.sh regression test
+ */
+@RunWith(value = Parameterized.class)
+public class LMGrammarBerkeleyTest {
+
+ private static final String INPUT = "the chat-rooms";
+ private static final String[] OPTIONS = "-v 0 -output-format %f".split(" ");
+
+ private JoshuaConfiguration joshuaConfig;
+ private Decoder decoder;
+
+ @Parameters
+ public static List<String> lmFiles() {
+ return Arrays.asList("resources/berkeley_lm/lm",
+ "resources/berkeley_lm/lm.gz",
+ "resources/berkeley_lm/lm.berkeleylm",
+ "resources/berkeley_lm/lm.berkeleylm.gz");
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ decoder.cleanUp();
+ }
+
+ //TODO @Parameters
+ public String lmFile;
+
+ @Test
+ public void verifyLM() {
+ joshuaConfig = new JoshuaConfiguration();
+ joshuaConfig.processCommandLineOptions(OPTIONS);
- joshuaConfig.features.add("feature_function = LanguageModel -lm_type berkeleylm -lm_order 2 -lm_file " + lmFile);
++ joshuaConfig.features.add("LanguageModel -lm_type berkeleylm -lm_order 2 -lm_file " + lmFile);
+ decoder = new Decoder(joshuaConfig, null);
+ String translation = decode(INPUT).toString();
+ assertEquals(lmFile, "tm_glue_0=2.000 lm_0=-7.153\n", translation);
+ }
+
+ private Translation decode(String input) {
+ final Sentence sentence = new Sentence(input, 0, joshuaConfig);
+ return decoder.decode(sentence);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5c0d5388/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
----------------------------------------------------------------------
diff --cc src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
index f006363,0000000..c760586
mode 100644,000000..100644
--- a/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
+++ b/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
@@@ -1,164 -1,0 +1,164 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ package org.apache.joshua.system;
+
+import static org.junit.Assert.assertTrue;
+
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.MetaDataException;
+import org.apache.joshua.decoder.io.TranslationRequestStream;
+import org.apache.joshua.decoder.segment_file.Sentence;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Integration test for multithreaded Joshua decoder tests. Grammar used is a
+ * toy packed grammar.
+ *
+ * @author kellens
+ */
+public class MultithreadedTranslationTests {
+
+ private JoshuaConfiguration joshuaConfig = null;
+ private Decoder decoder = null;
+ private static final String INPUT = "A K B1 U Z1 Z2 B2 C";
+ private int previousLogLevel;
+ private final static long NANO_SECONDS_PER_SECOND = 1_000_000_000;
+
+ @Before
+ public void setUp() throws Exception {
+ joshuaConfig = new JoshuaConfiguration();
+ joshuaConfig.search_algorithm = "cky";
+ joshuaConfig.mark_oovs = false;
+ joshuaConfig.pop_limit = 100;
+ joshuaConfig.use_unique_nbest = false;
+ joshuaConfig.include_align_index = false;
+ joshuaConfig.topN = 0;
+ joshuaConfig.tms.add("thrax -owner pt -maxspan 20 -path resources/wa_grammar.packed");
+ joshuaConfig.tms.add("thrax -owner glue -maxspan -1 -path resources/grammar.glue");
+ joshuaConfig.goal_symbol = "[GOAL]";
+ joshuaConfig.default_non_terminal = "[X]";
- joshuaConfig.features.add("feature_function = OOVPenalty");
++ joshuaConfig.features.add("OOVPenalty");
+ joshuaConfig.weights.add("tm_pt_0 1");
+ joshuaConfig.weights.add("tm_pt_1 1");
+ joshuaConfig.weights.add("tm_pt_2 1");
+ joshuaConfig.weights.add("tm_pt_3 1");
+ joshuaConfig.weights.add("tm_pt_4 1");
+ joshuaConfig.weights.add("tm_pt_5 1");
+ joshuaConfig.weights.add("tm_glue_0 1");
+ joshuaConfig.weights.add("OOVPenalty 2");
+ joshuaConfig.num_parallel_decoders = 500; // This will enable 500 parallel
+ // decoders to run at once.
+ // Useful to help flush out
+ // concurrency errors in
+ // underlying
+ // data-structures.
+ this.decoder = new Decoder(joshuaConfig, ""); // Second argument
+ // (configFile)
+ // is not even used by the
+ // constructor/initialize.
+
+ previousLogLevel = Decoder.VERBOSE;
+ Decoder.VERBOSE = 0;
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ this.decoder.cleanUp();
+ this.decoder = null;
+ Decoder.VERBOSE = previousLogLevel;
+ }
+
+
+
+ // This test was created specifically to reproduce a multithreaded issue
+ // related to mapped byte array access in the PackedGrammer getAlignmentArray
+ // function.
+
+ // We'll test the decoding engine using N = 10,000 identical inputs. This
+ // should be sufficient to induce concurrent data access for many shared
+ // data structures.
+
+ @Test
+ public void givenPackedGrammar_whenNTranslationsCalledConcurrently_thenReturnNResults() {
+ // GIVEN
+
+ int inputLines = 10000;
+ joshuaConfig.use_structured_output = true; // Enabled alignments.
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < inputLines; i++) {
+ sb.append(INPUT + "\n");
+ }
+
+ // Append a large string together to simulate N requests to the decoding
+ // engine.
+ TranslationRequestStream req = new TranslationRequestStream(
+ new BufferedReader(new InputStreamReader(new ByteArrayInputStream(sb.toString()
+ .getBytes(Charset.forName("UTF-8"))))), joshuaConfig);
+
+ ByteArrayOutputStream output = new ByteArrayOutputStream();
+
+ // WHEN
+ // Translate all spans in parallel.
+ try {
+ this.decoder.decodeAll(req, output);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ ArrayList<Sentence> translationResults = new ArrayList<Sentence>();
+
+
+ final long translationStartTime = System.nanoTime();
+ Sentence t;
+ try {
+ while ((t = req.next()) != null) {
+ translationResults.add(t);
+ }
+ } catch (MetaDataException e) {
+ e.printStackTrace();
+ } finally {
+ if (output != null) {
+ try {
+ output.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ final long translationEndTime = System.nanoTime();
+ final double pipelineLoadDurationInSeconds = (translationEndTime - translationStartTime) / ((double)NANO_SECONDS_PER_SECOND);
+ System.err.println(String.format("%.2f seconds", pipelineLoadDurationInSeconds));
+
+ // THEN
+ assertTrue(translationResults.size() == inputLines);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/5c0d5388/src/test/java/org/apache/joshua/system/StructuredTranslationTest.java
----------------------------------------------------------------------
diff --cc src/test/java/org/apache/joshua/system/StructuredTranslationTest.java
index a78a4a1,0000000..69412e2
mode 100644,000000..100644
--- a/src/test/java/org/apache/joshua/system/StructuredTranslationTest.java
+++ b/src/test/java/org/apache/joshua/system/StructuredTranslationTest.java
@@@ -1,272 -1,0 +1,272 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.system;
+
+import static java.util.Arrays.asList;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.StructuredTranslation;
+import org.apache.joshua.decoder.Translation;
+import org.apache.joshua.decoder.segment_file.Sentence;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Integration test for the complete Joshua decoder using a toy grammar that translates
+ * a bunch of capital letters to lowercase letters. Rules in the test grammar
+ * drop and generate additional words and simulate reordering of rules, so that
+ * proper extraction of word alignments and other information from the decoder
+ * can be tested.
+ *
+ * @author fhieber
+ */
+public class StructuredTranslationTest {
+
+ private JoshuaConfiguration joshuaConfig = null;
+ private Decoder decoder = null;
+ private static final String INPUT = "A K B1 U Z1 Z2 B2 C";
+ private static final String EXPECTED_TRANSLATION = "a b n1 u z c1 k1 k2 k3 n1 n2 n3 c2";
+ private static final List<String> EXPECTED_TRANSLATED_TOKENS = asList(EXPECTED_TRANSLATION.split("\\s+"));
+ private static final String EXPECTED_WORD_ALIGNMENT_STRING = "0-0 2-1 6-1 3-3 4-4 5-4 7-5 1-6 1-7 1-8 7-12";
+ private static final List<List<Integer>> EXPECTED_WORD_ALIGNMENT = asList(
+ asList(0), asList(2, 6), asList(), asList(3),
+ asList(4, 5), asList(7), asList(1),
+ asList(1), asList(1), asList(), asList(),
+ asList(), asList(7));
+ private static final double EXPECTED_SCORE = -17.0;
+ private static final Map<String,Float> EXPECTED_FEATURES = new HashMap<>();
+ private static final int EXPECTED_NBEST_LIST_SIZE = 8;
+ static {
+ EXPECTED_FEATURES.put("tm_glue_0", 1.0f);
+ EXPECTED_FEATURES.put("tm_pt_0", -3.0f);
+ EXPECTED_FEATURES.put("tm_pt_1", -3.0f);
+ EXPECTED_FEATURES.put("tm_pt_2", -3.0f);
+ EXPECTED_FEATURES.put("tm_pt_3", -3.0f);
+ EXPECTED_FEATURES.put("tm_pt_4", -3.0f);
+ EXPECTED_FEATURES.put("tm_pt_5", -3.0f);
+ EXPECTED_FEATURES.put("OOV", 7.0f);
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ joshuaConfig = new JoshuaConfiguration();
+ joshuaConfig.search_algorithm = "cky";
+ joshuaConfig.mark_oovs = false;
+ joshuaConfig.pop_limit = 100;
+ joshuaConfig.use_unique_nbest = false;
+ joshuaConfig.include_align_index = false;
+ joshuaConfig.topN = 0;
+ joshuaConfig.tms.add("thrax -owner pt -maxspan 20 -path resources/wa_grammar");
+ joshuaConfig.tms.add("thrax -owner glue -maxspan -1 -path resources/grammar.glue");
+ joshuaConfig.goal_symbol = "[GOAL]";
+ joshuaConfig.default_non_terminal = "[X]";
- joshuaConfig.features.add("feature_function = OOVPenalty");
++ joshuaConfig.features.add("OOVPenalty");
+ joshuaConfig.weights.add("tm_pt_0 1");
+ joshuaConfig.weights.add("tm_pt_1 1");
+ joshuaConfig.weights.add("tm_pt_2 1");
+ joshuaConfig.weights.add("tm_pt_3 1");
+ joshuaConfig.weights.add("tm_pt_4 1");
+ joshuaConfig.weights.add("tm_pt_5 1");
+ joshuaConfig.weights.add("tm_glue_0 1");
+ joshuaConfig.weights.add("OOVPenalty 1");
+ decoder = new Decoder(joshuaConfig, ""); // second argument (configFile
+ // is not even used by the
+ // constructor/initialize)
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ decoder.cleanUp();
+ decoder = null;
+ }
+
+ private Translation decode(String input) {
+ Sentence sentence = new Sentence(input, 0, joshuaConfig);
+ return decoder.decode(sentence);
+ }
+
+ @Test
+ public void givenInput_whenRegularOutputFormat_thenExpectedOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = false;
+ joshuaConfig.outputFormat = "%s | %a ";
+
+ // WHEN
+ final String translation = decode(INPUT).toString().trim();
+
+ // THEN
+ assertEquals(EXPECTED_TRANSLATION + " | " + EXPECTED_WORD_ALIGNMENT_STRING, translation);
+ }
+
+ @Test
+ public void givenInput_whenRegularOutputFormatWithTopN1_thenExpectedOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = false;
+ joshuaConfig.outputFormat = "%s | %e | %a | %c";
+ joshuaConfig.topN = 1;
+
+ // WHEN
+ final String translation = decode(INPUT).toString().trim();
+
+ // THEN
+ assertEquals(EXPECTED_TRANSLATION + " | " + INPUT + " | " + EXPECTED_WORD_ALIGNMENT_STRING + String.format(" | %.3f", EXPECTED_SCORE),
+ translation);
+ }
+
+ @Test
+ public void givenInput_whenStructuredOutputFormatWithTopN0_thenExpectedOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = true;
+ joshuaConfig.topN = 0;
+
+ // WHEN
+ final Translation translation = decode(INPUT);
+ final StructuredTranslation structuredTranslation = translation.getStructuredTranslations().get(0);
+ final String translationString = structuredTranslation.getTranslationString();
+ final List<String> translatedTokens = structuredTranslation.getTranslationTokens();
+ final float translationScore = structuredTranslation.getTranslationScore();
+ final List<List<Integer>> wordAlignment = structuredTranslation.getTranslationWordAlignments();
+ final Map<String,Float> translationFeatures = structuredTranslation.getTranslationFeatures();
+
+ // THEN
+ assertTrue(translation.getStructuredTranslations().size() == 1);
+ assertEquals(EXPECTED_TRANSLATION, translationString);
+ assertEquals(EXPECTED_TRANSLATED_TOKENS, translatedTokens);
+ assertEquals(EXPECTED_SCORE, translationScore, 0.00001);
+ assertEquals(EXPECTED_WORD_ALIGNMENT, wordAlignment);
+ assertEquals(wordAlignment.size(), translatedTokens.size());
+ assertEquals(EXPECTED_FEATURES.entrySet(), translationFeatures.entrySet());
+ }
+
+ @Test
+ public void givenInput_whenStructuredOutputFormatWithTopN1_thenExpectedOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = true;
+ joshuaConfig.topN = 1;
+
+ // WHEN
+ final Translation translation = decode(INPUT);
+ final List<StructuredTranslation> structuredTranslations = translation.getStructuredTranslations();
+ final StructuredTranslation structuredTranslation = structuredTranslations.get(0);
+ final String translationString = structuredTranslation.getTranslationString();
+ final List<String> translatedTokens = structuredTranslation.getTranslationTokens();
+ final float translationScore = structuredTranslation.getTranslationScore();
+ final List<List<Integer>> wordAlignment = structuredTranslation.getTranslationWordAlignments();
+ final Map<String,Float> translationFeatures = structuredTranslation.getTranslationFeatures();
+
+ // THEN
+ assertTrue(structuredTranslations.size() == 1);
+ assertEquals(EXPECTED_TRANSLATION, translationString);
+ assertEquals(EXPECTED_TRANSLATED_TOKENS, translatedTokens);
+ assertEquals(EXPECTED_SCORE, translationScore, 0.00001);
+ assertEquals(EXPECTED_WORD_ALIGNMENT, wordAlignment);
+ assertEquals(wordAlignment.size(), translatedTokens.size());
+ assertEquals(EXPECTED_FEATURES.entrySet(), translationFeatures.entrySet());
+ }
+
+ @Test
+ public void givenInput_whenStructuredOutputFormatWithKBest_thenExpectedOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = true;
+ joshuaConfig.topN = 100;
+
+ // WHEN
+ final Translation translation = decode(INPUT);
+ final List<StructuredTranslation> structuredTranslations = translation.getStructuredTranslations();
+ final StructuredTranslation viterbiTranslation = structuredTranslations.get(0);
+ final StructuredTranslation lastKBest = structuredTranslations.get(structuredTranslations.size() - 1);
+
+ // THEN
+ assertEquals(structuredTranslations.size(), EXPECTED_NBEST_LIST_SIZE);
+ assertTrue(structuredTranslations.size() > 1);
+ assertEquals(EXPECTED_TRANSLATION, viterbiTranslation.getTranslationString());
+ assertEquals(EXPECTED_TRANSLATED_TOKENS, viterbiTranslation.getTranslationTokens());
+ assertEquals(EXPECTED_SCORE, viterbiTranslation.getTranslationScore(), 0.00001);
+ assertEquals(EXPECTED_WORD_ALIGNMENT, viterbiTranslation.getTranslationWordAlignments());
+ assertEquals(EXPECTED_FEATURES.entrySet(), viterbiTranslation.getTranslationFeatures().entrySet());
+ // last entry in KBEST is all input words untranslated, should have 8 OOVs.
+ assertEquals(INPUT, lastKBest.getTranslationString());
+ assertEquals(-800.0, lastKBest.getTranslationFeatures().get("OOVPenalty"), 0.0001);
+
+ }
+
+ @Test
+ public void givenEmptyInput_whenStructuredOutputFormat_thenEmptyOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = true;
+
+ // WHEN
+ final Translation translation = decode("");
+ final StructuredTranslation structuredTranslation = translation.getStructuredTranslations().get(0);
+ final String translationString = structuredTranslation.getTranslationString();
+ final List<String> translatedTokens = structuredTranslation.getTranslationTokens();
+ final float translationScore = structuredTranslation.getTranslationScore();
+ final List<List<Integer>> wordAlignment = structuredTranslation.getTranslationWordAlignments();
+
+ // THEN
+ assertEquals("", translationString);
+ assertTrue(translatedTokens.isEmpty());
+ assertEquals(0, translationScore, 0.00001);
+ assertTrue(wordAlignment.isEmpty());
+ }
+
+ @Test
+ public void givenOOVInput_whenStructuredOutputFormat_thenOOVOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = true;
+ final String input = "gabarbl";
+
+ // WHEN
+ final Translation translation = decode(input);
+ final StructuredTranslation structuredTranslation = translation.getStructuredTranslations().get(0);
+ final String translationString = structuredTranslation.getTranslationString();
+ final List<String> translatedTokens = structuredTranslation.getTranslationTokens();
+ final float translationScore = structuredTranslation.getTranslationScore();
+ final List<List<Integer>> wordAlignment = structuredTranslation.getTranslationWordAlignments();
+
+ // THEN
+ assertEquals(input, translationString);
+ assertTrue(translatedTokens.contains(input));
+ assertEquals(-99.0, translationScore, 0.00001);
+ assertTrue(wordAlignment.contains(asList(0)));
+ }
+
+ @Test
+ public void givenEmptyInput_whenRegularOutputFormat_thenNewlineOutput() {
+ // GIVEN
+ joshuaConfig.use_structured_output = false;
+
+ // WHEN
+ final Translation translation = decode("");
+ final String translationString = translation.toString();
+
+ // THEN
+ assertEquals("\n", translationString);
+ }
+
+}