You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/04 14:29:39 UTC
[37/53] [abbrv] [partial] mahout git commit: end of day 6-2-2018
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java
new file mode 100644
index 0000000..c3b2fa3
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java
@@ -0,0 +1,284 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Custom InputFormat that generates InputSplits given the desired number of trees.<br>
+ * each input split contains a subset of the trees.<br>
+ * The number of splits is equal to the number of requested splits
+ */
+@Deprecated
+public class InMemInputFormat extends InputFormat<IntWritable,NullWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(InMemInputSplit.class);
+
+ private Random rng;
+
+ private Long seed;
+
+ private boolean isSingleSeed;
+
+ /**
+ * Used for DEBUG purposes only. if true and a seed is available, all the mappers use the same seed, thus
+ * all the mapper should take the same time to build their trees.
+ */
+ private static boolean isSingleSeed(Configuration conf) {
+ return conf.getBoolean("debug.mahout.rf.single.seed", false);
+ }
+
+ @Override
+ public RecordReader<IntWritable,NullWritable> createRecordReader(InputSplit split, TaskAttemptContext context)
+ throws IOException, InterruptedException {
+ Preconditions.checkArgument(split instanceof InMemInputSplit);
+ return new InMemRecordReader((InMemInputSplit) split);
+ }
+
+ @Override
+ public List<InputSplit> getSplits(JobContext context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ int numSplits = conf.getInt("mapred.map.tasks", -1);
+
+ return getSplits(conf, numSplits);
+ }
+
+ public List<InputSplit> getSplits(Configuration conf, int numSplits) {
+ int nbTrees = Builder.getNbTrees(conf);
+ int splitSize = nbTrees / numSplits;
+
+ seed = Builder.getRandomSeed(conf);
+ isSingleSeed = isSingleSeed(conf);
+
+ if (rng != null && seed != null) {
+ log.warn("getSplits() was called more than once and the 'seed' is set, "
+ + "this can lead to no-repeatable behavior");
+ }
+
+ rng = seed == null || isSingleSeed ? null : RandomUtils.getRandom(seed);
+
+ int id = 0;
+
+ List<InputSplit> splits = new ArrayList<>(numSplits);
+
+ for (int index = 0; index < numSplits - 1; index++) {
+ splits.add(new InMemInputSplit(id, splitSize, nextSeed()));
+ id += splitSize;
+ }
+
+ // take care of the remainder
+ splits.add(new InMemInputSplit(id, nbTrees - id, nextSeed()));
+
+ return splits;
+ }
+
+ /**
+ * @return the seed for the next InputSplit
+ */
+ private Long nextSeed() {
+ if (seed == null) {
+ return null;
+ } else if (isSingleSeed) {
+ return seed;
+ } else {
+ return rng.nextLong();
+ }
+ }
+
+ public static class InMemRecordReader extends RecordReader<IntWritable,NullWritable> {
+
+ private final InMemInputSplit split;
+ private int pos;
+ private IntWritable key;
+ private NullWritable value;
+
+ public InMemRecordReader(InMemInputSplit split) {
+ this.split = split;
+ }
+
+ @Override
+ public float getProgress() throws IOException {
+ return pos == 0 ? 0.0f : (float) (pos - 1) / split.nbTrees;
+ }
+
+ @Override
+ public IntWritable getCurrentKey() throws IOException, InterruptedException {
+ return key;
+ }
+
+ @Override
+ public NullWritable getCurrentValue() throws IOException, InterruptedException {
+ return value;
+ }
+
+ @Override
+ public void initialize(InputSplit arg0, TaskAttemptContext arg1) throws IOException, InterruptedException {
+ key = new IntWritable();
+ value = NullWritable.get();
+ }
+
+ @Override
+ public boolean nextKeyValue() throws IOException, InterruptedException {
+ if (pos < split.nbTrees) {
+ key.set(split.firstId + pos);
+ pos++;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ }
+
+ }
+
+ /**
+ * Custom InputSplit that indicates how many trees are built by each mapper
+ */
+ public static class InMemInputSplit extends InputSplit implements Writable {
+
+ private static final String[] NO_LOCATIONS = new String[0];
+
+ /** Id of the first tree of this split */
+ private int firstId;
+
+ private int nbTrees;
+
+ private Long seed;
+
+ public InMemInputSplit() { }
+
+ public InMemInputSplit(int firstId, int nbTrees, Long seed) {
+ this.firstId = firstId;
+ this.nbTrees = nbTrees;
+ this.seed = seed;
+ }
+
+ /**
+ * @return the Id of the first tree of this split
+ */
+ public int getFirstId() {
+ return firstId;
+ }
+
+ /**
+ * @return the number of trees
+ */
+ public int getNbTrees() {
+ return nbTrees;
+ }
+
+ /**
+ * @return the random seed or null if no seed is available
+ */
+ public Long getSeed() {
+ return seed;
+ }
+
+ @Override
+ public long getLength() throws IOException {
+ return nbTrees;
+ }
+
+ @Override
+ public String[] getLocations() throws IOException {
+ return NO_LOCATIONS;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof InMemInputSplit)) {
+ return false;
+ }
+
+ InMemInputSplit split = (InMemInputSplit) obj;
+
+ if (firstId != split.firstId || nbTrees != split.nbTrees) {
+ return false;
+ }
+ if (seed == null) {
+ return split.seed == null;
+ } else {
+ return seed.equals(split.seed);
+ }
+
+ }
+
+ @Override
+ public int hashCode() {
+ return firstId + nbTrees + (seed == null ? 0 : seed.intValue());
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "[firstId:%d, nbTrees:%d, seed:%d]", firstId, nbTrees, seed);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ firstId = in.readInt();
+ nbTrees = in.readInt();
+ boolean isSeed = in.readBoolean();
+ seed = isSeed ? in.readLong() : null;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(firstId);
+ out.writeInt(nbTrees);
+ out.writeBoolean(seed != null);
+ if (seed != null) {
+ out.writeLong(seed);
+ }
+ }
+
+ public static InMemInputSplit read(DataInput in) throws IOException {
+ InMemInputSplit split = new InMemInputSplit();
+ split.readFields(in);
+ return split;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java
new file mode 100644
index 0000000..2fc67ba
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.mahout.classifier.df.Bagging;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.MapredMapper;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Random;
+
+/**
+ * In-memory mapper that grows the trees using a full copy of the data loaded in-memory. The number of trees
+ * to grow is determined by the current InMemInputSplit.
+ */
+@Deprecated
+public class InMemMapper extends MapredMapper<IntWritable,NullWritable,IntWritable,MapredOutput> {
+
+ private static final Logger log = LoggerFactory.getLogger(InMemMapper.class);
+
+ private Bagging bagging;
+
+ private Random rng;
+
+ /**
+ * Load the training data
+ */
+ private static Data loadData(Configuration conf, Dataset dataset) throws IOException {
+ Path dataPath = Builder.getDistributedCacheFile(conf, 1);
+ FileSystem fs = FileSystem.get(dataPath.toUri(), conf);
+ return DataLoader.loadData(dataset, fs, dataPath);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+
+ Configuration conf = context.getConfiguration();
+
+ log.info("Loading the data...");
+ Data data = loadData(conf, getDataset());
+ log.info("Data loaded : {} instances", data.size());
+
+ bagging = new Bagging(getTreeBuilder(), data);
+ }
+
+ @Override
+ protected void map(IntWritable key,
+ NullWritable value,
+ Context context) throws IOException, InterruptedException {
+ map(key, context);
+ }
+
+ void map(IntWritable key, Context context) throws IOException, InterruptedException {
+
+ initRandom((InMemInputSplit) context.getInputSplit());
+
+ log.debug("Building...");
+ Node tree = bagging.build(rng);
+
+ if (isOutput()) {
+ log.debug("Outputing...");
+ MapredOutput mrOut = new MapredOutput(tree);
+
+ context.write(key, mrOut);
+ }
+ }
+
+ void initRandom(InMemInputSplit split) {
+ if (rng == null) { // first execution of this mapper
+ Long seed = split.getSeed();
+ log.debug("Initialising rng with seed : {}", seed);
+ rng = seed == null ? RandomUtils.getRandom() : RandomUtils.getRandom(seed);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java
new file mode 100644
index 0000000..61e65e8
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java
@@ -0,0 +1,22 @@
+/**
+ * <h2>In-memory mapreduce implementation of Random Decision Forests</h2>
+ *
+ * <p>Each mapper is responsible for growing a number of trees with a whole copy of the dataset loaded in memory,
+ * it uses the reference implementation's code to build each tree and estimate the oob error.</p>
+ *
+ * <p>The dataset is distributed to the slave nodes using the {@link org.apache.hadoop.filecache.DistributedCache}.
+ * A custom {@link org.apache.hadoop.mapreduce.InputFormat}
+ * ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat}) is configured with the
+ * desired number of trees and generates a number of {@link org.apache.hadoop.mapreduce.InputSplit}s
+ * equal to the configured number of maps.</p>
+ *
+ * <p>There is no need for reducers, each map outputs (the trees it built and, for each tree, the labels the
+ * tree predicted for each out-of-bag instance. This step has to be done in the mapper because only there we
+ * know which instances are o-o-b.</p>
+ *
+ * <p>The Forest builder ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemBuilder}) is responsible
+ * for configuring and launching the job.
+ * At the end of the job it parses the output files and builds the corresponding
+ * {@link org.apache.mahout.classifier.df.DecisionForest}.</p>
+ */
+package org.apache.mahout.classifier.df.mapreduce.inmem;
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java
new file mode 100644
index 0000000..9236af3
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Builds a random forest using partial data. Each mapper uses only the data given by its InputSplit
+ */
+@Deprecated
+public class PartialBuilder extends Builder {
+
+ private static final Logger log = LoggerFactory.getLogger(PartialBuilder.class);
+
+ public PartialBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed) {
+ this(treeBuilder, dataPath, datasetPath, seed, new Configuration());
+ }
+
+ public PartialBuilder(TreeBuilder treeBuilder,
+ Path dataPath,
+ Path datasetPath,
+ Long seed,
+ Configuration conf) {
+ super(treeBuilder, dataPath, datasetPath, seed, conf);
+ }
+
+ @Override
+ protected void configureJob(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ job.setJarByClass(PartialBuilder.class);
+
+ FileInputFormat.setInputPaths(job, getDataPath());
+ FileOutputFormat.setOutputPath(job, getOutputPath(conf));
+
+ job.setOutputKeyClass(TreeID.class);
+ job.setOutputValueClass(MapredOutput.class);
+
+ job.setMapperClass(Step1Mapper.class);
+ job.setNumReduceTasks(0); // no reducers
+
+ job.setInputFormatClass(TextInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ // For this implementation to work, mapred.map.tasks needs to be set to the actual
+ // number of mappers Hadoop will use:
+ TextInputFormat inputFormat = new TextInputFormat();
+ List<?> splits = inputFormat.getSplits(job);
+ if (splits == null || splits.isEmpty()) {
+ log.warn("Unable to compute number of splits?");
+ } else {
+ int numSplits = splits.size();
+ log.info("Setting mapred.map.tasks = {}", numSplits);
+ conf.setInt("mapred.map.tasks", numSplits);
+ }
+ }
+
+ @Override
+ protected DecisionForest parseOutput(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ int numTrees = Builder.getNbTrees(conf);
+
+ Path outputPath = getOutputPath(conf);
+
+ TreeID[] keys = new TreeID[numTrees];
+ Node[] trees = new Node[numTrees];
+
+ processOutput(job, outputPath, keys, trees);
+
+ return new DecisionForest(Arrays.asList(trees));
+ }
+
+ /**
+ * Processes the output from the output path.<br>
+ *
+ * @param outputPath
+ * directory that contains the output of the job
+ * @param keys
+ * can be null
+ * @param trees
+ * can be null
+ * @throws java.io.IOException
+ */
+ protected static void processOutput(JobContext job,
+ Path outputPath,
+ TreeID[] keys,
+ Node[] trees) throws IOException {
+ Preconditions.checkArgument(keys == null && trees == null || keys != null && trees != null,
+ "if keys is null, trees should also be null");
+ Preconditions.checkArgument(keys == null || keys.length == trees.length, "keys.length != trees.length");
+
+ Configuration conf = job.getConfiguration();
+
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+
+ // read all the outputs
+ int index = 0;
+ for (Path path : outfiles) {
+ for (Pair<TreeID,MapredOutput> record : new SequenceFileIterable<TreeID, MapredOutput>(path, conf)) {
+ TreeID key = record.getFirst();
+ MapredOutput value = record.getSecond();
+ if (keys != null) {
+ keys[index] = key;
+ }
+ if (trees != null) {
+ trees[index] = value.getTree();
+ }
+ index++;
+ }
+ }
+
+ // make sure we got all the keys/values
+ if (keys != null && index != keys.length) {
+ throw new IllegalStateException("Some key/values are missing from the output");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java
new file mode 100644
index 0000000..9474236
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java
@@ -0,0 +1,168 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.df.Bagging;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.MapredMapper;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * First step of the Partial Data Builder. Builds the trees using the data available in the InputSplit.
+ * Predict the oob classes for each tree in its growing partition (input split).
+ */
+@Deprecated
+public class Step1Mapper extends MapredMapper<LongWritable,Text,TreeID,MapredOutput> {
+
+ private static final Logger log = LoggerFactory.getLogger(Step1Mapper.class);
+
+ /** used to convert input values to data instances */
+ private DataConverter converter;
+
+ private Random rng;
+
+ /** number of trees to be built by this mapper */
+ private int nbTrees;
+
+ /** id of the first tree */
+ private int firstTreeId;
+
+ /** mapper's partition */
+ private int partition;
+
+ /** will contain all instances if this mapper's split */
+ private final List<Instance> instances = new ArrayList<>();
+
+ public int getFirstTreeId() {
+ return firstTreeId;
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+
+ configure(Builder.getRandomSeed(conf), conf.getInt("mapred.task.partition", -1),
+ Builder.getNumMaps(conf), Builder.getNbTrees(conf));
+ }
+
+ /**
+ * Useful when testing
+ *
+ * @param partition
+ * current mapper inputSplit partition
+ * @param numMapTasks
+ * number of running map tasks
+ * @param numTrees
+ * total number of trees in the forest
+ */
+ protected void configure(Long seed, int partition, int numMapTasks, int numTrees) {
+ converter = new DataConverter(getDataset());
+
+ // prepare random-numders generator
+ log.debug("seed : {}", seed);
+ if (seed == null) {
+ rng = RandomUtils.getRandom();
+ } else {
+ rng = RandomUtils.getRandom(seed);
+ }
+
+ // mapper's partition
+ Preconditions.checkArgument(partition >= 0, "Wrong partition ID: " + partition + ". Partition must be >= 0!");
+ this.partition = partition;
+
+ // compute number of trees to build
+ nbTrees = nbTrees(numMapTasks, numTrees, partition);
+
+ // compute first tree id
+ firstTreeId = 0;
+ for (int p = 0; p < partition; p++) {
+ firstTreeId += nbTrees(numMapTasks, numTrees, p);
+ }
+
+ log.debug("partition : {}", partition);
+ log.debug("nbTrees : {}", nbTrees);
+ log.debug("firstTreeId : {}", firstTreeId);
+ }
+
+ /**
+ * Compute the number of trees for a given partition. The first partitions may be longer
+ * than the rest because of the remainder.
+ *
+ * @param numMaps
+ * total number of maps (partitions)
+ * @param numTrees
+ * total number of trees to build
+ * @param partition
+ * partition to compute the number of trees for
+ */
+ public static int nbTrees(int numMaps, int numTrees, int partition) {
+ int treesPerMapper = numTrees / numMaps;
+ int remainder = numTrees - numMaps * treesPerMapper;
+ return treesPerMapper + (partition < remainder ? 1 : 0);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
+ instances.add(converter.convert(value.toString()));
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ // prepare the data
+ log.debug("partition: {} numInstances: {}", partition, instances.size());
+
+ Data data = new Data(getDataset(), instances);
+ Bagging bagging = new Bagging(getTreeBuilder(), data);
+
+ TreeID key = new TreeID();
+
+ log.debug("Building {} trees", nbTrees);
+ for (int treeId = 0; treeId < nbTrees; treeId++) {
+ log.debug("Building tree number : {}", treeId);
+
+ Node tree = bagging.build(rng);
+
+ key.set(partition, firstTreeId + treeId);
+
+ if (isOutput()) {
+ MapredOutput emOut = new MapredOutput(tree);
+ context.write(key, emOut);
+ }
+
+ context.progress();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java
new file mode 100644
index 0000000..c296061
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.io.LongWritable;
+
+/**
+ * Indicates both the tree and the data partition used to grow the tree
+ */
+@Deprecated
+public class TreeID extends LongWritable implements Cloneable {
+
+ public static final int MAX_TREEID = 100000;
+
+ public TreeID() { }
+
+ public TreeID(int partition, int treeId) {
+ Preconditions.checkArgument(partition >= 0, "Wrong partition: " + partition + ". Partition must be >= 0!");
+ Preconditions.checkArgument(treeId >= 0, "Wrong treeId: " + treeId + ". TreeId must be >= 0!");
+ set(partition, treeId);
+ }
+
+ public void set(int partition, int treeId) {
+ set((long) partition * MAX_TREEID + treeId);
+ }
+
+ /**
+ * Data partition (InputSplit's index) that was used to grow the tree
+ */
+ public int partition() {
+ return (int) (get() / MAX_TREEID);
+ }
+
+ public int treeId() {
+ return (int) (get() % MAX_TREEID);
+ }
+
+ @Override
+ public TreeID clone() {
+ return new TreeID(partition(), treeId());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java
new file mode 100644
index 0000000..e621c91
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java
@@ -0,0 +1,16 @@
+/**
+ * <h2>Partial-data mapreduce implementation of Random Decision Forests</h2>
+ *
+ * <p>The builder splits the data, using a FileInputSplit, among the mappers.
+ * Building the forest and estimating the oob error takes two job steps.</p>
+ *
+ * <p>In the first step, each mapper is responsible for growing a number of trees with its partition's,
+ * loading the data instances in its {@code map()} function, then building the trees in the {@code close()} method. It
+ * uses the reference implementation's code to build each tree and estimate the oob error.</p>
+ *
+ * <p>The second step is needed when estimating the oob error. Each mapper loads all the trees that does not
+ * belong to its own partition (were not built using the partition's data) and uses them to classify the
+ * partition's data instances. The data instances are loaded in the {@code map()} method and the classification
+ * is performed in the {@code close()} method.</p>
+ */
+package org.apache.mahout.classifier.df.mapreduce.partial;
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
new file mode 100644
index 0000000..1f91842
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
@@ -0,0 +1,134 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+@Deprecated
+public class CategoricalNode extends Node {
+
+ private int attr;
+ private double[] values;
+ private Node[] childs;
+
+ public CategoricalNode() {
+ }
+
+ public CategoricalNode(int attr, double[] values, Node[] childs) {
+ this.attr = attr;
+ this.values = values;
+ this.childs = childs;
+ }
+
+ @Override
+ public double classify(Instance instance) {
+ int index = ArrayUtils.indexOf(values, instance.get(attr));
+ if (index == -1) {
+ // value not available, we cannot predict
+ return Double.NaN;
+ }
+ return childs[index].classify(instance);
+ }
+
+ @Override
+ public long maxDepth() {
+ long max = 0;
+
+ for (Node child : childs) {
+ long depth = child.maxDepth();
+ if (depth > max) {
+ max = depth;
+ }
+ }
+
+ return 1 + max;
+ }
+
+ @Override
+ public long nbNodes() {
+ long nbNodes = 1;
+
+ for (Node child : childs) {
+ nbNodes += child.nbNodes();
+ }
+
+ return nbNodes;
+ }
+
+ @Override
+ protected Type getType() {
+ return Type.CATEGORICAL;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof CategoricalNode)) {
+ return false;
+ }
+
+ CategoricalNode node = (CategoricalNode) obj;
+
+ return attr == node.attr && Arrays.equals(values, node.values) && Arrays.equals(childs, node.childs);
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = attr;
+ for (double value : values) {
+ hashCode = 31 * hashCode + (int) Double.doubleToLongBits(value);
+ }
+ for (Node node : childs) {
+ hashCode = 31 * hashCode + node.hashCode();
+ }
+ return hashCode;
+ }
+
+ @Override
+ protected String getString() {
+ StringBuilder buffer = new StringBuilder();
+
+ for (Node child : childs) {
+ buffer.append(child).append(',');
+ }
+
+ return buffer.toString();
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ attr = in.readInt();
+ values = DFUtils.readDoubleArray(in);
+ childs = DFUtils.readNodeArray(in);
+ }
+
+ @Override
+ protected void writeNode(DataOutput out) throws IOException {
+ out.writeInt(attr);
+ DFUtils.writeArray(out, values);
+ DFUtils.writeArray(out, childs);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java
new file mode 100644
index 0000000..3360bb5
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Represents a Leaf node
+ */
+@Deprecated
+public class Leaf extends Node {
+ private static final double EPSILON = 1.0e-6;
+
+ private double label;
+
+ Leaf() { }
+
+ public Leaf(double label) {
+ this.label = label;
+ }
+
+ @Override
+ public double classify(Instance instance) {
+ return label;
+ }
+
+ @Override
+ public long maxDepth() {
+ return 1;
+ }
+
+ @Override
+ public long nbNodes() {
+ return 1;
+ }
+
+ @Override
+ protected Type getType() {
+ return Type.LEAF;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof Leaf)) {
+ return false;
+ }
+
+ Leaf leaf = (Leaf) obj;
+
+ return Math.abs(label - leaf.label) < EPSILON;
+ }
+
+ @Override
+ public int hashCode() {
+ long bits = Double.doubleToLongBits(label);
+ return (int)(bits ^ (bits >>> 32));
+ }
+
+ @Override
+ protected String getString() {
+ return "";
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ label = in.readDouble();
+ }
+
+ @Override
+ protected void writeNode(DataOutput out) throws IOException {
+ out.writeDouble(label);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java
new file mode 100644
index 0000000..73d516d
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Represents an abstract node of a decision tree
+ */
+@Deprecated
+public abstract class Node implements Writable {
+
+ protected enum Type {
+ LEAF,
+ NUMERICAL,
+ CATEGORICAL
+ }
+
+ /**
+ * predicts the label for the instance
+ *
+ * @return -1 if the label cannot be predicted
+ */
+ public abstract double classify(Instance instance);
+
+ /**
+ * @return the total number of nodes of the tree
+ */
+ public abstract long nbNodes();
+
+ /**
+ * @return the maximum depth of the tree
+ */
+ public abstract long maxDepth();
+
+ protected abstract Type getType();
+
+ public static Node read(DataInput in) throws IOException {
+ Type type = Type.values()[in.readInt()];
+ Node node;
+
+ switch (type) {
+ case LEAF:
+ node = new Leaf();
+ break;
+ case NUMERICAL:
+ node = new NumericalNode();
+ break;
+ case CATEGORICAL:
+ node = new CategoricalNode();
+ break;
+ default:
+ throw new IllegalStateException("This implementation is not currently supported");
+ }
+
+ node.readFields(in);
+
+ return node;
+ }
+
+ @Override
+ public final String toString() {
+ return getType() + ":" + getString() + ';';
+ }
+
+ protected abstract String getString();
+
+ @Override
+ public final void write(DataOutput out) throws IOException {
+ out.writeInt(getType().ordinal());
+ writeNode(out);
+ }
+
+ protected abstract void writeNode(DataOutput out) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java
new file mode 100644
index 0000000..aa02089
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java
@@ -0,0 +1,115 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Represents a node that splits using a numerical attribute
+ */
+@Deprecated
+public class NumericalNode extends Node {
+ /** numerical attribute to split for */
+ private int attr;
+
+ /** split value */
+ private double split;
+
+ /** child node when attribute's value < split value */
+ private Node loChild;
+
+ /** child node when attribute's value >= split value */
+ private Node hiChild;
+
+ public NumericalNode() { }
+
+ public NumericalNode(int attr, double split, Node loChild, Node hiChild) {
+ this.attr = attr;
+ this.split = split;
+ this.loChild = loChild;
+ this.hiChild = hiChild;
+ }
+
+ @Override
+ public double classify(Instance instance) {
+ if (instance.get(attr) < split) {
+ return loChild.classify(instance);
+ } else {
+ return hiChild.classify(instance);
+ }
+ }
+
+ @Override
+ public long maxDepth() {
+ return 1 + Math.max(loChild.maxDepth(), hiChild.maxDepth());
+ }
+
+ @Override
+ public long nbNodes() {
+ return 1 + loChild.nbNodes() + hiChild.nbNodes();
+ }
+
+ @Override
+ protected Type getType() {
+ return Type.NUMERICAL;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof NumericalNode)) {
+ return false;
+ }
+
+ NumericalNode node = (NumericalNode) obj;
+
+ return attr == node.attr && split == node.split && loChild.equals(node.loChild) && hiChild.equals(node.hiChild);
+ }
+
+ @Override
+ public int hashCode() {
+ return attr + (int) Double.doubleToLongBits(split) + loChild.hashCode() + hiChild.hashCode();
+ }
+
+ @Override
+ protected String getString() {
+ return loChild.toString() + ',' + hiChild.toString();
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ attr = in.readInt();
+ split = in.readDouble();
+ loChild = Node.read(in);
+ hiChild = Node.read(in);
+ }
+
+ @Override
+ protected void writeNode(DataOutput out) throws IOException {
+ out.writeInt(attr);
+ out.writeDouble(split);
+ loChild.write(out);
+ hiChild.write(out);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java
new file mode 100644
index 0000000..7ef907e
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.ref;
+
+import org.apache.mahout.classifier.df.Bagging;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.node.Node;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Builds a Random Decision Forest using a given TreeBuilder to grow the trees
+ */
+@Deprecated
+public class SequentialBuilder {
+
+ private static final Logger log = LoggerFactory.getLogger(SequentialBuilder.class);
+
+ private final Random rng;
+
+ private final Bagging bagging;
+
+ /**
+ * Constructor
+ *
+ * @param rng
+ * random-numbers generator
+ * @param treeBuilder
+ * tree builder
+ * @param data
+ * training data
+ */
+ public SequentialBuilder(Random rng, TreeBuilder treeBuilder, Data data) {
+ this.rng = rng;
+ bagging = new Bagging(treeBuilder, data);
+ }
+
+ public DecisionForest build(int nbTrees) {
+ List<Node> trees = new ArrayList<>();
+
+ for (int treeId = 0; treeId < nbTrees; treeId++) {
+ trees.add(bagging.build(rng));
+ logProgress(((float) treeId + 1) / nbTrees);
+ }
+
+ return new DecisionForest(trees);
+ }
+
+ private static void logProgress(float progress) {
+ int percent = (int) (progress * 100);
+ if (percent % 10 == 0) {
+ log.info("Building {}%", percent);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java
new file mode 100644
index 0000000..3f1cfdf
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+
+import java.util.Arrays;
+
+/**
+ * Default, not optimized, implementation of IgSplit
+ */
+@Deprecated
+public class DefaultIgSplit extends IgSplit {
+
+ /** used by entropy() */
+ private int[] counts;
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ double[] values = data.values(attr);
+ double bestIg = -1;
+ double bestSplit = 0.0;
+
+ for (double value : values) {
+ double ig = numericalIg(data, attr, value);
+ if (ig > bestIg) {
+ bestIg = ig;
+ bestSplit = value;
+ }
+ }
+
+ return new Split(attr, bestIg, bestSplit);
+ } else {
+ double ig = categoricalIg(data, attr);
+
+ return new Split(attr, ig);
+ }
+ }
+
+ /**
+ * Computes the Information Gain for a CATEGORICAL attribute
+ */
+ double categoricalIg(Data data, int attr) {
+ double[] values = data.values(attr);
+ double hy = entropy(data); // H(Y)
+ double hyx = 0.0; // H(Y|X)
+ double invDataSize = 1.0 / data.size();
+
+ for (double value : values) {
+ Data subset = data.subset(Condition.equals(attr, value));
+ hyx += subset.size() * invDataSize * entropy(subset);
+ }
+
+ return hy - hyx;
+ }
+
+ /**
+ * Computes the Information Gain for a NUMERICAL attribute given a splitting value
+ */
+ double numericalIg(Data data, int attr, double split) {
+ double hy = entropy(data);
+ double invDataSize = 1.0 / data.size();
+
+ // LO subset
+ Data subset = data.subset(Condition.lesser(attr, split));
+ hy -= subset.size() * invDataSize * entropy(subset);
+
+ // HI subset
+ subset = data.subset(Condition.greaterOrEquals(attr, split));
+ hy -= subset.size() * invDataSize * entropy(subset);
+
+ return hy;
+ }
+
+ /**
+ * Computes the Entropy
+ */
+ protected double entropy(Data data) {
+ double invDataSize = 1.0 / data.size();
+
+ if (counts == null) {
+ counts = new int[data.getDataset().nblabels()];
+ }
+
+ Arrays.fill(counts, 0);
+ data.countLabels(counts);
+
+ double entropy = 0.0;
+ for (int label = 0; label < data.getDataset().nblabels(); label++) {
+ int count = counts[label];
+ if (count == 0) {
+ continue; // otherwise we get a NaN
+ }
+ double p = count * invDataSize;
+ entropy += -p * Math.log(p) / LOG2;
+ }
+
+ return entropy;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java
new file mode 100644
index 0000000..aff94e1
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java
@@ -0,0 +1,35 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.classifier.df.data.Data;
+
+/**
+ * Computes the best split using the Information Gain measure
+ */
+@Deprecated
+public abstract class IgSplit {
+
+ static final double LOG2 = Math.log(2.0);
+
+ /**
+ * Computes the best split for the given attribute
+ */
+ public abstract Split computeSplit(Data data, int attr);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
new file mode 100644
index 0000000..56b1a04
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.TreeSet;
+
+/**
+ * <p>Optimized implementation of IgSplit.
+ * This class can be used when the criterion variable is the categorical attribute.</p>
+ *
+ * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric
+ * features to fix a performance problem. To generate some synthetic data that exercises
+ * the issue, try for example generating 4 features of Normal(0,1) values with a random
+ * boolean 0/1 categorical feature. In Scala:</p>
+ *
+ * {@code
+ * val r = new scala.util.Random()
+ * val pw = new java.io.PrintWriter("random.csv")
+ * (1 to 10000000).foreach(e =>
+ * pw.println(r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * (if (r.nextBoolean()) 1 else 0))
+ * )
+ * pw.close()
+ * }
+ */
+@Deprecated
+public class OptIgSplit extends IgSplit {
+
+ private static final int MAX_NUMERIC_SPLITS = 16;
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ return numericalSplit(data, attr);
+ } else {
+ return categoricalSplit(data, attr);
+ }
+ }
+
+ /**
+ * Computes the split for a CATEGORICAL attribute
+ */
+ private static Split categoricalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+
+ double[] splitPoints = chooseCategoricalSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size); // H(Y)
+ double hyx = 0.0; // H(Y|X)
+ double invDataSize = 1.0 / size;
+
+ for (int index = 0; index < splitPoints.length; index++) {
+ size = DataUtils.sum(counts[index]);
+ hyx += size * invDataSize * entropy(counts[index], size);
+ }
+
+ double ig = hy - hyx;
+ return new Split(attr, ig);
+ }
+
+ static void computeFrequencies(Data data,
+ int attr,
+ double[] splitPoints,
+ int[][] counts,
+ int[] countAll) {
+ Dataset dataset = data.getDataset();
+
+ for (int index = 0; index < data.size(); index++) {
+ Instance instance = data.get(index);
+ int label = (int) dataset.getLabel(instance);
+ double value = instance.get(attr);
+ int split = 0;
+ while (split < splitPoints.length && value > splitPoints[split]) {
+ split++;
+ }
+ if (split < splitPoints.length) {
+ counts[split][label]++;
+ } // Otherwise it's in the last split, which we don't need to count
+ countAll[label]++;
+ }
+ }
+
+ /**
+ * Computes the best split for a NUMERICAL attribute
+ */
+ static Split numericalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+ Arrays.sort(values);
+
+ double[] splitPoints = chooseNumericSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+ int[] countLess = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size);
+ double invDataSize = 1.0 / size;
+
+ int best = -1;
+ double bestIg = -1.0;
+
+ // try each possible split value
+ for (int index = 0; index < splitPoints.length; index++) {
+ double ig = hy;
+
+ DataUtils.add(countLess, counts[index]);
+ DataUtils.dec(countAll, counts[index]);
+
+ // instance with attribute value < values[index]
+ size = DataUtils.sum(countLess);
+ ig -= size * invDataSize * entropy(countLess, size);
+ // instance with attribute value >= values[index]
+ size = DataUtils.sum(countAll);
+ ig -= size * invDataSize * entropy(countAll, size);
+
+ if (ig > bestIg) {
+ bestIg = ig;
+ best = index;
+ }
+ }
+
+ if (best == -1) {
+ throw new IllegalStateException("no best split found !");
+ }
+ return new Split(attr, bestIg, splitPoints[best]);
+ }
+
+ /**
+ * @return an array of values to split the numeric feature's values on when
+ * building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will
+ * return the averages between success values as split points. When larger, it will
+ * return MAX_NUMERIC_SPLITS approximate percentiles through the data.
+ */
+ private static double[] chooseNumericSplitPoints(double[] values) {
+ if (values.length <= 1) {
+ return values;
+ }
+ if (values.length <= MAX_NUMERIC_SPLITS + 1) {
+ double[] splitPoints = new double[values.length - 1];
+ for (int i = 1; i < values.length; i++) {
+ splitPoints[i-1] = (values[i] + values[i-1]) / 2.0;
+ }
+ return splitPoints;
+ }
+ Percentile distribution = new Percentile();
+ distribution.setData(values);
+ double[] percentiles = new double[MAX_NUMERIC_SPLITS];
+ for (int i = 0 ; i < percentiles.length; i++) {
+ double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0));
+ percentiles[i] = distribution.evaluate(p);
+ }
+ return percentiles;
+ }
+
+ private static double[] chooseCategoricalSplitPoints(double[] values) {
+ // There is no great reason to believe that categorical value order matters,
+ // but the original code worked this way, and it's not terrible in the absence
+ // of more sophisticated analysis
+ Collection<Double> uniqueOrderedCategories = new TreeSet<>();
+ for (double v : values) {
+ uniqueOrderedCategories.add(v);
+ }
+ double[] uniqueValues = new double[uniqueOrderedCategories.size()];
+ Iterator<Double> it = uniqueOrderedCategories.iterator();
+ for (int i = 0; i < uniqueValues.length; i++) {
+ uniqueValues[i] = it.next();
+ }
+ return uniqueValues;
+ }
+
+ /**
+ * Computes the Entropy
+ *
+ * @param counts counts[i] = numInstances with label i
+ * @param dataSize numInstances
+ */
+ private static double entropy(int[] counts, int dataSize) {
+ if (dataSize == 0) {
+ return 0.0;
+ }
+
+ double entropy = 0.0;
+
+ for (int count : counts) {
+ if (count > 0) {
+ double p = count / (double) dataSize;
+ entropy -= p * Math.log(p);
+ }
+ }
+
+ return entropy / LOG2;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
new file mode 100644
index 0000000..38695a3
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
@@ -0,0 +1,177 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Comparator;
+
+/**
+ * Regression problem implementation of IgSplit. This class can be used when the criterion variable is the numerical
+ * attribute.
+ */
+@Deprecated
+public class RegressionSplit extends IgSplit {
+
+ /**
+ * Comparator for Instance sort
+ */
+ private static class InstanceComparator implements Comparator<Instance>, Serializable {
+ private final int attr;
+
+ InstanceComparator(int attr) {
+ this.attr = attr;
+ }
+
+ @Override
+ public int compare(Instance arg0, Instance arg1) {
+ return Double.compare(arg0.get(attr), arg1.get(attr));
+ }
+ }
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ return numericalSplit(data, attr);
+ } else {
+ return categoricalSplit(data, attr);
+ }
+ }
+
+ /**
+ * Computes the split for a CATEGORICAL attribute
+ */
+ private static Split categoricalSplit(Data data, int attr) {
+ FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)];
+ double[] sk = new double[data.getDataset().nbValues(attr)];
+ for (int i = 0; i < ra.length; i++) {
+ ra[i] = new FullRunningAverage();
+ }
+ FullRunningAverage totalRa = new FullRunningAverage();
+ double totalSk = 0.0;
+
+ for (int i = 0; i < data.size(); i++) {
+ // computes the variance
+ Instance instance = data.get(i);
+ int value = (int) instance.get(attr);
+ double xk = data.getDataset().getLabel(instance);
+ if (ra[value].getCount() == 0) {
+ ra[value].addDatum(xk);
+ sk[value] = 0.0;
+ } else {
+ double mk = ra[value].getAverage();
+ ra[value].addDatum(xk);
+ sk[value] += (xk - mk) * (xk - ra[value].getAverage());
+ }
+
+ // total variance
+ if (i == 0) {
+ totalRa.addDatum(xk);
+ totalSk = 0.0;
+ } else {
+ double mk = totalRa.getAverage();
+ totalRa.addDatum(xk);
+ totalSk += (xk - mk) * (xk - totalRa.getAverage());
+ }
+ }
+
+ // computes the variance gain
+ double ig = totalSk;
+ for (double aSk : sk) {
+ ig -= aSk;
+ }
+
+ return new Split(attr, ig);
+ }
+
+ /**
+ * Computes the best split for a NUMERICAL attribute
+ */
+ private static Split numericalSplit(Data data, int attr) {
+ FullRunningAverage[] ra = new FullRunningAverage[2];
+ for (int i = 0; i < ra.length; i++) {
+ ra[i] = new FullRunningAverage();
+ }
+
+ // Instance sort
+ Instance[] instances = new Instance[data.size()];
+ for (int i = 0; i < data.size(); i++) {
+ instances[i] = data.get(i);
+ }
+ Arrays.sort(instances, new InstanceComparator(attr));
+
+ double[] sk = new double[2];
+ for (Instance instance : instances) {
+ double xk = data.getDataset().getLabel(instance);
+ if (ra[1].getCount() == 0) {
+ ra[1].addDatum(xk);
+ sk[1] = 0.0;
+ } else {
+ double mk = ra[1].getAverage();
+ ra[1].addDatum(xk);
+ sk[1] += (xk - mk) * (xk - ra[1].getAverage());
+ }
+ }
+ double totalSk = sk[1];
+
+ // find the best split point
+ double split = Double.NaN;
+ double preSplit = Double.NaN;
+ double bestVal = Double.MAX_VALUE;
+ double bestSk = 0.0;
+
+ // computes total variance
+ for (Instance instance : instances) {
+ double xk = data.getDataset().getLabel(instance);
+
+ if (instance.get(attr) > preSplit) {
+ double curVal = sk[0] / ra[0].getCount() + sk[1] / ra[1].getCount();
+ if (curVal < bestVal) {
+ bestVal = curVal;
+ bestSk = sk[0] + sk[1];
+ split = (instance.get(attr) + preSplit) / 2.0;
+ }
+ }
+
+ // computes the variance
+ if (ra[0].getCount() == 0) {
+ ra[0].addDatum(xk);
+ sk[0] = 0.0;
+ } else {
+ double mk = ra[0].getAverage();
+ ra[0].addDatum(xk);
+ sk[0] += (xk - mk) * (xk - ra[0].getAverage());
+ }
+
+ double mk = ra[1].getAverage();
+ ra[1].removeDatum(xk);
+ sk[1] -= (xk - mk) * (xk - ra[1].getAverage());
+
+ preSplit = instance.get(attr);
+ }
+
+ // computes the variance gain
+ double ig = totalSk - bestSk;
+
+ return new Split(attr, ig, split);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
new file mode 100644
index 0000000..2a6a322
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import java.util.Locale;
+
+/**
+ * Contains enough information to identify each split
+ */
+@Deprecated
+public final class Split {
+
+ private final int attr;
+ private final double ig;
+ private final double split;
+
+ public Split(int attr, double ig, double split) {
+ this.attr = attr;
+ this.ig = ig;
+ this.split = split;
+ }
+
+ public Split(int attr, double ig) {
+ this(attr, ig, Double.NaN);
+ }
+
+ /**
+ * @return attribute to split for
+ */
+ public int getAttr() {
+ return attr;
+ }
+
+ /**
+ * @return Information Gain of the split
+ */
+ public double getIg() {
+ return ig;
+ }
+
+ /**
+ * @return split value for NUMERICAL attributes
+ */
+ public double getSplit() {
+ return split;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "attr: %d, ig: %f, split: %f", attr, ig, split);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
new file mode 100644
index 0000000..f29faed
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
@@ -0,0 +1,166 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.DescriptorUtils;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Generates a file descriptor for a given dataset
+ */
+public final class Describe implements Tool {
+
+ private static final Logger log = LoggerFactory.getLogger(Describe.class);
+
+ private Describe() {}
+
+ public static int main(String[] args) throws Exception {
+ return ToolRunner.run(new Describe(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option pathOpt = obuilder.withLongName("path").withShortName("p").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option descriptorOpt = obuilder.withLongName("descriptor").withShortName("d").withRequired(true)
+ .withArgument(abuilder.withName("descriptor").withMinimum(1).create()).withDescription(
+ "data descriptor").create();
+
+ Option descPathOpt = obuilder.withLongName("file").withShortName("f").withRequired(true).withArgument(
+ abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Path to generated descriptor file").create();
+
+ Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r")
+ .create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption(
+ descriptorOpt).withOption(regOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return -1;
+ }
+
+ String dataPath = cmdLine.getValue(pathOpt).toString();
+ String descPath = cmdLine.getValue(descPathOpt).toString();
+ List<String> descriptor = convert(cmdLine.getValues(descriptorOpt));
+ boolean regression = cmdLine.hasOption(regOpt);
+
+ log.debug("Data path : {}", dataPath);
+ log.debug("Descriptor path : {}", descPath);
+ log.debug("Descriptor : {}", descriptor);
+ log.debug("Regression : {}", regression);
+
+ runTool(dataPath, descriptor, descPath, regression);
+ } catch (OptionException e) {
+ log.warn(e.toString());
+ CommandLineUtil.printHelp(group);
+ }
+ return 0;
+ }
+
+ private void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression)
+ throws DescriptorException, IOException {
+ log.info("Generating the descriptor...");
+ String descriptor = DescriptorUtils.generateDescriptor(description);
+
+ Path fPath = validateOutput(filePath);
+
+ log.info("generating the dataset...");
+ Dataset dataset = generateDataset(descriptor, dataPath, regression);
+
+ log.info("storing the dataset description");
+ String json = dataset.toJSON();
+ DFUtils.storeString(conf, fPath, json);
+ }
+
+ private Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException,
+ DescriptorException {
+ Path path = new Path(dataPath);
+ FileSystem fs = path.getFileSystem(conf);
+
+ return DataLoader.generateDataset(descriptor, regression, fs, path);
+ }
+
+ private Path validateOutput(String filePath) throws IOException {
+ Path path = new Path(filePath);
+ FileSystem fs = path.getFileSystem(conf);
+ if (fs.exists(path)) {
+ throw new IllegalStateException("Descriptor's file already exists");
+ }
+
+ return path;
+ }
+
+ private static List<String> convert(Collection<?> values) {
+ List<String> list = new ArrayList<>(values.size());
+ for (Object value : values) {
+ list.add(value.toString());
+ }
+ return list;
+ }
+
+ private Configuration conf;
+
+ @Override
+ public void setConf(Configuration entries) {
+ this.conf = entries;
+ }
+
+ @Override
+ public Configuration getConf() {
+ return conf;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
new file mode 100644
index 0000000..b421c4e
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is to visualize the Decision Forest
+ */
+@Deprecated
+public final class ForestVisualizer {
+
+ private static final Logger log = LoggerFactory.getLogger(ForestVisualizer.class);
+
+ private ForestVisualizer() {
+ }
+
+ public static String toString(DecisionForest forest, Dataset dataset, String[] attrNames) {
+
+ List<Node> trees;
+ try {
+ Method getTrees = forest.getClass().getDeclaredMethod("getTrees");
+ getTrees.setAccessible(true);
+ trees = (List<Node>) getTrees.invoke(forest);
+ } catch (IllegalAccessException e) {
+ throw new IllegalStateException(e);
+ } catch (InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ }
+
+ int cnt = 1;
+ StringBuilder buff = new StringBuilder();
+ for (Node tree : trees) {
+ buff.append("Tree[").append(cnt).append("]:");
+ buff.append(TreeVisualizer.toString(tree, dataset, attrNames));
+ buff.append('\n');
+ cnt++;
+ }
+ return buff.toString();
+ }
+
+ /**
+ * Decision Forest to String
+ * @param forestPath
+ * path to the Decision Forest
+ * @param datasetPath
+ * dataset path
+ * @param attrNames
+ * attribute names
+ */
+ public static String toString(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+ Configuration conf = new Configuration();
+ DecisionForest forest = DecisionForest.load(conf, new Path(forestPath));
+ Dataset dataset = Dataset.load(conf, new Path(datasetPath));
+ return toString(forest, dataset, attrNames);
+ }
+
+ /**
+ * Print Decision Forest
+ * @param forestPath
+ * path to the Decision Forest
+ * @param datasetPath
+ * dataset path
+ * @param attrNames
+ * attribute names
+ */
+ public static void print(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+ System.out.println(toString(forestPath, datasetPath, attrNames));
+ }
+
+ public static void main(String[] args) {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
+ .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+ .withDescription("Dataset path").create();
+
+ Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true)
+ .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+ .withDescription("Path to the Decision Forest").create();
+
+ Option attrNamesOpt = obuilder.withLongName("names").withShortName("n").withRequired(false)
+ .withArgument(abuilder.withName("names").withMinimum(1).create())
+ .withDescription("Optional, Attribute names").create();
+
+ Option helpOpt = obuilder.withLongName("help").withShortName("h")
+ .withDescription("Print out help").create();
+
+ Group group = gbuilder.withName("Options").withOption(datasetOpt).withOption(modelOpt)
+ .withOption(attrNamesOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption("help")) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ String datasetName = cmdLine.getValue(datasetOpt).toString();
+ String modelName = cmdLine.getValue(modelOpt).toString();
+ String[] attrNames = null;
+ if (cmdLine.hasOption(attrNamesOpt)) {
+ Collection<String> names = (Collection<String>) cmdLine.getValues(attrNamesOpt);
+ if (!names.isEmpty()) {
+ attrNames = new String[names.size()];
+ names.toArray(attrNames);
+ }
+ }
+
+ print(modelName, datasetName, attrNames);
+ } catch (Exception e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+}