You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/06/23 18:45:56 UTC
[45/60] [partial] incubator-joshua git commit: maven multi-module
layout 1st commit: moving files into joshua-core
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java b/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
new file mode 100755
index 0000000..722c593
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
@@ -0,0 +1,728 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.adagrad;
+
+import java.util.Collections;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.Vector;
+import java.lang.Math;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.metrics.EvaluationMetric;
+
+// this class implements the AdaGrad algorithm
+public class Optimizer {
+ public Optimizer(Vector<String>_output, boolean[] _isOptimizable, double[] _initialLambda,
+ HashMap<String, String>[] _feat_hash, HashMap<String, String>[] _stats_hash) {
+ output = _output; // (not used for now)
+ isOptimizable = _isOptimizable;
+ initialLambda = _initialLambda; // initial weights array
+ paramDim = initialLambda.length - 1;
+ initialLambda = _initialLambda;
+ feat_hash = _feat_hash; // feature hash table
+ stats_hash = _stats_hash; // suff. stats hash table
+ finalLambda = new double[initialLambda.length];
+ for(int i = 0; i < finalLambda.length; i++)
+ finalLambda[i] = initialLambda[i];
+ }
+
+ //run AdaGrad for one epoch
+ public double[] runOptimizer() {
+ List<Integer> sents = new ArrayList<Integer>();
+ for( int i = 0; i < sentNum; ++i )
+ sents.add(i);
+ double[] avgLambda = new double[initialLambda.length]; //only needed if averaging is required
+ for( int i = 0; i < initialLambda.length; ++i )
+ avgLambda[i] = 0;
+ for ( int iter = 0; iter < adagradIter; ++iter ) {
+ System.arraycopy(finalLambda, 1, initialLambda, 1, paramDim);
+ if(needShuffle)
+ Collections.shuffle(sents);
+
+ double oraMetric, oraScore, predMetric, predScore;
+ double[] oraPredScore = new double[4];
+ double loss = 0;
+ double diff = 0;
+ double sumMetricScore = 0;
+ double sumModelScore = 0;
+ String oraFeat = "";
+ String predFeat = "";
+ String[] oraPredFeat = new String[2];
+ String[] vecOraFeat;
+ String[] vecPredFeat;
+ String[] featInfo;
+ int thisBatchSize = 0;
+ int numBatch = 0;
+ int numUpdate = 0;
+ Iterator it;
+ Integer diffFeatId;
+
+ //update weights
+ Integer s;
+ int sentCount = 0;
+ double prevLambda = 0;
+ double diffFeatVal = 0;
+ double oldVal = 0;
+ double gdStep = 0;
+ double Hii = 0;
+ double gradiiSquare = 0;
+ int lastUpdateTime = 0;
+ HashMap<Integer, Integer> lastUpdate = new HashMap<Integer, Integer>();
+ HashMap<Integer, Double> lastVal = new HashMap<Integer, Double>();
+ HashMap<Integer, Double> H = new HashMap<Integer, Double>();
+ while( sentCount < sentNum ) {
+ loss = 0;
+ thisBatchSize = batchSize;
+ ++numBatch;
+ HashMap<Integer, Double> featDiff = new HashMap<Integer, Double>();
+ for(int b = 0; b < batchSize; ++b ) {
+ //find out oracle and prediction
+ s = sents.get(sentCount);
+ findOraPred(s, oraPredScore, oraPredFeat, finalLambda, featScale);
+
+ //the model scores here are already scaled in findOraPred
+ oraMetric = oraPredScore[0];
+ oraScore = oraPredScore[1];
+ predMetric = oraPredScore[2];
+ predScore = oraPredScore[3];
+ oraFeat = oraPredFeat[0];
+ predFeat = oraPredFeat[1];
+
+ //update the scale
+ if(needScale) { //otherwise featscale remains 1.0
+ sumMetricScore += Math.abs(oraMetric + predMetric);
+ //restore the original model score
+ sumModelScore += Math.abs(oraScore + predScore) / featScale;
+
+ if(sumModelScore/sumMetricScore > scoreRatio)
+ featScale = sumMetricScore/sumModelScore;
+ }
+ // processedSent++;
+
+ vecOraFeat = oraFeat.split("\\s+");
+ vecPredFeat = predFeat.split("\\s+");
+
+ //accumulate difference feature vector
+ if ( b == 0 ) {
+ for (int i = 0; i < vecOraFeat.length; i++) {
+ featInfo = vecOraFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (int i = 0; i < vecPredFeat.length; i++) {
+ featInfo = vecPredFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)-Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the 2nd feature vector
+ featDiff.put(diffFeatId, -1.0*Double.parseDouble(featInfo[1]));
+ }
+ } else {
+ for (int i = 0; i < vecOraFeat.length; i++) {
+ featInfo = vecOraFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)+Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the new oracle feature vector
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (int i = 0; i < vecPredFeat.length; i++) {
+ featInfo = vecPredFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)-Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the new prediction feature vector
+ featDiff.put(diffFeatId, -1.0*Double.parseDouble(featInfo[1]));
+ }
+ }
+
+ //remember the model scores here are already scaled
+ double singleLoss = evalMetric.getToBeMinimized() ?
+ (predMetric-oraMetric) - (oraScore-predScore)/featScale:
+ (oraMetric-predMetric) - (oraScore-predScore)/featScale;
+ if(singleLoss > 0)
+ loss += singleLoss;
+ ++sentCount;
+ if( sentCount >= sentNum ) {
+ thisBatchSize = b + 1;
+ break;
+ }
+ } //for(int b : batchSize)
+
+ //System.out.println("\n\n"+sentCount+":");
+
+ if( loss > 0 ) {
+ //if(true) {
+ ++numUpdate;
+ //update weights (see Duchi'11, Eq.23. For l1-reg, use lazy update)
+ Set<Integer> diffFeatSet = featDiff.keySet();
+ it = diffFeatSet.iterator();
+ while(it.hasNext()) { //note these are all non-zero gradients!
+ diffFeatId = (Integer)it.next();
+ diffFeatVal = -1.0 * featDiff.get(diffFeatId); //gradient
+ if( regularization > 0 ) {
+ lastUpdateTime =
+ lastUpdate.get(diffFeatId) == null ? 0 : lastUpdate.get(diffFeatId);
+ if( lastUpdateTime < numUpdate - 1 ) {
+ //haven't been updated (gradient=0) for at least 2 steps
+ //lazy compute prevLambda now
+ oldVal =
+ lastVal.get(diffFeatId) == null ? initialLambda[diffFeatId] : lastVal.get(diffFeatId);
+ Hii =
+ H.get(diffFeatId) == null ? 0 : H.get(diffFeatId);
+ if(Math.abs(Hii) > 1e-20) {
+ if( regularization == 1 )
+ prevLambda =
+ Math.signum(oldVal) * clip( Math.abs(oldVal) - lam * eta * (numBatch - 1 - lastUpdateTime) / Hii );
+ else if( regularization == 2 ) {
+ prevLambda =
+ Math.pow( Hii/(lam+Hii), (numUpdate - 1 - lastUpdateTime) ) * oldVal;
+ if(needAvg) { //fill the gap due to lazy update
+ double prevLambdaCopy = prevLambda;
+ double scale = Hii/(lam+Hii);
+ for( int t = 0; t < numUpdate - 1 - lastUpdateTime; ++t ) {
+ avgLambda[diffFeatId] += prevLambdaCopy;
+ prevLambdaCopy /= scale;
+ }
+ }
+ }
+ } else {
+ if( regularization == 1 )
+ prevLambda = 0;
+ else if( regularization == 2 )
+ prevLambda = oldVal;
+ }
+ } else //just updated at last time step or just started
+ prevLambda = finalLambda[diffFeatId];
+ if(H.get(diffFeatId) != null) {
+ gradiiSquare = H.get(diffFeatId);
+ gradiiSquare *= gradiiSquare;
+ gradiiSquare += diffFeatVal * diffFeatVal;
+ Hii = Math.sqrt(gradiiSquare);
+ } else
+ Hii = Math.abs(diffFeatVal);
+ H.put(diffFeatId, Hii);
+ //update the weight
+ if( regularization == 1 ) {
+ gdStep = prevLambda - eta * diffFeatVal / Hii;
+ finalLambda[diffFeatId] = Math.signum(gdStep) * clip( Math.abs(gdStep) - lam * eta / Hii );
+ } else if(regularization == 2 ) {
+ finalLambda[diffFeatId] = (Hii * prevLambda - eta * diffFeatVal) / (lam + Hii);
+ if(needAvg)
+ avgLambda[diffFeatId] += finalLambda[diffFeatId];
+ }
+ lastUpdate.put(diffFeatId, numUpdate);
+ lastVal.put(diffFeatId, finalLambda[diffFeatId]);
+ } else { //if no regularization
+ if(H.get(diffFeatId) != null) {
+ gradiiSquare = H.get(diffFeatId);
+ gradiiSquare *= gradiiSquare;
+ gradiiSquare += diffFeatVal * diffFeatVal;
+ Hii = Math.sqrt(gradiiSquare);
+ } else
+ Hii = Math.abs(diffFeatVal);
+ H.put(diffFeatId, Hii);
+ finalLambda[diffFeatId] = finalLambda[diffFeatId] - eta * diffFeatVal / Hii;
+ if(needAvg)
+ avgLambda[diffFeatId] += finalLambda[diffFeatId];
+ }
+ } //while(it.hasNext())
+ } //if(loss > 0)
+ else { //no loss, therefore the weight update is skipped
+ //however, the avg weights still need to be accumulated
+ if( regularization == 0 ) {
+ for( int i = 1; i < finalLambda.length; ++i )
+ avgLambda[i] += finalLambda[i];
+ } else if( regularization == 2 ) {
+ if(needAvg) {
+ //due to lazy update, we need to figure out the actual
+ //weight vector at this point first...
+ for( int i = 1; i < finalLambda.length; ++i ) {
+ if( lastUpdate.get(i) != null ) {
+ if( lastUpdate.get(i) < numUpdate ) {
+ oldVal = lastVal.get(i);
+ Hii = H.get(i);
+ //lazy compute
+ avgLambda[i] +=
+ Math.pow( Hii/(lam+Hii), (numUpdate - lastUpdate.get(i)) ) * oldVal;
+ } else
+ avgLambda[i] += finalLambda[i];
+ }
+ avgLambda[i] += finalLambda[i];
+ }
+ }
+ }
+ }
+ } //while( sentCount < sentNum )
+ if( regularization > 0 ) {
+ for( int i = 1; i < finalLambda.length; ++i ) {
+ //now lazy compute those weights that haven't been taken care of
+ if( lastUpdate.get(i) == null )
+ finalLambda[i] = 0;
+ else if( lastUpdate.get(i) < numUpdate ) {
+ oldVal = lastVal.get(i);
+ Hii = H.get(i);
+ if( regularization == 1 )
+ finalLambda[i] =
+ Math.signum(oldVal) * clip( Math.abs(oldVal) - lam * eta * (numUpdate - lastUpdate.get(i)) / Hii );
+ else if( regularization == 2 ) {
+ finalLambda[i] =
+ Math.pow( Hii/(lam+Hii), (numUpdate - lastUpdate.get(i)) ) * oldVal;
+ if(needAvg) { //fill the gap due to lazy update
+ double prevLambdaCopy = finalLambda[i];
+ double scale = Hii/(lam+Hii);
+ for( int t = 0; t < numUpdate - lastUpdate.get(i); ++t ) {
+ avgLambda[i] += prevLambdaCopy;
+ prevLambdaCopy /= scale;
+ }
+ }
+ }
+ }
+ if( regularization == 2 && needAvg ) {
+ if( iter == adagradIter - 1 )
+ finalLambda[i] = avgLambda[i] / ( numBatch * adagradIter );
+ }
+ }
+ } else { //if no regularization
+ if( iter == adagradIter - 1 && needAvg ) {
+ for( int i = 1; i < finalLambda.length; ++i )
+ finalLambda[i] = avgLambda[i] / ( numBatch * adagradIter );
+ }
+ }
+
+ double initMetricScore;
+ if (iter == 0) {
+ initMetricScore = computeCorpusMetricScore(initialLambda);
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ } else {
+ initMetricScore = finalMetricScore;
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ }
+ // prepare the printing info
+ String result = " Initial "
+ + evalMetric.get_metricName() + "=" + String.format("%.4f", initMetricScore) + " Final "
+ + evalMetric.get_metricName() + "=" + String.format("%.4f", finalMetricScore);
+ //print lambda info
+ // int numParamToPrint = 0;
+ // numParamToPrint = paramDim > 10 ? 10 : paramDim; // how many parameters
+ // // to print
+ // result = paramDim > 10 ? "Final lambda (first 10): {" : "Final lambda: {";
+
+ // for (int i = 1; i <= numParamToPrint; ++i)
+ // result += String.format("%.4f", finalLambda[i]) + " ";
+
+ output.add(result);
+ } //for ( int iter = 0; iter < adagradIter; ++iter ) {
+
+ //non-optimizable weights should remain unchanged
+ ArrayList<Double> cpFixWt = new ArrayList<Double>();
+ for ( int i = 1; i < isOptimizable.length; ++i ) {
+ if ( ! isOptimizable[i] )
+ cpFixWt.add(finalLambda[i]);
+ }
+ normalizeLambda(finalLambda);
+ int countNonOpt = 0;
+ for ( int i = 1; i < isOptimizable.length; ++i ) {
+ if ( ! isOptimizable[i] ) {
+ finalLambda[i] = cpFixWt.get(countNonOpt);
+ ++countNonOpt;
+ }
+ }
+ return finalLambda;
+ }
+
+ private double clip(double x) {
+ return x > 0 ? x : 0;
+ }
+
+ public double computeCorpusMetricScore(double[] finalLambda) {
+ int suffStatsCount = evalMetric.get_suffStatsCount();
+ double modelScore;
+ double maxModelScore;
+ Set<String> candSet;
+ String candStr;
+ String[] feat_str;
+ String[] tmpStatsVal = new String[suffStatsCount];
+ int[] corpusStatsVal = new int[suffStatsCount];
+ for (int i = 0; i < suffStatsCount; i++)
+ corpusStatsVal[i] = 0;
+
+ for (int i = 0; i < sentNum; i++) {
+ candSet = feat_hash[i].keySet();
+
+ // find out the 1-best candidate for each sentence
+ // this depends on the training mode
+ maxModelScore = NegInf;
+ for (Iterator it = candSet.iterator(); it.hasNext();) {
+ modelScore = 0.0;
+ candStr = it.next().toString();
+
+ feat_str = feat_hash[i].get(candStr).split("\\s+");
+
+ String[] feat_info;
+
+ for (int f = 0; f < feat_str.length; f++) {
+ feat_info = feat_str[f].split("=");
+ modelScore +=
+ Double.parseDouble(feat_info[1]) * finalLambda[Vocabulary.id(feat_info[0])];
+ }
+
+ if (maxModelScore < modelScore) {
+ maxModelScore = modelScore;
+ tmpStatsVal = stats_hash[i].get(candStr).split("\\s+"); // save the
+ // suff stats
+ }
+ }
+
+ for (int j = 0; j < suffStatsCount; j++)
+ corpusStatsVal[j] += Integer.parseInt(tmpStatsVal[j]); // accumulate
+ // corpus-leve
+ // suff stats
+ } // for( int i=0; i<sentNum; i++ )
+
+ return evalMetric.score(corpusStatsVal);
+ }
+
+ private void findOraPred(int sentId, double[] oraPredScore, String[] oraPredFeat, double[] lambda, double featScale)
+ {
+ double oraMetric=0, oraScore=0, predMetric=0, predScore=0;
+ String oraFeat="", predFeat="";
+ double candMetric = 0, candScore = 0; //metric and model scores for each cand
+ Set<String> candSet = stats_hash[sentId].keySet();
+ String cand = "";
+ String feats = "";
+ String oraCand = ""; //only used when BLEU/TER-BLEU is used as metric
+ String[] featStr;
+ String[] featInfo;
+
+ int actualFeatId;
+ double bestOraScore;
+ double worstPredScore;
+
+ if(oraSelectMode==1)
+ bestOraScore = NegInf; //larger score will be selected
+ else {
+ if(evalMetric.getToBeMinimized())
+ bestOraScore = PosInf; //smaller score will be selected
+ else
+ bestOraScore = NegInf;
+ }
+
+ if(predSelectMode==1 || predSelectMode==2)
+ worstPredScore = NegInf; //larger score will be selected
+ else {
+ if(evalMetric.getToBeMinimized())
+ worstPredScore = NegInf; //larger score will be selected
+ else
+ worstPredScore = PosInf;
+ }
+
+ for (Iterator it = candSet.iterator(); it.hasNext();) {
+ cand = it.next().toString();
+ candMetric = computeSentMetric(sentId, cand); //compute metric score
+
+ //start to compute model score
+ candScore = 0;
+ featStr = feat_hash[sentId].get(cand).split("\\s+");
+ feats = "";
+
+ for (int i = 0; i < featStr.length; i++) {
+ featInfo = featStr[i].split("=");
+ actualFeatId = Vocabulary.id(featInfo[0]);
+ candScore += Double.parseDouble(featInfo[1]) * lambda[actualFeatId];
+ if ( (actualFeatId < isOptimizable.length && isOptimizable[actualFeatId]) ||
+ actualFeatId >= isOptimizable.length )
+ feats += actualFeatId + "=" + Double.parseDouble(featInfo[1]) + " ";
+ }
+
+ candScore *= featScale; //scale the model score
+
+ //is this cand oracle?
+ if(oraSelectMode == 1) {//"hope", b=1, r=1
+ if(evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if( bestOraScore<=(candScore-candMetric) ) {
+ bestOraScore = candScore-candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ else {
+ if( bestOraScore<=(candScore+candMetric) ) {
+ bestOraScore = candScore+candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ }
+ else {//best metric score(ex: max BLEU), b=1, r=0
+ if(evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if( bestOraScore>=candMetric ) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ else {
+ if( bestOraScore<=candMetric ) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ }
+
+ //is this cand prediction?
+ if(predSelectMode == 1) {//"fear"
+ if(evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if( worstPredScore<=(candScore+candMetric) ) {
+ worstPredScore = candScore+candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ else {
+ if( worstPredScore<=(candScore-candMetric) ) {
+ worstPredScore = candScore-candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ }
+ else if(predSelectMode == 2) {//model prediction(max model score)
+ if( worstPredScore<=candScore ) {
+ worstPredScore = candScore;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ else {//worst metric score(ex: min BLEU)
+ if(evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if( worstPredScore<=candMetric ) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ else {
+ if( worstPredScore>=candMetric ) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ }
+ }
+
+ oraPredScore[0] = oraMetric;
+ oraPredScore[1] = oraScore;
+ oraPredScore[2] = predMetric;
+ oraPredScore[3] = predScore;
+ oraPredFeat[0] = oraFeat;
+ oraPredFeat[1] = predFeat;
+
+ //update the BLEU metric statistics if pseudo corpus is used to compute BLEU/TER-BLEU
+ if(evalMetric.get_metricName().equals("BLEU") && usePseudoBleu ) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j]);
+ }
+
+ if(evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu ) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount()-2; j++)
+ bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j+2]); //the first 2 stats are TER stats
+ }
+ }
+
+ // compute *sentence-level* metric score for cand
+ private double computeSentMetric(int sentId, String cand) {
+ String statString;
+ String[] statVal_str;
+ int[] statVal = new int[evalMetric.get_suffStatsCount()];
+
+ statString = stats_hash[sentId].get(cand);
+ statVal_str = statString.split("\\s+");
+
+ if(evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = (int) (Integer.parseInt(statVal_str[j]) + bleuHistory[sentId][j]);
+ } else if(evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount()-2; j++)
+ statVal[j+2] = (int)(Integer.parseInt(statVal_str[j+2]) + bleuHistory[sentId][j]); //only modify the BLEU stats part(TER has 2 stats)
+ } else { //in all other situations, use normal stats
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = Integer.parseInt(statVal_str[j]);
+ }
+
+ return evalMetric.score(statVal);
+ }
+
+ // from ZMERT
+ private void normalizeLambda(double[] origLambda) {
+ // private String[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ int normalizationMethod = (int) normalizationOptions[0];
+ double scalingFactor = 1.0;
+ if (normalizationMethod == 0) {
+ scalingFactor = 1.0;
+ } else if (normalizationMethod == 1) {
+ int c = (int) normalizationOptions[2];
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[c]);
+ } else if (normalizationMethod == 2) {
+ double maxAbsVal = -1;
+ int maxAbsVal_c = 0;
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) > maxAbsVal) {
+ maxAbsVal = Math.abs(origLambda[c]);
+ maxAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[maxAbsVal_c]);
+
+ } else if (normalizationMethod == 3) {
+ double minAbsVal = PosInf;
+ int minAbsVal_c = 0;
+
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) < minAbsVal) {
+ minAbsVal = Math.abs(origLambda[c]);
+ minAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[minAbsVal_c]);
+
+ } else if (normalizationMethod == 4) {
+ double pow = normalizationOptions[1];
+ double norm = L_norm(origLambda, pow);
+ scalingFactor = normalizationOptions[2] / norm;
+ }
+
+ for (int c = 1; c <= paramDim; ++c) {
+ origLambda[c] *= scalingFactor;
+ }
+ }
+
+ // from ZMERT
+ private double L_norm(double[] A, double pow) {
+ // calculates the L-pow norm of A[]
+ // NOTE: this calculation ignores A[0]
+ double sum = 0.0;
+ for (int i = 1; i < A.length; ++i)
+ sum += Math.pow(Math.abs(A[i]), pow);
+
+ return Math.pow(sum, 1 / pow);
+ }
+
+ public static double getScale()
+ {
+ return featScale;
+ }
+
+ public static void initBleuHistory(int sentNum, int statCount)
+ {
+ bleuHistory = new double[sentNum][statCount];
+ for(int i=0; i<sentNum; i++) {
+ for(int j=0; j<statCount; j++) {
+ bleuHistory[i][j] = 0.0;
+ }
+ }
+ }
+
+ public double getMetricScore()
+ {
+ return finalMetricScore;
+ }
+
+ private Vector<String> output;
+ private double[] initialLambda;
+ private double[] finalLambda;
+ private double finalMetricScore;
+ private HashMap<String, String>[] feat_hash;
+ private HashMap<String, String>[] stats_hash;
+ private int paramDim;
+ private boolean[] isOptimizable;
+ public static int sentNum;
+ public static int adagradIter; //AdaGrad internal iterations
+ public static int oraSelectMode;
+ public static int predSelectMode;
+ public static int batchSize;
+ public static int regularization;
+ public static boolean needShuffle;
+ public static boolean needScale;
+ public static double scoreRatio;
+ public static boolean needAvg;
+ public static boolean usePseudoBleu;
+ public static double featScale = 1.0; //scale the features in order to make the model score comparable with metric score
+ //updates in each epoch if necessary
+ public static double eta;
+ public static double lam;
+ public static double R; //corpus decay(used only when pseudo corpus is used to compute BLEU)
+ public static EvaluationMetric evalMetric;
+ public static double[] normalizationOptions;
+ public static double[][] bleuHistory;
+
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/AbstractPhrase.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/AbstractPhrase.java b/joshua-core/src/main/java/org/apache/joshua/corpus/AbstractPhrase.java
new file mode 100644
index 0000000..b4637d4
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/AbstractPhrase.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+/**
+ * This class provides a skeletal implementation of the base methods likely to be common to most or
+ * all implementations of the <code>Phrase</code> interface.
+ *
+ * @author Lane Schwartz
+ * @author Chris Callison-Burch
+ */
+public abstract class AbstractPhrase implements Phrase {
+
+ // ===============================================================
+ // Constants
+ // ===============================================================
+
+ /** seed used in hash code generation */
+ public static final int HASH_SEED = 17;
+
+ /** offset used in has code generation */
+ public static final int HASH_OFFSET = 37;
+
+ /**
+ * Splits a sentence (on white space), then looks up the integer representations of each word
+ * using the supplied symbol table.
+ *
+ * @param sentence White-space separated String of words.
+ *
+ * @return Array of integers corresponding to the words in the sentence.
+ */
+ protected int[] splitSentence(String sentence) {
+ String[] w = sentence.split("\\s+");
+ int[] words = new int[w.length];
+ for (int i = 0; i < w.length; i++)
+ words[i] = Vocabulary.id(w[i]);
+ return words;
+ }
+
+ /**
+ * Uses the standard java approach of calculating hashCode. Start with a seed, add in every value
+ * multiplying the exsiting hash times an offset.
+ *
+ * @return int hashCode for the list
+ */
+ public int hashCode() {
+ int result = HASH_SEED;
+ for (int i = 0; i < size(); i++) {
+ result = HASH_OFFSET * result + getWordID(i);
+ }
+ return result;
+ }
+
+
+ /**
+ * Two phrases are their word IDs are the same. Note that this could give a false positive if
+ * their Vocabularies were different but their IDs were somehow the same.
+ */
+ public boolean equals(Object o) {
+
+ if (o instanceof Phrase) {
+ Phrase other = (Phrase) o;
+
+ if (this.size() != other.size()) return false;
+ for (int i = 0; i < size(); i++) {
+ if (this.getWordID(i) != other.getWordID(i)) return false;
+ }
+ return true;
+ } else {
+ return false;
+ }
+
+ }
+
+
+ /**
+ * Compares the two strings based on the lexicographic order of words defined in the Vocabulary.
+ *
+ * @param other the object to compare to
+ * @return -1 if this object is less than the parameter, 0 if equals, 1 if greater
+ * @exception ClassCastException if the passed object is not of type Phrase
+ */
+ public int compareTo(Phrase other) {
+ int length = size();
+ int otherLength = other.size();
+ for (int i = 0; i < length; i++) {
+ if (i < otherLength) {
+ int difference = getWordID(i) - other.getWordID(i);
+ if (difference != 0) return difference;
+ } else {
+ // same but other is shorter, so we are after
+ return 1;
+ }
+ }
+ if (length < otherLength) {
+ return -1;
+ } else {
+ return 0;
+ }
+ }
+
+ /**
+ * Returns a string representation of the phrase.
+ *
+ * @return a space-delimited string of the words in the phrase.
+ */
+ public String toString() {
+ StringBuffer buf = new StringBuffer();
+ for (int i = 0; i < size(); i++) {
+ String word = Vocabulary.word(getWordID(i));
+ if (i != 0) buf.append(' ');
+ buf.append(word);
+ }
+ return buf.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/BasicPhrase.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/BasicPhrase.java b/joshua-core/src/main/java/org/apache/joshua/corpus/BasicPhrase.java
new file mode 100644
index 0000000..6c50458
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/BasicPhrase.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import java.util.ArrayList;
+
+/**
+ * The simplest concrete implementation of Phrase.
+ *
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @version $LastChangedDate$
+ */
+public class BasicPhrase extends AbstractPhrase {
+ private byte language;
+ private int[] words;
+
+
+ public BasicPhrase(byte language, String sentence) {
+ this.language = language;
+ this.words = splitSentence(sentence);
+ }
+
+ private BasicPhrase() {}
+
+ public int[] getWordIDs() {
+ return words;
+ }
+
+ /* See Javadoc for Phrase interface. */
+ public BasicPhrase subPhrase(int start, int end) {
+ BasicPhrase that = new BasicPhrase();
+ that.language = this.language;
+ that.words = new int[end - start + 1];
+ System.arraycopy(this.words, start, that.words, 0, end - start + 1);
+ return that;
+ }
+
+ /* See Javadoc for Phrase interface. */
+ public ArrayList<Phrase> getSubPhrases() {
+ return this.getSubPhrases(this.size());
+ }
+
+ /* See Javadoc for Phrase interface. */
+ public ArrayList<Phrase> getSubPhrases(int maxLength) {
+ ArrayList<Phrase> phrases = new ArrayList<Phrase>();
+ int len = this.size();
+ for (int n = 1; n <= maxLength; n++)
+ for (int i = 0; i <= len - n; i++)
+ phrases.add(this.subPhrase(i, i + n - 1));
+ return phrases;
+ }
+
+ /* See Javadoc for Phrase interface. */
+ public int size() {
+ return (words == null ? 0 : words.length);
+ }
+
+ /* See Javadoc for Phrase interface. */
+ public int getWordID(int position) {
+ return words[position];
+ }
+
+ /**
+ * Returns a human-readable String representation of the phrase.
+ * <p>
+ * The implementation of this method is slightly more efficient than that inherited from
+ * <code>AbstractPhrase</code>.
+ *
+ * @return a human-readable String representation of the phrase.
+ */
+ public String toString() {
+ StringBuffer sb = new StringBuffer();
+ if (words != null) {
+ for (int i = 0; i < words.length; ++i) {
+ if (i != 0) sb.append(' ');
+ sb.append(Vocabulary.word(words[i]));
+ }
+ }
+ return sb.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/ContiguousPhrase.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/ContiguousPhrase.java b/joshua-core/src/main/java/org/apache/joshua/corpus/ContiguousPhrase.java
new file mode 100644
index 0000000..af669b7
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/ContiguousPhrase.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * ContiguousPhrase implements the Phrase interface by linking into indices within a corpus. This is
+ * intended to be a very low-memory implementation of the class.
+ *
+ * @author Chris Callison-Burch
+ * @since 29 May 2008
+ * @version $LastChangedDate:2008-09-18 12:47:23 -0500 (Thu, 18 Sep 2008) $
+ */
+public class ContiguousPhrase extends AbstractPhrase {
+
+ protected int startIndex;
+ protected int endIndex;
+ protected Corpus corpusArray;
+
+ public ContiguousPhrase(int startIndex, int endIndex, Corpus corpusArray) {
+ this.startIndex = startIndex;
+ this.endIndex = endIndex;
+ this.corpusArray = corpusArray;
+ }
+
+ /**
+ * This method copies the phrase into an array of ints. This method should be avoided if possible.
+ *
+ * @return an int[] corresponding to the ID of each word in the phrase
+ */
+ public int[] getWordIDs() {
+ int[] words = new int[endIndex - startIndex];
+ for (int i = startIndex; i < endIndex; i++) {
+ words[i - startIndex] = corpusArray.getWordID(i); // corpusArray.corpus[i];
+ }
+ return words;
+ }
+
+ public int getWordID(int position) {
+ return corpusArray.getWordID(startIndex + position);
+ // return corpusArray.corpus[startIndex+position];
+ }
+
+ public int size() {
+ return endIndex - startIndex;
+ }
+
+ /**
+ * Gets all possible subphrases of this phrase, up to and including the phrase itself. For
+ * example, the phrase "I like cheese ." would return the following:
+ * <ul>
+ * <li>I
+ * <li>like
+ * <li>cheese
+ * <li>.
+ * <li>I like
+ * <li>like cheese
+ * <li>cheese .
+ * <li>I like cheese
+ * <li>like cheese .
+ * <li>I like cheese .
+ * </ul>
+ *
+ * @return ArrayList of all possible subphrases.
+ */
+ public List<Phrase> getSubPhrases() {
+ return getSubPhrases(size());
+ }
+
+ /**
+ * Returns a list of subphrases only of length <code>maxLength</code> or smaller.
+ *
+ * @param maxLength the maximum length phrase to return.
+ * @return ArrayList of all possible subphrases of length maxLength or less
+ * @see #getSubPhrases()
+ */
+ public List<Phrase> getSubPhrases(int maxLength) {
+ if (maxLength > size()) return getSubPhrases(size());
+ List<Phrase> phrases = new ArrayList<Phrase>();
+ for (int i = 0; i < size(); i++) {
+ for (int j = i + 1; (j <= size()) && (j - i <= maxLength); j++) {
+ Phrase subPhrase = subPhrase(i, j);
+ phrases.add(subPhrase);
+ }
+ }
+ return phrases;
+ }
+
+ /**
+ * creates a new phrase object from the indexes provided.
+ * <P>
+ * NOTE: subList merely creates a "view" of the existing Phrase object. Memory taken up by other
+ * Words in the Phrase is not freed since the underlying subList object still points to the
+ * complete Phrase List.
+ *
+ * @see ArrayList#subList(int, int)
+ */
+ public Phrase subPhrase(int start, int end) {
+ return new ContiguousPhrase(startIndex + start, startIndex + end, corpusArray);
+ }
+
+ /**
+ * Main contains test code
+ * @param args String array of arguments used to run this class.
+ */
+ public static void main(String[] args) {
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/Corpus.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/Corpus.java b/joshua-core/src/main/java/org/apache/joshua/corpus/Corpus.java
new file mode 100755
index 0000000..1a7d1b0
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/Corpus.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+/**
+ * Corpus is an interface that contains methods for accessing the information within a monolingual
+ * corpus.
+ *
+ * @author Chris Callison-Burch
+ * @since 7 February 2005
+ * @version $LastChangedDate:2008-07-30 17:15:52 -0400 (Wed, 30 Jul 2008) $
+ */
+
+public interface Corpus { // extends Externalizable {
+
+ // ===============================================================
+ // Attribute definitions
+ // ===============================================================
+
+ /**
+ * @param position the position at which we want to obtain a word ID
+ * @return the integer representation of the Word at the specified position in the corpus.
+ */
+ int getWordID(int position);
+
+
+ /**
+ * Gets the sentence index associated with the specified position in the corpus.
+ *
+ * @param position Index into the corpus
+ * @return the sentence index associated with the specified position in the corpus.
+ */
+ int getSentenceIndex(int position);
+
+
+ /**
+ * Gets the sentence index of each specified position.
+ *
+ * @param positions Index into the corpus
+ * @return array of the sentence indices associated with the specified positions in the corpus.
+ */
+ int[] getSentenceIndices(int[] positions);
+
+ /**
+ * Gets the position in the corpus of the first word of the specified sentence. If the sentenceID
+ * is outside of the bounds of the sentences, then it returns the last position in the corpus + 1.
+ *
+ * @param sentenceID a specific sentence to obtain a position for
+ * @return the position in the corpus of the first word of the specified sentence. If the
+ * sentenceID is outside of the bounds of the sentences, then it returns the last position
+ * in the corpus + 1.
+ */
+ int getSentencePosition(int sentenceID);
+
+ /**
+ * Gets the exclusive end position of a sentence in the corpus.
+ *
+ * @param sentenceID a specific sentence to obtain an end position for
+ * @return the position in the corpus one past the last word of the specified sentence. If the
+ * sentenceID is outside of the bounds of the sentences, then it returns one past the last
+ * position in the corpus.
+ */
+ int getSentenceEndPosition(int sentenceID);
+
+ /**
+ * Gets the specified sentence as a phrase.
+ *
+ * @param sentenceIndex Zero-based sentence index
+ * @return the sentence, or null if the specified sentence number doesn't exist
+ */
+ Phrase getSentence(int sentenceIndex);
+
+
+ /**
+ * Gets the number of words in the corpus.
+ *
+ * @return the number of words in the corpus.
+ */
+ int size();
+
+
+ /**
+ * Gets the number of sentences in the corpus.
+ *
+ * @return the number of sentences in the corpus.
+ */
+ int getNumSentences();
+
+
+ // ===========================================================
+ // Methods
+ // ===========================================================
+
+
+ /**
+ * Compares the phrase that starts at position start with the subphrase indicated by the start and
+ * end points of the phrase.
+ *
+ * @param corpusStart the point in the corpus where the comparison begins
+ * @param phrase the superphrase that the comparsion phrase is drawn from
+ * @param phraseStart the point in the phrase where the comparison begins (inclusive)
+ * @param phraseEnd the point in the phrase where the comparison ends (exclusive)
+ * @return an int that follows the conventions of {@link java.util.Comparator#compare(Object, Object)}
+ */
+ int comparePhrase(int corpusStart, Phrase phrase, int phraseStart, int phraseEnd);
+
+
+ /**
+ * Compares the phrase that starts at position start with the phrase passed in. Compares the
+ * entire phrase.
+ *
+ * @param corpusStart position start
+ * @param phrase {@link org.apache.joshua.corpus.Phrase} to compare against
+ * @return an int that follows the conventions of {@link java.util.Comparator#compare(Object, Object)}
+ */
+ int comparePhrase(int corpusStart, Phrase phrase);
+
+ /**
+ * Compares the suffixes starting a positions index1 and index2.
+ *
+ * @param position1 the position in the corpus where the first suffix begins
+ * @param position2 the position in the corpus where the second suffix begins
+ * @param maxComparisonLength a cutoff point to stop the comparison
+ * @return an int that follows the conventions of {@link java.util.Comparator#compare(Object, Object)}
+ */
+ int compareSuffixes(int position1, int position2, int maxComparisonLength);
+
+ /**
+ *
+ * @param startPosition start position for phrase
+ * @param endPosition end position for phrase
+ * @return the {@link org.apache.joshua.corpus.ContiguousPhrase}
+ */
+ ContiguousPhrase getPhrase(int startPosition, int endPosition);
+
+ /**
+ * Gets an object capable of iterating over all positions in the corpus, in order.
+ *
+ * @return An object capable of iterating over all positions in the corpus, in order.
+ */
+ Iterable<Integer> corpusPositions();
+
+ // void write(String corpusFilename, String vocabFilename, String charset) throws IOException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/Phrase.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/Phrase.java b/joshua-core/src/main/java/org/apache/joshua/corpus/Phrase.java
new file mode 100644
index 0000000..5a06a8b
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/Phrase.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import java.util.ArrayList;
+import java.util.List;
+
+
+/**
+ * Representation of a sequence of tokens.
+ *
+ * @version $LastChangedDate:2008-09-18 10:31:54 -0500 (Thu, 18 Sep 2008) $
+ */
+public interface Phrase extends Comparable<Phrase> {
+
+ /**
+ * This method gets the integer IDs of the phrase as an array of ints.
+ *
+ * @return an int[] corresponding to the ID of each word in the phrase
+ */
+ public int[] getWordIDs();
+
+ /**
+ * Returns the integer word id of the word at the specified position.
+ *
+ * @param position Index of a word in this phrase.
+ * @return the integer word id of the word at the specified position.
+ */
+ int getWordID(int position);
+
+
+ /**
+ * Returns the number of words in this phrase.
+ *
+ * @return the number of words in this phrase.
+ */
+ int size();
+
+
+
+ /**
+ * Gets all possible subphrases of this phrase, up to and including the phrase itself. For
+ * example, the phrase "I like cheese ." would return the following:
+ * <ul>
+ * <li>I
+ * <li>like
+ * <li>cheese
+ * <li>.
+ * <li>I like
+ * <li>like cheese
+ * <li>cheese .
+ * <li>I like cheese
+ * <li>like cheese .
+ * <li>I like cheese .
+ * </ul>
+ *
+ * @return List of all possible subphrases.
+ */
+ List<Phrase> getSubPhrases();
+
+
+ /**
+ * Returns a list of subphrases only of length <code>maxLength</code> or smaller.
+ *
+ * @param maxLength the maximum length phrase to return.
+ * @return List of all possible subphrases of length maxLength or less
+ * @see #getSubPhrases()
+ */
+ List<Phrase> getSubPhrases(int maxLength);
+
+
+ /**
+ * creates a new phrase object from the indexes provided.
+ * <P>
+ * NOTE: subList merely creates a "view" of the existing Phrase object. Memory taken up by other
+ * Words in the Phrase is not freed since the underlying subList object still points to the
+ * complete Phrase List.
+ *
+ * @see ArrayList#subList(int, int)
+ * @param start start position to begin new phrase
+ * @param end end position to end new phrase
+ * @return a new {@link org.apache.joshua.corpus.Phrase} object from the indexes provided.
+ */
+ Phrase subPhrase(int start, int end);
+
+
+ /**
+ * Compares the two strings based on the lexicographic order of words defined in the Vocabulary.
+ *
+ * @param other the object to compare to
+ * @return -1 if this object is less than the parameter, 0 if equals, 1 if greater
+ */
+ int compareTo(Phrase other);
+
+ /**
+ * Returns a human-readable String representation of the phrase.
+ *
+ * @return a human-readable String representation of the phrase.
+ */
+ String toString();
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/Span.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/Span.java b/joshua-core/src/main/java/org/apache/joshua/corpus/Span.java
new file mode 100644
index 0000000..414fe95
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/Span.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * Represents a span with an inclusive starting index and an exclusive ending index.
+ *
+ * @author Lane Schwartz
+ */
+public class Span implements Iterable<Integer>, Comparable<Span> {
+
+ /** Inclusive starting index of this span. */
+ public int start;
+
+ /** Exclusive ending index of this span. */
+ public int end;
+
+
+ /**
+ * Constructs a new span with the given inclusive starting and exclusive ending indices.
+ *
+ * @param start Inclusive starting index of this span.
+ * @param end Exclusive ending index of this span.
+ */
+ public Span(int start, int end) {
+ this.start = start;
+ this.end = end;
+ }
+
+
+ /**
+ * Returns the length of the span.
+ *
+ * @return the length of the span; this is equivalent to <code>span.end - span.start</code>.
+ */
+ public int size() {
+ return end - start;
+ }
+
+ /**
+ * Returns all subspans of the given Span.
+ *
+ * @return a list of all subspans.
+ */
+ public List<Span> getSubSpans() {
+ return getSubSpans(size());
+ }
+
+ /**
+ * Returns all subspans of the given Span, up to a specified Span size.
+ *
+ * @param max the maximum Span size to return
+ * @return a list all subspans up to the given size
+ */
+ public List<Span> getSubSpans(int max) {
+ int spanSize = size();
+ ArrayList<Span> result = new ArrayList<Span>(max * spanSize);
+ for (int len = max; len > 0; len--) {
+ for (int i = start; i < end - len + 1; i++) {
+ result.add(new Span(i, i + len));
+ }
+ }
+ return result;
+ }
+
+ public boolean strictlyContainedIn(Span o) {
+ return (start >= o.start) && (end <= o.end) && !(start == o.start && end == o.end);
+ }
+
+ /**
+ * Returns true if the other span does not intersect with this one.
+ * @param o new {@link org.apache.joshua.corpus.Span} to check for intersection
+ * @return true if the other span does not intersect with this one
+ */
+ public boolean disjointFrom(Span o) {
+ if (start < o.start) {
+ return end <= o.start;
+ }
+ if (end > o.end) {
+ return start >= o.end;
+ }
+ return false;
+ }
+
+ public String toString() {
+ return "[" + start + "-" + end + ")";
+ }
+
+
+ public Iterator<Integer> iterator() {
+ return new Iterator<Integer>() {
+
+ int next = start;
+
+ public boolean hasNext() {
+ return next < end;
+ }
+
+ public Integer next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ return next++;
+ }
+
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ };
+ }
+
+
+ public int compareTo(Span o) {
+
+ if (o == null) {
+ throw new NullPointerException();
+ } else {
+
+ if (start < o.start) {
+ return -1;
+ } else if (start > o.start) {
+ return 1;
+ } else {
+ if (end < o.end) {
+ return -1;
+ } else if (end > o.end) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+ }
+
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ } else if (o instanceof Span) {
+ Span other = (Span) o;
+ return (start == other.start && end == other.end);
+
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return start * 31 + end * 773;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/SymbolTable.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/SymbolTable.java b/joshua-core/src/main/java/org/apache/joshua/corpus/SymbolTable.java
new file mode 100644
index 0000000..274e8b9
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/SymbolTable.java
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import java.util.Collection;
+
+/**
+ * Represents a symbol table capable of mapping between strings and
+ * symbols.
+ *
+ * @author Lane Schwartz
+ * @author Zhifei Li
+ * @version $LastChangedDate: 2009-11-24 23:07:43 -0600 (Tue, 24 Nov 2009) $
+ */
+public interface SymbolTable {
+
+ //TODO Remove all hard-coded references to nonterminals
+
+ /**
+ * The unknown word's ID will be the size of the vocabulary,
+ * ensuring that it is outside of the vocabulary. Note that
+ * for vocabularies which have not been fixed yet, this
+ * means the actual value is volatile and therefore a word
+ * ID can only be compared against UNKNOWN_WORD at the time
+ * the word ID is generated (otherwise unknown words can
+ * become "known" if new words are added to the vocabulary
+ * before testing).
+ * <p>
+ * Negative IDs are reserved for non-terminals.
+ *
+ * Zero is reserved as the UNKNOWN_WORD.
+ */
+ int UNKNOWN_WORD = 1;
+
+ /** String representation for out-of-vocabulary words. */
+ String UNKNOWN_WORD_STRING = "<unk>";
+
+ /**
+ * Integer representation of the bare (non-indexed) nonterminal X,
+ * which represents a wild-card gap in a phrase.
+ * <p>
+ * All nonterminals are guaranteed to be represented by negative integers.
+ */
+ int X = -1;
+
+ /**
+ * String representation of the bare (non-indexed) nonterminal X,
+ * which represents a wild-card gap in a phrase.
+ */
+ String X_STRING = "[X]";
+
+
+
+ /**
+ * String representation of the nonterminal X with index 1,
+ * which represents a wild-card gap in a phrase.
+ */
+ String X1_STRING = "[X,1]";
+
+
+
+ /**
+ * String representation of the nonterminal X with index 2,
+ * which represents a wild-card gap in a phrase.
+ */
+ String X2_STRING = "[X,2]";
+
+ /**
+ * Integer representation of the nonterminal S.
+ * <p>
+ * All nonterminals are guaranteed to be represented by negative integers.
+ */
+ int S = -4;
+
+ /**
+ * String representation of the nonterminal S..
+ */
+ String S_STRING = "[S]";
+
+ /**
+ * Integer representation of the nonterminal X with index 1,
+ * which represents a wild-card gap in a phrase.
+ * <p>
+ * All nonterminals are guaranteed to be represented by negative integers.
+ */
+ int S1 = -5;
+
+ /**
+ * String representation of the nonterminal X with index 2,
+ * which represents a wild-card gap in a phrase.
+ */
+ String S1_STRING = "[S,1]";
+
+ /**
+ * Gets a unique integer identifier for the nonterminal.
+ * <p>
+ * The integer returned is guaranteed to be a negative number.
+ *
+ * If the nonterminal is {@link #X_STRING},
+ * then the value returned must be {@link #X}.
+ *
+ * Otherwise, the value returned must be a negative number
+ * whose value is less than {@link X}.
+ *
+ * @param nonterminal Nonterminal symbol
+ * @return a unique integer identifier for the nonterminal
+ */
+ int addNonterminal(String nonterminal);
+
+ /**
+ * Gets a unique integer identifier for the terminal.
+ *
+ * @param terminal Terminal symbol
+ * @return a unique integer identifier for the terminal
+ */
+ int addTerminal(String terminal);
+
+ /**
+ * Gets the unique integer identifiers for the words.
+ *
+ * @param words Array of symbols
+ * @return the unique integer identifiers for the words
+ */
+ int[] addTerminals(String[] words);
+
+ /**
+ * Gets the unique integer identifiers for the words
+ * in the sentence.
+ *
+ * @param sentence Space-delimited string of symbols
+ * @return the unique integer identifiers for the words
+ * in the sentence
+ */
+ int[] addTerminals(String sentence);
+
+ /**
+ * Gets an integer identifier for the word.
+ * <p>
+ * If the word is in the vocabulary, the integer returned
+ * will uniquely identify that word.
+ * <p>
+ * If the word is not in the vocabulary, the integer returned
+ * by <code>getUnknownWordID</code> may be returned.
+ *
+ * Alternatively, implementations may, if they choose, add
+ * unknown words and assign them a symbol ID instead of
+ * returning <code>getUnknownWordID</code>.
+ *
+ * @see #getUnknownWordID
+ * @return the unique integer identifier for wordString,
+ * or the result of <code>getUnknownWordID</code>
+ * if wordString is not in the vocabulary
+ * @param wordString the word to retrieve the integer identifier
+ */
+ int getID(String wordString);
+
+ /**
+ * Gets the integer identifiers for all words in the provided
+ * sentence.
+ * <p>
+ * The sentence will be split (on spaces) into words, then
+ * the integer identifier for each word will be retrieved
+ * using <code>getID</code>.
+ *
+ * @see #getID(String)
+ * @param sentence String of words, separated by spaces.
+ * @return Array of integer identifiers for each word in
+ * the sentence
+ */
+ int[] getIDs(String sentence);
+
+ /**
+ * Gets the String that corresponds to the specified integer
+ * identifier.
+ * <p>
+ * If the identifier is in the symbol vocabulary, the String
+ * returned will correspond to that identifier.
+ *
+ * Otherwise, the String returned by <code>getUnknownWord</code>
+ * will be returned.
+ *
+ * @param wordID an integer identifier for a specific String
+ * @return the String that corresponds to the specified
+ * integer identifier, or the result of
+ * <code>getUnknownWord</code> if the identifier
+ * does not correspond to a word in the vocabulary
+ */
+ String getTerminal(int wordID);
+
+ /**
+ * Gets the String that corresponds to the specified integer
+ * identifier.
+ * <p>
+ * This method can be called for terminals or nonterminals.
+ *
+ * @param tokenID Integer identifier
+ * @return the String that corresponds to the specified
+ * integer identifier
+ */
+ String getWord(int tokenID);
+
+ /**
+ * Gets the String that corresponds to the sequence of
+ * specified integer identifiers.
+ *
+ * @param ids Sequence of integer identifiers
+ * @return the String that corresponds to the sequence of
+ * specified integer identifiers
+ */
+ String getWords(int[] ids);
+
+ /**
+ *
+ * @param wordIDs an int[] of identifiers for a specific Strings
+ * @return the String that corresponds to the specified
+ * integer identifiers
+ */
+ String getTerminals(int[] wordIDs);
+
+ /**
+ * Gets a collection over all symbol identifiers for the
+ * vocabulary.
+ *
+ * @return a collection over all symbol identifiers for the
+ * vocabulary
+ */
+ Collection<Integer> getAllIDs();
+
+ /**
+ * Gets the list of all words represented by this vocabulary.
+ *
+ * @return the list of all words represented by this
+ * vocabulary
+ */
+ Collection<String> getWords();
+
+ /**
+ * Gets the number of unique words in the vocabulary.
+ *
+ * @return the number of unique words in the vocabulary.
+ */
+ int size();
+
+ /**
+ * Gets the integer symbol representation of the unknown
+ * word.
+ *
+ * @return the integer symbol representation of the unknown
+ * word.
+ */
+ int getUnknownWordID();
+
+ /**
+ * Gets the string representation of the unknown word.
+ *
+ * @return the string representation of the unknown word.
+ */
+ String getUnknownWord();
+
+ /**
+ * Returns <code>true</code> if the symbol id represents a
+ * nonterminal, <code>false</code> otherwise.
+ *
+ * @param id int symbol id
+ * @return <code>true</code> if the symbol id represents a
+ * nonterminal, <code>false</code> otherwise.
+ */
+ boolean isNonterminal(int id);
+
+ /**
+ * Gets the lowest-valued allowable terminal symbol id in
+ * this table.
+ *
+ * @return the lowest-valued allowable terminal symbol id
+ * in this table.
+ */
+ int getLowestID();
+
+
+ /**
+ * Gets the highest-valued allowable terminal symbol id in
+ * this table.
+ * <p>
+ * NOTE: This may or may not return the same value as
+ * <code>size</code>.
+ *
+ * @return the highest-valued allowable terminal symbol id
+ * in this table.
+ */
+ int getHighestID();
+
+ /**
+ * @param id todo
+ * @return todo
+ */
+ int getTargetNonterminalIndex(int id);//first convert id to its String mapping, then call the function below
+
+ /**
+ * @param word todo
+ * @return todo
+ */
+ int getTargetNonterminalIndex(String word);
+
+ /**
+ * @param wordIDs todo
+ * @param ntIndexIncrements todo
+ * @return todo
+ */
+ String getWords(int[] wordIDs, boolean ntIndexIncrements);
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/TerminalIterator.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/TerminalIterator.java b/joshua-core/src/main/java/org/apache/joshua/corpus/TerminalIterator.java
new file mode 100644
index 0000000..fcf5c72
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/TerminalIterator.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+import org.apache.joshua.util.FormatUtils;
+
+/**
+ * Iterator capable of iterating over those word identifiers in a phrase which represent terminals.
+ * <p>
+ * <em>Note</em>: This class is <em>not</em> thread-safe.
+ *
+ * @author Lane Schwartz
+ */
+public class TerminalIterator implements Iterator<Integer> {
+
+ private final int[] words;
+
+ private int nextIndex = -1;
+ private int next = Integer.MIN_VALUE;
+ private boolean dirty = true;
+
+ /**
+ * Constructs an iterator for the terminals in the given list of words.
+ *
+ * @param words array of words
+ */
+ public TerminalIterator(int[] words) {
+ this.words = words;
+ }
+
+ /* See Javadoc for java.util.Iterator#next(). */
+ public boolean hasNext() {
+
+ while (dirty || FormatUtils.isNonterminal(next)) {
+ nextIndex++;
+ if (nextIndex < words.length) {
+ next = words[nextIndex];
+ dirty = false;
+ } else {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ /* See Javadoc for java.util.Iterator#next(). */
+ public Integer next() {
+ if (hasNext()) {
+ dirty = true;
+ return next;
+ } else {
+ throw new NoSuchElementException();
+ }
+ }
+
+ /**
+ * Unsupported operation, guaranteed to throw an UnsupportedOperationException.
+ *
+ * @throws UnsupportedOperationException operation not supported yet!
+ */
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e2734396/joshua-core/src/main/java/org/apache/joshua/corpus/Vocabulary.java
----------------------------------------------------------------------
diff --git a/joshua-core/src/main/java/org/apache/joshua/corpus/Vocabulary.java b/joshua-core/src/main/java/org/apache/joshua/corpus/Vocabulary.java
new file mode 100644
index 0000000..24644ee
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/Vocabulary.java
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.Externalizable;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.StampedLock;
+
+import org.apache.joshua.decoder.ff.lm.NGramLanguageModel;
+import org.apache.joshua.util.FormatUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Static singular vocabulary class.
+ * Supports (de-)serialization into a vocabulary file.
+ *
+ * @author Juri Ganitkevitch
+ */
+
+public class Vocabulary implements Externalizable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(Vocabulary.class);
+ private final static ArrayList<NGramLanguageModel> LMs = new ArrayList<>();
+
+ private static List<String> idToString;
+ private static Map<String, Integer> stringToId;
+ private static final StampedLock lock = new StampedLock();
+
+ static final int UNKNOWN_ID = 0;
+ static final String UNKNOWN_WORD = "<unk>";
+
+ public static final String START_SYM = "<s>";
+ public static final String STOP_SYM = "</s>";
+
+ static {
+ clear();
+ }
+
+ public static boolean registerLanguageModel(NGramLanguageModel lm) {
+ long lock_stamp = lock.writeLock();
+ try {
+ // Store the language model.
+ LMs.add(lm);
+ // Notify it of all the existing words.
+ boolean collision = false;
+ for (int i = idToString.size() - 1; i > 0; i--)
+ collision = collision || lm.registerWord(idToString.get(i), i);
+ return collision;
+ } finally {
+ lock.unlockWrite(lock_stamp);
+ }
+ }
+
+ /**
+ * Reads a vocabulary from file. This deletes any additions to the vocabulary made prior to
+ * reading the file.
+ *
+ * @param vocab_file path to a vocabulary file
+ * @return Returns true if vocabulary was read without mismatches or collisions.
+ * @throws IOException of the file cannot be found or read properly
+ */
+ public static boolean read(final File vocab_file) throws IOException {
+ DataInputStream vocab_stream =
+ new DataInputStream(new BufferedInputStream(new FileInputStream(vocab_file)));
+ int size = vocab_stream.readInt();
+ LOG.info("Read {} entries from the vocabulary", size);
+ clear();
+ for (int i = 0; i < size; i++) {
+ int id = vocab_stream.readInt();
+ String token = vocab_stream.readUTF();
+ if (id != Math.abs(id(token))) {
+ vocab_stream.close();
+ return false;
+ }
+ }
+ vocab_stream.close();
+ return (size + 1 == idToString.size());
+ }
+
+ public static void write(String file_name) throws IOException {
+ long lock_stamp =lock.readLock();
+ try {
+ File vocab_file = new File(file_name);
+ DataOutputStream vocab_stream =
+ new DataOutputStream(new BufferedOutputStream(new FileOutputStream(vocab_file)));
+ vocab_stream.writeInt(idToString.size() - 1);
+ LOG.info("Writing vocabulary: {} tokens", idToString.size() - 1);
+ for (int i = 1; i < idToString.size(); i++) {
+ vocab_stream.writeInt(i);
+ vocab_stream.writeUTF(idToString.get(i));
+ }
+ vocab_stream.close();
+ }
+ finally{
+ lock.unlockRead(lock_stamp);
+ }
+ }
+
+ /**
+ * Get the id of the token if it already exists, new id is created otherwise.
+ *
+ * TODO: currently locks for every call. Separate constant (frozen) ids from
+ * changing (e.g. OOV) ids. Constant ids could be immutable -> no locking.
+ * Alternatively: could we use ConcurrentHashMap to not have to lock if
+ * actually contains it and only lock for modifications?
+ *
+ * @param token a token to obtain an id for
+ * @return the token id
+ */
+ public static int id(String token) {
+ // First attempt an optimistic read
+ long attempt_read_lock = lock.tryOptimisticRead();
+ if (stringToId.containsKey(token)) {
+ int resultId = stringToId.get(token);
+ if (lock.validate(attempt_read_lock)) {
+ return resultId;
+ }
+ }
+
+ // The optimistic read failed, try a read with a stamped read lock
+ long read_lock_stamp = lock.readLock();
+ try {
+ if (stringToId.containsKey(token)) {
+ return stringToId.get(token);
+ }
+ } finally {
+ lock.unlockRead(read_lock_stamp);
+ }
+
+ // Looks like the id we want is not there, let's get a write lock and add it
+ long write_lock_stamp = lock.writeLock();
+ try {
+ if (stringToId.containsKey(token)) {
+ return stringToId.get(token);
+ }
+ int id = idToString.size() * (FormatUtils.isNonterminal(token) ? -1 : 1);
+
+ // register this (token,id) mapping with each language
+ // model, so that they can map it to their own private
+ // vocabularies
+ for (NGramLanguageModel lm : LMs)
+ lm.registerWord(token, Math.abs(id));
+
+ idToString.add(token);
+ stringToId.put(token, id);
+ return id;
+ } finally {
+ lock.unlockWrite(write_lock_stamp);
+ }
+ }
+
+ public static boolean hasId(int id) {
+ long lock_stamp = lock.readLock();
+ try {
+ id = Math.abs(id);
+ return (id < idToString.size());
+ }
+ finally{
+ lock.unlockRead(lock_stamp);
+ }
+ }
+
+ public static int[] addAll(String sentence) {
+ return addAll(sentence.split("\\s+"));
+ }
+
+ public static int[] addAll(String[] tokens) {
+ int[] ids = new int[tokens.length];
+ for (int i = 0; i < tokens.length; i++)
+ ids[i] = id(tokens[i]);
+ return ids;
+ }
+
+ public static String word(int id) {
+ long lock_stamp = lock.readLock();
+ try {
+ id = Math.abs(id);
+ return idToString.get(id);
+ }
+ finally{
+ lock.unlockRead(lock_stamp);
+ }
+ }
+
+ public static String getWords(int[] ids) {
+ return getWords(ids, " ");
+ }
+
+ public static String getWords(int[] ids, final String separator) {
+ if (ids.length == 0) {
+ return "";
+ }
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < ids.length - 1; i++) {
+ sb.append(word(ids[i])).append(separator);
+ }
+ return sb.append(word(ids[ids.length - 1])).toString();
+ }
+
+ public static String getWords(final Iterable<Integer> ids) {
+ StringBuilder sb = new StringBuilder();
+ for (int id : ids)
+ sb.append(word(id)).append(" ");
+ return sb.deleteCharAt(sb.length() - 1).toString();
+ }
+
+ public static int getUnknownId() {
+ return UNKNOWN_ID;
+ }
+
+ public static String getUnknownWord() {
+ return UNKNOWN_WORD;
+ }
+
+ public static int size() {
+ long lock_stamp = lock.readLock();
+ try {
+ return idToString.size();
+ } finally {
+ lock.unlockRead(lock_stamp);
+ }
+ }
+
+ public static synchronized int getTargetNonterminalIndex(int id) {
+ return FormatUtils.getNonterminalIndex(word(id));
+ }
+
+ /**
+ * Clears the vocabulary and initializes it with an unknown word. Registered
+ * language models are left unchanged.
+ */
+ public static void clear() {
+ long lock_stamp = lock.writeLock();
+ try {
+ idToString = new ArrayList<String>();
+ stringToId = new HashMap<String, Integer>();
+
+ idToString.add(UNKNOWN_ID, UNKNOWN_WORD);
+ stringToId.put(UNKNOWN_WORD, UNKNOWN_ID);
+ } finally {
+ lock.unlockWrite(lock_stamp);
+ }
+ }
+
+ public static void unregisterLanguageModels() {
+ LMs.clear();
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // TODO Auto-generated method stub
+
+ }
+
+ @Override
+ public void readExternal(ObjectInput in)
+ throws IOException, ClassNotFoundException {
+ // TODO Auto-generated method stub
+
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if(getClass() == o.getClass()) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+}