You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/27 13:14:43 UTC
[17/24] mahout git commit: MAHOUT-2034 Split MR and New Examples into
seperate modules
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
new file mode 100644
index 0000000..f4b8bcb
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.io.Resources;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+
+/**
+ * Train a logistic regression for the examples from Chapter 13 of Mahout in Action
+ */
+public final class TrainLogistic {
+
+ private static String inputFile;
+ private static String outputFile;
+ private static LogisticModelParameters lmp;
+ private static int passes;
+ private static boolean scores;
+ private static OnlineLogisticRegression model;
+
+ private TrainLogistic() {
+ }
+
+ public static void main(String[] args) throws Exception {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws Exception {
+ if (parseArgs(args)) {
+ double logPEstimate = 0;
+ int samples = 0;
+
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ OnlineLogisticRegression lr = lmp.createRegression();
+ for (int pass = 0; pass < passes; pass++) {
+ try (BufferedReader in = open(inputFile)) {
+ // read variable names
+ csv.firstLine(in.readLine());
+
+ String line = in.readLine();
+ while (line != null) {
+ // for each new line, get target and predictors
+ Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
+ int targetValue = csv.processLine(line, input);
+
+ // check performance while this is still news
+ double logP = lr.logLikelihood(targetValue, input);
+ if (!Double.isInfinite(logP)) {
+ if (samples < 20) {
+ logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
+ } else {
+ logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
+ }
+ samples++;
+ }
+ double p = lr.classifyScalar(input);
+ if (scores) {
+ output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n",
+ samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
+ }
+
+ // now update model
+ lr.train(targetValue, input);
+
+ line = in.readLine();
+ }
+ }
+ }
+
+ try (OutputStream modelOutput = new FileOutputStream(outputFile)) {
+ lmp.saveTo(modelOutput);
+ }
+
+ output.println(lmp.getNumFeatures());
+ output.println(lmp.getTargetVariable() + " ~ ");
+ String sep = "";
+ for (String v : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, 0, csv, v);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
+ sep = " + ";
+ }
+ }
+ output.printf("%n");
+ model = lr;
+ for (int row = 0; row < lr.getBeta().numRows(); row++) {
+ for (String key : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, row, csv, key);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
+ }
+ }
+ for (int column = 0; column < lr.getBeta().numCols(); column++) {
+ output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
+ }
+ output.println();
+ }
+ }
+ }
+
+ private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
+ double weight = 0;
+ for (Integer column : csv.getTraceDictionary().get(predictor)) {
+ weight += lr.getBeta().get(row, column);
+ }
+ return weight;
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
+ Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFile = builder.withLongName("input")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+
+ Option outputFile = builder.withLongName("output")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+
+ Option predictors = builder.withLongName("predictors")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("p").create())
+ .withDescription("a list of predictor variables")
+ .create();
+
+ Option types = builder.withLongName("types")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("t").create())
+ .withDescription("a list of predictor variable types (numeric, word, or text)")
+ .create();
+
+ Option target = builder.withLongName("target")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("target").withMaximum(1).create())
+ .withDescription("the name of the target variable")
+ .create();
+
+ Option features = builder.withLongName("features")
+ .withArgument(
+ argumentBuilder.withName("numFeatures")
+ .withDefault("1000")
+ .withMaximum(1).create())
+ .withDescription("the number of internal hashed features to use")
+ .create();
+
+ Option passes = builder.withLongName("passes")
+ .withArgument(
+ argumentBuilder.withName("passes")
+ .withDefault("2")
+ .withMaximum(1).create())
+ .withDescription("the number of times to pass over the input data")
+ .create();
+
+ Option lambda = builder.withLongName("lambda")
+ .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create())
+ .withDescription("the amount of coefficient decay to use")
+ .create();
+
+ Option rate = builder.withLongName("rate")
+ .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create())
+ .withDescription("the learning rate")
+ .create();
+
+ Option noBias = builder.withLongName("noBias")
+ .withDescription("don't include a bias term")
+ .create();
+
+ Option targetCategories = builder.withLongName("categories")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("number").withMaximum(1).create())
+ .withDescription("the number of target categories to be considered")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(quiet)
+ .withOption(inputFile)
+ .withOption(outputFile)
+ .withOption(target)
+ .withOption(targetCategories)
+ .withOption(predictors)
+ .withOption(types)
+ .withOption(passes)
+ .withOption(lambda)
+ .withOption(rate)
+ .withOption(noBias)
+ .withOption(features)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+ TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile);
+
+ List<String> typeList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(types)) {
+ typeList.add(x.toString());
+ }
+
+ List<String> predictorList = new ArrayList<>();
+ for (Object x : cmdLine.getValues(predictors)) {
+ predictorList.add(x.toString());
+ }
+
+ lmp = new LogisticModelParameters();
+ lmp.setTargetVariable(getStringArgument(cmdLine, target));
+ lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
+ lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
+ lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
+ lmp.setTypeMap(predictorList, typeList);
+
+ lmp.setLambda(getDoubleArgument(cmdLine, lambda));
+ lmp.setLearningRate(getDoubleArgument(cmdLine, rate));
+
+ TrainLogistic.scores = getBooleanArgument(cmdLine, scores);
+ TrainLogistic.passes = getIntegerArgument(cmdLine, passes);
+
+ return true;
+ }
+
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static int getIntegerArgument(CommandLine cmdLine, Option features) {
+ return Integer.parseInt((String) cmdLine.getValue(features));
+ }
+
+ private static double getDoubleArgument(CommandLine cmdLine, Option op) {
+ return Double.parseDouble((String) cmdLine.getValue(op));
+ }
+
+ public static OnlineLogisticRegression getModel() {
+ return model;
+ }
+
+ public static LogisticModelParameters getParameters() {
+ return lmp;
+ }
+
+ static BufferedReader open(String inputFile) throws IOException {
+ InputStream in;
+ try {
+ in = Resources.getResource(inputFile).openStream();
+ } catch (IllegalArgumentException e) {
+ in = new FileInputStream(new File(inputFile));
+ }
+ return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
new file mode 100644
index 0000000..632b32c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Reads and trains an adaptive logistic regression model on the 20 newsgroups data.
+ * The first command line argument gives the path of the directory holding the training
+ * data. The optional second argument, leakType, defines which classes of features to use.
+ * Importantly, leakType controls whether a synthetic date is injected into the data as
+ * a target leak and if so, how.
+ * <p/>
+ * The value of leakType % 3 determines whether the target leak is injected according to
+ * the following table:
+ * <p/>
+ * <table>
+ * <tr><td valign='top'>0</td><td>No leak injected</td></tr>
+ * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. This will be a single token and
+ * is a perfect target leak since each newsgroup is given a different month</td></tr>
+ * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy HH:mm:ss format. The day varies
+ * and thus there are more leak symbols that need to be learned. Ultimately this is just
+ * as big a leak as case 1.</td></tr>
+ * </table>
+ * <p/>
+ * Leaktype also determines what other text will be indexed. If leakType is greater
+ * than or equal to 6, then neither headers nor text body will be used for features and the leak is the only
+ * source of data. If leakType is greater than or equal to 3, then subject words will be used as features.
+ * If leakType is less than 3, then both subject and body text will be used as features.
+ * <p/>
+ * A leakType of 0 gives no leak and all textual features.
+ * <p/>
+ * See the following table for a summary of commonly used values for leakType
+ * <p/>
+ * <table>
+ * <tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr>
+ * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr>
+ * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * </table>
+ */
+public final class TrainNewsGroups {
+
+ private TrainNewsGroups() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ File base = new File(args[0]);
+
+ Multiset<String> overallCounts = HashMultiset.create();
+
+ int leakType = 0;
+ if (args.length > 1) {
+ leakType = Integer.parseInt(args[1]);
+ }
+
+ Dictionary newsGroups = new Dictionary();
+
+ NewsgroupHelper helper = new NewsgroupHelper();
+ helper.getEncoder().setProbes(2);
+ AdaptiveLogisticRegression learningAlgorithm =
+ new AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1());
+ learningAlgorithm.setInterval(800);
+ learningAlgorithm.setAveragingWindow(500);
+
+ List<File> files = new ArrayList<>();
+ for (File newsgroup : base.listFiles()) {
+ if (newsgroup.isDirectory()) {
+ newsGroups.intern(newsgroup.getName());
+ files.addAll(Arrays.asList(newsgroup.listFiles()));
+ }
+ }
+ Collections.shuffle(files);
+ System.out.println(files.size() + " training files");
+ SGDInfo info = new SGDInfo();
+
+ int k = 0;
+
+ for (File file : files) {
+ String ng = file.getParentFile().getName();
+ int actual = newsGroups.intern(ng);
+
+ Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
+ learningAlgorithm.train(actual, v);
+
+ k++;
+ State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
+
+ SGDHelper.analyzeState(info, leakType, k, best);
+ }
+ learningAlgorithm.close();
+ SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, overallCounts);
+ System.out.println("exiting main");
+
+ File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group.model");
+ ModelSerializer.writeBinary(modelFile.getAbsolutePath(),
+ learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+
+ List<Integer> counts = new ArrayList<>();
+ System.out.println("Word counts");
+ for (String count : overallCounts.elementSet()) {
+ counts.add(overallCounts.count(count));
+ }
+ Collections.sort(counts, Ordering.natural().reverse());
+ k = 0;
+ for (Integer count : counts) {
+ System.out.println(k + "\t" + count);
+ k++;
+ if (k > 1000) {
+ break;
+ }
+ }
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
new file mode 100644
index 0000000..7a74289
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.Locale;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.classifier.ConfusionMatrix;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/*
+ * Auc and averageLikelihood are always shown if possible, if the number of target value is more than 2,
+ * then Auc and entropy matirx are not shown regardless the value of showAuc and showEntropy
+ * the user passes, because the current implementation does not support them on two value targets.
+ * */
+public final class ValidateAdaptiveLogistic {
+
+ private static String inputFile;
+ private static String modelFile;
+ private static String defaultCategory;
+ private static boolean showAuc;
+ private static boolean showScores;
+ private static boolean showConfusion;
+
+ private ValidateAdaptiveLogistic() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws IOException {
+ if (parseArgs(args)) {
+ if (!showAuc && !showConfusion && !showScores) {
+ showAuc = true;
+ showConfusion = true;
+ }
+
+ Auc collector = null;
+ AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+ .loadFromFile(new File(modelFile));
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
+
+ if (lmp.getTargetCategories().size() <= 2) {
+ collector = new Auc();
+ }
+
+ OnlineSummarizer slh = new OnlineSummarizer();
+ ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(), defaultCategory);
+
+ State<Wrapper, CrossFoldLearner> best = lr.getBest();
+ if (best == null) {
+ output.println("AdaptiveLogisticRegression has not be trained probably.");
+ return;
+ }
+ CrossFoldLearner learner = best.getPayload().getLearner();
+
+ BufferedReader in = TrainLogistic.open(inputFile);
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ if (showScores) {
+ output.println("\"target\", \"model-output\", \"log-likelihood\", \"average-likelihood\"");
+ }
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ //TODO: How to avoid extra target values not shown in the training process.
+ int target = csv.processLine(line, v);
+ double likelihood = learner.logLikelihood(target, v);
+ double score = learner.classifyFull(v).maxValue();
+
+ slh.add(likelihood);
+ cm.addInstance(csv.getTargetString(line), csv.getTargetLabel(target));
+
+ if (showScores) {
+ output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f%n", target,
+ score, learner.logLikelihood(target, v), slh.getMean());
+ }
+ if (collector != null) {
+ collector.add(target, score);
+ }
+ line = in.readLine();
+ }
+
+ output.printf(Locale.ENGLISH,"\nLog-likelihood:");
+ output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f, Median=%.2f%n",
+ slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian());
+
+ if (collector != null) {
+ output.printf(Locale.ENGLISH, "%nAUC = %.2f%n", collector.auc());
+ }
+
+ if (showConfusion) {
+ output.printf(Locale.ENGLISH, "%n%s%n%n", cm.toString());
+
+ if (collector != null) {
+ Matrix m = collector.entropy();
+ output.printf(Locale.ENGLISH,
+ "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0),
+ m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ }
+ }
+
+ }
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+
+ Option auc = builder.withLongName("auc").withDescription("print AUC")
+ .create();
+ Option confusion = builder.withLongName("confusion")
+ .withDescription("print confusion matrix").create();
+
+ Option scores = builder.withLongName("scores")
+ .withDescription("print scores").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get validate data").create();
+
+ Option modelFileOption = builder
+ .withLongName("model")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("model").withMaximum(1)
+ .create())
+ .withDescription("where to get the trained model").create();
+
+ Option defaultCagetoryOption = builder
+ .withLongName("defaultCategory")
+ .withRequired(false)
+ .withArgument(
+ argumentBuilder.withName("defaultCategory").withMaximum(1).withDefault("unknown")
+ .create())
+ .withDescription("the default category value to use").create();
+
+ Group normalArgs = new GroupBuilder().withOption(help)
+ .withOption(quiet).withOption(auc).withOption(scores)
+ .withOption(confusion).withOption(inputFileOption)
+ .withOption(modelFileOption).withOption(defaultCagetoryOption).create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ defaultCategory = getStringArgument(cmdLine, defaultCagetoryOption);
+ showAuc = getBooleanArgument(cmdLine, auc);
+ showScores = getBooleanArgument(cmdLine, scores);
+ showConfusion = getBooleanArgument(cmdLine, confusion);
+
+ return true;
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java
new file mode 100644
index 0000000..ab3c861
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd.bankmarketing;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Uses the SGD classifier on the 'Bank marketing' dataset from UCI.
+ *
+ * See http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
+ *
+ * Learn when people accept or reject an offer from the bank via telephone based on income, age, education and more.
+ */
+public class BankMarketingClassificationMain {
+
+ public static final int NUM_CATEGORIES = 2;
+
+ public static void main(String[] args) throws Exception {
+ List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));
+
+ double heldOutPercentage = 0.10;
+
+ for (int run = 0; run < 20; run++) {
+ Collections.shuffle(calls);
+ int cutoff = (int) (heldOutPercentage * calls.size());
+ List<TelephoneCall> test = calls.subList(0, cutoff);
+ List<TelephoneCall> train = calls.subList(cutoff, calls.size());
+
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1())
+ .learningRate(1)
+ .alpha(1)
+ .lambda(0.000001)
+ .stepOffset(10000)
+ .decayExponent(0.2);
+ for (int pass = 0; pass < 20; pass++) {
+ for (TelephoneCall observation : train) {
+ lr.train(observation.getTarget(), observation.asVector());
+ }
+ if (pass % 5 == 0) {
+ Auc eval = new Auc(0.5);
+ for (TelephoneCall testCall : test) {
+ eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
+ }
+ System.out.printf("%d, %.4f, %.4f\n", pass, lr.currentLearningRate(), eval.auc());
+ }
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java
new file mode 100644
index 0000000..728ec20
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd.bankmarketing;
+
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+public class TelephoneCall {
+ public static final int FEATURES = 100;
+ private static final ConstantValueEncoder interceptEncoder = new ConstantValueEncoder("intercept");
+ private static final FeatureVectorEncoder featureEncoder = new StaticWordValueEncoder("feature");
+
+ private RandomAccessSparseVector vector;
+
+ private Map<String, String> fields = new LinkedHashMap<>();
+
+ public TelephoneCall(Iterable<String> fieldNames, Iterable<String> values) {
+ vector = new RandomAccessSparseVector(FEATURES);
+ Iterator<String> value = values.iterator();
+ interceptEncoder.addToVector("1", vector);
+ for (String name : fieldNames) {
+ String fieldValue = value.next();
+ fields.put(name, fieldValue);
+
+ switch (name) {
+ case "age": {
+ double v = Double.parseDouble(fieldValue);
+ featureEncoder.addToVector(name, Math.log(v), vector);
+ break;
+ }
+ case "balance": {
+ double v;
+ v = Double.parseDouble(fieldValue);
+ if (v < -2000) {
+ v = -2000;
+ }
+ featureEncoder.addToVector(name, Math.log(v + 2001) - 8, vector);
+ break;
+ }
+ case "duration": {
+ double v;
+ v = Double.parseDouble(fieldValue);
+ featureEncoder.addToVector(name, Math.log(v + 1) - 5, vector);
+ break;
+ }
+ case "pdays": {
+ double v;
+ v = Double.parseDouble(fieldValue);
+ featureEncoder.addToVector(name, Math.log(v + 2), vector);
+ break;
+ }
+ case "job":
+ case "marital":
+ case "education":
+ case "default":
+ case "housing":
+ case "loan":
+ case "contact":
+ case "campaign":
+ case "previous":
+ case "poutcome":
+ featureEncoder.addToVector(name + ":" + fieldValue, 1, vector);
+ break;
+ case "day":
+ case "month":
+ case "y":
+ // ignore these for vectorizing
+ break;
+ default:
+ throw new IllegalArgumentException(String.format("Bad field name: %s", name));
+ }
+ }
+ }
+
+ public Vector asVector() {
+ return vector;
+ }
+
+ public int getTarget() {
+ return fields.get("y").equals("no") ? 0 : 1;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java
new file mode 100644
index 0000000..5ef6490
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd.bankmarketing;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Splitter;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.io.Resources;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.Iterator;
+
+/** Parses semi-colon separated data as TelephoneCalls */
+public class TelephoneCallParser implements Iterable<TelephoneCall> {
+
+ private final Splitter onSemi = Splitter.on(";").trimResults(CharMatcher.anyOf("\" ;"));
+ private String resourceName;
+
+ public TelephoneCallParser(String resourceName) throws IOException {
+ this.resourceName = resourceName;
+ }
+
+ @Override
+ public Iterator<TelephoneCall> iterator() {
+ try {
+ return new AbstractIterator<TelephoneCall>() {
+ BufferedReader input =
+ new BufferedReader(new InputStreamReader(Resources.getResource(resourceName).openStream()));
+ Iterable<String> fieldNames = onSemi.split(input.readLine());
+
+ @Override
+ protected TelephoneCall computeNext() {
+ try {
+ String line = input.readLine();
+ if (line == null) {
+ return endOfData();
+ }
+
+ return new TelephoneCall(fieldNames, onSemi.split(line));
+ } catch (IOException e) {
+ throw new RuntimeException("Error reading data", e);
+ }
+ }
+ };
+ } catch (IOException e) {
+ throw new RuntimeException("Error reading data", e);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
new file mode 100644
index 0000000..a0b845f
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.display;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+
+final class ClustersFilter implements PathFilter {
+
+ @Override
+ public boolean accept(Path path) {
+ String pathString = path.toString();
+ return pathString.contains("/clusters-");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
new file mode 100644
index 0000000..50dba99
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.display;
+
+import java.awt.BasicStroke;
+import java.awt.Color;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+
+/**
+ * Java desktop graphics class that runs canopy clustering and displays the results.
+ * This class generates random data and clusters it.
+ */
+@Deprecated
+public class DisplayCanopy extends DisplayClustering {
+
+ DisplayCanopy() {
+ initialize();
+ this.setTitle("Canopy Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+
+ @Override
+ public void paint(Graphics g) {
+ plotSampleData((Graphics2D) g);
+ plotClusters((Graphics2D) g);
+ }
+
+ protected static void plotClusters(Graphics2D g2) {
+ int cx = CLUSTERS.size() - 1;
+ for (List<Cluster> clusters : CLUSTERS) {
+ for (Cluster cluster : clusters) {
+ if (isSignificant(cluster)) {
+ g2.setStroke(new BasicStroke(1));
+ g2.setColor(Color.BLUE);
+ double[] t1 = {T1, T1};
+ plotEllipse(g2, cluster.getCenter(), new DenseVector(t1));
+ double[] t2 = {T2, T2};
+ plotEllipse(g2, cluster.getCenter(), new DenseVector(t2));
+ g2.setColor(COLORS[Math.min(DisplayClustering.COLORS.length - 1, cx)]);
+ g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1));
+ plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3));
+ }
+ }
+ cx--;
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ Path samples = new Path("samples");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, samples);
+ HadoopUtil.delete(conf, output);
+ RandomUtils.useTestSeed();
+ generateSamples();
+ writeSampleData(samples);
+ CanopyDriver.buildClusters(conf, samples, output, new ManhattanDistanceMeasure(), T1, T2, 0, true);
+ loadClustersWritable(output);
+
+ new DisplayCanopy();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
new file mode 100644
index 0000000..ad85c6a
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.display;
+
+import java.awt.*;
+import java.awt.event.WindowAdapter;
+import java.awt.event.WindowEvent;
+import java.awt.geom.AffineTransform;
+import java.awt.geom.Ellipse2D;
+import java.awt.geom.Rectangle2D;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.UncommonDistributions;
+import org.apache.mahout.clustering.classify.WeightedVectorWritable;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class DisplayClustering extends Frame {
+
+ private static final Logger log = LoggerFactory.getLogger(DisplayClustering.class);
+
+ protected static final int DS = 72; // default scale = 72 pixels per inch
+
+ protected static final int SIZE = 8; // screen size in inches
+
+ private static final Collection<Vector> SAMPLE_PARAMS = new ArrayList<>();
+
+ protected static final List<VectorWritable> SAMPLE_DATA = new ArrayList<>();
+
+ protected static final List<List<Cluster>> CLUSTERS = new ArrayList<>();
+
+ static final Color[] COLORS = { Color.red, Color.orange, Color.yellow, Color.green, Color.blue, Color.magenta,
+ Color.lightGray };
+
+ protected static final double T1 = 3.0;
+
+ protected static final double T2 = 2.8;
+
+ static double significance = 0.05;
+
+ protected static int res; // screen resolution
+
+ public DisplayClustering() {
+ initialize();
+ this.setTitle("Sample Data");
+ }
+
+ public void initialize() {
+ // Get screen resolution
+ res = Toolkit.getDefaultToolkit().getScreenResolution();
+
+ // Set Frame size in inches
+ this.setSize(SIZE * res, SIZE * res);
+ this.setVisible(true);
+ this.setTitle("Asymmetric Sample Data");
+
+ // Window listener to terminate program.
+ this.addWindowListener(new WindowAdapter() {
+ @Override
+ public void windowClosing(WindowEvent e) {
+ System.exit(0);
+ }
+ });
+ }
+
+ public static void main(String[] args) throws Exception {
+ RandomUtils.useTestSeed();
+ generateSamples();
+ new DisplayClustering();
+ }
+
+ // Override the paint() method
+ @Override
+ public void paint(Graphics g) {
+ Graphics2D g2 = (Graphics2D) g;
+ plotSampleData(g2);
+ plotSampleParameters(g2);
+ plotClusters(g2);
+ }
+
+ protected static void plotClusters(Graphics2D g2) {
+ int cx = CLUSTERS.size() - 1;
+ for (List<Cluster> clusters : CLUSTERS) {
+ g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1));
+ g2.setColor(COLORS[Math.min(COLORS.length - 1, cx--)]);
+ for (Cluster cluster : clusters) {
+ plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3));
+ }
+ }
+ }
+
+ protected static void plotSampleParameters(Graphics2D g2) {
+ Vector v = new DenseVector(2);
+ Vector dv = new DenseVector(2);
+ g2.setColor(Color.RED);
+ for (Vector param : SAMPLE_PARAMS) {
+ v.set(0, param.get(0));
+ v.set(1, param.get(1));
+ dv.set(0, param.get(2) * 3);
+ dv.set(1, param.get(3) * 3);
+ plotEllipse(g2, v, dv);
+ }
+ }
+
+ protected static void plotSampleData(Graphics2D g2) {
+ double sx = (double) res / DS;
+ g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
+
+ // plot the axes
+ g2.setColor(Color.BLACK);
+ Vector dv = new DenseVector(2).assign(SIZE / 2.0);
+ plotRectangle(g2, new DenseVector(2).assign(2), dv);
+ plotRectangle(g2, new DenseVector(2).assign(-2), dv);
+
+ // plot the sample data
+ g2.setColor(Color.DARK_GRAY);
+ dv.assign(0.03);
+ for (VectorWritable v : SAMPLE_DATA) {
+ plotRectangle(g2, v.get(), dv);
+ }
+ }
+
+ /**
+ * This method plots points and colors them according to their cluster
+ * membership, rather than drawing ellipses.
+ *
+ * As of commit, this method is used only by K-means spectral clustering.
+ * Since the cluster assignments are set within the eigenspace of the data, it
+ * is not inherent that the original data cluster as they would in K-means:
+ * that is, as symmetric gaussian mixtures.
+ *
+ * Since Spectral K-Means uses K-Means to cluster the eigenspace data, the raw
+ * output is not directly usable. Rather, the cluster assignments from the raw
+ * output need to be transferred back to the original data. As such, this
+ * method will read the SequenceFile cluster results of K-means and transfer
+ * the cluster assignments to the original data, coloring them appropriately.
+ *
+ * @param g2
+ * @param data
+ */
+ protected static void plotClusteredSampleData(Graphics2D g2, Path data) {
+ double sx = (double) res / DS;
+ g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
+
+ g2.setColor(Color.BLACK);
+ Vector dv = new DenseVector(2).assign(SIZE / 2.0);
+ plotRectangle(g2, new DenseVector(2).assign(2), dv);
+ plotRectangle(g2, new DenseVector(2).assign(-2), dv);
+
+ // plot the sample data, colored according to the cluster they belong to
+ dv.assign(0.03);
+
+ Path clusteredPointsPath = new Path(data, "clusteredPoints");
+ Path inputPath = new Path(clusteredPointsPath, "part-m-00000");
+ Map<Integer,Color> colors = new HashMap<>();
+ int point = 0;
+ for (Pair<IntWritable,WeightedVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedVectorWritable>(
+ inputPath, new Configuration())) {
+ int clusterId = record.getFirst().get();
+ VectorWritable v = SAMPLE_DATA.get(point++);
+ Integer key = clusterId;
+ if (!colors.containsKey(key)) {
+ colors.put(key, COLORS[Math.min(COLORS.length - 1, colors.size())]);
+ }
+ plotClusteredRectangle(g2, v.get(), dv, colors.get(key));
+ }
+ }
+
+ /**
+ * Identical to plotRectangle(), but with the option of setting the color of
+ * the rectangle's stroke.
+ *
+ * NOTE: This should probably be refactored with plotRectangle() since most of
+ * the code here is direct copy/paste from that method.
+ *
+ * @param g2
+ * A Graphics2D context.
+ * @param v
+ * A vector for the rectangle's center.
+ * @param dv
+ * A vector for the rectangle's dimensions.
+ * @param color
+ * The color of the rectangle's stroke.
+ */
+ protected static void plotClusteredRectangle(Graphics2D g2, Vector v, Vector dv, Color color) {
+ double[] flip = {1, -1};
+ Vector v2 = v.times(new DenseVector(flip));
+ v2 = v2.minus(dv.divide(2));
+ int h = SIZE / 2;
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+
+ g2.setStroke(new BasicStroke(1));
+ g2.setColor(color);
+ g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
+ }
+
+ /**
+ * Draw a rectangle on the graphics context
+ *
+ * @param g2
+ * a Graphics2D context
+ * @param v
+ * a Vector of rectangle center
+ * @param dv
+ * a Vector of rectangle dimensions
+ */
+ protected static void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
+ double[] flip = {1, -1};
+ Vector v2 = v.times(new DenseVector(flip));
+ v2 = v2.minus(dv.divide(2));
+ int h = SIZE / 2;
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+ g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
+ }
+
+ /**
+ * Draw an ellipse on the graphics context
+ *
+ * @param g2
+ * a Graphics2D context
+ * @param v
+ * a Vector of ellipse center
+ * @param dv
+ * a Vector of ellipse dimensions
+ */
+ protected static void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
+ double[] flip = {1, -1};
+ Vector v2 = v.times(new DenseVector(flip));
+ v2 = v2.minus(dv.divide(2));
+ int h = SIZE / 2;
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+ g2.draw(new Ellipse2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
+ }
+
+ protected static void generateSamples() {
+ generateSamples(500, 1, 1, 3);
+ generateSamples(300, 1, 0, 0.5);
+ generateSamples(300, 0, 2, 0.1);
+ }
+
+ protected static void generate2dSamples() {
+ generate2dSamples(500, 1, 1, 3, 1);
+ generate2dSamples(300, 1, 0, 0.5, 1);
+ generate2dSamples(300, 0, 2, 0.1, 0.5);
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num
+ * int number of samples to generate
+ * @param mx
+ * double x-value of the sample mean
+ * @param my
+ * double y-value of the sample mean
+ * @param sd
+ * double standard deviation of the samples
+ */
+ protected static void generateSamples(int num, double mx, double my, double sd) {
+ double[] params = {mx, my, sd, sd};
+ SAMPLE_PARAMS.add(new DenseVector(params));
+ log.info("Generating {} samples m=[{}, {}] sd={}", num, mx, my, sd);
+ for (int i = 0; i < num; i++) {
+ SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sd),
+ UncommonDistributions.rNorm(my, sd)})));
+ }
+ }
+
+ protected static void writeSampleData(Path output) throws IOException {
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+
+ try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, output, Text.class, VectorWritable.class)) {
+ int i = 0;
+ for (VectorWritable vw : SAMPLE_DATA) {
+ writer.append(new Text("sample_" + i++), vw);
+ }
+ }
+ }
+
+ protected static List<Cluster> readClustersWritable(Path clustersIn) {
+ List<Cluster> clusters = new ArrayList<>();
+ Configuration conf = new Configuration();
+ for (ClusterWritable value : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST,
+ PathFilters.logsCRCFilter(), conf)) {
+ Cluster cluster = value.getValue();
+ log.info(
+ "Reading Cluster:{} center:{} numPoints:{} radius:{}",
+ cluster.getId(), AbstractCluster.formatVector(cluster.getCenter(), null),
+ cluster.getNumObservations(), AbstractCluster.formatVector(cluster.getRadius(), null));
+ clusters.add(cluster);
+ }
+ return clusters;
+ }
+
+ protected static void loadClustersWritable(Path output) throws IOException {
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ for (FileStatus s : fs.listStatus(output, new ClustersFilter())) {
+ List<Cluster> clusters = readClustersWritable(s.getPath());
+ CLUSTERS.add(clusters);
+ }
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num
+ * int number of samples to generate
+ * @param mx
+ * double x-value of the sample mean
+ * @param my
+ * double y-value of the sample mean
+ * @param sdx
+ * double x-value standard deviation of the samples
+ * @param sdy
+ * double y-value standard deviation of the samples
+ */
+ protected static void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
+ double[] params = {mx, my, sdx, sdy};
+ SAMPLE_PARAMS.add(new DenseVector(params));
+ log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy);
+ for (int i = 0; i < num; i++) {
+ SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy)})));
+ }
+ }
+
+ protected static boolean isSignificant(Cluster cluster) {
+ return (double) cluster.getNumObservations() / SAMPLE_DATA.size() > significance;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
new file mode 100644
index 0000000..f8ce7c7
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.display;
+
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
+import org.apache.mahout.clustering.iterator.ClusterIterator;
+import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy;
+import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.collect.Lists;
+
+public class DisplayFuzzyKMeans extends DisplayClustering {
+
+ DisplayFuzzyKMeans() {
+ initialize();
+ this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+
+ // Override the paint() method
+ @Override
+ public void paint(Graphics g) {
+ plotSampleData((Graphics2D) g);
+ plotClusters((Graphics2D) g);
+ }
+
+ public static void main(String[] args) throws Exception {
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+
+ Path samples = new Path("samples");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, output);
+ HadoopUtil.delete(conf, samples);
+ RandomUtils.useTestSeed();
+ DisplayClustering.generateSamples();
+ writeSampleData(samples);
+ boolean runClusterer = true;
+ int maxIterations = 10;
+ float threshold = 0.001F;
+ float m = 1.1F;
+ if (runClusterer) {
+ runSequentialFuzzyKClusterer(conf, samples, output, measure, maxIterations, m, threshold);
+ } else {
+ int numClusters = 3;
+ runSequentialFuzzyKClassifier(conf, samples, output, measure, numClusters, maxIterations, m, threshold);
+ }
+ new DisplayFuzzyKMeans();
+ }
+
+ private static void runSequentialFuzzyKClassifier(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int numClusters, int maxIterations, float m, double threshold) throws IOException {
+ Collection<Vector> points = Lists.newArrayList();
+ for (int i = 0; i < numClusters; i++) {
+ points.add(SAMPLE_DATA.get(i).get());
+ }
+ List<Cluster> initialClusters = Lists.newArrayList();
+ int id = 0;
+ for (Vector point : points) {
+ initialClusters.add(new SoftCluster(point, id++, measure));
+ }
+ ClusterClassifier prior = new ClusterClassifier(initialClusters, new FuzzyKMeansClusteringPolicy(m, threshold));
+ Path priorPath = new Path(output, "classifier-0");
+ prior.writeToSeqFiles(priorPath);
+
+ ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations);
+ loadClustersWritable(output);
+ }
+
+ private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException,
+ ClassNotFoundException, InterruptedException {
+ Path clustersIn = new Path(output, "random-seeds");
+ RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure);
+ FuzzyKMeansDriver.run(samples, clustersIn, output, threshold, maxIterations, m, true, true, threshold,
+ true);
+
+ loadClustersWritable(output);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
new file mode 100644
index 0000000..336d69e
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.display;
+
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.iterator.ClusterIterator;
+import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.collect.Lists;
+
+public class DisplayKMeans extends DisplayClustering {
+
+ DisplayKMeans() {
+ initialize();
+ this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+
+ public static void main(String[] args) throws Exception {
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ Path samples = new Path("samples");
+ Path output = new Path("output");
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, samples);
+ HadoopUtil.delete(conf, output);
+
+ RandomUtils.useTestSeed();
+ generateSamples();
+ writeSampleData(samples);
+ boolean runClusterer = true;
+ double convergenceDelta = 0.001;
+ int numClusters = 3;
+ int maxIterations = 10;
+ if (runClusterer) {
+ runSequentialKMeansClusterer(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta);
+ } else {
+ runSequentialKMeansClassifier(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta);
+ }
+ new DisplayKMeans();
+ }
+
+ private static void runSequentialKMeansClassifier(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta) throws IOException {
+ Collection<Vector> points = Lists.newArrayList();
+ for (int i = 0; i < numClusters; i++) {
+ points.add(SAMPLE_DATA.get(i).get());
+ }
+ List<Cluster> initialClusters = Lists.newArrayList();
+ int id = 0;
+ for (Vector point : points) {
+ initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure));
+ }
+ ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta));
+ Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
+ prior.writeToSeqFiles(priorPath);
+
+ ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations);
+ loadClustersWritable(output);
+ }
+
+ private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path output,
+ DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Path clustersIn = new Path(output, "random-seeds");
+ RandomSeedGenerator.buildRandom(conf, samples, clustersIn, numClusters, measure);
+ KMeansDriver.run(samples, clustersIn, output, convergenceDelta, maxIterations, true, 0.0, true);
+ loadClustersWritable(output);
+ }
+
+ // Override the paint() method
+ @Override
+ public void paint(Graphics g) {
+ plotSampleData((Graphics2D) g);
+ plotClusters((Graphics2D) g);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java
new file mode 100644
index 0000000..2b70749
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.display;
+
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+import java.io.Writer;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+
+public class DisplaySpectralKMeans extends DisplayClustering {
+
+ protected static final String SAMPLES = "samples";
+ protected static final String OUTPUT = "output";
+ protected static final String TEMP = "tmp";
+ protected static final String AFFINITIES = "affinities";
+
+ DisplaySpectralKMeans() {
+ initialize();
+ setTitle("Spectral k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
+ }
+
+ public static void main(String[] args) throws Exception {
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ Path samples = new Path(SAMPLES);
+ Path output = new Path(OUTPUT);
+ Path tempDir = new Path(TEMP);
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, samples);
+ HadoopUtil.delete(conf, output);
+
+ RandomUtils.useTestSeed();
+ DisplayClustering.generateSamples();
+ writeSampleData(samples);
+ Path affinities = new Path(output, AFFINITIES);
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ if (!fs.exists(output)) {
+ fs.mkdirs(output);
+ }
+
+ try (Writer writer = new BufferedWriter(new FileWriter(affinities.toString()))){
+ for (int i = 0; i < SAMPLE_DATA.size(); i++) {
+ for (int j = 0; j < SAMPLE_DATA.size(); j++) {
+ writer.write(i + "," + j + ',' + measure.distance(SAMPLE_DATA.get(i).get(),
+ SAMPLE_DATA.get(j).get()) + '\n');
+ }
+ }
+ }
+
+ int maxIter = 10;
+ double convergenceDelta = 0.001;
+ SpectralKMeansDriver.run(new Configuration(), affinities, output, SAMPLE_DATA.size(), 3, measure,
+ convergenceDelta, maxIter, tempDir);
+ new DisplaySpectralKMeans();
+ }
+
+ @Override
+ public void paint(Graphics g) {
+ plotClusteredSampleData((Graphics2D) g, new Path(new Path(OUTPUT), "kmeans_out"));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/README.txt
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/README.txt b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/README.txt
new file mode 100644
index 0000000..470c16c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/display/README.txt
@@ -0,0 +1,22 @@
+The following classes can be run without parameters to generate a sample data set and
+run the reference clustering implementations over them:
+
+DisplayClustering - generates 1000 samples from three, symmetric distributions. This is the same
+ data set that is used by the following clustering programs. It displays the points on a screen
+ and superimposes the model parameters that were used to generate the points. You can edit the
+ generateSamples() method to change the sample points used by these programs.
+
+ * DisplayCanopy - uses Canopy clustering
+ * DisplayKMeans - uses k-Means clustering
+ * DisplayFuzzyKMeans - uses Fuzzy k-Means clustering
+
+ * NOTE: some of these programs display the sample points and then superimpose all of the clusters
+ from each iteration. The last iteration's clusters are in bold red and the previous several are
+ colored (orange, yellow, green, blue, violet) in order after which all earlier clusters are in
+ light grey. This helps to visualize how the clusters converge upon a solution over multiple
+ iterations.
+ * NOTE: by changing the parameter values (k, ALPHA_0, numIterations) and the display SIGNIFICANCE
+ you can obtain different results.
+
+
+
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java
new file mode 100644
index 0000000..c29cbc4
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.tools;
+
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.List;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+public class ClusterQualitySummarizer extends AbstractJob {
+ private String outputFile;
+
+ private PrintWriter fileOut;
+
+ private String trainFile;
+ private String testFile;
+ private String centroidFile;
+ private String centroidCompareFile;
+ private boolean mahoutKMeansFormat;
+ private boolean mahoutKMeansFormatCompare;
+
+ private DistanceMeasure distanceMeasure = new SquaredEuclideanDistanceMeasure();
+
+ public void printSummaries(List<OnlineSummarizer> summarizers, String type) {
+ printSummaries(summarizers, type, fileOut);
+ }
+
+ public static void printSummaries(List<OnlineSummarizer> summarizers, String type, PrintWriter fileOut) {
+ double maxDistance = 0;
+ for (int i = 0; i < summarizers.size(); ++i) {
+ OnlineSummarizer summarizer = summarizers.get(i);
+ if (summarizer.getCount() > 1) {
+ maxDistance = Math.max(maxDistance, summarizer.getMax());
+ System.out.printf("Average distance in cluster %d [%d]: %f\n", i, summarizer.getCount(), summarizer.getMean());
+ // If there is just one point in the cluster, quartiles cannot be estimated. We'll just assume all the quartiles
+ // equal the only value.
+ if (fileOut != null) {
+ fileOut.printf("%d,%f,%f,%f,%f,%f,%f,%f,%d,%s\n", i, summarizer.getMean(),
+ summarizer.getSD(),
+ summarizer.getQuartile(0),
+ summarizer.getQuartile(1),
+ summarizer.getQuartile(2),
+ summarizer.getQuartile(3),
+ summarizer.getQuartile(4), summarizer.getCount(), type);
+ }
+ } else {
+ System.out.printf("Cluster %d is has %d data point. Need atleast 2 data points in a cluster for" +
+ " OnlineSummarizer.\n", i, summarizer.getCount());
+ }
+ }
+ System.out.printf("Num clusters: %d; maxDistance: %f\n", summarizers.size(), maxDistance);
+ }
+
+ public int run(String[] args) throws IOException {
+ if (!parseArgs(args)) {
+ return -1;
+ }
+
+ Configuration conf = new Configuration();
+ try {
+ fileOut = new PrintWriter(new FileOutputStream(outputFile));
+ fileOut.printf("cluster,distance.mean,distance.sd,distance.q0,distance.q1,distance.q2,distance.q3,"
+ + "distance.q4,count,is.train\n");
+
+ // Reading in the centroids (both pairs, if they exist).
+ List<Centroid> centroids;
+ List<Centroid> centroidsCompare = null;
+ if (mahoutKMeansFormat) {
+ SequenceFileDirValueIterable<ClusterWritable> clusterIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf);
+ centroids = Lists.newArrayList(IOUtils.getCentroidsFromClusterWritableIterable(clusterIterable));
+ } else {
+ SequenceFileDirValueIterable<CentroidWritable> centroidIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf);
+ centroids = Lists.newArrayList(IOUtils.getCentroidsFromCentroidWritableIterable(centroidIterable));
+ }
+
+ if (centroidCompareFile != null) {
+ if (mahoutKMeansFormatCompare) {
+ SequenceFileDirValueIterable<ClusterWritable> clusterCompareIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf);
+ centroidsCompare = Lists.newArrayList(
+ IOUtils.getCentroidsFromClusterWritableIterable(clusterCompareIterable));
+ } else {
+ SequenceFileDirValueIterable<CentroidWritable> centroidCompareIterable =
+ new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf);
+ centroidsCompare = Lists.newArrayList(
+ IOUtils.getCentroidsFromCentroidWritableIterable(centroidCompareIterable));
+ }
+ }
+
+ // Reading in the "training" set.
+ SequenceFileDirValueIterable<VectorWritable> trainIterable =
+ new SequenceFileDirValueIterable<>(new Path(trainFile), PathType.GLOB, conf);
+ Iterable<Vector> trainDatapoints = IOUtils.getVectorsFromVectorWritableIterable(trainIterable);
+ Iterable<Vector> datapoints = trainDatapoints;
+
+ printSummaries(ClusteringUtils.summarizeClusterDistances(trainDatapoints, centroids,
+ new SquaredEuclideanDistanceMeasure()), "train");
+
+ // Also adding in the "test" set.
+ if (testFile != null) {
+ SequenceFileDirValueIterable<VectorWritable> testIterable =
+ new SequenceFileDirValueIterable<>(new Path(testFile), PathType.GLOB, conf);
+ Iterable<Vector> testDatapoints = IOUtils.getVectorsFromVectorWritableIterable(testIterable);
+
+ printSummaries(ClusteringUtils.summarizeClusterDistances(testDatapoints, centroids,
+ new SquaredEuclideanDistanceMeasure()), "test");
+
+ datapoints = Iterables.concat(trainDatapoints, testDatapoints);
+ }
+
+ // At this point, all train/test CSVs have been written. We now compute quality metrics.
+ List<OnlineSummarizer> summaries =
+ ClusteringUtils.summarizeClusterDistances(datapoints, centroids, distanceMeasure);
+ List<OnlineSummarizer> compareSummaries = null;
+ if (centroidsCompare != null) {
+ compareSummaries = ClusteringUtils.summarizeClusterDistances(datapoints, centroidsCompare, distanceMeasure);
+ }
+ System.out.printf("[Dunn Index] First: %f", ClusteringUtils.dunnIndex(centroids, distanceMeasure, summaries));
+ if (compareSummaries != null) {
+ System.out.printf(" Second: %f\n", ClusteringUtils.dunnIndex(centroidsCompare, distanceMeasure, compareSummaries));
+ } else {
+ System.out.printf("\n");
+ }
+ System.out.printf("[Davies-Bouldin Index] First: %f",
+ ClusteringUtils.daviesBouldinIndex(centroids, distanceMeasure, summaries));
+ if (compareSummaries != null) {
+ System.out.printf(" Second: %f\n",
+ ClusteringUtils.daviesBouldinIndex(centroidsCompare, distanceMeasure, compareSummaries));
+ } else {
+ System.out.printf("\n");
+ }
+ } catch (IOException e) {
+ System.out.println(e.getMessage());
+ } finally {
+ Closeables.close(fileOut, false);
+ }
+ return 0;
+ }
+
+ private boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withShortName("i")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get seq files with the vectors (training set)")
+ .create();
+
+ Option testInputFileOption = builder.withLongName("testInput")
+ .withShortName("itest")
+ .withArgument(argumentBuilder.withName("testInput").withMaximum(1).create())
+ .withDescription("where to get seq files with the vectors (test set)")
+ .create();
+
+ Option centroidsFileOption = builder.withLongName("centroids")
+ .withShortName("c")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("centroids").withMaximum(1).create())
+ .withDescription("where to get seq files with the centroids (from Mahout KMeans or StreamingKMeansDriver)")
+ .create();
+
+ Option centroidsCompareFileOption = builder.withLongName("centroidsCompare")
+ .withShortName("cc")
+ .withRequired(false)
+ .withArgument(argumentBuilder.withName("centroidsCompare").withMaximum(1).create())
+ .withDescription("where to get seq files with the second set of centroids (from Mahout KMeans or "
+ + "StreamingKMeansDriver)")
+ .create();
+
+ Option outputFileOption = builder.withLongName("output")
+ .withShortName("o")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .withDescription("where to dump the CSV file with the results")
+ .create();
+
+ Option mahoutKMeansFormatOption = builder.withLongName("mahoutkmeansformat")
+ .withShortName("mkm")
+ .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
+ .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
+ .create();
+
+ Option mahoutKMeansCompareFormatOption = builder.withLongName("mahoutkmeansformatCompare")
+ .withShortName("mkmc")
+ .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
+ .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(testInputFileOption)
+ .withOption(outputFileOption)
+ .withOption(centroidsFileOption)
+ .withOption(centroidsCompareFileOption)
+ .withOption(mahoutKMeansFormatOption)
+ .withOption(mahoutKMeansCompareFormatOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 150));
+
+ CommandLine cmdLine = parser.parseAndHelp(args);
+ if (cmdLine == null) {
+ return false;
+ }
+
+ trainFile = (String) cmdLine.getValue(inputFileOption);
+ if (cmdLine.hasOption(testInputFileOption)) {
+ testFile = (String) cmdLine.getValue(testInputFileOption);
+ }
+ centroidFile = (String) cmdLine.getValue(centroidsFileOption);
+ if (cmdLine.hasOption(centroidsCompareFileOption)) {
+ centroidCompareFile = (String) cmdLine.getValue(centroidsCompareFileOption);
+ }
+ outputFile = (String) cmdLine.getValue(outputFileOption);
+ if (cmdLine.hasOption(mahoutKMeansFormatOption)) {
+ mahoutKMeansFormat = true;
+ }
+ if (cmdLine.hasOption(mahoutKMeansCompareFormatOption)) {
+ mahoutKMeansFormatCompare = true;
+ }
+ return true;
+ }
+
+ public static void main(String[] args) throws IOException {
+ new ClusterQualitySummarizer().run(args);
+ }
+}