You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/13 21:27:30 UTC
svn commit: r909900 [1/4] - in
/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df: ./ builder/
callback/ data/ data/conditions/ mapred/ mapred/inmem/ mapred/partial/
mapreduce/ mapreduce/inmem/ mapreduce/partial/ node/ ref/ split/ tools/
Author: robinanil
Date: Sat Feb 13 20:27:25 2010
New Revision: 909900
URL: http://svn.apache.org/viewvc?rev=909900&view=rev
Log:
MAHOUT-291
Submitting StyleChanges for DF
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/ErrorEstimate.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/TreeBuilder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/ForestPredictions.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MeanTreeCollector.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MultiCallback.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/PredictionCallback.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/SingleTreePredictions.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataUtils.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DescriptorUtils.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Condition.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Equals.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/GreaterOrEquals.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Lesser.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemBuilder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemInputFormat.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/partial/PartialBuilder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/partial/Step0Job.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/partial/Step1Mapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/partial/Step2Job.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/partial/Step2Mapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredOutput.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormat.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/InterResults.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/PartialBuilder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step0Job.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step1Mapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step2Job.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step2Mapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/TreeID.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/node/CategoricalNode.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Leaf.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/node/MockLeaf.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/node/NumericalNode.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/ref/SequentialBuilder.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/DefaultIgSplit.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/IgSplit.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Frequencies.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java Sat Feb 13 20:27:25 2010
@@ -31,41 +31,43 @@
* Builds a tree using bagging
*/
public class Bagging {
-
+
private static final Logger log = LoggerFactory.getLogger(Bagging.class);
-
+
private final TreeBuilder treeBuilder;
-
+
private final Data data;
-
+
private final boolean[] sampled;
-
+
public Bagging(TreeBuilder treeBuilder, Data data) {
this.treeBuilder = treeBuilder;
this.data = data;
sampled = new boolean[data.size()];
}
-
+
/**
* Builds one tree
*
- * @param treeId tree identifier
+ * @param treeId
+ * tree identifier
* @param rng
* @param callback
* @return
- * @throws RuntimeException if the data is not set
+ * @throws RuntimeException
+ * if the data is not set
*/
public Node build(int treeId, Random rng, PredictionCallback callback) {
- log.debug("Bagging...");
+ Bagging.log.debug("Bagging...");
Arrays.fill(sampled, false);
Data bag = data.bagging(rng, sampled);
-
- log.debug("Building...");
+
+ Bagging.log.debug("Building...");
Node tree = treeBuilder.build(rng, bag);
-
+
// predict the label for the out-of-bag elements
if (callback != null) {
- log.debug("Oob error estimation");
+ Bagging.log.debug("Oob error estimation");
for (int index = 0; index < data.size(); index++) {
if (sampled[index] == false) {
int prediction = tree.classify(data.get(index));
@@ -73,8 +75,8 @@
}
}
}
-
+
return tree;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java Sat Feb 13 20:27:25 2010
@@ -30,9 +30,8 @@
* Utility class that contains various helper methods
*/
public class DFUtils {
- private DFUtils() {
- }
-
+ private DFUtils() {}
+
/**
* Writes an Node[] into a DataOutput
*
@@ -40,8 +39,7 @@
* @param array
* @throws IOException
*/
- public static void writeArray(DataOutput out, Node[] array)
- throws IOException {
+ public static void writeArray(DataOutput out, Node[] array) throws IOException {
out.writeInt(array.length);
for (Node w : array) {
w.write(out);
@@ -61,10 +59,10 @@
for (int index = 0; index < length; index++) {
nodes[index] = Node.read(in);
}
-
+
return nodes;
}
-
+
/**
* Writes a double[] into a DataOutput
*
@@ -72,14 +70,13 @@
* @param array
* @throws IOException
*/
- public static void writeArray(DataOutput out, double[] array)
- throws IOException {
+ public static void writeArray(DataOutput out, double[] array) throws IOException {
out.writeInt(array.length);
for (double value : array) {
out.writeDouble(value);
}
}
-
+
/**
* Reads a double[] from a DataInput
*
@@ -93,10 +90,10 @@
for (int index = 0; index < length; index++) {
array[index] = in.readDouble();
}
-
+
return array;
}
-
+
/**
* Writes an int[] into a DataOutput
*
@@ -110,7 +107,7 @@
out.writeInt(value);
}
}
-
+
/**
* Reads an int[] from a DataInput
*
@@ -124,10 +121,10 @@
for (int index = 0; index < length; index++) {
array[index] = in.readInt();
}
-
+
return array;
}
-
+
/**
* Return a list of all files in the output directory
*
@@ -135,21 +132,20 @@
* @param outputPath
* @return
* @throws IOException
- * @throws RuntimeException if no file is found
+ * @throws RuntimeException
+ * if no file is found
*/
- public static Path[] listOutputFiles(FileSystem fs, Path outputPath)
- throws IOException {
+ public static Path[] listOutputFiles(FileSystem fs, Path outputPath) throws IOException {
Path[] outfiles = OutputUtils.listOutputFiles(fs, outputPath);
if (outfiles.length == 0) {
throw new IOException("No output found !");
}
-
+
return outfiles;
}
-
+
/**
- * Formats a time interval in milliseconds to a String in the form
- * "hours:minutes:seconds:millis"
+ * Formats a time interval in milliseconds to a String in the form "hours:minutes:seconds:millis"
*
* @param milli
* @return
@@ -157,14 +153,14 @@
public static String elapsedTime(long milli) {
long seconds = milli / 1000;
milli %= 1000;
-
+
long minutes = seconds / 60;
seconds %= 60;
-
+
long hours = minutes / 60;
minutes %= 60;
-
+
return hours + "h " + minutes + "m " + seconds + "s " + milli;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java Sat Feb 13 20:27:25 2010
@@ -31,21 +31,21 @@
* Represents a forest of decision trees.
*/
public class DecisionForest {
-
+
private final List<Node> trees;
-
+
protected DecisionForest() {
trees = new ArrayList<Node>();
}
-
+
public DecisionForest(List<Node> trees) {
- if (!(trees != null && !trees.isEmpty())) {
+ if (!((trees != null) && !trees.isEmpty())) {
throw new IllegalArgumentException("trees argument must not be null or empty");
}
-
+
this.trees = trees;
}
-
+
public List<Node> getTrees() {
return trees;
}
@@ -60,42 +60,46 @@
if (callback == null) {
throw new IllegalArgumentException("callback must not be null");
}
-
- if (data.isEmpty())
+
+ if (data.isEmpty()) {
return; // nothing to classify
-
+ }
+
for (int treeId = 0; treeId < trees.size(); treeId++) {
Node tree = trees.get(treeId);
-
+
for (int index = 0; index < data.size(); index++) {
int prediction = tree.classify(data.get(index));
callback.prediction(treeId, index, prediction);
}
}
}
-
+
/**
* predicts the label for the instance
*
- * @param rng Random number generator, used to break ties randomly
+ * @param rng
+ * Random number generator, used to break ties randomly
* @param instance
* @return -1 if the label cannot be predicted
*/
public int classify(Random rng, Instance instance) {
int[] predictions = new int[trees.size()];
-
+
for (Node tree : trees) {
int prediction = tree.classify(instance);
- if (prediction != -1)
+ if (prediction != -1) {
predictions[prediction]++;
+ }
}
-
- if (DataUtils.sum(predictions) == 0)
+
+ if (DataUtils.sum(predictions) == 0) {
return -1; // no prediction available
-
+ }
+
return DataUtils.maxindex(rng, predictions);
}
-
+
/**
* Mean number of nodes per tree
*
@@ -103,14 +107,14 @@
*/
public long meanNbNodes() {
long sum = 0;
-
+
for (Node tree : trees) {
sum += tree.nbNodes();
}
-
+
return sum / trees.size();
}
-
+
/**
* Total number of nodes in all the trees
*
@@ -118,14 +122,14 @@
*/
public long nbNodes() {
long sum = 0;
-
+
for (Node tree : trees) {
sum += tree.nbNodes();
}
-
+
return sum;
}
-
+
/**
* Mean maximum depth per tree
*
@@ -133,29 +137,31 @@
*/
public long meanMaxDepth() {
long sum = 0;
-
+
for (Node tree : trees) {
sum += tree.maxDepth();
}
-
+
return sum / trees.size();
}
-
+
@Override
public boolean equals(Object obj) {
- if (this == obj)
+ if (this == obj) {
return true;
- if (!(obj instanceof DecisionForest))
+ }
+ if (!(obj instanceof DecisionForest)) {
return false;
+ }
- DecisionForest rf = (DecisionForest)obj;
+ DecisionForest rf = (DecisionForest) obj;
- return trees.size() == rf.getTrees().size() && trees.containsAll(rf.getTrees());
+ return (trees.size() == rf.getTrees().size()) && trees.containsAll(rf.getTrees());
}
-
+
@Override
public int hashCode() {
return trees.hashCode();
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/ErrorEstimate.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/ErrorEstimate.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/ErrorEstimate.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/ErrorEstimate.java Sat Feb 13 20:27:25 2010
@@ -21,46 +21,49 @@
* various methods to compute from the output of a random forest
*/
public class ErrorEstimate {
- private ErrorEstimate() {
- }
-
+ private ErrorEstimate() {}
+
public static double errorRate(int[] labels, int[] predictions) {
if (labels.length != predictions.length) {
throw new IllegalArgumentException("labels.length != predictions.length");
}
-
+
double nberrors = 0; // number of instance that got bad predictions
double datasize = 0; // number of classified instances
-
+
for (int index = 0; index < labels.length; index++) {
- if (predictions[index] == -1)
+ if (predictions[index] == -1) {
continue; // instance not classified
-
- if (predictions[index] != labels[index])
+ }
+
+ if (predictions[index] != labels[index]) {
nberrors++;
-
+ }
+
datasize++;
}
-
+
return nberrors / datasize;
}
-
+
/**
* Counts the number of classified instances (prediction != -1)
+ *
* @param predictions
* @return
*/
public static int nbPredicted(int[] predictions) {
int nbpredicted = 0;
-
+
for (int prediction : predictions) {
- if (prediction != -1)
+ if (prediction != -1) {
nbpredicted++;
+ }
}
-
+
return nbpredicted;
}
-
+
/**
* Counts the number of instance that got bad predictions
*
@@ -72,17 +75,19 @@
if (labels.length != predictions.length) {
throw new IllegalArgumentException("labels.length != predictions.length");
}
-
+
int nberrors = 0;
-
+
for (int index = 0; index < labels.length; index++) {
- if (predictions[index] == -1)
+ if (predictions[index] == -1) {
continue; // instance not classified
-
- if (predictions[index] != labels[index])
+ }
+
+ if (predictions[index] != labels[index]) {
nberrors++;
+ }
}
-
+
return nberrors;
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java Sat Feb 13 20:27:25 2010
@@ -34,146 +34,159 @@
/**
* Builds a Decision Tree <br>
- * Based on the algorithm described in the "Decision Trees" tutorials by Andrew
- * W. Moore, available at:<br>
+ * Based on the algorithm described in the "Decision Trees" tutorials by Andrew W. Moore, available at:<br>
* <br>
* http://www.cs.cmu.edu/~awm/tutorials
*/
public class DefaultTreeBuilder implements TreeBuilder {
-
+
private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class);
-
+
/** indicates which CATEGORICAL attributes have already been selected in the parent nodes */
private boolean[] selected;
-
+
/** number of attributes to select randomly at each node */
private int m = 1;
-
+
/** IgSplit implementation */
private IgSplit igSplit;
-
+
public DefaultTreeBuilder() {
igSplit = new OptIgSplit();
}
-
+
public void setM(int m) {
this.m = m;
}
-
+
public void setIgSplit(IgSplit igSplit) {
this.igSplit = igSplit;
}
-
+
@Override
public Node build(Random rng, Data data) {
-
+
if (selected == null) {
selected = new boolean[data.getDataset().nbAttributes()];
}
-
- if (data.isEmpty())
+
+ if (data.isEmpty()) {
return new Leaf(-1);
- if (isIdentical(data))
+ }
+ if (isIdentical(data)) {
return new Leaf(data.majorityLabel(rng));
- if (data.identicalLabel())
+ }
+ if (data.identicalLabel()) {
return new Leaf(data.get(0).label);
-
- int[] attributes = randomAttributes(rng, selected, m);
-
+ }
+
+ int[] attributes = DefaultTreeBuilder.randomAttributes(rng, selected, m);
+
// find the best split
Split best = null;
for (int attr : attributes) {
Split split = igSplit.computeSplit(data, attr);
- if (best == null || best.ig < split.ig)
+ if ((best == null) || (best.ig < split.ig)) {
best = split;
+ }
}
-
+
boolean alreadySelected = selected[best.attr];
-
+
if (alreadySelected) {
// attribute already selected
- log.warn("attribute {} already selected in a parent node", best.attr);
+ DefaultTreeBuilder.log.warn("attribute {} already selected in a parent node", best.attr);
}
-
+
Node childNode;
if (data.getDataset().isNumerical(best.attr)) {
Data loSubset = data.subset(Condition.lesser(best.attr, best.split));
Node loChild = build(rng, loSubset);
-
- Data hiSubset = data.subset(Condition.greaterOrEquals(best.attr,
- best.split));
+
+ Data hiSubset = data.subset(Condition.greaterOrEquals(best.attr, best.split));
Node hiChild = build(rng, hiSubset);
-
+
childNode = new NumericalNode(best.attr, best.split, loChild, hiChild);
} else { // CATEGORICAL attribute
selected[best.attr] = true;
double[] values = data.values(best.attr);
Node[] childs = new Node[values.length];
-
+
for (int index = 0; index < values.length; index++) {
Data subset = data.subset(Condition.equals(best.attr, values[index]));
childs[index] = build(rng, subset);
}
-
+
childNode = new CategoricalNode(best.attr, values, childs);
-
+
if (!alreadySelected) {
selected[best.attr] = false;
}
}
-
+
return childNode;
}
-
+
/**
* checks if all the vectors have identical attribute values. Ignore selected attributes.
- *
+ *
* @return true is all the vectors are identical or the data is empty<br>
* false otherwise
*/
private boolean isIdentical(Data data) {
- if (data.isEmpty()) return true;
-
+ if (data.isEmpty()) {
+ return true;
+ }
+
Instance instance = data.get(0);
for (int attr = 0; attr < selected.length; attr++) {
- if (selected[attr]) continue;
-
- for (int index = 1; index < data.size(); index++) {
- if (data.get(index).get(attr) != instance.get(attr))
- return false;
+ if (selected[attr]) {
+ continue;
+ }
+
+ for (int index = 1; index < data.size(); index++) {
+ if (data.get(index).get(attr) != instance.get(attr)) {
+ return false;
+ }
}
}
-
+
return true;
}
-
+
/**
- * Randomly selects m attributes to consider for split, excludes IGNORED and
- * LABEL attributes
+ * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
*
- * @param rng random-numbers generator
- * @param selected attributes' state (selected or not)
- * @param m number of attributes to choose
+ * @param rng
+ * random-numbers generator
+ * @param selected
+ * attributes' state (selected or not)
+ * @param m
+ * number of attributes to choose
* @return
*/
protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
int nbNonSelected = 0; // number of non selected attributes
for (boolean sel : selected) {
- if (!sel) nbNonSelected++;
+ if (!sel) {
+ nbNonSelected++;
+ }
}
-
+
if (nbNonSelected == 0) {
- log.warn("All attributes are selected !");
+ DefaultTreeBuilder.log.warn("All attributes are selected !");
}
-
+
int[] result;
if (nbNonSelected <= m) {
// return all non selected attributes
result = new int[nbNonSelected];
int index = 0;
for (int attr = 0; attr < selected.length; attr++) {
- if (!selected[attr]) result[index++] = attr;
+ if (!selected[attr]) {
+ result[index++] = attr;
+ }
}
} else {
result = new int[m];
@@ -183,17 +196,17 @@
do {
rind = rng.nextInt(selected.length);
} while (selected[rind]);
-
+
result[index] = rind;
selected[rind] = true; // temporarely set the choosen attribute to be selected
}
-
+
// the choosen attributes are not yet selected
for (int attr : result) {
selected[attr] = false;
}
}
-
+
return result;
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/TreeBuilder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/TreeBuilder.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/TreeBuilder.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/TreeBuilder.java Sat Feb 13 20:27:25 2010
@@ -30,10 +30,12 @@
/**
* Builds a Decision tree using the training data
*
- * @param rng random-numbers generator
- * @param data training data
+ * @param rng
+ * random-numbers generator
+ * @param data
+ * training data
* @return root Node
*/
Node build(Random rng, Data data);
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/ForestPredictions.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/ForestPredictions.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/ForestPredictions.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/ForestPredictions.java Sat Feb 13 20:27:25 2010
@@ -26,45 +26,48 @@
* Collects a forest's predictions
*/
public class ForestPredictions implements PredictionCallback {
-
+
/** predictions[n][label] = number of times instance n was classified 'label' */
private final int[][] predictions;
-
+
public ForestPredictions(int nbInstances, int nblabels) {
predictions = new int[nbInstances][];
for (int index = 0; index < predictions.length; index++) {
predictions[index] = new int[nblabels];
}
}
-
+
@Override
public void prediction(int treeId, int instanceId, int prediction) {
- if (prediction != -1)
+ if (prediction != -1) {
predictions[instanceId][prediction]++;
+ }
}
-
+
@Override
public boolean equals(Object obj) {
- if (this == obj)
+ if (this == obj) {
return true;
- if (!(obj instanceof ForestPredictions))
+ }
+ if (!(obj instanceof ForestPredictions)) {
return false;
-
+ }
+
ForestPredictions fp = (ForestPredictions) obj;
-
+
if (predictions.length != fp.predictions.length) {
return false;
}
-
+
for (int i = 0; i < predictions.length; i++) {
if (!Arrays.equals(predictions[i], fp.predictions[i])) {
return false;
}
}
-
+
return true;
}
-
+
@Override
public int hashCode() {
int hashCode = 1;
@@ -75,10 +78,10 @@
}
return hashCode;
}
-
+
/**
- * compute the prediction for each instance. the prediction of an instance is
- * the index of the label that got most of the votes
+ * compute the prediction for each instance. the prediction of an instance is the index of the label that
+ * got most of the votes
*
* @param rng
* @return
@@ -86,15 +89,16 @@
public int[] computePredictions(Random rng) {
int[] result = new int[predictions.length];
Arrays.fill(result, -1);
-
+
for (int index = 0; index < predictions.length; index++) {
- if (DataUtils.sum(predictions[index]) == 0)
+ if (DataUtils.sum(predictions[index]) == 0) {
continue; // this instance has not been classified
-
+ }
+
result[index] = DataUtils.maxindex(rng, predictions[index]);
}
-
+
return result;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MeanTreeCollector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MeanTreeCollector.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MeanTreeCollector.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MeanTreeCollector.java Sat Feb 13 20:27:25 2010
@@ -23,13 +23,13 @@
* Computes the error rate for each tree, and returns the mean of all the trees
*/
public class MeanTreeCollector implements PredictionCallback {
-
+
/** number of errors for each tree */
private final int[] nbErrors;
-
+
/** number of predictions for each tree */
private final int[] nbPredictions;
-
+
private final Data data;
public MeanTreeCollector(Data data, int nbtrees) {
@@ -37,29 +37,32 @@
nbPredictions = new int[nbtrees];
this.data = data;
}
-
+
public double meanTreeError() {
double sumerror = 0.0;
-
+
for (int treeId = 0; treeId < nbErrors.length; treeId++) {
- if (nbPredictions[treeId] == 0)
+ if (nbPredictions[treeId] == 0) {
continue; // this tree has 0 predictions
-
+ }
+
sumerror += (double) nbErrors[treeId] / nbPredictions[treeId];
}
-
+
return sumerror / nbErrors.length;
}
-
+
@Override
public void prediction(int treeId, int instanceId, int prediction) {
- if (prediction == -1)
+ if (prediction == -1) {
return;
-
+ }
+
nbPredictions[treeId]++;
-
- if (data.get(instanceId).label != prediction)
+
+ if (data.get(instanceId).label != prediction) {
nbErrors[treeId]++;
+ }
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MultiCallback.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MultiCallback.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MultiCallback.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/MultiCallback.java Sat Feb 13 20:27:25 2010
@@ -17,22 +17,22 @@
package org.apache.mahout.df.callback;
-
/**
* Combines many callbacks, that will be called when a prediction is done.
*/
public class MultiCallback implements PredictionCallback {
-
+
private final PredictionCallback[] callbacks;
-
+
public MultiCallback(PredictionCallback... callbacks) {
this.callbacks = callbacks;
}
-
+
@Override
public void prediction(int treeId, int instanceId, int prediction) {
- for (PredictionCallback callback : callbacks)
+ for (PredictionCallback callback : callbacks) {
callback.prediction(treeId, instanceId, prediction);
+ }
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/PredictionCallback.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/PredictionCallback.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/PredictionCallback.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/PredictionCallback.java Sat Feb 13 20:27:25 2010
@@ -17,18 +17,20 @@
package org.apache.mahout.df.callback;
-
/**
* Called each time an instance has been classified
*/
public interface PredictionCallback {
-
+
/**
* called when an instance has been classified
*
- * @param treeId tree that classified the instance
- * @param instanceId classified instance
- * @param prediction predicted label
+ * @param treeId
+ * tree that classified the instance
+ * @param instanceId
+ * classified instance
+ * @param prediction
+ * predicted label
*/
void prediction(int treeId, int instanceId, int prediction);
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/SingleTreePredictions.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/SingleTreePredictions.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/SingleTreePredictions.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/SingleTreePredictions.java Sat Feb 13 20:27:25 2010
@@ -23,10 +23,10 @@
* Collects the predictions for a single tree
*/
public class SingleTreePredictions implements PredictionCallback {
-
+
/** predictions[n] = 'label' predicted for instance 'n' */
private final int[] predictions;
-
+
/** used to assert that all the predictions belong to the same tree */
private Integer treeId;
@@ -47,9 +47,9 @@
predictions[instanceId] = prediction;
}
-
+
public int[] getPredictions() {
return predictions;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java Sat Feb 13 20:27:25 2010
@@ -31,21 +31,21 @@
import org.apache.mahout.df.data.conditions.Condition;
/**
- * Holds a list of vectors and their corresponding Dataset. contains various
- * operations that deals with the vectors (subset, count,...)
+ * Holds a list of vectors and their corresponding Dataset. contains various operations that deals with the
+ * vectors (subset, count,...)
*
*/
public class Data implements Cloneable {
-
+
private final List<Instance> instances;
-
+
private final Dataset dataset;
-
+
public Data(Dataset dataset, List<Instance> instances) {
this.dataset = dataset;
this.instances = new ArrayList<Instance>(instances);
}
-
+
/**
* Returns the number of elements
*
@@ -54,7 +54,7 @@
public int size() {
return instances.size();
}
-
+
/**
* Returns true is this data contains no element
*
@@ -63,38 +63,42 @@
public boolean isEmpty() {
return instances.isEmpty();
}
-
+
/**
* Returns true is this data contains the specified element.
*
- * @param v element whose presence in this list if to be searched
+ * @param v
+ * element whose presence in this list if to be searched
* @return
*/
public boolean contains(Instance v) {
return instances.contains(v);
}
-
+
/**
* Returns the index of the first occurrence of the element in this data
*
- * @param v element to search for
+ * @param v
+ * element to search for
* @return -1 if the element is not found
*/
public int indexof(Instance v) {
return instances.indexOf(v);
}
-
+
/**
* Returns the element at the specified position
*
- * @param index index of element to return
+ * @param index
+ * index of element to return
* @return the element at the specified position
- * @throws IndexOutOfBoundsException if the index is out of range
+ * @throws IndexOutOfBoundsException
+ * if the index is out of range
*/
public Instance get(int index) {
return instances.get(index);
}
-
+
/**
* Returns the subset from this data that matches the given condition
*
@@ -103,33 +107,37 @@
*/
public Data subset(Condition condition) {
List<Instance> subset = new ArrayList<Instance>();
-
+
for (Instance instance : instances) {
- if (condition.isTrueFor(instance))
+ if (condition.isTrueFor(instance)) {
subset.add(instance);
+ }
}
-
+
return new Data(dataset, subset);
}
-
+
/**
* Returns a random subset without modifying the current data
*
- * @param rng Random number generator
- * @param ratio [0,1]
+ * @param rng
+ * Random number generator
+ * @param ratio
+ * [0,1]
* @return
*/
public Data rsubset(Random rng, double ratio) {
List<Instance> subset = new ArrayList<Instance>();
-
+
for (Instance instance : instances) {
- if (rng.nextDouble() < ratio)
+ if (rng.nextDouble() < ratio) {
subset.add(instance);
+ }
}
-
+
return new Data(dataset, subset);
}
-
+
/**
* if data has N cases, sample N cases at random -but with replacement.
*
@@ -139,53 +147,52 @@
public Data bagging(Random rng) {
int datasize = size();
List<Instance> bag = new ArrayList<Instance>(datasize);
-
+
for (int i = 0; i < datasize; i++) {
bag.add(instances.get(rng.nextInt(datasize)));
}
-
+
return new Data(dataset, bag);
}
-
+
/**
* if data has N cases, sample N cases at random -but with replacement.
*
* @param rng
- * @param sampled indicating which instance has been sampled
+ * @param sampled
+ * indicating which instance has been sampled
*
* @return sampled data
*/
public Data bagging(Random rng, boolean[] sampled) {
int datasize = size();
List<Instance> bag = new ArrayList<Instance>(datasize);
-
+
for (int i = 0; i < datasize; i++) {
int index = rng.nextInt(datasize);
bag.add(instances.get(index));
sampled[index] = true;
}
-
+
return new Data(dataset, bag);
}
-
+
/**
- * Splits the data in two, returns one part, and this gets the rest of the
- * data. <b>VERY SLOW!</b>
+ * Splits the data in two, returns one part, and this gets the rest of the data. <b>VERY SLOW!</b>
*
* @param rng
* @return
*/
public Data rsplit(Random rng, int subsize) {
List<Instance> subset = new ArrayList<Instance>(subsize);
-
+
for (int i = 0; i < subsize; i++) {
subset.add(instances.remove(rng.nextInt(instances.size())));
}
-
+
return new Data(dataset, subset);
}
-
-
+
/**
* checks if all the vectors have identical attribute values
*
@@ -193,39 +200,42 @@
* false otherwise
*/
public boolean isIdentical() {
- if (isEmpty())
+ if (isEmpty()) {
return true;
-
+ }
+
Instance instance = get(0);
for (int attr = 0; attr < dataset.nbAttributes(); attr++) {
for (int index = 1; index < size(); index++) {
- if (get(index).get(attr) != instance.get(attr))
+ if (get(index).get(attr) != instance.get(attr)) {
return false;
+ }
}
}
-
+
return true;
}
-
-
+
/**
* checks if all the vectors have identical label values
*
* @return
*/
public boolean identicalLabel() {
- if (isEmpty())
+ if (isEmpty()) {
return true;
-
+ }
+
int label = get(0).label;
for (int index = 1; index < size(); index++) {
- if (get(index).label != label)
+ if (get(index).label != label) {
return false;
+ }
}
-
+
return true;
}
-
+
/**
* finds all distinct values of a given attribute
*
@@ -234,43 +244,45 @@
*/
public double[] values(int attr) {
Set<Double> result = new HashSet<Double>();
-
+
for (Instance instance : instances) {
result.add(instance.get(attr));
}
-
+
double[] values = new double[result.size()];
-
+
int index = 0;
for (Double value : result) {
values[index++] = value;
}
-
+
return values;
}
-
+
@Override
public Data clone() {
return new Data(dataset, new ArrayList<Instance>(instances));
}
-
+
@Override
public boolean equals(Object obj) {
- if (this == obj)
+ if (this == obj) {
return true;
- if (obj == null || !(obj instanceof Data))
+ }
+ if ((obj == null) || !(obj instanceof Data)) {
return false;
+ }
- Data data = (Data)obj;
+ Data data = (Data) obj;
return instances.equals(data.instances) && dataset.equals(data.dataset);
}
-
+
@Override
public int hashCode() {
return instances.hashCode() + dataset.hashCode();
}
-
+
/**
* extract the labels of all instances
*
@@ -278,27 +290,28 @@
*/
public int[] extractLabels() {
int[] labels = new int[size()];
-
+
for (int index = 0; index < labels.length; index++) {
labels[index] = get(index).label;
}
-
+
return labels;
}
-
-
+
/**
* extract the labels of all instances from a data file
*
* @param dataset
- * @param fs file system
- * @param path data path
+ * @param fs
+ * file system
+ * @param path
+ * data path
* @return
*/
public static int[] extractLabels(Dataset dataset, FileSystem fs, Path path) throws IOException {
FSDataInputStream input = fs.open(path);
FileLineIterator iterator = new FileLineIterator(input);
-
+
int[] labels = new int[dataset.nbInstances()];
DataConverter converter = new DataConverter(dataset);
@@ -312,6 +325,7 @@
return labels;
}
+
/**
* finds the majority label, breaking ties randomly
*
@@ -320,27 +334,27 @@
public int majorityLabel(Random rng) {
// count the frequency of each label value
int[] counts = new int[dataset.nblabels()];
-
+
for (int index = 0; index < size(); index++) {
counts[get(index).label]++;
}
-
+
// find the label values that appears the most
return DataUtils.maxindex(rng, counts);
}
-
+
/**
* Counts the number of occurrences of each label value
-
- * @param counts will contain the results, supposed to be initialized at 0
+ *
+ * @param counts
+ * will contain the results, supposed to be initialized at 0
*/
public void countLabels(int[] counts) {
for (int index = 0; index < size(); index++) {
counts[get(index).label]++;
}
}
-
-
+
public Dataset getDataset() {
return dataset;
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java Sat Feb 13 20:27:25 2010
@@ -17,8 +17,8 @@
package org.apache.mahout.df.data;
-import java.util.StringTokenizer;
import java.util.Arrays;
+import java.util.StringTokenizer;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.math.DenseVector;
@@ -29,27 +29,27 @@
* Converts String to Instance using a Dataset
*/
public class DataConverter {
-
+
private static final Logger log = LoggerFactory.getLogger(DataConverter.class);
-
+
private final Dataset dataset;
-
+
public DataConverter(Dataset dataset) {
this.dataset = dataset;
}
-
+
public Instance convert(int id, String string) {
// all attributes (categorical, numerical), ignored, label
int nball = dataset.nbAttributes() + dataset.getIgnored().length + 1;
-
+
StringTokenizer tokenizer = new StringTokenizer(string, ", ");
if (tokenizer.countTokens() != nball) {
throw new IllegalArgumentException("Wrong number of attributes in the string");
}
-
+
int nbattrs = dataset.nbAttributes();
DenseVector vector = new DenseVector(nbattrs);
-
+
int aId = 0;
int label = -1;
for (int attr = 0; attr < nball; attr++) {
@@ -58,7 +58,7 @@
if (ArrayUtils.contains(dataset.getIgnored(), attr)) {
continue; // IGNORED
}
-
+
if ("?".equals(token)) {
// missing value
return null;
@@ -67,9 +67,9 @@
if (attr == dataset.getLabelId()) {
label = dataset.labelCode(token);
if (label == -1) {
- log.error(String.format("label token: %s\ndataset.labels: %s",
- token, Arrays.toString(dataset.labels())));
- throw new IllegalStateException("Label value ("+token+") not known");
+ DataConverter.log.error(String.format("label token: %s\ndataset.labels: %s", token, Arrays
+ .toString(dataset.labels())));
+ throw new IllegalStateException("Label value (" + token + ") not known");
}
} else if (dataset.isNumerical(aId)) {
vector.set(aId++, Double.parseDouble(token));
@@ -78,12 +78,12 @@
aId++;
}
}
-
+
if (label == -1) {
- log.error(String.format("Label not found, instance id : %d, \nstring : %s", id, string));
+ DataConverter.log.error(String.format("Label not found, instance id : %d, \nstring : %s", id, string));
throw new IllegalStateException("Label not found!");
}
-
+
return new Instance(id, vector, label);
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java Sat Feb 13 20:27:25 2010
@@ -32,76 +32,81 @@
import org.slf4j.LoggerFactory;
/**
- * Converts the input data to a Vector Array using the information given by the
- * Dataset.<br>
+ * Converts the input data to a Vector Array using the information given by the Dataset.<br>
* Generates for each line a Vector that contains :<br>
* <ul>
- * <li> double parsed value for NUMERICAL attributes </li>
- * <li> int value for CATEGORICAL and LABEL attributes </li>
+ * <li>double parsed value for NUMERICAL attributes</li>
+ * <li>int value for CATEGORICAL and LABEL attributes</li>
* </ul>
* <br>
- * adds an IGNORED first attribute that will contain a unique id for each
- * instance, which is the line number of the instance in the input data
+ * adds an IGNORED first attribute that will contain a unique id for each instance, which is the line number
+ * of the instance in the input data
*/
public class DataLoader {
-
+
private static final Logger log = LoggerFactory.getLogger(DataLoader.class);
-
- private DataLoader() {
- }
-
+
+ private DataLoader() {}
+
/**
* Converts a comma-separated String to a Vector.
*
- * @param id unique id for the current instance
- * @param attrs attributes description
- * @param values used to convert CATEGORICAL attribute values to Integer
+ * @param id
+ * unique id for the current instance
+ * @param attrs
+ * attributes description
+ * @param values
+ * used to convert CATEGORICAL attribute values to Integer
* @param string
* @return null if there are missing values '?'
*/
- private static Instance parseString(int id, Attribute[] attrs,
- List<String>[] values, String string) {
+ private static Instance parseString(int id, Attribute[] attrs, List<String>[] values, String string) {
StringTokenizer tokenizer = new StringTokenizer(string, ", ");
if (tokenizer.countTokens() != attrs.length) {
- log.error("{}: {}", id, string);
+ DataLoader.log.error("{}: {}", id, string);
throw new IllegalArgumentException("Wrong number of attributes in the string");
}
-
+
// extract tokens and check is there is any missing value
String[] tokens = new String[attrs.length];
for (int attr = 0; attr < attrs.length; attr++) {
String token = tokenizer.nextToken();
-
- if (attrs[attr].isIgnored())
+
+ if (attrs[attr].isIgnored()) {
continue;
-
- if ("?".equals(token))
+ }
+
+ if ("?".equals(token)) {
return null; // missing value
-
+ }
+
tokens[attr] = token;
}
-
+
int nbattrs = Dataset.countAttributes(attrs);
-
+
DenseVector vector = new DenseVector(nbattrs);
-
+
int aId = 0;
int label = -1;
for (int attr = 0; attr < attrs.length; attr++) {
- if (attrs[attr].isIgnored())
+ if (attrs[attr].isIgnored()) {
continue;
-
+ }
+
String token = tokens[attr];
-
+
if (attrs[attr].isNumerical()) {
vector.set(aId++, Double.parseDouble(token));
} else { // CATEGORICAL or LABEL
// update values
- if (values[attr] == null)
+ if (values[attr] == null) {
values[attr] = new ArrayList<String>();
- if (!values[attr].contains(token))
+ }
+ if (!values[attr].contains(token)) {
values[attr].add(token);
-
+ }
+
if (attrs[attr].isCategorical()) {
vector.set(aId++, values[attr].indexOf(token));
} else { // LABEL
@@ -109,21 +114,25 @@
}
}
}
-
- if (label == -1)
+
+ if (label == -1) {
throw new IllegalStateException("Label not found!");
-
+ }
+
return new Instance(id, vector, label);
}
-
+
/**
* Loads the data from a file
*
* @param dataset
- * @param fs file system
- * @param fpath data file path
+ * @param fs
+ * file system
+ * @param fpath
+ * data file path
* @return
- * @throws IOException if any problem is encountered
+ * @throws IOException
+ * if any problem is encountered
*/
public static Data loadData(Dataset dataset, FileSystem fs, Path fpath) throws IOException {
@@ -131,49 +140,49 @@
Scanner scanner = new Scanner(input);
List<Instance> instances = new ArrayList<Instance>();
-
+
DataConverter converter = new DataConverter(dataset);
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
if (line.isEmpty()) {
- log.warn("{}: empty string", instances.size());
+ DataLoader.log.warn("{}: empty string", instances.size());
continue;
}
Instance instance = converter.convert(instances.size(), line);
if (instance == null) {
// missing values found
- log.warn("{}: missing values", instances.size());
+ DataLoader.log.warn("{}: missing values", instances.size());
continue;
}
instances.add(instance);
}
-
+
scanner.close();
return new Data(dataset, instances);
}
-
+
/**
* Loads the data from a String array
*/
public static Data loadData(Dataset dataset, String[] data) {
List<Instance> instances = new ArrayList<Instance>();
-
+
DataConverter converter = new DataConverter(dataset);
for (String line : data) {
if (line.isEmpty()) {
- log.warn("{}: empty string", instances.size());
+ DataLoader.log.warn("{}: empty string", instances.size());
continue;
}
Instance instance = converter.convert(instances.size(), line);
if (instance == null) {
// missing values found
- log.warn("{}: missing values", instances.size());
+ DataLoader.log.warn("{}: missing values", instances.size());
continue;
}
@@ -182,20 +191,24 @@
return new Data(dataset, instances);
}
-
+
/**
* Generates the Dataset by parsing the entire data
- * @param descriptor attributes description
- * @param fs file system
- * @param path data path
+ *
+ * @param descriptor
+ * attributes description
+ * @param fs
+ * file system
+ * @param path
+ * data path
*/
- public static Dataset generateDataset(String descriptor, FileSystem fs, Path path)
- throws DescriptorException, IOException {
+ public static Dataset generateDataset(String descriptor, FileSystem fs, Path path) throws DescriptorException,
+ IOException {
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
-
+
FSDataInputStream input = fs.open(path);
Scanner scanner = new Scanner(input);
-
+
// used to convert CATEGORICAL attribute to Integer
List<String>[] values = new List[attrs.length];
@@ -206,54 +219,59 @@
continue;
}
- if (parseString(id, attrs, values, line) != null) {
+ if (DataLoader.parseString(id, attrs, values, line) != null) {
id++;
}
}
-
+
scanner.close();
-
+
return new Dataset(attrs, values, id);
}
-
+
/**
* Generates the Dataset by parsing the entire data
- * @param descriptor attributes description
+ *
+ * @param descriptor
+ * attributes description
* @param data
*/
public static Dataset generateDataset(String descriptor, String[] data) throws DescriptorException {
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
-
+
// used to convert CATEGORICAL and LABEL attributes to Integer
List<String>[] values = new List[attrs.length];
-
+
int id = 0;
for (String aData : data) {
if (aData.isEmpty()) {
continue;
}
-
- if (parseString(id, attrs, values, aData) != null) {
+
+ if (DataLoader.parseString(id, attrs, values, aData) != null) {
id++;
}
}
-
+
return new Dataset(attrs, values, id);
}
/**
* constructs the data
*
- * @param attrs attributes description
- * @param vectors data elements
- * @param values used to convert CATEGORICAL attributes to Integer
+ * @param attrs
+ * attributes description
+ * @param vectors
+ * data elements
+ * @param values
+ * used to convert CATEGORICAL attributes to Integer
* @return
- * @throws RuntimeException if no LABEL is found in the attributes description
+ * @throws RuntimeException
+ * if no LABEL is found in the attributes description
*/
- protected static Data constructData(Attribute[] attrs,
- List<Instance> vectors, List<String>[] values) {
+ protected static Data constructData(Attribute[] attrs, List<Instance> vectors, List<String>[] values) {
Dataset dataset = new Dataset(attrs, values, vectors.size());
-
+
return new Data(dataset, vectors);
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataUtils.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataUtils.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataUtils.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataUtils.java Sat Feb 13 20:27:25 2010
@@ -25,9 +25,8 @@
* Helper methods that deals with data lists and arrays of values
*/
public class DataUtils {
- private DataUtils() {
- }
-
+ private DataUtils() {}
+
/**
* Computes the sum of the values
*
@@ -39,12 +38,13 @@
for (int value : values) {
sum += value;
}
-
+
return sum;
}
-
+
/**
* foreach i : array1[i] += array2[i]
+ *
* @param array1
* @param array2
*/
@@ -57,9 +57,10 @@
array1[index] += array2[index];
}
}
-
+
/**
* foreach i : array1[i] -= array2[i]
+ *
* @param array1
* @param array2
*/
@@ -72,18 +73,19 @@
array1[index] -= array2[index];
}
}
-
+
/**
* return the index of the maximum of the array, breaking ties randomly
*
- * @param rng used to break ties
+ * @param rng
+ * used to break ties
* @param values
* @return index of the maximum
*/
public static int maxindex(Random rng, int[] values) {
int max = 0;
List<Integer> maxindices = new ArrayList<Integer>();
-
+
for (int index = 0; index < values.length; index++) {
if (values[index] > max) {
max = values[index];
@@ -93,7 +95,7 @@
maxindices.add(index);
}
}
-
+
int bestind;
if (maxindices.size() > 1) {
// break ties randomly
@@ -101,7 +103,7 @@
} else {
bestind = maxindices.get(0);
}
-
+
return bestind;
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java Sat Feb 13 20:27:25 2010
@@ -37,55 +37,58 @@
*
*/
public class Dataset implements Writable {
-
+
/**
* Attributes type
*/
public enum Attribute {
- IGNORED, NUMERICAL, CATEGORICAL, LABEL;
-
+ IGNORED,
+ NUMERICAL,
+ CATEGORICAL,
+ LABEL;
+
public boolean isNumerical() {
return this == NUMERICAL;
}
-
+
public boolean isCategorical() {
return this == CATEGORICAL;
}
-
+
public boolean isLabel() {
return this == LABEL;
}
-
+
public boolean isIgnored() {
- return this == IGNORED;
+ return this == IGNORED;
}
}
-
+
private Attribute[] attributes;
-
+
/** all distinct labels */
private String[] labels;
-
+
/** list of ignored attributes */
private int[] ignored;
/** distinct values (CATEGORIAL attributes only) */
private String[][] values;
-
+
/** index of the label attribute in the original data */
private int labelId;
/** number of instances in the dataset */
private int nbInstances;
-
+
public String[] labels() {
return Arrays.copyOf(labels, labels.length);
}
-
+
public int nblabels() {
return labels.length;
}
-
+
public int getLabelId() {
return labelId;
}
@@ -97,20 +100,23 @@
/**
* Returns the code used to represent the label value in the data
*
- * @param label label's value to code
+ * @param label
+ * label's value to code
* @return label's code
*/
public int labelCode(String label) {
return ArrayUtils.indexOf(labels, label);
}
-
+
public String getLabel(int code) {
return labels[code];
}
/**
* Converts a token to its corresponding int code for a given attribute
- * @param attr attribute's index
+ *
+ * @param attr
+ * attribute's index
* @param token
* @return
*/
@@ -121,34 +127,35 @@
if (values == null) {
throw new IllegalStateException("Values not found");
}
-
+
return ArrayUtils.indexOf(values[attr], token);
}
public int[] getIgnored() {
return ignored;
}
-
- private Dataset() {
- }
-
+
+ private Dataset() {}
+
/**
* Should only be called by a DataLoader
*
- * @param attrs attributes description
- * @param values distinct values for all CATEGORICAL attributes
+ * @param attrs
+ * attributes description
+ * @param values
+ * distinct values for all CATEGORICAL attributes
* @param nbInstances
*/
protected Dataset(Attribute[] attrs, List<String>[] values, int nbInstances) {
- validateValues(attrs, values);
-
- int nbattrs = countAttributes(attrs);
+ Dataset.validateValues(attrs, values);
+
+ int nbattrs = Dataset.countAttributes(attrs);
// the label values are set apart
attributes = new Attribute[nbattrs];
this.values = new String[nbattrs][];
ignored = new int[attrs.length - (nbattrs + 1)]; // nbignored = total - (nbattrs + label)
-
+
labelId = -1;
int ignoredId = 0;
int ind = 0;
@@ -170,20 +177,20 @@
this.values[ind] = new String[values[attr].size()];
values[attr].toArray(this.values[ind]);
}
-
+
attributes[ind++] = attrs[attr];
}
if (labelId == -1) {
throw new IllegalStateException("Label not found");
}
-
+
labels = new String[values[labelId].size()];
values[labelId].toArray(labels);
this.nbInstances = nbInstances;
}
-
+
/**
* Counts the number of attributes, except IGNORED and LABEL
*
@@ -192,30 +199,30 @@
*/
protected static int countAttributes(Attribute[] attrs) {
int nbattrs = 0;
-
+
for (Attribute attr1 : attrs) {
- if (attr1.isNumerical() || attr1.isCategorical())
+ if (attr1.isNumerical() || attr1.isCategorical()) {
nbattrs++;
+ }
}
-
+
return nbattrs;
}
-
+
private static void validateValues(Attribute[] attrs, List<String>[] values) {
if (attrs.length != values.length) {
throw new IllegalArgumentException("attrs.length != values.length");
}
-
+
for (int attr = 0; attr < attrs.length; attr++) {
if (attrs[attr].isCategorical()) {
if (values[attr] == null) {
- throw new IllegalArgumentException("values not found for attribute N° "
- + attr);
+ throw new IllegalArgumentException("values not found for attribute N° " + attr);
}
}
}
}
-
+
/**
* Number of attributes
*
@@ -224,24 +231,27 @@
public int nbAttributes() {
return attributes.length;
}
-
+
/**
* Is this a numerical attribute ?
*
- * @param attr index of the attribute to check
+ * @param attr
+ * index of the attribute to check
* @return true if the attribute is numerical
*/
public boolean isNumerical(int attr) {
return attributes[attr].isNumerical();
}
-
+
@Override
public boolean equals(Object obj) {
- if (this == obj)
+ if (this == obj) {
return true;
- if (obj == null || !(obj instanceof Dataset))
+ }
+ if ((obj == null) || !(obj instanceof Dataset)) {
return false;
-
+ }
+
Dataset dataset = (Dataset) obj;
if (!Arrays.equals(attributes, dataset.attributes)) {
@@ -251,16 +261,17 @@
if (!Arrays.equals(labels, dataset.labels)) {
return false;
}
-
+
for (int attr = 0; attr < nbAttributes(); attr++) {
if (!Arrays.equals(values[attr], dataset.values[attr])) {
return false;
}
}
-
- return labelId == dataset.labelId && nbInstances == dataset.nbInstances;
+
+ return (labelId == dataset.labelId) && (nbInstances == dataset.nbInstances);
}
-
+
+ @Override
public int hashCode() {
int hashCode = labelId + 31 * nbInstances;
for (Attribute attr : attributes) {
@@ -276,19 +287,20 @@
}
return hashCode;
}
-
+
/**
* Loads the dataset from a file
+ *
* @throws IOException
*/
public static Dataset load(Configuration conf, Path path) throws IOException {
FileSystem fs = path.getFileSystem(conf);
FSDataInputStream input = fs.open(path);
-
- Dataset dataset = read(input);
+
+ Dataset dataset = Dataset.read(input);
input.close();
-
+
return dataset;
}
@@ -323,11 +335,11 @@
labelId = in.readInt();
nbInstances = in.readInt();
}
-
+
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(attributes.length); // nb attributes
- for (Attribute attr:attributes) {
+ for (Attribute attr : attributes) {
WritableUtils.writeString(out, attr.name());
}
@@ -345,5 +357,5 @@
out.writeInt(labelId);
out.writeInt(nbInstances);
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DescriptorUtils.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DescriptorUtils.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DescriptorUtils.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DescriptorUtils.java Sat Feb 13 20:27:25 2010
@@ -28,41 +28,43 @@
* Contains various methods that deal with descriptor strings
*/
public class DescriptorUtils {
- private DescriptorUtils() {
- }
-
+ private DescriptorUtils() {}
+
/**
* Parses a descriptor string and generates the corresponding array of Attributes
*
* @param descriptor
* @return
- * @throws DescriptorException if a bad token is encountered
+ * @throws DescriptorException
+ * if a bad token is encountered
*/
public static Attribute[] parseDescriptor(String descriptor) throws DescriptorException {
StringTokenizer tokenizer = new StringTokenizer(descriptor);
Attribute[] attributes = new Attribute[tokenizer.countTokens()];
-
+
for (int attr = 0; attr < attributes.length; attr++) {
String token = tokenizer.nextToken().toUpperCase(Locale.ENGLISH);
- if ("I".equals(token))
+ if ("I".equals(token)) {
attributes[attr] = Attribute.IGNORED;
- else if ("N".equals(token))
+ } else if ("N".equals(token)) {
attributes[attr] = Attribute.NUMERICAL;
- else if ("C".equals(token))
+ } else if ("C".equals(token)) {
attributes[attr] = Attribute.CATEGORICAL;
- else if ("L".equals(token)) {
+ } else if ("L".equals(token)) {
attributes[attr] = Attribute.LABEL;
- } else
+ } else {
throw new DescriptorException("Bad Token : " + token);
+ }
}
-
+
return attributes;
}
-
+
/**
* Generates a valid descriptor string from a user-friendly representation.<br>
* for example "3 N I N N 2 C L 5 I" generates "N N N I N N C C L I I I I I".<br>
* this useful when describing datasets with a large number of attributes
+ *
* @param description
* @return
* @throws DescriptorException
@@ -75,11 +77,12 @@
tokens.add(tokenizer.nextToken());
}
- return generateDescriptor(tokens);
+ return DescriptorUtils.generateDescriptor(tokens);
}
-
+
/**
* Generates a valid descriptor string from a list of tokens
+ *
* @param tokens
* @return
* @throws DescriptorException
@@ -88,14 +91,14 @@
StringBuilder descriptor = new StringBuilder();
int multiplicator = 0;
-
+
for (String token : tokens) {
try {
// try to parse an integer
int number = Integer.parseInt(token);
-
+
if (number <= 0) {
- throw new DescriptorException("Multiplicator ("+number+") must be > 0");
+ throw new DescriptorException("Multiplicator (" + number + ") must be > 0");
}
if (multiplicator > 0) {
throw new DescriptorException("A multiplicator cannot be followed by another multiplicator");
@@ -108,7 +111,7 @@
multiplicator = 1;
}
- for (int index=0;index<multiplicator; index++) {
+ for (int index = 0; index < multiplicator; index++) {
descriptor.append(token).append(' ');
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java Sat Feb 13 20:27:25 2010
@@ -23,63 +23,66 @@
* Represents one data instance.
*/
public class Instance {
-
+
/** instance unique id */
public final int id;
-
+
/** attributes, except LABEL and IGNORED */
private final Vector attrs;
-
+
/**
* instance label code.<br>
* use Dataset.labels to get the real label value
*
*/
public final int label;
-
+
public Instance(int id, Vector attrs, int label) {
this.id = id;
this.attrs = attrs;
this.label = label;
}
-
+
/**
* Return the attribute at the specified position
*
- * @param index position of the attribute to retrieve
+ * @param index
+ * position of the attribute to retrieve
* @return value of the attribute
*/
public double get(int index) {
return attrs.getQuick(index);
}
-
+
/**
* Set the value at the given index
*
* @param index
- * @param value a double value to set
+ * @param value
+ * a double value to set
*/
public void set(int index, double value) {
attrs.set(index, value);
}
-
+
@Override
public boolean equals(Object obj) {
- if (this == obj)
+ if (this == obj) {
return true;
- if (obj == null || !(obj instanceof Instance))
+ }
+ if ((obj == null) || !(obj instanceof Instance)) {
return false;
+ }
- Instance instance = (Instance)obj;
+ Instance instance = (Instance) obj;
- return id == instance.id && label == instance.label && attrs.equals(instance.attrs);
+ return (id == instance.id) && (label == instance.label) && attrs.equals(instance.attrs);
}
-
+
@Override
public int hashCode() {
return id + label + attrs.hashCode();
}
-
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Condition.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Condition.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Condition.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Condition.java Sat Feb 13 20:27:25 2010
@@ -23,18 +23,18 @@
* Condition on Instance
*/
public abstract class Condition {
-
+
/**
* Returns true is the checked instance matches the condition
*
- * @param instance checked instance
+ * @param instance
+ * checked instance
* @return true is the checked instance matches the condition
*/
public abstract boolean isTrueFor(Instance instance);
-
+
/**
- * Condition that checks if the given attribute has a value "equal" to the
- * given value
+ * Condition that checks if the given attribute has a value "equal" to the given value
*
* @param attr
* @param value
@@ -43,10 +43,9 @@
public static Condition equals(int attr, double value) {
return new Equals(attr, value);
}
-
+
/**
- * Condition that checks if the given attribute has a value "lesser" than the
- * given value
+ * Condition that checks if the given attribute has a value "lesser" than the given value
*
* @param attr
* @param value
@@ -55,10 +54,9 @@
public static Condition lesser(int attr, double value) {
return new Lesser(attr, value);
}
-
+
/**
- * Condition that checks if the given attribute has a value "greater or equal"
- * than the given value
+ * Condition that checks if the given attribute has a value "greater or equal" than the given value
*
* @param attr
* @param value
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Equals.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Equals.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Equals.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Equals.java Sat Feb 13 20:27:25 2010
@@ -23,19 +23,19 @@
* True if a given attribute has a given value
*/
public class Equals extends Condition {
-
+
private final int attr;
-
+
private final double value;
-
+
public Equals(int attr, double value) {
this.attr = attr;
this.value = value;
}
-
+
@Override
public boolean isTrueFor(Instance instance) {
return instance.get(attr) == value;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/GreaterOrEquals.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/GreaterOrEquals.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/GreaterOrEquals.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/GreaterOrEquals.java Sat Feb 13 20:27:25 2010
@@ -23,19 +23,19 @@
* True if a given attribute has a value "greater or equal" than a given value
*/
public class GreaterOrEquals extends Condition {
-
+
private final int attr;
-
+
private final double value;
-
+
public GreaterOrEquals(int attr, double value) {
this.attr = attr;
this.value = value;
}
-
+
@Override
public boolean isTrueFor(Instance v) {
return v.get(attr) >= value;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Lesser.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Lesser.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Lesser.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/conditions/Lesser.java Sat Feb 13 20:27:25 2010
@@ -23,19 +23,19 @@
* True if a given attribute has a value "lesser" than a given value
*/
public class Lesser extends Condition {
-
+
private final int attr;
-
+
private final double value;
-
+
public Lesser(int attr, double value) {
this.attr = attr;
this.value = value;
}
-
+
@Override
public boolean isTrueFor(Instance instance) {
return instance.get(attr) < value;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java Sat Feb 13 20:27:25 2010
@@ -29,69 +29,65 @@
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.common.StringUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.builder.TreeBuilder;
import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Dataset;
-import org.apache.mahout.common.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
- * Base class for Mapred DecisionForest builders. Takes care of storing the
- * parameters common to the mapred implementations.<br>
+ * Base class for Mapred DecisionForest builders. Takes care of storing the parameters common to the mapred
+ * implementations.<br>
* The child classes must implement at least :
* <ul>
- * <li> void configureJob(JobConf) : to further configure the job before its
- * launch; and </li>
- * <li> DecisionForest parseOutput(JobConf, PredictionCallback) : in order to
- * convert the job outputs into a DecisionForest and its corresponding oob
- * predictions </li>
+ * <li>void configureJob(JobConf) : to further configure the job before its launch; and</li>
+ * <li>DecisionForest parseOutput(JobConf, PredictionCallback) : in order to convert the job outputs into a
+ * DecisionForest and its corresponding oob predictions</li>
* </ul>
*
*/
public abstract class Builder {
-
+
private static final Logger log = LoggerFactory.getLogger(Builder.class);
-
+
/** Tree Builder Component */
private final TreeBuilder treeBuilder;
-
+
private final Path dataPath;
-
+
private final Path datasetPath;
-
+
private final Long seed;
-
+
private final Configuration conf;
-
+
private String outputDirName = "output";
-
-
+
protected TreeBuilder getTreeBuilder() {
return treeBuilder;
}
-
+
protected Path getDataPath() {
return dataPath;
}
-
+
protected Path getDatasetPath() {
return datasetPath;
}
-
+
protected Long getSeed() {
return seed;
}
-
+
protected Configuration getConf() {
return conf;
}
-
/**
- * Used only for DEBUG purposes. if false, the mappers doesn't output anything,
- * so the builder has nothing to process
+ * Used only for DEBUG purposes. if false, the mappers doesn't output anything, so the builder has nothing
+ * to process
*
* @param conf
* @return
@@ -99,15 +95,15 @@
protected static boolean isOutput(Configuration conf) {
return conf.getBoolean("debug.mahout.rf.output", true);
}
-
+
protected static boolean isOobEstimate(Configuration conf) {
return conf.getBoolean("mahout.rf.oob", false);
}
-
+
private static void setOobEstimate(Configuration conf, boolean value) {
conf.setBoolean("mahout.rf.oob", value);
}
-
+
/**
* Returns the random seed
*
@@ -116,12 +112,13 @@
*/
public static Long getRandomSeed(Configuration conf) {
String seed = conf.get("mahout.rf.random.seed");
- if (seed == null)
+ if (seed == null) {
return null;
-
+ }
+
return Long.valueOf(seed);
}
-
+
/**
* Sets the random seed value
*
@@ -131,19 +128,20 @@
private static void setRandomSeed(Configuration conf, long seed) {
conf.setLong("mahout.rf.random.seed", seed);
}
-
+
public static TreeBuilder getTreeBuilder(Configuration conf) {
String string = conf.get("mahout.rf.treebuilder");
- if (string == null)
+ if (string == null) {
return null;
-
+ }
+
return StringUtils.fromString(string);
}
-
+
private static void setTreeBuilder(Configuration conf, TreeBuilder treeBuilder) {
conf.set("mahout.rf.treebuilder", StringUtils.toString(treeBuilder));
}
-
+
/**
* Get the number of trees for the map-reduce job. The default value is 100
*
@@ -153,31 +151,35 @@
public static int getNbTrees(Configuration conf) {
return conf.getInt("mahout.rf.nbtrees", -1);
}
-
+
/**
* Set the number of trees to grow for the map-reduce job
*
* @param conf
* @param nbTrees
- * @throws IllegalArgumentException if (nbTrees <= 0)
+ * @throws IllegalArgumentException
+ * if (nbTrees <= 0)
*/
public static void setNbTrees(Configuration conf, int nbTrees) {
- if (nbTrees <= 0)
+ if (nbTrees <= 0) {
throw new IllegalArgumentException("nbTrees should be greater than 0");
-
+ }
+
conf.setInt("mahout.rf.nbtrees", nbTrees);
}
-
+
/**
* Sets the Output directory name, will be creating in the working directory
+ *
* @param name
*/
public void setOutputDirName(String name) {
outputDirName = name;
}
-
+
/**
* Output Directory name
+ *
* @param conf
* @return
* @throws IOException
@@ -188,35 +190,34 @@
FileSystem fs = FileSystem.get(conf);
return new Path(fs.getWorkingDirectory(), outputDirName);
}
-
- protected Builder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath,
- Long seed, Configuration conf) {
+
+ protected Builder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed, Configuration conf) {
this.treeBuilder = treeBuilder;
this.dataPath = dataPath;
this.datasetPath = datasetPath;
this.seed = seed;
this.conf = conf;
}
-
+
/**
* Helper method. Get a path from the DistributedCache
*
* @param job
- * @param index index of the path in the DistributedCache files
+ * @param index
+ * index of the path in the DistributedCache files
* @return
* @throws IOException
*/
- public static Path getDistributedCacheFile(Configuration job, int index)
- throws IOException {
+ public static Path getDistributedCacheFile(Configuration job, int index) throws IOException {
URI[] files = DistributedCache.getCacheFiles(job);
-
- if (files == null || files.length < index) {
+
+ if ((files == null) || (files.length < index)) {
throw new IOException("path not found in the DistributedCache");
}
-
+
return new Path(files[index].getPath());
}
-
+
/**
* Helper method. Load a Dataset stored in the DistributedCache
*
@@ -225,80 +226,81 @@
* @throws IOException
*/
public static Dataset loadDataset(JobConf job) throws IOException {
- Path datasetPath = getDistributedCacheFile(job, 0);
-
+ Path datasetPath = Builder.getDistributedCacheFile(job, 0);
+
return Dataset.load(job, datasetPath);
}
-
+
/**
* Used by the inheriting classes to configure the job
*
* @param conf
- * @param nbTrees number of trees to grow
- * @param oobEstimate true, if oob error should be estimated
+ * @param nbTrees
+ * number of trees to grow
+ * @param oobEstimate
+ * true, if oob error should be estimated
* @throws IOException
*/
- protected abstract void configureJob(JobConf conf, int nbTrees,
- boolean oobEstimate) throws IOException;
-
+ protected abstract void configureJob(JobConf conf, int nbTrees, boolean oobEstimate) throws IOException;
+
/**
- * Sequential implementation should override this method to simulate the job
- * execution
+ * Sequential implementation should override this method to simulate the job execution
*/
protected void runJob(JobConf job) throws IOException {
JobClient.runJob(job);
}
-
+
/**
- * Parse the output files to extract the trees and pass the predictions to the
- * callback
+ * Parse the output files to extract the trees and pass the predictions to the callback
*
* @param job
- * @param callback can be null
+ * @param callback
+ * can be null
* @return
* @throws IOException
*/
- protected abstract DecisionForest parseOutput(JobConf job,
- PredictionCallback callback) throws IOException;
-
+ protected abstract DecisionForest parseOutput(JobConf job, PredictionCallback callback) throws IOException;
+
public DecisionForest build(int nbTrees, PredictionCallback callback) throws IOException {
JobConf job = new JobConf(conf, Builder.class);
-
+
Path outputPath = getOutputPath(job);
FileSystem fs = outputPath.getFileSystem(job);
-
+
// check the output
- if (fs.exists(outputPath))
+ if (fs.exists(outputPath)) {
throw new IOException("Output path already exists : " + outputPath);
-
- if (seed != null)
- setRandomSeed(job, seed);
- setNbTrees(job, nbTrees);
- setTreeBuilder(job, treeBuilder);
- setOobEstimate(job, callback != null);
-
+ }
+
+ if (seed != null) {
+ Builder.setRandomSeed(job, seed);
+ }
+ Builder.setNbTrees(job, nbTrees);
+ Builder.setTreeBuilder(job, treeBuilder);
+ Builder.setOobEstimate(job, callback != null);
+
// put the dataset into the DistributedCache
DistributedCache.addCacheFile(datasetPath.toUri(), job);
-
- log.debug("Configuring the job...");
+
+ Builder.log.debug("Configuring the job...");
configureJob(job, nbTrees, callback != null);
-
- log.debug("Running the job...");
+
+ Builder.log.debug("Running the job...");
runJob(job);
-
- if (isOutput(job)) {
- log.debug("Parsing the output...");
+
+ if (Builder.isOutput(job)) {
+ Builder.log.debug("Parsing the output...");
DecisionForest forest = parseOutput(job, callback);
-
+
// delete the output path
fs.delete(outputPath, true);
-
+
return forest;
}
-
+
return null;
}
-
+
/**
* sort the splits into order based on size, so that the biggest go first.<br>
* This is the same code used by Hadoop's JobClient.
@@ -325,5 +327,5 @@
}
});
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java?rev=909900&r1=909899&r2=909900&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java Sat Feb 13 20:27:25 2010
@@ -28,15 +28,15 @@
* Base class for Mapred mappers. Loads common parameters from the job
*/
public class MapredMapper extends MapReduceBase {
-
+
private boolean noOutput;
-
+
private boolean oobEstimate;
-
+
private TreeBuilder treeBuilder;
-
+
private Dataset dataset;
-
+
/**
*
* @return if false, the mapper does not output
@@ -44,7 +44,7 @@
protected boolean isOobEstimate() {
return oobEstimate;
}
-
+
/**
*
* @return if false, the mapper does not estimate and output predictions
@@ -52,27 +52,27 @@
protected boolean isNoOutput() {
return noOutput;
}
-
+
protected TreeBuilder getTreeBuilder() {
return treeBuilder;
}
-
+
protected Dataset getDataset() {
return dataset;
}
-
+
@Override
public void configure(JobConf conf) {
super.configure(conf);
-
+
try {
- configure(!Builder.isOutput(conf), Builder.isOobEstimate(conf), Builder
- .getTreeBuilder(conf), Builder.loadDataset(conf));
+ configure(!Builder.isOutput(conf), Builder.isOobEstimate(conf), Builder.getTreeBuilder(conf), Builder
+ .loadDataset(conf));
} catch (IOException e) {
throw new IllegalStateException("Exception caught while configuring the mapper: ", e);
}
}
-
+
/**
* Useful for testing
*
@@ -81,16 +81,15 @@
* @param treeBuilder
* @param dataset
*/
- protected void configure(boolean noOutput, boolean oobEstimate,
- TreeBuilder treeBuilder, Dataset dataset) {
+ protected void configure(boolean noOutput, boolean oobEstimate, TreeBuilder treeBuilder, Dataset dataset) {
this.noOutput = noOutput;
this.oobEstimate = oobEstimate;
-
+
if (treeBuilder == null) {
throw new IllegalArgumentException("TreeBuilder not found in the Job parameters");
}
this.treeBuilder = treeBuilder;
-
+
this.dataset = dataset;
}
}