You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2011/10/23 21:26:20 UTC
svn commit: r1187953 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/df/builder/
core/src/main/java/org/apache/mahout/df/data/
core/src/main/java/org/apache/mahout/df/mapreduce/
core/src/main/java/org/apache/mahout/df/node/ core/src/main/java/...
Author: adeneche
Date: Sun Oct 23 19:26:19 2011
New Revision: 1187953
URL: http://svn.apache.org/viewvc?rev=1187953&view=rev
Log:
MAHOUT-840 target attribute can now be numerical, although regression is still not supported
Removed:
mahout/trunk/core/src/main/java/org/apache/mahout/df/node/MockLeaf.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java
mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java Sun Oct 23 19:26:19 2011
@@ -69,6 +69,7 @@ public class DefaultTreeBuilder implemen
if (selected == null) {
selected = new boolean[data.getDataset().nbAttributes()];
+ selected[data.getDataset().getLabelId()] = true; // never select the label
}
if (data.isEmpty()) {
@@ -78,7 +79,7 @@ public class DefaultTreeBuilder implemen
return new Leaf(data.majorityLabel(rng));
}
if (data.identicalLabel()) {
- return new Leaf(data.get(0).getLabel());
+ return new Leaf(data.getDataset().getLabel(data.get(0)));
}
int[] attributes = randomAttributes(rng, selected, m);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java Sun Oct 23 19:26:19 2011
@@ -215,9 +215,9 @@ public class Data implements Cloneable {
return true;
}
- int label = get(0).getLabel();
+ int label = dataset.getLabel(get(0));
for (int index = 1; index < size(); index++) {
- if (get(index).getLabel() != label) {
+ if (dataset.getLabel(get(index)) != label) {
return false;
}
}
@@ -278,7 +278,7 @@ public class Data implements Cloneable {
int[] labels = new int[size()];
for (int index = 0; index < labels.length; index++) {
- labels[index] = get(index).getLabel();
+ labels[index] = dataset.getLabel(get(index));
}
return labels;
@@ -300,10 +300,12 @@ public class Data implements Cloneable {
int[] labels = new int[dataset.nbInstances()];
DataConverter converter = new DataConverter(dataset);
+ int labelId = dataset.getLabelId();
+
try {
int index = 0;
while (iterator.hasNext()) {
- labels[index++] = converter.convert(0, iterator.next()).getLabel();
+ labels[index++] = (int) converter.convert(0, iterator.next()).get(labelId);
}
} finally {
Closeables.closeQuietly(iterator);
@@ -322,7 +324,7 @@ public class Data implements Cloneable {
int[] counts = new int[dataset.nblabels()];
for (int index = 0; index < size(); index++) {
- counts[get(index).getLabel()]++;
+ counts[dataset.getLabel(get(index))]++;
}
// find the label values that appears the most
@@ -337,7 +339,7 @@ public class Data implements Cloneable {
*/
public void countLabels(int[] counts) {
for (int index = 0; index < size(); index++) {
- counts[get(index).getLabel()]++;
+ counts[dataset.getLabel(get(index))]++;
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java Sun Oct 23 19:26:19 2011
@@ -43,8 +43,8 @@ public class DataConverter {
}
public Instance convert(int id, CharSequence string) {
- // all attributes (categorical, numerical), ignored, label
- int nball = dataset.nbAttributes() + dataset.getIgnored().length + 1;
+ // all attributes (categorical, numerical, label), ignored
+ int nball = dataset.nbAttributes() + dataset.getIgnored().length;
String[] tokens = COMMA_SPACE.split(string);
Preconditions.checkArgument(tokens.length == nball, "Wrong number of attributes in the string");
@@ -55,26 +55,28 @@ public class DataConverter {
int aId = 0;
int label = -1;
for (int attr = 0; attr < nball; attr++) {
- String token = tokens[attr].trim();
-
if (ArrayUtils.contains(dataset.getIgnored(), attr)) {
continue; // IGNORED
}
+
+ String token = tokens[attr].trim();
if ("?".equals(token)) {
// missing value
return null;
}
- if (attr == dataset.getLabelId()) {
+ if (aId == dataset.getLabelId()) {
label = dataset.labelCode(token);
if (label == -1) {
log.error("label token: {} dataset.labels: {}", token, Arrays.toString(dataset.labels()));
throw new IllegalStateException("Label value (" + token + ") not known");
}
- } else if (dataset.isNumerical(aId)) {
+ }
+
+ if (dataset.isNumerical(aId)) {
vector.set(aId++, Double.parseDouble(token));
- } else {
+ } else { // CATEGORICAL/LABEL
vector.set(aId, dataset.valueOf(aId, token));
aId++;
}
@@ -85,6 +87,6 @@ public class DataConverter {
throw new IllegalStateException("Label not found!");
}
- return new Instance(id, vector, label);
+ return new Instance(id, vector);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java Sun Oct 23 19:26:19 2011
@@ -112,7 +112,7 @@ public final class DataLoader {
throw new IllegalStateException("Label not found!");
}
- return new Instance(id, vector, label);
+ return new Instance(id, vector);
}
/**
@@ -188,12 +188,14 @@ public final class DataLoader {
*
* @param descriptor
* attributes description
+ * @param regression
+ * if true, the label is numerical
* @param fs
* file system
* @param path
* data path
*/
- public static Dataset generateDataset(String descriptor, FileSystem fs, Path path) throws DescriptorException,
+ public static Dataset generateDataset(String descriptor, boolean regression, FileSystem fs, Path path) throws DescriptorException,
IOException {
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
@@ -217,7 +219,7 @@ public final class DataLoader {
scanner.close();
- return new Dataset(attrs, values, id);
+ return new Dataset(attrs, values, id, regression);
}
/**
@@ -226,7 +228,7 @@ public final class DataLoader {
* @param descriptor
* attributes description
*/
- public static Dataset generateDataset(String descriptor, String[] data) throws DescriptorException {
+ public static Dataset generateDataset(String descriptor, boolean regression, String[] data) throws DescriptorException {
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
// used to convert CATEGORICAL and LABEL attributes to Integer
@@ -243,7 +245,7 @@ public final class DataLoader {
}
}
- return new Dataset(attrs, values, id);
+ return new Dataset(attrs, values, id, regression);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java Sun Oct 23 19:26:19 2011
@@ -68,16 +68,13 @@ public class Dataset implements Writable
private Attribute[] attributes;
- /** all distinct labels */
- private String[] labels;
-
/** list of ignored attributes */
private int[] ignored;
/** distinct values (CATEGORIAL attributes only) */
private String[][] values;
- /** index of the label attribute in the original data */
+ /** index of the label attribute in the loaded data (without ignored attributed) */
private int labelId;
/** number of instances in the dataset */
@@ -94,7 +91,8 @@ public class Dataset implements Writable
* distinct values for all CATEGORICAL attributes
* @param nbInstances
*/
- protected Dataset(Attribute[] attrs, List<String>[] values, int nbInstances) {
+ protected Dataset(Attribute[] attrs, List<String>[] values, int nbInstances, boolean regression) {
+ Preconditions.checkArgument(regression == false, "Regression Problems not supported");
validateValues(attrs, values);
int nbattrs = countAttributes(attrs);
@@ -102,7 +100,7 @@ public class Dataset implements Writable
// the label values are set apart
attributes = new Attribute[nbattrs];
this.values = new String[nbattrs][];
- ignored = new int[attrs.length - (nbattrs + 1)]; // nbignored = total - (nbattrs + label)
+ ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs
labelId = -1;
int ignoredId = 0;
@@ -117,11 +115,10 @@ public class Dataset implements Writable
if (labelId != -1) {
throw new IllegalStateException("Label found more than once");
}
- labelId = attr;
- continue;
+ labelId = ind;
}
- if (attrs[attr].isCategorical()) {
+ if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) {
this.values[ind] = new String[values[attr].size()];
values[attr].toArray(this.values[ind]);
}
@@ -133,24 +130,25 @@ public class Dataset implements Writable
throw new IllegalStateException("Label not found");
}
- labels = new String[values[labelId].size()];
- values[labelId].toArray(labels);
-
this.nbInstances = nbInstances;
}
public String[] labels() {
- return Arrays.copyOf(labels, labels.length);
+ return Arrays.copyOf(values[labelId], nblabels());
}
public int nblabels() {
- return labels.length;
+ return values[labelId].length;
}
public int getLabelId() {
return labelId;
}
+ public int getLabel(Instance instance) {
+ return (int) instance.get(getLabelId());
+ }
+
public int nbInstances() {
return nbInstances;
}
@@ -163,12 +161,15 @@ public class Dataset implements Writable
* @return label's code
*/
public int labelCode(String label) {
- return ArrayUtils.indexOf(labels, label);
+ return ArrayUtils.indexOf(values[labelId], label);
}
- public String getLabel(int code) {
- // TODO should handle the case (prediction == -1)
- return labels[code];
+ public String getLabelString(int code) {
+ // handle the case (prediction == -1)
+ if (code == -1) {
+ return "unknown";
+ }
+ return values[labelId][code];
}
/**
@@ -189,15 +190,13 @@ public class Dataset implements Writable
/**
- * Counts the number of attributes, except IGNORED and LABEL
- *
- * @return number of attributes that are not IGNORED or LABEL
+ * @return number of attributes that are not IGNORED
*/
protected static int countAttributes(Attribute[] attrs) {
int nbattrs = 0;
- for (Attribute attr1 : attrs) {
- if (attr1.isNumerical() || attr1.isCategorical()) {
+ for (Attribute attr : attrs) {
+ if (!attr.isIgnored()) {
nbattrs++;
}
}
@@ -208,7 +207,7 @@ public class Dataset implements Writable
private static void validateValues(Attribute[] attrs, List<String>[] values) {
Preconditions.checkArgument(attrs.length == values.length, "attrs.length != values.length");
for (int attr = 0; attr < attrs.length; attr++) {
- Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null,
+ Preconditions.checkArgument(!(attrs[attr].isCategorical() || attrs[attr].isLabel()) || values[attr] != null,
"values not found for attribute " + attr);
}
}
@@ -246,10 +245,6 @@ public class Dataset implements Writable
return false;
}
- if (!Arrays.equals(labels, dataset.labels)) {
- return false;
- }
-
for (int attr = 0; attr < nbAttributes(); attr++) {
if (!Arrays.equals(values[attr], dataset.values[attr])) {
return false;
@@ -265,10 +260,8 @@ public class Dataset implements Writable
for (Attribute attr : attributes) {
hashCode = 31 * hashCode + attr.hashCode();
}
- for (String label : labels) {
- hashCode = 31 * hashCode + label.hashCode();
- }
for (String[] valueRow : values) {
+ if (valueRow == null) continue;
for (String value : valueRow) {
hashCode = 31 * hashCode + value.hashCode();
}
@@ -305,14 +298,12 @@ public class Dataset implements Writable
attributes[attr] = Attribute.valueOf(name);
}
- labels = WritableUtils.readStringArray(in);
-
ignored = DFUtils.readIntArray(in);
- // only CATEGORICAL attributes have values
+ // only CATEGORICAL/LABEL attributes have values
values = new String[nbAttributes][];
for (int attr = 0; attr < nbAttributes; attr++) {
- if (attributes[attr].isCategorical()) {
+ if (attributes[attr].isCategorical() || attributes[attr].isLabel()) {
values[attr] = WritableUtils.readStringArray(in);
}
}
@@ -328,8 +319,6 @@ public class Dataset implements Writable
WritableUtils.writeString(out, attr.name());
}
- WritableUtils.writeStringArray(out, labels);
-
DFUtils.writeArray(out, ignored);
// only CATEGORICAL attributes have values
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java Sun Oct 23 19:26:19 2011
@@ -29,12 +29,9 @@ public class Instance {
/** attributes, except LABEL and IGNORED */
private final Vector attrs;
- private final int label;
-
- public Instance(int id, Vector attrs, int label) {
+ public Instance(int id, Vector attrs) {
this.id = id;
this.attrs = attrs;
- this.label = label;
}
/**
@@ -70,26 +67,17 @@ public class Instance {
Instance instance = (Instance) obj;
- return id == instance.id && label == instance.label && attrs.equals(instance.attrs);
+ return id == instance.id && attrs.equals(instance.attrs);
}
@Override
public int hashCode() {
- return id + label + attrs.hashCode();
+ return id + attrs.hashCode();
}
/** instance unique id */
public int getId() {
return id;
}
-
- /**
- * instance label code.<br>
- * use Dataset.labels to get the real label value
- *
- */
- public int getLabel() {
- return label;
- }
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java Sun Oct 23 19:26:19 2011
@@ -173,8 +173,8 @@ public class Classifier {
ofile.writeChar('\n');
if (analyzer != null) {
- analyzer.addInstance(dataset.getLabel(key),
- new ClassifierResult(dataset.getLabel(Integer.parseInt(value)), 1.0));
+ analyzer.addInstance(dataset.getLabelString(key),
+ new ClassifierResult(dataset.getLabelString(Integer.parseInt(value)), 1.0));
}
}
}
@@ -204,6 +204,7 @@ public class Classifier {
private final Random rng = RandomUtils.getRandom();
private boolean first = true;
private final Text lvalue = new Text();
+ private Dataset dataset;
@Override
protected void setup(Context context) throws IOException, InterruptedException {
@@ -216,8 +217,8 @@ public class Classifier {
if (files == null || files.length < 2) {
throw new IOException("not enough paths in the DistributedCache");
}
-
- Dataset dataset = Dataset.load(conf, new Path(files[0].getPath()));
+
+ dataset = Dataset.load(conf, new Path(files[0].getPath()));
converter = new DataConverter(dataset);
@@ -242,7 +243,7 @@ public class Classifier {
if (!line.isEmpty()) {
Instance instance = converter.convert(0, line);
int prediction = forest.classify(rng, instance);
- key.set(instance.getLabel());
+ key.set(dataset.getLabel(instance));
lvalue.set(Integer.toString(prediction));
context.write(key, lvalue);
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java Sun Oct 23 19:26:19 2011
@@ -30,7 +30,6 @@ import org.apache.mahout.df.data.Instanc
public abstract class Node implements Writable {
protected enum Type {
- MOCKLEAF,
LEAF,
NUMERICAL,
CATEGORICAL
@@ -60,9 +59,6 @@ public abstract class Node implements Wr
Node node;
switch (type) {
- case MOCKLEAF:
- node = new MockLeaf();
- break;
case LEAF:
node = new Leaf();
break;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java Sun Oct 23 19:26:19 2011
@@ -22,6 +22,7 @@ import java.util.Arrays;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataUtils;
+import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.Instance;
/**
@@ -52,11 +53,13 @@ public class OptIgSplit extends IgSplit
int[][] counts = new int[values.length][data.getDataset().nblabels()];
int[] countAll = new int[data.getDataset().nblabels()];
+ Dataset dataset = data.getDataset();
+
// compute frequencies
for (int index = 0; index < data.size(); index++) {
Instance instance = data.get(index);
- counts[ArrayUtils.indexOf(values, instance.get(attr))][instance.getLabel()]++;
- countAll[instance.getLabel()]++;
+ counts[ArrayUtils.indexOf(values, instance.get(attr))][dataset.getLabel(instance)]++;
+ countAll[dataset.getLabel(instance)]++;
}
int size = data.size();
@@ -93,10 +96,12 @@ public class OptIgSplit extends IgSplit
}
protected void computeFrequencies(Data data, int attr, double[] values) {
+ Dataset dataset = data.getDataset();
+
for (int index = 0; index < data.size(); index++) {
Instance instance = data.get(index);
- counts[ArrayUtils.indexOf(values, instance.get(attr))][instance.getLabel()]++;
- countAll[instance.getLabel()]++;
+ counts[ArrayUtils.indexOf(values, instance.get(attr))][dataset.getLabel(instance)]++;
+ countAll[dataset.getLabel(instance)]++;
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java Sun Oct 23 19:26:19 2011
@@ -68,11 +68,14 @@ public final class Describe {
abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription(
"Path to generated descriptor file").create();
+ Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r")
+ .create();
+
Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
.create();
Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption(
- descriptorOpt).withOption(helpOpt).create();
+ descriptorOpt).withOption(regOpt).withOption(helpOpt).create();
try {
Parser parser = new Parser();
@@ -87,19 +90,21 @@ public final class Describe {
String dataPath = cmdLine.getValue(pathOpt).toString();
String descPath = cmdLine.getValue(descPathOpt).toString();
List<String> descriptor = convert(cmdLine.getValues(descriptorOpt));
+ boolean regression = cmdLine.hasOption(regOpt);
log.debug("Data path : {}", dataPath);
log.debug("Descriptor path : {}", descPath);
log.debug("Descriptor : {}", descriptor);
+ log.debug("Regression : {}", regression);
- runTool(dataPath, descriptor, descPath);
+ runTool(dataPath, descriptor, descPath, regression);
} catch (OptionException e) {
log.warn(e.toString());
CommandLineUtil.printHelp(group);
}
}
- private static void runTool(String dataPath, Iterable<String> description, String filePath)
+ private static void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression)
throws DescriptorException, IOException {
log.info("Generating the descriptor...");
String descriptor = DescriptorUtils.generateDescriptor(description);
@@ -107,17 +112,17 @@ public final class Describe {
Path fPath = validateOutput(filePath);
log.info("generating the dataset...");
- Dataset dataset = generateDataset(descriptor, dataPath);
+ Dataset dataset = generateDataset(descriptor, dataPath, regression);
log.info("storing the dataset description");
DFUtils.storeWritable(new Configuration(), fPath, dataset);
}
- private static Dataset generateDataset(String descriptor, String dataPath) throws IOException, DescriptorException {
+ private static Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException, DescriptorException {
Path path = new Path(dataPath);
FileSystem fs = path.getFileSystem(new Configuration());
- return DataLoader.generateDataset(descriptor, fs, path);
+ return DataLoader.generateDataset(descriptor, regression, fs, path);
}
private static Path validateOutput(String filePath) throws IOException {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java Sun Oct 23 19:26:19 2011
@@ -163,12 +163,13 @@ public class FrequenciesJob {
private LongWritable firstId;
private DataConverter converter;
+ private Dataset dataset;
@Override
protected void setup(Context context) throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
- Dataset dataset = Builder.loadDataset(conf);
+ dataset = Builder.loadDataset(conf);
setup(dataset);
}
@@ -188,7 +189,7 @@ public class FrequenciesJob {
Instance instance = converter.convert((int) key.get(), value.toString());
- context.write(firstId, new IntWritable(instance.getLabel()));
+ context.write(firstId, new IntWritable(dataset.getLabel(instance)));
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java Sun Oct 23 19:26:19 2011
@@ -175,7 +175,7 @@ public final class UDistrib {
// write the tuple in files[tuple.label]
Instance instance = converter.convert(id++, line);
- int label = instance.getLabel();
+ int label = dataset.getLabel(instance);
files[currents[label]].writeBytes(line);
files[currents[label]].writeChar('\n');
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java Sun Oct 23 19:26:19 2011
@@ -47,7 +47,7 @@ public final class InfiniteRecursionTest
String[] source = Utils.double2String(dData);
String descriptor = "N N N N N N N N L";
- Dataset dataset = DataLoader.generateDataset(descriptor, source);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, source);
Data data = DataLoader.loadData(dataset, source);
builder.build(rng, data);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java Sun Oct 23 19:26:19 2011
@@ -34,9 +34,9 @@ public final class DataConverterTest ext
Random rng = RandomUtils.getRandom();
String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
- double[][] source = Utils.randomDoubles(rng, descriptor, INSTANCE_COUNT);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, INSTANCE_COUNT);
String[] sData = Utils.double2String(source);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
Data data = DataLoader.loadData(dataset, sData);
DataConverter converter = new DataConverter(dataset);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java Sun Oct 23 19:26:19 2011
@@ -49,10 +49,10 @@ public final class DataLoaderTest extend
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
// prepare the data
- double[][] data = Utils.randomDoubles(rng, descriptor, datasize);
+ double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
Collection<Integer> missings = Lists.newArrayList();
String[] sData = prepareData(data, attrs, missings);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
Data loaded = DataLoader.loadData(dataset, sData);
testLoadedData(data, attrs, missings, loaded);
@@ -73,12 +73,12 @@ public final class DataLoaderTest extend
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
// prepare the data
- double[][] data = Utils.randomDoubles(rng, descriptor, datasize);
+ double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
Collection<Integer> missings = Lists.newArrayList();
String[] sData = prepareData(data, attrs, missings);
- Dataset expected = DataLoader.generateDataset(descriptor, sData);
+ Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
assertEquals(expected, dataset);
}
@@ -157,13 +157,13 @@ public final class DataLoaderTest extend
if (attrs[attr].isNumerical()) {
assertEquals(vector[attr], instance.get(aId++), EPSILON);
- } else if (attrs[attr].isCategorical()) {
+ } else if (attrs[attr].isCategorical()||attrs[attr].isLabel()) {
checkCategorical(data, missings, loaded, attr, aId, vector[attr],
instance.get(aId));
aId++;
- } else if (attrs[attr].isLabel()) {
+ } /*else if (attrs[attr].isLabel()) {
checkLabel(data, missings, loaded, attr, vector[attr]);
- }
+ }*/
}
lind++;
@@ -192,7 +192,7 @@ public final class DataLoaderTest extend
int aId = 0;
for (int attr = 0; attr < nbAttributes; attr++) {
- if (attrs[attr].isIgnored() || attrs[attr].isLabel()) {
+ if (attrs[attr].isIgnored()) {
continue;
}
@@ -220,10 +220,10 @@ public final class DataLoaderTest extend
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
// prepare the data
- double[][] source = Utils.randomDoubles(rng, descriptor, datasize);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
Collection<Integer> missings = Lists.newArrayList();
String[] sData = prepareData(source, attrs, missings);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
Path dataPath = Utils.writeDataToTestFile(sData);
FileSystem fs = dataPath.getFileSystem(new Configuration());
@@ -246,15 +246,15 @@ public final class DataLoaderTest extend
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
// prepare the data
- double[][] source = Utils.randomDoubles(rng, descriptor, datasize);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
Collection<Integer> missings = Lists.newArrayList();
String[] sData = prepareData(source, attrs, missings);
- Dataset expected = DataLoader.generateDataset(descriptor, sData);
+ Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
Path path = Utils.writeDataToTestFile(sData);
FileSystem fs = path.getFileSystem(new Configuration());
- Dataset dataset = DataLoader.generateDataset(descriptor, fs, path);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, fs, path);
assertEquals(expected, dataset);
}
@@ -304,6 +304,8 @@ public final class DataLoaderTest extend
Data loaded,
int labelInd,
double value) {
+ Dataset dataset = loaded.getDataset();
+
// label's code that corresponds to the value
int code = loaded.getDataset().labelCode(Double.toString(value));
@@ -315,9 +317,9 @@ public final class DataLoaderTest extend
}
if (source[index][labelInd] == value) {
- assertEquals(code, loaded.get(lind).getLabel());
+ assertEquals(code, dataset.getLabel(loaded.get(lind)));
} else {
- assertFalse(code == loaded.get(lind).getLabel());
+ assertFalse(code == dataset.getLabel(loaded.get(lind)));
}
lind++;
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java Sun Oct 23 19:26:19 2011
@@ -39,7 +39,7 @@ public class DataTest extends MahoutTest
public void setUp() throws Exception {
super.setUp();
rng = RandomUtils.getRandom();
- data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+ data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
}
/**
@@ -82,7 +82,7 @@ public class DataTest extends MahoutTest
@Test
public void testValues() throws Exception {
- Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+ Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
for (int attr = 0; attr < data.getDataset().nbAttributes(); attr++) {
double[] values = data.values(attr);
@@ -108,14 +108,14 @@ public class DataTest extends MahoutTest
@Test
public void testIdenticalTrue() throws Exception {
// generate a small data, only to get the dataset
- Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, 1).getDataset();
+ Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
// test empty data
Data empty = new Data(dataset);
assertTrue(empty.isIdentical());
// test identical data, except for the labels
- Data identical = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+ Data identical = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
Instance model = identical.get(0);
for (int index = 1; index < DATA_SIZE; index++) {
for (int attr = 0; attr < identical.getDataset().nbAttributes(); attr++) {
@@ -131,7 +131,7 @@ public class DataTest extends MahoutTest
int n = 10;
for (int nloop = 0; nloop < n; nloop++) {
- Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+ Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
// choose a random instance
int index = rng.nextInt(DATA_SIZE);
@@ -148,7 +148,7 @@ public class DataTest extends MahoutTest
@Test
public void testIdenticalLabelTrue() throws Exception {
// generate a small data, only to get a dataset
- Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, 1).getDataset();
+ Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
// test empty data
Data empty = new Data(dataset);
@@ -156,11 +156,11 @@ public class DataTest extends MahoutTest
// test identical labels
String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
- double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor,
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
DATA_SIZE, rng.nextInt());
String[] sData = Utils.double2String(source);
- dataset = DataLoader.generateDataset(descriptor, sData);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
Data data = DataLoader.loadData(dataset, sData);
assertTrue(data.identicalLabel());
@@ -173,7 +173,7 @@ public class DataTest extends MahoutTest
for (int nloop = 0; nloop < n; nloop++) {
String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
int label = Utils.findLabel(descriptor);
- double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor,
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
DATA_SIZE, rng.nextInt());
// choose a random vector and change its label
int index = rng.nextInt(DATA_SIZE);
@@ -181,7 +181,7 @@ public class DataTest extends MahoutTest
String[] sData = Utils.double2String(source);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
Data data = DataLoader.loadData(dataset, sData);
assertFalse(data.identicalLabel());
@@ -237,8 +237,9 @@ public class DataTest extends MahoutTest
@Test
public void testCountLabel() throws Exception {
- Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
- int[] counts = new int[data.getDataset().nblabels()];
+ Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+ Dataset dataset = data.getDataset();
+ int[] counts = new int[dataset.nblabels()];
int n = 10;
@@ -247,7 +248,7 @@ public class DataTest extends MahoutTest
data.countLabels(counts);
for (int index=0;index<data.size();index++) {
- counts[data.get(index).getLabel()]--;
+ counts[dataset.getLabel(data.get(index))]--;
}
for (int label = 0; label < data.getDataset().nblabels(); label++) {
@@ -264,11 +265,11 @@ public class DataTest extends MahoutTest
int label = Utils.findLabel(descriptor);
int label1 = rng.nextInt();
- double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, 100,
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100,
label1);
String[] sData = Utils.double2String(source);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
Data data = DataLoader.loadData(dataset, sData);
int code1 = dataset.labelCode(Double.toString(label1));
@@ -286,7 +287,7 @@ public class DataTest extends MahoutTest
}
}
sData = Utils.double2String(source);
- dataset = DataLoader.generateDataset(descriptor, sData);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
data = DataLoader.loadData(dataset, sData);
int code2 = dataset.labelCode(Double.toString(label2));
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java Sun Oct 23 19:26:19 2011
@@ -51,7 +51,7 @@ public final class DatasetTest extends M
for (int nloop = 0; nloop < n; nloop++) {
byteOutStream.reset();
- Dataset dataset = Utils.randomData(rng, NUM_ATTRIBUTES, 1).getDataset();
+ Dataset dataset = Utils.randomData(rng, NUM_ATTRIBUTES, false, 1).getDataset();
dataset.write(out);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java Sun Oct 23 19:26:19 2011
@@ -106,16 +106,17 @@ public final class Utils {
*
* @param rng Random number generator
* @param nbAttributes number of attributes
+ * @param regression true is the label is numerical
* @param number of data lines to generate
*/
- public static double[][] randomDoubles(Random rng, int nbAttributes,int number) throws DescriptorException {
+ public static double[][] randomDoubles(Random rng, int nbAttributes, boolean regression, int number) throws DescriptorException {
String descriptor = randomDescriptor(rng, nbAttributes);
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
double[][] data = new double[number][];
for (int index = 0; index < number; index++) {
- data[index] = randomVector(rng, attrs);
+ data[index] = randomVector(rng, attrs, regression);
}
return data;
@@ -128,13 +129,13 @@ public final class Utils {
* @param descriptor attributes description
* @param number number of data lines to generate
*/
- public static double[][] randomDoubles(Random rng, CharSequence descriptor, int number) throws DescriptorException {
+ public static double[][] randomDoubles(Random rng, CharSequence descriptor, boolean regression, int number) throws DescriptorException {
Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
double[][] data = new double[number][];
for (int index = 0; index < number; index++) {
- data[index] = randomVector(rng, attrs);
+ data[index] = randomVector(rng, attrs, regression);
}
return data;
@@ -145,13 +146,14 @@ public final class Utils {
*
* @param rng Random number generator
* @param nbAttributes number of attributes
+ * @param regression true is the label should be numerical
* @param size data size
*/
- public static Data randomData(Random rng, int nbAttributes, int size) throws DescriptorException {
+ public static Data randomData(Random rng, int nbAttributes, boolean regression, int size) throws DescriptorException {
String descriptor = randomDescriptor(rng, nbAttributes);
- double[][] source = randomDoubles(rng, descriptor, size);
+ double[][] source = randomDoubles(rng, descriptor, regression, size);
String[] sData = double2String(source);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, regression, sData);
return DataLoader.loadData(dataset, sData);
}
@@ -168,7 +170,7 @@ public final class Utils {
*
* @param attrs attributes description
*/
- private static double[] randomVector(Random rng, Attribute[] attrs) {
+ private static double[] randomVector(Random rng, Attribute[] attrs, boolean regression) {
double[] vector = new double[attrs.length];
for (int attr = 0; attr < attrs.length; attr++) {
@@ -176,9 +178,14 @@ public final class Utils {
vector[attr] = Double.NaN;
} else if (attrs[attr].isNumerical()) {
vector[attr] = rng.nextDouble();
- } else {
- // CATEGORICAL or LABEL
+ } else if (attrs[attr].isCategorical()){
vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+ } else { // LABEL
+ if (regression) {
+ vector[attr] = rng.nextDouble();
+ } else {
+ vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+ }
}
}
@@ -222,14 +229,15 @@ public final class Utils {
*
* @param rng
* @param descriptor
+ * @param regression
* @param number data size
* @param value label value
*/
public static double[][] randomDoublesWithSameLabel(Random rng,
- String descriptor, int number, int value) throws DescriptorException {
+ String descriptor, boolean regression, int number, int value) throws DescriptorException {
int label = findLabel(descriptor);
- double[][] source = randomDoubles(rng, descriptor, number);
+ double[][] source = randomDoubles(rng, descriptor, regression, number);
for (int index = 0; index < number; index++) {
source[index][label] = value;
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java Sun Oct 23 19:26:19 2011
@@ -90,9 +90,9 @@ public final class Step1MapperTest exten
// prepare the data
String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES);
- double[][] source = Utils.randomDoubles(rng, descriptor, NUM_INSTANCES);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, NUM_INSTANCES);
String[] sData = Utils.double2String(source);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
String[][] splits = Utils.splitData(sData, NUM_MAPPERS);
MockTreeBuilder treeBuilder = new MockTreeBuilder();
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java Sun Oct 23 19:26:19 2011
@@ -38,9 +38,9 @@ public final class DefaultIgSplitTest ex
int label = Utils.findLabel(descriptor);
// all the vectors have the same label (0)
- double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, 100, 0);
+ double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100, 0);
String[] sData = Utils.double2String(temp);
- Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
Data data = DataLoader.loadData(dataset, sData);
DefaultIgSplit iG = new DefaultIgSplit();
@@ -53,7 +53,7 @@ public final class DefaultIgSplitTest ex
temp[index][label] = 1.0;
}
sData = Utils.double2String(temp);
- dataset = DataLoader.generateDataset(descriptor, sData);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
data = DataLoader.loadData(dataset, sData);
iG = new DefaultIgSplit();
@@ -67,7 +67,7 @@ public final class DefaultIgSplitTest ex
temp[index][label] = 2.0;
}
sData = Utils.double2String(temp);
- dataset = DataLoader.generateDataset(descriptor, sData);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
data = DataLoader.loadData(dataset, sData);
iG = new DefaultIgSplit();
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java Sun Oct 23 19:26:19 2011
@@ -37,7 +37,7 @@ public final class OptIgSplitTest extend
IgSplit opt = new OptIgSplit();
Random rng = RandomUtils.getRandom();
- Data data = Utils.randomData(rng, NUM_ATTRIBUTES, NUM_INSTANCES);
+ Data data = Utils.randomData(rng, NUM_ATTRIBUTES, false, NUM_INSTANCES);
for (int nloop = 0; nloop < 100; nloop++) {
int attr = rng.nextInt(data.getDataset().nbAttributes());
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java Sun Oct 23 19:26:19 2011
@@ -54,7 +54,9 @@ public class BreimanExample extends Conf
private static final Logger log = LoggerFactory.getLogger(BreimanExample.class);
/** sum test error */
- private double sumTestErr;
+ private double sumTestErrM;
+
+ private double sumTestErrOne;
/** mean time to build a forest with m=log2(M)+1 */
private long sumTimeM;
@@ -113,9 +115,12 @@ public class BreimanExample extends Conf
// compute the test set error (Selection Error), and mean tree error (One Tree Error),
int[] testLabels = test.extractLabels();
int[] predictions = new int[test.size()];
+
forestM.classify(test, predictions);
+ sumTestErrM += ErrorEstimate.errorRate(testLabels, predictions);
- sumTestErr += ErrorEstimate.errorRate(testLabels, predictions);
+ forestOne.classify(test, predictions);
+ sumTestErrOne += ErrorEstimate.errorRate(testLabels, predictions);
}
public static void main(String[] args) throws Exception {
@@ -194,7 +199,8 @@ public class BreimanExample extends Conf
}
log.info("********************************************");
- log.info("Selection error : {}", sumTestErr / nbIterations);
+ log.info("Random Input Test Error : {}", sumTestErrM / nbIterations);
+ log.info("Single Input Test Error : {}", sumTestErrOne / nbIterations);
log.info("Mean Random Input Time : {}", DFUtils.elapsedTime(sumTimeM / nbIterations));
log.info("Mean Single Input Time : {}", DFUtils.elapsedTime(sumTimeOne / nbIterations));
log.info("Mean Random Input Num Nodes : {}", numNodesM / nbIterations);
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sun Oct 23 19:26:19 2011
@@ -261,8 +261,8 @@ public class TestForest extends Configur
}
if (analyzer != null) {
- analyzer.addInstance(dataset.getLabel(instance.getLabel()),
- new ClassifierResult(dataset.getLabel(prediction), 1.0));
+ analyzer.addInstance(dataset.getLabelString(dataset.getLabel(instance)),
+ new ClassifierResult(dataset.getLabelString(prediction), 1.0));
}
}