You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/27 13:14:33 UTC
[07/24] mahout git commit: MAHOUT-2034 Split MR and New Examples into
seperate modules
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
deleted file mode 100644
index e762924..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
+++ /dev/null
@@ -1,265 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.base.Preconditions;
-import com.google.common.io.Closeables;
-import java.io.DataInput;
-import java.io.DataInputStream;
-import java.io.DataOutput;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import org.apache.hadoop.io.Writable;
-
-/**
- * Encapsulates everything we need to know about a model and how it reads and vectorizes its input.
- * This encapsulation allows us to coherently save and restore a model from a file. This also
- * allows us to keep command line arguments that affect learning in a coherent way.
- */
-public class LogisticModelParameters implements Writable {
- private String targetVariable;
- private Map<String, String> typeMap;
- private int numFeatures;
- private boolean useBias;
- private int maxTargetCategories;
- private List<String> targetCategories;
- private double lambda;
- private double learningRate;
- private CsvRecordFactory csv;
- private OnlineLogisticRegression lr;
-
- /**
- * Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied
- * in here is so that we have access to the list of target categories when it comes time to save
- * the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will
- * suffice.
- *
- * @return The CsvRecordFactory.
- */
- public CsvRecordFactory getCsvRecordFactory() {
- if (csv == null) {
- csv = new CsvRecordFactory(getTargetVariable(), getTypeMap())
- .maxTargetValue(getMaxTargetCategories())
- .includeBiasTerm(useBias());
- if (targetCategories != null) {
- csv.defineTargetCategories(targetCategories);
- }
- }
- return csv;
- }
-
- /**
- * Creates a logistic regression trainer using the parameters collected here.
- *
- * @return The newly allocated OnlineLogisticRegression object
- */
- public OnlineLogisticRegression createRegression() {
- if (lr == null) {
- lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1())
- .lambda(getLambda())
- .learningRate(getLearningRate())
- .alpha(1 - 1.0e-3);
- }
- return lr;
- }
-
- /**
- * Saves a model to an output stream.
- */
- public void saveTo(OutputStream out) throws IOException {
- Closeables.close(lr, false);
- targetCategories = getCsvRecordFactory().getTargetCategories();
- write(new DataOutputStream(out));
- }
-
- /**
- * Reads a model from a stream.
- */
- public static LogisticModelParameters loadFrom(InputStream in) throws IOException {
- LogisticModelParameters result = new LogisticModelParameters();
- result.readFields(new DataInputStream(in));
- return result;
- }
-
- /**
- * Reads a model from a file.
- * @throws IOException If there is an error opening or closing the file.
- */
- public static LogisticModelParameters loadFrom(File in) throws IOException {
- try (InputStream input = new FileInputStream(in)) {
- return loadFrom(input);
- }
- }
-
-
- @Override
- public void write(DataOutput out) throws IOException {
- out.writeUTF(targetVariable);
- out.writeInt(typeMap.size());
- for (Map.Entry<String,String> entry : typeMap.entrySet()) {
- out.writeUTF(entry.getKey());
- out.writeUTF(entry.getValue());
- }
- out.writeInt(numFeatures);
- out.writeBoolean(useBias);
- out.writeInt(maxTargetCategories);
-
- if (targetCategories == null) {
- out.writeInt(0);
- } else {
- out.writeInt(targetCategories.size());
- for (String category : targetCategories) {
- out.writeUTF(category);
- }
- }
- out.writeDouble(lambda);
- out.writeDouble(learningRate);
- // skip csv
- lr.write(out);
- }
-
- @Override
- public void readFields(DataInput in) throws IOException {
- targetVariable = in.readUTF();
- int typeMapSize = in.readInt();
- typeMap = new HashMap<>(typeMapSize);
- for (int i = 0; i < typeMapSize; i++) {
- String key = in.readUTF();
- String value = in.readUTF();
- typeMap.put(key, value);
- }
- numFeatures = in.readInt();
- useBias = in.readBoolean();
- maxTargetCategories = in.readInt();
- int targetCategoriesSize = in.readInt();
- targetCategories = new ArrayList<>(targetCategoriesSize);
- for (int i = 0; i < targetCategoriesSize; i++) {
- targetCategories.add(in.readUTF());
- }
- lambda = in.readDouble();
- learningRate = in.readDouble();
- csv = null;
- lr = new OnlineLogisticRegression();
- lr.readFields(in);
- }
-
- /**
- * Sets the types of the predictors. This will later be used when reading CSV data. If you don't
- * use the CSV data and convert to vectors on your own, you don't need to call this.
- *
- * @param predictorList The list of variable names.
- * @param typeList The list of types in the format preferred by CsvRecordFactory.
- */
- public void setTypeMap(Iterable<String> predictorList, List<String> typeList) {
- Preconditions.checkArgument(!typeList.isEmpty(), "Must have at least one type specifier");
- typeMap = new HashMap<>();
- Iterator<String> iTypes = typeList.iterator();
- String lastType = null;
- for (Object x : predictorList) {
- // type list can be short .. we just repeat last spec
- if (iTypes.hasNext()) {
- lastType = iTypes.next();
- }
- typeMap.put(x.toString(), lastType);
- }
- }
-
- /**
- * Sets the target variable. If you don't use the CSV record factory, then this is irrelevant.
- *
- * @param targetVariable The name of the target variable.
- */
- public void setTargetVariable(String targetVariable) {
- this.targetVariable = targetVariable;
- }
-
- /**
- * Sets the number of target categories to be considered.
- *
- * @param maxTargetCategories The number of target categories.
- */
- public void setMaxTargetCategories(int maxTargetCategories) {
- this.maxTargetCategories = maxTargetCategories;
- }
-
- public void setNumFeatures(int numFeatures) {
- this.numFeatures = numFeatures;
- }
-
- public void setTargetCategories(List<String> targetCategories) {
- this.targetCategories = targetCategories;
- maxTargetCategories = targetCategories.size();
- }
-
- public List<String> getTargetCategories() {
- return this.targetCategories;
- }
-
- public void setUseBias(boolean useBias) {
- this.useBias = useBias;
- }
-
- public boolean useBias() {
- return useBias;
- }
-
- public String getTargetVariable() {
- return targetVariable;
- }
-
- public Map<String, String> getTypeMap() {
- return typeMap;
- }
-
- public void setTypeMap(Map<String, String> map) {
- this.typeMap = map;
- }
-
- public int getNumFeatures() {
- return numFeatures;
- }
-
- public int getMaxTargetCategories() {
- return maxTargetCategories;
- }
-
- public double getLambda() {
- return lambda;
- }
-
- public void setLambda(double lambda) {
- this.lambda = lambda;
- }
-
- public double getLearningRate() {
- return learningRate;
- }
-
- public void setLearningRate(double learningRate) {
- this.learningRate = learningRate;
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
deleted file mode 100644
index 3ec6a06..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.base.Preconditions;
-
-import java.io.BufferedReader;
-
-/**
- * Uses the same logic as TrainLogistic and RunLogistic for finding an input, but instead
- * of processing the input, this class just prints the input to standard out.
- */
-public final class PrintResourceOrFile {
-
- private PrintResourceOrFile() {
- }
-
- public static void main(String[] args) throws Exception {
- Preconditions.checkArgument(args.length == 1, "Must have a single argument that names a file or resource.");
- try (BufferedReader in = TrainLogistic.open(args[0])){
- String line;
- while ((line = in.readLine()) != null) {
- System.out.println(line);
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
deleted file mode 100644
index 678a8f5..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.SequentialAccessSparseVector;
-import org.apache.mahout.math.Vector;
-
-import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.HashMap;
-import java.util.Map;
-
-public final class RunAdaptiveLogistic {
-
- private static String inputFile;
- private static String modelFile;
- private static String outputFile;
- private static String idColumn;
- private static boolean maxScoreOnly;
-
- private RunAdaptiveLogistic() {
- }
-
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
-
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (!parseArgs(args)) {
- return;
- }
- AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
- .loadFromFile(new File(modelFile));
-
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- csv.setIdName(idColumn);
-
- AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
-
- State<Wrapper, CrossFoldLearner> best = lr.getBest();
- if (best == null) {
- output.println("AdaptiveLogisticRegression has not be trained probably.");
- return;
- }
- CrossFoldLearner learner = best.getPayload().getLearner();
-
- BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
- int k = 0;
-
- try (BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile),
- Charsets.UTF_8))) {
- out.write(idColumn + ",target,score");
- out.newLine();
-
- String line = in.readLine();
- csv.firstLine(line);
- line = in.readLine();
- Map<String, Double> results = new HashMap<>();
- while (line != null) {
- Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
- csv.processLine(line, v, false);
- Vector scores = learner.classifyFull(v);
- results.clear();
- if (maxScoreOnly) {
- results.put(csv.getTargetLabel(scores.maxValueIndex()),
- scores.maxValue());
- } else {
- for (int i = 0; i < scores.size(); i++) {
- results.put(csv.getTargetLabel(i), scores.get(i));
- }
- }
-
- for (Map.Entry<String, Double> entry : results.entrySet()) {
- out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
- out.newLine();
- }
- k++;
- if (k % 100 == 0) {
- output.println(k + " records processed");
- }
- line = in.readLine();
- }
- out.flush();
- }
- output.println(k + " records processed totally.");
- }
-
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
-
- Option help = builder.withLongName("help")
- .withDescription("print this list").create();
-
- Option quiet = builder.withLongName("quiet")
- .withDescription("be extra quiet").create();
-
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder
- .withLongName("input")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("input").withMaximum(1)
- .create())
- .withDescription("where to get training data").create();
-
- Option modelFileOption = builder
- .withLongName("model")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("model").withMaximum(1)
- .create())
- .withDescription("where to get the trained model").create();
-
- Option outputFileOption = builder
- .withLongName("output")
- .withRequired(true)
- .withDescription("the file path to output scores")
- .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
- .create();
-
- Option idColumnOption = builder
- .withLongName("idcolumn")
- .withRequired(true)
- .withDescription("the name of the id column for each record")
- .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create())
- .create();
-
- Option maxScoreOnlyOption = builder
- .withLongName("maxscoreonly")
- .withDescription("only output the target label with max scores")
- .create();
-
- Group normalArgs = new GroupBuilder()
- .withOption(help).withOption(quiet)
- .withOption(inputFileOption).withOption(modelFileOption)
- .withOption(outputFileOption).withOption(idColumnOption)
- .withOption(maxScoreOnlyOption)
- .create();
-
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
-
- if (cmdLine == null) {
- return false;
- }
-
- inputFile = getStringArgument(cmdLine, inputFileOption);
- modelFile = getStringArgument(cmdLine, modelFileOption);
- outputFile = getStringArgument(cmdLine, outputFileOption);
- idColumn = getStringArgument(cmdLine, idColumnOption);
- maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption);
- return true;
- }
-
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
-
- private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
deleted file mode 100644
index 2d57016..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
+++ /dev/null
@@ -1,163 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.evaluation.Auc;
-import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.SequentialAccessSparseVector;
-import org.apache.mahout.math.Vector;
-
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.Locale;
-
-public final class RunLogistic {
-
- private static String inputFile;
- private static String modelFile;
- private static boolean showAuc;
- private static boolean showScores;
- private static boolean showConfusion;
-
- private RunLogistic() {
- }
-
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
-
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (parseArgs(args)) {
- if (!showAuc && !showConfusion && !showScores) {
- showAuc = true;
- showConfusion = true;
- }
-
- Auc collector = new Auc();
- LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));
-
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- OnlineLogisticRegression lr = lmp.createRegression();
- BufferedReader in = TrainLogistic.open(inputFile);
- String line = in.readLine();
- csv.firstLine(line);
- line = in.readLine();
- if (showScores) {
- output.println("\"target\",\"model-output\",\"log-likelihood\"");
- }
- while (line != null) {
- Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
- int target = csv.processLine(line, v);
-
- double score = lr.classifyScalar(v);
- if (showScores) {
- output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
- }
- collector.add(target, score);
- line = in.readLine();
- }
-
- if (showAuc) {
- output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
- }
- if (showConfusion) {
- Matrix m = collector.confusion();
- output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",
- m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
- m = collector.entropy();
- output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",
- m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
- }
- }
- }
-
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
-
- Option help = builder.withLongName("help").withDescription("print this list").create();
-
- Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
-
- Option auc = builder.withLongName("auc").withDescription("print AUC").create();
- Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create();
-
- Option scores = builder.withLongName("scores").withDescription("print scores").create();
-
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
-
- Option modelFileOption = builder.withLongName("model")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
- .withDescription("where to get a model")
- .create();
-
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(quiet)
- .withOption(auc)
- .withOption(scores)
- .withOption(confusion)
- .withOption(inputFileOption)
- .withOption(modelFileOption)
- .create();
-
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
-
- if (cmdLine == null) {
- return false;
- }
-
- inputFile = getStringArgument(cmdLine, inputFileOption);
- modelFile = getStringArgument(cmdLine, modelFileOption);
- showAuc = getBooleanArgument(cmdLine, auc);
- showScores = getBooleanArgument(cmdLine, scores);
- showConfusion = getBooleanArgument(cmdLine, confusion);
-
- return true;
- }
-
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
-
- private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
deleted file mode 100644
index c657803..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
+++ /dev/null
@@ -1,151 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.collect.Multiset;
-import org.apache.mahout.classifier.NewsgroupHelper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.function.DoubleFunction;
-import org.apache.mahout.math.function.Functions;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.Set;
-import java.util.TreeMap;
-
-public final class SGDHelper {
-
- private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
-
- private SGDHelper() {
- }
-
- public static void dissect(int leakType,
- Dictionary dictionary,
- AdaptiveLogisticRegression learningAlgorithm,
- Iterable<File> files, Multiset<String> overallCounts) throws IOException {
- CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
- model.close();
-
- Map<String, Set<Integer>> traceDictionary = new TreeMap<>();
- ModelDissector md = new ModelDissector();
-
- NewsgroupHelper helper = new NewsgroupHelper();
- helper.getEncoder().setTraceDictionary(traceDictionary);
- helper.getBias().setTraceDictionary(traceDictionary);
-
- for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
- String ng = file.getParentFile().getName();
- int actual = dictionary.intern(ng);
-
- traceDictionary.clear();
- Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
- md.update(v, traceDictionary, model);
- }
-
- List<String> ngNames = new ArrayList<>(dictionary.values());
- List<ModelDissector.Weight> weights = md.summary(100);
- System.out.println("============");
- System.out.println("Model Dissection");
- for (ModelDissector.Weight w : weights) {
- System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s%n",
- w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
- w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
- }
- }
-
- public static List<File> permute(Iterable<File> files, Random rand) {
- List<File> r = new ArrayList<>();
- for (File file : files) {
- int i = rand.nextInt(r.size() + 1);
- if (i == r.size()) {
- r.add(file);
- } else {
- r.add(r.get(i));
- r.set(i, file);
- }
- }
- return r;
- }
-
- static void analyzeState(SGDInfo info, int leakType, int k, State<AdaptiveLogisticRegression.Wrapper,
- CrossFoldLearner> best) throws IOException {
- int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
- int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
- double maxBeta;
- double nonZeros;
- double positive;
- double norm;
-
- double lambda = 0;
- double mu = 0;
-
- if (best != null) {
- CrossFoldLearner state = best.getPayload().getLearner();
- info.setAverageCorrect(state.percentCorrect());
- info.setAverageLL(state.logLikelihood());
-
- OnlineLogisticRegression model = state.getModels().get(0);
- // finish off pending regularization
- model.close();
-
- Matrix beta = model.getBeta();
- maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
- nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
- @Override
- public double apply(double v) {
- return Math.abs(v) > 1.0e-6 ? 1 : 0;
- }
- });
- positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
- @Override
- public double apply(double v) {
- return v > 0 ? 1 : 0;
- }
- });
- norm = beta.aggregate(Functions.PLUS, Functions.ABS);
-
- lambda = best.getMappedParams()[0];
- mu = best.getMappedParams()[1];
- } else {
- maxBeta = 0;
- nonZeros = 0;
- positive = 0;
- norm = 0;
- }
- if (k % (bump * scale) == 0) {
- if (best != null) {
- File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group-" + k + ".model");
- ModelSerializer.writeBinary(modelFile.getAbsolutePath(), best.getPayload().getLearner().getModels().get(0));
- }
-
- info.setStep(info.getStep() + 0.25);
- System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
- System.out.printf("%d\t%.3f\t%.2f\t%s%n",
- k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
deleted file mode 100644
index be55d43..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-final class SGDInfo {
-
- private double averageLL;
- private double averageCorrect;
- private double step;
- private int[] bumps = {1, 2, 5};
-
- double getAverageLL() {
- return averageLL;
- }
-
- void setAverageLL(double averageLL) {
- this.averageLL = averageLL;
- }
-
- double getAverageCorrect() {
- return averageCorrect;
- }
-
- void setAverageCorrect(double averageCorrect) {
- this.averageCorrect = averageCorrect;
- }
-
- double getStep() {
- return step;
- }
-
- void setStep(double step) {
- this.step = step;
- }
-
- int[] getBumps() {
- return bumps;
- }
-
- void setBumps(int[] bumps) {
- this.bumps = bumps;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
deleted file mode 100644
index b3da452..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java
+++ /dev/null
@@ -1,283 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.base.Joiner;
-import com.google.common.base.Splitter;
-import com.google.common.collect.Lists;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.list.IntArrayList;
-import org.apache.mahout.math.stats.OnlineSummarizer;
-import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
-import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.BufferedReader;
-import java.io.Closeable;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-
-/**
- * Shows how different encoding choices can make big speed differences.
- * <p/>
- * Run with command line options --generate 1000000 test.csv to generate a million data lines in
- * test.csv.
- * <p/>
- * Run with command line options --parser test.csv to time how long it takes to parse and encode
- * those million data points
- * <p/>
- * Run with command line options --fast test.csv to time how long it takes to parse and encode those
- * million data points using byte-level parsing and direct value encoding.
- * <p/>
- * This doesn't demonstrate text encoding which is subject to somewhat different tricks. The basic
- * idea of caching hash locations and byte level parsing still very much applies to text, however.
- */
-public final class SimpleCsvExamples {
-
- public static final char SEPARATOR_CHAR = '\t';
- private static final int FIELDS = 100;
-
- private static final Logger log = LoggerFactory.getLogger(SimpleCsvExamples.class);
-
- private SimpleCsvExamples() {}
-
- public static void main(String[] args) throws IOException {
- FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS];
- for (int i = 0; i < FIELDS; i++) {
- encoder[i] = new ConstantValueEncoder("v" + 1);
- }
-
- OnlineSummarizer[] s = new OnlineSummarizer[FIELDS];
- for (int i = 0; i < FIELDS; i++) {
- s[i] = new OnlineSummarizer();
- }
- long t0 = System.currentTimeMillis();
- Vector v = new DenseVector(1000);
- if ("--generate".equals(args[0])) {
- try (PrintWriter out =
- new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8))) {
- int n = Integer.parseInt(args[1]);
- for (int i = 0; i < n; i++) {
- Line x = Line.generate();
- out.println(x);
- }
- }
- } else if ("--parse".equals(args[0])) {
- try (BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8)){
- String line = in.readLine();
- while (line != null) {
- v.assign(0);
- Line x = new Line(line);
- for (int i = 0; i < FIELDS; i++) {
- s[i].add(x.getDouble(i));
- encoder[i].addToVector(x.get(i), v);
- }
- line = in.readLine();
- }
- }
- String separator = "";
- for (int i = 0; i < FIELDS; i++) {
- System.out.printf("%s%.3f", separator, s[i].getMean());
- separator = ",";
- }
- } else if ("--fast".equals(args[0])) {
- try (FastLineReader in = new FastLineReader(new FileInputStream(args[1]))){
- FastLine line = in.read();
- while (line != null) {
- v.assign(0);
- for (int i = 0; i < FIELDS; i++) {
- double z = line.getDouble(i);
- s[i].add(z);
- encoder[i].addToVector((byte[]) null, z, v);
- }
- line = in.read();
- }
- }
-
- String separator = "";
- for (int i = 0; i < FIELDS; i++) {
- System.out.printf("%s%.3f", separator, s[i].getMean());
- separator = ",";
- }
- }
- System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0);
- }
-
-
- private static final class Line {
- private static final Splitter ON_TABS = Splitter.on(SEPARATOR_CHAR).trimResults();
- public static final Joiner WITH_COMMAS = Joiner.on(SEPARATOR_CHAR);
-
- public static final Random RAND = RandomUtils.getRandom();
-
- private final List<String> data;
-
- private Line(CharSequence line) {
- data = Lists.newArrayList(ON_TABS.split(line));
- }
-
- private Line() {
- data = new ArrayList<>();
- }
-
- public double getDouble(int field) {
- return Double.parseDouble(data.get(field));
- }
-
- /**
- * Generate a random line with 20 fields each with integer values.
- *
- * @return A new line with data.
- */
- public static Line generate() {
- Line r = new Line();
- for (int i = 0; i < FIELDS; i++) {
- double mean = ((i + 1) * 257) % 50 + 1;
- r.data.add(Integer.toString(randomValue(mean)));
- }
- return r;
- }
-
- /**
- * Returns a random exponentially distributed integer with a particular mean value. This is
- * just a way to create more small numbers than big numbers.
- *
- * @param mean mean of the distribution
- * @return random exponentially distributed integer with the specific mean
- */
- private static int randomValue(double mean) {
- return (int) (-mean * Math.log1p(-RAND.nextDouble()));
- }
-
- @Override
- public String toString() {
- return WITH_COMMAS.join(data);
- }
-
- public String get(int field) {
- return data.get(field);
- }
- }
-
- private static final class FastLine {
-
- private final ByteBuffer base;
- private final IntArrayList start = new IntArrayList();
- private final IntArrayList length = new IntArrayList();
-
- private FastLine(ByteBuffer base) {
- this.base = base;
- }
-
- public static FastLine read(ByteBuffer buf) {
- FastLine r = new FastLine(buf);
- r.start.add(buf.position());
- int offset = buf.position();
- while (offset < buf.limit()) {
- int ch = buf.get();
- offset = buf.position();
- switch (ch) {
- case '\n':
- r.length.add(offset - r.start.get(r.length.size()) - 1);
- return r;
- case SEPARATOR_CHAR:
- r.length.add(offset - r.start.get(r.length.size()) - 1);
- r.start.add(offset);
- break;
- default:
- // nothing to do for now
- }
- }
- throw new IllegalArgumentException("Not enough bytes in buffer");
- }
-
- public double getDouble(int field) {
- int offset = start.get(field);
- int size = length.get(field);
- switch (size) {
- case 1:
- return base.get(offset) - '0';
- case 2:
- return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0';
- default:
- double r = 0;
- for (int i = 0; i < size; i++) {
- r = 10 * r + base.get(offset + i) - '0';
- }
- return r;
- }
- }
- }
-
- private static final class FastLineReader implements Closeable {
- private final InputStream in;
- private final ByteBuffer buf = ByteBuffer.allocate(100000);
-
- private FastLineReader(InputStream in) throws IOException {
- this.in = in;
- buf.limit(0);
- fillBuffer();
- }
-
- public FastLine read() throws IOException {
- fillBuffer();
- if (buf.remaining() > 0) {
- return FastLine.read(buf);
- } else {
- return null;
- }
- }
-
- private void fillBuffer() throws IOException {
- if (buf.remaining() < 10000) {
- buf.compact();
- int n = in.read(buf.array(), buf.position(), buf.remaining());
- if (n == -1) {
- buf.flip();
- } else {
- buf.limit(buf.position() + n);
- buf.position(0);
- }
- }
- }
-
- @Override
- public void close() {
- try {
- Closeables.close(in, true);
- } catch (IOException e) {
- log.error(e.getMessage(), e);
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
deleted file mode 100644
index 074f774..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
+++ /dev/null
@@ -1,152 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
-import org.apache.hadoop.io.Text;
-import org.apache.mahout.classifier.ClassifierResult;
-import org.apache.mahout.classifier.ResultAnalyzer;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-
-/**
- * Run the ASF email, as trained by TrainASFEmail
- */
-public final class TestASFEmail {
-
- private String inputFile;
- private String modelFile;
-
- private TestASFEmail() {}
-
- public static void main(String[] args) throws IOException {
- TestASFEmail runner = new TestASFEmail();
- if (runner.parseArgs(args)) {
- runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- }
-
- public void run(PrintWriter output) throws IOException {
-
- File base = new File(inputFile);
- //contains the best model
- OnlineLogisticRegression classifier =
- ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
-
-
- Dictionary asfDictionary = new Dictionary();
- Configuration conf = new Configuration();
- PathFilter testFilter = new PathFilter() {
- @Override
- public boolean accept(Path path) {
- return path.getName().contains("test");
- }
- };
- SequenceFileDirIterator<Text, VectorWritable> iter =
- new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
- null, true, conf);
-
- long numItems = 0;
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- asfDictionary.intern(next.getFirst().toString());
- numItems++;
- }
-
- System.out.println(numItems + " test files");
- ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT");
- iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter,
- null, true, conf);
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- String ng = next.getFirst().toString();
-
- int actual = asfDictionary.intern(ng);
- Vector result = classifier.classifyFull(next.getSecond().get());
- int cat = result.maxValueIndex();
- double score = result.maxValue();
- double ll = classifier.logLikelihood(actual, next.getSecond().get());
- ClassifierResult cr = new ClassifierResult(asfDictionary.values().get(cat), score, ll);
- ra.addInstance(asfDictionary.values().get(actual), cr);
-
- }
- output.println(ra);
- }
-
- boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
-
- Option help = builder.withLongName("help").withDescription("print this list").create();
-
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
-
- Option modelFileOption = builder.withLongName("model")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
- .withDescription("where to get a model")
- .create();
-
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(inputFileOption)
- .withOption(modelFileOption)
- .create();
-
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
-
- if (cmdLine == null) {
- return false;
- }
-
- inputFile = (String) cmdLine.getValue(inputFileOption);
- modelFile = (String) cmdLine.getValue(modelFileOption);
- return true;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
deleted file mode 100644
index f0316e9..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
+++ /dev/null
@@ -1,141 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.ClassifierResult;
-import org.apache.mahout.classifier.NewsgroupHelper;
-import org.apache.mahout.classifier.ResultAnalyzer;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-/**
- * Run the 20 news groups test data through SGD, as trained by {@link org.apache.mahout.classifier.sgd.TrainNewsGroups}.
- */
-public final class TestNewsGroups {
-
- private String inputFile;
- private String modelFile;
-
- private TestNewsGroups() {
- }
-
- public static void main(String[] args) throws IOException {
- TestNewsGroups runner = new TestNewsGroups();
- if (runner.parseArgs(args)) {
- runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
- }
-
- public void run(PrintWriter output) throws IOException {
-
- File base = new File(inputFile);
- //contains the best model
- OnlineLogisticRegression classifier =
- ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);
-
- Dictionary newsGroups = new Dictionary();
- Multiset<String> overallCounts = HashMultiset.create();
-
- List<File> files = new ArrayList<>();
- for (File newsgroup : base.listFiles()) {
- if (newsgroup.isDirectory()) {
- newsGroups.intern(newsgroup.getName());
- files.addAll(Arrays.asList(newsgroup.listFiles()));
- }
- }
- System.out.println(files.size() + " test files");
- ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
- for (File file : files) {
- String ng = file.getParentFile().getName();
-
- int actual = newsGroups.intern(ng);
- NewsgroupHelper helper = new NewsgroupHelper();
- //no leak type ensures this is a normal vector
- Vector input = helper.encodeFeatureVector(file, actual, 0, overallCounts);
- Vector result = classifier.classifyFull(input);
- int cat = result.maxValueIndex();
- double score = result.maxValue();
- double ll = classifier.logLikelihood(actual, input);
- ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
- ra.addInstance(newsGroups.values().get(actual), cr);
-
- }
- output.println(ra);
- }
-
- boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
-
- Option help = builder.withLongName("help").withDescription("print this list").create();
-
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFileOption = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
-
- Option modelFileOption = builder.withLongName("model")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
- .withDescription("where to get a model")
- .create();
-
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(inputFileOption)
- .withOption(modelFileOption)
- .create();
-
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
-
- if (cmdLine == null) {
- return false;
- }
-
- inputFile = (String) cmdLine.getValue(inputFileOption);
- modelFile = (String) cmdLine.getValue(modelFileOption);
- return true;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
deleted file mode 100644
index e681f92..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-import com.google.common.collect.Ordering;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
-import org.apache.hadoop.io.Text;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.vectorizer.encoders.Dictionary;
-
-import java.io.File;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-
-public final class TrainASFEmail extends AbstractJob {
-
- private TrainASFEmail() {
- }
-
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption("categories", "nc", "The number of categories to train on", true);
- addOption("cardinality", "c", "The size of the vectors to use", "100000");
- addOption("threads", "t", "The number of threads to use in the learner", "20");
- addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. "
- + "Higher values require more memory.", "5");
- if (parseArguments(args) == null) {
- return -1;
- }
-
- File base = new File(getInputPath().toString());
-
- Multiset<String> overallCounts = HashMultiset.create();
- File output = new File(getOutputPath().toString());
- output.mkdirs();
- int numCats = Integer.parseInt(getOption("categories"));
- int cardinality = Integer.parseInt(getOption("cardinality", "100000"));
- int threadCount = Integer.parseInt(getOption("threads", "20"));
- int poolSize = Integer.parseInt(getOption("poolSize", "5"));
- Dictionary asfDictionary = new Dictionary();
- AdaptiveLogisticRegression learningAlgorithm =
- new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize);
- learningAlgorithm.setInterval(800);
- learningAlgorithm.setAveragingWindow(500);
-
- //We ran seq2encoded and split input already, so let's just build up the dictionary
- Configuration conf = new Configuration();
- PathFilter trainFilter = new PathFilter() {
- @Override
- public boolean accept(Path path) {
- return path.getName().contains("training");
- }
- };
- SequenceFileDirIterator<Text, VectorWritable> iter =
- new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf);
- long numItems = 0;
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- asfDictionary.intern(next.getFirst().toString());
- numItems++;
- }
-
- System.out.println(numItems + " training files");
-
- SGDInfo info = new SGDInfo();
-
- iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter,
- null, true, conf);
- int k = 0;
- while (iter.hasNext()) {
- Pair<Text, VectorWritable> next = iter.next();
- String ng = next.getFirst().toString();
- int actual = asfDictionary.intern(ng);
- //we already have encoded
- learningAlgorithm.train(actual, next.getSecond().get());
- k++;
- State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
-
- SGDHelper.analyzeState(info, 0, k, best);
- }
- learningAlgorithm.close();
- //TODO: how to dissection since we aren't processing the files here
- //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, overallCounts);
- System.out.println("exiting main, writing model to " + output);
-
- ModelSerializer.writeBinary(output + "/asf.model",
- learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
-
- List<Integer> counts = new ArrayList<>();
- System.out.println("Word counts");
- for (String count : overallCounts.elementSet()) {
- counts.add(overallCounts.count(count));
- }
- Collections.sort(counts, Ordering.natural().reverse());
- k = 0;
- for (Integer count : counts) {
- System.out.println(k + "\t" + count);
- k++;
- if (k > 1000) {
- break;
- }
- }
- return 0;
- }
-
- public static void main(String[] args) throws Exception {
- TrainASFEmail trainer = new TrainASFEmail();
- trainer.run(args);
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
deleted file mode 100644
index defb5b9..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
+++ /dev/null
@@ -1,377 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.io.Resources;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
-import org.apache.mahout.ep.State;
-import org.apache.mahout.math.RandomAccessSparseVector;
-import org.apache.mahout.math.Vector;
-
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.io.OutputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Locale;
-
-public final class TrainAdaptiveLogistic {
-
- private static String inputFile;
- private static String outputFile;
- private static AdaptiveLogisticModelParameters lmp;
- private static int passes;
- private static boolean showperf;
- private static int skipperfnum = 99;
- private static AdaptiveLogisticRegression model;
-
- private TrainAdaptiveLogistic() {
- }
-
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
-
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (parseArgs(args)) {
-
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- model = lmp.createAdaptiveLogisticRegression();
- State<Wrapper, CrossFoldLearner> best;
- CrossFoldLearner learner = null;
-
- int k = 0;
- for (int pass = 0; pass < passes; pass++) {
- BufferedReader in = open(inputFile);
-
- // read variable names
- csv.firstLine(in.readLine());
-
- String line = in.readLine();
- while (line != null) {
- // for each new line, get target and predictors
- Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
- int targetValue = csv.processLine(line, input);
-
- // update model
- model.train(targetValue, input);
- k++;
-
- if (showperf && (k % (skipperfnum + 1) == 0)) {
-
- best = model.getBest();
- if (best != null) {
- learner = best.getPayload().getLearner();
- }
- if (learner != null) {
- double averageCorrect = learner.percentCorrect();
- double averageLL = learner.logLikelihood();
- output.printf("%d\t%.3f\t%.2f%n",
- k, averageLL, averageCorrect * 100);
- } else {
- output.printf(Locale.ENGLISH,
- "%10d %2d %s%n", k, targetValue,
- "AdaptiveLogisticRegression has not found a good model ......");
- }
- }
- line = in.readLine();
- }
- in.close();
- }
-
- best = model.getBest();
- if (best != null) {
- learner = best.getPayload().getLearner();
- }
- if (learner == null) {
- output.println("AdaptiveLogisticRegression has failed to train a model.");
- return;
- }
-
- try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
- lmp.saveTo(modelOutput);
- }
-
- OnlineLogisticRegression lr = learner.getModels().get(0);
- output.println(lmp.getNumFeatures());
- output.println(lmp.getTargetVariable() + " ~ ");
- String sep = "";
- for (String v : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, 0, csv, v);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
- sep = " + ";
- }
- }
- output.printf("%n");
-
- for (int row = 0; row < lr.getBeta().numRows(); row++) {
- for (String key : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, row, csv, key);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
- }
- }
- for (int column = 0; column < lr.getBeta().numCols(); column++) {
- output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
- }
- output.println();
- }
- }
-
- }
-
- private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
- double weight = 0;
- for (Integer column : csv.getTraceDictionary().get(predictor)) {
- weight += lr.getBeta().get(row, column);
- }
- return weight;
- }
-
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
-
- Option help = builder.withLongName("help")
- .withDescription("print this list").create();
-
- Option quiet = builder.withLongName("quiet")
- .withDescription("be extra quiet").create();
-
-
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option showperf = builder
- .withLongName("showperf")
- .withDescription("output performance measures during training")
- .create();
-
- Option inputFile = builder
- .withLongName("input")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("input").withMaximum(1)
- .create())
- .withDescription("where to get training data").create();
-
- Option outputFile = builder
- .withLongName("output")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("output").withMaximum(1)
- .create())
- .withDescription("where to write the model content").create();
-
- Option threads = builder.withLongName("threads")
- .withArgument(
- argumentBuilder.withName("threads").withDefault("4").create())
- .withDescription("the number of threads AdaptiveLogisticRegression uses")
- .create();
-
-
- Option predictors = builder.withLongName("predictors")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("predictors").create())
- .withDescription("a list of predictor variables").create();
-
- Option types = builder
- .withLongName("types")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("types").create())
- .withDescription(
- "a list of predictor variable types (numeric, word, or text)")
- .create();
-
- Option target = builder
- .withLongName("target")
- .withDescription("the name of the target variable")
- .withRequired(true)
- .withArgument(
- argumentBuilder.withName("target").withMaximum(1)
- .create())
- .create();
-
- Option targetCategories = builder
- .withLongName("categories")
- .withDescription("the number of target categories to be considered")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("categories").withMaximum(1).create())
- .create();
-
-
- Option features = builder
- .withLongName("features")
- .withDescription("the number of internal hashed features to use")
- .withArgument(
- argumentBuilder.withName("numFeatures")
- .withDefault("1000").withMaximum(1).create())
- .create();
-
- Option passes = builder
- .withLongName("passes")
- .withDescription("the number of times to pass over the input data")
- .withArgument(
- argumentBuilder.withName("passes").withDefault("2")
- .withMaximum(1).create())
- .create();
-
- Option interval = builder.withLongName("interval")
- .withArgument(
- argumentBuilder.withName("interval").withDefault("500").create())
- .withDescription("the interval property of AdaptiveLogisticRegression")
- .create();
-
- Option window = builder.withLongName("window")
- .withArgument(
- argumentBuilder.withName("window").withDefault("800").create())
- .withDescription("the average propery of AdaptiveLogisticRegression")
- .create();
-
- Option skipperfnum = builder.withLongName("skipperfnum")
- .withArgument(
- argumentBuilder.withName("skipperfnum").withDefault("99").create())
- .withDescription("show performance measures every (skipperfnum + 1) rows")
- .create();
-
- Option prior = builder.withLongName("prior")
- .withArgument(
- argumentBuilder.withName("prior").withDefault("L1").create())
- .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up")
- .create();
-
- Option priorOption = builder.withLongName("prioroption")
- .withArgument(
- argumentBuilder.withName("prioroption").create())
- .withDescription("constructor parameter for ElasticBandPrior and TPrior")
- .create();
-
- Option auc = builder.withLongName("auc")
- .withArgument(
- argumentBuilder.withName("auc").withDefault("global").create())
- .withDescription("the auc to use: global or grouped")
- .create();
-
-
-
- Group normalArgs = new GroupBuilder().withOption(help)
- .withOption(quiet).withOption(inputFile).withOption(outputFile)
- .withOption(target).withOption(targetCategories)
- .withOption(predictors).withOption(types).withOption(passes)
- .withOption(interval).withOption(window).withOption(threads)
- .withOption(prior).withOption(features).withOption(showperf)
- .withOption(skipperfnum).withOption(priorOption).withOption(auc)
- .create();
-
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
-
- if (cmdLine == null) {
- return false;
- }
-
- TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile);
- TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine,
- outputFile);
-
- List<String> typeList = new ArrayList<>();
- for (Object x : cmdLine.getValues(types)) {
- typeList.add(x.toString());
- }
-
- List<String> predictorList = new ArrayList<>();
- for (Object x : cmdLine.getValues(predictors)) {
- predictorList.add(x.toString());
- }
-
- lmp = new AdaptiveLogisticModelParameters();
- lmp.setTargetVariable(getStringArgument(cmdLine, target));
- lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
- lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
- lmp.setInterval(getIntegerArgument(cmdLine, interval));
- lmp.setAverageWindow(getIntegerArgument(cmdLine, window));
- lmp.setThreads(getIntegerArgument(cmdLine, threads));
- lmp.setAuc(getStringArgument(cmdLine, auc));
- lmp.setPrior(getStringArgument(cmdLine, prior));
- if (cmdLine.getValue(priorOption) != null) {
- lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption));
- }
- lmp.setTypeMap(predictorList, typeList);
- TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf);
- TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum);
- TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes);
-
- lmp.checkParameters();
-
- return true;
- }
-
- private static String getStringArgument(CommandLine cmdLine,
- Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }
-
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
-
- private static int getIntegerArgument(CommandLine cmdLine, Option features) {
- return Integer.parseInt((String) cmdLine.getValue(features));
- }
-
- private static double getDoubleArgument(CommandLine cmdLine, Option op) {
- return Double.parseDouble((String) cmdLine.getValue(op));
- }
-
- public static AdaptiveLogisticRegression getModel() {
- return model;
- }
-
- public static LogisticModelParameters getParameters() {
- return lmp;
- }
-
- static BufferedReader open(String inputFile) throws IOException {
- InputStream in;
- try {
- in = Resources.getResource(inputFile).openStream();
- } catch (IllegalArgumentException e) {
- in = new FileInputStream(new File(inputFile));
- }
- return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
deleted file mode 100644
index f4b8bcb..0000000
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
+++ /dev/null
@@ -1,311 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sgd;
-
-import com.google.common.io.Resources;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
-import org.apache.commons.cli2.util.HelpFormatter;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.math.RandomAccessSparseVector;
-import org.apache.mahout.math.Vector;
-
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.io.OutputStream;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Locale;
-
-/**
- * Train a logistic regression for the examples from Chapter 13 of Mahout in Action
- */
-public final class TrainLogistic {
-
- private static String inputFile;
- private static String outputFile;
- private static LogisticModelParameters lmp;
- private static int passes;
- private static boolean scores;
- private static OnlineLogisticRegression model;
-
- private TrainLogistic() {
- }
-
- public static void main(String[] args) throws Exception {
- mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
- }
-
- static void mainToOutput(String[] args, PrintWriter output) throws Exception {
- if (parseArgs(args)) {
- double logPEstimate = 0;
- int samples = 0;
-
- CsvRecordFactory csv = lmp.getCsvRecordFactory();
- OnlineLogisticRegression lr = lmp.createRegression();
- for (int pass = 0; pass < passes; pass++) {
- try (BufferedReader in = open(inputFile)) {
- // read variable names
- csv.firstLine(in.readLine());
-
- String line = in.readLine();
- while (line != null) {
- // for each new line, get target and predictors
- Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
- int targetValue = csv.processLine(line, input);
-
- // check performance while this is still news
- double logP = lr.logLikelihood(targetValue, input);
- if (!Double.isInfinite(logP)) {
- if (samples < 20) {
- logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
- } else {
- logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
- }
- samples++;
- }
- double p = lr.classifyScalar(input);
- if (scores) {
- output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n",
- samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
- }
-
- // now update model
- lr.train(targetValue, input);
-
- line = in.readLine();
- }
- }
- }
-
- try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
- lmp.saveTo(modelOutput);
- }
-
- output.println(lmp.getNumFeatures());
- output.println(lmp.getTargetVariable() + " ~ ");
- String sep = "";
- for (String v : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, 0, csv, v);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
- sep = " + ";
- }
- }
- output.printf("%n");
- model = lr;
- for (int row = 0; row < lr.getBeta().numRows(); row++) {
- for (String key : csv.getTraceDictionary().keySet()) {
- double weight = predictorWeight(lr, row, csv, key);
- if (weight != 0) {
- output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
- }
- }
- for (int column = 0; column < lr.getBeta().numCols(); column++) {
- output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
- }
- output.println();
- }
- }
- }
-
- private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
- double weight = 0;
- for (Integer column : csv.getTraceDictionary().get(predictor)) {
- weight += lr.getBeta().get(row, column);
- }
- return weight;
- }
-
- private static boolean parseArgs(String[] args) {
- DefaultOptionBuilder builder = new DefaultOptionBuilder();
-
- Option help = builder.withLongName("help").withDescription("print this list").create();
-
- Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
- Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create();
-
- ArgumentBuilder argumentBuilder = new ArgumentBuilder();
- Option inputFile = builder.withLongName("input")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
-
- Option outputFile = builder.withLongName("output")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
- .withDescription("where to get training data")
- .create();
-
- Option predictors = builder.withLongName("predictors")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("p").create())
- .withDescription("a list of predictor variables")
- .create();
-
- Option types = builder.withLongName("types")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("t").create())
- .withDescription("a list of predictor variable types (numeric, word, or text)")
- .create();
-
- Option target = builder.withLongName("target")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("target").withMaximum(1).create())
- .withDescription("the name of the target variable")
- .create();
-
- Option features = builder.withLongName("features")
- .withArgument(
- argumentBuilder.withName("numFeatures")
- .withDefault("1000")
- .withMaximum(1).create())
- .withDescription("the number of internal hashed features to use")
- .create();
-
- Option passes = builder.withLongName("passes")
- .withArgument(
- argumentBuilder.withName("passes")
- .withDefault("2")
- .withMaximum(1).create())
- .withDescription("the number of times to pass over the input data")
- .create();
-
- Option lambda = builder.withLongName("lambda")
- .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create())
- .withDescription("the amount of coefficient decay to use")
- .create();
-
- Option rate = builder.withLongName("rate")
- .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create())
- .withDescription("the learning rate")
- .create();
-
- Option noBias = builder.withLongName("noBias")
- .withDescription("don't include a bias term")
- .create();
-
- Option targetCategories = builder.withLongName("categories")
- .withRequired(true)
- .withArgument(argumentBuilder.withName("number").withMaximum(1).create())
- .withDescription("the number of target categories to be considered")
- .create();
-
- Group normalArgs = new GroupBuilder()
- .withOption(help)
- .withOption(quiet)
- .withOption(inputFile)
- .withOption(outputFile)
- .withOption(target)
- .withOption(targetCategories)
- .withOption(predictors)
- .withOption(types)
- .withOption(passes)
- .withOption(lambda)
- .withOption(rate)
- .withOption(noBias)
- .withOption(features)
- .create();
-
- Parser parser = new Parser();
- parser.setHelpOption(help);
- parser.setHelpTrigger("--help");
- parser.setGroup(normalArgs);
- parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
- CommandLine cmdLine = parser.parseAndHelp(args);
-
- if (cmdLine == null) {
- return false;
- }
-
- TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile);
- TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile);
-
- List<String> typeList = new ArrayList<>();
- for (Object x : cmdLine.getValues(types)) {
- typeList.add(x.toString());
- }
-
- List<String> predictorList = new ArrayList<>();
- for (Object x : cmdLine.getValues(predictors)) {
- predictorList.add(x.toString());
- }
-
- lmp = new LogisticModelParameters();
- lmp.setTargetVariable(getStringArgument(cmdLine, target));
- lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
- lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
- lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
- lmp.setTypeMap(predictorList, typeList);
-
- lmp.setLambda(getDoubleArgument(cmdLine, lambda));
- lmp.setLearningRate(getDoubleArgument(cmdLine, rate));
-
- TrainLogistic.scores = getBooleanArgument(cmdLine, scores);
- TrainLogistic.passes = getIntegerArgument(cmdLine, passes);
-
- return true;
- }
-
- private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
- return (String) cmdLine.getValue(inputFile);
- }
-
- private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
- return cmdLine.hasOption(option);
- }
-
- private static int getIntegerArgument(CommandLine cmdLine, Option features) {
- return Integer.parseInt((String) cmdLine.getValue(features));
- }
-
- private static double getDoubleArgument(CommandLine cmdLine, Option op) {
- return Double.parseDouble((String) cmdLine.getValue(op));
- }
-
- public static OnlineLogisticRegression getModel() {
- return model;
- }
-
- public static LogisticModelParameters getParameters() {
- return lmp;
- }
-
- static BufferedReader open(String inputFile) throws IOException {
- InputStream in;
- try {
- in = Resources.getResource(inputFile).openStream();
- } catch (IllegalArgumentException e) {
- in = new FileInputStream(new File(inputFile));
- }
- return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
- }
-}