You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/27 14:52:10 UTC
[42/51] [partial] mahout git commit: MAHOUT-2042 and MAHOUT-2045
Delete directories which were moved/no longer in use
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java
new file mode 100644
index 0000000..180a1e1
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java
@@ -0,0 +1,155 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.arff;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.Reader;
+import java.io.StringReader;
+import java.nio.charset.Charset;
+import java.text.DateFormat;
+import java.text.SimpleDateFormat;
+import java.util.Iterator;
+import java.util.Locale;
+
+import com.google.common.io.Files;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Read in ARFF (http://www.cs.waikato.ac.nz/~ml/weka/arff.html) and create {@link Vector}s
+ * <p/>
+ * Attribute type handling:
+ * <ul>
+ * <li>Numeric -> As is</li>
+ * <li>Nominal -> ordinal(value) i.e. @attribute lumber {'\'(-inf-0.5]\'','\'(0.5-inf)\''}
+ * will convert -inf-0.5 -> 0, and 0.5-inf -> 1</li>
+ * <li>Dates -> Convert to time as a long</li>
+ * <li>Strings -> Create a map of String -> long</li>
+ * </ul>
+ * NOTE: This class does not set the label bindings on every vector. If you want the label
+ * bindings, call {@link MapBackedARFFModel#getLabelBindings()}, as they are the same for every vector.
+ */
+public class ARFFVectorIterable implements Iterable<Vector> {
+
+ private final BufferedReader buff;
+ private final ARFFModel model;
+
+ public ARFFVectorIterable(File file, ARFFModel model) throws IOException {
+ this(file, Charsets.UTF_8, model);
+ }
+
+ public ARFFVectorIterable(File file, Charset encoding, ARFFModel model) throws IOException {
+ this(Files.newReader(file, encoding), model);
+ }
+
+ public ARFFVectorIterable(String arff, ARFFModel model) throws IOException {
+ this(new StringReader(arff), model);
+ }
+
+ public ARFFVectorIterable(Reader reader, ARFFModel model) throws IOException {
+ if (reader instanceof BufferedReader) {
+ buff = (BufferedReader) reader;
+ } else {
+ buff = new BufferedReader(reader);
+ }
+ //grab the attributes, then start the iterator at the first line of data
+ this.model = model;
+
+ int labelNumber = 0;
+ String line;
+ while ((line = buff.readLine()) != null) {
+ line = line.trim();
+ if (!line.startsWith(ARFFModel.ARFF_COMMENT) && !line.isEmpty()) {
+ Integer labelNumInt = labelNumber;
+ String[] lineParts = line.split("[\\s\\t]+", 2);
+
+ // is it a relation name?
+ if (lineParts[0].equalsIgnoreCase(ARFFModel.RELATION)) {
+ model.setRelation(ARFFType.removeQuotes(lineParts[1]));
+ }
+ // or an attribute
+ else if (lineParts[0].equalsIgnoreCase(ARFFModel.ATTRIBUTE)) {
+ String label;
+ ARFFType type;
+
+ // split the name of the attribute and its description
+ String[] attrParts = lineParts[1].split("[\\s\\t]+", 2);
+ if (attrParts.length < 2)
+ throw new UnsupportedOperationException("No type for attribute found: " + lineParts[1]);
+
+ // label is attribute name
+ label = ARFFType.removeQuotes(attrParts[0].toLowerCase());
+ if (attrParts[1].equalsIgnoreCase(ARFFType.NUMERIC.getIndicator())) {
+ type = ARFFType.NUMERIC;
+ } else if (attrParts[1].equalsIgnoreCase(ARFFType.INTEGER.getIndicator())) {
+ type = ARFFType.INTEGER;
+ } else if (attrParts[1].equalsIgnoreCase(ARFFType.REAL.getIndicator())) {
+ type = ARFFType.REAL;
+ } else if (attrParts[1].equalsIgnoreCase(ARFFType.STRING.getIndicator())) {
+ type = ARFFType.STRING;
+ } else if (attrParts[1].toLowerCase().startsWith(ARFFType.NOMINAL.getIndicator())) {
+ type = ARFFType.NOMINAL;
+ // nominal example:
+ // @ATTRIBUTE class {Iris-setosa,'Iris versicolor',Iris-virginica}
+ String[] classes = ARFFIterator.splitCSV(attrParts[1].substring(1, attrParts[1].length() - 1));
+ for (int i = 0; i < classes.length; i++) {
+ model.addNominal(label, ARFFType.removeQuotes(classes[i]), i + 1);
+ }
+ } else if (attrParts[1].toLowerCase().startsWith(ARFFType.DATE.getIndicator())) {
+ type = ARFFType.DATE;
+ //TODO: DateFormatter map
+ DateFormat format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.ENGLISH);
+ String formStr = attrParts[1].substring(ARFFType.DATE.getIndicator().length()).trim();
+ if (!formStr.isEmpty()) {
+ if (formStr.startsWith("\"")) {
+ formStr = formStr.substring(1, formStr.length() - 1);
+ }
+ format = new SimpleDateFormat(formStr, Locale.ENGLISH);
+ }
+ model.addDateFormat(labelNumInt, format);
+ //@attribute <name> date [<date-format>]
+ } else {
+ throw new UnsupportedOperationException("Invalid attribute: " + attrParts[1]);
+ }
+ model.addLabel(label, labelNumInt);
+ model.addType(labelNumInt, type);
+ labelNumber++;
+ } else if (lineParts[0].equalsIgnoreCase(ARFFModel.DATA)) {
+ break; //skip it
+ }
+ }
+ }
+
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return new ARFFIterator(buff, model);
+ }
+
+ /**
+ * Returns info about the ARFF content that was parsed.
+ *
+ * @return the model
+ */
+ public ARFFModel getModel() {
+ return model;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java
new file mode 100644
index 0000000..ccecbb1
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java
@@ -0,0 +1,263 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.arff;
+
+import java.io.File;
+import java.io.FilenameFilter;
+import java.io.IOException;
+import java.io.Writer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+
+import com.google.common.io.Files;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.utils.vectors.io.SequenceFileVectorWriter;
+import org.apache.mahout.utils.vectors.io.VectorWriter;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class Driver {
+
+ private static final Logger log = LoggerFactory.getLogger(Driver.class);
+
+ /** used for JSON serialization/deserialization */
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ private Driver() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputOpt = obuilder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The file or directory containing the ARFF files. If it is a directory, all .arff files will be converted")
+ .withShortName("d").create();
+
+ Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The output directory. Files will have the same name as the input, but with the extension .mvc")
+ .withShortName("o").create();
+
+ Option maxOpt = obuilder.withLongName("max").withRequired(false).withArgument(
+ abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The maximum number of vectors to output. If not specified, then it will loop over all docs")
+ .withShortName("m").create();
+
+ Option dictOutOpt = obuilder.withLongName("dictOut").withRequired(true).withArgument(
+ abuilder.withName("dictOut").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The file to output the label bindings").withShortName("t").create();
+
+ Option jsonDictonaryOpt = obuilder.withLongName("json-dictonary").withRequired(false)
+ .withDescription("Write dictonary in JSON format").withShortName("j").create();
+
+ Option delimiterOpt = obuilder.withLongName("delimiter").withRequired(false).withArgument(
+ abuilder.withName("delimiter").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The delimiter for outputing the dictionary").withShortName("l").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(maxOpt)
+ .withOption(helpOpt).withOption(dictOutOpt).withOption(jsonDictonaryOpt).withOption(delimiterOpt)
+ .create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+ if (cmdLine.hasOption(inputOpt)) { // Lucene case
+ File input = new File(cmdLine.getValue(inputOpt).toString());
+ long maxDocs = Long.MAX_VALUE;
+ if (cmdLine.hasOption(maxOpt)) {
+ maxDocs = Long.parseLong(cmdLine.getValue(maxOpt).toString());
+ }
+ if (maxDocs < 0) {
+ throw new IllegalArgumentException("maxDocs must be >= 0");
+ }
+ String outDir = cmdLine.getValue(outputOpt).toString();
+ log.info("Output Dir: {}", outDir);
+
+ String delimiter = cmdLine.hasOption(delimiterOpt) ? cmdLine.getValue(delimiterOpt).toString() : "\t";
+ File dictOut = new File(cmdLine.getValue(dictOutOpt).toString());
+ boolean jsonDictonary = cmdLine.hasOption(jsonDictonaryOpt);
+ ARFFModel model = new MapBackedARFFModel();
+ if (input.exists() && input.isDirectory()) {
+ File[] files = input.listFiles(new FilenameFilter() {
+ @Override
+ public boolean accept(File file, String name) {
+ return name.endsWith(".arff");
+ }
+ });
+
+ for (File file : files) {
+ writeFile(outDir, file, maxDocs, model, dictOut, delimiter, jsonDictonary);
+ }
+ } else {
+ writeFile(outDir, input, maxDocs, model, dictOut, delimiter, jsonDictonary);
+ }
+ }
+
+ } catch (OptionException e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+
+ protected static void writeLabelBindings(File dictOut, ARFFModel arffModel, String delimiter, boolean jsonDictonary)
+ throws IOException {
+ try (Writer writer = Files.newWriterSupplier(dictOut, Charsets.UTF_8, true).getOutput()) {
+ if (jsonDictonary) {
+ writeLabelBindingsJSON(writer, arffModel);
+ } else {
+ writeLabelBindings(writer, arffModel, delimiter);
+ }
+ }
+ }
+
+ protected static void writeLabelBindingsJSON(Writer writer, ARFFModel arffModel) throws IOException {
+
+ // Turn the map of labels into a list order by order of appearance
+ List<Entry<String, Integer>> attributes = new ArrayList<>();
+ attributes.addAll(arffModel.getLabelBindings().entrySet());
+ Collections.sort(attributes, new Comparator<Map.Entry<String, Integer>>() {
+ @Override
+ public int compare(Entry<String, Integer> t, Entry<String, Integer> t1) {
+ return t.getValue().compareTo(t1.getValue());
+ }
+ });
+
+ // write a map for each object
+ List<Map<String, Object>> jsonObjects = new LinkedList<>();
+ for (int i = 0; i < attributes.size(); i++) {
+
+ Entry<String, Integer> modelRepresentation = attributes.get(i);
+ Map<String, Object> jsonRepresentation = new HashMap<>();
+ jsonObjects.add(jsonRepresentation);
+ // the last one is the class label
+ jsonRepresentation.put("label", i < (attributes.size() - 1) ? String.valueOf(false) : String.valueOf(true));
+ String attribute = modelRepresentation.getKey();
+ jsonRepresentation.put("attribute", attribute);
+ Map<String, Integer> nominalValues = arffModel.getNominalMap().get(attribute);
+
+ if (nominalValues != null) {
+ String[] values = nominalValues.keySet().toArray(new String[1]);
+
+ jsonRepresentation.put("values", values);
+ jsonRepresentation.put("type", "categorical");
+ } else {
+ jsonRepresentation.put("type", "numerical");
+ }
+ }
+ writer.write(OBJECT_MAPPER.writeValueAsString(jsonObjects));
+ }
+
+ protected static void writeLabelBindings(Writer writer, ARFFModel arffModel, String delimiter) throws IOException {
+
+ Map<String, Integer> labels = arffModel.getLabelBindings();
+ writer.write("Label bindings for Relation " + arffModel.getRelation() + '\n');
+ for (Map.Entry<String, Integer> entry : labels.entrySet()) {
+ writer.write(entry.getKey());
+ writer.write(delimiter);
+ writer.write(String.valueOf(entry.getValue()));
+ writer.write('\n');
+ }
+ writer.write('\n');
+ writer.write("Values for nominal attributes\n");
+ // emit allowed values for NOMINAL/categorical/enumerated attributes
+ Map<String, Map<String, Integer>> nominalMap = arffModel.getNominalMap();
+ // how many nominal attributes
+ writer.write(String.valueOf(nominalMap.size()) + "\n");
+
+ for (Entry<String, Map<String, Integer>> entry : nominalMap.entrySet()) {
+ // the label of this attribute
+ writer.write(entry.getKey() + "\n");
+ Set<Entry<String, Integer>> attributeValues = entry.getValue().entrySet();
+ // how many values does this attribute have
+ writer.write(attributeValues.size() + "\n");
+ for (Map.Entry<String, Integer> value : attributeValues) {
+ // the value and the value index
+ writer.write(String.format("%s%s%s\n", value.getKey(), delimiter, value.getValue().toString()));
+ }
+ }
+ }
+
+ protected static void writeFile(String outDir,
+ File file,
+ long maxDocs,
+ ARFFModel arffModel,
+ File dictOut,
+ String delimiter,
+ boolean jsonDictonary) throws IOException {
+ log.info("Converting File: {}", file);
+ ARFFModel model = new MapBackedARFFModel(arffModel.getWords(), arffModel.getWordCount() + 1, arffModel
+ .getNominalMap());
+ Iterable<Vector> iteratable = new ARFFVectorIterable(file, model);
+ String outFile = outDir + '/' + file.getName() + ".mvc";
+
+ try (VectorWriter vectorWriter = getSeqFileWriter(outFile)) {
+ long numDocs = vectorWriter.write(iteratable, maxDocs);
+ writeLabelBindings(dictOut, model, delimiter, jsonDictonary);
+ log.info("Wrote: {} vectors", numDocs);
+ }
+ }
+
+ private static VectorWriter getSeqFileWriter(String outFile) throws IOException {
+ Path path = new Path(outFile);
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+ SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class,
+ VectorWritable.class);
+ return new SequenceFileVectorWriter(seqWriter);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java
new file mode 100644
index 0000000..e911b1a
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java
@@ -0,0 +1,282 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.arff;
+
+import java.text.DateFormat;
+import java.text.NumberFormat;
+import java.text.ParseException;
+import java.text.ParsePosition;
+import java.text.SimpleDateFormat;
+import java.util.Collections;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+/**
+ * Holds ARFF information in {@link Map}.
+ */
+public class MapBackedARFFModel implements ARFFModel {
+
+ private static final Pattern QUOTE_PATTERN = Pattern.compile("\"");
+
+ private long wordCount = 1;
+
+ private String relation;
+
+ private final Map<String,Integer> labelBindings;
+ private final Map<Integer,String> idxLabel;
+ private final Map<Integer,ARFFType> typeMap; // key is the vector index, value is the type
+ private final Map<Integer,DateFormat> dateMap;
+ private final Map<String,Map<String,Integer>> nominalMap;
+ private final Map<String,Long> words;
+
+ public MapBackedARFFModel() {
+ this(new HashMap<String,Long>(), 1, new HashMap<String,Map<String,Integer>>());
+ }
+
+ public MapBackedARFFModel(Map<String,Long> words, long wordCount, Map<String,Map<String,Integer>> nominalMap) {
+ this.words = words;
+ this.wordCount = wordCount;
+ labelBindings = new HashMap<>();
+ idxLabel = new HashMap<>();
+ typeMap = new HashMap<>();
+ dateMap = new HashMap<>();
+ this.nominalMap = nominalMap;
+
+ }
+
+ @Override
+ public String getRelation() {
+ return relation;
+ }
+
+ @Override
+ public void setRelation(String relation) {
+ this.relation = relation;
+ }
+
+ /**
+ * Convert a piece of String data at a specific spot into a value
+ *
+ * @param data
+ * The data to convert
+ * @param idx
+ * The position in the ARFF data
+ * @return A double representing the data
+ */
+ @Override
+ public double getValue(String data, int idx) {
+ ARFFType type = typeMap.get(idx);
+ if (type == null) {
+ throw new IllegalArgumentException("Attribute type cannot be NULL, attribute index was: " + idx);
+ }
+ data = QUOTE_PATTERN.matcher(data).replaceAll("");
+ data = data.trim();
+ double result;
+ switch (type) {
+ case NUMERIC:
+ case INTEGER:
+ case REAL:
+ result = processNumeric(data);
+ break;
+ case DATE:
+ result = processDate(data, idx);
+ break;
+ case STRING:
+ // may have quotes
+ result = processString(data);
+ break;
+ case NOMINAL:
+ String label = idxLabel.get(idx);
+ result = processNominal(label, data);
+ break;
+ default:
+ throw new IllegalStateException("Unknown type: " + type);
+ }
+ return result;
+ }
+
+ protected double processNominal(String label, String data) {
+ double result;
+ Map<String,Integer> classes = nominalMap.get(label);
+ if (classes != null) {
+ Integer ord = classes.get(ARFFType.removeQuotes(data));
+ if (ord != null) {
+ result = ord;
+ } else {
+ throw new IllegalStateException("Invalid nominal: " + data + " for label: " + label);
+ }
+ } else {
+ throw new IllegalArgumentException("Invalid nominal label: " + label + " Data: " + data);
+ }
+
+ return result;
+ }
+
+ // Not sure how scalable this is going to be
+ protected double processString(String data) {
+ data = QUOTE_PATTERN.matcher(data).replaceAll("");
+ // map it to an long
+ Long theLong = words.get(data);
+ if (theLong == null) {
+ theLong = wordCount++;
+ words.put(data, theLong);
+ }
+ return theLong;
+ }
+
+ protected static double processNumeric(String data) {
+ if (isNumeric(data)) {
+ return Double.parseDouble(data);
+ }
+ return Double.NaN;
+ }
+
+ public static boolean isNumeric(String str) {
+ NumberFormat formatter = NumberFormat.getInstance();
+ ParsePosition parsePosition = new ParsePosition(0);
+ formatter.parse(str, parsePosition);
+ return str.length() == parsePosition.getIndex();
+ }
+
+ protected double processDate(String data, int idx) {
+ DateFormat format = dateMap.get(idx);
+ if (format == null) {
+ format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.ENGLISH);
+ }
+ double result;
+ try {
+ Date date = format.parse(data);
+ result = date.getTime(); // hmmm, what kind of loss casting long to double?
+ } catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ return result;
+ }
+
+ /**
+ * The vector attributes (labels in Mahout speak), unmodifiable
+ *
+ * @return the map
+ */
+ @Override
+ public Map<String,Integer> getLabelBindings() {
+ return Collections.unmodifiableMap(labelBindings);
+ }
+
+ /**
+ * The map of types encountered
+ *
+ * @return the map
+ */
+ public Map<Integer,ARFFType> getTypeMap() {
+ return Collections.unmodifiableMap(typeMap);
+ }
+
+ /**
+ * Map of Date formatters used
+ *
+ * @return the map
+ */
+ public Map<Integer,DateFormat> getDateMap() {
+ return Collections.unmodifiableMap(dateMap);
+ }
+
+ /**
+ * Map nominals to ids. Should only be modified by calling {@link ARFFModel#addNominal(String, String, int)}
+ *
+ * @return the map
+ */
+ @Override
+ public Map<String,Map<String,Integer>> getNominalMap() {
+ return nominalMap;
+ }
+
+ /**
+ * Immutable map of words to the long id used for those words
+ *
+ * @return The map
+ */
+ @Override
+ public Map<String,Long> getWords() {
+ return words;
+ }
+
+ @Override
+ public Integer getNominalValue(String label, String nominal) {
+ return nominalMap.get(label).get(nominal);
+ }
+
+ @Override
+ public void addNominal(String label, String nominal, int idx) {
+ Map<String,Integer> noms = nominalMap.get(label);
+ if (noms == null) {
+ noms = new HashMap<>();
+ nominalMap.put(label, noms);
+ }
+ noms.put(nominal, idx);
+ }
+
+ @Override
+ public DateFormat getDateFormat(Integer idx) {
+ return dateMap.get(idx);
+ }
+
+ @Override
+ public void addDateFormat(Integer idx, DateFormat format) {
+ dateMap.put(idx, format);
+ }
+
+ @Override
+ public Integer getLabelIndex(String label) {
+ return labelBindings.get(label);
+ }
+
+ @Override
+ public void addLabel(String label, Integer idx) {
+ labelBindings.put(label, idx);
+ idxLabel.put(idx, label);
+ }
+
+ @Override
+ public ARFFType getARFFType(Integer idx) {
+ return typeMap.get(idx);
+ }
+
+ @Override
+ public void addType(Integer idx, ARFFType type) {
+ typeMap.put(idx, type);
+ }
+
+ /**
+ * The count of the number of words seen
+ *
+ * @return the count
+ */
+ @Override
+ public long getWordCount() {
+ return wordCount;
+ }
+
+ @Override
+ public int getLabelSize() {
+ return labelBindings.size();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java
new file mode 100644
index 0000000..3c583fd
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.csv;
+
+import java.io.IOException;
+import java.io.Reader;
+
+import com.google.common.collect.AbstractIterator;
+import org.apache.commons.csv.CSVParser;
+import org.apache.commons.csv.CSVStrategy;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Iterates a CSV file and produces {@link org.apache.mahout.math.Vector}.
+ * <br/>
+ * The Iterator returned throws {@link UnsupportedOperationException} for the {@link java.util.Iterator#remove()}
+ * method.
+ * <p/>
+ * Assumes DenseVector for now, but in the future may have the option of mapping columns to sparse format
+ * <p/>
+ * The Iterator is not thread-safe.
+ */
+public class CSVVectorIterator extends AbstractIterator<Vector> {
+
+ private final CSVParser parser;
+
+ public CSVVectorIterator(Reader reader) {
+ parser = new CSVParser(reader);
+ }
+
+ public CSVVectorIterator(Reader reader, CSVStrategy strategy) {
+ parser = new CSVParser(reader, strategy);
+ }
+
+ @Override
+ protected Vector computeNext() {
+ String[] line;
+ try {
+ line = parser.getLine();
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ if (line == null) {
+ return endOfData();
+ }
+ Vector result = new DenseVector(line.length);
+ for (int i = 0; i < line.length; i++) {
+ result.setQuick(i, Double.parseDouble(line[i]));
+ }
+ return result;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java
new file mode 100644
index 0000000..b5f9f2b
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.io;
+
+import java.io.IOException;
+import java.io.Writer;
+import java.util.Iterator;
+
+import com.google.common.io.Closeables;
+import org.apache.mahout.utils.vectors.TermEntry;
+import org.apache.mahout.utils.vectors.TermInfo;
+
+/**
+ * Write {@link TermInfo} to a {@link Writer} in a textual, delimited format with header.
+ */
+public class DelimitedTermInfoWriter implements TermInfoWriter {
+
+ private final Writer writer;
+ private final String delimiter;
+ private final String field;
+
+ public DelimitedTermInfoWriter(Writer writer, String delimiter, String field) {
+ this.writer = writer;
+ this.delimiter = delimiter;
+ this.field = field;
+ }
+
+ @Override
+ public void write(TermInfo ti) throws IOException {
+
+ Iterator<TermEntry> entIter = ti.getAllEntries();
+ try {
+ writer.write(String.valueOf(ti.totalTerms(field)));
+ writer.write('\n');
+ writer.write("#term" + delimiter + "doc freq" + delimiter + "idx");
+ writer.write('\n');
+ while (entIter.hasNext()) {
+ TermEntry entry = entIter.next();
+ writer.write(entry.getTerm());
+ writer.write(delimiter);
+ writer.write(String.valueOf(entry.getDocFreq()));
+ writer.write(delimiter);
+ writer.write(String.valueOf(entry.getTermIdx()));
+ writer.write('\n');
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ /**
+ * Does NOT close the underlying writer
+ */
+ @Override
+ public void close() {
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java
new file mode 100644
index 0000000..0d763a1
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.io;
+
+import java.io.IOException;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+
+/**
+ * Writes out Vectors to a SequenceFile.
+ *
+ * Closes the writer when done
+ */
+public class SequenceFileVectorWriter implements VectorWriter {
+ private final SequenceFile.Writer writer;
+ private long recNum = 0;
+ public SequenceFileVectorWriter(SequenceFile.Writer writer) {
+ this.writer = writer;
+ }
+
+ @Override
+ public long write(Iterable<Vector> iterable, long maxDocs) throws IOException {
+
+ for (Vector point : iterable) {
+ if (recNum >= maxDocs) {
+ break;
+ }
+ if (point != null) {
+ writer.append(new LongWritable(recNum++), new VectorWritable(point));
+ }
+
+ }
+ return recNum;
+ }
+
+ @Override
+ public void write(Vector vector) throws IOException {
+ writer.append(new LongWritable(recNum++), new VectorWritable(vector));
+
+ }
+
+ @Override
+ public long write(Iterable<Vector> iterable) throws IOException {
+ return write(iterable, Long.MAX_VALUE);
+ }
+
+ @Override
+ public void close() throws IOException {
+ Closeables.close(writer, false);
+ }
+
+ public SequenceFile.Writer getWriter() {
+ return writer;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java
new file mode 100644
index 0000000..e165b45
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java
@@ -0,0 +1,29 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.io;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import org.apache.mahout.utils.vectors.TermInfo;
+
+public interface TermInfoWriter extends Closeable {
+
+ void write(TermInfo ti) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java
new file mode 100644
index 0000000..cc27d1d
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.io;
+
+import java.io.IOException;
+import java.io.Writer;
+
+import com.google.common.io.Closeables;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Write out the vectors to any {@link Writer} using {@link Vector#asFormatString()},
+ * one per line by default.
+ */
+public class TextualVectorWriter implements VectorWriter {
+
+ private final Writer writer;
+
+ public TextualVectorWriter(Writer writer) {
+ this.writer = writer;
+ }
+
+ protected Writer getWriter() {
+ return writer;
+ }
+
+ @Override
+ public long write(Iterable<Vector> iterable) throws IOException {
+ return write(iterable, Long.MAX_VALUE);
+ }
+
+ @Override
+ public long write(Iterable<Vector> iterable, long maxDocs) throws IOException {
+ long result = 0;
+ for (Vector vector : iterable) {
+ if (result >= maxDocs) {
+ break;
+ }
+ write(vector);
+ result++;
+ }
+ return result;
+ }
+
+ @Override
+ public void write(Vector vector) throws IOException {
+ writer.write(vector.asFormatString());
+ writer.write('\n');
+ }
+
+ @Override
+ public void close() throws IOException {
+ Closeables.close(writer, false);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java
new file mode 100644
index 0000000..923e270
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.io;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import org.apache.mahout.math.Vector;
+
+public interface VectorWriter extends Closeable {
+ /**
+ * Write all values in the Iterable to the output
+ * @param iterable The {@link Iterable} to loop over
+ * @return the number of docs written
+ * @throws IOException if there was a problem writing
+ *
+ */
+ long write(Iterable<Vector> iterable) throws IOException;
+
+ /**
+ * Write out a vector
+ *
+ * @param vector The {@link org.apache.mahout.math.Vector} to write
+ * @throws IOException
+ */
+ void write(Vector vector) throws IOException;
+
+ /**
+ * Write the first {@code maxDocs} to the output.
+ * @param iterable The {@link Iterable} to loop over
+ * @param maxDocs the maximum number of docs to write
+ * @return The number of docs written
+ * @throws IOException if there was a problem writing
+ */
+ long write(Iterable<Vector> iterable, long maxDocs) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java
new file mode 100644
index 0000000..ff61a70
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java
@@ -0,0 +1,140 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.lucene;
+
+import com.google.common.collect.AbstractIterator;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.util.BytesRef;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.utils.Bump125;
+import org.apache.mahout.utils.vectors.TermInfo;
+import org.apache.mahout.vectorizer.Weight;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * Iterate over a Lucene index, extracting term vectors.
+ * Subclasses define how much information to retrieve from the Lucene index.
+ */
+public abstract class AbstractLuceneIterator extends AbstractIterator<Vector> {
+ private static final Logger log = LoggerFactory.getLogger(LuceneIterator.class);
+ protected final IndexReader indexReader;
+ protected final String field;
+ protected final TermInfo terminfo;
+ protected final double normPower;
+ protected final Weight weight;
+ protected final Bump125 bump = new Bump125();
+ protected int nextDocId;
+ protected int maxErrorDocs;
+ protected int numErrorDocs;
+ protected long nextLogRecord = bump.increment();
+ protected int skippedErrorMessages;
+
+ public AbstractLuceneIterator(TermInfo terminfo, double normPower, IndexReader indexReader, Weight weight,
+ double maxPercentErrorDocs, String field) {
+ this.terminfo = terminfo;
+ this.normPower = normPower;
+ this.indexReader = indexReader;
+
+ this.weight = weight;
+ this.nextDocId = 0;
+ this.maxErrorDocs = (int) (maxPercentErrorDocs * indexReader.numDocs());
+ this.field = field;
+ }
+
+ /**
+ * Given the document name, derive a name for the vector. This may involve
+ * reading the document from Lucene and setting up any other state that the
+ * subclass wants. This will be called once for each document that the
+ * iterator processes.
+ * @param documentIndex the lucene document index.
+ * @return the name to store in the vector.
+ */
+ protected abstract String getVectorName(int documentIndex) throws IOException;
+
+ @Override
+ protected Vector computeNext() {
+ try {
+ int doc;
+ Terms termFreqVector;
+ String name;
+
+ do {
+ doc = this.nextDocId;
+ nextDocId++;
+
+ if (doc >= indexReader.maxDoc()) {
+ return endOfData();
+ }
+
+ termFreqVector = indexReader.getTermVector(doc, field);
+ name = getVectorName(doc);
+
+ if (termFreqVector == null) {
+ numErrorDocs++;
+ if (numErrorDocs >= maxErrorDocs) {
+ log.error("There are too many documents that do not have a term vector for {}", field);
+ throw new IllegalStateException("There are too many documents that do not have a term vector for "
+ + field);
+ }
+ if (numErrorDocs >= nextLogRecord) {
+ if (skippedErrorMessages == 0) {
+ log.warn("{} does not have a term vector for {}", name, field);
+ } else {
+ log.warn("{} documents do not have a term vector for {}", numErrorDocs, field);
+ }
+ nextLogRecord = bump.increment();
+ skippedErrorMessages = 0;
+ } else {
+ skippedErrorMessages++;
+ }
+ }
+ } while (termFreqVector == null);
+
+ // The loop exits with termFreqVector and name set.
+
+ TermsEnum te = termFreqVector.iterator();
+ BytesRef term;
+ TFDFMapper mapper = new TFDFMapper(indexReader.numDocs(), weight, this.terminfo);
+ mapper.setExpectations(field, termFreqVector.size());
+ while ((term = te.next()) != null) {
+ mapper.map(term, (int) te.totalTermFreq());
+ }
+ Vector result = mapper.getVector();
+ if (result == null) {
+ // TODO is this right? last version would produce null in the iteration in this case, though it
+ // seems like that may not be desirable
+ return null;
+ }
+
+ if (normPower == LuceneIterable.NO_NORMALIZING) {
+ result = new NamedVector(result, name);
+ } else {
+ result = new NamedVector(result.normalize(normPower), name);
+ }
+ return result;
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java
new file mode 100644
index 0000000..0b59ed6
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java
@@ -0,0 +1,79 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.lucene;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.util.BytesRef;
+import org.apache.mahout.utils.vectors.TermEntry;
+import org.apache.mahout.utils.vectors.TermInfo;
+
+
+/**
+ * Caches TermEntries from a single field. Materializes all values in the TermEnum to memory (much like FieldCache)
+ */
+public class CachedTermInfo implements TermInfo {
+
+ private final Map<String, TermEntry> termEntries;
+ private final String field;
+
+ public CachedTermInfo(IndexReader reader, String field, int minDf, int maxDfPercent) throws IOException {
+ this.field = field;
+ Terms t = MultiFields.getTerms(reader, field);
+ TermsEnum te = t.iterator();
+
+ int numDocs = reader.numDocs();
+ double percent = numDocs * maxDfPercent / 100.0;
+ //Should we use a linked hash map so that we know terms are in order?
+ termEntries = new LinkedHashMap<>();
+ int count = 0;
+ BytesRef text;
+ while ((text = te.next()) != null) {
+ int df = te.docFreq();
+ if (df >= minDf && df <= percent) {
+ TermEntry entry = new TermEntry(text.utf8ToString(), count++, df);
+ termEntries.put(entry.getTerm(), entry);
+ }
+ }
+ }
+
+ @Override
+ public int totalTerms(String field) {
+ return termEntries.size();
+ }
+
+ @Override
+ public TermEntry getTermEntry(String field, String term) {
+ if (!this.field.equals(field)) {
+ return null;
+ }
+ return termEntries.get(term);
+ }
+
+ @Override
+ public Iterator<TermEntry> getAllEntries() {
+ return termEntries.values().iterator();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java
new file mode 100644
index 0000000..b2568e7
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java
@@ -0,0 +1,381 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.lucene;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.file.Paths;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.fs.Path;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.PostingsEnum;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.FixedBitSet;
+import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.stats.LogLikelihood;
+import org.apache.mahout.utils.clustering.ClusterDumper;
+import org.apache.mahout.utils.vectors.TermEntry;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Get labels for the cluster using Log Likelihood Ratio (LLR).
+ * <p/>
+ *"The most useful way to think of this (LLR) is as the percentage of in-cluster documents that have the
+ * feature (term) versus the percentage out, keeping in mind that both percentages are uncertain since we have
+ * only a sample of all possible documents." - Ted Dunning
+ * <p/>
+ * More about LLR can be found at : http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html
+ */
+public class ClusterLabels {
+
+ private static final Logger log = LoggerFactory.getLogger(ClusterLabels.class);
+
+ public static final int DEFAULT_MIN_IDS = 50;
+ public static final int DEFAULT_MAX_LABELS = 25;
+
+ private final String indexDir;
+ private final String contentField;
+ private String idField;
+ private final Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints;
+ private String output;
+ private final int minNumIds;
+ private final int maxLabels;
+
+ public ClusterLabels(Path seqFileDir,
+ Path pointsDir,
+ String indexDir,
+ String contentField,
+ int minNumIds,
+ int maxLabels) {
+ this.indexDir = indexDir;
+ this.contentField = contentField;
+ this.minNumIds = minNumIds;
+ this.maxLabels = maxLabels;
+ ClusterDumper clusterDumper = new ClusterDumper(seqFileDir, pointsDir);
+ this.clusterIdToPoints = clusterDumper.getClusterIdToPoints();
+ }
+
+ public void getLabels() throws IOException {
+
+ try (Writer writer = (this.output == null) ?
+ new OutputStreamWriter(System.out, Charsets.UTF_8) : Files.newWriter(new File(this.output), Charsets.UTF_8)){
+ for (Map.Entry<Integer, List<WeightedPropertyVectorWritable>> integerListEntry : clusterIdToPoints.entrySet()) {
+ List<WeightedPropertyVectorWritable> wpvws = integerListEntry.getValue();
+ List<TermInfoClusterInOut> termInfos = getClusterLabels(integerListEntry.getKey(), wpvws);
+ if (termInfos != null) {
+ writer.write('\n');
+ writer.write("Top labels for Cluster ");
+ writer.write(String.valueOf(integerListEntry.getKey()));
+ writer.write(" containing ");
+ writer.write(String.valueOf(wpvws.size()));
+ writer.write(" vectors");
+ writer.write('\n');
+ writer.write("Term \t\t LLR \t\t In-ClusterDF \t\t Out-ClusterDF ");
+ writer.write('\n');
+ for (TermInfoClusterInOut termInfo : termInfos) {
+ writer.write(termInfo.getTerm());
+ writer.write("\t\t");
+ writer.write(String.valueOf(termInfo.getLogLikelihoodRatio()));
+ writer.write("\t\t");
+ writer.write(String.valueOf(termInfo.getInClusterDF()));
+ writer.write("\t\t");
+ writer.write(String.valueOf(termInfo.getOutClusterDF()));
+ writer.write('\n');
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Get the list of labels, sorted by best score.
+ */
+ protected List<TermInfoClusterInOut> getClusterLabels(Integer integer,
+ Collection<WeightedPropertyVectorWritable> wpvws) throws IOException {
+
+ if (wpvws.size() < minNumIds) {
+ log.info("Skipping small cluster {} with size: {}", integer, wpvws.size());
+ return null;
+ }
+
+ log.info("Processing Cluster {} with {} documents", integer, wpvws.size());
+ Directory dir = FSDirectory.open(Paths.get(this.indexDir));
+ IndexReader reader = DirectoryReader.open(dir);
+
+
+ log.info("# of documents in the index {}", reader.numDocs());
+
+ Collection<String> idSet = new HashSet<>();
+ for (WeightedPropertyVectorWritable wpvw : wpvws) {
+ Vector vector = wpvw.getVector();
+ if (vector instanceof NamedVector) {
+ idSet.add(((NamedVector) vector).getName());
+ }
+ }
+
+ int numDocs = reader.numDocs();
+
+ FixedBitSet clusterDocBitset = getClusterDocBitset(reader, idSet, this.idField);
+
+ log.info("Populating term infos from the index");
+
+ /**
+ * This code is as that of CachedTermInfo, with one major change, which is to get the document frequency.
+ *
+ * Since we have deleted the documents out of the cluster, the document frequency for a term should only
+ * include the in-cluster documents. The document frequency obtained from TermEnum reflects the frequency
+ * in the entire index. To get the in-cluster frequency, we need to query the index to get the term
+ * frequencies in each document. The number of results of this call will be the in-cluster document
+ * frequency.
+ */
+ Terms t = MultiFields.getTerms(reader, contentField);
+ TermsEnum te = t.iterator();
+ Map<String, TermEntry> termEntryMap = new LinkedHashMap<>();
+ Bits liveDocs = MultiFields.getLiveDocs(reader); //WARNING: returns null if there are no deletions
+
+
+ int count = 0;
+ BytesRef term;
+ while ((term = te.next()) != null) {
+ FixedBitSet termBitset = new FixedBitSet(reader.maxDoc());
+ PostingsEnum docsEnum = MultiFields.getTermDocsEnum(reader, contentField, term);
+ int docID;
+ while ((docID = docsEnum.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
+ //check to see if we don't have an deletions (null) or if document is live
+ if (liveDocs != null && !liveDocs.get(docID)) {
+ // document is deleted...
+ termBitset.set(docsEnum.docID());
+ }
+ }
+ // AND the term's bitset with cluster doc bitset to get the term's in-cluster frequency.
+ // This modifies the termBitset, but that's fine as we are not using it anywhere else.
+ termBitset.and(clusterDocBitset);
+ int inclusterDF = (int) termBitset.cardinality();
+
+ TermEntry entry = new TermEntry(term.utf8ToString(), count++, inclusterDF);
+ termEntryMap.put(entry.getTerm(), entry);
+
+ }
+
+ List<TermInfoClusterInOut> clusteredTermInfo = new LinkedList<>();
+
+ int clusterSize = wpvws.size();
+
+ for (TermEntry termEntry : termEntryMap.values()) {
+
+ int corpusDF = reader.docFreq(new Term(this.contentField,termEntry.getTerm()));
+ int outDF = corpusDF - termEntry.getDocFreq();
+ int inDF = termEntry.getDocFreq();
+ double logLikelihoodRatio = scoreDocumentFrequencies(inDF, outDF, clusterSize, numDocs);
+ TermInfoClusterInOut termInfoCluster =
+ new TermInfoClusterInOut(termEntry.getTerm(), inDF, outDF, logLikelihoodRatio);
+ clusteredTermInfo.add(termInfoCluster);
+ }
+
+ Collections.sort(clusteredTermInfo);
+ // Cleanup
+ Closeables.close(reader, true);
+ termEntryMap.clear();
+
+ return clusteredTermInfo.subList(0, Math.min(clusteredTermInfo.size(), maxLabels));
+ }
+
+ private static FixedBitSet getClusterDocBitset(IndexReader reader,
+ Collection<String> idSet,
+ String idField) throws IOException {
+ int numDocs = reader.numDocs();
+
+ FixedBitSet bitset = new FixedBitSet(numDocs);
+
+ Set<String> idFieldSelector = null;
+ if (idField != null) {
+ idFieldSelector = new TreeSet<>();
+ idFieldSelector.add(idField);
+ }
+
+
+ for (int i = 0; i < numDocs; i++) {
+ String id;
+ // Use Lucene's internal ID if idField is not specified. Else, get it from the document.
+ if (idField == null) {
+ id = Integer.toString(i);
+ } else {
+ id = reader.document(i, idFieldSelector).get(idField);
+ }
+ if (idSet.contains(id)) {
+ bitset.set(i);
+ }
+ }
+ log.info("Created bitset for in-cluster documents : {}", bitset.cardinality());
+ return bitset;
+ }
+
+ private static double scoreDocumentFrequencies(long inDF, long outDF, long clusterSize, long corpusSize) {
+ long k12 = clusterSize - inDF;
+ long k22 = corpusSize - clusterSize - outDF;
+
+ return LogLikelihood.logLikelihoodRatio(inDF, k12, outDF, k22);
+ }
+
+ public String getIdField() {
+ return idField;
+ }
+
+ public void setIdField(String idField) {
+ this.idField = idField;
+ }
+
+ public String getOutput() {
+ return output;
+ }
+
+ public void setOutput(String output) {
+ this.output = output;
+ }
+
+ public static void main(String[] args) {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option indexOpt = obuilder.withLongName("dir").withRequired(true).withArgument(
+ abuilder.withName("dir").withMinimum(1).withMaximum(1).create())
+ .withDescription("The Lucene index directory").withShortName("d").create();
+
+ Option outputOpt = obuilder.withLongName("output").withRequired(false).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The output file. If not specified, the result is printed on console.").withShortName("o").create();
+
+ Option fieldOpt = obuilder.withLongName("field").withRequired(true).withArgument(
+ abuilder.withName("field").withMinimum(1).withMaximum(1).create())
+ .withDescription("The content field in the index").withShortName("f").create();
+
+ Option idFieldOpt = obuilder.withLongName("idField").withRequired(false).withArgument(
+ abuilder.withName("idField").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The field for the document ID in the index. If null, then the Lucene internal doc "
+ + "id is used which is prone to error if the underlying index changes").withShortName("i").create();
+
+ Option seqOpt = obuilder.withLongName("seqFileDir").withRequired(true).withArgument(
+ abuilder.withName("seqFileDir").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The directory containing Sequence Files for the Clusters").withShortName("s").create();
+
+ Option pointsOpt = obuilder.withLongName("pointsDir").withRequired(true).withArgument(
+ abuilder.withName("pointsDir").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The directory containing points sequence files mapping input vectors to their cluster. ")
+ .withShortName("p").create();
+ Option minClusterSizeOpt = obuilder.withLongName("minClusterSize").withRequired(false).withArgument(
+ abuilder.withName("minClusterSize").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The minimum number of points required in a cluster to print the labels for").withShortName("m").create();
+ Option maxLabelsOpt = obuilder.withLongName("maxLabels").withRequired(false).withArgument(
+ abuilder.withName("maxLabels").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The maximum number of labels to print per cluster").withShortName("x").create();
+ Option helpOpt = DefaultOptionCreator.helpOption();
+
+ Group group = gbuilder.withName("Options").withOption(indexOpt).withOption(idFieldOpt).withOption(outputOpt)
+ .withOption(fieldOpt).withOption(seqOpt).withOption(pointsOpt).withOption(helpOpt)
+ .withOption(maxLabelsOpt).withOption(minClusterSizeOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ Path seqFileDir = new Path(cmdLine.getValue(seqOpt).toString());
+ Path pointsDir = new Path(cmdLine.getValue(pointsOpt).toString());
+ String indexDir = cmdLine.getValue(indexOpt).toString();
+ String contentField = cmdLine.getValue(fieldOpt).toString();
+
+ String idField = null;
+
+ if (cmdLine.hasOption(idFieldOpt)) {
+ idField = cmdLine.getValue(idFieldOpt).toString();
+ }
+ String output = null;
+ if (cmdLine.hasOption(outputOpt)) {
+ output = cmdLine.getValue(outputOpt).toString();
+ }
+ int maxLabels = DEFAULT_MAX_LABELS;
+ if (cmdLine.hasOption(maxLabelsOpt)) {
+ maxLabels = Integer.parseInt(cmdLine.getValue(maxLabelsOpt).toString());
+ }
+ int minSize = DEFAULT_MIN_IDS;
+ if (cmdLine.hasOption(minClusterSizeOpt)) {
+ minSize = Integer.parseInt(cmdLine.getValue(minClusterSizeOpt).toString());
+ }
+ ClusterLabels clusterLabel = new ClusterLabels(seqFileDir, pointsDir, indexDir, contentField, minSize, maxLabels);
+
+ if (idField != null) {
+ clusterLabel.setIdField(idField);
+ }
+ if (output != null) {
+ clusterLabel.setOutput(output);
+ }
+
+ clusterLabel.getLabels();
+
+ } catch (OptionException e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ } catch (IOException e) {
+ log.error("Exception", e);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java
new file mode 100644
index 0000000..876816f
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java
@@ -0,0 +1,349 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.lucene;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.Writer;
+import java.nio.file.Paths;
+import java.util.Iterator;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Files;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.FSDirectory;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.utils.vectors.TermEntry;
+import org.apache.mahout.utils.vectors.TermInfo;
+import org.apache.mahout.utils.vectors.io.DelimitedTermInfoWriter;
+import org.apache.mahout.utils.vectors.io.SequenceFileVectorWriter;
+import org.apache.mahout.utils.vectors.io.VectorWriter;
+import org.apache.mahout.vectorizer.TF;
+import org.apache.mahout.vectorizer.TFIDF;
+import org.apache.mahout.vectorizer.Weight;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class Driver {
+
+ private static final Logger log = LoggerFactory.getLogger(Driver.class);
+
+ private String luceneDir;
+ private String outFile;
+ private String field;
+ private String idField;
+ private String dictOut;
+ private String seqDictOut = "";
+ private String weightType = "tfidf";
+ private String delimiter = "\t";
+ private double norm = LuceneIterable.NO_NORMALIZING;
+ private long maxDocs = Long.MAX_VALUE;
+ private int minDf = 1;
+ private int maxDFPercent = 99;
+ private double maxPercentErrorDocs = 0.0;
+
+ public void dumpVectors() throws IOException {
+
+ File file = new File(luceneDir);
+ Preconditions.checkArgument(file.isDirectory(),
+ "Lucene directory: " + file.getAbsolutePath()
+ + " does not exist or is not a directory");
+ Preconditions.checkArgument(maxDocs >= 0, "maxDocs must be >= 0");
+ Preconditions.checkArgument(minDf >= 1, "minDf must be >= 1");
+ Preconditions.checkArgument(maxDFPercent <= 99, "maxDFPercent must be <= 99");
+
+ Directory dir = FSDirectory.open(Paths.get(file.getAbsolutePath()));
+ IndexReader reader = DirectoryReader.open(dir);
+
+
+ Weight weight;
+ if ("tf".equalsIgnoreCase(weightType)) {
+ weight = new TF();
+ } else if ("tfidf".equalsIgnoreCase(weightType)) {
+ weight = new TFIDF();
+ } else {
+ throw new IllegalArgumentException("Weight type " + weightType + " is not supported");
+ }
+
+ TermInfo termInfo = new CachedTermInfo(reader, field, minDf, maxDFPercent);
+
+ LuceneIterable iterable;
+ if (norm == LuceneIterable.NO_NORMALIZING) {
+ iterable = new LuceneIterable(reader, idField, field, termInfo, weight, LuceneIterable.NO_NORMALIZING,
+ maxPercentErrorDocs);
+ } else {
+ iterable = new LuceneIterable(reader, idField, field, termInfo, weight, norm, maxPercentErrorDocs);
+ }
+
+ log.info("Output File: {}", outFile);
+
+ try (VectorWriter vectorWriter = getSeqFileWriter(outFile)) {
+ long numDocs = vectorWriter.write(iterable, maxDocs);
+ log.info("Wrote: {} vectors", numDocs);
+ }
+
+ File dictOutFile = new File(dictOut);
+ log.info("Dictionary Output file: {}", dictOutFile);
+ Writer writer = Files.newWriter(dictOutFile, Charsets.UTF_8);
+ try (DelimitedTermInfoWriter tiWriter = new DelimitedTermInfoWriter(writer, delimiter, field)) {
+ tiWriter.write(termInfo);
+ }
+
+ if (!"".equals(seqDictOut)) {
+ log.info("SequenceFile Dictionary Output file: {}", seqDictOut);
+
+ Path path = new Path(seqDictOut);
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+ try (SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, Text.class, IntWritable.class)) {
+ Text term = new Text();
+ IntWritable termIndex = new IntWritable();
+ Iterator<TermEntry> termEntries = termInfo.getAllEntries();
+ while (termEntries.hasNext()) {
+ TermEntry termEntry = termEntries.next();
+ term.set(termEntry.getTerm());
+ termIndex.set(termEntry.getTermIdx());
+ seqWriter.append(term, termIndex);
+ }
+ }
+ }
+ }
+
+ public static void main(String[] args) throws IOException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputOpt = obuilder.withLongName("dir").withRequired(true).withArgument(
+ abuilder.withName("dir").withMinimum(1).withMaximum(1).create())
+ .withDescription("The Lucene directory").withShortName("d").create();
+
+ Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription("The output file")
+ .withShortName("o").create();
+
+ Option fieldOpt = obuilder.withLongName("field").withRequired(true).withArgument(
+ abuilder.withName("field").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The field in the index").withShortName("f").create();
+
+ Option idFieldOpt = obuilder.withLongName("idField").withRequired(false).withArgument(
+ abuilder.withName("idField").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The field in the index containing the index. If null, then the Lucene internal doc "
+ + "id is used which is prone to error if the underlying index changes").create();
+
+ Option dictOutOpt = obuilder.withLongName("dictOut").withRequired(true).withArgument(
+ abuilder.withName("dictOut").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The output of the dictionary").withShortName("t").create();
+
+ Option seqDictOutOpt = obuilder.withLongName("seqDictOut").withRequired(false).withArgument(
+ abuilder.withName("seqDictOut").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The output of the dictionary as sequence file").withShortName("st").create();
+
+ Option weightOpt = obuilder.withLongName("weight").withRequired(false).withArgument(
+ abuilder.withName("weight").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The kind of weight to use. Currently TF or TFIDF").withShortName("w").create();
+
+ Option delimiterOpt = obuilder.withLongName("delimiter").withRequired(false).withArgument(
+ abuilder.withName("delimiter").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The delimiter for outputting the dictionary").withShortName("l").create();
+
+ Option powerOpt = obuilder.withLongName("norm").withRequired(false).withArgument(
+ abuilder.withName("norm").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The norm to use, expressed as either a double or \"INF\" if you want to use the Infinite norm. "
+ + "Must be greater or equal to 0. The default is not to normalize").withShortName("n").create();
+
+ Option maxOpt = obuilder.withLongName("max").withRequired(false).withArgument(
+ abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The maximum number of vectors to output. If not specified, then it will loop over all docs")
+ .withShortName("m").create();
+
+ Option minDFOpt = obuilder.withLongName("minDF").withRequired(false).withArgument(
+ abuilder.withName("minDF").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The minimum document frequency. Default is 1").withShortName("md").create();
+
+ Option maxDFPercentOpt = obuilder.withLongName("maxDFPercent").withRequired(false).withArgument(
+ abuilder.withName("maxDFPercent").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The max percentage of docs for the DF. Can be used to remove really high frequency terms."
+ + " Expressed as an integer between 0 and 100. Default is 99.").withShortName("x").create();
+
+ Option maxPercentErrorDocsOpt = obuilder.withLongName("maxPercentErrorDocs").withRequired(false).withArgument(
+ abuilder.withName("maxPercentErrorDocs").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The max percentage of docs that can have a null term vector. These are noise document and can occur if the "
+ + "analyzer used strips out all terms in the target field. This percentage is expressed as a value "
+ + "between 0 and 1. The default is 0.").withShortName("err").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(idFieldOpt).withOption(
+ outputOpt).withOption(delimiterOpt).withOption(helpOpt).withOption(fieldOpt).withOption(maxOpt)
+ .withOption(dictOutOpt).withOption(seqDictOutOpt).withOption(powerOpt).withOption(maxDFPercentOpt)
+ .withOption(weightOpt).withOption(minDFOpt).withOption(maxPercentErrorDocsOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ if (cmdLine.hasOption(inputOpt)) { // Lucene case
+ Driver luceneDriver = new Driver();
+ luceneDriver.setLuceneDir(cmdLine.getValue(inputOpt).toString());
+
+ if (cmdLine.hasOption(maxOpt)) {
+ luceneDriver.setMaxDocs(Long.parseLong(cmdLine.getValue(maxOpt).toString()));
+ }
+
+ if (cmdLine.hasOption(weightOpt)) {
+ luceneDriver.setWeightType(cmdLine.getValue(weightOpt).toString());
+ }
+
+ luceneDriver.setField(cmdLine.getValue(fieldOpt).toString());
+
+ if (cmdLine.hasOption(minDFOpt)) {
+ luceneDriver.setMinDf(Integer.parseInt(cmdLine.getValue(minDFOpt).toString()));
+ }
+
+ if (cmdLine.hasOption(maxDFPercentOpt)) {
+ luceneDriver.setMaxDFPercent(Integer.parseInt(cmdLine.getValue(maxDFPercentOpt).toString()));
+ }
+
+ if (cmdLine.hasOption(powerOpt)) {
+ String power = cmdLine.getValue(powerOpt).toString();
+ if ("INF".equals(power)) {
+ luceneDriver.setNorm(Double.POSITIVE_INFINITY);
+ } else {
+ luceneDriver.setNorm(Double.parseDouble(power));
+ }
+ }
+
+ if (cmdLine.hasOption(idFieldOpt)) {
+ luceneDriver.setIdField(cmdLine.getValue(idFieldOpt).toString());
+ }
+
+ if (cmdLine.hasOption(maxPercentErrorDocsOpt)) {
+ luceneDriver.setMaxPercentErrorDocs(Double.parseDouble(cmdLine.getValue(maxPercentErrorDocsOpt).toString()));
+ }
+
+ luceneDriver.setOutFile(cmdLine.getValue(outputOpt).toString());
+
+ luceneDriver.setDelimiter(cmdLine.hasOption(delimiterOpt) ? cmdLine.getValue(delimiterOpt).toString() : "\t");
+
+ luceneDriver.setDictOut(cmdLine.getValue(dictOutOpt).toString());
+
+ if (cmdLine.hasOption(seqDictOutOpt)) {
+ luceneDriver.setSeqDictOut(cmdLine.getValue(seqDictOutOpt).toString());
+ }
+
+ luceneDriver.dumpVectors();
+ }
+ } catch (OptionException e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+
+ private static VectorWriter getSeqFileWriter(String outFile) throws IOException {
+ Path path = new Path(outFile);
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+ // TODO: Make this parameter driven
+
+ SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class,
+ VectorWritable.class);
+
+ return new SequenceFileVectorWriter(seqWriter);
+ }
+
+ public void setLuceneDir(String luceneDir) {
+ this.luceneDir = luceneDir;
+ }
+
+ public void setMaxDocs(long maxDocs) {
+ this.maxDocs = maxDocs;
+ }
+
+ public void setWeightType(String weightType) {
+ this.weightType = weightType;
+ }
+
+ public void setField(String field) {
+ this.field = field;
+ }
+
+ public void setMinDf(int minDf) {
+ this.minDf = minDf;
+ }
+
+ public void setMaxDFPercent(int maxDFPercent) {
+ this.maxDFPercent = maxDFPercent;
+ }
+
+ public void setNorm(double norm) {
+ this.norm = norm;
+ }
+
+ public void setIdField(String idField) {
+ this.idField = idField;
+ }
+
+ public void setOutFile(String outFile) {
+ this.outFile = outFile;
+ }
+
+ public void setDelimiter(String delimiter) {
+ this.delimiter = delimiter;
+ }
+
+ public void setDictOut(String dictOut) {
+ this.dictOut = dictOut;
+ }
+
+ public void setSeqDictOut(String seqDictOut) {
+ this.seqDictOut = seqDictOut;
+ }
+
+ public void setMaxPercentErrorDocs(double maxPercentErrorDocs) {
+ this.maxPercentErrorDocs = maxPercentErrorDocs;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java
new file mode 100644
index 0000000..1af0ed0
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.vectors.lucene;
+
+import org.apache.lucene.index.IndexReader;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.utils.vectors.TermInfo;
+import org.apache.mahout.vectorizer.Weight;
+
+import java.util.Iterator;
+
+/**
+ * {@link Iterable} counterpart to {@link LuceneIterator}.
+ */
+public final class LuceneIterable implements Iterable<Vector> {
+
+ public static final double NO_NORMALIZING = -1.0;
+
+ private final IndexReader indexReader;
+ private final String field;
+ private final String idField;
+ private final TermInfo terminfo;
+ private final double normPower;
+ private final double maxPercentErrorDocs;
+ private final Weight weight;
+
+ public LuceneIterable(IndexReader reader, String idField, String field, TermInfo terminfo, Weight weight) {
+ this(reader, idField, field, terminfo, weight, NO_NORMALIZING);
+ }
+
+ public LuceneIterable(IndexReader indexReader, String idField, String field, TermInfo terminfo, Weight weight,
+ double normPower) {
+ this(indexReader, idField, field, terminfo, weight, normPower, 0);
+ }
+
+ /**
+ * Produce a LuceneIterable that can create the Vector plus normalize it.
+ *
+ * @param indexReader {@link org.apache.lucene.index.IndexReader} to read the documents from.
+ * @param idField field containing the id. May be null.
+ * @param field field to use for the Vector
+ * @param normPower the normalization value. Must be nonnegative, or {@link #NO_NORMALIZING}
+ * @param maxPercentErrorDocs the percentage of documents in the lucene index that can have a null term vector
+ */
+ public LuceneIterable(IndexReader indexReader,
+ String idField,
+ String field,
+ TermInfo terminfo,
+ Weight weight,
+ double normPower,
+ double maxPercentErrorDocs) {
+ this.indexReader = indexReader;
+ this.idField = idField;
+ this.field = field;
+ this.terminfo = terminfo;
+ this.normPower = normPower;
+ this.maxPercentErrorDocs = maxPercentErrorDocs;
+ this.weight = weight;
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return new LuceneIterator(indexReader, idField, field, terminfo, weight, normPower, maxPercentErrorDocs);
+ }
+}