You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/08/30 21:04:53 UTC
[08/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
index 5861052,0000000..838279b
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
+++ b/joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
@@@ -1,936 -1,0 +1,940 @@@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.tools;
+
+import static org.apache.joshua.decoder.ff.tm.OwnerMap.UNKNOWN_OWNER_ID;
+import static org.apache.joshua.decoder.ff.tm.packed.PackedGrammar.VOCABULARY_FILENAME;
+
+import java.io.BufferedOutputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Queue;
+import java.util.TreeMap;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.RuleFactory;
+import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
+import org.apache.joshua.decoder.ff.tm.format.MosesFormatReader;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.encoding.EncoderConfiguration;
+import org.apache.joshua.util.encoding.FeatureTypeAnalyzer;
+import org.apache.joshua.util.encoding.IntEncoder;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public class GrammarPacker {
+
+ private static final Logger LOG = LoggerFactory.getLogger(GrammarPacker.class);
+
+ /**
+ * The packed grammar version number. Increment this any time you add new features, and update
+ * the documentation.
- *
++ *
+ * Version history:
- *
++ *
+ * - 3 (May 2016). This was the first version that was marked. It removed the special phrase-
+ * table packing that packed phrases without the [X,1] on the source and target sides, which
+ * then required special handling in the decoder to use for phrase-based decoding.
- *
++ *
+ * - 4 (August 2016). Phrase-based decoding rewritten to represent phrases without a builtin
+ * nonterminal. Instead, cost-less glue rules are used in phrase-based decoding. This eliminates
+ * the need for special handling of phrase grammars (except for having to add a LHS), and lets
+ * phrase grammars be used in both hierarchical and phrase-based decoding without conversion.
+ *
+ */
+ public static final int VERSION = 4;
-
++
+ // Size limit for slice in bytes.
+ private static int DATA_SIZE_LIMIT = (int) (Integer.MAX_VALUE * 0.8);
+ // Estimated average number of feature entries for one rule.
+ private static int DATA_SIZE_ESTIMATE = 20;
+
+ private static final String SOURCE_WORDS_SEPARATOR = " ||| ";
+
+ // Output directory name.
+ private String output;
+
+ // Input grammar to be packed.
+ private String grammar;
+
+ public String getGrammar() {
+ return grammar;
+ }
+
+ public String getOutputDirectory() {
+ return output;
+ }
+
+ // Approximate maximum size of a slice in number of rules
+ private int approximateMaximumSliceSize;
+
+ private boolean labeled;
+
+ private boolean packAlignments;
+ private boolean grammarAlignments;
+ private String alignments;
+
+ private FeatureTypeAnalyzer types;
+ private EncoderConfiguration encoderConfig;
+
+ private String dump;
+
+ private int max_source_len;
+
+ public GrammarPacker(String grammar_filename, String config_filename, String output_filename,
+ String alignments_filename, String featuredump_filename, boolean grammar_alignments,
+ int approximateMaximumSliceSize)
+ throws IOException {
+ this.labeled = true;
+ this.grammar = grammar_filename;
+ this.output = output_filename;
+ this.dump = featuredump_filename;
+ this.grammarAlignments = grammar_alignments;
+ this.approximateMaximumSliceSize = approximateMaximumSliceSize;
+ this.max_source_len = 0;
+
+ // TODO: Always open encoder config? This is debatable.
+ this.types = new FeatureTypeAnalyzer(true);
+
+ this.alignments = alignments_filename;
+ packAlignments = grammarAlignments || (alignments != null);
+ if (!packAlignments) {
+ LOG.info("No alignments file or grammar specified, skipping.");
+ } else if (alignments != null && !new File(alignments_filename).exists()) {
+ throw new RuntimeException("Alignments file does not exist: " + alignments);
+ }
+
+ if (config_filename != null) {
+ readConfig(config_filename);
+ types.readConfig(config_filename);
+ } else {
+ LOG.info("No config specified. Attempting auto-detection of feature types.");
+ }
+ LOG.info("Approximate maximum slice size (in # of rules) set to {}", approximateMaximumSliceSize);
+
+ File working_dir = new File(output);
+ working_dir.mkdir();
+ if (!working_dir.exists()) {
+ throw new RuntimeException("Failed creating output directory.");
+ }
+ }
+
+ private void readConfig(String config_filename) throws IOException {
- LineReader reader = new LineReader(config_filename);
- while (reader.hasNext()) {
- // Clean up line, chop comments off and skip if the result is empty.
- String line = reader.next().trim();
- if (line.indexOf('#') != -1)
- line = line.substring(0, line.indexOf('#'));
- if (line.isEmpty())
- continue;
- String[] fields = line.split("[\\s]+");
-
- if (fields.length < 2) {
- throw new RuntimeException("Incomplete line in config.");
- }
- if ("slice_size".equals(fields[0])) {
- // Number of records to concurrently load into memory for sorting.
- approximateMaximumSliceSize = Integer.parseInt(fields[1]);
++ try(LineReader reader = new LineReader(config_filename);) {
++ while (reader.hasNext()) {
++ // Clean up line, chop comments off and skip if the result is empty.
++ String line = reader.next().trim();
++ if (line.indexOf('#') != -1)
++ line = line.substring(0, line.indexOf('#'));
++ if (line.isEmpty())
++ continue;
++ String[] fields = line.split("[\\s]+");
++
++ if (fields.length < 2) {
++ throw new RuntimeException("Incomplete line in config.");
++ }
++ if ("slice_size".equals(fields[0])) {
++ // Number of records to concurrently load into memory for sorting.
++ approximateMaximumSliceSize = Integer.parseInt(fields[1]);
++ }
+ }
+ }
- reader.close();
+ }
+
+ /**
+ * Executes the packing.
- *
++ *
+ * @throws IOException if there is an error reading the grammar
+ */
+ public void pack() throws IOException {
+ LOG.info("Beginning exploration pass.");
+
+ // Explore pass. Learn vocabulary and feature value histograms.
+ LOG.info("Exploring: {}", grammar);
+
+ HieroFormatReader grammarReader = getGrammarReader();
+ explore(grammarReader);
+
+ LOG.info("Exploration pass complete. Freezing vocabulary and finalizing encoders.");
+ if (dump != null) {
+ PrintWriter dump_writer = new PrintWriter(dump);
+ dump_writer.println(types.toString());
+ dump_writer.close();
+ }
+
+ types.inferTypes(this.labeled);
+ LOG.info("Type inference complete.");
+
+ LOG.info("Finalizing encoding.");
+
+ LOG.info("Writing encoding.");
+ types.write(output + File.separator + "encoding");
+
+ writeVocabulary();
+
+ String configFile = output + File.separator + "config";
+ LOG.info("Writing config to '{}'", configFile);
+ // Write config options
+ FileWriter config = new FileWriter(configFile);
+ config.write(String.format("version = %d\n", VERSION));
+ config.write(String.format("max-source-len = %d\n", max_source_len));
+ config.close();
+
+ // Read previously written encoder configuration to match up to changed
+ // vocabulary id's.
+ LOG.info("Reading encoding.");
+ encoderConfig = new EncoderConfiguration();
+ encoderConfig.load(output + File.separator + "encoding");
+
+ LOG.info("Beginning packing pass.");
+ // Actual binarization pass. Slice and pack source, target and data.
+ grammarReader = getGrammarReader();
+ LineReader alignment_reader = null;
+ if (packAlignments && !grammarAlignments)
+ alignment_reader = new LineReader(alignments);
+ binarize(grammarReader, alignment_reader);
+ LOG.info("Packing complete.");
+
+ LOG.info("Packed grammar in: {}", output);
+ LOG.info("Done.");
+ }
+
+ /**
+ * Returns a reader that turns whatever file format is found into unowned Hiero grammar rules.
+ * This means, features are NOT prepended with an owner string at packing time.
- *
++ *
+ * @param grammarFile
+ * @return GrammarReader of correct Format
+ * @throws IOException
+ */
+ private HieroFormatReader getGrammarReader() throws IOException {
- LineReader reader = new LineReader(grammar);
- String line = reader.next();
- if (line.startsWith("[")) {
- return new HieroFormatReader(grammar, UNKNOWN_OWNER_ID);
- } else {
- return new MosesFormatReader(grammar, UNKNOWN_OWNER_ID);
++ try (LineReader reader = new LineReader(grammar);) {
++ String line = reader.next();
++ if (line.startsWith("["))
++ return new HieroFormatReader(grammar, UNKNOWN_OWNER_ID);
++ else
++ return new MosesFormatReader(grammar, UNKNOWN_OWNER_ID);
+ }
+ }
+
+ /**
- * This first pass over the grammar
++ * This first pass over the grammar
+ * @param reader
+ */
+ private void explore(HieroFormatReader reader) {
+
+ // We always assume a labeled grammar. Unlabeled features are assumed to be dense and to always
+ // appear in the same order. They are assigned numeric names in order of appearance.
+ this.types.setLabeled(true);
+
+ for (Rule rule : reader) {
+
+ max_source_len = Math.max(max_source_len, rule.getSource().length);
+
+ /* Add symbols to vocabulary.
+ * NOTE: In case of nonterminals, we add both stripped versions ("[X]")
+ * and "[X,1]" to the vocabulary.
- *
++ *
+ * TODO: MJP May 2016: Is it necessary to add [X,1]? This is currently being done in
- * {@link HieroFormatReader}, which is called by {@link MosesFormatReader}.
++ * {@link HieroFormatReader}, which is called by {@link MosesFormatReader}.
+ */
+
+ // pass the value through the appropriate encoder.
+ for (final Entry<Integer, Float> entry : rule.getFeatureVector().entrySet()) {
+ types.observe(entry.getKey(), entry.getValue());
+ }
+ }
+ }
+
+ /**
+ * Returns a String encoding the first two source words.
+ * If there is only one source word, use empty string for the second.
+ */
+ private String getFirstTwoSourceWords(final String[] source_words) {
+ return source_words[0] + SOURCE_WORDS_SEPARATOR + ((source_words.length > 1) ? source_words[1] : "");
+ }
+
+ private void binarize(HieroFormatReader grammarReader, LineReader alignment_reader) throws IOException {
+ int counter = 0;
+ int slice_counter = 0;
+ int num_slices = 0;
+
+ boolean ready_to_flush = false;
+ // to determine when flushing is possible
+ String prev_first_two_source_words = null;
+
+ PackingTrie<SourceValue> source_trie = new PackingTrie<SourceValue>();
+ PackingTrie<TargetValue> target_trie = new PackingTrie<TargetValue>();
+ FeatureBuffer feature_buffer = new FeatureBuffer();
+
+ AlignmentBuffer alignment_buffer = null;
+ if (packAlignments)
+ alignment_buffer = new AlignmentBuffer();
+
+ TreeMap<Integer, Float> features = new TreeMap<Integer, Float>();
+ for (Rule rule : grammarReader) {
+ counter++;
+ slice_counter++;
+
+ String lhs_word = Vocabulary.word(rule.getLHS());
+ String[] source_words = rule.getSourceWords().split("\\s+");
+ String[] target_words = rule.getTargetWords().split("\\s+");
+
+ // Reached slice limit size, indicate that we're closing up.
+ if (!ready_to_flush
+ && (slice_counter > approximateMaximumSliceSize
+ || feature_buffer.overflowing()
+ || (packAlignments && alignment_buffer.overflowing()))) {
+ ready_to_flush = true;
+ // store the first two source words when slice size limit was reached
+ prev_first_two_source_words = getFirstTwoSourceWords(source_words);
+ }
+ // ready to flush
+ if (ready_to_flush) {
+ final String first_two_source_words = getFirstTwoSourceWords(source_words);
+ // the grammar can only be partitioned at the level of first two source word changes.
+ // Thus, we can only flush if the current first two source words differ from the ones
+ // when the slice size limit was reached.
+ if (!first_two_source_words.equals(prev_first_two_source_words)) {
+ LOG.warn("ready to flush and first two words have changed ({} vs. {})",
+ prev_first_two_source_words, first_two_source_words);
+ LOG.info("flushing {} rules to slice.", slice_counter);
+ flush(source_trie, target_trie, feature_buffer, alignment_buffer, num_slices);
+ source_trie.clear();
+ target_trie.clear();
+ feature_buffer.clear();
+ if (packAlignments)
+ alignment_buffer.clear();
+
+ num_slices++;
+ slice_counter = 0;
+ ready_to_flush = false;
+ }
+ }
+
+ int alignment_index = -1;
+ // If present, process alignments.
+ if (packAlignments) {
+ byte[] alignments = null;
+ if (grammarAlignments) {
+ alignments = rule.getAlignment();
+ } else {
+ if (!alignment_reader.hasNext()) {
+ LOG.error("No more alignments starting in line {}", counter);
+ throw new RuntimeException("No more alignments starting in line " + counter);
+ }
+ alignments = RuleFactory.parseAlignmentString(alignment_reader.next().trim());
+ }
+ alignment_index = alignment_buffer.add(alignments);
+ }
+
+ // Process features.
+ // Implicitly sort via TreeMap, write to data buffer, remember position
+ // to pass on to the source trie node.
+ features.clear();
+ for (Entry<Integer, Float> entry : rule.getFeatureVector().entrySet()) {
+ int featureId = entry.getKey();
+ float featureValue = entry.getValue();
+ if (featureValue != 0f) {
+ features.put(encoderConfig.innerId(featureId), featureValue);
+ }
+ }
+
+ int features_index = feature_buffer.add(features);
+
+ // Sanity check on the data block index.
+ if (packAlignments && features_index != alignment_index) {
+ LOG.error("Block index mismatch between features ({}) and alignments ({}).",
+ features_index, alignment_index);
+ throw new RuntimeException("Data block index mismatch.");
+ }
+
+ // Process source side.
+ SourceValue sv = new SourceValue(Vocabulary.id(lhs_word), features_index);
+ int[] source = new int[source_words.length];
+ for (int i = 0; i < source_words.length; i++) {
+ if (FormatUtils.isNonterminal(source_words[i]))
+ source[i] = Vocabulary.id(FormatUtils.stripNonTerminalIndex(source_words[i]));
+ else
+ source[i] = Vocabulary.id(source_words[i]);
+ }
+ source_trie.add(source, sv);
+
+ // Process target side.
+ TargetValue tv = new TargetValue(sv);
+ int[] target = new int[target_words.length];
+ for (int i = 0; i < target_words.length; i++) {
+ if (FormatUtils.isNonterminal(target_words[i])) {
+ target[target_words.length - (i + 1)] = -FormatUtils.getNonterminalIndex(target_words[i]);
+ } else {
+ target[target_words.length - (i + 1)] = Vocabulary.id(target_words[i]);
+ }
+ }
+ target_trie.add(target, tv);
+ }
+ // flush last slice and clear buffers
+ flush(source_trie, target_trie, feature_buffer, alignment_buffer, num_slices);
+ }
+
+ /**
+ * Serializes the source, target and feature data structures into interlinked binary files. Target
+ * is written first, into a skeletal (node don't carry any data) upward-pointing trie, updating
+ * the linking source trie nodes with the position once it is known. Source and feature data are
+ * written simultaneously. The source structure is written into a downward-pointing trie and
+ * stores the rule's lhs as well as links to the target and feature stream. The feature stream is
+ * prompted to write out a block
- *
++ *
+ * @param source_trie
+ * @param target_trie
+ * @param feature_buffer
+ * @param id
+ * @throws IOException
+ */
+ private void flush(PackingTrie<SourceValue> source_trie,
+ PackingTrie<TargetValue> target_trie, FeatureBuffer feature_buffer,
+ AlignmentBuffer alignment_buffer, int id) throws IOException {
+ // Make a slice object for this piece of the grammar.
+ PackingFileTuple slice = new PackingFileTuple("slice_" + String.format("%05d", id));
+ // Pull out the streams for source, target and data output.
+ DataOutputStream source_stream = slice.getSourceOutput();
+ DataOutputStream target_stream = slice.getTargetOutput();
+ DataOutputStream target_lookup_stream = slice.getTargetLookupOutput();
+ DataOutputStream feature_stream = slice.getFeatureOutput();
+ DataOutputStream alignment_stream = slice.getAlignmentOutput();
+
+ Queue<PackingTrie<TargetValue>> target_queue;
+ Queue<PackingTrie<SourceValue>> source_queue;
+
+ // The number of bytes both written into the source stream and
+ // buffered in the source queue.
+ int source_position;
+ // The number of bytes written into the target stream.
+ int target_position;
+
+ // Add trie root into queue, set target position to 0 and set cumulated
+ // size to size of trie root.
+ target_queue = new LinkedList<PackingTrie<TargetValue>>();
+ target_queue.add(target_trie);
+ target_position = 0;
+
+ // Target lookup table for trie levels.
+ int current_level_size = 1;
+ int next_level_size = 0;
+ ArrayList<Integer> target_lookup = new ArrayList<Integer>();
+
+ // Packing loop for upwards-pointing target trie.
+ while (!target_queue.isEmpty()) {
+ // Pop top of queue.
+ PackingTrie<TargetValue> node = target_queue.poll();
+ // Register that this is where we're writing the node to.
+ node.address = target_position;
+ // Tell source nodes that we're writing to this position in the file.
+ for (TargetValue tv : node.values)
+ tv.parent.target = node.address;
+ // Write link to parent.
+ if (node.parent != null)
+ target_stream.writeInt(node.parent.address);
+ else
+ target_stream.writeInt(-1);
+ target_stream.writeInt(node.symbol);
+ // Enqueue children.
+ for (int k : node.children.descendingKeySet()) {
+ PackingTrie<TargetValue> child = node.children.get(k);
+ target_queue.add(child);
+ }
+ target_position += node.size(false, true);
+ next_level_size += node.children.descendingKeySet().size();
+
+ current_level_size--;
+ if (current_level_size == 0) {
+ target_lookup.add(target_position);
+ current_level_size = next_level_size;
+ next_level_size = 0;
+ }
+ }
+ target_lookup_stream.writeInt(target_lookup.size());
+ for (int i : target_lookup)
+ target_lookup_stream.writeInt(i);
+ target_lookup_stream.close();
+
+ // Setting up for source and data writing.
+ source_queue = new LinkedList<PackingTrie<SourceValue>>();
+ source_queue.add(source_trie);
+ source_position = source_trie.size(true, false);
+ source_trie.address = target_position;
+
+ // Ready data buffers for writing.
+ feature_buffer.initialize();
+ if (packAlignments)
+ alignment_buffer.initialize();
+
+ // Packing loop for downwards-pointing source trie.
+ while (!source_queue.isEmpty()) {
+ // Pop top of queue.
+ PackingTrie<SourceValue> node = source_queue.poll();
+ // Write number of children.
+ source_stream.writeInt(node.children.size());
+ // Write links to children.
+ for (int k : node.children.descendingKeySet()) {
+ PackingTrie<SourceValue> child = node.children.get(k);
+ // Enqueue child.
+ source_queue.add(child);
+ // Child's address will be at the current end of the queue.
+ child.address = source_position;
+ // Advance cumulated size by child's size.
+ source_position += child.size(true, false);
+ // Write the link.
+ source_stream.writeInt(k);
+ source_stream.writeInt(child.address);
+ }
+ // Write number of data items.
+ source_stream.writeInt(node.values.size());
+ // Write lhs and links to target and data.
+ for (SourceValue sv : node.values) {
+ int feature_block_index = feature_buffer.write(sv.data);
+ if (packAlignments) {
+ int alignment_block_index = alignment_buffer.write(sv.data);
+ if (alignment_block_index != feature_block_index) {
+ LOG.error("Block index mismatch.");
+ throw new RuntimeException("Block index mismatch: alignment (" + alignment_block_index
+ + ") and features (" + feature_block_index + ") don't match.");
+ }
+ }
+ source_stream.writeInt(sv.lhs);
+ source_stream.writeInt(sv.target);
+ source_stream.writeInt(feature_block_index);
+ }
+ }
+ // Flush the data stream.
+ feature_buffer.flush(feature_stream);
+ if (packAlignments)
+ alignment_buffer.flush(alignment_stream);
+
+ target_stream.close();
+ source_stream.close();
+ feature_stream.close();
+ if (packAlignments)
+ alignment_stream.close();
+ }
+
+ public void writeVocabulary() throws IOException {
+ final String vocabularyFilename = output + File.separator + VOCABULARY_FILENAME;
+ LOG.info("Writing vocabulary to {}", vocabularyFilename);
+ Vocabulary.write(vocabularyFilename);
+ }
+
+ /**
+ * Integer-labeled, doubly-linked trie with some provisions for packing.
- *
++ *
+ * @author Juri Ganitkevitch
- *
++ *
+ * @param <D> The trie's value type.
+ */
+ class PackingTrie<D extends PackingTrieValue> {
+ int symbol;
+ PackingTrie<D> parent;
+
+ TreeMap<Integer, PackingTrie<D>> children;
+ List<D> values;
+
+ int address;
+
+ PackingTrie() {
+ address = -1;
+
+ symbol = 0;
+ parent = null;
+
+ children = new TreeMap<Integer, PackingTrie<D>>();
+ values = new ArrayList<D>();
+ }
+
+ PackingTrie(PackingTrie<D> parent, int symbol) {
+ this();
+ this.parent = parent;
+ this.symbol = symbol;
+ }
+
+ void add(int[] path, D value) {
+ add(path, 0, value);
+ }
+
+ private void add(int[] path, int index, D value) {
+ if (index == path.length)
+ this.values.add(value);
+ else {
+ PackingTrie<D> child = children.get(path[index]);
+ if (child == null) {
+ child = new PackingTrie<D>(this, path[index]);
+ children.put(path[index], child);
+ }
+ child.add(path, index + 1, value);
+ }
+ }
+
+ /**
+ * Calculate the size (in ints) of a packed trie node. Distinguishes downwards pointing (parent
+ * points to children) from upwards pointing (children point to parent) tries, as well as
+ * skeletal (no data, just the labeled links) and non-skeletal (nodes have a data block)
+ * packing.
- *
++ *
+ * @param downwards Are we packing into a downwards-pointing trie?
+ * @param skeletal Are we packing into a skeletal trie?
- *
++ *
+ * @return Number of bytes the trie node would occupy.
+ */
+ int size(boolean downwards, boolean skeletal) {
+ int size = 0;
+ if (downwards) {
+ // Number of children and links to children.
+ size = 1 + 2 * children.size();
+ } else {
+ // Link to parent.
+ size += 2;
+ }
+ // Non-skeletal packing: number of data items.
+ if (!skeletal)
+ size += 1;
+ // Non-skeletal packing: write size taken up by data items.
+ if (!skeletal && !values.isEmpty())
+ size += values.size() * values.get(0).size();
+
+ return size;
+ }
+
+ void clear() {
+ children.clear();
+ values.clear();
+ }
+ }
+
+ interface PackingTrieValue {
+ int size();
+ }
+
+ class SourceValue implements PackingTrieValue {
+ int lhs;
+ int data;
+ int target;
+
+ public SourceValue() {
+ }
+
+ SourceValue(int lhs, int data) {
+ this.lhs = lhs;
+ this.data = data;
+ }
+
+ void setTarget(int target) {
+ this.target = target;
+ }
+
++ @Override
+ public int size() {
+ return 3;
+ }
+ }
+
+ class TargetValue implements PackingTrieValue {
+ SourceValue parent;
+
+ TargetValue(SourceValue parent) {
+ this.parent = parent;
+ }
+
++ @Override
+ public int size() {
+ return 0;
+ }
+ }
+
+ abstract class PackingBuffer<T> {
+ private byte[] backing;
+ protected ByteBuffer buffer;
+
+ protected ArrayList<Integer> memoryLookup;
+ protected int totalSize;
+ protected ArrayList<Integer> onDiskOrder;
+
+ PackingBuffer() throws IOException {
+ allocate();
+ memoryLookup = new ArrayList<Integer>();
+ onDiskOrder = new ArrayList<Integer>();
+ totalSize = 0;
+ }
+
+ abstract int add(T item);
+
+ // Allocate a reasonably-sized buffer for the feature data.
+ private void allocate() {
+ backing = new byte[approximateMaximumSliceSize * DATA_SIZE_ESTIMATE];
+ buffer = ByteBuffer.wrap(backing);
+ }
+
+ // Reallocate the backing array and buffer, copies data over.
+ protected void reallocate() {
+ if (backing.length == Integer.MAX_VALUE)
+ return;
+ long attempted_length = backing.length * 2l;
+ int new_length;
+ // Detect overflow.
+ if (attempted_length >= Integer.MAX_VALUE)
+ new_length = Integer.MAX_VALUE;
+ else
+ new_length = (int) attempted_length;
+ byte[] new_backing = new byte[new_length];
+ System.arraycopy(backing, 0, new_backing, 0, backing.length);
+ int old_position = buffer.position();
+ ByteBuffer new_buffer = ByteBuffer.wrap(new_backing);
+ new_buffer.position(old_position);
+ buffer = new_buffer;
+ backing = new_backing;
+ }
+
+ /**
+ * Prepare the data buffer for disk writing.
+ */
+ void initialize() {
+ onDiskOrder.clear();
+ }
+
+ /**
+ * Enqueue a data block for later writing.
- *
++ *
+ * @param block_index The index of the data block to add to writing queue.
+ * @return The to-be-written block's output index.
+ */
+ int write(int block_index) {
+ onDiskOrder.add(block_index);
+ return onDiskOrder.size() - 1;
+ }
+
+ /**
+ * Performs the actual writing to disk in the order specified by calls to write() since the last
+ * call to initialize().
- *
++ *
+ * @param out
+ * @throws IOException
+ */
+ void flush(DataOutputStream out) throws IOException {
+ writeHeader(out);
+ int size;
+ int block_address;
+ for (int block_index : onDiskOrder) {
+ block_address = memoryLookup.get(block_index);
+ size = blockSize(block_index);
+ out.write(backing, block_address, size);
+ }
+ }
+
+ void clear() {
+ buffer.clear();
+ memoryLookup.clear();
+ onDiskOrder.clear();
+ }
+
+ boolean overflowing() {
+ return (buffer.position() >= DATA_SIZE_LIMIT);
+ }
+
+ private void writeHeader(DataOutputStream out) throws IOException {
+ if (out.size() == 0) {
+ out.writeInt(onDiskOrder.size());
+ out.writeInt(totalSize);
+ int disk_position = headerSize();
+ for (int block_index : onDiskOrder) {
+ out.writeInt(disk_position);
+ disk_position += blockSize(block_index);
+ }
+ } else {
+ throw new RuntimeException("Got a used stream for header writing.");
+ }
+ }
+
+ private int headerSize() {
+ // One integer for each data block, plus number of blocks and total size.
+ return 4 * (onDiskOrder.size() + 2);
+ }
+
+ private int blockSize(int block_index) {
+ int block_address = memoryLookup.get(block_index);
+ return (block_index < memoryLookup.size() - 1 ? memoryLookup.get(block_index + 1) : totalSize)
+ - block_address;
+ }
+ }
+
+ class FeatureBuffer extends PackingBuffer<TreeMap<Integer, Float>> {
+
+ private IntEncoder idEncoder;
+
+ FeatureBuffer() throws IOException {
+ super();
+ idEncoder = types.getIdEncoder();
+ LOG.info("Encoding feature ids in: {}", idEncoder.getKey());
+ }
+
+ /**
+ * Add a block of features to the buffer.
- *
++ *
+ * @param features TreeMap with the features for one rule.
+ * @return The index of the resulting data block.
+ */
++ @Override
+ int add(TreeMap<Integer, Float> features) {
+ int data_position = buffer.position();
+
+ // Over-estimate how much room this addition will need: for each
+ // feature (ID_SIZE for label, "upper bound" of 4 for the value), plus ID_SIZE for
+ // the number of features. If this won't fit, reallocate the buffer.
+ int size_estimate = (4 + EncoderConfiguration.ID_SIZE) * features.size()
+ + EncoderConfiguration.ID_SIZE;
+ if (buffer.capacity() - buffer.position() <= size_estimate)
+ reallocate();
+
+ // Write features to buffer.
+ idEncoder.write(buffer, features.size());
+ for (Integer k : features.descendingKeySet()) {
+ float v = features.get(k);
+ // Sparse features.
+ if (v != 0.0) {
+ idEncoder.write(buffer, k);
+ encoderConfig.encoder(k).write(buffer, v);
+ }
+ }
+ // Store position the block was written to.
+ memoryLookup.add(data_position);
+ // Update total size (in bytes).
+ totalSize = buffer.position();
+
+ // Return block index.
+ return memoryLookup.size() - 1;
+ }
+ }
+
+ class AlignmentBuffer extends PackingBuffer<byte[]> {
+
+ AlignmentBuffer() throws IOException {
+ super();
+ }
+
+ /**
+ * Add a rule alignments to the buffer.
- *
++ *
+ * @param alignments a byte array with the alignment points for one rule.
+ * @return The index of the resulting data block.
+ */
++ @Override
+ int add(byte[] alignments) {
+ int data_position = buffer.position();
+ int size_estimate = alignments.length + 1;
+ if (buffer.capacity() - buffer.position() <= size_estimate)
+ reallocate();
+
+ // Write alignment points to buffer.
+ buffer.put((byte) (alignments.length / 2));
+ buffer.put(alignments);
+
+ // Store position the block was written to.
+ memoryLookup.add(data_position);
+ // Update total size (in bytes).
+ totalSize = buffer.position();
+ // Return block index.
+ return memoryLookup.size() - 1;
+ }
+ }
+
+ class PackingFileTuple implements Comparable<PackingFileTuple> {
+ private File sourceFile;
+ private File targetLookupFile;
+ private File targetFile;
+
+ private File featureFile;
+ private File alignmentFile;
+
+ PackingFileTuple(String prefix) {
+ sourceFile = new File(output + File.separator + prefix + ".source");
+ targetFile = new File(output + File.separator + prefix + ".target");
+ targetLookupFile = new File(output + File.separator + prefix + ".target.lookup");
+ featureFile = new File(output + File.separator + prefix + ".features");
+
+ alignmentFile = null;
+ if (packAlignments)
+ alignmentFile = new File(output + File.separator + prefix + ".alignments");
+
+ LOG.info("Allocated slice: {}", sourceFile.getAbsolutePath());
+ }
+
+ DataOutputStream getSourceOutput() throws IOException {
+ return getOutput(sourceFile);
+ }
+
+ DataOutputStream getTargetOutput() throws IOException {
+ return getOutput(targetFile);
+ }
+
+ DataOutputStream getTargetLookupOutput() throws IOException {
+ return getOutput(targetLookupFile);
+ }
+
+ DataOutputStream getFeatureOutput() throws IOException {
+ return getOutput(featureFile);
+ }
+
+ DataOutputStream getAlignmentOutput() throws IOException {
+ if (alignmentFile != null)
+ return getOutput(alignmentFile);
+ return null;
+ }
+
+ private DataOutputStream getOutput(File file) throws IOException {
+ if (file.createNewFile()) {
+ return new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
+ } else {
+ throw new RuntimeException("File doesn't exist: " + file.getName());
+ }
+ }
+
+ long getSize() {
+ return sourceFile.length() + targetFile.length() + featureFile.length();
+ }
+
+ @Override
+ public int compareTo(PackingFileTuple o) {
+ if (getSize() > o.getSize()) {
+ return -1;
+ } else if (getSize() < o.getSize()) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
index 2fd2b3f,0000000..8f15b0e
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
+++ b/joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
@@@ -1,111 -1,0 +1,111 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.tools;
+
+import java.io.IOException;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.corpus.syntax.ArraySyntaxTree;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Finds labeling for a set of phrases.
- *
++ *
+ * @author Juri Ganitkevitch
+ */
+public class LabelPhrases {
+
+ private static final Logger LOG = LoggerFactory.getLogger(LabelPhrases.class);
+
+ /**
+ * Main method.
- *
++ *
+ * @param args names of the two grammars to be compared
+ * @throws IOException if there is an error reading the input grammars
+ */
+ public static void main(String[] args) throws IOException {
+
+ if (args.length < 1 || args[0].equals("-h")) {
+ System.err.println("Usage: " + LabelPhrases.class.toString());
+ System.err.println(" -p phrase_file phrase-sentence file to process");
+ System.err.println();
+ System.exit(-1);
+ }
+
+ String phrase_file_name = null;
+
+ for (int i = 0; i < args.length; i++) {
+ if ("-p".equals(args[i])) phrase_file_name = args[++i];
+ }
+ if (phrase_file_name == null) {
+ LOG.error("a phrase file is required for operation");
+ System.exit(-1);
+ }
+
- LineReader phrase_reader = new LineReader(phrase_file_name);
-
- while (phrase_reader.ready()) {
- String line = phrase_reader.readLine();
-
- String[] fields = line.split("\\t");
- if (fields.length != 3 || fields[2].equals("()")) {
- System.err.println("[FAIL] Empty parse in line:\t" + line);
- continue;
- }
-
- String[] phrase_strings = fields[0].split("\\s");
- int[] phrase_ids = new int[phrase_strings.length];
- for (int i = 0; i < phrase_strings.length; i++)
- phrase_ids[i] = Vocabulary.id(phrase_strings[i]);
++ try (LineReader phrase_reader = new LineReader(phrase_file_name);) {
++ while (phrase_reader.ready()) {
++ String line = phrase_reader.readLine();
+
- ArraySyntaxTree syntax = new ArraySyntaxTree(fields[2]);
- int[] sentence_ids = syntax.getTerminals();
++ String[] fields = line.split("\\t");
++ if (fields.length != 3 || fields[2].equals("()")) {
++ System.err.println("[FAIL] Empty parse in line:\t" + line);
++ continue;
++ }
+
- int match_start = -1;
- int match_end = -1;
- for (int i = 0; i < sentence_ids.length; i++) {
- if (phrase_ids[0] == sentence_ids[i]) {
- match_start = i;
- int j = 0;
- while (j < phrase_ids.length && phrase_ids[j] == sentence_ids[i + j]) {
- j++;
- }
- if (j == phrase_ids.length) {
- match_end = i + j;
- break;
++ String[] phrase_strings = fields[0].split("\\s");
++ int[] phrase_ids = new int[phrase_strings.length];
++ for (int i = 0; i < phrase_strings.length; i++)
++ phrase_ids[i] = Vocabulary.id(phrase_strings[i]);
++
++ ArraySyntaxTree syntax = new ArraySyntaxTree(fields[2]);
++ int[] sentence_ids = syntax.getTerminals();
++
++ int match_start = -1;
++ int match_end = -1;
++ for (int i = 0; i < sentence_ids.length; i++) {
++ if (phrase_ids[0] == sentence_ids[i]) {
++ match_start = i;
++ int j = 0;
++ while (j < phrase_ids.length && phrase_ids[j] == sentence_ids[i + j]) {
++ j++;
++ }
++ if (j == phrase_ids.length) {
++ match_end = i + j;
++ break;
++ }
+ }
+ }
- }
+
- int label = syntax.getOneConstituent(match_start, match_end);
- if (label == 0) label = syntax.getOneSingleConcatenation(match_start, match_end);
- if (label == 0) label = syntax.getOneRightSideCCG(match_start, match_end);
- if (label == 0) label = syntax.getOneLeftSideCCG(match_start, match_end);
- if (label == 0) label = syntax.getOneDoubleConcatenation(match_start, match_end);
- if (label == 0) {
- System.err.println("[FAIL] No label found in line:\t" + line);
- continue;
- }
++ int label = syntax.getOneConstituent(match_start, match_end);
++ if (label == 0) label = syntax.getOneSingleConcatenation(match_start, match_end);
++ if (label == 0) label = syntax.getOneRightSideCCG(match_start, match_end);
++ if (label == 0) label = syntax.getOneLeftSideCCG(match_start, match_end);
++ if (label == 0) label = syntax.getOneDoubleConcatenation(match_start, match_end);
++ if (label == 0) {
++ System.err.println("[FAIL] No label found in line:\t" + line);
++ continue;
++ }
+
- System.out.println(Vocabulary.word(label) + "\t" + line);
++ System.out.println(Vocabulary.word(label) + "\t" + line);
++ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
index ecb2e6e,0000000..f73f02d
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
+++ b/joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
@@@ -1,383 -1,0 +1,384 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.tools;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TestSetFilter {
+
+ private static final Logger LOG = LoggerFactory.getLogger(TestSetFilter.class);
+
+ private Filter filter = null;
+
+ // for caching of accepted rules
+ private String lastSourceSide;
+ private boolean acceptedLastSourceSide;
+
+ public int cached = 0;
+ public int RULE_LENGTH = 12;
+ public boolean verbose = false;
+ public boolean parallel = false;
+
+ private static final String DELIMITER = "|||";
+ private static final String DELIMITER_REGEX = " \\|\\|\\| ";
+ public static final String DELIM = String.format(" %s ", DELIMITER);
+ public static final Pattern P_DELIM = Pattern.compile(DELIMITER_REGEX);
+ private final String NT_REGEX = "\\[[^\\]]+?\\]";
+
+ public TestSetFilter() {
+ acceptedLastSourceSide = false;
+ lastSourceSide = null;
+ }
-
++
+ public String getFilterName() {
+ if (filter != null)
+ if (filter instanceof FastFilter)
+ return "fast";
+ else if (filter instanceof LooseFilter)
+ return "loose";
+ else
+ return "exact";
+ return "null";
+ }
+
+ public void setVerbose(boolean value) {
+ verbose = value;
+ }
+
+ public void setParallel(boolean value) {
+ parallel = value;
+ }
+
+ public void setFilter(String type) {
+ if (type.equals("fast"))
+ filter = new FastFilter();
+ else if (type.equals("exact"))
+ filter = new ExactFilter();
+ else if (type.equals("loose"))
+ filter = new LooseFilter();
+ else
+ throw new RuntimeException(String.format("Invalid filter type '%s'", type));
+ }
+
+ public void setRuleLength(int value) {
+ RULE_LENGTH = value;
+ }
+
+ private void loadTestSentences(String filename) throws IOException {
+ int count = 0;
+
+ try {
+ for (String line: new LineReader(filename)) {
+ filter.addSentence(line);
+ count++;
+ }
+ } catch (FileNotFoundException e) {
+ LOG.error(e.getMessage(), e);
+ }
+
+ if (verbose)
+ System.err.println(String.format("Added %d sentences.\n", count));
+ }
+
+ /**
- * Top-level filter, responsible for calling the fast or exact version. Takes the source side
++ * Top-level filter, responsible for calling the fast or exact version. Takes the source side
+ * of a rule and determines whether there is any sentence in the test set that can match it.
+ * @param sourceSide an input source sentence
+ * @return true if is any sentence in the test set can match the source input
+ */
+ public boolean inTestSet(String sourceSide) {
+ if (!sourceSide.equals(lastSourceSide)) {
+ lastSourceSide = sourceSide;
+ acceptedLastSourceSide = filter.permits(sourceSide);
+ } else {
+ cached++;
+ }
+
+ return acceptedLastSourceSide;
+ }
-
++
+ /**
+ * Determines whether a rule is an abstract rule. An abstract rule is one that has no terminals on
+ * its source side.
- *
++ *
+ * If the rule is abstract, the rule's arity is returned. Otherwise, 0 is returned.
+ */
+ private boolean isAbstract(String source) {
+ int nonterminalCount = 0;
+ for (String t : source.split("\\s+")) {
+ if (!t.matches(NT_REGEX))
+ return false;
+ nonterminalCount++;
+ }
+ return nonterminalCount != 0;
+ }
+
+ private interface Filter {
+ /* Tell the filter about a sentence in the test set being filtered to */
+ public void addSentence(String sentence);
-
++
+ /* Returns true if the filter permits the specified source side */
+ public boolean permits(String sourceSide);
+ }
+
+ private class FastFilter implements Filter {
+ private Set<String> ngrams = null;
+
+ public FastFilter() {
+ ngrams = new HashSet<String>();
+ }
-
++
+ @Override
+ public boolean permits(String source) {
+ for (String chunk : source.split(NT_REGEX)) {
+ chunk = chunk.trim();
+ /* Important: you need to make sure the string isn't empty. */
+ if (!chunk.equals("") && !ngrams.contains(chunk))
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void addSentence(String sentence) {
+ String[] tokens = sentence.trim().split("\\s+");
+ int maxOrder = RULE_LENGTH < tokens.length ? RULE_LENGTH : tokens.length;
+ for (int order = 1; order <= maxOrder; order++) {
+ for (int start = 0; start < tokens.length - order + 1; start++)
+ ngrams.add(createNGram(tokens, start, order));
+ }
+ }
+
+ private String createNGram(String[] tokens, int start, int order) {
+ if (order < 1 || start + order > tokens.length) {
+ return "";
+ }
+ String result = tokens[start];
+ for (int i = 1; i < order; i++)
+ result += " " + tokens[start + i];
+ return result;
+ }
+ }
+
+ private class LooseFilter implements Filter {
+ List<String> testSentences = null;
+
+ public LooseFilter() {
+ testSentences = new ArrayList<String>();
+ }
-
++
+ @Override
+ public void addSentence(String source) {
+ testSentences.add(source);
+ }
+
+ @Override
+ public boolean permits(String source) {
+ Pattern pattern = getPattern(source);
+ for (String testSentence : testSentences) {
+ if (pattern.matcher(testSentence).find()) {
+ return true;
+ }
+ }
+ return isAbstract(source);
+ }
+
+ protected Pattern getPattern(String source) {
+ String pattern = source;
+ pattern = pattern.replaceAll(String.format("\\s*%s\\s*", NT_REGEX), ".+");
+ pattern = pattern.replaceAll("\\s+", ".*");
+// System.err.println(String.format("PATTERN(%s) = %s", source, pattern));
+ return Pattern.compile(pattern);
+ }
+ }
+
+ /**
+ * This class is the same as LooseFilter except with a tighter regex for matching rules.
+ */
+ private class ExactFilter implements Filter {
+ private FastFilter fastFilter = null;
+ private Map<String, Set<Integer>> sentencesByWord;
+ List<String> testSentences = null;
-
++
+ public ExactFilter() {
+ fastFilter = new FastFilter();
+ sentencesByWord = new HashMap<String, Set<Integer>>();
+ testSentences = new ArrayList<String>();
+ }
-
++
+ @Override
+ public void addSentence(String source) {
+ fastFilter.addSentence(source);
+ addSentenceToWordHash(source, testSentences.size());
+ testSentences.add(source);
+ }
+
+ /**
+ * Always permit abstract rules. Otherwise, query the fast filter, and if that passes, apply
- *
++ *
+ */
+ @Override
+ public boolean permits(String sourceSide) {
+ if (isAbstract(sourceSide))
+ return true;
-
++
+ if (fastFilter.permits(sourceSide)) {
+ Pattern pattern = getPattern(sourceSide);
+ for (int i : getSentencesForRule(sourceSide)) {
+ if (pattern.matcher(testSentences.get(i)).find()) {
+ return true;
+ }
+ }
- }
++ }
+ return false;
+ }
-
++
+ protected Pattern getPattern(String source) {
+ String pattern = Pattern.quote(source);
+ pattern = pattern.replaceAll(NT_REGEX, "\\\\E.+\\\\Q");
+ pattern = pattern.replaceAll("\\\\Q\\\\E", "");
+ pattern = "(?:^|\\s)" + pattern + "(?:$|\\s)";
+ return Pattern.compile(pattern);
+ }
-
++
+ /*
+ * Map words to all the sentences they appear in.
+ */
+ private void addSentenceToWordHash(String sentence, int index) {
+ String[] tokens = sentence.split("\\s+");
+ for (String t : tokens) {
+ if (! sentencesByWord.containsKey(t))
+ sentencesByWord.put(t, new HashSet<Integer>());
+ sentencesByWord.get(t).add(index);
+ }
+ }
-
++
+ private Set<Integer> getSentencesForRule(String source) {
+ Set<Integer> sentences = null;
+ for (String token : source.split("\\s+")) {
+ if (!token.matches(NT_REGEX)) {
+ if (sentencesByWord.containsKey(token)) {
+ if (sentences == null)
+ sentences = new HashSet<Integer>(sentencesByWord.get(token));
+ else
+ sentences.retainAll(sentencesByWord.get(token));
+ }
+ }
+ }
-
++
+ return sentences;
+ }
+ }
+
+ public static void main(String[] argv) throws IOException {
+ // do some setup
+ if (argv.length < 1) {
+ System.err.println("usage: TestSetFilter [-v|-p|-f|-e|-l|-n N|-g grammar] test_set1 [test_set2 ...]");
+ System.err.println(" -g grammar file (can also be on STDIN)");
+ System.err.println(" -v verbose output");
+ System.err.println(" -p parallel compatibility");
+ System.err.println(" -f fast mode (default)");
+ System.err.println(" -e exact mode (slower)");
+ System.err.println(" -l loose mode");
+ System.err.println(" -n max n-gram to compare to (default 12)");
+ return;
+ }
-
++
+ String grammarFile = null;
+
+ TestSetFilter filter = new TestSetFilter();
+
+ for (int i = 0; i < argv.length; i++) {
+ if (argv[i].equals("-v")) {
+ filter.setVerbose(true);
+ continue;
+ } else if (argv[i].equals("-p")) {
+ filter.setParallel(true);
+ continue;
+ } else if (argv[i].equals("-g")) {
+ grammarFile = argv[++i];
+ continue;
+ } else if (argv[i].equals("-f")) {
+ filter.setFilter("fast");
+ continue;
+ } else if (argv[i].equals("-e")) {
+ filter.setFilter("exact");
+ continue;
+ } else if (argv[i].equals("-l")) {
+ filter.setFilter("loose");
+ continue;
+ } else if (argv[i].equals("-n")) {
+ filter.setRuleLength(Integer.parseInt(argv[i + 1]));
+ i++;
+ continue;
+ }
+
+ filter.loadTestSentences(argv[i]);
+ }
+
+ int rulesIn = 0;
+ int rulesOut = 0;
+ if (filter.verbose) {
+ System.err.println(String.format("Filtering rules with the %s filter...", filter.getFilterName()));
+// System.err.println("Using at max " + filter.RULE_LENGTH + " n-grams...");
+ }
- LineReader reader = (grammarFile != null)
++ try(LineReader reader = (grammarFile != null)
+ ? new LineReader(grammarFile, filter.verbose)
- : new LineReader(System.in);
- for (String rule: reader) {
- rulesIn++;
-
- String[] parts = P_DELIM.split(rule);
- if (parts.length >= 4) {
- // the source is the second field for thrax grammars, first field for phrasal ones
- String source = rule.startsWith("[") ? parts[1].trim() : parts[0].trim();
- if (filter.inTestSet(source)) {
- System.out.println(rule);
- if (filter.parallel)
++ : new LineReader(System.in);) {
++ for (String rule: reader) {
++ rulesIn++;
++
++ String[] parts = P_DELIM.split(rule);
++ if (parts.length >= 4) {
++ // the source is the second field for thrax grammars, first field for phrasal ones
++ String source = rule.startsWith("[") ? parts[1].trim() : parts[0].trim();
++ if (filter.inTestSet(source)) {
++ System.out.println(rule);
++ if (filter.parallel)
++ System.out.flush();
++ rulesOut++;
++ } else if (filter.parallel) {
++ System.out.println("");
+ System.out.flush();
- rulesOut++;
- } else if (filter.parallel) {
- System.out.println("");
- System.out.flush();
++ }
+ }
+ }
- }
- if (filter.verbose) {
- System.err.println("[INFO] Total rules read: " + rulesIn);
- System.err.println("[INFO] Rules kept: " + rulesOut);
- System.err.println("[INFO] Rules dropped: " + (rulesIn - rulesOut));
- System.err.println("[INFO] cached queries: " + filter.cached);
- }
++ if (filter.verbose) {
++ System.err.println("[INFO] Total rules read: " + rulesIn);
++ System.err.println("[INFO] Rules kept: " + rulesOut);
++ System.err.println("[INFO] Rules dropped: " + (rulesIn - rulesOut));
++ System.err.println("[INFO] cached queries: " + filter.cached);
++ }
+
- return;
++ return;
++ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/Constants.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/Constants.java
index bcabfe4,0000000..669023b
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/Constants.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/Constants.java
@@@ -1,44 -1,0 +1,45 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util;
+
+/***
+ * One day, all constants should be moved here (many are in Vocabulary).
- *
++ *
+ * @author Matt Post post@cs.jhu.edu
+ */
+
+public final class Constants {
- public static String defaultNT = "[X]";
++ public static final String defaultNT = "[X]";
+
+ public static final String START_SYM = "<s>";
+ public static final String STOP_SYM = "</s>";
+ public static final String UNKNOWN_WORD = "<unk>";
-
++
+ public static final String fieldDelimiterPattern = "\\s\\|{3}\\s";
+ public static final String fieldDelimiter = " ||| ";
++
+ public static final String spaceSeparator = "\\s+";
-
++
+ public static final String NT_REGEX = "\\[[^\\]]+?\\]";
-
++
+ public static final String TM_PREFIX = "tm";
-
++
+ public static final String labeledFeatureSeparator = "=";
-
++
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
index a36b07f,0000000..0f13e6a
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
@@@ -1,318 -1,0 +1,63 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util;
+
- import java.io.BufferedReader;
+import java.io.BufferedWriter;
- import java.io.Closeable;
+import java.io.File;
+import java.io.FileDescriptor;
- import java.io.FileInputStream;
+import java.io.FileOutputStream;
- import java.io.FileReader;
+import java.io.IOException;
- import java.io.InputStream;
- import java.io.OutputStream;
+import java.io.OutputStreamWriter;
- import java.nio.charset.Charset;
- import java.util.LinkedList;
- import java.util.List;
- import java.util.Scanner;
++import java.nio.charset.StandardCharsets;
+
+/**
+ * utility functions for file operations
- *
++ *
+ * @author Zhifei Li, zhifei.work@gmail.com
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @since 28 February 2009
+ */
+public class FileUtility {
- public static String DEFAULT_ENCODING = "UTF-8";
-
- /*
- * Note: charset name is case-agnostic "UTF-8" is the canonical name "UTF8", "unicode-1-1-utf-8"
- * are aliases Java doesn't distinguish utf8 vs UTF-8 like Perl does
- */
- private static final Charset FILE_ENCODING = Charset.forName(DEFAULT_ENCODING);
+
+ /**
+ * Warning, will truncate/overwrite existing files
+ * @param filename a file for which to obtain a writer
+ * @return the buffered writer object
+ * @throws IOException if there is a problem reading the inout file
+ */
+ public static BufferedWriter getWriteFileStream(String filename) throws IOException {
+ return new BufferedWriter(new OutputStreamWriter(
+ // TODO: add GZIP
+ filename.equals("-") ? new FileOutputStream(FileDescriptor.out) : new FileOutputStream(
- filename, false), FILE_ENCODING));
- }
-
- /**
- * Recursively delete the specified file or directory.
- *
- * @param f File or directory to delete
- * @return <code>true</code> if the specified file or directory was deleted, <code>false</code>
- * otherwise
- */
- public static boolean deleteRecursively(File f) {
- if (null != f) {
- if (f.isDirectory())
- for (File child : f.listFiles())
- deleteRecursively(child);
- return f.delete();
- } else {
- return false;
- }
- }
-
- /**
- * Writes data from the integer array to disk as raw bytes, overwriting the old file if present.
- *
- * @param data The integer array to write to disk.
- * @param filename The filename where the data should be written.
- * @throws IOException if there is a problem writing to the output file
- * @return the FileOutputStream on which the bytes were written
- */
- public static FileOutputStream writeBytes(int[] data, String filename) throws IOException {
- FileOutputStream out = new FileOutputStream(filename, false);
- writeBytes(data, out);
- return out;
- }
-
- /**
- * Writes data from the integer array to disk as raw bytes.
- *
- * @param data The integer array to write to disk.
- * @param out The output stream where the data should be written.
- * @throws IOException if there is a problem writing bytes
- */
- public static void writeBytes(int[] data, OutputStream out) throws IOException {
-
- byte[] b = new byte[4];
-
- for (int word : data) {
- b[0] = (byte) ((word >>> 24) & 0xFF);
- b[1] = (byte) ((word >>> 16) & 0xFF);
- b[2] = (byte) ((word >>> 8) & 0xFF);
- b[3] = (byte) ((word >>> 0) & 0xFF);
-
- out.write(b);
- }
- }
-
- public static void copyFile(String srFile, String dtFile) throws IOException {
- try {
- File f1 = new File(srFile);
- File f2 = new File(dtFile);
- copyFile(f1, f2);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public static void copyFile(File srFile, File dtFile) throws IOException {
- try {
-
- InputStream in = new FileInputStream(srFile);
-
- // For Append the file.
- // OutputStream out = new FileOutputStream(f2,true);
-
- // For Overwrite the file.
- OutputStream out = new FileOutputStream(dtFile);
-
- byte[] buf = new byte[1024];
- int len;
- while ((len = in.read(buf)) > 0) {
- out.write(buf, 0, len);
- }
- in.close();
- out.close();
- System.out.println("File copied.");
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- static public boolean deleteFile(String fileName) {
-
- File f = new File(fileName);
-
- // Make sure the file or directory exists and isn't write protected
- if (!f.exists())
- System.out.println("Delete: no such file or directory: " + fileName);
-
- if (!f.canWrite())
- System.out.println("Delete: write protected: " + fileName);
-
- // If it is a directory, make sure it is empty
- if (f.isDirectory()) {
- String[] files = f.list();
- if (files.length > 0)
- System.out.println("Delete: directory not empty: " + fileName);
- }
-
- // Attempt to delete it
- boolean success = f.delete();
-
- if (!success)
- System.out.println("Delete: deletion failed");
-
- return success;
-
++ filename, false), StandardCharsets.UTF_8));
+ }
+
+ /**
+ * Returns the base directory of the file. For example, dirname('/usr/local/bin/emacs') ->
+ * '/usr/local/bin'
+ * @param fileName the input path
+ * @return the parent path
+ */
+ static public String dirname(String fileName) {
+ if (fileName.indexOf(File.separator) != -1)
+ return fileName.substring(0, fileName.lastIndexOf(File.separator));
+
+ return ".";
+ }
-
- public static void createFolderIfNotExisting(String folderName) {
- File f = new File(folderName);
- if (!f.isDirectory()) {
- System.out.println(" createFolderIfNotExisting -- Making directory: " + folderName);
- f.mkdirs();
- } else {
- System.out.println(" createFolderIfNotExisting -- Directory: " + folderName
- + " already existed");
- }
- }
-
- public static void closeCloseableIfNotNull(Closeable fileWriter) {
- if (fileWriter != null) {
- try {
- fileWriter.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- }
-
- /**
- * Returns the directory were the program has been started,
- * the base directory you will implicitly get when specifying no
- * full path when e.g. opening a file
- * @return the current 'user.dir'
- */
- public static String getWorkingDirectory() {
- return System.getProperty("user.dir");
- }
-
- /**
- * Method to handle standard IO exceptions. catch (Exception e) {Utility.handleIO_exception(e);}
- * @param e an input {@link java.lang.Exception}
- */
- public static void handleExceptions(Exception e) {
- throw new RuntimeException(e);
- }
-
- /**
- * Convenience method to get a full file as a String
- * @param file the input {@link java.io.File}
- * @return The file as a String. Lines are separated by newline character.
- */
- public static String getFileAsString(File file) {
- String result = "";
- List<String> lines = getLines(file, true);
- for (int i = 0; i < lines.size() - 1; i++) {
- result += lines.get(i) + "\n";
- }
- if (!lines.isEmpty()) {
- result += lines.get(lines.size() - 1);
- }
- return result;
- }
-
- /**
- * This method returns a List of String. Each element of the list corresponds to a line from the
- * input file. The boolean keepDuplicates in the input determines if duplicate lines are allowed
- * in the output LinkedList or not.
- * @param file the input file
- * @param keepDuplicates whether to retain duplicate lines
- * @return a {@link java.util.List} of lines
- */
- static public List<String> getLines(File file, boolean keepDuplicates) {
- LinkedList<String> list = new LinkedList<String>();
- String line = "";
- try {
- BufferedReader InputReader = new BufferedReader(new FileReader(file));
- for (;;) { // this loop writes writes in a Sting each sentence of
- // the file and process it
- int current = InputReader.read();
- if (current == -1 || current == '\n') {
- if (keepDuplicates || !list.contains(line))
- list.add(line);
- line = "";
- if (current == -1)
- break; // EOF
- } else
- line += (char) current;
- }
- InputReader.close();
- } catch (Exception e) {
- handleExceptions(e);
- }
- return list;
- }
-
- /**
- * Returns a Scanner of the inputFile using a specific encoding
- *
- * @param inputFile the file for which to get a {@link java.util.Scanner} object
- * @param encoding the encoding to use within the Scanner
- * @return a {@link java.util.Scanner} object for a given file
- */
- public static Scanner getScanner(File inputFile, String encoding) {
- Scanner scan = null;
- try {
- scan = new Scanner(inputFile, encoding);
- } catch (IOException e) {
- FileUtility.handleExceptions(e);
- }
- return scan;
- }
-
- /**
- * Returns a Scanner of the inputFile using default encoding
- *
- * @param inputFile the file for which to get a {@link java.util.Scanner} object
- * @return a {@link java.util.Scanner} object for a given file
- */
- public static Scanner getScanner(File inputFile) {
- return getScanner(inputFile, DEFAULT_ENCODING);
- }
-
- static public String getFirstLineInFile(File inputFile) {
- Scanner scan = FileUtility.getScanner(inputFile);
- if (!scan.hasNextLine())
- return null;
- String line = scan.nextLine();
- scan.close();
- return line;
- }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
index afb5af1,0000000..14154e8
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
@@@ -1,95 -1,0 +1,44 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util;
+
+import java.util.List;
+
+public class ListUtil {
+
- /**
- * Static method to generate a list representation for an ArrayList of Strings S1,...,Sn
- *
- * @param list A list of Strings
- * @return A String consisting of the original list of strings concatenated and separated by
- * commas, and enclosed by square brackets i.e. '[S1,S2,...,Sn]'
- */
- public static String stringListString(List<String> list) {
-
- String result = "[";
- for (int i = 0; i < list.size() - 1; i++) {
- result += list.get(i) + ",";
- }
-
- if (list.size() > 0) {
- // get the generated word for the last target position
- result += list.get(list.size() - 1);
- }
-
- result += "]";
-
- return result;
-
- }
-
- public static <E> String objectListString(List<E> list) {
- String result = "[";
- for (int i = 0; i < list.size() - 1; i++) {
- result += list.get(i) + ",";
- }
- if (list.size() > 0) {
- // get the generated word for the last target position
- result += list.get(list.size() - 1);
- }
- result += "]";
- return result;
- }
-
- /**
- * Static method to generate a simple concatenated representation for an ArrayList of Strings
- * S1,...,Sn
- *
- * @param list A list of Strings
- * @return todo
- */
- public static String stringListStringWithoutBrackets(List<String> list) {
- return stringListStringWithoutBracketsWithSpecifiedSeparator(list, " ");
- }
-
+ public static String stringListStringWithoutBracketsCommaSeparated(List<String> list) {
+ return stringListStringWithoutBracketsWithSpecifiedSeparator(list, ",");
+ }
+
+ public static String stringListStringWithoutBracketsWithSpecifiedSeparator(List<String> list,
+ String separator) {
+
+ String result = "";
+ for (int i = 0; i < list.size() - 1; i++) {
+ result += list.get(i) + separator;
+ }
+
+ if (list.size() > 0) {
+ // get the generated word for the last target position
+ result += list.get(list.size() - 1);
+ }
+
+ return result;
-
+ }
-
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
index ad2910c,0000000..d9bab66
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
@@@ -1,235 -1,0 +1,236 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.encoding;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.TreeMap;
+
+import org.apache.joshua.util.io.LineReader;
+
+public class Analyzer {
+
+ private TreeMap<Float, Integer> histogram;
+ private int total;
+
+ public Analyzer() {
+ histogram = new TreeMap<Float, Integer>();
+ initialize();
+ }
+
+ public void initialize() {
+ histogram.clear();
+ // TODO: drop zero bucket; we won't encode zero-valued features anyway.
+ histogram.put(0.0f, 0);
+ total = 0;
+ }
+
+ public void add(float key) {
+ if (histogram.containsKey(key))
+ histogram.put(key, histogram.get(key) + 1);
+ else
+ histogram.put(key, 1);
+ total++;
+ }
+
+ public float[] quantize(int num_bits) {
+ float[] buckets = new float[1 << num_bits];
+
+ // We make sure that 0.0f always has its own bucket, so the bucket
+ // size is determined excluding the zero values.
+ int size = (total - histogram.get(0.0f)) / (buckets.length - 1);
+ buckets[0] = 0.0f;
+
+ int old_size = -1;
+ while (old_size != size) {
+ int sum = 0;
+ int count = buckets.length - 1;
+ for (float key : histogram.keySet()) {
+ int entry_count = histogram.get(key);
+ if (entry_count < size && key != 0)
+ sum += entry_count;
+ else
+ count--;
+ }
+ old_size = size;
+ size = sum / count;
+ }
+
+ float last_key = Float.MAX_VALUE;
+
+ int index = 1;
+ int count = 0;
+ float sum = 0.0f;
+
+ int value;
+ for (float key : histogram.keySet()) {
+ value = histogram.get(key);
+ // Special bucket termination cases: zero boundary and histogram spikes.
+ if (key == 0 || (last_key < 0 && key > 0) || (value >= size)) {
+ // If the count is not 0, i.e. there were negative values, we should
+ // not bucket them with the positive ones. Close out the bucket now.
+ if (count != 0 && index < buckets.length - 2) {
- buckets[index++] = (float) sum / count;
++ buckets[index++] = sum / count;
+ count = 0;
+ sum = 0;
+ }
+ if (key == 0)
+ continue;
+ }
+ count += value;
+ sum += key * value;
+ // Check if the bucket is full.
+ if (count >= size && index < buckets.length - 2) {
- buckets[index++] = (float) sum / count;
++ buckets[index++] = sum / count;
+ count = 0;
+ sum = 0;
+ }
+ last_key = key;
+ }
+ if (count > 0 && index < buckets.length - 1)
- buckets[index++] = (float) sum / count;
-
++ buckets[index++] = sum / count;
++
+ float[] shortened = new float[index];
+ for (int i = 0; i < shortened.length; ++i)
+ shortened[i] = buckets[i];
+ return shortened;
+ }
+
+ public boolean isBoolean() {
+ for (float value : histogram.keySet())
+ if (value != 0 && value != 1)
+ return false;
+ return true;
+ }
+
+ public boolean isByte() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value || value < Byte.MIN_VALUE || value > Byte.MAX_VALUE)
+ return false;
+ return true;
+ }
+
+ public boolean isShort() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value || value < Short.MIN_VALUE || value > Short.MAX_VALUE)
+ return false;
+ return true;
+ }
+
+ public boolean isChar() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value || value < Character.MIN_VALUE || value > Character.MAX_VALUE)
+ return false;
+ return true;
+ }
+
+ public boolean isInt() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value)
+ return false;
+ return true;
+ }
+
+ public boolean is8Bit() {
+ return (histogram.keySet().size() <= 256);
+ }
+
+ public FloatEncoder inferUncompressedType() {
+ if (isBoolean())
+ return PrimitiveFloatEncoder.BOOLEAN;
+ if (isByte())
+ return PrimitiveFloatEncoder.BYTE;
+ if (is8Bit())
+ return new EightBitQuantizer(this.quantize(8));
+ if (isChar())
+ return PrimitiveFloatEncoder.CHAR;
+ if (isShort())
+ return PrimitiveFloatEncoder.SHORT;
+ if (isInt())
+ return PrimitiveFloatEncoder.INT;
+ return PrimitiveFloatEncoder.FLOAT;
+ }
-
++
+ public FloatEncoder inferType(int bits) {
+ if (isBoolean())
+ return PrimitiveFloatEncoder.BOOLEAN;
+ if (isByte())
+ return PrimitiveFloatEncoder.BYTE;
+ if (bits == 8 || is8Bit())
+ return new EightBitQuantizer(this.quantize(8));
+ // TODO: Could add sub-8-bit encoding here (or larger).
+ if (isChar())
+ return PrimitiveFloatEncoder.CHAR;
+ if (isShort())
+ return PrimitiveFloatEncoder.SHORT;
+ if (isInt())
+ return PrimitiveFloatEncoder.INT;
+ return PrimitiveFloatEncoder.FLOAT;
+ }
+
+ public String toString(String label) {
+ StringBuilder sb = new StringBuilder();
+ for (float val : histogram.keySet())
+ sb.append(label + "\t" + String.format("%.5f", val) + "\t" + histogram.get(val) + "\n");
+ return sb.toString();
+ }
-
++
+ public static void main(String[] args) throws IOException {
- LineReader reader = new LineReader(args[0]);
- ArrayList<Float> s = new ArrayList<Float>();
-
- System.out.println("Initialized.");
- while (reader.hasNext())
- s.add(Float.parseFloat(reader.next().trim()));
- System.out.println("Data read.");
- int n = s.size();
- byte[] c = new byte[n];
- ByteBuffer b = ByteBuffer.wrap(c);
- Analyzer q = new Analyzer();
-
- q.initialize();
- for (int i = 0; i < n; i++)
- q.add(s.get(i));
- EightBitQuantizer eb = new EightBitQuantizer(q.quantize(8));
- System.out.println("Quantizer learned.");
-
- for (int i = 0; i < n; i++)
- eb.write(b, s.get(i));
- b.rewind();
- System.out.println("Quantization complete.");
-
- float avg_error = 0;
- float error = 0;
- int count = 0;
- for (int i = -4; i < n - 4; i++) {
- float coded = eb.read(b, i);
- if (s.get(i + 4) != 0) {
- error = Math.abs(s.get(i + 4) - coded);
- avg_error += error;
- count++;
++ try (LineReader reader = new LineReader(args[0]);) {
++ ArrayList<Float> s = new ArrayList<Float>();
++
++ System.out.println("Initialized.");
++ while (reader.hasNext())
++ s.add(Float.parseFloat(reader.next().trim()));
++ System.out.println("Data read.");
++ int n = s.size();
++ byte[] c = new byte[n];
++ ByteBuffer b = ByteBuffer.wrap(c);
++ Analyzer q = new Analyzer();
++
++ q.initialize();
++ for (int i = 0; i < n; i++)
++ q.add(s.get(i));
++ EightBitQuantizer eb = new EightBitQuantizer(q.quantize(8));
++ System.out.println("Quantizer learned.");
++
++ for (int i = 0; i < n; i++)
++ eb.write(b, s.get(i));
++ b.rewind();
++ System.out.println("Quantization complete.");
++
++ float avg_error = 0;
++ float error = 0;
++ int count = 0;
++ for (int i = -4; i < n - 4; i++) {
++ float coded = eb.read(b, i);
++ if (s.get(i + 4) != 0) {
++ error = Math.abs(s.get(i + 4) - coded);
++ avg_error += error;
++ count++;
++ }
+ }
- }
- avg_error /= count;
- System.out.println("Evaluation complete.");
++ avg_error /= count;
++ System.out.println("Evaluation complete.");
+
- System.out.println("Average quanitization error over " + n + " samples is: " + avg_error);
++ System.out.println("Average quanitization error over " + n + " samples is: " + avg_error);
++ }
+ }
+}