You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/08/30 21:05:00 UTC
[15/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
index 16c25cd,0000000..6ad85a8
mode 100755,000000..100755
--- a/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
@@@ -1,716 -1,0 +1,712 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.adagrad;
+
- import java.util.Collections;
+import java.util.ArrayList;
++import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.Vector;
- import java.lang.Math;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.metrics.EvaluationMetric;
+
+// this class implements the AdaGrad algorithm
+public class Optimizer {
+ public Optimizer(Vector<String>_output, boolean[] _isOptimizable, double[] _initialLambda,
+ HashMap<String, String>[] _feat_hash, HashMap<String, String>[] _stats_hash) {
+ output = _output; // (not used for now)
+ isOptimizable = _isOptimizable;
+ initialLambda = _initialLambda; // initial weights array
- paramDim = initialLambda.length - 1;
++ paramDim = initialLambda.length - 1;
+ initialLambda = _initialLambda;
+ feat_hash = _feat_hash; // feature hash table
+ stats_hash = _stats_hash; // suff. stats hash table
+ finalLambda = new double[initialLambda.length];
+ System.arraycopy(initialLambda, 0, finalLambda, 0, finalLambda.length);
+ }
+
+ //run AdaGrad for one epoch
+ public double[] runOptimizer() {
+ List<Integer> sents = new ArrayList<>();
+ for( int i = 0; i < sentNum; ++i )
+ sents.add(i);
+ double[] avgLambda = new double[initialLambda.length]; //only needed if averaging is required
+ for( int i = 0; i < initialLambda.length; ++i )
+ avgLambda[i] = 0;
+ for ( int iter = 0; iter < adagradIter; ++iter ) {
+ System.arraycopy(finalLambda, 1, initialLambda, 1, paramDim);
+ if(needShuffle)
+ Collections.shuffle(sents);
-
++
+ double oraMetric, oraScore, predMetric, predScore;
+ double[] oraPredScore = new double[4];
+ double loss = 0;
+ double diff = 0;
+ double sumMetricScore = 0;
+ double sumModelScore = 0;
+ String oraFeat = "";
+ String predFeat = "";
+ String[] oraPredFeat = new String[2];
+ String[] vecOraFeat;
+ String[] vecPredFeat;
+ String[] featInfo;
- int thisBatchSize = 0;
+ int numBatch = 0;
+ int numUpdate = 0;
- Iterator it;
++ Iterator<Integer> it;
+ Integer diffFeatId;
+
+ //update weights
+ Integer s;
+ int sentCount = 0;
+ double prevLambda = 0;
+ double diffFeatVal = 0;
+ double oldVal = 0;
+ double gdStep = 0;
+ double Hii = 0;
+ double gradiiSquare = 0;
+ int lastUpdateTime = 0;
+ HashMap<Integer, Integer> lastUpdate = new HashMap<>();
+ HashMap<Integer, Double> lastVal = new HashMap<>();
+ HashMap<Integer, Double> H = new HashMap<>();
+ while( sentCount < sentNum ) {
+ loss = 0;
- thisBatchSize = batchSize;
+ ++numBatch;
+ HashMap<Integer, Double> featDiff = new HashMap<>();
+ for(int b = 0; b < batchSize; ++b ) {
+ //find out oracle and prediction
+ s = sents.get(sentCount);
+ findOraPred(s, oraPredScore, oraPredFeat, finalLambda, featScale);
-
++
+ //the model scores here are already scaled in findOraPred
+ oraMetric = oraPredScore[0];
+ oraScore = oraPredScore[1];
+ predMetric = oraPredScore[2];
+ predScore = oraPredScore[3];
+ oraFeat = oraPredFeat[0];
+ predFeat = oraPredFeat[1];
-
++
+ //update the scale
+ if(needScale) { //otherwise featscale remains 1.0
+ sumMetricScore += Math.abs(oraMetric + predMetric);
+ //restore the original model score
+ sumModelScore += Math.abs(oraScore + predScore) / featScale;
-
++
+ if(sumModelScore/sumMetricScore > scoreRatio)
+ featScale = sumMetricScore/sumModelScore;
+ }
+ // processedSent++;
-
++
+ vecOraFeat = oraFeat.split("\\s+");
+ vecPredFeat = predFeat.split("\\s+");
+
+ //accumulate difference feature vector
+ if ( b == 0 ) {
+ for (String aVecOraFeat : vecOraFeat) {
+ featInfo = aVecOraFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (String aVecPredFeat : vecPredFeat) {
+ featInfo = aVecPredFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId) - Double.parseDouble(featInfo[1]);
+ if (Math.abs(diff) > 1e-20)
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ } else //features only firing in the 2nd feature vector
+ featDiff.put(diffFeatId, -1.0 * Double.parseDouble(featInfo[1]));
+ }
+ } else {
+ for (String aVecOraFeat : vecOraFeat) {
+ featInfo = aVecOraFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId) + Double.parseDouble(featInfo[1]);
+ if (Math.abs(diff) > 1e-20)
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ } else //features only firing in the new oracle feature vector
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (String aVecPredFeat : vecPredFeat) {
+ featInfo = aVecPredFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId) - Double.parseDouble(featInfo[1]);
+ if (Math.abs(diff) > 1e-20)
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ } else //features only firing in the new prediction feature vector
+ featDiff.put(diffFeatId, -1.0 * Double.parseDouble(featInfo[1]));
+ }
+ }
+
+ //remember the model scores here are already scaled
+ double singleLoss = evalMetric.getToBeMinimized() ?
- (predMetric-oraMetric) - (oraScore-predScore)/featScale:
++ (predMetric-oraMetric) - (oraScore-predScore)/featScale:
+ (oraMetric-predMetric) - (oraScore-predScore)/featScale;
+ if(singleLoss > 0)
+ loss += singleLoss;
+ ++sentCount;
+ if( sentCount >= sentNum ) {
- thisBatchSize = b + 1;
+ break;
+ }
+ } //for(int b : batchSize)
+
+ //System.out.println("\n\n"+sentCount+":");
+
+ if( loss > 0 ) {
+ //if(true) {
+ ++numUpdate;
+ //update weights (see Duchi'11, Eq.23. For l1-reg, use lazy update)
+ Set<Integer> diffFeatSet = featDiff.keySet();
+ it = diffFeatSet.iterator();
+ while(it.hasNext()) { //note these are all non-zero gradients!
- diffFeatId = (Integer)it.next();
++ diffFeatId = it.next();
+ diffFeatVal = -1.0 * featDiff.get(diffFeatId); //gradient
+ if( regularization > 0 ) {
+ lastUpdateTime =
+ lastUpdate.get(diffFeatId) == null ? 0 : lastUpdate.get(diffFeatId);
+ if( lastUpdateTime < numUpdate - 1 ) {
+ //haven't been updated (gradient=0) for at least 2 steps
+ //lazy compute prevLambda now
+ oldVal =
+ lastVal.get(diffFeatId) == null ? initialLambda[diffFeatId] : lastVal.get(diffFeatId);
+ Hii =
+ H.get(diffFeatId) == null ? 0 : H.get(diffFeatId);
+ if(Math.abs(Hii) > 1e-20) {
+ if( regularization == 1 )
+ prevLambda =
+ Math.signum(oldVal) * clip( Math.abs(oldVal) - lam * eta * (numBatch - 1 - lastUpdateTime) / Hii );
+ else if( regularization == 2 ) {
+ prevLambda =
+ Math.pow( Hii/(lam+Hii), (numUpdate - 1 - lastUpdateTime) ) * oldVal;
+ if(needAvg) { //fill the gap due to lazy update
+ double prevLambdaCopy = prevLambda;
+ double scale = Hii/(lam+Hii);
+ for( int t = 0; t < numUpdate - 1 - lastUpdateTime; ++t ) {
+ avgLambda[diffFeatId] += prevLambdaCopy;
+ prevLambdaCopy /= scale;
+ }
+ }
+ }
+ } else {
+ if( regularization == 1 )
+ prevLambda = 0;
+ else if( regularization == 2 )
+ prevLambda = oldVal;
+ }
+ } else //just updated at last time step or just started
+ prevLambda = finalLambda[diffFeatId];
+ if(H.get(diffFeatId) != null) {
+ gradiiSquare = H.get(diffFeatId);
+ gradiiSquare *= gradiiSquare;
+ gradiiSquare += diffFeatVal * diffFeatVal;
+ Hii = Math.sqrt(gradiiSquare);
+ } else
+ Hii = Math.abs(diffFeatVal);
+ H.put(diffFeatId, Hii);
+ //update the weight
+ if( regularization == 1 ) {
+ gdStep = prevLambda - eta * diffFeatVal / Hii;
+ finalLambda[diffFeatId] = Math.signum(gdStep) * clip( Math.abs(gdStep) - lam * eta / Hii );
+ } else if(regularization == 2 ) {
+ finalLambda[diffFeatId] = (Hii * prevLambda - eta * diffFeatVal) / (lam + Hii);
+ if(needAvg)
+ avgLambda[diffFeatId] += finalLambda[diffFeatId];
+ }
+ lastUpdate.put(diffFeatId, numUpdate);
+ lastVal.put(diffFeatId, finalLambda[diffFeatId]);
+ } else { //if no regularization
+ if(H.get(diffFeatId) != null) {
+ gradiiSquare = H.get(diffFeatId);
+ gradiiSquare *= gradiiSquare;
+ gradiiSquare += diffFeatVal * diffFeatVal;
+ Hii = Math.sqrt(gradiiSquare);
+ } else
+ Hii = Math.abs(diffFeatVal);
+ H.put(diffFeatId, Hii);
+ finalLambda[diffFeatId] = finalLambda[diffFeatId] - eta * diffFeatVal / Hii;
+ if(needAvg)
+ avgLambda[diffFeatId] += finalLambda[diffFeatId];
+ }
+ } //while(it.hasNext())
+ } //if(loss > 0)
+ else { //no loss, therefore the weight update is skipped
+ //however, the avg weights still need to be accumulated
+ if( regularization == 0 ) {
+ for( int i = 1; i < finalLambda.length; ++i )
+ avgLambda[i] += finalLambda[i];
+ } else if( regularization == 2 ) {
+ if(needAvg) {
+ //due to lazy update, we need to figure out the actual
+ //weight vector at this point first...
+ for( int i = 1; i < finalLambda.length; ++i ) {
+ if( lastUpdate.get(i) != null ) {
+ if( lastUpdate.get(i) < numUpdate ) {
+ oldVal = lastVal.get(i);
+ Hii = H.get(i);
+ //lazy compute
+ avgLambda[i] +=
+ Math.pow( Hii/(lam+Hii), (numUpdate - lastUpdate.get(i)) ) * oldVal;
+ } else
+ avgLambda[i] += finalLambda[i];
+ }
+ avgLambda[i] += finalLambda[i];
+ }
+ }
+ }
+ }
+ } //while( sentCount < sentNum )
+ if( regularization > 0 ) {
+ for( int i = 1; i < finalLambda.length; ++i ) {
+ //now lazy compute those weights that haven't been taken care of
+ if( lastUpdate.get(i) == null )
+ finalLambda[i] = 0;
+ else if( lastUpdate.get(i) < numUpdate ) {
+ oldVal = lastVal.get(i);
+ Hii = H.get(i);
+ if( regularization == 1 )
+ finalLambda[i] =
+ Math.signum(oldVal) * clip( Math.abs(oldVal) - lam * eta * (numUpdate - lastUpdate.get(i)) / Hii );
+ else if( regularization == 2 ) {
- finalLambda[i] =
++ finalLambda[i] =
+ Math.pow( Hii/(lam+Hii), (numUpdate - lastUpdate.get(i)) ) * oldVal;
+ if(needAvg) { //fill the gap due to lazy update
+ double prevLambdaCopy = finalLambda[i];
+ double scale = Hii/(lam+Hii);
+ for( int t = 0; t < numUpdate - lastUpdate.get(i); ++t ) {
+ avgLambda[i] += prevLambdaCopy;
+ prevLambdaCopy /= scale;
+ }
+ }
+ }
+ }
+ if( regularization == 2 && needAvg ) {
+ if( iter == adagradIter - 1 )
+ finalLambda[i] = avgLambda[i] / ( numBatch * adagradIter );
+ }
+ }
+ } else { //if no regularization
+ if( iter == adagradIter - 1 && needAvg ) {
+ for( int i = 1; i < finalLambda.length; ++i )
+ finalLambda[i] = avgLambda[i] / ( numBatch * adagradIter );
+ }
+ }
+
+ double initMetricScore;
+ if (iter == 0) {
+ initMetricScore = computeCorpusMetricScore(initialLambda);
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ } else {
+ initMetricScore = finalMetricScore;
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ }
+ // prepare the printing info
+ String result = " Initial "
+ + evalMetric.get_metricName() + "=" + String.format("%.4f", initMetricScore) + " Final "
+ + evalMetric.get_metricName() + "=" + String.format("%.4f", finalMetricScore);
+ //print lambda info
+ // int numParamToPrint = 0;
+ // numParamToPrint = paramDim > 10 ? 10 : paramDim; // how many parameters
+ // // to print
+ // result = paramDim > 10 ? "Final lambda (first 10): {" : "Final lambda: {";
-
++
+ // for (int i = 1; i <= numParamToPrint; ++i)
+ // result += String.format("%.4f", finalLambda[i]) + " ";
+
+ output.add(result);
+ } //for ( int iter = 0; iter < adagradIter; ++iter ) {
+
+ //non-optimizable weights should remain unchanged
+ ArrayList<Double> cpFixWt = new ArrayList<>();
+ for ( int i = 1; i < isOptimizable.length; ++i ) {
+ if ( ! isOptimizable[i] )
+ cpFixWt.add(finalLambda[i]);
+ }
+ normalizeLambda(finalLambda);
+ int countNonOpt = 0;
+ for ( int i = 1; i < isOptimizable.length; ++i ) {
+ if ( ! isOptimizable[i] ) {
+ finalLambda[i] = cpFixWt.get(countNonOpt);
+ ++countNonOpt;
+ }
+ }
+ return finalLambda;
+ }
+
+ private double clip(double x) {
+ return x > 0 ? x : 0;
+ }
+
+ public double computeCorpusMetricScore(double[] finalLambda) {
+ int suffStatsCount = evalMetric.get_suffStatsCount();
+ double modelScore;
+ double maxModelScore;
+ Set<String> candSet;
+ String candStr;
+ String[] feat_str;
+ String[] tmpStatsVal = new String[suffStatsCount];
+ int[] corpusStatsVal = new int[suffStatsCount];
+ for (int i = 0; i < suffStatsCount; i++)
+ corpusStatsVal[i] = 0;
+
+ for (int i = 0; i < sentNum; i++) {
+ candSet = feat_hash[i].keySet();
+
+ // find out the 1-best candidate for each sentence
+ // this depends on the training mode
+ maxModelScore = NegInf;
+ for (String aCandSet : candSet) {
+ modelScore = 0.0;
+ candStr = aCandSet.toString();
+
+ feat_str = feat_hash[i].get(candStr).split("\\s+");
+
+ String[] feat_info;
+
+ for (String aFeat_str : feat_str) {
+ feat_info = aFeat_str.split("=");
+ modelScore += Double.parseDouble(feat_info[1]) * finalLambda[Vocabulary.id(feat_info[0])];
+ }
+
+ if (maxModelScore < modelScore) {
+ maxModelScore = modelScore;
+ tmpStatsVal = stats_hash[i].get(candStr).split("\\s+"); // save the
+ // suff stats
+ }
+ }
+
+ for (int j = 0; j < suffStatsCount; j++)
+ corpusStatsVal[j] += Integer.parseInt(tmpStatsVal[j]); // accumulate
+ // corpus-leve
+ // suff stats
+ } // for( int i=0; i<sentNum; i++ )
+
+ return evalMetric.score(corpusStatsVal);
+ }
-
++
+ private void findOraPred(int sentId, double[] oraPredScore, String[] oraPredFeat, double[] lambda, double featScale)
+ {
+ double oraMetric=0, oraScore=0, predMetric=0, predScore=0;
+ String oraFeat="", predFeat="";
+ double candMetric = 0, candScore = 0; //metric and model scores for each cand
+ Set<String> candSet = stats_hash[sentId].keySet();
+ String cand = "";
+ String feats = "";
+ String oraCand = ""; //only used when BLEU/TER-BLEU is used as metric
+ String[] featStr;
+ String[] featInfo;
-
++
+ int actualFeatId;
+ double bestOraScore;
+ double worstPredScore;
-
++
+ if(oraSelectMode==1)
+ bestOraScore = NegInf; //larger score will be selected
+ else {
+ if(evalMetric.getToBeMinimized())
+ bestOraScore = PosInf; //smaller score will be selected
+ else
+ bestOraScore = NegInf;
+ }
-
++
+ if(predSelectMode==1 || predSelectMode==2)
+ worstPredScore = NegInf; //larger score will be selected
+ else {
+ if(evalMetric.getToBeMinimized())
+ worstPredScore = NegInf; //larger score will be selected
+ else
+ worstPredScore = PosInf;
+ }
+
+ for (String aCandSet : candSet) {
+ cand = aCandSet.toString();
+ candMetric = computeSentMetric(sentId, cand); //compute metric score
+
+ //start to compute model score
+ candScore = 0;
+ featStr = feat_hash[sentId].get(cand).split("\\s+");
+ feats = "";
+
+ for (String aFeatStr : featStr) {
+ featInfo = aFeatStr.split("=");
+ actualFeatId = Vocabulary.id(featInfo[0]);
+ candScore += Double.parseDouble(featInfo[1]) * lambda[actualFeatId];
+ if ((actualFeatId < isOptimizable.length && isOptimizable[actualFeatId])
+ || actualFeatId >= isOptimizable.length)
+ feats += actualFeatId + "=" + Double.parseDouble(featInfo[1]) + " ";
+ }
+
+ candScore *= featScale; //scale the model score
+
+ //is this cand oracle?
+ if (oraSelectMode == 1) {//"hope", b=1, r=1
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (bestOraScore <= (candScore - candMetric)) {
+ bestOraScore = candScore - candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= (candScore + candMetric)) {
+ bestOraScore = candScore + candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ } else {//best metric score(ex: max BLEU), b=1, r=0
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (bestOraScore >= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ }
+
+ //is this cand prediction?
+ if (predSelectMode == 1) {//"fear"
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (worstPredScore <= (candScore + candMetric)) {
+ worstPredScore = candScore + candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore <= (candScore - candMetric)) {
+ worstPredScore = candScore - candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ } else if (predSelectMode == 2) {//model prediction(max model score)
+ if (worstPredScore <= candScore) {
+ worstPredScore = candScore;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {//worst metric score(ex: min BLEU)
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (worstPredScore <= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore >= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ }
+ }
-
++
+ oraPredScore[0] = oraMetric;
+ oraPredScore[1] = oraScore;
+ oraPredScore[2] = predMetric;
+ oraPredScore[3] = predScore;
+ oraPredFeat[0] = oraFeat;
+ oraPredFeat[1] = predFeat;
-
++
+ //update the BLEU metric statistics if pseudo corpus is used to compute BLEU/TER-BLEU
+ if(evalMetric.get_metricName().equals("BLEU") && usePseudoBleu ) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j]);
+ }
-
++
+ if(evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu ) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount()-2; j++)
+ bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j+2]); //the first 2 stats are TER stats
+ }
+ }
-
++
+ // compute *sentence-level* metric score for cand
+ private double computeSentMetric(int sentId, String cand) {
+ String statString;
+ String[] statVal_str;
+ int[] statVal = new int[evalMetric.get_suffStatsCount()];
+
+ statString = stats_hash[sentId].get(cand);
+ statVal_str = statString.split("\\s+");
+
+ if(evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = (int) (Integer.parseInt(statVal_str[j]) + bleuHistory[sentId][j]);
+ } else if(evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount()-2; j++)
+ statVal[j+2] = (int)(Integer.parseInt(statVal_str[j+2]) + bleuHistory[sentId][j]); //only modify the BLEU stats part(TER has 2 stats)
+ } else { //in all other situations, use normal stats
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = Integer.parseInt(statVal_str[j]);
+ }
+
+ return evalMetric.score(statVal);
+ }
+
+ // from ZMERT
+ private void normalizeLambda(double[] origLambda) {
+ // private String[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ int normalizationMethod = (int) normalizationOptions[0];
+ double scalingFactor = 1.0;
+ if (normalizationMethod == 0) {
+ scalingFactor = 1.0;
+ } else if (normalizationMethod == 1) {
+ int c = (int) normalizationOptions[2];
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[c]);
+ } else if (normalizationMethod == 2) {
+ double maxAbsVal = -1;
+ int maxAbsVal_c = 0;
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) > maxAbsVal) {
+ maxAbsVal = Math.abs(origLambda[c]);
+ maxAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[maxAbsVal_c]);
+
+ } else if (normalizationMethod == 3) {
+ double minAbsVal = PosInf;
+ int minAbsVal_c = 0;
+
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) < minAbsVal) {
+ minAbsVal = Math.abs(origLambda[c]);
+ minAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[minAbsVal_c]);
+
+ } else if (normalizationMethod == 4) {
+ double pow = normalizationOptions[1];
+ double norm = L_norm(origLambda, pow);
+ scalingFactor = normalizationOptions[2] / norm;
+ }
+
+ for (int c = 1; c <= paramDim; ++c) {
+ origLambda[c] *= scalingFactor;
+ }
+ }
+
+ // from ZMERT
+ private double L_norm(double[] A, double pow) {
+ // calculates the L-pow norm of A[]
+ // NOTE: this calculation ignores A[0]
+ double sum = 0.0;
+ for (int i = 1; i < A.length; ++i)
+ sum += Math.pow(Math.abs(A[i]), pow);
+
+ return Math.pow(sum, 1 / pow);
+ }
+
+ public static double getScale()
+ {
+ return featScale;
+ }
-
++
+ public static void initBleuHistory(int sentNum, int statCount)
+ {
+ bleuHistory = new double[sentNum][statCount];
+ for(int i=0; i<sentNum; i++) {
+ for(int j=0; j<statCount; j++) {
+ bleuHistory[i][j] = 0.0;
+ }
+ }
+ }
+
+ public double getMetricScore()
+ {
+ return finalMetricScore;
+ }
-
++
+ private final Vector<String> output;
+ private double[] initialLambda;
+ private final double[] finalLambda;
+ private double finalMetricScore;
+ private final HashMap<String, String>[] feat_hash;
+ private final HashMap<String, String>[] stats_hash;
+ private final int paramDim;
+ private final boolean[] isOptimizable;
+ public static int sentNum;
+ public static int adagradIter; //AdaGrad internal iterations
+ public static int oraSelectMode;
+ public static int predSelectMode;
+ public static int batchSize;
+ public static int regularization;
+ public static boolean needShuffle;
+ public static boolean needScale;
+ public static double scoreRatio;
+ public static boolean needAvg;
+ public static boolean usePseudoBleu;
+ public static double featScale = 1.0; //scale the features in order to make the model score comparable with metric score
+ //updates in each epoch if necessary
+ public static double eta;
+ public static double lam;
- public static double R; //corpus decay(used only when pseudo corpus is used to compute BLEU)
++ public static double R; //corpus decay(used only when pseudo corpus is used to compute BLEU)
+ public static EvaluationMetric evalMetric;
+ public static double[] normalizationOptions;
+ public static double[][] bleuHistory;
-
++
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
index 10efdc6,0000000..27303ec
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
@@@ -1,412 -1,0 +1,414 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus.syntax;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.util.io.LineReader;
+
+public class ArraySyntaxTree implements SyntaxTree, Externalizable {
+
+ /**
+ * Note that index stores the indices of lattice node positions, i.e. the last element of index is
+ * the terminal node, pointing to lattice.size()
+ */
+ private ArrayList<Integer> forwardIndex;
+ private ArrayList<Integer> forwardLattice;
+ private ArrayList<Integer> backwardIndex;
+ private ArrayList<Integer> backwardLattice;
+
+ private ArrayList<Integer> terminals;
+
+ private final boolean useBackwardLattice = true;
+
+ private static final int MAX_CONCATENATIONS = 3;
+ private static final int MAX_LABELS = 100;
+
+ public ArraySyntaxTree() {
+ forwardIndex = null;
+ forwardLattice = null;
+ backwardIndex = null;
+ backwardLattice = null;
+
+ terminals = null;
+ }
+
+
+ public ArraySyntaxTree(String parsed_line) {
+ initialize();
+ appendFromPennFormat(parsed_line);
+ }
+
+
+ /**
+ * Returns a collection of single-non-terminal labels that exactly cover the specified span in the
+ * lattice.
+ */
++ @Override
+ public Collection<Integer> getConstituentLabels(int from, int to) {
+ Collection<Integer> labels = new HashSet<>();
+ int span_length = to - from;
+ for (int i = forwardIndex.get(from); i < forwardIndex.get(from + 1); i += 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span == span_length)
+ labels.add(forwardLattice.get(i));
+ else if (current_span < span_length) break;
+ }
+ return labels;
+ }
+
+
+ public int getOneConstituent(int from, int to) {
+ int spanLength = to - from;
+ Stack<Integer> stack = new Stack<>();
+
+ for (int i = forwardIndex.get(from); i < forwardIndex.get(from + 1); i += 2) {
+ int currentSpan = forwardLattice.get(i + 1);
+ if (currentSpan == spanLength) {
+ return forwardLattice.get(i);
+ } else if (currentSpan < spanLength) break;
+ }
+ if (stack.isEmpty()) return 0;
+ StringBuilder sb = new StringBuilder();
+ while (!stack.isEmpty()) {
+ String w = Vocabulary.word(stack.pop());
+ if (sb.length() != 0) sb.append(":");
+ sb.append(w);
+ }
+ String label = sb.toString();
+ return Vocabulary.id(adjustMarkup(label));
+ }
+
+
+ public int getOneSingleConcatenation(int from, int to) {
+ for (int midpt = from + 1; midpt < to; midpt++) {
+ int x = getOneConstituent(from, midpt);
+ if (x == 0) continue;
+ int y = getOneConstituent(midpt, to);
+ if (y == 0) continue;
+ String label = Vocabulary.word(x) + "+" + Vocabulary.word(y);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ return 0;
+ }
+
+
+ public int getOneDoubleConcatenation(int from, int to) {
+ for (int a = from + 1; a < to - 1; a++) {
+ for (int b = a + 1; b < to; b++) {
+ int x = getOneConstituent(from, a);
+ if (x == 0) continue;
+ int y = getOneConstituent(a, b);
+ if (y == 0) continue;
+ int z = getOneConstituent(b, to);
+ if (z == 0) continue;
+ String label = Vocabulary.word(x) + "+" + Vocabulary.word(y) + "+" + Vocabulary.word(z);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ }
+ return 0;
+ }
+
+
+ public int getOneRightSideCCG(int from, int to) {
+ for (int end = to + 1; end <= forwardLattice.size(); end++) {
+ int x = getOneConstituent(from, end);
+ if (x == 0) continue;
+ int y = getOneConstituent(to, end);
+ if (y == 0) continue;
+ String label = Vocabulary.word(x) + "/" + Vocabulary.word(y);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ return 0;
+ }
+
+
+ public int getOneLeftSideCCG(int from, int to) {
+ for (int start = from - 1; start >= 0; start--) {
+ int x = getOneConstituent(start, to);
+ if (x == 0) continue;
+ int y = getOneConstituent(start, from);
+ if (y == 0) continue;
+ String label = Vocabulary.word(y) + "\\" + Vocabulary.word(x);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ return 0;
+ }
+
+
+ /**
+ * Returns a collection of concatenated non-terminal labels that exactly cover the specified span
+ * in the lattice. The number of non-terminals concatenated is limited by MAX_CONCATENATIONS and
+ * the total number of labels returned is bounded by MAX_LABELS.
+ */
++ @Override
+ public Collection<Integer> getConcatenatedLabels(int from, int to) {
+ Collection<Integer> labels = new HashSet<>();
+
+ int span_length = to - from;
+ Stack<Integer> nt_stack = new Stack<>();
+ Stack<Integer> pos_stack = new Stack<>();
+ Stack<Integer> depth_stack = new Stack<>();
+
+ // seed stacks (reverse order to save on iterations, longer spans)
+ for (int i = forwardIndex.get(from + 1) - 2; i >= forwardIndex.get(from); i -= 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span < span_length) {
+ nt_stack.push(forwardLattice.get(i));
+ pos_stack.push(from + current_span);
+ depth_stack.push(1);
+ } else if (current_span >= span_length) break;
+ }
+
+ while (!nt_stack.isEmpty() && labels.size() < MAX_LABELS) {
+ int nt = nt_stack.pop();
+ int pos = pos_stack.pop();
+ int depth = depth_stack.pop();
+
+ // maximum depth reached without filling span
+ if (depth == MAX_CONCATENATIONS) continue;
+
+ int remaining_span = to - pos;
+ for (int i = forwardIndex.get(pos + 1) - 2; i >= forwardIndex.get(pos); i -= 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span > remaining_span) break;
+
+ // create and look up concatenated label
+ int concatenated_nt =
+ Vocabulary.id(adjustMarkup(Vocabulary.word(nt) + "+"
+ + Vocabulary.word(forwardLattice.get(i))));
+ if (current_span < remaining_span) {
+ nt_stack.push(concatenated_nt);
+ pos_stack.push(pos + current_span);
+ depth_stack.push(depth + 1);
+ } else if (current_span == remaining_span) {
+ labels.add(concatenated_nt);
+ }
+ }
+ }
+
+ return labels;
+ }
+
+ // TODO: can pre-comupute all that in top-down fashion.
++ @Override
+ public Collection<Integer> getCcgLabels(int from, int to) {
+ Collection<Integer> labels = new HashSet<>();
+
+ int span_length = to - from;
+ // TODO: range checks on the to and from
+
+ boolean is_prefix = (forwardLattice.get(forwardIndex.get(from) + 1) > span_length);
+ if (is_prefix) {
+ Map<Integer, Set<Integer>> main_constituents = new HashMap<>();
+ // find missing to the right
+ for (int i = forwardIndex.get(from); i < forwardIndex.get(from + 1); i += 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span <= span_length)
+ break;
+ else {
+ int end_pos = forwardLattice.get(i + 1) + from;
+ Set<Integer> nts = main_constituents.get(end_pos);
+ if (nts == null) main_constituents.put(end_pos, new HashSet<>());
+ main_constituents.get(end_pos).add(forwardLattice.get(i));
+ }
+ }
+ for (int i = forwardIndex.get(to); i < forwardIndex.get(to + 1); i += 2) {
+ Set<Integer> main_set = main_constituents.get(to + forwardLattice.get(i + 1));
+ if (main_set != null) {
+ for (int main : main_set)
+ labels.add(Vocabulary.id(adjustMarkup(Vocabulary.word(main) + "/"
+ + Vocabulary.word(forwardLattice.get(i)))));
+ }
+ }
+ }
+
+ if (!is_prefix) {
+ if (useBackwardLattice) {
+ // check if there is any possible higher-level constituent overlapping
+ int to_end =
+ (to == backwardIndex.size() - 1) ? backwardLattice.size() : backwardIndex.get(to + 1);
+ // check longest span ending in to..
+ if (backwardLattice.get(to_end - 1) <= span_length) return labels;
+
+ Map<Integer, Set<Integer>> main_constituents = new HashMap<>();
+ // find missing to the left
+ for (int i = to_end - 2; i >= backwardIndex.get(to); i -= 2) {
+ int current_span = backwardLattice.get(i + 1);
+ if (current_span <= span_length)
+ break;
+ else {
+ int start_pos = to - backwardLattice.get(i + 1);
+ Set<Integer> nts = main_constituents.get(start_pos);
+ if (nts == null) main_constituents.put(start_pos, new HashSet<>());
+ main_constituents.get(start_pos).add(backwardLattice.get(i));
+ }
+ }
+ for (int i = backwardIndex.get(from); i < backwardIndex.get(from + 1); i += 2) {
+ Set<Integer> main_set = main_constituents.get(from - backwardLattice.get(i + 1));
+ if (main_set != null) {
+ for (int main : main_set)
+ labels.add(Vocabulary.id(adjustMarkup(Vocabulary.word(main) + "\\"
+ + Vocabulary.word(backwardLattice.get(i)))));
+ }
+ }
+ } else {
+ // TODO: bothersome no-backwards-arrays method.
+ }
+ }
+ return labels;
+ }
+
+ @Override
+ public int[] getTerminals() {
+ return getTerminals(0, terminals.size());
+ }
+
+ @Override
+ public int[] getTerminals(int from, int to) {
+ int[] span = new int[to - from];
+ for (int i = from; i < to; i++)
+ span[i - from] = terminals.get(i);
+ return span;
+ }
+
++ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ // TODO Auto-generated method stub
+ }
+
++ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // TODO Auto-generated method stub
+ }
+
+ /**
+ * Reads Penn Treebank format file
+ * @param file_name the string path of the Penn Treebank file
+ * @throws IOException if the file does not exist
+ */
+ public void readExternalText(String file_name) throws IOException {
- LineReader reader = new LineReader(file_name);
- initialize();
- for (String line : reader) {
- if (line.trim().equals("")) continue;
- appendFromPennFormat(line);
++ try (LineReader reader = new LineReader(file_name);) {
++ initialize();
++ for (String line : reader) {
++ if (line.trim().equals("")) continue;
++ appendFromPennFormat(line);
++ }
+ }
+ }
+
- public void writeExternalText(String file_name) throws IOException {
- // TODO Auto-generated method stub
- }
-
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < forwardIndex.size(); i++)
+ sb.append("FI[").append(i).append("] =\t").append(forwardIndex.get(i)).append("\n");
+ sb.append("\n");
+ for (int i = 0; i < forwardLattice.size(); i += 2)
+ sb.append("F[").append(i).append("] =\t").append(Vocabulary.word(forwardLattice.get(i)))
+ .append(" , ").append(forwardLattice.get(i + 1)).append("\n");
+
+ sb.append("\n");
+ for (int i = 0; i < terminals.size(); i += 1)
+ sb.append("T[").append(i).append("] =\t").append(Vocabulary.word(terminals.get(i)))
+ .append(" , 1 \n");
+
+ if (this.useBackwardLattice) {
+ sb.append("\n");
+ for (int i = 0; i < backwardIndex.size(); i++)
+ sb.append("BI[").append(i).append("] =\t").append(backwardIndex.get(i)).append("\n");
+ sb.append("\n");
+ for (int i = 0; i < backwardLattice.size(); i += 2)
+ sb.append("B[").append(i).append("] =\t").append(Vocabulary.word(backwardLattice.get(i)))
+ .append(" , ").append(backwardLattice.get(i + 1)).append("\n");
+ }
+ return sb.toString();
+ }
+
+
+ private void initialize() {
+ forwardIndex = new ArrayList<>();
+ forwardIndex.add(0);
+ forwardLattice = new ArrayList<>();
+ if (this.useBackwardLattice) {
+ backwardIndex = new ArrayList<>();
+ backwardIndex.add(0);
+ backwardLattice = new ArrayList<>();
+ }
+
+ terminals = new ArrayList<>();
+ }
+
+
+ // TODO: could make this way more efficient
+ private void appendFromPennFormat(String line) {
+ String[] tokens = line.replaceAll("\\(", " ( ").replaceAll("\\)", " ) ").trim().split("\\s+");
+
+ boolean next_nt = false;
+ int current_id = 0;
+ Stack<Integer> stack = new Stack<>();
+
+ for (String token : tokens) {
+ if ("(".equals(token)) {
+ next_nt = true;
+ continue;
+ }
+ if (")".equals(token)) {
+ int closing_pos = stack.pop();
+ forwardLattice.set(closing_pos, forwardIndex.size() - forwardLattice.get(closing_pos));
+ if (this.useBackwardLattice) {
+ backwardLattice.add(forwardLattice.get(closing_pos - 1));
+ backwardLattice.add(forwardLattice.get(closing_pos));
+ }
+ continue;
+ }
+ if (next_nt) {
+ // get NT id
+ current_id = Vocabulary.id(adjustMarkup(token));
+ // add into lattice
+ forwardLattice.add(current_id);
+ // push NT span field onto stack (added hereafter, we're just saving the "- 1")
+ stack.push(forwardLattice.size());
+ // add NT span field
+ forwardLattice.add(forwardIndex.size());
+ } else {
+ current_id = Vocabulary.id(token);
+ terminals.add(current_id);
+
+ forwardIndex.add(forwardLattice.size());
+ if (this.useBackwardLattice) backwardIndex.add(backwardLattice.size());
+ }
+ next_nt = false;
+ }
+ }
+
+ private String adjustMarkup(String nt) {
+ return "[" + nt.replaceAll("[\\[\\]]", "") + "]";
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
index 26ed674,0000000..97baa27
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
@@@ -1,116 -1,0 +1,116 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * @author orluke
- *
++ *
+ */
+public class ArgsParser {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ArgsParser.class);
+
+ private String configFile = null;
+
+ /**
+ * Parse the arguments passed from the command line when the JoshuaDecoder application was
+ * executed from the command line.
- *
++ *
+ * @param args string array of input arguments
+ * @param config the {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ * @throws IOException if there is an error wit the input arguments
+ */
+ public ArgsParser(String[] args, JoshuaConfiguration config) throws IOException {
+
+ /*
+ * Look for a verbose flag, -v.
- *
- * Look for an argument to the "-config" flag to find the config file, if any.
++ *
++ * Look for an argument to the "-config" flag to find the config file, if any.
+ */
+ if (args.length >= 1) {
+ // Search for a verbose flag
+ for (int i = 0; i < args.length; i++) {
+ if (args[i].equals("-v")) {
+ Decoder.VERBOSE = Integer.parseInt(args[i + 1].trim());
+ config.setVerbosity(Decoder.VERBOSE);
+ }
-
- if (args[i].equals("-version")) {
- LineReader reader = new LineReader(String.format("%s/VERSION", System.getenv("JOSHUA")));
- reader.readLine();
- String version = reader.readLine().split("\\s+")[2];
- System.out.println(String.format("The Apache Joshua machine translator, version %s", version));
- System.out.println("joshua.incubator.apache.org");
- System.exit(0);
+
++ if (args[i].equals("-version")) {
++ try (LineReader reader = new LineReader(String.format("%s/VERSION", System.getenv("JOSHUA")));) {
++ reader.readLine();
++ String version = reader.readLine().split("\\s+")[2];
++ System.out.println(String.format("The Apache Joshua machine translator, version %s", version));
++ System.out.println("joshua.incubator.apache.org");
++ System.exit(0);
++ }
+ } else if (args[i].equals("-license")) {
+ try {
+ Files.readAllLines(Paths.get(String.format("%s/../LICENSE",
+ JoshuaConfiguration.class.getProtectionDomain().getCodeSource().getLocation()
+ .getPath())), Charset.defaultCharset()).forEach(System.out::println);
+ } catch (IOException e) {
+ throw new RuntimeException("FATAL: missing license file!", e);
+ }
+ System.exit(0);
+ }
+ }
+
+ // Search for the configuration file from the end (so as to take the last one)
+ for (int i = args.length-1; i >= 0; i--) {
+ if (args[i].equals("-c") || args[i].equals("-config")) {
+
+ setConfigFile(args[i + 1].trim());
+ try {
+ LOG.info("Parameters read from configuration file: {}", getConfigFile());
+ config.readConfigFile(getConfigFile());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ break;
+ }
+ }
+
+ // Now process all the command-line args
+ config.processCommandLineOptions(args);
+ }
+ }
+
+ /**
+ * @return the configFile
+ */
+ public String getConfigFile() {
+ return configFile;
+ }
+
+ /**
+ * @param configFile the configFile to set
+ */
+ public void setConfigFile(String configFile) {
+ this.configFile = configFile;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
index 9cfb6eb,0000000..3d6f3bc
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
@@@ -1,598 -1,0 +1,597 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature;
+import static org.apache.joshua.decoder.ff.tm.OwnerMap.getOwner;
+import static org.apache.joshua.util.Constants.spaceSeparator;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+
- import com.google.common.base.Strings;
- import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.FeatureMap;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.PhraseModel;
+import org.apache.joshua.decoder.ff.StatefulFF;
+import org.apache.joshua.decoder.ff.lm.LanguageModelFF;
+import org.apache.joshua.decoder.ff.tm.Grammar;
+import org.apache.joshua.decoder.ff.tm.OwnerId;
+import org.apache.joshua.decoder.ff.tm.OwnerMap;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
+import org.apache.joshua.decoder.ff.tm.hash_based.MemoryBasedBatchGrammar;
+import org.apache.joshua.decoder.ff.tm.packed.PackedGrammar;
+import org.apache.joshua.decoder.io.TranslationRequestStream;
+import org.apache.joshua.decoder.phrase.PhraseTable;
+import org.apache.joshua.decoder.segment_file.Sentence;
+import org.apache.joshua.util.FileUtility;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.Regex;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
++import com.google.common.base.Strings;
++import com.google.common.util.concurrent.ThreadFactoryBuilder;
++
+/**
+ * This class handles decoder initialization and the complication introduced by multithreading.
+ *
+ * After initialization, the main entry point to the Decoder object is
+ * decodeAll(TranslationRequest), which returns a set of Translation objects wrapped in an iterable
+ * TranslationResponseStream object. It is important that we support multithreading both (a) across the sentences
+ * within a request and (b) across requests, in a round-robin fashion. This is done by maintaining a
+ * fixed sized concurrent thread pool. When a new request comes in, a RequestParallelizer thread is
+ * launched. This object iterates over the request's sentences, obtaining a thread from the
+ * thread pool, and using that thread to decode the sentence. If a decoding thread is not available,
+ * it will block until one is in a fair (FIFO) manner. RequestParallelizer thereby permits intra-request
+ * parallelization by separating out reading the input stream from processing the translated sentences,
+ * but also ensures that round-robin parallelization occurs, since RequestParallelizer uses the
+ * thread pool before translating each request.
+ *
+ * A decoding thread is handled by DecoderTask and launched from DecoderThreadRunner. The purpose
+ * of the runner is to record where to place the translated sentence when it is done (i.e., which
+ * TranslationResponseStream object). TranslationResponseStream itself is an iterator whose next() call blocks until the next
+ * translation is available.
+ *
+ * @author Matt Post post@cs.jhu.edu
+ * @author Zhifei Li, zhifei.work@gmail.com
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @author Lane Schwartz dowobeha@users.sourceforge.net
+ * @author Kellen Sunderland kellen.sunderland@gmail.com
+ */
+public class Decoder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(Decoder.class);
+
+ private final JoshuaConfiguration joshuaConfiguration;
+
+ public JoshuaConfiguration getJoshuaConfiguration() {
+ return joshuaConfiguration;
+ }
+
+ /*
+ * Many of these objects themselves are global objects. We pass them in when constructing other
+ * objects, so that they all share pointers to the same object. This is good because it reduces
+ * overhead, but it can be problematic because of unseen dependencies (for example, in the
+ * Vocabulary shared by language model, translation grammar, etc).
+ */
+ private final List<Grammar> grammars = new ArrayList<Grammar>();
+ private final ArrayList<FeatureFunction> featureFunctions = new ArrayList<>();
+ private Grammar customPhraseTable = null;
+
+ /* The feature weights. */
+ public static FeatureVector weights;
+
+ public static int VERBOSE = 1;
+
+ /**
+ * Constructor method that creates a new decoder using the specified configuration file.
+ *
+ * @param joshuaConfiguration a populated {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ */
+ public Decoder(JoshuaConfiguration joshuaConfiguration) {
+ this.joshuaConfiguration = joshuaConfiguration;
+ this.initialize();
+ }
+
+ /**
+ * This function is the main entry point into the decoder. It translates all the sentences in a
+ * (possibly boundless) set of input sentences. Each request launches its own thread to read the
+ * sentences of the request.
+ *
+ * @param request the populated {@link TranslationRequestStream}
+ * @throws RuntimeException if any fatal errors occur during translation
+ * @return an iterable, asynchronously-filled list of TranslationResponseStream
+ */
+ public TranslationResponseStream decodeAll(TranslationRequestStream request) {
+ TranslationResponseStream results = new TranslationResponseStream(request);
+ CompletableFuture.runAsync(() -> decodeAllAsync(request, results));
+ return results;
+ }
+
+ private void decodeAllAsync(TranslationRequestStream request,
+ TranslationResponseStream responseStream) {
+
+ // Give the threadpool a friendly name to help debuggers
+ final ThreadFactory threadFactory = new ThreadFactoryBuilder()
+ .setNameFormat("TranslationWorker-%d")
+ .setDaemon(true)
+ .build();
+ ExecutorService executor = Executors.newFixedThreadPool(this.joshuaConfiguration.num_parallel_decoders,
+ threadFactory);
+ try {
+ for (; ; ) {
+ Sentence sentence = request.next();
+
+ if (sentence == null) {
+ break;
+ }
+
+ executor.execute(() -> {
+ try {
+ Translation result = decode(sentence);
+ responseStream.record(result);
+ } catch (Throwable ex) {
+ responseStream.propagate(ex);
+ }
+ });
+ }
+ responseStream.finish();
+ } finally {
+ executor.shutdown();
+ }
+ }
+
+
+ /**
+ * We can also just decode a single sentence in the same thread.
+ *
+ * @param sentence {@link org.apache.joshua.lattice.Lattice} input
+ * @throws RuntimeException if any fatal errors occur during translation
+ * @return the sentence {@link org.apache.joshua.decoder.Translation}
+ */
+ public Translation decode(Sentence sentence) {
+ DecoderTask decoderTask = new DecoderTask(this.grammars, this.featureFunctions, joshuaConfiguration);
+ return decoderTask.translate(sentence);
+ }
+
+ /**
+ * Clean shutdown of Decoder, resetting all
+ * static variables, such that any other instance of Decoder
+ * afterwards gets a fresh start.
+ */
+ public void cleanUp() {
+ resetGlobalState();
+ }
+
+ public static void resetGlobalState() {
+ // clear/reset static variables
+ OwnerMap.clear();
+ FeatureMap.clear();
+ Vocabulary.clear();
+ Vocabulary.unregisterLanguageModels();
+ LanguageModelFF.resetLmIndex();
+ StatefulFF.resetGlobalStateIndex();
+ }
+
+ public static void writeConfigFile(double[] newWeights, String template, String outputFile,
+ String newDiscriminativeModel) {
+ try {
+ int columnID = 0;
+
+ try (LineReader reader = new LineReader(template);
+ BufferedWriter writer = FileUtility.getWriteFileStream(outputFile)) {
+ for (String line : reader) {
+ line = line.trim();
+ if (Regex.commentOrEmptyLine.matches(line) || line.contains("=")) {
+ // comment, empty line, or parameter lines: just copy
+ writer.write(line);
+ writer.newLine();
+
+ } else { // models: replace the weight
+ String[] fds = Regex.spaces.split(line);
+ StringBuilder newSent = new StringBuilder();
+ if (!Regex.floatingNumber.matches(fds[fds.length - 1])) {
+ throw new IllegalArgumentException("last field is not a number; the field is: "
+ + fds[fds.length - 1]);
+ }
+
+ if (newDiscriminativeModel != null && "discriminative".equals(fds[0])) {
+ newSent.append(fds[0]).append(' ');
+ newSent.append(newDiscriminativeModel).append(' ');// change the
+ // file name
+ for (int i = 2; i < fds.length - 1; i++) {
+ newSent.append(fds[i]).append(' ');
+ }
+ } else {// regular
+ for (int i = 0; i < fds.length - 1; i++) {
+ newSent.append(fds[i]).append(' ');
+ }
+ }
+ if (newWeights != null)
+ newSent.append(newWeights[columnID++]);// change the weight
+ else
+ newSent.append(fds[fds.length - 1]);// do not change
+
+ writer.write(newSent.toString());
+ writer.newLine();
+ }
+ }
+ }
+
+ if (newWeights != null && columnID != newWeights.length) {
+ throw new IllegalArgumentException("number of models does not match number of weights");
+ }
+
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Initialize all parts of the JoshuaDecoder.
+ */
+ private void initialize() {
+ try {
+
+ long pre_load_time = System.currentTimeMillis();
+ resetGlobalState();
+
+ /* Weights can be listed in a separate file (denoted by parameter "weights-file") or directly
+ * in the Joshua config file. Config file values take precedent.
+ */
+ this.readWeights(joshuaConfiguration.weights_file);
-
-
++
++
+ /* Add command-line-passed weights to the weights array for processing below */
+ if (!Strings.isNullOrEmpty(joshuaConfiguration.weight_overwrite)) {
+ String[] tokens = joshuaConfiguration.weight_overwrite.split("\\s+");
+ for (int i = 0; i < tokens.length; i += 2) {
+ String feature = tokens[i];
+ float value = Float.parseFloat(tokens[i+1]);
+
+ if (joshuaConfiguration.moses)
+ feature = demoses(feature);
+
+ joshuaConfiguration.weights.add(String.format("%s %s", feature, tokens[i+1]));
+ LOG.info("COMMAND LINE WEIGHT: {} -> {}", feature, value);
+ }
+ }
+
+ /* Read the weights found in the config file */
+ for (String pairStr: joshuaConfiguration.weights) {
+ String pair[] = pairStr.split("\\s+");
+
+ /* Sanity check for old-style unsupported feature invocations. */
+ if (pair.length != 2) {
+ String errMsg = "FATAL: Invalid feature weight line found in config file.\n" +
+ String.format("The line was '%s'\n", pairStr) +
+ "You might be using an old version of the config file that is no longer supported\n" +
+ "Check joshua.apache.org or email dev@joshua.apache.org for help\n" +
+ "Code = " + 17;
+ throw new RuntimeException(errMsg);
+ }
+
+ weights.add(hashFeature(pair[0]), Float.parseFloat(pair[1]));
+ }
+
+ LOG.info("Read {} weights", weights.size());
+
+ // Do this before loading the grammars and the LM.
+ this.featureFunctions.clear();
+
+ // Initialize and load grammars. This must happen first, since the vocab gets defined by
+ // the packed grammar (if any)
+ this.initializeTranslationGrammars();
+ LOG.info("Grammar loading took: {} seconds.",
+ (System.currentTimeMillis() - pre_load_time) / 1000);
+
+ // Initialize the features: requires that LM model has been initialized.
+ this.initializeFeatureFunctions();
+
+ // This is mostly for compatibility with the Moses tuning script
+ if (joshuaConfiguration.show_weights_and_quit) {
+ for (Entry<Integer, Float> entry : weights.entrySet()) {
+ System.out.println(String.format("%s=%.5f", FeatureMap.getFeature(entry.getKey()), entry.getValue()));
+ }
+ // TODO (fhieber): this functionality should not be in main Decoder class and simply exit.
+ System.exit(0);
+ }
+
+ // Sort the TM grammars (needed to do cube pruning)
+ if (joshuaConfiguration.amortized_sorting) {
+ LOG.info("Grammar sorting happening lazily on-demand.");
+ } else {
+ long pre_sort_time = System.currentTimeMillis();
+ for (Grammar grammar : this.grammars) {
+ grammar.sortGrammar(this.featureFunctions);
+ }
+ LOG.info("Grammar sorting took {} seconds.",
+ (System.currentTimeMillis() - pre_sort_time) / 1000);
+ }
+
+ } catch (IOException e) {
+ LOG.warn(e.getMessage(), e);
+ }
+ }
+
+ /**
+ * Initializes translation grammars Retained for backward compatibility
+ *
+ * @throws IOException Several grammar elements read from disk that can
+ * cause IOExceptions.
+ */
+ private void initializeTranslationGrammars() throws IOException {
+
+ if (joshuaConfiguration.tms.size() > 0) {
+
+ // collect packedGrammars to check if they use a shared vocabulary
+ final List<PackedGrammar> packed_grammars = new ArrayList<>();
+
+ // tm = {thrax/hiero,packed,samt,moses} OWNER LIMIT FILE
+ for (String tmLine : joshuaConfiguration.tms) {
+
+ String type = tmLine.substring(0, tmLine.indexOf(' '));
+ String[] args = tmLine.substring(tmLine.indexOf(' ')).trim().split("\\s+");
+ HashMap<String, String> parsedArgs = FeatureFunction.parseArgs(args);
+
+ String owner = parsedArgs.get("owner");
+ int span_limit = Integer.parseInt(parsedArgs.get("maxspan"));
+ String path = parsedArgs.get("path");
+
+ Grammar grammar;
+ if (! type.equals("moses") && ! type.equals("phrase")) {
+ if (new File(path).isDirectory()) {
+ try {
+ PackedGrammar packed_grammar = new PackedGrammar(path, span_limit, owner, type, joshuaConfiguration);
+ packed_grammars.add(packed_grammar);
+ grammar = packed_grammar;
+ } catch (FileNotFoundException e) {
+ String msg = String.format("Couldn't load packed grammar from '%s'", path)
+ + "Perhaps it doesn't exist, or it may be an old packed file format.";
+ throw new RuntimeException(msg);
+ }
+ } else {
+ // thrax, hiero, samt
+ grammar = new MemoryBasedBatchGrammar(type, path, owner,
+ joshuaConfiguration.default_non_terminal, span_limit, joshuaConfiguration);
+ }
+
+ } else {
+
+ joshuaConfiguration.search_algorithm = "stack";
+ grammar = new PhraseTable(path, owner, type, joshuaConfiguration);
+ }
+
+ this.grammars.add(grammar);
+ }
+
+ checkSharedVocabularyChecksumsForPackedGrammars(packed_grammars);
+
+ } else {
+ LOG.warn("no grammars supplied! Supplying dummy glue grammar.");
+ MemoryBasedBatchGrammar glueGrammar = new MemoryBasedBatchGrammar("glue", joshuaConfiguration, -1);
+ glueGrammar.addGlueRules(featureFunctions);
+ this.grammars.add(glueGrammar);
+ }
-
++
+ /* Add the grammar for custom entries */
+ if (joshuaConfiguration.search_algorithm.equals("stack"))
+ this.customPhraseTable = new PhraseTable("custom", joshuaConfiguration);
+ else
+ this.customPhraseTable = new MemoryBasedBatchGrammar("custom", joshuaConfiguration, 20);
+ this.grammars.add(this.customPhraseTable);
-
++
+ /* Create an epsilon-deleting grammar */
+ if (joshuaConfiguration.lattice_decoding) {
+ LOG.info("Creating an epsilon-deleting grammar");
+ MemoryBasedBatchGrammar latticeGrammar = new MemoryBasedBatchGrammar("lattice", joshuaConfiguration, -1);
+ HieroFormatReader reader = new HieroFormatReader(OwnerMap.register("lattice"));
+
+ String goalNT = FormatUtils.cleanNonTerminal(joshuaConfiguration.goal_symbol);
+ String defaultNT = FormatUtils.cleanNonTerminal(joshuaConfiguration.default_non_terminal);
+
+ //FIXME: arguments changed to match string format on best effort basis. Author please review.
+ String ruleString = String.format("[%s] ||| [%s,1] <eps> ||| [%s,1] ||| ", goalNT, defaultNT, defaultNT);
+
+ Rule rule = reader.parseLine(ruleString);
+ latticeGrammar.addRule(rule);
+ rule.estimateRuleCost(featureFunctions);
+
+ this.grammars.add(latticeGrammar);
+ }
+
+ /* Now create a feature function for each owner */
+ final Set<OwnerId> ownersSeen = new HashSet<>();
+
+ for (Grammar grammar: this.grammars) {
+ OwnerId owner = grammar.getOwner();
+ if (! ownersSeen.contains(owner)) {
+ this.featureFunctions.add(
+ new PhraseModel(
+ weights, new String[] { "tm", "-owner", getOwner(owner) }, joshuaConfiguration, grammar));
+ ownersSeen.add(owner);
+ }
+ }
+
+ LOG.info("Memory used {} MB",
+ ((Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / 1000000.0));
+ }
+
+ /**
+ * Checks if multiple packedGrammars have the same vocabulary by comparing their vocabulary file checksums.
+ */
+ private static void checkSharedVocabularyChecksumsForPackedGrammars(final List<PackedGrammar> packed_grammars) {
+ String previous_checksum = "";
+ for (PackedGrammar grammar : packed_grammars) {
+ final String checksum = grammar.computeVocabularyChecksum();
+ if (previous_checksum.isEmpty()) {
+ previous_checksum = checksum;
+ } else {
+ if (!checksum.equals(previous_checksum)) {
+ throw new RuntimeException(
+ "Trying to load multiple packed grammars with different vocabularies!" +
+ "Have you packed them jointly?");
+ }
+ previous_checksum = checksum;
+ }
+ }
+ }
+
+ /*
+ * This function reads the weights for the model. Feature names and their weights are listed one
+ * per line in the following format:
- *
++ *
+ * FEATURE_NAME WEIGHT
+ */
+ private void readWeights(String fileName) {
+ Decoder.weights = new FeatureVector(5);
+
+ if (fileName.equals(""))
+ return;
+
- try {
- LineReader lineReader = new LineReader(fileName);
-
++ try (LineReader lineReader = new LineReader(fileName);) {
+ for (String line : lineReader) {
+ line = line.replaceAll(spaceSeparator, " ");
+
+ if (line.equals("") || line.startsWith("#") || line.startsWith("//")
+ || line.indexOf(' ') == -1)
+ continue;
+
+ String tokens[] = line.split(spaceSeparator);
+ String feature = tokens[0];
+ Float value = Float.parseFloat(tokens[1]);
+
+ // Kludge for compatibility with Moses tuners
+ if (joshuaConfiguration.moses) {
+ feature = demoses(feature);
+ }
+
+ weights.add(hashFeature(feature), value);
+ }
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ LOG.info("Read {} weights from file '{}'", weights.size(), fileName);
+ }
+
+ private String demoses(String feature) {
+ if (feature.endsWith("="))
+ feature = feature.replace("=", "");
+ if (feature.equals("OOV_Penalty"))
+ feature = "OOVPenalty";
+ else if (feature.startsWith("tm-") || feature.startsWith("lm-"))
+ feature = feature.replace("-", "_");
+ return feature;
+ }
+
+ /**
+ * Feature functions are instantiated with a line of the form
+ *
+ * <pre>
+ * FEATURE OPTIONS
+ * </pre>
+ *
+ * Weights for features are listed separately.
+ *
+ */
+ private void initializeFeatureFunctions() {
+
+ for (String featureLine : joshuaConfiguration.features) {
+ // line starts with NAME, followed by args
+ // 1. create new class named NAME, pass it config, weights, and the args
+
+ String fields[] = featureLine.split("\\s+");
+ String featureName = fields[0];
-
++
+ try {
-
++
+ Class<?> clas = getFeatureFunctionClass(featureName);
+ Constructor<?> constructor = clas.getConstructor(FeatureVector.class,
+ String[].class, JoshuaConfiguration.class);
+ FeatureFunction feature = (FeatureFunction) constructor.newInstance(weights, fields, joshuaConfiguration);
+ this.featureFunctions.add(feature);
-
++
+ } catch (Exception e) {
- throw new RuntimeException(String.format("Unable to instantiate feature function '%s'!", featureLine), e);
++ throw new RuntimeException(String.format("Unable to instantiate feature function '%s'!", featureLine), e);
+ }
+ }
+
+ for (FeatureFunction feature : featureFunctions) {
+ LOG.info("FEATURE: {}", feature.logString());
+ }
+ }
+
+ /**
+ * Searches a list of predefined paths for classes, and returns the first one found. Meant for
+ * instantiating feature functions.
+ *
+ * @param featureName Class name of the feature to return.
+ * @return the class, found in one of the search paths
+ */
+ private Class<?> getFeatureFunctionClass(String featureName) {
+ Class<?> clas = null;
+
+ String[] packages = { "org.apache.joshua.decoder.ff", "org.apache.joshua.decoder.ff.lm", "org.apache.joshua.decoder.ff.phrase" };
+ for (String path : packages) {
+ try {
+ clas = Class.forName(String.format("%s.%s", path, featureName));
+ break;
+ } catch (ClassNotFoundException e) {
+ try {
+ clas = Class.forName(String.format("%s.%sFF", path, featureName));
+ break;
+ } catch (ClassNotFoundException e2) {
+ // do nothing
+ }
+ }
+ }
+ return clas;
+ }
-
++
+ /**
- * Adds a rule to the custom grammar.
- *
++ * Adds a rule to the custom grammar.
++ *
+ * @param rule the rule to add
+ */
+ public void addCustomRule(Rule rule) {
+ customPhraseTable.addRule(rule);
+ rule.estimateRuleCost(featureFunctions);
+ }
+
+ public Grammar getCustomPhraseTable() {
+ return customPhraseTable;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
index f25590c,0000000..2ac5269
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
@@@ -1,148 -1,0 +1,147 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.net.InetSocketAddress;
+
- import com.sun.net.httpserver.HttpServer;
-
+import org.apache.joshua.decoder.JoshuaConfiguration.SERVER_TYPE;
+import org.apache.joshua.decoder.io.TranslationRequestStream;
++import org.apache.joshua.server.ServerThread;
+import org.apache.joshua.server.TcpServer;
+import org.apache.log4j.Level;
+import org.apache.log4j.LogManager;
- import org.apache.joshua.server.ServerThread;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
++import com.sun.net.httpserver.HttpServer;
++
+/**
+ * Implements decoder initialization, including interaction with <code>JoshuaConfiguration</code>
+ * and <code>DecoderTask</code>.
- *
++ *
+ * @author Zhifei Li, zhifei.work@gmail.com
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @author Lane Schwartz dowobeha@users.sourceforge.net
+ */
+public class JoshuaDecoder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(JoshuaDecoder.class);
+
+ // ===============================================================
+ // Main
+ // ===============================================================
+ public static void main(String[] args) throws IOException {
+
+ // default log level
+ LogManager.getRootLogger().setLevel(Level.INFO);
+
+ JoshuaConfiguration joshuaConfiguration = new JoshuaConfiguration();
+ ArgsParser userArgs = new ArgsParser(args,joshuaConfiguration);
+
+ long startTime = System.currentTimeMillis();
+
+ /* Step-0: some sanity checking */
+ joshuaConfiguration.sanityCheck();
+
+ /* Step-1: initialize the decoder, test-set independent */
+ Decoder decoder = new Decoder(joshuaConfiguration);
+
+ LOG.info("Model loading took {} seconds", (System.currentTimeMillis() - startTime) / 1000);
+ LOG.info("Memory used {} MB", ((Runtime.getRuntime().totalMemory()
+ - Runtime.getRuntime().freeMemory()) / 1000000.0));
+
+ /* Step-2: Decoding */
+ // create a server if requested, which will create TranslationRequest objects
+ if (joshuaConfiguration.server_port > 0) {
+ int port = joshuaConfiguration.server_port;
+ if (joshuaConfiguration.server_type == SERVER_TYPE.TCP) {
+ new TcpServer(decoder, port, joshuaConfiguration).start();
+
+ } else if (joshuaConfiguration.server_type == SERVER_TYPE.HTTP) {
+ joshuaConfiguration.use_structured_output = true;
-
++
+ HttpServer server = HttpServer.create(new InetSocketAddress(port), 0);
+ LOG.info("HTTP Server running and listening on port {}.", port);
+ server.createContext("/", new ServerThread(null, decoder, joshuaConfiguration));
+ server.setExecutor(null); // creates a default executor
+ server.start();
+ } else {
+ LOG.error("Unknown server type");
+ System.exit(1);
+ }
+ return;
+ }
-
++
+ // Create a TranslationRequest object, reading from a file if requested, or from STDIN
- InputStream input = (joshuaConfiguration.input_file != null)
++ InputStream input = (joshuaConfiguration.input_file != null)
+ ? new FileInputStream(joshuaConfiguration.input_file)
+ : System.in;
+
+ BufferedReader reader = new BufferedReader(new InputStreamReader(input));
+ TranslationRequestStream fileRequest = new TranslationRequestStream(reader, joshuaConfiguration);
+ TranslationResponseStream translationResponseStream = decoder.decodeAll(fileRequest);
-
++
+ // Create the n-best output stream
+ FileWriter nbest_out = null;
+ if (joshuaConfiguration.n_best_file != null)
+ nbest_out = new FileWriter(joshuaConfiguration.n_best_file);
+
+ for (Translation translation: translationResponseStream) {
-
+ /**
+ * We need to munge the feature value outputs in order to be compatible with Moses tuners.
+ * Whereas Joshua writes to STDOUT whatever is specified in the `output-format` parameter,
+ * Moses expects the simple translation on STDOUT and the n-best list in a file with a fixed
+ * format.
+ */
+ if (joshuaConfiguration.moses) {
+ String text = translation.toString().replaceAll("=", "= ");
+ // Write the complete formatted string to STDOUT
+ if (joshuaConfiguration.n_best_file != null)
+ nbest_out.write(text);
+
+ // Extract just the translation and output that to STDOUT
+ text = text.substring(0, text.indexOf('\n'));
+ String[] fields = text.split(" \\|\\|\\| ");
+ text = fields[1];
+
+ System.out.println(text);
+
+ } else {
+ System.out.print(translation.toString());
+ }
+ }
+
+ if (joshuaConfiguration.n_best_file != null)
+ nbest_out.close();
+
+ LOG.info("Decoding completed.");
+ LOG.info("Memory used {} MB", ((Runtime.getRuntime().totalMemory()
+ - Runtime.getRuntime().freeMemory()) / 1000000.0));
+
+ /* Step-3: clean up */
+ decoder.cleanUp();
+ LOG.info("Total running time: {} seconds", (System.currentTimeMillis() - startTime) / 1000);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
index 6453bd1,0000000..544e16b
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
@@@ -1,116 -1,0 +1,114 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiFeatures;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiString;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiWordAlignmentList;
+import static org.apache.joshua.util.FormatUtils.removeSentenceMarkers;
+
+import java.util.List;
+
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationState;
+import org.apache.joshua.decoder.segment_file.Sentence;
- import org.apache.joshua.decoder.segment_file.Token;
- import org.apache.joshua.util.FormatUtils;
+
+/**
+ * This factory provides methods to create StructuredTranslation objects
+ * from either Viterbi derivations or KBest derivations.
- *
++ *
+ * @author fhieber
+ */
+public class StructuredTranslationFactory {
-
++
+ /**
+ * Returns a StructuredTranslation instance from the Viterbi derivation.
- *
++ *
+ * @param sourceSentence the source sentence
+ * @param hypergraph the hypergraph object
+ * @param featureFunctions the list of active feature functions
+ * @return A StructuredTranslation object representing the Viterbi derivation.
+ */
+ public static StructuredTranslation fromViterbiDerivation(
+ final Sentence sourceSentence,
+ final HyperGraph hypergraph,
+ final List<FeatureFunction> featureFunctions) {
+ final long startTime = System.currentTimeMillis();
+ final String translationString = removeSentenceMarkers(getViterbiString(hypergraph));
+ return new StructuredTranslation(
+ sourceSentence,
+ translationString,
+ extractTranslationTokens(translationString),
+ extractTranslationScore(hypergraph),
+ getViterbiWordAlignmentList(hypergraph),
+ getViterbiFeatures(hypergraph, featureFunctions, sourceSentence).toStringMap(),
+ (System.currentTimeMillis() - startTime) / 1000.0f);
+ }
-
++
+ /**
+ * Returns a StructuredTranslation from an empty decoder output
+ * @param sourceSentence the source sentence
+ * @return a StructuredTranslation object
+ */
+ public static StructuredTranslation fromEmptyOutput(final Sentence sourceSentence) {
+ return new StructuredTranslation(
+ sourceSentence, "", emptyList(), 0, emptyList(), emptyMap(), 0f);
+ }
-
++
+ /**
- * Returns a StructuredTranslation instance from a KBest DerivationState.
++ * Returns a StructuredTranslation instance from a KBest DerivationState.
+ * @param sourceSentence Sentence object representing the source.
+ * @param derivationState the KBest DerivationState.
+ * @return A StructuredTranslation object representing the derivation encoded by derivationState.
+ */
+ public static StructuredTranslation fromKBestDerivation(
+ final Sentence sourceSentence,
+ final DerivationState derivationState) {
+ final long startTime = System.currentTimeMillis();
+ final String translationString = removeSentenceMarkers(derivationState.getHypothesis());
+ return new StructuredTranslation(
+ sourceSentence,
+ translationString,
+ extractTranslationTokens(translationString),
+ derivationState.getModelCost(),
+ derivationState.getWordAlignmentList(),
+ derivationState.getFeatures().toStringMap(),
+ (System.currentTimeMillis() - startTime) / 1000.0f);
+ }
-
++
+ private static float extractTranslationScore(final HyperGraph hypergraph) {
+ if (hypergraph == null) {
+ return 0;
+ } else {
+ return hypergraph.goalNode.getScore();
+ }
+ }
-
++
+ private static List<String> extractTranslationTokens(final String translationString) {
+ if (translationString.isEmpty()) {
+ return emptyList();
+ } else {
+ return asList(translationString.split("\\s+"));
+ }
+ }
+}