You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/27 13:14:44 UTC
[18/24] mahout git commit: MAHOUT-2034 Split MR and New Examples into
seperate modules
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
new file mode 100644
index 0000000..b2ce8b1
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.GroupedOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+public class AdaptiveLogisticModelParameters extends LogisticModelParameters {
+
+ private AdaptiveLogisticRegression alr;
+ private int interval = 800;
+ private int averageWindow = 500;
+ private int threads = 4;
+ private String prior = "L1";
+ private double priorOption = Double.NaN;
+ private String auc = null;
+
+ public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {
+
+ if (alr == null) {
+ alr = new AdaptiveLogisticRegression(getMaxTargetCategories(),
+ getNumFeatures(), createPrior(prior, priorOption));
+ alr.setInterval(interval);
+ alr.setAveragingWindow(averageWindow);
+ alr.setThreadCount(threads);
+ alr.setAucEvaluator(createAUC(auc));
+ }
+ return alr;
+ }
+
+ public void checkParameters() {
+ if (prior != null) {
+ String priorUppercase = prior.toUpperCase(Locale.ENGLISH).trim();
+ if (("TP".equals(priorUppercase) || "EBP".equals(priorUppercase)) && Double.isNaN(priorOption)) {
+ throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior.");
+ }
+ }
+ }
+
+ private static PriorFunction createPrior(String cmd, double priorOption) {
+ if (cmd == null) {
+ return null;
+ }
+ if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new L1();
+ }
+ if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new L2();
+ }
+ if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new UniformPrior();
+ }
+ if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new TPrior(priorOption);
+ }
+ if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new ElasticBandPrior(priorOption);
+ }
+
+ return null;
+ }
+
+ private static OnlineAuc createAUC(String cmd) {
+ if (cmd == null) {
+ return null;
+ }
+ if ("GLOBAL".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new GlobalOnlineAuc();
+ }
+ if ("GROUPED".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) {
+ return new GroupedOnlineAuc();
+ }
+ return null;
+ }
+
+ @Override
+ public void saveTo(OutputStream out) throws IOException {
+ if (alr != null) {
+ alr.close();
+ }
+ setTargetCategories(getCsvRecordFactory().getTargetCategories());
+ write(new DataOutputStream(out));
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeUTF(getTargetVariable());
+ out.writeInt(getTypeMap().size());
+ for (Map.Entry<String, String> entry : getTypeMap().entrySet()) {
+ out.writeUTF(entry.getKey());
+ out.writeUTF(entry.getValue());
+ }
+ out.writeInt(getNumFeatures());
+ out.writeInt(getMaxTargetCategories());
+ out.writeInt(getTargetCategories().size());
+ for (String category : getTargetCategories()) {
+ out.writeUTF(category);
+ }
+
+ out.writeInt(interval);
+ out.writeInt(averageWindow);
+ out.writeInt(threads);
+ out.writeUTF(prior);
+ out.writeDouble(priorOption);
+ out.writeUTF(auc);
+
+ // skip csv
+ alr.write(out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ setTargetVariable(in.readUTF());
+ int typeMapSize = in.readInt();
+ Map<String, String> typeMap = new HashMap<>(typeMapSize);
+ for (int i = 0; i < typeMapSize; i++) {
+ String key = in.readUTF();
+ String value = in.readUTF();
+ typeMap.put(key, value);
+ }
+ setTypeMap(typeMap);
+
+ setNumFeatures(in.readInt());
+ setMaxTargetCategories(in.readInt());
+ int targetCategoriesSize = in.readInt();
+ List<String> targetCategories = new ArrayList<>(targetCategoriesSize);
+ for (int i = 0; i < targetCategoriesSize; i++) {
+ targetCategories.add(in.readUTF());
+ }
+ setTargetCategories(targetCategories);
+
+ interval = in.readInt();
+ averageWindow = in.readInt();
+ threads = in.readInt();
+ prior = in.readUTF();
+ priorOption = in.readDouble();
+ auc = in.readUTF();
+
+ alr = new AdaptiveLogisticRegression();
+ alr.readFields(in);
+ }
+
+
+ private static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException {
+ AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters();
+ result.readFields(new DataInputStream(in));
+ return result;
+ }
+
+ public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException {
+ try (InputStream input = new FileInputStream(in)) {
+ return loadFromStream(input);
+ }
+ }
+
+ public int getInterval() {
+ return interval;
+ }
+
+ public void setInterval(int interval) {
+ this.interval = interval;
+ }
+
+ public int getAverageWindow() {
+ return averageWindow;
+ }
+
+ public void setAverageWindow(int averageWindow) {
+ this.averageWindow = averageWindow;
+ }
+
+ public int getThreads() {
+ return threads;
+ }
+
+ public void setThreads(int threads) {
+ this.threads = threads;
+ }
+
+ public String getPrior() {
+ return prior;
+ }
+
+ public void setPrior(String prior) {
+ this.prior = prior;
+ }
+
+ public String getAuc() {
+ return auc;
+ }
+
+ public void setAuc(String auc) {
+ this.auc = auc;
+ }
+
+ public double getPriorOption() {
+ return priorOption;
+ }
+
+ public void setPriorOption(double priorOption) {
+ this.priorOption = priorOption;
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
new file mode 100644
index 0000000..e762924
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Encapsulates everything we need to know about a model and how it reads and vectorizes its input.
+ * This encapsulation allows us to coherently save and restore a model from a file. This also
+ * allows us to keep command line arguments that affect learning in a coherent way.
+ */
+public class LogisticModelParameters implements Writable {
+ private String targetVariable;
+ private Map<String, String> typeMap;
+ private int numFeatures;
+ private boolean useBias;
+ private int maxTargetCategories;
+ private List<String> targetCategories;
+ private double lambda;
+ private double learningRate;
+ private CsvRecordFactory csv;
+ private OnlineLogisticRegression lr;
+
+ /**
+ * Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied
+ * in here is so that we have access to the list of target categories when it comes time to save
+ * the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will
+ * suffice.
+ *
+ * @return The CsvRecordFactory.
+ */
+ public CsvRecordFactory getCsvRecordFactory() {
+ if (csv == null) {
+ csv = new CsvRecordFactory(getTargetVariable(), getTypeMap())
+ .maxTargetValue(getMaxTargetCategories())
+ .includeBiasTerm(useBias());
+ if (targetCategories != null) {
+ csv.defineTargetCategories(targetCategories);
+ }
+ }
+ return csv;
+ }
+
+ /**
+ * Creates a logistic regression trainer using the parameters collected here.
+ *
+ * @return The newly allocated OnlineLogisticRegression object
+ */
+ public OnlineLogisticRegression createRegression() {
+ if (lr == null) {
+ lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1())
+ .lambda(getLambda())
+ .learningRate(getLearningRate())
+ .alpha(1 - 1.0e-3);
+ }
+ return lr;
+ }
+
+ /**
+ * Saves a model to an output stream.
+ */
+ public void saveTo(OutputStream out) throws IOException {
+ Closeables.close(lr, false);
+ targetCategories = getCsvRecordFactory().getTargetCategories();
+ write(new DataOutputStream(out));
+ }
+
+ /**
+ * Reads a model from a stream.
+ */
+ public static LogisticModelParameters loadFrom(InputStream in) throws IOException {
+ LogisticModelParameters result = new LogisticModelParameters();
+ result.readFields(new DataInputStream(in));
+ return result;
+ }
+
+ /**
+ * Reads a model from a file.
+ * @throws IOException If there is an error opening or closing the file.
+ */
+ public static LogisticModelParameters loadFrom(File in) throws IOException {
+ try (InputStream input = new FileInputStream(in)) {
+ return loadFrom(input);
+ }
+ }
+
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeUTF(targetVariable);
+ out.writeInt(typeMap.size());
+ for (Map.Entry<String,String> entry : typeMap.entrySet()) {
+ out.writeUTF(entry.getKey());
+ out.writeUTF(entry.getValue());
+ }
+ out.writeInt(numFeatures);
+ out.writeBoolean(useBias);
+ out.writeInt(maxTargetCategories);
+
+ if (targetCategories == null) {
+ out.writeInt(0);
+ } else {
+ out.writeInt(targetCategories.size());
+ for (String category : targetCategories) {
+ out.writeUTF(category);
+ }
+ }
+ out.writeDouble(lambda);
+ out.writeDouble(learningRate);
+ // skip csv
+ lr.write(out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ targetVariable = in.readUTF();
+ int typeMapSize = in.readInt();
+ typeMap = new HashMap<>(typeMapSize);
+ for (int i = 0; i < typeMapSize; i++) {
+ String key = in.readUTF();
+ String value = in.readUTF();
+ typeMap.put(key, value);
+ }
+ numFeatures = in.readInt();
+ useBias = in.readBoolean();
+ maxTargetCategories = in.readInt();
+ int targetCategoriesSize = in.readInt();
+ targetCategories = new ArrayList<>(targetCategoriesSize);
+ for (int i = 0; i < targetCategoriesSize; i++) {
+ targetCategories.add(in.readUTF());
+ }
+ lambda = in.readDouble();
+ learningRate = in.readDouble();
+ csv = null;
+ lr = new OnlineLogisticRegression();
+ lr.readFields(in);
+ }
+
+ /**
+ * Sets the types of the predictors. This will later be used when reading CSV data. If you don't
+ * use the CSV data and convert to vectors on your own, you don't need to call this.
+ *
+ * @param predictorList The list of variable names.
+ * @param typeList The list of types in the format preferred by CsvRecordFactory.
+ */
+ public void setTypeMap(Iterable<String> predictorList, List<String> typeList) {
+ Preconditions.checkArgument(!typeList.isEmpty(), "Must have at least one type specifier");
+ typeMap = new HashMap<>();
+ Iterator<String> iTypes = typeList.iterator();
+ String lastType = null;
+ for (Object x : predictorList) {
+ // type list can be short .. we just repeat last spec
+ if (iTypes.hasNext()) {
+ lastType = iTypes.next();
+ }
+ typeMap.put(x.toString(), lastType);
+ }
+ }
+
+ /**
+ * Sets the target variable. If you don't use the CSV record factory, then this is irrelevant.
+ *
+ * @param targetVariable The name of the target variable.
+ */
+ public void setTargetVariable(String targetVariable) {
+ this.targetVariable = targetVariable;
+ }
+
+ /**
+ * Sets the number of target categories to be considered.
+ *
+ * @param maxTargetCategories The number of target categories.
+ */
+ public void setMaxTargetCategories(int maxTargetCategories) {
+ this.maxTargetCategories = maxTargetCategories;
+ }
+
+ public void setNumFeatures(int numFeatures) {
+ this.numFeatures = numFeatures;
+ }
+
+ public void setTargetCategories(List<String> targetCategories) {
+ this.targetCategories = targetCategories;
+ maxTargetCategories = targetCategories.size();
+ }
+
+ public List<String> getTargetCategories() {
+ return this.targetCategories;
+ }
+
+ public void setUseBias(boolean useBias) {
+ this.useBias = useBias;
+ }
+
+ public boolean useBias() {
+ return useBias;
+ }
+
+ public String getTargetVariable() {
+ return targetVariable;
+ }
+
+ public Map<String, String> getTypeMap() {
+ return typeMap;
+ }
+
+ public void setTypeMap(Map<String, String> map) {
+ this.typeMap = map;
+ }
+
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+
+ public int getMaxTargetCategories() {
+ return maxTargetCategories;
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public void setLambda(double lambda) {
+ this.lambda = lambda;
+ }
+
+ public double getLearningRate() {
+ return learningRate;
+ }
+
+ public void setLearningRate(double learningRate) {
+ this.learningRate = learningRate;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
new file mode 100644
index 0000000..3ec6a06
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Preconditions;
+
+import java.io.BufferedReader;
+
+/**
+ * Uses the same logic as TrainLogistic and RunLogistic for finding an input, but instead
+ * of processing the input, this class just prints the input to standard out.
+ */
+public final class PrintResourceOrFile {
+
+ private PrintResourceOrFile() {
+ }
+
+ public static void main(String[] args) throws Exception {
+ Preconditions.checkArgument(args.length == 1, "Must have a single argument that names a file or resource.");
+ try (BufferedReader in = TrainLogistic.open(args[0])){
+ String line;
+ while ((line = in.readLine()) != null) {
+ System.out.println(line);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
new file mode 100644
index 0000000..678a8f5
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.HashMap;
+import java.util.Map;
+
+public final class RunAdaptiveLogistic {
+
+ private static String inputFile;
+ private static String modelFile;
+ private static String outputFile;
+ private static String idColumn;
+ private static boolean maxScoreOnly;
+
+ private RunAdaptiveLogistic() {
+ }
+
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (!parseArgs(args)) {
+ return;
+ }
+ AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+ .loadFromFile(new File(modelFile));
+
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ csv.setIdName(idColumn);
+
+ AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
+
+ State<Wrapper, CrossFoldLearner> best = lr.getBest();
+ if (best == null) {
+ output.println("AdaptiveLogisticRegression has not be trained probably.");
+ return;
+ }
+ CrossFoldLearner learner = best.getPayload().getLearner();
+
+ BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
+ int k = 0;
+
+ try (BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile),
+ Charsets.UTF_8))) {
+ out.write(idColumn + ",target,score");
+ out.newLine();
+
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ Map<String, Double> results = new HashMap<>();
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ csv.processLine(line, v, false);
+ Vector scores = learner.classifyFull(v);
+ results.clear();
+ if (maxScoreOnly) {
+ results.put(csv.getTargetLabel(scores.maxValueIndex()),
+ scores.maxValue());
+ } else {
+ for (int i = 0; i < scores.size(); i++) {
+ results.put(csv.getTargetLabel(i), scores.get(i));
+ }
+ }
+
+ for (Map.Entry<String, Double> entry : results.entrySet()) {
+ out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
+ out.newLine();
+ }
+ k++;
+ if (k % 100 == 0) {
+ output.println(k + " records processed");
+ }
+ line = in.readLine();
+ }
+ out.flush();
+ }
+ output.println(k + " records processed totally.");
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get training data").create();
+
+ Option modelFileOption = builder
+ .withLongName("model")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("model").withMaximum(1)
+ .create())
+ .withDescription("where to get the trained model").create();
+
+ Option outputFileOption = builder
+ .withLongName("output")
+ .withRequired(true)
+ .withDescription("the file path to output scores")
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .create();
+
+ Option idColumnOption = builder
+ .withLongName("idcolumn")
+ .withRequired(true)
+ .withDescription("the name of the id column for each record")
+ .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create())
+ .create();
+
+ Option maxScoreOnlyOption = builder
+ .withLongName("maxscoreonly")
+ .withDescription("only output the target label with max scores")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help).withOption(quiet)
+ .withOption(inputFileOption).withOption(modelFileOption)
+ .withOption(outputFileOption).withOption(idColumnOption)
+ .withOption(maxScoreOnlyOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ outputFile = getStringArgument(cmdLine, outputFileOption);
+ idColumn = getStringArgument(cmdLine, idColumnOption);
+ maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption);
+ return true;
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
new file mode 100644
index 0000000..2d57016
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.Locale;
+
+public final class RunLogistic {
+
+ private static String inputFile;
+ private static String modelFile;
+ private static boolean showAuc;
+ private static boolean showScores;
+ private static boolean showConfusion;
+
+ private RunLogistic() {
+ }
+
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (parseArgs(args)) {
+ if (!showAuc && !showConfusion && !showScores) {
+ showAuc = true;
+ showConfusion = true;
+ }
+
+ Auc collector = new Auc();
+ LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));
+
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ OnlineLogisticRegression lr = lmp.createRegression();
+ BufferedReader in = TrainLogistic.open(inputFile);
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ if (showScores) {
+ output.println("\"target\",\"model-output\",\"log-likelihood\"");
+ }
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ int target = csv.processLine(line, v);
+
+ double score = lr.classifyScalar(v);
+ if (showScores) {
+ output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
+ }
+ collector.add(target, score);
+ line = in.readLine();
+ }
+
+ if (showAuc) {
+ output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
+ }
+ if (showConfusion) {
+ Matrix m = collector.confusion();
+ output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",
+ m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ m = collector.entropy();
+ output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",
+ m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ }
+ }
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
+
+ Option auc = builder.withLongName("auc").withDescription("print AUC").create();
+ Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create();
+
+ Option scores = builder.withLongName("scores").withDescription("print scores").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(quiet)
+ .withOption(auc)
+ .withOption(scores)
+ .withOption(confusion)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ showAuc = getBooleanArgument(cmdLine, auc);
+ showScores = getBooleanArgument(cmdLine, scores);
+ showConfusion = getBooleanArgument(cmdLine, confusion);
+
+ return true;
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
new file mode 100644
index 0000000..c657803
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
@@ -0,0 +1,151 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Multiset;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.TreeMap;
+
+public final class SGDHelper {
+
+ private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
+
+ private SGDHelper() {
+ }
+
+ public static void dissect(int leakType,
+ Dictionary dictionary,
+ AdaptiveLogisticRegression learningAlgorithm,
+ Iterable<File> files, Multiset<String> overallCounts) throws IOException {
+ CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
+ model.close();
+
+ Map<String, Set<Integer>> traceDictionary = new TreeMap<>();
+ ModelDissector md = new ModelDissector();
+
+ NewsgroupHelper helper = new NewsgroupHelper();
+ helper.getEncoder().setTraceDictionary(traceDictionary);
+ helper.getBias().setTraceDictionary(traceDictionary);
+
+ for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
+ String ng = file.getParentFile().getName();
+ int actual = dictionary.intern(ng);
+
+ traceDictionary.clear();
+ Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
+ md.update(v, traceDictionary, model);
+ }
+
+ List<String> ngNames = new ArrayList<>(dictionary.values());
+ List<ModelDissector.Weight> weights = md.summary(100);
+ System.out.println("============");
+ System.out.println("Model Dissection");
+ for (ModelDissector.Weight w : weights) {
+ System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s%n",
+ w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
+ w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
+ }
+ }
+
+ public static List<File> permute(Iterable<File> files, Random rand) {
+ List<File> r = new ArrayList<>();
+ for (File file : files) {
+ int i = rand.nextInt(r.size() + 1);
+ if (i == r.size()) {
+ r.add(file);
+ } else {
+ r.add(r.get(i));
+ r.set(i, file);
+ }
+ }
+ return r;
+ }
+
+ static void analyzeState(SGDInfo info, int leakType, int k, State<AdaptiveLogisticRegression.Wrapper,
+ CrossFoldLearner> best) throws IOException {
+ int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
+ int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
+ double maxBeta;
+ double nonZeros;
+ double positive;
+ double norm;
+
+ double lambda = 0;
+ double mu = 0;
+
+ if (best != null) {
+ CrossFoldLearner state = best.getPayload().getLearner();
+ info.setAverageCorrect(state.percentCorrect());
+ info.setAverageLL(state.logLikelihood());
+
+ OnlineLogisticRegression model = state.getModels().get(0);
+ // finish off pending regularization
+ model.close();
+
+ Matrix beta = model.getBeta();
+ maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
+ nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return Math.abs(v) > 1.0e-6 ? 1 : 0;
+ }
+ });
+ positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return v > 0 ? 1 : 0;
+ }
+ });
+ norm = beta.aggregate(Functions.PLUS, Functions.ABS);
+
+ lambda = best.getMappedParams()[0];
+ mu = best.getMappedParams()[1];
+ } else {
+ maxBeta = 0;
+ nonZeros = 0;
+ positive = 0;
+ norm = 0;
+ }
+ if (k % (bump * scale) == 0) {
+ if (best != null) {
+ File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group-" + k + ".model");
+ ModelSerializer.writeBinary(modelFile.getAbsolutePath(), best.getPayload().getLearner().getModels().get(0));
+ }
+
+ info.setStep(info.getStep() + 0.25);
+ System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
+ System.out.printf("%d\t%.3f\t%.2f\t%s%n",
+ k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
new file mode 100644
index 0000000..be55d43
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+final class SGDInfo {
+
+ private double averageLL;
+ private double averageCorrect;
+ private double step;
+ private int[] bumps = {1, 2, 5};
+
+ double getAverageLL() {
+ return averageLL;
+ }
+
+ void setAverageLL(double averageLL) {
+ this.averageLL = averageLL;
+ }
+
+ double getAverageCorrect() {
+ return averageCorrect;
+ }
+
+ void setAverageCorrect(double averageCorrect) {
+ this.averageCorrect = averageCorrect;
+ }
+
+ double getStep() {
+ return step;
+ }
+
+ void setStep(double step) {
+ this.step = step;
+ }
+
+ int[] getBumps() {
+ return bumps;
+ }
+
+ void setBumps(int[] bumps) {
+ this.bumps = bumps;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
new file mode 100644
index 0000000..b3da452
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Joiner;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.list.IntArrayList;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.Closeable;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Shows how different encoding choices can make big speed differences.
+ * <p/>
+ * Run with command line options --generate 1000000 test.csv to generate a million data lines in
+ * test.csv.
+ * <p/>
+ * Run with command line options --parser test.csv to time how long it takes to parse and encode
+ * those million data points
+ * <p/>
+ * Run with command line options --fast test.csv to time how long it takes to parse and encode those
+ * million data points using byte-level parsing and direct value encoding.
+ * <p/>
+ * This doesn't demonstrate text encoding which is subject to somewhat different tricks. The basic
+ * idea of caching hash locations and byte level parsing still very much applies to text, however.
+ */
+public final class SimpleCsvExamples {
+
+ public static final char SEPARATOR_CHAR = '\t';
+ private static final int FIELDS = 100;
+
+ private static final Logger log = LoggerFactory.getLogger(SimpleCsvExamples.class);
+
+ private SimpleCsvExamples() {}
+
+ public static void main(String[] args) throws IOException {
+ FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS];
+ for (int i = 0; i < FIELDS; i++) {
+ encoder[i] = new ConstantValueEncoder("v" + 1);
+ }
+
+ OnlineSummarizer[] s = new OnlineSummarizer[FIELDS];
+ for (int i = 0; i < FIELDS; i++) {
+ s[i] = new OnlineSummarizer();
+ }
+ long t0 = System.currentTimeMillis();
+ Vector v = new DenseVector(1000);
+ if ("--generate".equals(args[0])) {
+ try (PrintWriter out =
+ new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8))) {
+ int n = Integer.parseInt(args[1]);
+ for (int i = 0; i < n; i++) {
+ Line x = Line.generate();
+ out.println(x);
+ }
+ }
+ } else if ("--parse".equals(args[0])) {
+ try (BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8)){
+ String line = in.readLine();
+ while (line != null) {
+ v.assign(0);
+ Line x = new Line(line);
+ for (int i = 0; i < FIELDS; i++) {
+ s[i].add(x.getDouble(i));
+ encoder[i].addToVector(x.get(i), v);
+ }
+ line = in.readLine();
+ }
+ }
+ String separator = "";
+ for (int i = 0; i < FIELDS; i++) {
+ System.out.printf("%s%.3f", separator, s[i].getMean());
+ separator = ",";
+ }
+ } else if ("--fast".equals(args[0])) {
+ try (FastLineReader in = new FastLineReader(new FileInputStream(args[1]))){
+ FastLine line = in.read();
+ while (line != null) {
+ v.assign(0);
+ for (int i = 0; i < FIELDS; i++) {
+ double z = line.getDouble(i);
+ s[i].add(z);
+ encoder[i].addToVector((byte[]) null, z, v);
+ }
+ line = in.read();
+ }
+ }
+
+ String separator = "";
+ for (int i = 0; i < FIELDS; i++) {
+ System.out.printf("%s%.3f", separator, s[i].getMean());
+ separator = ",";
+ }
+ }
+ System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0);
+ }
+
+
+ private static final class Line {
+ private static final Splitter ON_TABS = Splitter.on(SEPARATOR_CHAR).trimResults();
+ public static final Joiner WITH_COMMAS = Joiner.on(SEPARATOR_CHAR);
+
+ public static final Random RAND = RandomUtils.getRandom();
+
+ private final List<String> data;
+
+ private Line(CharSequence line) {
+ data = Lists.newArrayList(ON_TABS.split(line));
+ }
+
+ private Line() {
+ data = new ArrayList<>();
+ }
+
+ public double getDouble(int field) {
+ return Double.parseDouble(data.get(field));
+ }
+
+ /**
+ * Generate a random line with 20 fields each with integer values.
+ *
+ * @return A new line with data.
+ */
+ public static Line generate() {
+ Line r = new Line();
+ for (int i = 0; i < FIELDS; i++) {
+ double mean = ((i + 1) * 257) % 50 + 1;
+ r.data.add(Integer.toString(randomValue(mean)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a random exponentially distributed integer with a particular mean value. This is
+ * just a way to create more small numbers than big numbers.
+ *
+ * @param mean mean of the distribution
+ * @return random exponentially distributed integer with the specific mean
+ */
+ private static int randomValue(double mean) {
+ return (int) (-mean * Math.log1p(-RAND.nextDouble()));
+ }
+
+ @Override
+ public String toString() {
+ return WITH_COMMAS.join(data);
+ }
+
+ public String get(int field) {
+ return data.get(field);
+ }
+ }
+
+ private static final class FastLine {
+
+ private final ByteBuffer base;
+ private final IntArrayList start = new IntArrayList();
+ private final IntArrayList length = new IntArrayList();
+
+ private FastLine(ByteBuffer base) {
+ this.base = base;
+ }
+
+ public static FastLine read(ByteBuffer buf) {
+ FastLine r = new FastLine(buf);
+ r.start.add(buf.position());
+ int offset = buf.position();
+ while (offset < buf.limit()) {
+ int ch = buf.get();
+ offset = buf.position();
+ switch (ch) {
+ case '\n':
+ r.length.add(offset - r.start.get(r.length.size()) - 1);
+ return r;
+ case SEPARATOR_CHAR:
+ r.length.add(offset - r.start.get(r.length.size()) - 1);
+ r.start.add(offset);
+ break;
+ default:
+ // nothing to do for now
+ }
+ }
+ throw new IllegalArgumentException("Not enough bytes in buffer");
+ }
+
+ public double getDouble(int field) {
+ int offset = start.get(field);
+ int size = length.get(field);
+ switch (size) {
+ case 1:
+ return base.get(offset) - '0';
+ case 2:
+ return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0';
+ default:
+ double r = 0;
+ for (int i = 0; i < size; i++) {
+ r = 10 * r + base.get(offset + i) - '0';
+ }
+ return r;
+ }
+ }
+ }
+
+ private static final class FastLineReader implements Closeable {
+ private final InputStream in;
+ private final ByteBuffer buf = ByteBuffer.allocate(100000);
+
+ private FastLineReader(InputStream in) throws IOException {
+ this.in = in;
+ buf.limit(0);
+ fillBuffer();
+ }
+
+ public FastLine read() throws IOException {
+ fillBuffer();
+ if (buf.remaining() > 0) {
+ return FastLine.read(buf);
+ } else {
+ return null;
+ }
+ }
+
+ private void fillBuffer() throws IOException {
+ if (buf.remaining() < 10000) {
+ buf.compact();
+ int n = in.read(buf.array(), buf.position(), buf.remaining());
+ if (n == -1) {
+ buf.flip();
+ } else {
+ buf.limit(buf.position() + n);
+ buf.position(0);
+ }
+ }
+ }
+
+ @Override
+ public void close() {
+ try {
+ Closeables.close(in, true);
+ } catch (IOException e) {
+ log.error(e.getMessage(), e);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
new file mode 100644
index 0000000..074f774
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+
+/**
+ * Run the ASF email, as trained by TrainASFEmail
+ */
+public final class TestASFEmail {
+
+ private String inputFile;
+ private String modelFile;
+
+ private TestASFEmail() {}
+
+ public static void main(String[] args) throws IOException {
+ TestASFEmail runner = new TestASFEmail();
+ if (runner.parseArgs(args)) {
+ runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ }
+
+ public void run(PrintWriter output) throws IOException {
+
+ File base = new File(inputFile);
+ //contains the best model
+ OnlineLogisticRegression classifier =
+ ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
+
+
+ Dictionary asfDictionary = new Dictionary();
+ Configuration conf = new Configuration();
+ PathFilter testFilter = new PathFilter() {
+ @Override
+ public boolean accept(Path path) {
+ return path.getName().contains("test");
+ }
+ };
+ SequenceFileDirIterator<Text, VectorWritable> iter =
+ new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
+ null, true, conf);
+
+ long numItems = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ asfDictionary.intern(next.getFirst().toString());
+ numItems++;
+ }
+
+ System.out.println(numItems + " test files");
+ ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT");
+ iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
+ null, true, conf);
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ String ng = next.getFirst().toString();
+
+ int actual = asfDictionary.intern(ng);
+ Vector result = classifier.classifyFull(next.getSecond().get());
+ int cat = result.maxValueIndex();
+ double score = result.maxValue();
+ double ll = classifier.logLikelihood(actual, next.getSecond().get());
+ ClassifierResult cr = new ClassifierResult(asfDictionary.values().get(cat), score, ll);
+ ra.addInstance(asfDictionary.values().get(actual), cr);
+
+ }
+ output.println(ra);
+ }
+
+ boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = (String) cmdLine.getValue(inputFileOption);
+ modelFile = (String) cmdLine.getValue(modelFileOption);
+ return true;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
new file mode 100644
index 0000000..f0316e9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
@@ -0,0 +1,141 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Run the 20 news groups test data through SGD, as trained by {@link org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ */
+public final class TestNewsGroups {
+
+ private String inputFile;
+ private String modelFile;
+
+ private TestNewsGroups() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ TestNewsGroups runner = new TestNewsGroups();
+ if (runner.parseArgs(args)) {
+ runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ }
+
+ public void run(PrintWriter output) throws IOException {
+
+ File base = new File(inputFile);
+ //contains the best model
+ OnlineLogisticRegression classifier =
+ ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
+
+ Dictionary newsGroups = new Dictionary();
+ Multiset<String> overallCounts = HashMultiset.create();
+
+ List<File> files = new ArrayList<>();
+ for (File newsgroup : base.listFiles()) {
+ if (newsgroup.isDirectory()) {
+ newsGroups.intern(newsgroup.getName());
+ files.addAll(Arrays.asList(newsgroup.listFiles()));
+ }
+ }
+ System.out.println(files.size() + " test files");
+ ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
+ for (File file : files) {
+ String ng = file.getParentFile().getName();
+
+ int actual = newsGroups.intern(ng);
+ NewsgroupHelper helper = new NewsgroupHelper();
+ //no leak type ensures this is a normal vector
+ Vector input = helper.encodeFeatureVector(file, actual, 0, overallCounts);
+ Vector result = classifier.classifyFull(input);
+ int cat = result.maxValueIndex();
+ double score = result.maxValue();
+ double ll = classifier.logLikelihood(actual, input);
+ ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
+ ra.addInstance(newsGroups.values().get(actual), cr);
+
+ }
+ output.println(ra);
+ }
+
+ boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = (String) cmdLine.getValue(inputFileOption);
+ modelFile = (String) cmdLine.getValue(modelFileOption);
+ return true;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
new file mode 100644
index 0000000..e681f92
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+public final class TrainASFEmail extends AbstractJob {
+
+ private TrainASFEmail() {
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption("categories", "nc", "The number of categories to train on", true);
+ addOption("cardinality", "c", "The size of the vectors to use", "100000");
+ addOption("threads", "t", "The number of threads to use in the learner", "20");
+ addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. "
+ + "Higher values require more memory.", "5");
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ File base = new File(getInputPath().toString());
+
+ Multiset<String> overallCounts = HashMultiset.create();
+ File output = new File(getOutputPath().toString());
+ output.mkdirs();
+ int numCats = Integer.parseInt(getOption("categories"));
+ int cardinality = Integer.parseInt(getOption("cardinality", "100000"));
+ int threadCount = Integer.parseInt(getOption("threads", "20"));
+ int poolSize = Integer.parseInt(getOption("poolSize", "5"));
+ Dictionary asfDictionary = new Dictionary();
+ AdaptiveLogisticRegression learningAlgorithm =
+ new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize);
+ learningAlgorithm.setInterval(800);
+ learningAlgorithm.setAveragingWindow(500);
+
+ //We ran seq2encoded and split input already, so let's just build up the dictionary
+ Configuration conf = new Configuration();
+ PathFilter trainFilter = new PathFilter() {
+ @Override
+ public boolean accept(Path path) {
+ return path.getName().contains("training");
+ }
+ };
+ SequenceFileDirIterator<Text, VectorWritable> iter =
+ new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf);
+ long numItems = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ asfDictionary.intern(next.getFirst().toString());
+ numItems++;
+ }
+
+ System.out.println(numItems + " training files");
+
+ SGDInfo info = new SGDInfo();
+
+ iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter,
+ null, true, conf);
+ int k = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ String ng = next.getFirst().toString();
+ int actual = asfDictionary.intern(ng);
+ //we already have encoded
+ learningAlgorithm.train(actual, next.getSecond().get());
+ k++;
+ State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
+
+ SGDHelper.analyzeState(info, 0, k, best);
+ }
+ learningAlgorithm.close();
+ //TODO: how to dissection since we aren't processing the files here
+ //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, overallCounts);
+ System.out.println("exiting main, writing model to " + output);
+
+ ModelSerializer.writeBinary(output + "/asf.model",
+ learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+
+ List<Integer> counts = new ArrayList<>();
+ System.out.println("Word counts");
+ for (String count : overallCounts.elementSet()) {
+ counts.add(overallCounts.count(count));
+ }
+ Collections.sort(counts, Ordering.natural().reverse());
+ k = 0;
+ for (Integer count : counts) {
+ System.out.println(k + "\t" + count);
+ k++;
+ if (k > 1000) {
+ break;
+ }
+ }
+ return 0;
+ }
+
+ public static void main(String[] args) throws Exception {
+ TrainASFEmail trainer = new TrainASFEmail();
+ trainer.run(args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
new file mode 100644
index 0000000..defb5b9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.io.Resources;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+
+public final class TrainAdaptiveLogistic {
+
+ private static String inputFile;
+ private static String outputFile;
+ private static AdaptiveLogisticModelParameters lmp;
+ private static int passes;
+ private static boolean showperf;
+ private static int skipperfnum = 99;
+ private static AdaptiveLogisticRegression model;
+
+ private TrainAdaptiveLogistic() {
+ }
+
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (parseArgs(args)) {
+
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ model = lmp.createAdaptiveLogisticRegression();
+ State<Wrapper, CrossFoldLearner> best;
+ CrossFoldLearner learner = null;
+
+ int k = 0;
+ for (int pass = 0; pass < passes; pass++) {
+ BufferedReader in = open(inputFile);
+
+ // read variable names
+ csv.firstLine(in.readLine());
+
+ String line = in.readLine();
+ while (line != null) {
+ // for each new line, get target and predictors
+ Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
+ int targetValue = csv.processLine(line, input);
+
+ // update model
+ model.train(targetValue, input);
+ k++;
+
+ if (showperf && (k % (skipperfnum + 1) == 0)) {
+
+ best = model.getBest();
+ if (best != null) {
+ learner = best.getPayload().getLearner();
+ }
+ if (learner != null) {
+ double averageCorrect = learner.percentCorrect();
+ double averageLL = learner.logLikelihood();
+ output.printf("%d\t%.3f\t%.2f%n",
+ k, averageLL, averageCorrect * 100);
+ } else {
+ output.printf(Locale.ENGLISH,
+ "%10d %2d %s%n", k, targetValue,
+ "AdaptiveLogisticRegression has not found a good model ......");
+ }
+ }
+ line = in.readLine();
+ }
+ in.close();
+ }
+
+ best = model.getBest();
+ if (best != null) {
+ learner = best.getPayload().getLearner();
+ }
+ if (learner == null) {
+ output.println("AdaptiveLogisticRegression has failed to train a model.");
+ return;
+ }
+
+ try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
+ lmp.saveTo(modelOutput);
+ }
+
+ OnlineLogisticRegression lr = learner.getModels().get(0);
+ output.println(lmp.getNumFeatures());
+ output.println(lmp.getTargetVariable() + " ~ ");
+ String sep = "";
+ for (String v : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, 0, csv, v);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
+ sep = " + ";
+ }
+ }
+ output.printf("%n");
+
+ for (int row = 0; row < lr.getBeta().numRows(); row++) {
+ for (String key : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, row, csv, key);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
+ }
+ }
+ for (int column = 0; column < lr.getBeta().numCols(); column++) {
+ output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
+ }
+ output.println();
+ }
+ }
+
+ }
+
+ private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
+ double weight = 0;
+ for (Integer column : csv.getTraceDictionary().get(predictor)) {
+ weight += lr.getBeta().get(row, column);
+ }
+ return weight;
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option showperf = builder
+ .withLongName("showperf")
+ .withDescription("output performance measures during training")
+ .create();
+
+ Option inputFile = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get training data").create();
+
+ Option outputFile = builder
+ .withLongName("output")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("output").withMaximum(1)
+ .create())
+ .withDescription("where to write the model content").create();
+
+ Option threads = builder.withLongName("threads")
+ .withArgument(
+ argumentBuilder.withName("threads").withDefault("4").create())
+ .withDescription("the number of threads AdaptiveLogisticRegression uses")
+ .create();
+
+
+ Option predictors = builder.withLongName("predictors")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("predictors").create())
+ .withDescription("a list of predictor variables").create();
+
+ Option types = builder
+ .withLongName("types")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("types").create())
+ .withDescription(
+ "a list of predictor variable types (numeric, word, or text)")
+ .create();
+
+ Option target = builder
+ .withLongName("target")
+ .withDescription("the name of the target variable")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("target").withMaximum(1)
+ .create())
+ .create();
+
+ Option targetCategories = builder
+ .withLongName("categories")
+ .withDescription("the number of target categories to be considered")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("categories").withMaximum(1).create())
+ .create();
+
+
+ Option features = builder
+ .withLongName("features")
+ .withDescription("the number of internal hashed features to use")
+ .withArgument(
+ argumentBuilder.withName("numFeatures")
+ .withDefault("1000").withMaximum(1).create())
+ .create();
+
+ Option passes = builder
+ .withLongName("passes")
+ .withDescription("the number of times to pass over the input data")
+ .withArgument(
+ argumentBuilder.withName("passes").withDefault("2")
+ .withMaximum(1).create())
+ .create();
+
+ Option interval = builder.withLongName("interval")
+ .withArgument(
+ argumentBuilder.withName("interval").withDefault("500").create())
+ .withDescription("the interval property of AdaptiveLogisticRegression")
+ .create();
+
+ Option window = builder.withLongName("window")
+ .withArgument(
+ argumentBuilder.withName("window").withDefault("800").create())
+ .withDescription("the average propery of AdaptiveLogisticRegression")
+ .create();
+
+ Option skipperfnum = builder.withLongName("skipperfnum")
+ .withArgument(
+ argumentBuilder.withName("skipperfnum").withDefault("99").create())
+ .withDescription("show performance measures every (skipperfnum + 1) rows")
+ .create();
+
+ Option prior = builder.withLongName("prior")
+ .withArgument(
+ argumentBuilder.withName("prior").withDefault("L1").create())
+ .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up")
+ .create();
+
+ Option priorOption = builder.withLongName("prioroption")
+ .withArgument(
+ argumentBuilder.withName("prioroption").create())
+ .withDescription("constructor parameter for ElasticBandPrior and TPrior")
+ .create();
+
+ Option auc = builder.withLongName("auc")
+ .withArgument(
+ argumentBuilder.withName("auc").withDefault("global").create())
+ .withDescription("the auc to use: global or grouped")
+ .create();
+
+
+
+ Group normalArgs = new GroupBuilder().withOption(help)
+ .withOption(quiet).withOption(inputFile).withOption(outputFile)
+ .withOption(target).withOption(targetCategories)
+ .withOption(predictors).withOption(types).withOption(passes)
+ .withOption(interval).withOption(window).withOption(threads)
+ .withOption(prior).withOption(features).withOption(showperf)
+ .withOption(skipperfnum).withOption(priorOption).withOption(auc)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+ TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine,
+ outputFile);
+
+ List<String> typeList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(types)) {
+ typeList.add(x.toString());
+ }
+
+ List<String> predictorList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(predictors)) {
+ predictorList.add(x.toString());
+ }
+
+ lmp = new AdaptiveLogisticModelParameters();
+ lmp.setTargetVariable(getStringArgument(cmdLine, target));
+ lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
+ lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
+ lmp.setInterval(getIntegerArgument(cmdLine, interval));
+ lmp.setAverageWindow(getIntegerArgument(cmdLine, window));
+ lmp.setThreads(getIntegerArgument(cmdLine, threads));
+ lmp.setAuc(getStringArgument(cmdLine, auc));
+ lmp.setPrior(getStringArgument(cmdLine, prior));
+ if (cmdLine.getValue(priorOption) != null) {
+ lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption));
+ }
+ lmp.setTypeMap(predictorList, typeList);
+ TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf);
+ TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum);
+ TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes);
+
+ lmp.checkParameters();
+
+ return true;
+ }
+
+ private static String getStringArgument(CommandLine cmdLine,
+ Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static int getIntegerArgument(CommandLine cmdLine, Option features) {
+ return Integer.parseInt((String) cmdLine.getValue(features));
+ }
+
+ private static double getDoubleArgument(CommandLine cmdLine, Option op) {
+ return Double.parseDouble((String) cmdLine.getValue(op));
+ }
+
+ public static AdaptiveLogisticRegression getModel() {
+ return model;
+ }
+
+ public static LogisticModelParameters getParameters() {
+ return lmp;
+ }
+
+ static BufferedReader open(String inputFile) throws IOException {
+ InputStream in;
+ try {
+ in = Resources.getResource(inputFile).openStream();
+ } catch (IllegalArgumentException e) {
+ in = new FileInputStream(new File(inputFile));
+ }
+ return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
+ }
+
+}