You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/13 20:08:05 UTC
svn commit: r909871 [4/7] - in
/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout: analysis/
cf/taste/ejb/ cf/taste/example/ cf/taste/example/bookcrossing/
cf/taste/example/grouplens/ cf/taste/example/jester/
cf/taste/example/netflix/ classi...
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java Sat Feb 13 19:07:36 2010
@@ -17,6 +17,9 @@
package org.apache.mahout.df;
+import java.io.IOException;
+import java.util.Random;
+
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
@@ -45,111 +48,116 @@
import org.slf4j.LoggerFactory;
import org.uncommons.maths.Maths;
-import java.io.IOException;
-import java.util.Random;
-
/**
* Test procedure as described in Breiman's paper.<br>
* <b>Leo Breiman: Random Forests. Machine Learning 45(1): 5-32 (2001)</b>
*/
public class BreimanExample extends Configured implements Tool {
-
+
private static final Logger log = LoggerFactory.getLogger(BreimanExample.class);
-
+
/** sum test error */
private double sumTestErr;
-
+
/** sum mean tree error */
private double sumTreeErr;
-
+
/** sum test error with m=1 */
private double sumOneErr;
-
+
/** mean time to build a forest with m=log2(M)+1 */
private long sumTimeM;
-
+
/** mean time to build a forest with m=1 */
private long sumTimeOne;
-
+
/** mean number of nodes for all the trees grown with m=log2(M)+1 */
private long numNodesM;
-
+
/** mean number of nodes for all the trees grown with m=1 */
private long numNodesOne;
-
+
/**
* runs one iteration of the procedure.
- *
- * @param rng random numbers generator
- * @param data training data
- * @param m number of random variables to select at each tree-node
- * @param nbtrees number of trees to grow
- * @throws Exception if an error occured while growing the trees
+ *
+ * @param rng
+ * random numbers generator
+ * @param data
+ * training data
+ * @param m
+ * number of random variables to select at each tree-node
+ * @param nbtrees
+ * number of trees to grow
+ * @throws Exception
+ * if an error occured while growing the trees
*/
private void runIteration(Random rng, Data data, int m, int nbtrees) {
-
+
int nblabels = data.getDataset().nblabels();
-
- log.info("Splitting the data");
+
+ BreimanExample.log.info("Splitting the data");
Data train = data.clone();
Data test = train.rsplit(rng, (int) (data.size() * 0.1));
int[] trainLabels = train.extractLabels();
int[] testLabels = test.extractLabels();
-
+
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
SequentialBuilder forestBuilder = new SequentialBuilder(rng, treeBuilder, train);
-
+
// grow a forest with m = log2(M)+1
- ForestPredictions errorM = new ForestPredictions(train.size(), nblabels); // oob error when using m = log2(M)+1
+ ForestPredictions errorM = new ForestPredictions(train.size(), nblabels); // oob error when using m =
+ // log2(M)+1
treeBuilder.setM(m);
-
+
long time = System.currentTimeMillis();
- log.info("Growing a forest with m={}", m);
+ BreimanExample.log.info("Growing a forest with m={}", m);
DecisionForest forestM = forestBuilder.build(nbtrees, errorM);
sumTimeM += System.currentTimeMillis() - time;
numNodesM += forestM.nbNodes();
-
- double oobM = ErrorEstimate.errorRate(trainLabels, errorM.computePredictions(rng)); // oob error estimate when m = log2(M)+1
-
+
+ double oobM = ErrorEstimate.errorRate(trainLabels, errorM.computePredictions(rng)); // oob error estimate
+ // when m = log2(M)+1
+
// grow a forest with m=1
ForestPredictions errorOne = new ForestPredictions(train.size(), nblabels); // oob error when using m = 1
treeBuilder.setM(1);
-
+
time = System.currentTimeMillis();
- log.info("Growing a forest with m=1");
+ BreimanExample.log.info("Growing a forest with m=1");
DecisionForest forestOne = forestBuilder.build(nbtrees, errorOne);
sumTimeOne += System.currentTimeMillis() - time;
numNodesOne += forestOne.nbNodes();
- double oobOne = ErrorEstimate.errorRate(trainLabels, errorOne.computePredictions(rng)); // oob error estimate when m = 1
-
+ double oobOne = ErrorEstimate.errorRate(trainLabels, errorOne.computePredictions(rng)); // oob error
+ // estimate when m
+ // = 1
+
// compute the test set error (Selection Error), and mean tree error (One Tree Error),
// using the lowest oob error forest
ForestPredictions testError = new ForestPredictions(test.size(), nblabels); // test set error
MeanTreeCollector treeError = new MeanTreeCollector(test, nbtrees); // mean tree error
-
+
// compute the test set error using m=1 (Single Input Error)
errorOne = new ForestPredictions(test.size(), nblabels);
-
+
if (oobM < oobOne) {
forestM.classify(test, new MultiCallback(testError, treeError));
forestOne.classify(test, errorOne);
} else {
- forestOne.classify(test,
- new MultiCallback(testError, treeError, errorOne));
+ forestOne.classify(test, new MultiCallback(testError, treeError, errorOne));
}
-
+
sumTestErr += ErrorEstimate.errorRate(testLabels, testError.computePredictions(rng));
sumOneErr += ErrorEstimate.errorRate(testLabels, errorOne.computePredictions(rng));
sumTreeErr += treeError.meanTreeError();
}
-
+
public static void main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new BreimanExample(), args);
}
-
+
@Override
public int run(String[] args) throws IOException {
@@ -157,28 +165,27 @@
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
- Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true)
- .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
- .withDescription("Data path").create();
-
- Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
- .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
- .withDescription("Dataset path").create();
-
- Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true)
- .withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
- .withDescription("Number of trees to grow, each iteration").create();
-
+ Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset path")
+ .create();
+
+ Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true).withArgument(
+ abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Number of trees to grow, each iteration").create();
+
Option nbItersOpt = obuilder.withLongName("iterations").withShortName("i").withRequired(true)
.withArgument(abuilder.withName("numIterations").withMinimum(1).withMaximum(1).create())
.withDescription("Number of times to repeat the test").create();
-
- Option helpOpt = obuilder.withLongName("help").withDescription("Print out help")
- .withShortName("h").create();
-
- Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt)
- .withOption(nbItersOpt).withOption(nbtreesOpt).withOption(helpOpt).create();
-
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(
+ nbItersOpt).withOption(nbtreesOpt).withOption(helpOpt).create();
+
Path dataPath;
Path datasetPath;
int nbTrees;
@@ -188,21 +195,21 @@
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption("help")) {
CommandLineUtil.printHelp(group);
return -1;
}
-
+
String dataName = cmdLine.getValue(dataOpt).toString();
String datasetName = cmdLine.getValue(datasetOpt).toString();
nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
nbIterations = Integer.parseInt(cmdLine.getValue(nbItersOpt).toString());
-
+
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
} catch (OptionException e) {
- log.error("Error while parsing options", e);
+ BreimanExample.log.error("Error while parsing options", e);
CommandLineUtil.printHelp(group);
return -1;
}
@@ -211,27 +218,27 @@
FileSystem fs = dataPath.getFileSystem(new Configuration());
Dataset dataset = Dataset.load(getConf(), datasetPath);
Data data = DataLoader.loadData(dataset, fs, dataPath);
-
+
// take m to be the first integer less than log2(M) + 1, where M is the
// number of inputs
int m = (int) Math.floor(Maths.log(2, data.getDataset().nbAttributes()) + 1);
-
+
Random rng = RandomUtils.getRandom();
for (int iteration = 0; iteration < nbIterations; iteration++) {
- log.info("Iteration {}", iteration);
+ BreimanExample.log.info("Iteration {}", iteration);
runIteration(rng, data, m, nbTrees);
}
-
- log.info("********************************************");
- log.info("Selection error : {}", sumTestErr / nbIterations);
- log.info("Single Input error : {}", sumOneErr / nbIterations);
- log.info("One Tree error : {}", sumTreeErr / nbIterations);
- log.info("Mean Random Input Time : {}", DFUtils.elapsedTime(sumTimeM / nbIterations));
- log.info("Mean Single Input Time : {}", DFUtils.elapsedTime(sumTimeOne / nbIterations));
- log.info("Mean Random Input Num Nodes : {}", numNodesM / nbIterations);
- log.info("Mean Single Input Num Nodes : {}", numNodesOne / nbIterations);
-
+
+ BreimanExample.log.info("********************************************");
+ BreimanExample.log.info("Selection error : {}", sumTestErr / nbIterations);
+ BreimanExample.log.info("Single Input error : {}", sumOneErr / nbIterations);
+ BreimanExample.log.info("One Tree error : {}", sumTreeErr / nbIterations);
+ BreimanExample.log.info("Mean Random Input Time : {}", DFUtils.elapsedTime(sumTimeM / nbIterations));
+ BreimanExample.log.info("Mean Single Input Time : {}", DFUtils.elapsedTime(sumTimeOne / nbIterations));
+ BreimanExample.log.info("Mean Random Input Num Nodes : {}", numNodesM / nbIterations);
+ BreimanExample.log.info("Mean Single Input Num Nodes : {}", numNodesOne / nbIterations);
+
return 0;
}
-
+
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java Sat Feb 13 19:07:36 2010
@@ -17,6 +17,9 @@
package org.apache.mahout.df.mapred;
+import java.io.IOException;
+import java.util.Random;
+
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
@@ -46,174 +49,164 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.util.Random;
-
/**
- * Tool to builds a Random Forest using any given dataset (in UCI format). Can
- * use either the in-mem mapred or partial mapred implementations
+ * Tool to builds a Random Forest using any given dataset (in UCI format). Can use either the in-mem mapred or
+ * partial mapred implementations
*/
public class BuildForest extends Configured implements Tool {
-
+
private static final Logger log = LoggerFactory.getLogger(BuildForest.class);
-
+
private Path dataPath; // Data path
-
+
private Path datasetPath; // Dataset path
-
+
private int m; // Number of variables to select at each tree-node
-
+
private int nbTrees; // Number of trees to grow
-
+
private Long seed = null; // Random seed
-
+
private boolean isPartial; // use partial data implementation
-
+
private boolean isOob; // estimate oob error;
-
+
@Override
public int run(String[] args) throws IOException {
-
+
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
- Option oobOpt = obuilder.withShortName("oob").withRequired(false)
- .withDescription("Optional, estimate the out-of-bag error").create();
-
- Option dataOpt = obuilder.withLongName("data").withShortName("d")
- .withRequired(true).withArgument(
- abuilder.withName("path").withMinimum(1).withMaximum(1).create())
- .withDescription("Data path").create();
-
- Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
- .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
- .withDescription("Dataset path").create();
-
- Option selectionOpt = obuilder.withLongName("selection")
- .withShortName("sl").withRequired(true).withArgument(
- abuilder.withName("m").withMinimum(1).withMaximum(1).create())
- .withDescription("Number of variables to select randomly at each tree-node")
+
+ Option oobOpt = obuilder.withShortName("oob").withRequired(false).withDescription(
+ "Optional, estimate the out-of-bag error").create();
+
+ Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset path")
.create();
-
- Option seedOpt = obuilder.withLongName("seed").withShortName("sd").withRequired(false)
- .withArgument(abuilder.withName("seed").withMinimum(1).withMaximum(1).create())
- .withDescription("Optional, seed value used to initialise the Random number generator")
+
+ Option selectionOpt = obuilder.withLongName("selection").withShortName("sl").withRequired(true)
+ .withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Number of variables to select randomly at each tree-node").create();
+
+ Option seedOpt = obuilder.withLongName("seed").withShortName("sd").withRequired(false).withArgument(
+ abuilder.withName("seed").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Optional, seed value used to initialise the Random number generator").create();
+
+ Option partialOpt = obuilder.withLongName("partial").withShortName("p").withRequired(false)
+ .withDescription("Optional, use the Partial Data implementation").create();
+
+ Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true).withArgument(
+ abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Number of trees to grow").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
.create();
-
- Option partialOpt = obuilder.withLongName("partial").withShortName("p")
- .withRequired(false).withDescription("Optional, use the Partial Data implementation").create();
-
- Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true)
- .withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
- .withDescription("Number of trees to grow").create();
-
- Option helpOpt = obuilder.withLongName("help").withDescription(
- "Print out help").withShortName("h").create();
-
- Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(
- dataOpt).withOption(datasetOpt).withOption(selectionOpt).withOption(
- seedOpt).withOption(partialOpt).withOption(nbtreesOpt).withOption(
- helpOpt).create();
-
+
+ Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(dataOpt).withOption(datasetOpt)
+ .withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
+ .withOption(helpOpt).create();
+
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption("help")) {
CommandLineUtil.printHelp(group);
return -1;
}
-
+
isPartial = cmdLine.hasOption(partialOpt);
isOob = cmdLine.hasOption(oobOpt);
String dataName = cmdLine.getValue(dataOpt).toString();
String datasetName = cmdLine.getValue(datasetOpt).toString();
m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
-
+
if (cmdLine.hasOption(seedOpt)) {
seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
}
-
- log.debug("data : {}", dataName);
- log.debug("dataset : {}", datasetName);
- log.debug("m : {}", m);
- log.debug("seed : {}", seed);
- log.debug("nbtrees : {}", nbTrees);
- log.debug("isPartial : {}", isPartial);
- log.debug("isOob : {}", isOob);
-
+
+ BuildForest.log.debug("data : {}", dataName);
+ BuildForest.log.debug("dataset : {}", datasetName);
+ BuildForest.log.debug("m : {}", m);
+ BuildForest.log.debug("seed : {}", seed);
+ BuildForest.log.debug("nbtrees : {}", nbTrees);
+ BuildForest.log.debug("isPartial : {}", isPartial);
+ BuildForest.log.debug("isOob : {}", isOob);
+
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
-
+
} catch (OptionException e) {
- log.error("Error while parsing options", e);
+ BuildForest.log.error("Error while parsing options", e);
CommandLineUtil.printHelp(group);
return -1;
}
-
+
buildForest();
-
+
return 0;
}
-
+
private DecisionForest buildForest() throws IOException {
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
treeBuilder.setM(m);
-
+
Dataset dataset = Dataset.load(getConf(), datasetPath);
-
- ForestPredictions callback = (isOob) ? new ForestPredictions(dataset
- .nbInstances(), dataset.nblabels()) : null;
-
+
+ ForestPredictions callback = isOob ? new ForestPredictions(dataset.nbInstances(), dataset.nblabels())
+ : null;
+
Builder forestBuilder;
-
+
if (isPartial) {
- log.info("Partial Mapred implementation");
- forestBuilder = new PartialBuilder(treeBuilder, dataPath, datasetPath,
- seed, getConf());
+ BuildForest.log.info("Partial Mapred implementation");
+ forestBuilder = new PartialBuilder(treeBuilder, dataPath, datasetPath, seed, getConf());
} else {
- log.info("InMem Mapred implementation");
- forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath,
- seed, getConf());
+ BuildForest.log.info("InMem Mapred implementation");
+ forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath, seed, getConf());
}
-
- log.info("Building the forest...");
+
+ BuildForest.log.info("Building the forest...");
long time = System.currentTimeMillis();
-
+
DecisionForest forest = forestBuilder.build(nbTrees, callback);
-
+
time = System.currentTimeMillis() - time;
- log.info("Build Time: {}", DFUtils.elapsedTime(time));
-
+ BuildForest.log.info("Build Time: {}", DFUtils.elapsedTime(time));
+
if (isOob) {
Random rng;
- if (seed != null)
+ if (seed != null) {
rng = RandomUtils.getRandom(seed);
- else
+ } else {
rng = RandomUtils.getRandom();
-
+ }
+
FileSystem fs = dataPath.getFileSystem(getConf());
int[] labels = Data.extractLabels(dataset, fs, dataPath);
- log.info("oob error estimate : "
- + ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
+ BuildForest.log.info("oob error estimate : "
+ + ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
}
-
+
return forest;
}
-
+
protected static Data loadData(Configuration conf, Path dataPath, Dataset dataset) throws IOException {
- log.info("Loading the data...");
+ BuildForest.log.info("Loading the data...");
FileSystem fs = dataPath.getFileSystem(conf);
Data data = DataLoader.loadData(dataset, fs, dataPath);
- log.info("Data Loaded");
-
+ BuildForest.log.info("Data Loaded");
+
return data;
}
-
+
/**
* @param args
* @throws Exception
@@ -221,5 +214,5 @@
public static void main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new BuildForest(), args);
}
-
+
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java Sat Feb 13 19:07:36 2010
@@ -34,173 +34,151 @@
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.ErrorEstimate;
-import org.apache.mahout.df.DFUtils;
-import org.apache.mahout.df.mapreduce.inmem.InMemBuilder;
-import org.apache.mahout.df.mapreduce.partial.PartialBuilder;
import org.apache.mahout.df.builder.DefaultTreeBuilder;
import org.apache.mahout.df.callback.ForestPredictions;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
-import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.df.mapreduce.inmem.InMemBuilder;
+import org.apache.mahout.df.mapreduce.partial.PartialBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
- * Tool to builds a Random Forest using any given dataset (in UCI format). Can
- * use either the in-mem mapred or partial mapred implementations
+ * Tool to builds a Random Forest using any given dataset (in UCI format). Can use either the in-mem mapred or
+ * partial mapred implementations
*/
public class BuildForest extends Configured implements Tool {
-
+
private static final Logger log = LoggerFactory.getLogger(BuildForest.class);
-
+
private Path dataPath; // Data path
-
+
private Path datasetPath; // Dataset path
-
+
private int m; // Number of variables to select at each tree-node
-
+
private int nbTrees; // Number of trees to grow
-
+
private Long seed = null; // Random seed
-
+
private boolean isPartial; // use partial data implementation
-
+
private boolean isOob; // estimate oob error;
-
+
@Override
public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
-
+
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
- Option oobOpt = obuilder.withShortName("oob").withRequired(false)
- .withDescription("Optional, estimate the out-of-bag error").create();
-
- Option dataOpt = obuilder.withLongName("data").withShortName("d")
- .withRequired(true).withArgument(
- abuilder.withName("path").withMinimum(1).withMaximum(1).create())
- .withDescription("Data path").create();
-
- Option datasetOpt = obuilder
- .withLongName("dataset")
- .withShortName("ds")
- .withRequired(true)
- .withArgument(
- abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
- .withDescription("Dataset path").create();
-
- Option selectionOpt = obuilder.withLongName("selection")
- .withShortName("sl").withRequired(true).withArgument(
- abuilder.withName("m").withMinimum(1).withMaximum(1).create())
- .withDescription(
- "Number of variables to select randomly at each tree-node")
+
+ Option oobOpt = obuilder.withShortName("oob").withRequired(false).withDescription(
+ "Optional, estimate the out-of-bag error").create();
+
+ Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset path")
.create();
-
- Option seedOpt = obuilder
- .withLongName("seed")
- .withShortName("sd")
- .withRequired(false)
- .withArgument(
- abuilder.withName("seed").withMinimum(1).withMaximum(1).create())
- .withDescription(
- "Optional, seed value used to initialise the Random number generator")
+
+ Option selectionOpt = obuilder.withLongName("selection").withShortName("sl").withRequired(true)
+ .withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Number of variables to select randomly at each tree-node").create();
+
+ Option seedOpt = obuilder.withLongName("seed").withShortName("sd").withRequired(false).withArgument(
+ abuilder.withName("seed").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Optional, seed value used to initialise the Random number generator").create();
+
+ Option partialOpt = obuilder.withLongName("partial").withShortName("p").withRequired(false)
+ .withDescription("Optional, use the Partial Data implementation").create();
+
+ Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true).withArgument(
+ abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Number of trees to grow").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
.create();
-
- Option partialOpt = obuilder.withLongName("partial").withShortName("p")
- .withRequired(false).withDescription(
- "Optional, use the Partial Data implementation").create();
-
- Option nbtreesOpt = obuilder
- .withLongName("nbtrees")
- .withShortName("t")
- .withRequired(true)
- .withArgument(
- abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
- .withDescription("Number of trees to grow").create();
-
- Option helpOpt = obuilder.withLongName("help").withDescription(
- "Print out help").withShortName("h").create();
-
- Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(
- dataOpt).withOption(datasetOpt).withOption(selectionOpt).withOption(
- seedOpt).withOption(partialOpt).withOption(nbtreesOpt).withOption(
- helpOpt).create();
-
+
+ Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(dataOpt).withOption(datasetOpt)
+ .withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
+ .withOption(helpOpt).create();
+
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption("help")) {
CommandLineUtil.printHelp(group);
return -1;
}
-
+
isPartial = cmdLine.hasOption(partialOpt);
isOob = cmdLine.hasOption(oobOpt);
String dataName = cmdLine.getValue(dataOpt).toString();
String datasetName = cmdLine.getValue(datasetOpt).toString();
m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
-
+
if (cmdLine.hasOption(seedOpt)) {
seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
}
-
- log.debug("data : {}", dataName);
- log.debug("dataset : {}", datasetName);
- log.debug("m : {}", m);
- log.debug("seed : {}", seed);
- log.debug("nbtrees : {}", nbTrees);
- log.debug("isPartial : {}", isPartial);
- log.debug("isOob : {}", isOob);
-
+
+ BuildForest.log.debug("data : {}", dataName);
+ BuildForest.log.debug("dataset : {}", datasetName);
+ BuildForest.log.debug("m : {}", m);
+ BuildForest.log.debug("seed : {}", seed);
+ BuildForest.log.debug("nbtrees : {}", nbTrees);
+ BuildForest.log.debug("isPartial : {}", isPartial);
+ BuildForest.log.debug("isOob : {}", isOob);
+
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
-
+
} catch (OptionException e) {
System.err.println("Exception : " + e);
CommandLineUtil.printHelp(group);
return -1;
}
-
+
buildForest();
-
+
return 0;
}
-
+
private DecisionForest buildForest() throws IOException, ClassNotFoundException, InterruptedException {
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
treeBuilder.setM(m);
-
- Dataset dataset = Dataset.load(getConf(), datasetPath);
-
- ForestPredictions callback = (isOob) ? new ForestPredictions(dataset
- .nbInstances(), dataset.nblabels()) : null;
-
+
+ Dataset dataset = Dataset.load(getConf(), datasetPath);
+
+ ForestPredictions callback = isOob ? new ForestPredictions(dataset.nbInstances(), dataset.nblabels())
+ : null;
+
Builder forestBuilder;
-
+
if (isPartial) {
- log.info("Partial Mapred implementation");
+ BuildForest.log.info("Partial Mapred implementation");
forestBuilder = new PartialBuilder(treeBuilder, dataPath, datasetPath, seed, getConf());
} else {
- log.info("InMem Mapred implementation");
- forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath,
- seed, getConf());
+ BuildForest.log.info("InMem Mapred implementation");
+ forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath, seed, getConf());
}
- log.info("Building the forest...");
+ BuildForest.log.info("Building the forest...");
long time = System.currentTimeMillis();
-
+
DecisionForest forest = forestBuilder.build(nbTrees, callback);
-
+
time = System.currentTimeMillis() - time;
- log.info("Build Time: {}", DFUtils.elapsedTime(time));
-
+ BuildForest.log.info("Build Time: {}", DFUtils.elapsedTime(time));
+
if (isOob) {
Random rng;
if (seed != null) {
@@ -208,26 +186,26 @@
} else {
rng = RandomUtils.getRandom();
}
-
+
FileSystem fs = dataPath.getFileSystem(getConf());
int[] labels = Data.extractLabels(dataset, fs, dataPath);
-
- log.info("oob error estimate : "
- + ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
+
+ BuildForest.log.info("oob error estimate : "
+ + ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
}
-
+
return forest;
}
-
+
protected static Data loadData(Configuration conf, Path dataPath, Dataset dataset) throws IOException {
- log.info("Loading the data...");
+ BuildForest.log.info("Loading the data...");
FileSystem fs = dataPath.getFileSystem(conf);
Data data = DataLoader.loadData(dataset, fs, dataPath);
- log.info("Data Loaded");
-
+ BuildForest.log.info("Data Loaded");
+
return data;
}
-
+
/**
* @param args
* @throws Exception
@@ -235,6 +213,5 @@
public static void main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new BuildForest(), args);
}
-
+
}
-
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/DeliciousTagsExample.java Sat Feb 13 19:07:36 2010
@@ -33,34 +33,33 @@
import org.apache.mahout.fpm.pfpgrowth.dataset.KeyBasedStringTupleGrouper;
public final class DeliciousTagsExample {
- private DeliciousTagsExample() {
- }
-
+ private DeliciousTagsExample() { }
+
public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
Option inputDirOpt = DefaultOptionCreator.inputOption().create();
-
+
Option outputOpt = DefaultOptionCreator.outputOption().create();
-
+
Option helpOpt = DefaultOptionCreator.helpOption();
Option recordSplitterOpt = obuilder.withLongName("splitterPattern").withArgument(
- abuilder.withName("splitterPattern").withMinimum(1).withMaximum(1).create()).withDescription(
- "Regular Expression pattern used to split given line into fields."
- + " Default value splits comma or tab separated fields."
- + " Default Value: \"[ ,\\t]*\\t[ ,\\t]*\" ").withShortName("regex").create();
+ abuilder.withName("splitterPattern").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Regular Expression pattern used to split given line into fields."
+ + " Default value splits comma or tab separated fields."
+ + " Default Value: \"[ ,\\t]*\\t[ ,\\t]*\" ").withShortName("regex").create();
Option encodingOpt = obuilder.withLongName("encoding").withArgument(
- abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()).withDescription(
- "(Optional) The file encoding. Default value: UTF-8").withShortName("e").create();
+ abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()).withDescription(
+ "(Optional) The file encoding. Default value: UTF-8").withShortName("e").create();
Group group = gbuilder.withName("Options").withOption(inputDirOpt).withOption(outputOpt).withOption(
- helpOpt).withOption(recordSplitterOpt).withOption(encodingOpt).create();
-
+ helpOpt).withOption(recordSplitterOpt).withOption(encodingOpt).create();
+
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
@@ -69,7 +68,7 @@
if (cmdLine.hasOption(recordSplitterOpt)) {
params.set("splitPattern", (String) cmdLine.getValue(recordSplitterOpt));
}
-
+
String encoding = "UTF-8";
if (cmdLine.hasOption(encodingOpt)) {
encoding = (String) cmdLine.getValue(encodingOpt);
@@ -86,10 +85,10 @@
params.set("field0", "3");
params.set("maxTransactionLength", "100");
KeyBasedStringTupleGrouper.startJob(params);
-
+
} catch (OptionException ex) {
CommandLineUtil.printHelp(group);
}
-
+
}
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleCombiner.java Sat Feb 13 19:07:36 2010
@@ -24,12 +24,11 @@
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.mahout.common.StringTuple;
-public class KeyBasedStringTupleCombiner extends
- Reducer<Text, StringTuple, Text, StringTuple> {
-
+public class KeyBasedStringTupleCombiner extends Reducer<Text,StringTuple,Text,StringTuple> {
+
@Override
- protected void reduce(Text key, Iterable<StringTuple> values,
- Context context) throws IOException, InterruptedException {
+ protected void reduce(Text key, Iterable<StringTuple> values, Context context) throws IOException,
+ InterruptedException {
HashSet<String> outputValues = new HashSet<String>();
for (StringTuple value : values) {
outputValues.addAll(value.getEntries());
@@ -37,4 +36,3 @@
context.write(key, new StringTuple(outputValues));
}
}
-
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleGrouper.java Sat Feb 13 19:07:36 2010
@@ -49,33 +49,31 @@
import org.apache.mahout.common.StringTuple;
public class KeyBasedStringTupleGrouper {
-
- private KeyBasedStringTupleGrouper() {
- }
-
+
+ private KeyBasedStringTupleGrouper() { }
+
public static void startJob(Parameters params) throws IOException,
- InterruptedException, ClassNotFoundException {
+ InterruptedException,
+ ClassNotFoundException {
Configuration conf = new Configuration();
conf.set("job.parameters", params.toString());
conf.set("mapred.compress.map.output", "true");
conf.set("mapred.output.compression.type", "BLOCK");
- conf.set("mapred.map.output.compression.codec",
- "org.apache.hadoop.io.compress.GzipCodec");
- conf.set("io.serializations",
- "org.apache.hadoop.io.serializer.JavaSerialization,"
- + "org.apache.hadoop.io.serializer.WritableSerialization");
-
+ conf.set("mapred.map.output.compression.codec", "org.apache.hadoop.io.compress.GzipCodec");
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+
String input = params.get("input");
Job job = new Job(conf, "Generating dataset based from input" + input);
job.setJarByClass(KeyBasedStringTupleGrouper.class);
-
+
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(StringTuple.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Text.class);
-
+
FileInputFormat.addInputPath(job, new Path(input));
Path outPath = new Path(params.get("output"));
FileOutputFormat.setOutputPath(job, outPath);
@@ -84,13 +82,13 @@
if (dfs.exists(outPath)) {
dfs.delete(outPath, true);
}
-
+
job.setInputFormatClass(TextInputFormat.class);
job.setMapperClass(KeyBasedStringTupleMapper.class);
job.setCombinerClass(KeyBasedStringTupleCombiner.class);
job.setReducerClass(KeyBasedStringTupleReducer.class);
job.setOutputFormatClass(TextOutputFormat.class);
-
+
job.waitForCompletion(true);
}
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleMapper.java Sat Feb 13 19:07:36 2010
@@ -31,15 +31,12 @@
import org.slf4j.LoggerFactory;
/**
- * Splits the line using a {@link Pattern} and outputs key as given by the
- * groupingFields
+ * Splits the line using a {@link Pattern} and outputs key as given by the groupingFields
*
*/
-public class KeyBasedStringTupleMapper extends
- Mapper<LongWritable,Text,Text,StringTuple> {
+public class KeyBasedStringTupleMapper extends Mapper<LongWritable,Text,Text,StringTuple> {
- private static final Logger log = LoggerFactory
- .getLogger(KeyBasedStringTupleMapper.class);
+ private static final Logger log = LoggerFactory.getLogger(KeyBasedStringTupleMapper.class);
private Pattern splitter;
@@ -48,27 +45,22 @@
private int[] groupingFields;
@Override
- protected void map(LongWritable key,
- Text value,
- Context context) throws IOException,
- InterruptedException {
+ protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String[] fields = splitter.split(value.toString());
if (fields.length != 4) {
- log.info("{} {}", fields.length, value.toString());
+ KeyBasedStringTupleMapper.log.info("{} {}", fields.length, value.toString());
context.getCounter("Map", "ERROR").increment(1);
return;
}
List<String> oKey = new ArrayList<String>();
- for (int i = 0, groupingFieldCount = groupingFields.length;
- i < groupingFieldCount; i++) {
- oKey.add(fields[groupingFields[i]]);
- context.setStatus(fields[groupingFields[i]]);
+ for (int groupingField : groupingFields) {
+ oKey.add(fields[groupingField]);
+ context.setStatus(fields[groupingField]);
}
List<String> oValue = new ArrayList<String>();
- for (int i = 0, selectedFieldCount = selectedFields.length;
- i < selectedFieldCount; i++) {
- oValue.add(fields[selectedFields[i]]);
+ for (int selectedField : selectedFields) {
+ oValue.add(fields[selectedField]);
}
context.write(new Text(oKey.toString()), new StringTuple(oValue));
@@ -76,22 +68,18 @@
}
@Override
- protected void setup(Context context) throws IOException,
- InterruptedException {
+ protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
- Parameters params = Parameters.fromString(
- context.getConfiguration().get("job.parameters", ""));
+ Parameters params = Parameters.fromString(context.getConfiguration().get("job.parameters", ""));
splitter = Pattern.compile(params.get("splitPattern", "[ \t]*\t[ \t]*"));
- int selectedFieldCount = Integer.valueOf(
- params.get("selectedFieldCount", "0"));
+ int selectedFieldCount = Integer.valueOf(params.get("selectedFieldCount", "0"));
selectedFields = new int[selectedFieldCount];
for (int i = 0; i < selectedFieldCount; i++) {
selectedFields[i] = Integer.valueOf(params.get("field" + i, "0"));
}
- int groupingFieldCount = Integer.valueOf(
- params.get("groupingFieldCount", "0"));
+ int groupingFieldCount = Integer.valueOf(params.get("groupingFieldCount", "0"));
groupingFields = new int[groupingFieldCount];
for (int i = 0; i < groupingFieldCount; i++) {
groupingFields[i] = Integer.valueOf(params.get("gfield" + i, "0"));
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/fpm/pfpgrowth/dataset/KeyBasedStringTupleReducer.java Sat Feb 13 19:07:36 2010
@@ -26,16 +26,13 @@
import org.apache.mahout.common.Parameters;
import org.apache.mahout.common.StringTuple;
-public class KeyBasedStringTupleReducer extends
- Reducer<Text,StringTuple,Text,Text> {
+public class KeyBasedStringTupleReducer extends Reducer<Text,StringTuple,Text,Text> {
private int maxTransactionLength = 100;
@Override
- protected void reduce(Text key,
- Iterable<StringTuple> values,
- Context context) throws IOException,
- InterruptedException {
+ protected void reduce(Text key, Iterable<StringTuple> values, Context context) throws IOException,
+ InterruptedException {
Set<String> items = new HashSet<String>();
for (StringTuple value : values) {
@@ -69,12 +66,9 @@
}
@Override
- protected void setup(Context context) throws IOException,
- InterruptedException {
+ protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
- Parameters params = Parameters.fromString(context.getConfiguration().get(
- "job.parameters", ""));
- maxTransactionLength = Integer.valueOf(params.get("maxTransactionLength",
- "100"));
+ Parameters params = Parameters.fromString(context.getConfiguration().get("job.parameters", ""));
+ maxTransactionLength = Integer.valueOf(params.get("maxTransactionLength", "100"));
}
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/Attribute.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/Attribute.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/Attribute.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/Attribute.java Sat Feb 13 19:07:36 2010
@@ -21,7 +21,7 @@
* An attribute for use with {@link DataSet}
*/
interface Attribute {
-
+
boolean isNumerical();
-
+
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDCrossover.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDCrossover.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDCrossover.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDCrossover.java Sat Feb 13 19:07:36 2010
@@ -17,46 +17,42 @@
package org.apache.mahout.ga.watchmaker.cd;
-import org.uncommons.maths.random.Probability;
-import org.uncommons.watchmaker.framework.operators.AbstractCrossover;
-
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
+import org.uncommons.maths.random.Probability;
+import org.uncommons.watchmaker.framework.operators.AbstractCrossover;
+
/**
* Crossover operator.
*/
public class CDCrossover extends AbstractCrossover<CDRule> {
-
+
public CDCrossover(int crossoverPoints) {
super(crossoverPoints);
}
-
+
public CDCrossover(int crossoverPoints, Probability crossoverProbability) {
super(crossoverPoints, crossoverProbability);
}
@Override
- protected List<CDRule> mate(CDRule parent1, CDRule parent2,
- int numberOfCrossoverPoints, Random rng) {
- if (parent1.getNbConditions() != parent2.getNbConditions())
- {
- throw new IllegalArgumentException("Cannot perform cross-over with parents of different size.");
+ protected List<CDRule> mate(CDRule parent1, CDRule parent2, int numberOfCrossoverPoints, Random rng) {
+ if (parent1.getNbConditions() != parent2.getNbConditions()) {
+ throw new IllegalArgumentException("Cannot perform cross-over with parents of different size.");
}
CDRule offspring1 = new CDRule(parent1);
CDRule offspring2 = new CDRule(parent2);
// Apply as many cross-overs as required.
- for (int i = 0; i < numberOfCrossoverPoints; i++)
- {
- // Cross-over index is always greater than zero and less than
- // the length of the parent so that we always pick a point that
- // will result in a meaningful cross-over.
- int crossoverIndex = (1 + rng.nextInt(parent1.getNbConditions() - 1));
- for (int j = 0; j < crossoverIndex; j++)
- {
- swap(offspring1, offspring2, j);
- }
+ for (int i = 0; i < numberOfCrossoverPoints; i++) {
+ // Cross-over index is always greater than zero and less than
+ // the length of the parent so that we always pick a point that
+ // will result in a meaningful cross-over.
+ int crossoverIndex = 1 + rng.nextInt(parent1.getNbConditions() - 1);
+ for (int j = 0; j < crossoverIndex; j++) {
+ CDCrossover.swap(offspring1, offspring2, j);
+ }
}
List<CDRule> result = new ArrayList<CDRule>(2);
@@ -64,9 +60,9 @@
result.add(offspring2);
return result;
}
-
+
static void swap(CDRule ind1, CDRule ind2, int index) {
-
+
// swap W
double dtemp = ind1.getW(index);
ind1.setW(index, ind2.getW(index));
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFactory.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFactory.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFactory.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFactory.java Sat Feb 13 19:07:36 2010
@@ -17,24 +17,24 @@
package org.apache.mahout.ga.watchmaker.cd;
-import org.uncommons.watchmaker.framework.factories.AbstractCandidateFactory;
-
import java.util.Random;
+import org.uncommons.watchmaker.framework.factories.AbstractCandidateFactory;
+
/**
* Factory used by Watchmaker to generate the initial population.
*/
public class CDFactory extends AbstractCandidateFactory<CDRule> {
-
+
private final double threshold;
-
+
/**
* @param threshold condition activation threshold
*/
public CDFactory(double threshold) {
this.threshold = threshold;
}
-
+
@Override
public CDRule generateRandomCandidate(Random rng) {
return new CDRule(threshold, rng);
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitness.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitness.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitness.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitness.java Sat Feb 13 19:07:36 2010
@@ -17,102 +17,104 @@
package org.apache.mahout.ga.watchmaker.cd;
-import org.apache.hadoop.io.Writable;
-
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
+import org.apache.hadoop.io.Writable;
+
/**
- * Fitness of the class discovery problem.
+ * Fitness of the class discovery problem.
*/
public class CDFitness implements Writable {
-
+
/** True positive */
private int tp;
-
+
/** False positive */
private int fp;
-
+
/** True negative */
private int tn;
-
+
/** False negative */
private int fn;
-
+
public CDFitness() {
-
+
}
-
+
public CDFitness(CDFitness f) {
tp = f.getTp();
fp = f.getFp();
tn = f.getTn();
fn = f.getFn();
}
-
+
public CDFitness(int tp, int fp, int tn, int fn) {
this.tp = tp;
this.fp = fp;
this.tn = tn;
this.fn = fn;
}
-
+
public int getTp() {
return tp;
}
-
+
public int getFp() {
return fp;
}
-
+
public int getTn() {
return tn;
}
-
+
public int getFn() {
return fn;
}
-
+
public void add(CDFitness f) {
tp += f.getTp();
fp += f.getFp();
tn += f.getTn();
fn += f.getFn();
}
-
+
@Override
public boolean equals(Object obj) {
- if (this == obj)
+ if (this == obj) {
return true;
- if (obj == null || !(obj instanceof CDFitness))
+ }
+ if (obj == null || !(obj instanceof CDFitness)) {
return false;
-
+ }
+
CDFitness f = (CDFitness) obj;
-
+
return tp == f.getTp() && fp == f.getFp() && tn == f.getTn() && fn == f.getFn();
}
-
+
@Override
public int hashCode() {
- return tp + 31 * (fp + 31 * (tn + 31 * fn));
+ return tp + 31 * (fp + 31 * (tn + 31 * fn));
}
-
+
@Override
public String toString() {
return "[TP=" + tp + ", FP=" + fp + ", TN=" + tn + ", FN=" + fn + ']';
}
-
+
/**
* Calculates the fitness corresponding to this evaluation.
*/
public double get() {
- double se = ((double) tp) / (tp + fn); // sensitivity
- double sp = ((double) tn) / (tn + fp); // specificity
+ double se = (double) tp / (tp + fn); // sensitivity
+ double sp = (double) tn / (tn + fp); // specificity
return se * sp;
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
tp = in.readInt();
@@ -120,14 +122,14 @@
tn = in.readInt();
fn = in.readInt();
}
-
+
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(tp);
out.writeInt(fp);
out.writeInt(tn);
out.writeInt(fn);
-
+
}
-
+
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitnessEvaluator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitnessEvaluator.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitnessEvaluator.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDFitnessEvaluator.java Sat Feb 13 19:07:36 2010
@@ -17,29 +17,29 @@
package org.apache.mahout.ga.watchmaker.cd;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.fs.Path;
import org.apache.mahout.ga.watchmaker.STFitnessEvaluator;
import org.apache.mahout.ga.watchmaker.cd.hadoop.CDMahoutEvaluator;
import org.apache.mahout.ga.watchmaker.cd.hadoop.DatasetSplit;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
/**
* Class Discovery Fitness Evaluator. Delegates to Mahout the task of evaluating
* the fitness.
*/
public class CDFitnessEvaluator extends STFitnessEvaluator<Rule> {
-
+
private final Path dataset;
-
+
private final DatasetSplit split;
-
+
private final List<CDFitness> evals = new ArrayList<CDFitness>();
private final int target;
-
+
/**
*
* @param dataset dataset path
@@ -50,25 +50,26 @@
this.target = target;
this.split = split;
}
-
+
@Override
public boolean isNatural() {
return true;
}
-
+
@Override
protected void evaluate(List<? extends Rule> population,
- List<Double> evaluations) {
+ List<Double> evaluations) {
evals.clear();
-
+
try {
CDMahoutEvaluator.evaluate(population, target, dataset, evals, split);
} catch (IOException e) {
throw new IllegalStateException("Exception while evaluating the population", e);
}
-
- for (CDFitness fitness : evals)
+
+ for (CDFitness fitness : evals) {
evaluations.add(fitness.get());
+ }
}
-
+
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDGA.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDGA.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDGA.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDGA.java Sat Feb 13 19:07:36 2010
@@ -17,20 +17,26 @@
package org.apache.mahout.ga.watchmaker.cd;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.ga.watchmaker.cd.hadoop.CDMahoutEvaluator;
-import org.apache.mahout.ga.watchmaker.cd.hadoop.DatasetSplit;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.common.CommandLineUtil;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.Group;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.ga.watchmaker.cd.hadoop.CDMahoutEvaluator;
+import org.apache.mahout.ga.watchmaker.cd.hadoop.DatasetSplit;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.uncommons.watchmaker.framework.CandidateFactory;
import org.uncommons.watchmaker.framework.EvolutionEngine;
import org.uncommons.watchmaker.framework.EvolutionObserver;
@@ -42,31 +48,20 @@
import org.uncommons.watchmaker.framework.operators.EvolutionPipeline;
import org.uncommons.watchmaker.framework.selection.RouletteWheelSelection;
import org.uncommons.watchmaker.framework.termination.GenerationCount;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
/**
* Class Discovery Genetic Algorithm main class. Has the following parameters:
* <ul>
* <li>threshold<br>
- * Condition activation threshold. See Also
- * {@link org.apache.mahout.ga.watchmaker.cd.CDRule CDRule}
+ * Condition activation threshold. See Also {@link org.apache.mahout.ga.watchmaker.cd.CDRule CDRule}
* <li>nb cross point<br>
- * Number of points used by the{@link org.apache.mahout.ga.watchmaker.cd.CDCrossover CrossOver}
- * operator
+ * Number of points used by the{@link org.apache.mahout.ga.watchmaker.cd.CDCrossover CrossOver} operator
* <li>mutation rate<br>
- * mutation rate of the
- * {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
+ * mutation rate of the {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
* <li>mutation range<br>
- * mutation range of the
- * {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
+ * mutation range of the {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
* <li>mutation precision<br>
- * mutation precision of the
- * {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
+ * mutation precision of the {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
* <li>population size
* <li>generations count<br>
* number of generations the genetic algorithm will be run for.
@@ -74,165 +69,161 @@
* </ul>
*/
public class CDGA {
-
+
private static final Logger log = LoggerFactory.getLogger(CDGA.class);
-
- private CDGA() {
- }
-
+
+ private CDGA() { }
+
public static void main(String[] args) throws IOException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
- Option inputOpt = obuilder.withLongName("input").withRequired(true)
- .withShortName("i").withArgument(
- abuilder.withName("input").withMinimum(1).withMaximum(1).create())
- .withDescription("The Path for input data directory.").create();
-
- Option labelOpt = obuilder.withLongName("label").withRequired(true)
- .withShortName("l").withArgument(
- abuilder.withName("index").withMinimum(1).withMaximum(1).create())
- .withDescription("label's index.").create();
-
- Option thresholdOpt = obuilder.withLongName("threshold").withRequired(false)
- .withShortName("t").withArgument(
- abuilder.withName("threshold").withMinimum(1).withMaximum(1).create())
- .withDescription("Condition activation threshold, default = 0.5.").create();
-
- Option crosspntsOpt = obuilder.withLongName("crosspnts").withRequired(false)
- .withShortName("cp").withArgument(
- abuilder.withName("points").withMinimum(1).withMaximum(1).create())
- .withDescription("Number of crossover points to use, default = 1.").create();
-
- Option mutrateOpt = obuilder.withLongName("mutrate").withRequired(true)
- .withShortName("m").withArgument(
- abuilder.withName("true").withMinimum(1).withMaximum(1).create())
- .withDescription("Mutation rate (float).").create();
-
- Option mutrangeOpt = obuilder.withLongName("mutrange").withRequired(false)
- .withShortName("mr").withArgument(
- abuilder.withName("range").withMinimum(1).withMaximum(1).create())
- .withDescription("Mutation range, default = 0.1 (10%).").create();
-
- Option mutprecOpt = obuilder.withLongName("mutprec").withRequired(false)
- .withShortName("mp").withArgument(
- abuilder.withName("precision").withMinimum(1).withMaximum(1).create())
- .withDescription("Mutation precision, default = 2.").create();
-
- Option popsizeOpt = obuilder.withLongName("popsize").withRequired(true)
- .withShortName("p").withArgument(
- abuilder.withName("size").withMinimum(1).withMaximum(1).create())
- .withDescription("Population size.").create();
-
- Option gencntOpt = obuilder.withLongName("gencnt").withRequired(true)
- .withShortName("g").withArgument(
- abuilder.withName("count").withMinimum(1).withMaximum(1).create())
- .withDescription("Generations count.").create();
-
+
+ Option inputOpt = obuilder.withLongName("input").withRequired(true).withShortName("i").withArgument(
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path for input data directory.").create();
+
+ Option labelOpt = obuilder.withLongName("label").withRequired(true).withShortName("l").withArgument(
+ abuilder.withName("index").withMinimum(1).withMaximum(1).create()).withDescription("label's index.")
+ .create();
+
+ Option thresholdOpt = obuilder.withLongName("threshold").withRequired(false).withShortName("t")
+ .withArgument(abuilder.withName("threshold").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Condition activation threshold, default = 0.5.").create();
+
+ Option crosspntsOpt = obuilder.withLongName("crosspnts").withRequired(false).withShortName("cp")
+ .withArgument(abuilder.withName("points").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Number of crossover points to use, default = 1.").create();
+
+ Option mutrateOpt = obuilder.withLongName("mutrate").withRequired(true).withShortName("m").withArgument(
+ abuilder.withName("true").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Mutation rate (float).").create();
+
+ Option mutrangeOpt = obuilder.withLongName("mutrange").withRequired(false).withShortName("mr")
+ .withArgument(abuilder.withName("range").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Mutation range, default = 0.1 (10%).").create();
+
+ Option mutprecOpt = obuilder.withLongName("mutprec").withRequired(false).withShortName("mp")
+ .withArgument(abuilder.withName("precision").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Mutation precision, default = 2.").create();
+
+ Option popsizeOpt = obuilder.withLongName("popsize").withRequired(true).withShortName("p").withArgument(
+ abuilder.withName("size").withMinimum(1).withMaximum(1).create()).withDescription("Population size.")
+ .create();
+
+ Option gencntOpt = obuilder.withLongName("gencnt").withRequired(true).withShortName("g").withArgument(
+ abuilder.withName("count").withMinimum(1).withMaximum(1).create())
+ .withDescription("Generations count.").create();
+
Option helpOpt = DefaultOptionCreator.helpOption();
-
- Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(
- helpOpt).withOption(labelOpt).withOption(thresholdOpt).withOption(crosspntsOpt).
- withOption(mutrateOpt).withOption(mutrangeOpt).withOption(mutprecOpt).
- withOption(popsizeOpt).withOption(gencntOpt).create();
-
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(helpOpt).withOption(labelOpt)
+ .withOption(thresholdOpt).withOption(crosspntsOpt).withOption(mutrateOpt).withOption(mutrangeOpt)
+ .withOption(mutprecOpt).withOption(popsizeOpt).withOption(gencntOpt).create();
+
Parser parser = new Parser();
parser.setGroup(group);
-
+
try {
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
-
+
String dataset = cmdLine.getValue(inputOpt).toString();
int target = Integer.parseInt(cmdLine.getValue(labelOpt).toString());
- double threshold = (!cmdLine.hasOption(thresholdOpt)) ? 0.5 : Double.parseDouble(cmdLine.getValue(thresholdOpt).toString());
- int crosspnts = (!cmdLine.hasOption(crosspntsOpt)) ? 1 : Integer.parseInt(cmdLine.getValue(crosspntsOpt).toString());
+ double threshold = !cmdLine.hasOption(thresholdOpt) ? 0.5 : Double.parseDouble(cmdLine.getValue(
+ thresholdOpt).toString());
+ int crosspnts = !cmdLine.hasOption(crosspntsOpt) ? 1 : Integer.parseInt(cmdLine.getValue(crosspntsOpt)
+ .toString());
double mutrate = Double.parseDouble(cmdLine.getValue(mutrateOpt).toString());
- double mutrange = (!cmdLine.hasOption(mutrangeOpt)) ? 0.1 : Double.parseDouble(cmdLine.getValue(mutrangeOpt).toString());
- int mutprec = (!cmdLine.hasOption(mutprecOpt)) ? 2 : Integer.parseInt(cmdLine.getValue(mutprecOpt).toString());
+ double mutrange = !cmdLine.hasOption(mutrangeOpt) ? 0.1 : Double.parseDouble(cmdLine.getValue(
+ mutrangeOpt).toString());
+ int mutprec = !cmdLine.hasOption(mutprecOpt) ? 2 : Integer.parseInt(cmdLine.getValue(mutprecOpt)
+ .toString());
int popSize = Integer.parseInt(cmdLine.getValue(popsizeOpt).toString());
int genCount = Integer.parseInt(cmdLine.getValue(gencntOpt).toString());
-
+
long start = System.currentTimeMillis();
-
- runJob(dataset, target, threshold, crosspnts, mutrate, mutrange, mutprec,
- popSize, genCount);
-
+
+ CDGA.runJob(dataset, target, threshold, crosspnts, mutrate, mutrange, mutprec, popSize, genCount);
+
long end = System.currentTimeMillis();
-
- printElapsedTime(end - start);
+
+ CDGA.printElapsedTime(end - start);
} catch (OptionException e) {
- log.error("Error while parsing options", e);
+ CDGA.log.error("Error while parsing options", e);
CommandLineUtil.printHelp(group);
}
}
-
- private static void runJob(String dataset, int target, double threshold,
- int crosspnts, double mutrate, double mutrange, int mutprec, int popSize,
- int genCount) throws IOException {
+
+ private static void runJob(String dataset,
+ int target,
+ double threshold,
+ int crosspnts,
+ double mutrate,
+ double mutrange,
+ int mutprec,
+ int popSize,
+ int genCount) throws IOException {
Path inpath = new Path(dataset);
CDMahoutEvaluator.initializeDataSet(inpath);
-
+
// Candidate Factory
CandidateFactory<CDRule> factory = new CDFactory(threshold);
-
+
// Evolution Scheme
List<EvolutionaryOperator<CDRule>> operators = new ArrayList<EvolutionaryOperator<CDRule>>();
operators.add(new CDCrossover(crosspnts));
operators.add(new CDMutation(mutrate, mutrange, mutprec));
EvolutionPipeline<CDRule> pipeline = new EvolutionPipeline<CDRule>(operators);
-
+
// 75 % of the dataset is dedicated to training
DatasetSplit split = new DatasetSplit(0.75);
-
+
// Fitness Evaluator (defaults to training)
- FitnessEvaluator<? super CDRule> evaluator = new CDFitnessEvaluator(
- dataset, target, split);
+ FitnessEvaluator<? super CDRule> evaluator = new CDFitnessEvaluator(dataset, target, split);
// Selection Strategy
SelectionStrategy<? super CDRule> selection = new RouletteWheelSelection();
-
- EvolutionEngine<CDRule> engine = new SequentialEvolutionEngine<CDRule>(factory,
- pipeline, evaluator, selection, RandomUtils.getRandom());
-
+
+ EvolutionEngine<CDRule> engine = new SequentialEvolutionEngine<CDRule>(factory, pipeline, evaluator,
+ selection, RandomUtils.getRandom());
+
engine.addEvolutionObserver(new EvolutionObserver<CDRule>() {
@Override
public void populationUpdate(PopulationData<? extends CDRule> data) {
- log.info("Generation {}", data.getGenerationNumber());
+ CDGA.log.info("Generation {}", data.getGenerationNumber());
}
});
-
+
// evolve the rules over the training set
Rule solution = engine.evolve(popSize, 1, new GenerationCount(genCount));
-
+
// fitness over the training set
- CDFitness bestTrainFit = CDMahoutEvaluator.evaluate(solution, target,
- inpath, split);
-
+ CDFitness bestTrainFit = CDMahoutEvaluator.evaluate(solution, target, inpath, split);
+
// fitness over the testing set
split.setTraining(false);
- CDFitness bestTestFit = CDMahoutEvaluator.evaluate(solution, target,
- inpath, split);
-
+ CDFitness bestTestFit = CDMahoutEvaluator.evaluate(solution, target, inpath, split);
+
// evaluate the solution over the testing set
- log.info("Best solution fitness (train set) : {}", bestTrainFit);
- log.info("Best solution fitness (test set) : {}", bestTestFit);
+ CDGA.log.info("Best solution fitness (train set) : {}", bestTrainFit);
+ CDGA.log.info("Best solution fitness (test set) : {}", bestTestFit);
}
-
+
private static void printElapsedTime(long milli) {
long seconds = milli / 1000;
milli %= 1000;
-
+
long minutes = seconds / 60;
seconds %= 60;
-
+
long hours = minutes / 60;
minutes %= 60;
-
- log.info("Elapsed time (Hours:minutes:seconds:milli) : {}:{}:{}:{}", new Object[] {hours, minutes, seconds, milli});
+
+ CDGA.log.info("Elapsed time (Hours:minutes:seconds:milli) : {}:{}:{}:{}", new Object[] {hours, minutes,
+ seconds, milli});
}
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDMutation.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDMutation.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDMutation.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDMutation.java Sat Feb 13 19:07:36 2010
@@ -17,29 +17,29 @@
package org.apache.mahout.ga.watchmaker.cd;
-import org.uncommons.watchmaker.framework.EvolutionaryOperator;
-
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
+import org.uncommons.watchmaker.framework.EvolutionaryOperator;
+
/**
* Mutation operator.
*/
public class CDMutation implements EvolutionaryOperator<CDRule> {
-
+
/** probability of mutating a variable */
private final double rate;
-
+
/** max size of the change (step-size) for each mutated variable */
private final double range;
-
+
/**
* mutation precision. Defines indirectly the minimal step-size and the
* distribution of mutation steps inside the mutation range.
*/
private final int k;
-
+
/**
*
* @param rate probability of mutating a variable
@@ -50,18 +50,21 @@
* for more information about the parameters
*/
public CDMutation(double rate, double range, int k) {
- if (rate <= 0 || rate > 1)
+ if (rate <= 0 || rate > 1) {
throw new IllegalArgumentException("mutation rate must be in ]0, 1]");
- if (range <= 0 || range > 1)
+ }
+ if (range <= 0 || range > 1) {
throw new IllegalArgumentException("mutation range must be in ]0, 1]");
- if (k < 0)
+ }
+ if (k < 0) {
throw new IllegalArgumentException("mutation precision must be >= 0");
-
+ }
+
this.rate = rate;
this.range = range;
this.k = k;
}
-
+
@Override
public List<CDRule> apply(List<CDRule> selectedCandidates, Random rng) {
List<CDRule> mutatedPopulation = new ArrayList<CDRule>(selectedCandidates.size());
@@ -70,48 +73,49 @@
}
return mutatedPopulation;
}
-
+
protected CDRule mutate(CDRule rule, Random rng) {
DataSet dataset = DataSet.getDataSet();
-
+
for (int condInd = 0; condInd < rule.getNbConditions(); condInd++) {
- if (rng.nextDouble() > rate)
+ if (rng.nextDouble() > rate) {
continue;
-
+ }
+
int attrInd = CDRule.attributeIndex(condInd);
-
+
rule.setW(condInd, rndDouble(rule.getW(condInd), 0.0, 1.0, rng));
-
+
if (dataset.isNumerical(attrInd)) {
rule.setV(condInd, rndDouble(rule.getV(condInd), dataset
- .getMin(attrInd), dataset.getMax(attrInd), rng));
+ .getMin(attrInd), dataset.getMax(attrInd), rng));
} else {
- rule.setV(condInd, rndInt(rule.getV(condInd), dataset
- .getNbValues(attrInd), rng));
+ rule.setV(condInd, CDMutation.rndInt(rule.getV(condInd), dataset
+ .getNbValues(attrInd), rng));
}
}
-
+
return rule;
}
-
+
/**
* returns a random double in the interval [min, max ].
*/
double rndDouble(double value, double min, double max, Random rng) {
double s = rng.nextDouble() * 2.0 - 1.0; // [-1, +1]
- double r = range * ((max - min) / 2);
+ double r = range * (max - min) / 2;
double a = Math.pow(2, -k * rng.nextDouble());
double stp = s * r * a;
-
+
value += stp;
-
+
// clamp value to [min, max]
value = Math.max(min, value);
value = Math.min(max, value);
-
+
return value;
}
-
+
static int rndInt(double value, int nbcategories, Random rng) {
return rng.nextInt(nbcategories);
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDRule.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDRule.java?rev=909871&r1=909870&r2=909871&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDRule.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/ga/watchmaker/cd/CDRule.java Sat Feb 13 19:07:36 2010
@@ -17,10 +17,10 @@
package org.apache.mahout.ga.watchmaker.cd;
-import org.uncommons.maths.binary.BitString;
-
import java.util.Random;
+import org.uncommons.maths.binary.BitString;
+
/**
* Binary classification rule of the form:
*
@@ -33,46 +33,46 @@
*
* where conditioni = (wi): attributi oi vi <br>
* <ul>
- * <li> wi is the weight of the condition: <br>
+ * <li>wi is the weight of the condition: <br>
* <code>
* if (wi < a given threshold) then conditioni is not taken into
* consideration
- * </code>
- * </li>
- * <li> oi is an operator ('<' or '>=') </li>
+ * </code></li>
+ * <li>oi is an operator ('<' or '>=')</li>
* </ul>
*/
public class CDRule implements Rule {
-
- private double threshold;
-
- private int nbConditions;
-
- private double[] weights;
-
- private BitString operators;
-
- private double[] values;
-
+
+ private final double threshold;
+
+ private final int nbConditions;
+
+ private final double[] weights;
+
+ private final BitString operators;
+
+ private final double[] values;
+
/**
- * @param threshold condition activation threshold
+ * @param threshold
+ * condition activation threshold
*/
public CDRule(double threshold) {
// crossover needs at least 2 attributes
if (!(threshold >= 0 && threshold <= 1)) {
throw new IllegalArgumentException("bad threshold");
}
-
+
this.threshold = threshold;
-
+
// the label is not included in the conditions
this.nbConditions = DataSet.getDataSet().getNbAttributes() - 1;
-
+
weights = new double[nbConditions];
operators = new BitString(nbConditions);
values = new double[nbConditions];
}
-
+
/**
* Random rule.
*
@@ -83,30 +83,31 @@
this(threshold);
DataSet dataset = DataSet.getDataSet();
-
+
for (int condInd = 0; condInd < nbConditions; condInd++) {
- int attrInd = attributeIndex(condInd);
-
+ int attrInd = CDRule.attributeIndex(condInd);
+
setW(condInd, rng.nextDouble());
setO(condInd, rng.nextBoolean());
- if (dataset.isNumerical(attrInd))
- setV(condInd, randomNumerical(dataset, attrInd, rng));
- else
- setV(condInd, randomCategorical(dataset, attrInd, rng));
+ if (dataset.isNumerical(attrInd)) {
+ setV(condInd, CDRule.randomNumerical(dataset, attrInd, rng));
+ } else {
+ setV(condInd, CDRule.randomCategorical(dataset, attrInd, rng));
+ }
}
}
-
+
protected static double randomNumerical(DataSet dataset, int attrInd, Random rng) {
double max = dataset.getMax(attrInd);
double min = dataset.getMin(attrInd);
return rng.nextDouble() * (max - min) + min;
}
-
+
protected static double randomCategorical(DataSet dataset, int attrInd, Random rng) {
int nbcategories = dataset.getNbValues(attrInd);
return rng.nextInt(nbcategories);
}
-
+
/**
* Copy Constructor
*
@@ -115,106 +116,115 @@
public CDRule(CDRule ind) {
threshold = ind.threshold;
nbConditions = ind.nbConditions;
-
+
weights = ind.weights.clone();
operators = ind.operators.clone();
values = ind.values.clone();
}
-
+
/**
* if all the active conditions are met returns 1, else returns 0.
*/
@Override
public int classify(DataLine dl) {
for (int condInd = 0; condInd < nbConditions; condInd++) {
- if (!condition(condInd, dl))
+ if (!condition(condInd, dl)) {
return 0;
+ }
}
return 1;
}
-
+
/**
* Makes sure that the label is not handled by any condition.
*
- * @param condInd condition index
+ * @param condInd
+ * condition index
* @return attribute index
*/
public static int attributeIndex(int condInd) {
int labelpos = DataSet.getDataSet().getLabelIndex();
- return (condInd < labelpos) ? condInd : condInd + 1;
+ return condInd < labelpos ? condInd : condInd + 1;
}
-
+
/**
* Returns the value of the condition.
*
- * @param condInd index of the condition
+ * @param condInd
+ * index of the condition
* @return
*/
boolean condition(int condInd, DataLine dl) {
- int attrInd = attributeIndex(condInd);
-
+ int attrInd = CDRule.attributeIndex(condInd);
+
// is the condition active
- if (getW(condInd) < threshold)
+ if (getW(condInd) < threshold) {
return true; // no
-
- if (DataSet.getDataSet().isNumerical(attrInd))
+ }
+
+ if (DataSet.getDataSet().isNumerical(attrInd)) {
return numericalCondition(condInd, dl);
- else
+ } else {
return categoricalCondition(condInd, dl);
+ }
}
-
+
boolean numericalCondition(int condInd, DataLine dl) {
- int attrInd = attributeIndex(condInd);
-
- if (getO(condInd))
+ int attrInd = CDRule.attributeIndex(condInd);
+
+ if (getO(condInd)) {
return dl.getAttribut(attrInd) >= getV(condInd);
- else
+ } else {
return dl.getAttribut(attrInd) < getV(condInd);
+ }
}
-
+
boolean categoricalCondition(int condInd, DataLine dl) {
- int attrInd = attributeIndex(condInd);
-
- if (getO(condInd))
+ int attrInd = CDRule.attributeIndex(condInd);
+
+ if (getO(condInd)) {
return dl.getAttribut(attrInd) == getV(condInd);
- else
+ } else {
return dl.getAttribut(attrInd) != getV(condInd);
+ }
}
-
+
@Override
public String toString() {
StringBuilder buffer = new StringBuilder();
-
+
buffer.append("CDRule = [");
boolean empty = true;
for (int condInd = 0; condInd < nbConditions; condInd++) {
if (getW(condInd) >= threshold) {
- if (!empty)
+ if (!empty) {
buffer.append(" && ");
-
- buffer.append("attr").append(attributeIndex(condInd)).append(' ').append(getO(condInd) ? ">=" : "<");
+ }
+
+ buffer.append("attr").append(CDRule.attributeIndex(condInd)).append(' ').append(
+ getO(condInd) ? ">=" : "<");
buffer.append(' ').append(getV(condInd));
-
+
empty = false;
}
}
buffer.append(']');
-
+
return buffer.toString();
}
-
+
public int getNbConditions() {
return nbConditions;
}
-
+
public double getW(int index) {
return weights[index];
}
-
+
public void setW(int index, double w) {
weights[index] = w;
}
-
+
/**
* operator
*
@@ -224,41 +234,45 @@
public boolean getO(int index) {
return operators.getBit(index);
}
-
+
/**
* set the operator
*
* @param index
- * @param o true if '>='; false if '<'
+ * @param o
+ * true if '>='; false if '<'
*/
public void setO(int index, boolean o) {
operators.setBit(index, o);
}
-
+
public double getV(int index) {
return values[index];
}
-
+
public void setV(int index, double v) {
values[index] = v;
}
-
+
@Override
public boolean equals(Object obj) {
- if (this == obj)
+ if (this == obj) {
return true;
- if (obj == null || !(obj instanceof CDRule))
+ }
+ if (obj == null || !(obj instanceof CDRule)) {
return false;
+ }
CDRule rule = (CDRule) obj;
-
+
for (int index = 0; index < nbConditions; index++) {
- if (!areGenesEqual(this, rule, index))
+ if (!CDRule.areGenesEqual(this, rule, index)) {
return false;
+ }
}
-
+
return true;
}
-
+
@Override
public int hashCode() {
int value = 0;
@@ -268,30 +282,31 @@
}
return value;
}
-
+
/**
* Compares a given gene between two rules
*
* @param rule1
* @param rule2
- * @param index gene index
+ * @param index
+ * gene index
* @return true if the gene is the same
*/
public static boolean areGenesEqual(CDRule rule1, CDRule rule2, int index) {
- return rule1.getW(index) == rule2.getW(index)
- && rule1.getO(index) == rule2.getO(index)
- && rule1.getV(index) == rule2.getV(index);
+ return rule1.getW(index) == rule2.getW(index) && rule1.getO(index) == rule2.getO(index)
+ && rule1.getV(index) == rule2.getV(index);
}
-
+
/**
* Compares two genes from this Rule
*
- * @param index1 first gene index
- * @param index2 second gene index
+ * @param index1
+ * first gene index
+ * @param index2
+ * second gene index
* @return if the genes are equal
*/
public boolean areGenesEqual(int index1, int index2) {
- return getW(index1) == getW(index2) && getO(index1) == getO(index2)
- && getV(index1) == getV(index2);
+ return getW(index1) == getW(index2) && getO(index1) == getO(index2) && getV(index1) == getV(index2);
}
}