You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/06/23 18:45:40 UTC
[29/60] [partial] incubator-joshua git commit: maven multi-module
layout 1st commit: moving files into joshua-core
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java b/joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java
new file mode 100755
index 0000000..6eaced4
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java
@@ -0,0 +1,643 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.mira;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.Vector;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.metrics.EvaluationMetric;
+
+// this class implements the MIRA algorithm
+public class Optimizer {
+ public Optimizer(Vector<String> _output, boolean[] _isOptimizable, double[] _initialLambda,
+ HashMap<String, String>[] _feat_hash, HashMap<String, String>[] _stats_hash) {
+ output = _output; // (not used for now)
+ isOptimizable = _isOptimizable;
+ initialLambda = _initialLambda; // initial weights array
+ paramDim = initialLambda.length - 1;
+ initialLambda = _initialLambda;
+ feat_hash = _feat_hash; // feature hash table
+ stats_hash = _stats_hash; // suff. stats hash table
+ finalLambda = new double[initialLambda.length];
+ for (int i = 0; i < finalLambda.length; i++)
+ finalLambda[i] = initialLambda[i];
+ }
+
+ // run MIRA for one epoch
+ public double[] runOptimizer() {
+ List<Integer> sents = new ArrayList<Integer>();
+ for (int i = 0; i < sentNum; ++i)
+ sents.add(i);
+ double[] avgLambda = new double[initialLambda.length]; // only needed if averaging is required
+ for (int i = 0; i < initialLambda.length; i++)
+ avgLambda[i] = 0.0;
+ double[] bestLambda = new double[initialLambda.length]; // only needed if averaging is required
+ for (int i = 0; i < initialLambda.length; i++)
+ bestLambda[i] = 0.0;
+ double bestMetricScore = evalMetric.getToBeMinimized() ? PosInf : NegInf;
+ int bestIter = 0;
+ for (int iter = 0; iter < miraIter; ++iter) {
+ System.arraycopy(finalLambda, 1, initialLambda, 1, paramDim);
+ if (needShuffle)
+ Collections.shuffle(sents);
+
+ double oraMetric, oraScore, predMetric, predScore;
+ double[] oraPredScore = new double[4];
+ double eta = 1.0; // learning rate, will not be changed if run percep
+ double avgEta = 0; // average eta, just for analysis
+ double loss = 0;
+ double diff = 0;
+ double featNorm = 0;
+ double sumMetricScore = 0;
+ double sumModelScore = 0;
+ String oraFeat = "";
+ String predFeat = "";
+ String[] oraPredFeat = new String[2];
+ String[] vecOraFeat;
+ String[] vecPredFeat;
+ String[] featInfo;
+ int thisBatchSize = 0;
+ int numBatch = 0;
+ int numUpdate = 0;
+ Iterator it;
+ Integer diffFeatId;
+
+ // update weights
+ Integer s;
+ int sentCount = 0;
+ while( sentCount < sentNum ) {
+ loss = 0;
+ thisBatchSize = batchSize;
+ ++numBatch;
+ HashMap<Integer, Double> featDiff = new HashMap<Integer, Double>();
+ for(int b = 0; b < batchSize; ++b ) {
+ //find out oracle and prediction
+ s = sents.get(sentCount);
+ // find out oracle and prediction
+ findOraPred(s, oraPredScore, oraPredFeat, finalLambda, featScale);
+
+ // the model scores here are already scaled in findOraPred
+ oraMetric = oraPredScore[0];
+ oraScore = oraPredScore[1];
+ predMetric = oraPredScore[2];
+ predScore = oraPredScore[3];
+ oraFeat = oraPredFeat[0];
+ predFeat = oraPredFeat[1];
+
+ // update the scale
+ if (needScale) { // otherwise featscale remains 1.0
+ sumMetricScore += java.lang.Math.abs(oraMetric + predMetric);
+ // restore the original model score
+ sumModelScore += java.lang.Math.abs(oraScore + predScore) / featScale;
+
+ if (sumModelScore / sumMetricScore > scoreRatio)
+ featScale = sumMetricScore / sumModelScore;
+ }
+
+ vecOraFeat = oraFeat.split("\\s+");
+ vecPredFeat = predFeat.split("\\s+");
+
+ //accumulate difference feature vector
+ if ( b == 0 ) {
+ for (int i = 0; i < vecOraFeat.length; i++) {
+ featInfo = vecOraFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (int i = 0; i < vecPredFeat.length; i++) {
+ featInfo = vecPredFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)-Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the 2nd feature vector
+ featDiff.put(diffFeatId, -1.0*Double.parseDouble(featInfo[1]));
+ }
+ } else {
+ for (int i = 0; i < vecOraFeat.length; i++) {
+ featInfo = vecOraFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)+Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the new oracle feature vector
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (int i = 0; i < vecPredFeat.length; i++) {
+ featInfo = vecPredFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)-Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the new prediction feature vector
+ featDiff.put(diffFeatId, -1.0*Double.parseDouble(featInfo[1]));
+ }
+ }
+ if (!runPercep) { // otherwise eta=1.0
+ // remember the model scores here are already scaled
+ double singleLoss = evalMetric.getToBeMinimized() ?
+ (predMetric - oraMetric) - (oraScore - predScore) / featScale
+ : (oraMetric - predMetric) - (oraScore - predScore) / featScale;
+ loss += singleLoss;
+ }
+ ++sentCount;
+ if( sentCount >= sentNum ) {
+ thisBatchSize = b + 1;
+ break;
+ }
+ } //for(int b = 0; b < batchSize; ++b)
+
+ if (!runPercep) { // otherwise eta=1.0
+ featNorm = 0;
+ Collection<Double> allDiff = featDiff.values();
+ for (it = allDiff.iterator(); it.hasNext();) {
+ diff = (Double) it.next();
+ featNorm += diff * diff / ( thisBatchSize * thisBatchSize );
+ }
+ }
+ if( loss <= 0 )
+ eta = 0;
+ else {
+ loss /= thisBatchSize;
+ // feat vector not scaled before
+ eta = C < loss / featNorm ? C : loss / featNorm;
+ }
+ avgEta += eta;
+ Set<Integer> diffFeatSet = featDiff.keySet();
+ it = diffFeatSet.iterator();
+ if ( java.lang.Math.abs(eta) > 1e-20 ) {
+ while (it.hasNext()) {
+ diffFeatId = (Integer) it.next();
+ finalLambda[diffFeatId] =
+ finalLambda[diffFeatId] + eta * featDiff.get(diffFeatId) / thisBatchSize;
+ }
+ }
+ if (needAvg) {
+ for (int i = 0; i < avgLambda.length; ++i)
+ avgLambda[i] += finalLambda[i];
+ }
+ } //while( sentCount < sentNum )
+
+ avgEta /= numBatch;
+
+ /*
+ * for( int i=0; i<finalLambda.length; i++ ) System.out.print(finalLambda[i]+" ");
+ * System.out.println(); System.exit(0);
+ */
+
+ double initMetricScore;
+ if(iter == 0 ) {
+ initMetricScore = computeCorpusMetricScore(initialLambda);
+ if(needAvg)
+ finalMetricScore = computeCorpusMetricScore(avgLambda);
+ else
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ } else {
+ initMetricScore = finalMetricScore;
+ if(needAvg)
+ finalMetricScore = computeCorpusMetricScore(avgLambda);
+ else
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ }
+
+ if(evalMetric.getToBeMinimized()) {
+ if( finalMetricScore < bestMetricScore ) {
+ bestMetricScore = finalMetricScore;
+ bestIter = iter;
+ for( int i = 0; i < finalLambda.length; ++i )
+ bestLambda[i] = needAvg ? avgLambda[i] : finalLambda[i];
+ }
+ } else {
+ if( finalMetricScore > bestMetricScore ) {
+ bestMetricScore = finalMetricScore;
+ bestIter = iter;
+ for( int i = 0; i < finalLambda.length; ++i )
+ bestLambda[i] = needAvg ? avgLambda[i] : finalLambda[i];
+ }
+ }
+
+ if ( iter == miraIter - 1 ) {
+ for (int i = 0; i < finalLambda.length; ++i)
+ finalLambda[i] =
+ needAvg ? bestLambda[i] / ( numBatch * ( bestIter + 1 ) ) : bestLambda[i];
+ }
+
+ // prepare the printing info
+ String result = "Iter " + iter + ": Avg learning rate=" + String.format("%.4f", avgEta);
+ result += " Initial " + evalMetric.get_metricName() + "="
+ + String.format("%.4f", initMetricScore) + " Final " + evalMetric.get_metricName() + "="
+ + String.format("%.4f", finalMetricScore);
+ output.add(result);
+ } // for ( int iter = 0; iter < miraIter; ++iter )
+ String result = "Best " + evalMetric.get_metricName() + "="
+ + String.format("%.4f", bestMetricScore)
+ + " (iter = " + bestIter + ")\n";
+ output.add(result);
+ finalMetricScore = bestMetricScore;
+
+ // non-optimizable weights should remain unchanged
+ ArrayList<Double> cpFixWt = new ArrayList<Double>();
+ for (int i = 1; i < isOptimizable.length; ++i) {
+ if (!isOptimizable[i])
+ cpFixWt.add(finalLambda[i]);
+ }
+ normalizeLambda(finalLambda);
+ int countNonOpt = 0;
+ for (int i = 1; i < isOptimizable.length; ++i) {
+ if (!isOptimizable[i]) {
+ finalLambda[i] = cpFixWt.get(countNonOpt);
+ ++countNonOpt;
+ }
+ }
+ return finalLambda;
+ }
+
+ public double computeCorpusMetricScore(double[] finalLambda) {
+ int suffStatsCount = evalMetric.get_suffStatsCount();
+ double modelScore;
+ double maxModelScore;
+ Set<String> candSet;
+ String candStr;
+ String[] feat_str;
+ String[] tmpStatsVal = new String[suffStatsCount];
+ int[] corpusStatsVal = new int[suffStatsCount];
+ for (int i = 0; i < suffStatsCount; i++)
+ corpusStatsVal[i] = 0;
+
+ for (int i = 0; i < sentNum; i++) {
+ candSet = feat_hash[i].keySet();
+ // find out the 1-best candidate for each sentence
+ // this depends on the training mode
+ maxModelScore = NegInf;
+ for (Iterator it = candSet.iterator(); it.hasNext();) {
+ modelScore = 0.0;
+ candStr = it.next().toString();
+ feat_str = feat_hash[i].get(candStr).split("\\s+");
+ String[] feat_info;
+ for (int f = 0; f < feat_str.length; f++) {
+ feat_info = feat_str[f].split("=");
+ modelScore += Double.parseDouble(feat_info[1]) * finalLambda[Vocabulary.id(feat_info[0])];
+ }
+ if (maxModelScore < modelScore) {
+ maxModelScore = modelScore;
+ tmpStatsVal = stats_hash[i].get(candStr).split("\\s+"); // save the
+ // suff stats
+ }
+ }
+
+ for (int j = 0; j < suffStatsCount; j++)
+ corpusStatsVal[j] += Integer.parseInt(tmpStatsVal[j]); // accumulate
+ // corpus-leve
+ // suff stats
+ } // for( int i=0; i<sentNum; i++ )
+
+ return evalMetric.score(corpusStatsVal);
+ }
+
+ private void findOraPred(int sentId, double[] oraPredScore, String[] oraPredFeat,
+ double[] lambda, double featScale) {
+ double oraMetric = 0, oraScore = 0, predMetric = 0, predScore = 0;
+ String oraFeat = "", predFeat = "";
+ double candMetric = 0, candScore = 0; // metric and model scores for each cand
+ Set<String> candSet = stats_hash[sentId].keySet();
+ String cand = "";
+ String feats = "";
+ String oraCand = ""; // only used when BLEU/TER-BLEU is used as metric
+ String[] featStr;
+ String[] featInfo;
+
+ int actualFeatId;
+ double bestOraScore;
+ double worstPredScore;
+
+ if (oraSelectMode == 1)
+ bestOraScore = NegInf; // larger score will be selected
+ else {
+ if (evalMetric.getToBeMinimized())
+ bestOraScore = PosInf; // smaller score will be selected
+ else
+ bestOraScore = NegInf;
+ }
+
+ if (predSelectMode == 1 || predSelectMode == 2)
+ worstPredScore = NegInf; // larger score will be selected
+ else {
+ if (evalMetric.getToBeMinimized())
+ worstPredScore = NegInf; // larger score will be selected
+ else
+ worstPredScore = PosInf;
+ }
+
+ for (Iterator it = candSet.iterator(); it.hasNext();) {
+ cand = it.next().toString();
+ candMetric = computeSentMetric(sentId, cand); // compute metric score
+
+ // start to compute model score
+ candScore = 0;
+ featStr = feat_hash[sentId].get(cand).split("\\s+");
+ feats = "";
+
+ for (int i = 0; i < featStr.length; i++) {
+ featInfo = featStr[i].split("=");
+ actualFeatId = Vocabulary.id(featInfo[0]);
+ candScore += Double.parseDouble(featInfo[1]) * lambda[actualFeatId];
+ if ((actualFeatId < isOptimizable.length && isOptimizable[actualFeatId])
+ || actualFeatId >= isOptimizable.length)
+ feats += actualFeatId + "=" + Double.parseDouble(featInfo[1]) + " ";
+ }
+
+ candScore *= featScale; // scale the model score
+
+ // is this cand oracle?
+ if (oraSelectMode == 1) {// "hope", b=1, r=1
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (bestOraScore <= (candScore - candMetric)) {
+ bestOraScore = candScore - candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= (candScore + candMetric)) {
+ bestOraScore = candScore + candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ } else {// best metric score(ex: max BLEU), b=1, r=0
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (bestOraScore >= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ }
+
+ // is this cand prediction?
+ if (predSelectMode == 1) {// "fear"
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (worstPredScore <= (candScore + candMetric)) {
+ worstPredScore = candScore + candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore <= (candScore - candMetric)) {
+ worstPredScore = candScore - candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ } else if (predSelectMode == 2) {// model prediction(max model score)
+ if (worstPredScore <= candScore) {
+ worstPredScore = candScore;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {// worst metric score(ex: min BLEU)
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (worstPredScore <= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore >= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ }
+ }
+
+ oraPredScore[0] = oraMetric;
+ oraPredScore[1] = oraScore;
+ oraPredScore[2] = predMetric;
+ oraPredScore[3] = predScore;
+ oraPredFeat[0] = oraFeat;
+ oraPredFeat[1] = predFeat;
+
+ // update the BLEU metric statistics if pseudo corpus is used to compute BLEU/TER-BLEU
+ if (evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ bleuHistory[sentId][j] = R * bleuHistory[sentId][j] + Integer.parseInt(statVal_str[j]);
+ }
+
+ if (evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount() - 2; j++)
+ bleuHistory[sentId][j] = R * bleuHistory[sentId][j] + Integer.parseInt(statVal_str[j + 2]); // the
+ // first
+ // 2
+ // stats
+ // are
+ // TER
+ // stats
+ }
+ }
+
+ // compute *sentence-level* metric score for cand
+ private double computeSentMetric(int sentId, String cand) {
+ String statString;
+ String[] statVal_str;
+ int[] statVal = new int[evalMetric.get_suffStatsCount()];
+
+ statString = stats_hash[sentId].get(cand);
+ statVal_str = statString.split("\\s+");
+
+ if (evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = (int) (Integer.parseInt(statVal_str[j]) + bleuHistory[sentId][j]);
+ } else if (evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount() - 2; j++)
+ statVal[j + 2] = (int) (Integer.parseInt(statVal_str[j + 2]) + bleuHistory[sentId][j]); // only
+ // modify
+ // the
+ // BLEU
+ // stats
+ // part(TER
+ // has
+ // 2
+ // stats)
+ } else { // in all other situations, use normal stats
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = Integer.parseInt(statVal_str[j]);
+ }
+
+ return evalMetric.score(statVal);
+ }
+
+ // from ZMERT
+ private void normalizeLambda(double[] origLambda) {
+ // private String[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ int normalizationMethod = (int) normalizationOptions[0];
+ double scalingFactor = 1.0;
+ if (normalizationMethod == 0) {
+ scalingFactor = 1.0;
+ } else if (normalizationMethod == 1) {
+ int c = (int) normalizationOptions[2];
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[c]);
+ } else if (normalizationMethod == 2) {
+ double maxAbsVal = -1;
+ int maxAbsVal_c = 0;
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) > maxAbsVal) {
+ maxAbsVal = Math.abs(origLambda[c]);
+ maxAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[maxAbsVal_c]);
+
+ } else if (normalizationMethod == 3) {
+ double minAbsVal = PosInf;
+ int minAbsVal_c = 0;
+
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) < minAbsVal) {
+ minAbsVal = Math.abs(origLambda[c]);
+ minAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[minAbsVal_c]);
+
+ } else if (normalizationMethod == 4) {
+ double pow = normalizationOptions[1];
+ double norm = L_norm(origLambda, pow);
+ scalingFactor = normalizationOptions[2] / norm;
+ }
+
+ for (int c = 1; c <= paramDim; ++c) {
+ origLambda[c] *= scalingFactor;
+ }
+ }
+
+ // from ZMERT
+ private double L_norm(double[] A, double pow) {
+ // calculates the L-pow norm of A[]
+ // NOTE: this calculation ignores A[0]
+ double sum = 0.0;
+ for (int i = 1; i < A.length; ++i)
+ sum += Math.pow(Math.abs(A[i]), pow);
+
+ return Math.pow(sum, 1 / pow);
+ }
+
+ public static double getScale() {
+ return featScale;
+ }
+
+ public static void initBleuHistory(int sentNum, int statCount) {
+ bleuHistory = new double[sentNum][statCount];
+ for (int i = 0; i < sentNum; i++) {
+ for (int j = 0; j < statCount; j++) {
+ bleuHistory[i][j] = 0.0;
+ }
+ }
+ }
+
+ public double getMetricScore() {
+ return finalMetricScore;
+ }
+
+ private Vector<String> output;
+ private double[] initialLambda;
+ private double[] finalLambda;
+ private double finalMetricScore;
+ private HashMap<String, String>[] feat_hash;
+ private HashMap<String, String>[] stats_hash;
+ private int paramDim;
+ private boolean[] isOptimizable;
+ public static int sentNum;
+ public static int miraIter; // MIRA internal iterations
+ public static int oraSelectMode;
+ public static int predSelectMode;
+ public static int batchSize;
+ public static boolean needShuffle;
+ public static boolean needScale;
+ public static double scoreRatio;
+ public static boolean runPercep;
+ public static boolean needAvg;
+ public static boolean usePseudoBleu;
+ public static double featScale = 1.0; // scale the features in order to make the model score
+ // comparable with metric score
+ // updates in each epoch if necessary
+ public static double C; // relaxation coefficient
+ public static double R; // corpus decay(used only when pseudo corpus is used to compute BLEU)
+ public static EvaluationMetric evalMetric;
+ public static double[] normalizationOptions;
+ public static double[][] bleuHistory;
+
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java b/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java
new file mode 100644
index 0000000..575515a
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractionHG.java
@@ -0,0 +1,797 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.oracle;
+
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiString;
+import static org.apache.joshua.util.FormatUtils.removeSentenceMarkers;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.Support;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.hypergraph.HyperEdge;
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor;
+import org.apache.joshua.util.FileUtility;
+import org.apache.joshua.util.io.LineReader;
+import org.apache.joshua.util.FormatUtils;
+
+/**
+ * approximated BLEU (1) do not consider clipping effect (2) in the dynamic programming, do not
+ * maintain different states for different hyp length (3) brief penalty is calculated based on the
+ * avg ref length (4) using sentence-level BLEU, instead of doc-level BLEU
+ *
+ * @author Zhifei Li, zhifei.work@gmail.com (Johns Hopkins University)
+ */
+public class OracleExtractionHG extends SplitHg {
+ static String BACKOFF_LEFT_LM_STATE_SYM = "<lzfbo>";
+ public int BACKOFF_LEFT_LM_STATE_SYM_ID;// used for equivelant state
+
+ static String NULL_LEFT_LM_STATE_SYM = "<lzflnull>";
+ public int NULL_LEFT_LM_STATE_SYM_ID;// used for equivelant state
+
+ static String NULL_RIGHT_LM_STATE_SYM = "<lzfrnull>";
+ public int NULL_RIGHT_LM_STATE_SYM_ID;// used for equivelant state
+
+ // int[] ref_sentence;//reference string (not tree)
+ protected int src_sent_len = 0;
+ protected int ref_sent_len = 0;
+ protected int g_lm_order = 4; // only used for decide whether to get the LM state by this class or
+ // not in compute_state
+ static protected boolean do_local_ngram_clip = false;
+ static protected boolean maitain_length_state = false;
+ static protected int g_bleu_order = 4;
+
+ static boolean using_left_equiv_state = true;
+ static boolean using_right_equiv_state = true;
+
+ // TODO Add generics to hash tables in this class
+ HashMap<String, Boolean> tbl_suffix = new HashMap<String, Boolean>();
+ HashMap<String, Boolean> tbl_prefix = new HashMap<String, Boolean>();
+ static PrefixGrammar grammar_prefix = new PrefixGrammar();// TODO
+ static PrefixGrammar grammar_suffix = new PrefixGrammar();// TODO
+
+ // key: item; value: best_deduction, best_bleu, best_len, # of n-gram match where n is in [1,4]
+ protected HashMap<String, Integer> tbl_ref_ngrams = new HashMap<String, Integer>();
+
+ static boolean always_maintain_seperate_lm_state = true; // if true: the virtual item maintain its
+ // own lm state regardless whether
+ // lm_order>=g_bleu_order
+
+ int lm_feat_id = 0; // the baseline LM feature id
+
+ /**
+ * Constructs a new object capable of extracting a tree from a hypergraph that most closely
+ * matches a provided oracle sentence.
+ * <p>
+ * It seems that the symbol table here should only need to represent monolingual terminals, plus
+ * nonterminals.
+ *
+ * @param lm_feat_id_ a language model feature identifier
+ */
+ public OracleExtractionHG(int lm_feat_id_) {
+ this.lm_feat_id = lm_feat_id_;
+ this.BACKOFF_LEFT_LM_STATE_SYM_ID = Vocabulary.id(BACKOFF_LEFT_LM_STATE_SYM);
+ this.NULL_LEFT_LM_STATE_SYM_ID = Vocabulary.id(NULL_RIGHT_LM_STATE_SYM);
+ this.NULL_RIGHT_LM_STATE_SYM_ID = Vocabulary.id(NULL_RIGHT_LM_STATE_SYM);
+ }
+
+ /*
+ * for 919 sent, time_on_reading: 148797 time_on_orc_extract: 580286
+ */
+ @SuppressWarnings({ "unused" })
+ public static void main(String[] args) throws IOException {
+ JoshuaConfiguration joshuaConfiguration = new JoshuaConfiguration();
+ /*
+ * String f_hypergraphs="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.items"; String
+ * f_rule_tbl="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.rules"; String
+ * f_ref_files="C:\\Users\\zli\\Documents\\mt03.ref.txt.1"; String f_orc_out
+ * ="C:\\Users\\zli\\Documents\\mt03.orc.txt";
+ */
+ if (6 != args.length) {
+ System.out
+ .println("Usage: java Decoder f_hypergraphs f_rule_tbl f_ref_files f_orc_out lm_order orc_extract_nbest");
+ System.out.println("num of args is " + args.length);
+ for (int i = 0; i < args.length; i++) {
+ System.out.println("arg is: " + args[i]);
+ }
+ System.exit(1);
+ }
+ // String f_hypergraphs = args[0].trim();
+ // String f_rule_tbl = args[1].trim();
+ String f_ref_files = args[2].trim();
+ String f_orc_out = args[3].trim();
+ int lm_order = Integer.parseInt(args[4].trim());
+ boolean orc_extract_nbest = Boolean.valueOf(args[5].trim()); // oracle extraction from nbest or hg
+
+ int baseline_lm_feat_id = 0;
+
+ KBestExtractor kbest_extractor = null;
+ int topN = 300;// TODO
+ joshuaConfiguration.use_unique_nbest = true;
+ joshuaConfiguration.include_align_index = false;
+ boolean do_ngram_clip_nbest = true; // TODO
+ if (orc_extract_nbest) {
+ System.out.println("oracle extraction from nbest list");
+
+ kbest_extractor = new KBestExtractor(null, null, Decoder.weights, false, joshuaConfiguration);
+ }
+
+ BufferedWriter orc_out = FileUtility.getWriteFileStream(f_orc_out);
+
+ long start_time0 = System.currentTimeMillis();
+ long time_on_reading = 0;
+ long time_on_orc_extract = 0;
+ // DiskHyperGraph dhg_read = new DiskHyperGraph(baseline_lm_feat_id, true, null);
+
+ // dhg_read.initRead(f_hypergraphs, f_rule_tbl, null);
+
+ OracleExtractionHG orc_extractor = new OracleExtractionHG(baseline_lm_feat_id);
+ long start_time = System.currentTimeMillis();
+ int sent_id = 0;
+ for (String ref_sent: new LineReader(f_ref_files)) {
+ System.out.println("############Process sentence " + sent_id);
+ start_time = System.currentTimeMillis();
+ sent_id++;
+ // if(sent_id>10)break;
+
+ // HyperGraph hg = dhg_read.readHyperGraph();
+ HyperGraph hg = null;
+ if (hg == null)
+ continue;
+
+ // System.out.println("read disk hyp: " + (System.currentTimeMillis()-start_time));
+ time_on_reading += System.currentTimeMillis() - start_time;
+ start_time = System.currentTimeMillis();
+
+ String orc_sent = null;
+ double orc_bleu = 0;
+ if (orc_extract_nbest) {
+ Object[] res = orc_extractor.oracle_extract_nbest(kbest_extractor, hg, topN,
+ do_ngram_clip_nbest, ref_sent);
+ orc_sent = (String) res[0];
+ orc_bleu = (Double) res[1];
+ } else {
+ HyperGraph hg_oracle = orc_extractor.oracle_extract_hg(hg, hg.sentLen(), lm_order, ref_sent);
+ orc_sent = removeSentenceMarkers(getViterbiString(hg_oracle));
+ orc_bleu = orc_extractor.get_best_goal_cost(hg, orc_extractor.g_tbl_split_virtual_items);
+
+ time_on_orc_extract += System.currentTimeMillis() - start_time;
+ System.out.println("num_virtual_items: " + orc_extractor.g_num_virtual_items
+ + " num_virtual_dts: " + orc_extractor.g_num_virtual_deductions);
+ // System.out.println("oracle extract: " + (System.currentTimeMillis()-start_time));
+ }
+
+ orc_out.write(orc_sent + "\n");
+ System.out.println("orc bleu is " + orc_bleu);
+ }
+ orc_out.close();
+
+ System.out.println("time_on_reading: " + time_on_reading);
+ System.out.println("time_on_orc_extract: " + time_on_orc_extract);
+ System.out.println("total running time: " + (System.currentTimeMillis() - start_time0));
+ }
+
+ // find the oracle hypothesis in the nbest list
+ public Object[] oracle_extract_nbest(KBestExtractor kbest_extractor, HyperGraph hg, int n,
+ boolean do_ngram_clip, String ref_sent) {
+ if (hg.goalNode == null)
+ return null;
+ kbest_extractor.resetState();
+ int next_n = 0;
+ double orc_bleu = -1;
+ String orc_sent = null;
+ while (true) {
+ String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n);// ?????????
+ if (hyp_sent == null || next_n > n)
+ break;
+ double t_bleu = compute_sentence_bleu(ref_sent, hyp_sent, do_ngram_clip, 4);
+ if (t_bleu > orc_bleu) {
+ orc_bleu = t_bleu;
+ orc_sent = hyp_sent;
+ }
+ }
+ System.out.println("Oracle sent: " + orc_sent);
+ System.out.println("Oracle bleu: " + orc_bleu);
+ Object[] res = new Object[2];
+ res[0] = orc_sent;
+ res[1] = orc_bleu;
+ return res;
+ }
+
+ public HyperGraph oracle_extract_hg(HyperGraph hg, int src_sent_len_in, int lm_order,
+ String ref_sent_str) {
+ int[] ref_sent = Vocabulary.addAll(ref_sent_str);
+ g_lm_order = lm_order;
+ src_sent_len = src_sent_len_in;
+ ref_sent_len = ref_sent.length;
+
+ tbl_ref_ngrams.clear();
+ get_ngrams(tbl_ref_ngrams, g_bleu_order, ref_sent, false);
+ if (using_left_equiv_state || using_right_equiv_state) {
+ tbl_prefix.clear();
+ tbl_suffix.clear();
+ setup_prefix_suffix_tbl(ref_sent, g_bleu_order, tbl_prefix, tbl_suffix);
+ setup_prefix_suffix_grammar(ref_sent, g_bleu_order, grammar_prefix, grammar_suffix);// TODO
+ }
+ split_hg(hg);
+
+ // System.out.println("best bleu is " + get_best_goal_cost( hg, g_tbl_split_virtual_items));
+ return get_1best_tree_hg(hg, g_tbl_split_virtual_items);
+ }
+
+ /*
+ * This procedure does (1) identify all possible match (2) add a new deduction for each matches
+ */
+ protected void process_one_combination_axiom(HGNode parent_item,
+ HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt) {
+ if (null == cur_dt.getRule()) {
+ throw new RuntimeException("error null rule in axiom");
+ }
+ double avg_ref_len = (parent_item.j - parent_item.i >= src_sent_len) ? ref_sent_len
+ : (parent_item.j - parent_item.i) * ref_sent_len * 1.0 / src_sent_len;// avg len?
+ double bleu_score[] = new double[1];
+ DPStateOracle dps = compute_state(parent_item, cur_dt, null, tbl_ref_ngrams,
+ do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix);
+ VirtualDeduction t_dt = new VirtualDeduction(cur_dt, null, -bleu_score[0]);// cost: -best_bleu
+ g_num_virtual_deductions++;
+ add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true);
+ }
+
+ /*
+ * This procedure does (1) create a new deduction (based on cur_dt and ant_virtual_item) (2) find
+ * whether an Item can contain this deduction (based on virtual_item_sigs which is a hashmap
+ * specific to a parent_item) (2.1) if yes, add the deduction, (2.2) otherwise (2.2.1) create a
+ * new item (2.2.2) and add the item into virtual_item_sigs
+ */
+ protected void process_one_combination_nonaxiom(HGNode parent_item,
+ HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt,
+ ArrayList<VirtualItem> l_ant_virtual_item) {
+ if (null == l_ant_virtual_item) {
+ throw new RuntimeException("wrong call in process_one_combination_nonaxiom");
+ }
+ double avg_ref_len = (parent_item.j - parent_item.i >= src_sent_len) ? ref_sent_len
+ : (parent_item.j - parent_item.i) * ref_sent_len * 1.0 / src_sent_len;// avg len?
+ double bleu_score[] = new double[1];
+ DPStateOracle dps = compute_state(parent_item, cur_dt, l_ant_virtual_item, tbl_ref_ngrams,
+ do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix);
+ VirtualDeduction t_dt = new VirtualDeduction(cur_dt, l_ant_virtual_item, -bleu_score[0]);// cost:
+ // -best_bleu
+ g_num_virtual_deductions++;
+ add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true);
+ }
+
+ // DPState maintain all the state information at an item that is required during dynamic
+ // programming
+ protected static class DPStateOracle extends DPState {
+ int best_len; // this may not be used in the signature
+ int[] ngram_matches;
+ int[] left_lm_state;
+ int[] right_lm_state;
+
+ public DPStateOracle(int blen, int[] matches, int[] left, int[] right) {
+ best_len = blen;
+ ngram_matches = matches;
+ left_lm_state = left;
+ right_lm_state = right;
+ }
+
+ protected String get_signature() {
+ StringBuffer res = new StringBuffer();
+ if (maitain_length_state) {
+ res.append(best_len);
+ res.append(' ');
+ }
+ if (null != left_lm_state) { // goal-item have null state
+ for (int i = 0; i < left_lm_state.length; i++) {
+ res.append(left_lm_state[i]);
+ res.append(' ');
+ }
+ }
+ res.append("lzf ");
+
+ if (null != right_lm_state) { // goal-item have null state
+ for (int i = 0; i < right_lm_state.length; i++) {
+ res.append(right_lm_state[i]);
+ res.append(' ');
+ }
+ }
+ // if(left_lm_state==null || right_lm_state==null)System.out.println("sig is: " +
+ // res.toString());
+ return res.toString();
+ }
+
+ protected void print() {
+ StringBuffer res = new StringBuffer();
+ res.append("DPstate: best_len: ");
+ res.append(best_len);
+ for (int i = 0; i < ngram_matches.length; i++) {
+ res.append("; ngram: ");
+ res.append(ngram_matches[i]);
+ }
+ System.out.println(res.toString());
+ }
+ }
+
+ // ########################## commmon funcions #####################
+ // based on tbl_oracle_states, tbl_ref_ngrams, and dt, get the state
+ // get the new state: STATE_BEST_DEDUCT STATE_BEST_BLEU STATE_BEST_LEN NGRAM_MATCH_COUNTS
+ protected DPStateOracle compute_state(HGNode parent_item, HyperEdge dt,
+ ArrayList<VirtualItem> l_ant_virtual_item, HashMap<String, Integer> tbl_ref_ngrams,
+ boolean do_local_ngram_clip, int lm_order, double ref_len, double[] bleu_score,
+ HashMap<String, Boolean> tbl_suffix, HashMap<String, Boolean> tbl_prefix) {
+ // ##### deductions under "goal item" does not have rule
+ if (null == dt.getRule()) {
+ if (l_ant_virtual_item.size() != 1) {
+ throw new RuntimeException("error deduction under goal item have more than one item");
+ }
+ bleu_score[0] = -l_ant_virtual_item.get(0).best_virtual_deduction.best_cost;
+ return new DPStateOracle(0, null, null, null); // no DPState at all
+ }
+
+ // ################## deductions *not* under "goal item"
+ HashMap<String, Integer> new_ngram_counts = new HashMap<String, Integer>();// new ngrams created
+ // due to the
+ // combination
+ HashMap<String, Integer> old_ngram_counts = new HashMap<String, Integer>();// the ngram that has
+ // already been
+ // computed
+ int total_hyp_len = 0;
+ int[] num_ngram_match = new int[g_bleu_order];
+ int[] en_words = dt.getRule().getEnglish();
+
+ // ####calulate new and old ngram counts, and len
+
+ ArrayList<Integer> words = new ArrayList<Integer>();
+
+ // used for compute left- and right- lm state
+ ArrayList<Integer> left_state_sequence = null;
+ // used for compute left- and right- lm state
+ ArrayList<Integer> right_state_sequence = null;
+
+ int correct_lm_order = lm_order;
+ if (always_maintain_seperate_lm_state || lm_order < g_bleu_order) {
+ left_state_sequence = new ArrayList<Integer>();
+ right_state_sequence = new ArrayList<Integer>();
+ correct_lm_order = g_bleu_order; // if lm_order is smaller than g_bleu_order, we will get the
+ // lm state by ourself
+ }
+
+ // #### get left_state_sequence, right_state_sequence, total_hyp_len, num_ngram_match
+ for (int c = 0; c < en_words.length; c++) {
+ int c_id = en_words[c];
+ if (FormatUtils.isNonterminal(c_id)) {
+ int index = -(c_id + 1);
+ DPStateOracle ant_state = (DPStateOracle) l_ant_virtual_item.get(index).dp_state;
+ total_hyp_len += ant_state.best_len;
+ for (int t = 0; t < g_bleu_order; t++) {
+ num_ngram_match[t] += ant_state.ngram_matches[t];
+ }
+ int[] l_context = ant_state.left_lm_state;
+ int[] r_context = ant_state.right_lm_state;
+ for (int t : l_context) { // always have l_context
+ words.add(t);
+ if (null != left_state_sequence && left_state_sequence.size() < g_bleu_order - 1) {
+ left_state_sequence.add(t);
+ }
+ }
+ get_ngrams(old_ngram_counts, g_bleu_order, l_context, true);
+ if (r_context.length >= correct_lm_order - 1) { // the right and left are NOT overlapping
+ get_ngrams(new_ngram_counts, g_bleu_order, words, true);
+ get_ngrams(old_ngram_counts, g_bleu_order, r_context, true);
+ words.clear();// start a new chunk
+ if (null != right_state_sequence) {
+ right_state_sequence.clear();
+ }
+ for (int t : r_context) {
+ words.add(t);
+ }
+ }
+ if (null != right_state_sequence) {
+ for (int t : r_context) {
+ right_state_sequence.add(t);
+ }
+ }
+ } else {
+ words.add(c_id);
+ total_hyp_len += 1;
+ if (null != left_state_sequence && left_state_sequence.size() < g_bleu_order - 1) {
+ left_state_sequence.add(c_id);
+ }
+ if (null != right_state_sequence) {
+ right_state_sequence.add(c_id);
+ }
+ }
+ }
+ get_ngrams(new_ngram_counts, g_bleu_order, words, true);
+
+ // ####now deduct ngram counts
+ for (String ngram : new_ngram_counts.keySet()) {
+ if (tbl_ref_ngrams.containsKey(ngram)) {
+ int final_count = (Integer) new_ngram_counts.get(ngram);
+ if (old_ngram_counts.containsKey(ngram)) {
+ final_count -= (Integer) old_ngram_counts.get(ngram);
+ // BUG: Whoa, is that an actual hard-coded ID in there? :)
+ if (final_count < 0) {
+ throw new RuntimeException("negative count for ngram: " + Vocabulary.word(11844)
+ + "; new: " + new_ngram_counts.get(ngram) + "; old: " + old_ngram_counts.get(ngram));
+ }
+ }
+ if (final_count > 0) { // TODO: not correct/global ngram clip
+ if (do_local_ngram_clip) {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += Support.findMin(final_count,
+ (Integer) tbl_ref_ngrams.get(ngram));
+ } else {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += final_count; // do not do any cliping
+ }
+ }
+ }
+ }
+
+ // ####now calculate the BLEU score and state
+ int[] left_lm_state = null;
+ int[] right_lm_state = null;
+ left_lm_state = get_left_equiv_state(left_state_sequence, tbl_suffix);
+ right_lm_state = get_right_equiv_state(right_state_sequence, tbl_prefix);
+
+ // debug
+ // System.out.println("lm_order is " + lm_order);
+ // compare_two_int_arrays(left_lm_state,
+ // (int[])parent_item.tbl_states.get(Symbol.LM_L_STATE_SYM_ID));
+ // compare_two_int_arrays(right_lm_state,
+ // (int[])parent_item.tbl_states.get(Symbol.LM_R_STATE_SYM_ID));
+ // end
+
+ bleu_score[0] = compute_bleu(total_hyp_len, ref_len, num_ngram_match, g_bleu_order);
+ // System.out.println("blue score is " + bleu_score[0]);
+ return new DPStateOracle(total_hyp_len, num_ngram_match, left_lm_state, right_lm_state);
+ }
+
+ private int[] get_left_equiv_state(ArrayList<Integer> left_state_sequence,
+ HashMap<String, Boolean> tbl_suffix) {
+ int l_size = (left_state_sequence.size() < g_bleu_order - 1) ? left_state_sequence.size()
+ : (g_bleu_order - 1);
+ int[] left_lm_state = new int[l_size];
+ if (!using_left_equiv_state || l_size < g_bleu_order - 1) { // regular
+ for (int i = 0; i < l_size; i++) {
+ left_lm_state[i] = left_state_sequence.get(i);
+ }
+ } else {
+ for (int i = l_size - 1; i >= 0; i--) { // right to left
+ if (is_a_suffix_in_tbl(left_state_sequence, 0, i, tbl_suffix)) {
+ // if(is_a_suffix_in_grammar(left_state_sequence, 0, i, grammar_suffix)){
+ for (int j = i; j >= 0; j--) {
+ left_lm_state[j] = left_state_sequence.get(j);
+ }
+ break;
+ } else {
+ left_lm_state[i] = this.NULL_LEFT_LM_STATE_SYM_ID;
+ }
+ }
+ // System.out.println("origi left:" + Symbol.get_string(left_state_sequence) + "; equiv left:"
+ // + Symbol.get_string(left_lm_state));
+ }
+ return left_lm_state;
+ }
+
+ private boolean is_a_suffix_in_tbl(ArrayList<Integer> left_state_sequence, int start_pos,
+ int end_pos, HashMap<String, Boolean> tbl_suffix) {
+ if ((Integer) left_state_sequence.get(end_pos) == this.NULL_LEFT_LM_STATE_SYM_ID) {
+ return false;
+ }
+ StringBuffer suffix = new StringBuffer();
+ for (int i = end_pos; i >= start_pos; i--) { // right-most first
+ suffix.append(left_state_sequence.get(i));
+ if (i > start_pos)
+ suffix.append(' ');
+ }
+ return (Boolean) tbl_suffix.containsKey(suffix.toString());
+ }
+
+ private int[] get_right_equiv_state(ArrayList<Integer> right_state_sequence,
+ HashMap<String, Boolean> tbl_prefix) {
+ int r_size = (right_state_sequence.size() < g_bleu_order - 1) ? right_state_sequence.size()
+ : (g_bleu_order - 1);
+ int[] right_lm_state = new int[r_size];
+ if (!using_right_equiv_state || r_size < g_bleu_order - 1) { // regular
+ for (int i = 0; i < r_size; i++) {
+ right_lm_state[i] = (Integer) right_state_sequence.get(right_state_sequence.size() - r_size
+ + i);
+ }
+ } else {
+ for (int i = 0; i < r_size; i++) { // left to right
+ if (is_a_prefix_in_tbl(right_state_sequence, right_state_sequence.size() - r_size + i,
+ right_state_sequence.size() - 1, tbl_prefix)) {
+ // if(is_a_prefix_in_grammar(right_state_sequence, right_state_sequence.size()-r_size+i,
+ // right_state_sequence.size()-1, grammar_prefix)){
+ for (int j = i; j < r_size; j++) {
+ right_lm_state[j] = (Integer) right_state_sequence.get(right_state_sequence.size()
+ - r_size + j);
+ }
+ break;
+ } else {
+ right_lm_state[i] = this.NULL_RIGHT_LM_STATE_SYM_ID;
+ }
+ }
+ // System.out.println("origi right:" + Symbol.get_string(right_state_sequence)+
+ // "; equiv right:" + Symbol.get_string(right_lm_state));
+ }
+ return right_lm_state;
+ }
+
+ private boolean is_a_prefix_in_tbl(ArrayList<Integer> right_state_sequence, int start_pos,
+ int end_pos, HashMap<String, Boolean> tbl_prefix) {
+ if (right_state_sequence.get(start_pos) == this.NULL_RIGHT_LM_STATE_SYM_ID) {
+ return false;
+ }
+ StringBuffer prefix = new StringBuffer();
+ for (int i = start_pos; i <= end_pos; i++) {
+ prefix.append(right_state_sequence.get(i));
+ if (i < end_pos)
+ prefix.append(' ');
+ }
+ return (Boolean) tbl_prefix.containsKey(prefix.toString());
+ }
+
+ public static void compare_two_int_arrays(int[] a, int[] b) {
+ if (a.length != b.length) {
+ throw new RuntimeException("two arrays do not have same size");
+ }
+ for (int i = 0; i < a.length; i++) {
+ if (a[i] != b[i]) {
+ throw new RuntimeException("elements in two arrays are not same");
+ }
+ }
+ }
+
+ // sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order]))
+ public static double compute_bleu(int hyp_len, double ref_len, int[] num_ngram_match,
+ int bleu_order) {
+ if (hyp_len <= 0 || ref_len <= 0) {
+ throw new RuntimeException("ref or hyp is zero len");
+ }
+ double res = 0;
+ double wt = 1.0 / bleu_order;
+ double prec = 0;
+ double smooth_factor = 1.0;
+ for (int t = 0; t < bleu_order && t < hyp_len; t++) {
+ if (num_ngram_match[t] > 0) {
+ prec += wt * Math.log(num_ngram_match[t] * 1.0 / (hyp_len - t));
+ } else {
+ smooth_factor *= 0.5;// TODO
+ prec += wt * Math.log(smooth_factor / (hyp_len - t));
+ }
+ }
+ double bp = (hyp_len >= ref_len) ? 1.0 : Math.exp(1 - ref_len / hyp_len);
+ res = bp * Math.exp(prec);
+ // System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec)
+ // + "; bp: " + bp + "; bleu: " + res);
+ return res;
+ }
+
+ // accumulate ngram counts into tbl
+ public void get_ngrams(HashMap<String, Integer> tbl, int order, int[] wrds,
+ boolean ignore_null_equiv_symbol) {
+ for (int i = 0; i < wrds.length; i++) {
+ for (int j = 0; j < order && j + i < wrds.length; j++) { // ngram: [i,i+j]
+ boolean contain_null = false;
+ StringBuffer ngram = new StringBuffer();
+ for (int k = i; k <= i + j; k++) {
+ if (wrds[k] == this.NULL_LEFT_LM_STATE_SYM_ID
+ || wrds[k] == this.NULL_RIGHT_LM_STATE_SYM_ID) {
+ contain_null = true;
+ if (ignore_null_equiv_symbol)
+ break;
+ }
+ ngram.append(wrds[k]);
+ if (k < i + j)
+ ngram.append(' ');
+ }
+ if (ignore_null_equiv_symbol && contain_null)
+ continue; // skip this ngram
+ String ngram_str = ngram.toString();
+ if (tbl.containsKey(ngram_str)) {
+ tbl.put(ngram_str, (Integer) tbl.get(ngram_str) + 1);
+ } else {
+ tbl.put(ngram_str, 1);
+ }
+ }
+ }
+ }
+
+ /**
+ * accumulate ngram counts into tbl.
+ * @param tbl a {@link java.util.HashMap} which is used to store ngram counts
+ * @param order todo
+ * @param wrds an {@link java.util.ArrayList} containing {@link java.lang.Integer} word representations
+ * @param ignore_null_equiv_symbol set to true to skip some nGrams
+ */
+ public void get_ngrams(HashMap<String, Integer> tbl, int order, ArrayList<Integer> wrds,
+ boolean ignore_null_equiv_symbol) {
+ for (int i = 0; i < wrds.size(); i++) {
+ // ngram: [i,i+j]
+ for (int j = 0; j < order && j + i < wrds.size(); j++) {
+ boolean contain_null = false;
+ StringBuffer ngram = new StringBuffer();
+ for (int k = i; k <= i + j; k++) {
+ int t_wrd = (Integer) wrds.get(k);
+ if (t_wrd == this.NULL_LEFT_LM_STATE_SYM_ID || t_wrd == this.NULL_RIGHT_LM_STATE_SYM_ID) {
+ contain_null = true;
+ if (ignore_null_equiv_symbol)
+ break;
+ }
+ ngram.append(t_wrd);
+ if (k < i + j)
+ ngram.append(' ');
+ }
+ // skip this ngram
+ if (ignore_null_equiv_symbol && contain_null)
+ continue;
+
+ String ngram_str = ngram.toString();
+ if (tbl.containsKey(ngram_str)) {
+ tbl.put(ngram_str, (Integer) tbl.get(ngram_str) + 1);
+ } else {
+ tbl.put(ngram_str, 1);
+ }
+ }
+ }
+ }
+
+ // do_ngram_clip: consider global n-gram clip
+ public double compute_sentence_bleu(String ref_sent, String hyp_sent, boolean do_ngram_clip,
+ int bleu_order) {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ int[] numeric_ref_sent = Vocabulary.addAll(ref_sent);
+ int[] numeric_hyp_sent = Vocabulary.addAll(hyp_sent);
+ return compute_sentence_bleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order);
+ }
+
+ public double compute_sentence_bleu(int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip,
+ int bleu_order) {
+ double res_bleu = 0;
+ int order = 4;
+ HashMap<String, Integer> ref_ngram_tbl = new HashMap<String, Integer>();
+ get_ngrams(ref_ngram_tbl, order, ref_sent, false);
+ HashMap<String, Integer> hyp_ngram_tbl = new HashMap<String, Integer>();
+ get_ngrams(hyp_ngram_tbl, order, hyp_sent, false);
+
+ int[] num_ngram_match = new int[order];
+ for (String ngram : hyp_ngram_tbl.keySet()) {
+ if (ref_ngram_tbl.containsKey(ngram)) {
+ if (do_ngram_clip) {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += Support.findMin(
+ (Integer) ref_ngram_tbl.get(ngram), (Integer) hyp_ngram_tbl.get(ngram)); // ngram clip
+ } else {
+ // BUG: use joshua.util.Regex.spaces.split(...)
+ num_ngram_match[ngram.split("\\s+").length - 1] += (Integer) hyp_ngram_tbl.get(ngram);// without
+ // ngram
+ // count
+ // clipping
+ }
+ }
+ }
+ res_bleu = compute_bleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order);
+ // System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length +
+ // "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+
+ // " " + num_ngram_match[2] + " " +num_ngram_match[3]);
+
+ return res_bleu;
+ }
+
+ // #### equivalent lm stuff ############
+ public static void setup_prefix_suffix_tbl(int[] wrds, int order,
+ HashMap<String, Boolean> prefix_tbl, HashMap<String, Boolean> suffix_tbl) {
+ for (int i = 0; i < wrds.length; i++) {
+ for (int j = 0; j < order && j + i < wrds.length; j++) { // ngram: [i,i+j]
+ StringBuffer ngram = new StringBuffer();
+ // ### prefix
+ for (int k = i; k < i + j; k++) { // all ngrams [i,i+j-1]
+ ngram.append(wrds[k]);
+ prefix_tbl.put(ngram.toString(), true);
+ ngram.append(' ');
+ }
+ // ### suffix: right-most wrd first
+ ngram = new StringBuffer();
+ for (int k = i + j; k > i; k--) { // all ngrams [i+1,i+j]: reverse order
+ ngram.append(wrds[k]);
+ suffix_tbl.put(ngram.toString(), true);// stored in reverse order
+ ngram.append(' ');
+ }
+ }
+ }
+ }
+
+ // #### equivalent lm stuff ############
+ public static void setup_prefix_suffix_grammar(int[] wrds, int order, PrefixGrammar prefix_gr,
+ PrefixGrammar suffix_gr) {
+ for (int i = 0; i < wrds.length; i++) {
+ for (int j = 0; j < order && j + i < wrds.length; j++) { // ngram: [i,i+j]
+ // ### prefix
+ prefix_gr.add_ngram(wrds, i, i + j - 1);// ngram: [i,i+j-1]
+
+ // ### suffix: right-most wrd first
+ int[] reverse_wrds = new int[j];
+ for (int k = i + j, t = 0; k > i; k--) { // all ngrams [i+1,i+j]: reverse order
+ reverse_wrds[t++] = wrds[k];
+ }
+ suffix_gr.add_ngram(reverse_wrds, 0, j - 1);
+ }
+ }
+ }
+
+ /*
+ * a backoff node is a hashtable, it may include: (1) probabilititis for next words (2) pointers
+ * to a next-layer backoff node (hashtable) (3) backoff weight for this node (4) suffix/prefix
+ * flag to indicate that there is ngrams start from this suffix
+ */
+ private static class PrefixGrammar {
+
+ private static class PrefixGrammarNode extends HashMap<Integer, PrefixGrammarNode> {
+ private static final long serialVersionUID = 1L;
+ };
+
+ PrefixGrammarNode root = new PrefixGrammarNode();
+
+ // add prefix information
+ public void add_ngram(int[] wrds, int start_pos, int end_pos) {
+ // ######### identify the position, and insert the trinodes if necessary
+ PrefixGrammarNode pos = root;
+ for (int k = start_pos; k <= end_pos; k++) {
+ int cur_sym_id = wrds[k];
+ PrefixGrammarNode next_layer = pos.get(cur_sym_id);
+
+ if (null != next_layer) {
+ pos = next_layer;
+ } else {
+ // next layer node
+ PrefixGrammarNode tmp = new PrefixGrammarNode();
+ pos.put(cur_sym_id, tmp);
+ pos = tmp;
+ }
+ }
+ }
+
+ @SuppressWarnings("unused")
+ public boolean contain_ngram(ArrayList<Integer> wrds, int start_pos, int end_pos) {
+ if (end_pos < start_pos)
+ return false;
+ PrefixGrammarNode pos = root;
+ for (int k = start_pos; k <= end_pos; k++) {
+ int cur_sym_id = wrds.get(k);
+ PrefixGrammarNode next_layer = pos.get(cur_sym_id);
+ if (next_layer != null) {
+ pos = next_layer;
+ } else {
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractor.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractor.java b/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractor.java
new file mode 100644
index 0000000..ef67905
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/oracle/OracleExtractor.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.oracle;
+
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+
+/**
+ * Convenience wrapper class for oracle extraction code.
+ *
+ * @author Lane Schwartz
+ */
+public class OracleExtractor {
+
+ private final OracleExtractionHG extractor;
+
+ /**
+ * Constructs an object capable of extracting an oracle hypergraph.
+ */
+ public OracleExtractor() {
+
+ int baselineLanguageModelFeatureID = 0;
+ this.extractor = new OracleExtractionHG(baselineLanguageModelFeatureID);
+
+ }
+
+ /**
+ * Extract a hypergraph that represents the translation from the original shared forest hypergraph
+ * that is closest to the reference translation.
+ *
+ * @param forest Original hypergraph representing a shared forest.
+ * @param lmOrder N-gram order of the language model.
+ * @param reference Reference sentence.
+ * @return Hypergraph closest to the reference.
+ */
+ public HyperGraph getOracle(HyperGraph forest, int lmOrder, String reference) {
+ if (reference != null)
+ return extractor.oracle_extract_hg(forest, forest.sentLen(), lmOrder, reference);
+
+ return null;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/oracle/SplitHg.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/oracle/SplitHg.java b/joshua-core/src/main/java/org/apache/joshua/oracle/SplitHg.java
new file mode 100644
index 0000000..9fcdd35
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/oracle/SplitHg.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.oracle;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.hypergraph.HyperEdge;
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+
+/**
+ * This class implements general ways of splitting the hypergraph based on coarse-to-fine idea input
+ * is a hypergraph output is another hypergraph that has changed state structures.
+ *
+ * @author Zhifei Li, zhifei.work@gmail.com (Johns Hopkins University)
+ */
+public abstract class SplitHg {
+
+ HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items = new HashMap<HGNode, ArrayList<VirtualItem>>();
+
+ // number of items or deductions after splitting the hypergraph
+ public int g_num_virtual_items = 0;
+ public int g_num_virtual_deductions = 0;
+
+ // Note: the implementation of the following two functions should call add_deduction
+ protected abstract void process_one_combination_axiom(HGNode parent_item,
+ HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt);
+
+ protected abstract void process_one_combination_nonaxiom(HGNode parent_item,
+ HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt,
+ ArrayList<VirtualItem> l_ant_virtual_item);
+
+ // #### all the functions should be called after running split_hg(), before clearing
+ // g_tbl_split_virtual_items
+ public double get_best_goal_cost(HyperGraph hg,
+ HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items) {
+ double res = get_virtual_goal_item(hg, g_tbl_split_virtual_items).best_virtual_deduction.best_cost;
+ // System.out.println("best bleu is " +res);
+ return res;
+ }
+
+ public VirtualItem get_virtual_goal_item(HyperGraph original_hg,
+ HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items) {
+ ArrayList<VirtualItem> l_virtual_items = g_tbl_split_virtual_items.get(original_hg.goalNode);
+
+ if (l_virtual_items.size() != 1) {
+ // TODO: log this properly, fail properly
+ throw new RuntimeException("number of virtual goal items is not equal to one");
+ }
+ return l_virtual_items.get(0);
+ }
+
+ // get the 1best tree hg, the 1-best is ranked by the split hypergraph, but the return hypergraph
+ // is in the form of the original hg
+ public HyperGraph get_1best_tree_hg(HyperGraph original_hg,
+ HashMap<HGNode, ArrayList<VirtualItem>> g_tbl_split_virtual_items) {
+ VirtualItem virutal_goal_item = get_virtual_goal_item(original_hg, g_tbl_split_virtual_items);
+ HGNode onebest_goal_item = clone_item_with_best_deduction(virutal_goal_item);
+ HyperGraph res = new HyperGraph(onebest_goal_item, -1, -1, null);
+ // TODO: number of items/deductions
+ get_1best_tree_item(virutal_goal_item, onebest_goal_item);
+ return res;
+ }
+
+ private void get_1best_tree_item(VirtualItem virtual_it, HGNode onebest_item) {
+ VirtualDeduction virtual_dt = virtual_it.best_virtual_deduction;
+ if (virtual_dt.l_ant_virtual_items != null)
+ for (int i = 0; i < virtual_dt.l_ant_virtual_items.size(); i++) {
+ VirtualItem ant_it = (VirtualItem) virtual_dt.l_ant_virtual_items.get(i);
+ HGNode new_it = clone_item_with_best_deduction(ant_it);
+ onebest_item.bestHyperedge.getTailNodes().set(i, new_it);
+ get_1best_tree_item(ant_it, new_it);
+ }
+ }
+
+ // TODO: tbl_states
+ private static HGNode clone_item_with_best_deduction(VirtualItem virtual_it) {
+ HGNode original_it = virtual_it.p_item;
+ ArrayList<HyperEdge> l_deductions = new ArrayList<HyperEdge>();
+ HyperEdge clone_dt = clone_deduction(virtual_it.best_virtual_deduction);
+ l_deductions.add(clone_dt);
+ return new HGNode(original_it.i, original_it.j, original_it.lhs, l_deductions, clone_dt,
+ original_it.getDPStates());
+ }
+
+ private static HyperEdge clone_deduction(VirtualDeduction virtual_dt) {
+ HyperEdge original_dt = virtual_dt.p_dt;
+ ArrayList<HGNode> l_ant_items = null;
+ // l_ant_items will be changed in get_1best_tree_item
+ if (original_dt.getTailNodes() != null)
+ l_ant_items = new ArrayList<HGNode>(original_dt.getTailNodes());
+ HyperEdge res = new HyperEdge(original_dt.getRule(), original_dt.getBestDerivationScore(),
+ original_dt.getTransitionLogP(false), l_ant_items, original_dt.getSourcePath());
+ return res;
+ }
+
+ // ############### split hg #####
+ public void split_hg(HyperGraph hg) {
+ // TODO: more pre-process in the extended class
+ g_tbl_split_virtual_items.clear();
+ g_num_virtual_items = 0;
+ g_num_virtual_deductions = 0;
+ split_item(hg.goalNode);
+ }
+
+ // for each original Item, get a list of VirtualItem
+ private void split_item(HGNode it) {
+ if (g_tbl_split_virtual_items.containsKey(it))
+ return;// already processed
+ HashMap<String, VirtualItem> virtual_item_sigs = new HashMap<String, VirtualItem>();
+ // ### recursive call on each deduction
+ if (speed_up_item(it)) {
+ for (HyperEdge dt : it.hyperedges) {
+ split_deduction(dt, virtual_item_sigs, it);
+ }
+ }
+ // ### item-specific operation
+ // a list of items result by splitting me
+ ArrayList<VirtualItem> l_virtual_items = new ArrayList<VirtualItem>();
+ for (String signature : virtual_item_sigs.keySet())
+ l_virtual_items.add(virtual_item_sigs.get(signature));
+ g_tbl_split_virtual_items.put(it, l_virtual_items);
+ g_num_virtual_items += l_virtual_items.size();
+ // if(virtual_item_sigs.size()!=1)System.out.println("num of split items is " +
+ // virtual_item_sigs.size());
+ // get_best_virtual_score(it);//debug
+ }
+
+ private void split_deduction(HyperEdge cur_dt, HashMap<String, VirtualItem> virtual_item_sigs,
+ HGNode parent_item) {
+ if (speed_up_deduction(cur_dt) == false)
+ return;// no need to continue
+
+ // ### recursively split all my ant items, get a l_split_items for each original item
+ if (cur_dt.getTailNodes() != null)
+ for (HGNode ant_it : cur_dt.getTailNodes())
+ split_item(ant_it);
+
+ // ### recombine the deduction
+ redo_combine(cur_dt, virtual_item_sigs, parent_item);
+ }
+
+ private void redo_combine(HyperEdge cur_dt, HashMap<String, VirtualItem> virtual_item_sigs,
+ HGNode parent_item) {
+ List<HGNode> l_ant_items = cur_dt.getTailNodes();
+ if (l_ant_items != null) {
+ // arity: one
+ if (l_ant_items.size() == 1) {
+ HGNode it = l_ant_items.get(0);
+ ArrayList<VirtualItem> l_virtual_items = g_tbl_split_virtual_items.get(it);
+ for (VirtualItem ant_virtual_item : l_virtual_items) {
+ // used in combination
+ ArrayList<VirtualItem> l_ant_virtual_item = new ArrayList<VirtualItem>();
+ l_ant_virtual_item.add(ant_virtual_item);
+ process_one_combination_nonaxiom(parent_item, virtual_item_sigs, cur_dt,
+ l_ant_virtual_item);
+ }
+ // arity: two
+ } else if (l_ant_items.size() == 2) {
+ HGNode it1 = l_ant_items.get(0);
+ HGNode it2 = l_ant_items.get(1);
+ ArrayList<VirtualItem> l_virtual_items1 = g_tbl_split_virtual_items.get(it1);
+ ArrayList<VirtualItem> l_virtual_items2 = g_tbl_split_virtual_items.get(it2);
+ for (VirtualItem virtual_it1 : l_virtual_items1) {
+ for (VirtualItem virtual_it2 : l_virtual_items2) {
+ // used in combination
+ ArrayList<VirtualItem> l_ant_virtual_item = new ArrayList<VirtualItem>();
+ l_ant_virtual_item.add(virtual_it1);
+ l_ant_virtual_item.add(virtual_it2);
+ process_one_combination_nonaxiom(parent_item, virtual_item_sigs, cur_dt,
+ l_ant_virtual_item);
+ }
+ }
+ } else {
+ throw new RuntimeException(
+ "Sorry, we can only deal with rules with at most TWO non-terminals");
+ }
+ // axiom case: no nonterminal
+ } else {
+ process_one_combination_axiom(parent_item, virtual_item_sigs, cur_dt);
+ }
+ }
+
+ // this function should be called by
+ // process_one_combination_axiom/process_one_combination_nonaxiom
+ // virtual_item_sigs is specific to parent_item
+ protected void add_deduction(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs,
+ VirtualDeduction t_ded, DPState dpstate, boolean maintain_onebest_only) {
+ if (null == t_ded) {
+ throw new RuntimeException("deduction is null");
+ }
+ String sig = VirtualItem.get_signature(parent_item, dpstate);
+ VirtualItem t_virtual_item = (VirtualItem) virtual_item_sigs.get(sig);
+ if (t_virtual_item != null) {
+ t_virtual_item.add_deduction(t_ded, dpstate, maintain_onebest_only);
+ } else {
+ t_virtual_item = new VirtualItem(parent_item, dpstate, t_ded, maintain_onebest_only);
+ virtual_item_sigs.put(sig, t_virtual_item);
+ }
+ }
+
+ // return false if we can skip the item;
+ protected boolean speed_up_item(HGNode it) {
+ return true;// e.g., if the lm state is not valid, then no need to continue
+ }
+
+ // return false if we can skip the deduction;
+ protected boolean speed_up_deduction(HyperEdge dt) {
+ return true;// if the rule state is not valid, then no need to continue
+ }
+
+ protected abstract static class DPState {
+ protected abstract String get_signature();
+ };
+
+ /*
+ * In general, variables of items (1) list of hyperedges (2) best hyperedge (3) DP state (4)
+ * signature (operated on part/full of DP state)
+ */
+
+ protected static class VirtualItem {
+ HGNode p_item = null;// pointer to the true item
+ ArrayList<VirtualDeduction> l_virtual_deductions = null;
+ VirtualDeduction best_virtual_deduction = null;
+ DPState dp_state;// dynamic programming state: not all the variable in dp_state are in the
+ // signature
+
+ public VirtualItem(HGNode item, DPState dstate, VirtualDeduction fdt,
+ boolean maintain_onebest_only) {
+ p_item = item;
+ add_deduction(fdt, dstate, maintain_onebest_only);
+ }
+
+ public void add_deduction(VirtualDeduction fdt, DPState dstate, boolean maintain_onebest_only) {
+ if (maintain_onebest_only == false) {
+ if (l_virtual_deductions == null)
+ l_virtual_deductions = new ArrayList<VirtualDeduction>();
+ ;
+ l_virtual_deductions.add(fdt);
+ }
+ if (best_virtual_deduction == null || fdt.best_cost < best_virtual_deduction.best_cost) {
+ dp_state = dstate;
+ best_virtual_deduction = fdt;
+ }
+ }
+
+ // not all the variable in dp_state are in the signature
+ public String get_signature() {
+ return get_signature(p_item, dp_state);
+ }
+
+ public static String get_signature(HGNode item, DPState dstate) {
+ /*
+ * StringBuffer res = new StringBuffer(); //res.append(item); res.append(" ");//TODO:
+ * res.append(dstate.get_signature()); return res.toString();
+ */
+ return dstate.get_signature();
+ }
+ }
+
+ protected static class VirtualDeduction {
+ HyperEdge p_dt = null;// pointer to the true deduction
+ ArrayList<VirtualItem> l_ant_virtual_items = null;
+ double best_cost = Double.POSITIVE_INFINITY;// the 1-best cost of all possible derivation: best
+ // costs of ant items +
+ // non_stateless_transition_cost + r.statelesscost
+
+ public VirtualDeduction(HyperEdge dt, ArrayList<VirtualItem> ant_items, double best_cost_in) {
+ p_dt = dt;
+ l_ant_virtual_items = ant_items;
+ best_cost = best_cost_in;
+ }
+
+ public double get_transition_cost() {// note: transition_cost is already linearly interpolated
+ double res = best_cost;
+ if (l_ant_virtual_items != null)
+ for (VirtualItem ant_it : l_ant_virtual_items)
+ res -= ant_it.best_virtual_deduction.best_cost;
+ return res;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/oracle/package-info.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/oracle/package-info.java b/joshua-core/src/main/java/org/apache/joshua/oracle/package-info.java
new file mode 100644
index 0000000..ae14e82
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/oracle/package-info.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/**
+ * Provides for extracting the target string from a hypergraph
+ * that most closely matches a reference sentence. Much of the
+ * code in this package is based on descriptions in Adam
+ * Lopez's <a href="http://homepages.inf.ed.ac.uk/alopez/papers/adam.lopez.dissertation.pdf">
+ * doctoral thesis</a>.
+ */
+package org.apache.joshua.oracle;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierInterface.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierInterface.java b/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierInterface.java
new file mode 100755
index 0000000..d6dca73
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierInterface.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.pro;
+
+import java.util.Vector;
+
+public interface ClassifierInterface {
+ /*
+ * Arguments required to train a binary linear classifier: Vector<String> samples: all training
+ * samples should use sparse feature value representation. Format: feat_id1:feat_val1
+ * feat_id2:feat_val2 ... label (1 or -1) Example: 3:0.2 6:2 8:0.5 -1 (only enumerate firing
+ * features) Note feat_id should start from 1 double[] initialLambda: the initial weight
+ * vector(doesn't have to be used, depending on the classifier - just ignore the array if not to
+ * be used). The length of the vector should be the same as feature dimension. Note the 0^th entry
+ * is not used, so array should have length featDim+1 (to be consistent with Z-MERT) int featDim:
+ * feature vector dimension
+ *
+ * Return value: double[]: a vector containing weights for all features after training(also should
+ * have length featDim+1)
+ */
+ double[] runClassifier(Vector<String> samples, double[] initialLambda, int featDim);
+
+ // Set classifier-specific parameters, like config file path, num of iterations, command line...
+ void setClassifierParam(String[] param);
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierMegaM.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierMegaM.java b/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierMegaM.java
new file mode 100755
index 0000000..f75605f
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierMegaM.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.pro;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Vector;
+
+import org.apache.joshua.util.StreamGobbler;
+import org.apache.joshua.util.io.LineReader;
+
+// sparse feature representation version
+public class ClassifierMegaM implements ClassifierInterface {
+ @Override
+ public double[] runClassifier(Vector<String> samples, double[] initialLambda, int featDim) {
+ double[] lambda = new double[featDim + 1];
+ System.out.println("------- MegaM training starts ------");
+
+ try {
+ // prepare training file for MegaM
+ PrintWriter prt = new PrintWriter(new FileOutputStream(trainingFilePath));
+ String[] feat;
+ String[] featInfo;
+
+ for (String line : samples) {
+ feat = line.split("\\s+");
+
+ if (feat[feat.length - 1].equals("1"))
+ prt.print("1 ");
+ else
+ prt.print("0 ");
+
+ // only for dense representation
+ // for(int i=0; i<feat.length-1; i++)
+ // prt.print( (i+1) + " " + feat[i]+" "); //feat id starts from 1!
+
+ for (int i = 0; i < feat.length - 1; i++) {
+ featInfo = feat[i].split(":");
+ prt.print(featInfo[0] + " " + featInfo[1] + " ");
+ }
+ prt.println();
+ }
+ prt.close();
+
+ // start running MegaM
+ Runtime rt = Runtime.getRuntime();
+ Process p = rt.exec(commandFilePath);
+
+ StreamGobbler errorGobbler = new StreamGobbler(p.getErrorStream(), 1);
+ StreamGobbler outputGobbler = new StreamGobbler(p.getInputStream(), 1);
+
+ errorGobbler.start();
+ outputGobbler.start();
+
+ int decStatus = p.waitFor();
+ if (decStatus != 0) {
+ throw new RuntimeException("Call to decoder returned " + decStatus + "; was expecting " + 0 + ".");
+ }
+
+ // read the weights
+ for (String line: new LineReader(weightFilePath)) {
+ String val[] = line.split("\\s+");
+ lambda[Integer.parseInt(val[0])] = Double.parseDouble(val[1]);
+ }
+
+ File file = new File(trainingFilePath);
+ file.delete();
+ file = new File(weightFilePath);
+ file.delete();
+ } catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+
+ System.out.println("------- MegaM training ends ------");
+
+ /*
+ * try { Thread.sleep(20000); } catch(InterruptedException e) { }
+ */
+
+ return lambda;
+ }
+
+ @Override
+ /*
+ * for MegaM classifier: param[0] = MegaM command file path param[1] = MegaM training data
+ * file(generated on the fly) path param[2] = MegaM weight file(generated after training) path
+ * note that the training and weight file path should be consistent with that specified in the
+ * command file
+ */
+ public void setClassifierParam(String[] param) {
+ if (param == null) {
+ throw new RuntimeException("ERROR: must provide parameters for MegaM classifier!");
+ } else {
+ commandFilePath = param[0];
+ trainingFilePath = param[1];
+ weightFilePath = param[2];
+ }
+ }
+
+ String commandFilePath;
+ String trainingFilePath;
+ String weightFilePath;
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierPerceptron.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierPerceptron.java b/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierPerceptron.java
new file mode 100755
index 0000000..1b5d75c
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/pro/ClassifierPerceptron.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.pro;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Vector;
+
+// sparse feature representation version
+public class ClassifierPerceptron implements ClassifierInterface {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ClassifierPerceptron.class);
+
+ @Override
+ public double[] runClassifier(Vector<String> samples, double[] initialLambda, int featDim) {
+ System.out.println("------- Average-perceptron training starts ------");
+
+ int sampleSize = samples.size();
+ double score = 0; // model score
+ double label;
+ double[] lambda = new double[featDim + 1]; // in ZMERT lambda[0] is not used
+ double[] sum_lambda = new double[featDim + 1];
+ String[] featVal;
+
+ for (int i = 1; i <= featDim; i++) {
+ sum_lambda[i] = 0;
+ lambda[i] = initialLambda[i];
+ }
+
+ System.out.print("Perceptron iteration ");
+ int numError = 0;
+ // int numPosSamp = 0;
+ String[] feat_info;
+
+ for (int it = 0; it < maxIter; it++) {
+ System.out.print(it + " ");
+ numError = 0;
+ // numPosSamp = 0;
+
+ for (int s = 0; s < sampleSize; s++) {
+ featVal = samples.get(s).split("\\s+");
+
+ // only consider positive samples
+ // if( featVal[featDim].equals("1") )
+ // {
+ // numPosSamp++;
+ score = 0;
+ for (int d = 0; d < featVal.length - 1; d++) {
+ feat_info = featVal[d].split(":");
+ score += Double.parseDouble(feat_info[1]) * lambda[Integer.parseInt(feat_info[0])];
+ }
+
+ label = Double.parseDouble(featVal[featVal.length - 1]);
+ score *= label; // the last element is class label(+1/-1)
+
+ if (score <= bias) // incorrect classification
+ {
+ numError++;
+ for (int d = 0; d < featVal.length - 1; d++) {
+ feat_info = featVal[d].split(":");
+ int featID = Integer.parseInt(feat_info[0]);
+ lambda[featID] += learningRate * label * Double.parseDouble(feat_info[1]);
+ sum_lambda[featID] += lambda[featID];
+ }
+ }
+ // }//if( featVal[featDim].equals("1") )
+ }
+ if (numError == 0) break;
+ }
+
+ System.out.println("\n------- Average-perceptron training ends ------");
+
+ for (int i = 1; i <= featDim; i++)
+ sum_lambda[i] /= maxIter;
+
+ return sum_lambda;
+ }
+
+ @Override
+ /*
+ * for avg_perceptron: param[0] = maximum number of iterations param[1] = learning rate (step
+ * size) param[2] = bias (usually set to 0)
+ */
+ public void setClassifierParam(String[] param) {
+ if (param == null)
+ LOG.warn("no parameters specified for perceptron classifier, using default settings.");
+ else {
+ maxIter = Integer.parseInt(param[0]);
+ learningRate = Double.parseDouble(param[1]);
+ bias = Double.parseDouble(param[2]);
+ }
+ }
+
+ int maxIter = 20;
+ double learningRate = 0.5;
+ double bias = 0.0;
+}