You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:21 UTC
[09/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 03db65c..5a831df 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -19,30 +19,43 @@
package hivemall.smile.classification;
import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.MatrixUtils;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.matrix.ints.DoKIntMatrix;
+import hivemall.math.matrix.ints.IntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
-import hivemall.smile.vm.StackMachine;
import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
+import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.RandomUtils;
-import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.BitSet;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
@@ -52,7 +65,9 @@ import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -67,9 +82,9 @@ import org.apache.hadoop.mapred.Reporter;
@Description(
name = "train_randomforest_classifier",
- value = "_FUNC_(double[] features, int label [, string options]) - "
+ value = "_FUNC_(array<double|string> features, int label [, const array<double> classWeights, const string options]) - "
+ "Returns a relation consists of "
- + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
+ + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests, double weight>")
public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(RandomForestClassifierUDTF.class);
@@ -77,8 +92,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector labelOI;
- private List<double[]> featuresList;
+ private boolean denseInput;
+ private MatrixBuilder matrixBuilder;
private IntArrayList labels;
+
/**
* The number of trees for each task
*/
@@ -99,8 +116,12 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private int _minSamplesLeaf;
private long _seed;
private Attribute[] _attributes;
- private ModelType _outputType;
private SplitRule _splitRule;
+ private boolean _stratifiedSampling;
+ private double _subsample;
+
+ @Nullable
+ private double[] _classWeight;
@Nullable
private Reporter _progressReporter;
@@ -126,11 +147,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
- opts.addOption("output", "output_type", true,
- "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
opts.addOption("rule", "split_rule", true, "Split algorithm [default: GINI, ENTROPY]");
- opts.addOption("disable_compression", false,
- "Whether to disable compression of the output script [default: false]");
+ opts.addOption("stratified", "stratified_sampling", false,
+ "Enable Stratified sampling for unbalanced data");
+ opts.addOption("subsample", true, "Sampling rate in range (0.0,1.0]");
return opts;
}
@@ -141,9 +161,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
float numVars = -1.f;
Attribute[] attrs = null;
long seed = -1L;
- String output = "serialization";
SplitRule splitRule = SplitRule.GINI;
- boolean compress = true;
+ double[] classWeight = null;
+ boolean stratifiedSampling = false;
+ double subsample = 1.0d;
CommandLine cl = null;
if (argOIs.length >= 3) {
@@ -162,10 +183,26 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
- output = cl.getOptionValue("output", output);
splitRule = SmileExtUtils.resolveSplitRule(cl.getOptionValue("split_rule", "GINI"));
- if (cl.hasOption("disable_compression")) {
- compress = false;
+ stratifiedSampling = cl.hasOption("stratified_sampling");
+ subsample = Primitives.parseDouble(cl.getOptionValue("subsample"), 1.0d);
+ Preconditions.checkArgument(subsample > 0.d && subsample <= 1.0d,
+ UDFArgumentException.class, "Invalid -subsample value: " + subsample);
+
+ if (argOIs.length >= 4) {
+ classWeight = HiveUtils.getConstDoubleArray(argOIs[3]);
+ if (classWeight != null) {
+ for (int i = 0; i < classWeight.length; i++) {
+ double v = classWeight[i];
+ if (Double.isNaN(v)) {
+ classWeight[i] = 1.0d;
+ } else if (v <= 0.d) {
+ throw new UDFArgumentTypeException(3,
+ "each classWeight must be greather than 0: "
+ + Arrays.toString(classWeight));
+ }
+ }
+ }
}
}
@@ -177,43 +214,60 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
this._attributes = attrs;
- this._outputType = ModelType.resolve(output, compress);
this._splitRule = splitRule;
+ this._stratifiedSampling = stratifiedSampling;
+ this._subsample = subsample;
+ this._classWeight = classWeight;
return cl;
}
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- if (argOIs.length != 2 && argOIs.length != 3) {
+ if (argOIs.length < 2 || argOIs.length > 4) {
throw new UDFArgumentException(
- getClass().getSimpleName()
- + " takes 2 or 3 arguments: double[] features, int label [, const string options]: "
+ "_FUNC_ takes 2 ~ 4 arguments: array<double|string> features, int label [, const string options, const array<double> classWeight]: "
+ argOIs.length);
}
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
- this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ this.matrixBuilder = new CSRMatrixBuilder(8192);
+ } else {
+ throw new UDFArgumentException(
+ "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+ }
this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
processOptions(argOIs);
- this.featuresList = new ArrayList<double[]>(1024);
this.labels = new IntArrayList(1024);
- ArrayList<String> fieldNames = new ArrayList<String>(6);
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
+ final ArrayList<String> fieldNames = new ArrayList<String>(6);
+ final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("model_type");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
- fieldNames.add("pred_model");
+ fieldNames.add("model_weight");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("model");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
fieldNames.add("var_importance");
- fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ if (denseInput) {
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ } else {
+ fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector,
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ }
fieldNames.add("oob_errors");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("oob_tests");
@@ -227,13 +281,36 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
if (args[0] == null) {
throw new HiveException("array<double> features was null");
}
- double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+ parseFeatures(args[0], matrixBuilder);
int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
-
- featuresList.add(features);
labels.add(label);
}
+ private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+ builder.nextColumn(i, v);
+ }
+ } else {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String fv = o.toString();
+ builder.nextColumn(fv);
+ }
+ }
+ builder.nextRow();
+ }
+
@Override
public void close() throws HiveException {
this._progressReporter = getReporter();
@@ -242,10 +319,9 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
"finishedTreeBuildTasks");
reportProgress(_progressReporter);
- int numExamples = featuresList.size();
- if (numExamples > 0) {
- double[][] x = featuresList.toArray(new double[numExamples][]);
- this.featuresList = null;
+ if (!labels.isEmpty()) {
+ Matrix x = matrixBuilder.buildMatrix();
+ this.matrixBuilder = null;
int[] y = labels.toArray();
this.labels = null;
@@ -277,15 +353,16 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
* @param numVars The number of variables to pick up in each node.
* @param seed The seed number for Random Forest
*/
- private void train(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException {
- if (x.length != y.length) {
+ private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException {
+ final int numExamples = x.numRows();
+ if (numExamples != y.length) {
throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
- x.length, y.length));
+ numExamples, y.length));
}
checkOptions();
- // Shuffle training samples
- SmileExtUtils.shuffle(x, y, _seed);
+ // Shuffle training samples
+ x = SmileExtUtils.shuffle(x, y, _seed);
int[] labels = SmileExtUtils.classLables(y);
Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
@@ -297,9 +374,8 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
+ _maxLeafNodes + ", splitRule: " + _splitRule + ", seed: " + _seed);
}
- final int numExamples = x.length;
- int[][] prediction = new int[numExamples][labels.length]; // placeholder for out-of-bag prediction
- int[][] order = SmileExtUtils.sort(attributes, x);
+ IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder for out-of-bag prediction
+ ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
List<TrainingTask> tasks = new ArrayList<TrainingTask>();
for (int i = 0; i < _numTrees; i++) {
@@ -321,17 +397,19 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
/**
* Synchronized because {@link #forward(Object)} should be called from a single thread.
+ *
+ * @param accuracy
*/
synchronized void forward(final int taskId, @Nonnull final Text model,
- @Nonnull final double[] importance, final int[] y, final int[][] prediction,
- final boolean lastTask) throws HiveException {
+ @Nonnull final Vector importance, @Nonnegative final double accuracy, final int[] y,
+ @Nonnull final IntMatrix prediction, final boolean lastTask) throws HiveException {
int oobErrors = 0;
int oobTests = 0;
if (lastTask) {
// out-of-bag error estimate
for (int i = 0; i < y.length; i++) {
- final int pred = smile.math.Math.whichMax(prediction[i]);
- if (prediction[i][pred] > 0) {
+ final int pred = MatrixUtils.whichMax(prediction, i);
+ if (pred != -1 && prediction.get(i, pred) > 0) {
oobTests++;
if (pred != y[i]) {
oobErrors++;
@@ -340,12 +418,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
}
}
- String modelId = RandomUtils.getUUID();
final Object[] forwardObjs = new Object[6];
+ String modelId = RandomUtils.getUUID();
forwardObjs[0] = new Text(modelId);
- forwardObjs[1] = new IntWritable(_outputType.getId());
+ forwardObjs[1] = new DoubleWritable(accuracy);
forwardObjs[2] = model;
- forwardObjs[3] = WritableUtils.toWritableList(importance);
+ if (denseInput) {
+ forwardObjs[3] = WritableUtils.toWritableList(importance.toArray());
+ } else {
+ final Map<IntWritable, DoubleWritable> map = new HashMap<IntWritable, DoubleWritable>(
+ importance.size());
+ importance.each(new VectorProcedure() {
+ public void apply(int i, double value) {
+ map.put(new IntWritable(i), new DoubleWritable(value));
+ }
+ });
+ forwardObjs[3] = map;
+ }
forwardObjs[4] = new IntWritable(oobErrors);
forwardObjs[5] = new IntWritable(oobTests);
forward(forwardObjs);
@@ -363,20 +452,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
/**
* Attribute properties.
*/
+ @Nonnull
private final Attribute[] _attributes;
/**
* Training instances.
*/
- private final double[][] _x;
+ @Nonnull
+ private final Matrix _x;
/**
* Training sample labels.
*/
+ @Nonnull
private final int[] _y;
/**
- * The index of training values in ascending order. Note that only numeric attributes will
- * be sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ @Nonnull
+ private final ColumnMajorIntMatrix _order;
/**
* The number of variables to pick up in each node.
*/
@@ -384,16 +476,21 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
/**
* The out-of-bag predictions.
*/
- private final int[][] _prediction;
+ @Nonnull
+ @GuardedBy("_udtf")
+ private final IntMatrix _prediction;
+ @Nonnull
private final RandomForestClassifierUDTF _udtf;
private final int _taskId;
private final long _seed;
+ @Nonnull
private final AtomicInteger _remainingTasks;
- TrainingTask(RandomForestClassifierUDTF udtf, int taskId, Attribute[] attributes,
- double[][] x, int[] y, int numVars, int[][] order, int[][] prediction, long seed,
- AtomicInteger remainingTasks) {
+ TrainingTask(@Nonnull RandomForestClassifierUDTF udtf, int taskId,
+ @Nonnull Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y, int numVars,
+ @Nonnull ColumnMajorIntMatrix order, @Nonnull IntMatrix prediction, long seed,
+ @Nonnull AtomicInteger remainingTasks) {
this._udtf = udtf;
this._taskId = taskId;
this._attributes = attributes;
@@ -408,98 +505,107 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
@Override
public Integer call() throws HiveException {
- long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random(
- _seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
- final int N = _x.length;
+ long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final int N = _x.numRows();
// Training samples draw with replacement.
- final int[] bags = new int[N];
final BitSet sampled = new BitSet(N);
- for (int i = 0; i < N; i++) {
- int index = rnd1.nextInt(N);
- bags[i] = index;
- sampled.set(index);
- }
+ final int[] bags = sampling(sampled, N, rnd1);
DecisionTree tree = new DecisionTree(_attributes, _x, _y, _numVars, _udtf._maxDepth,
_udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, bags, _order,
_udtf._splitRule, rnd2);
// out-of-bag prediction
+ int oob = 0;
+ int correct = 0;
+ final Vector xProbe = _x.rowVector();
for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
- final int p = tree.predict(_x[i]);
- synchronized (_prediction[i]) {
- _prediction[i][p]++;
+ oob++;
+ _x.getRow(i, xProbe);
+ final int p = tree.predict(xProbe);
+ if (p == _y[i]) {
+ correct++;
+ }
+ synchronized (_udtf) {
+ _prediction.incr(i, p);
}
}
- Text model = getModel(tree, _udtf._outputType);
- double[] importance = tree.importance();
+ Text model = getModel(tree);
+ Vector importance = tree.importance();
+ double accuracy = (oob == 0) ? 1.0d : (double) correct / oob;
int remain = _remainingTasks.decrementAndGet();
boolean lastTask = (remain == 0);
- _udtf.forward(_taskId + 1, model, importance, _y, _prediction, lastTask);
+ _udtf.forward(_taskId + 1, model, importance, accuracy, _y, _prediction, lastTask);
return Integer.valueOf(remain);
}
- private static Text getModel(@Nonnull final DecisionTree tree,
- @Nonnull final ModelType outputType) throws HiveException {
- final Text model;
- switch (outputType) {
- case serialization:
- case serialization_compressed: {
- byte[] b = tree.predictSerCodegen(outputType.isCompressed());
- b = Base91.encode(b);
- model = new Text(b);
- break;
- }
- case opscode:
- case opscode_compressed: {
- String s = tree.predictOpCodegen(StackMachine.SEP);
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
+ @Nonnull
+ private int[] sampling(@Nonnull final BitSet sampled, final int N, @Nonnull PRNG rnd) {
+ return _udtf._stratifiedSampling ? stratifiedSampling(sampled, N, _udtf._subsample, rnd)
+ : uniformSampling(sampled, N, _udtf._subsample, rnd);
+ }
+
+ @Nonnull
+ private static int[] uniformSampling(@Nonnull final BitSet sampled, final int N,
+ final double subsample, final PRNG rnd) {
+ final int size = (int) Math.round(N * subsample);
+ final int[] bags = new int[N];
+ for (int i = 0; i < size; i++) {
+ int index = rnd.nextInt(N);
+ bags[i] = index;
+ sampled.set(index);
+ }
+ return bags;
+ }
+
+ /**
+ * Stratified sampling for unbalanced data.
+ *
+ * @link https://en.wikipedia.org/wiki/Stratified_sampling
+ */
+ @Nonnull
+ private int[] stratifiedSampling(@Nonnull final BitSet sampled, final int N,
+ final double subsample, final PRNG rnd) {
+ final IntArrayList bagsList = new IntArrayList(N);
+ final int k = smile.math.Math.max(_y) + 1;
+ final IntArrayList cj = new IntArrayList(N / k);
+ for (int l = 0; l < k; l++) {
+ int nj = 0;
+ for (int i = 0; i < N; i++) {
+ if (_y[i] == l) {
+ cj.add(i);
+ nj++;
}
- break;
}
- case javascript:
- case javascript_compressed: {
- String s = tree.predictJsCodegen();
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
- }
- break;
+ if (subsample != 1.0d) {
+ nj = (int) Math.round(nj * subsample);
+ }
+ final int size = (_udtf._classWeight == null) ? nj : (int) Math.round(nj
+ * _udtf._classWeight[l]);
+ for (int j = 0; j < size; j++) {
+ int xi = rnd.nextInt(nj);
+ int index = cj.get(xi);
+ bagsList.add(index);
+ sampled.set(index);
}
- default:
- throw new HiveException("Unexpected output type: " + outputType
- + ". Use javascript for the output instead");
+ cj.clear();
}
- return model;
+ int[] bags = bagsList.toArray(true);
+ SmileExtUtils.shuffle(bags, rnd);
+ return bags;
+ }
+
+ @Nonnull
+ private static Text getModel(@Nonnull final DecisionTree tree) throws HiveException {
+ byte[] b = tree.predictSerCodegen(true);
+ b = Base91.encode(b);
+ return new Text(b);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/data/Attribute.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/data/Attribute.java b/core/src/main/java/hivemall/smile/data/Attribute.java
index be6651a..6569726 100644
--- a/core/src/main/java/hivemall/smile/data/Attribute.java
+++ b/core/src/main/java/hivemall/smile/data/Attribute.java
@@ -18,6 +18,9 @@
*/
package hivemall.smile.data;
+import hivemall.annotations.Immutable;
+import hivemall.annotations.Mutable;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
@@ -25,11 +28,9 @@ import java.io.ObjectOutput;
public abstract class Attribute {
public final AttributeType type;
- public final int attrIndex;
- Attribute(AttributeType type, int attrIndex) {
+ Attribute(AttributeType type) {
this.type = type;
- this.attrIndex = attrIndex;
}
public void setSize(int size) {
@@ -44,24 +45,23 @@ public abstract class Attribute {
}
public void writeTo(ObjectOutput out) throws IOException {
- out.writeInt(type.getTypeId());
- out.writeInt(attrIndex);
+ out.writeByte(type.getTypeId());
}
public enum AttributeType {
- NUMERIC(1), NOMINAL(2);
+ NUMERIC((byte) 1), NOMINAL((byte) 2);
- private final int id;
+ private final byte id;
- private AttributeType(int id) {
+ private AttributeType(byte id) {
this.id = id;
}
- public int getTypeId() {
+ public byte getTypeId() {
return id;
}
- public static AttributeType resolve(int id) {
+ public static AttributeType resolve(byte id) {
final AttributeType type;
switch (id) {
case 1:
@@ -78,25 +78,27 @@ public abstract class Attribute {
}
+ @Immutable
public static final class NumericAttribute extends Attribute {
- public NumericAttribute(int attrIndex) {
- super(AttributeType.NUMERIC, attrIndex);
+ public NumericAttribute() {
+ super(AttributeType.NUMERIC);
}
@Override
public String toString() {
- return "NumericAttribute [type=" + type + ", attrIndex=" + attrIndex + "]";
+ return "NumericAttribute [type=" + type + "]";
}
}
+ @Mutable
public static final class NominalAttribute extends Attribute {
private int size;
- public NominalAttribute(int attrIndex) {
- super(AttributeType.NOMINAL, attrIndex);
+ public NominalAttribute() {
+ super(AttributeType.NOMINAL);
this.size = -1;
}
@@ -118,25 +120,23 @@ public abstract class Attribute {
@Override
public String toString() {
- return "NominalAttribute [size=" + size + ", type=" + type + ", attrIndex=" + attrIndex
- + "]";
+ return "NominalAttribute [size=" + size + ", type=" + type + "]";
}
}
public static Attribute readFrom(ObjectInput in) throws IOException {
- int typeId = in.readInt();
- int attrIndex = in.readInt();
-
final Attribute attr;
+
+ byte typeId = in.readByte();
final AttributeType type = AttributeType.resolve(typeId);
switch (type) {
case NUMERIC: {
- attr = new NumericAttribute(attrIndex);
+ attr = new NumericAttribute();
break;
}
case NOMINAL: {
- attr = new NominalAttribute(attrIndex);
+ attr = new NominalAttribute();
int size = in.readInt();
attr.setSize(size);
break;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index ebb58c6..557df21 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -19,22 +19,25 @@
package hivemall.smile.regression;
import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
-import hivemall.smile.vm.StackMachine;
import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.DoubleArrayList;
+import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.RandomUtils;
-import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
@@ -42,6 +45,7 @@ import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -69,7 +73,7 @@ import org.apache.hadoop.mapred.Reporter;
@Description(
name = "train_randomforest_regression",
- value = "_FUNC_(double[] features, double target [, string options]) - "
+ value = "_FUNC_(array<double|string> features, double target [, string options]) - "
+ "Returns a relation consists of "
+ "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
public final class RandomForestRegressionUDTF extends UDTFWithOptions {
@@ -79,7 +83,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector targetOI;
- private List<double[]> featuresList;
+ private boolean denseInput;
+ private MatrixBuilder matrixBuilder;
private DoubleArrayList targets;
/**
* The number of trees for each task
@@ -101,7 +106,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private int _minSamplesLeaf;
private long _seed;
private Attribute[] _attributes;
- private ModelType _outputType;
@Nullable
private Reporter _progressReporter;
@@ -131,10 +135,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
- opts.addOption("output", "output_type", true,
- "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
- opts.addOption("disable_compression", false,
- "Whether to disable compression of the output script [default: false]");
return opts;
}
@@ -145,8 +145,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
float numVars = -1.f;
Attribute[] attrs = null;
long seed = -1L;
- String output = "serialization";
- boolean compress = true;
CommandLine cl = null;
if (argOIs.length >= 3) {
@@ -165,10 +163,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
- output = cl.getOptionValue("output", output);
- if (cl.hasOption("disable_compression")) {
- compress = false;
- }
}
this._numTrees = trees;
@@ -179,7 +173,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
this._attributes = attrs;
- this._outputType = ModelType.resolve(output, compress);
return cl;
}
@@ -189,19 +182,29 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
if (argOIs.length != 2 && argOIs.length != 3) {
throw new UDFArgumentException(
getClass().getSimpleName()
- + " takes 2 or 3 arguments: double[] features, double target [, const string options]: "
+ + " takes 2 or 3 arguments: array<double|string> features, double target [, const string options]: "
+ argOIs.length);
}
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
- this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ this.matrixBuilder = new CSRMatrixBuilder(8192);
+ } else {
+ throw new UDFArgumentException(
+ "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+ }
this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
processOptions(argOIs);
- this.featuresList = new ArrayList<double[]>(1024);
this.targets = new DoubleArrayList(1024);
ArrayList<String> fieldNames = new ArrayList<String>(5);
@@ -209,8 +212,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("model_type");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("model_err");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("pred_model");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
fieldNames.add("var_importance");
@@ -228,13 +231,36 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
if (args[0] == null) {
throw new HiveException("array<double> features was null");
}
- double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+ parseFeatures(args[0], matrixBuilder);
double target = PrimitiveObjectInspectorUtils.getDouble(args[1], targetOI);
-
- featuresList.add(features);
targets.add(target);
}
+ private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+ builder.nextColumn(i, v);
+ }
+ } else {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String fv = o.toString();
+ builder.nextColumn(fv);
+ }
+ }
+ builder.nextRow();
+ }
+
@Override
public void close() throws HiveException {
this._progressReporter = getReporter();
@@ -250,10 +276,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
reportProgress(_progressReporter);
- int numExamples = featuresList.size();
- if (numExamples > 0) {
- double[][] x = featuresList.toArray(new double[numExamples][]);
- this.featuresList = null;
+ if (!targets.isEmpty()) {
+ Matrix x = matrixBuilder.buildMatrix();
+ this.matrixBuilder = null;
double[] y = targets.toArray();
this.targets = null;
@@ -285,15 +310,16 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
* @param _numVars The number of variables to pick up in each node.
* @param _seed The seed number for Random Forest
*/
- private void train(@Nonnull final double[][] x, @Nonnull final double[] y) throws HiveException {
- if (x.length != y.length) {
+ private void train(@Nonnull Matrix x, @Nonnull final double[] y) throws HiveException {
+ final int numExamples = x.numRows();
+ if (numExamples != y.length) {
throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
- x.length, y.length));
+ numExamples, y.length));
}
checkOptions();
- // Shuffle training samples
- SmileExtUtils.shuffle(x, y, _seed);
+ // Shuffle training samples
+ x = SmileExtUtils.shuffle(x, y, _seed);
Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x);
@@ -305,10 +331,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
+ ", seed: " + _seed);
}
- int numExamples = x.length;
double[] prediction = new double[numExamples]; // placeholder for out-of-bag prediction
int[] oob = new int[numExamples];
- int[][] order = SmileExtUtils.sort(attributes, x);
+ ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
List<TrainingTask> tasks = new ArrayList<TrainingTask>();
for (int i = 0; i < _numTrees; i++) {
@@ -330,10 +355,13 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
/**
* Synchronized because {@link #forward(Object)} should be called from a single thread.
+ *
+ * @param error
*/
synchronized void forward(final int taskId, @Nonnull final Text model,
- @Nonnull final double[] importance, final double[] y, final double[] prediction,
- final int[] oob, final boolean lastTask) throws HiveException {
+ @Nonnull final double[] importance, @Nonnegative final double error, final double[] y,
+ final double[] prediction, final int[] oob, final boolean lastTask)
+ throws HiveException {
double oobErrors = 0.d;
int oobTests = 0;
if (lastTask) {
@@ -349,7 +377,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
String modelId = RandomUtils.getUUID();
final Object[] forwardObjs = new Object[6];
forwardObjs[0] = new Text(modelId);
- forwardObjs[1] = new IntWritable(_outputType.getId());
+ forwardObjs[1] = new DoubleWritable(error);
forwardObjs[2] = model;
forwardObjs[3] = WritableUtils.toWritableList(importance);
forwardObjs[4] = new DoubleWritable(oobErrors);
@@ -373,16 +401,15 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
/**
* Training instances.
*/
- private final double[][] _x;
+ private final Matrix _x;
/**
* Training sample labels.
*/
private final double[] _y;
/**
- * The index of training values in ascending order. Note that only numeric attributes will
- * be sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ private final ColumnMajorIntMatrix _order;
/**
* The number of variables to pick up in each node.
*/
@@ -401,8 +428,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private final long _seed;
private final AtomicInteger _remainingTasks;
- TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes,
- double[][] x, double[] y, int numVars, int[][] order, double[] prediction,
+ TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes, Matrix x,
+ double[] y, int numVars, ColumnMajorIntMatrix order, double[] prediction,
int[] oob, long seed, AtomicInteger remainingTasks) {
this._udtf = udtf;
this._taskId = taskId;
@@ -419,11 +446,11 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
@Override
public Integer call() throws HiveException {
- long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random(
- _seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
- final int N = _x.length;
+ long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final int N = _x.numRows();
// Training samples draw with replacement.
final int[] bags = new int[N];
@@ -441,82 +468,40 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
incrCounter(_udtf._treeConstuctionTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
// out-of-bag prediction
+ int oob = 0;
+ double error = 0.d;
+ final Vector xProbe = _x.rowVector();
for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
- double pred = tree.predict(_x[i]);
- synchronized (_x[i]) {
+ oob++;
+ _x.getRow(i, xProbe);
+ final double pred = tree.predict(xProbe);
+ synchronized (_prediction) {
_prediction[i] += pred;
_oob[i]++;
}
+ error += Math.abs(pred - _y[i]);
+ }
+ if (oob != 0) {
+ error /= oob;
}
stopwatch.reset().start();
- Text model = getModel(tree, _udtf._outputType);
+ Text model = getModel(tree);
double[] importance = tree.importance();
tree = null; // help GC
int remain = _remainingTasks.decrementAndGet();
boolean lastTask = (remain == 0);
- _udtf.forward(_taskId + 1, model, importance, _y, _prediction, _oob, lastTask);
+ _udtf.forward(_taskId + 1, model, importance, error, _y, _prediction, _oob, lastTask);
incrCounter(_udtf._treeSerializationTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
return Integer.valueOf(remain);
}
- private static Text getModel(@Nonnull final RegressionTree tree,
- @Nonnull final ModelType outputType) throws HiveException {
- final Text model;
- switch (outputType) {
- case serialization:
- case serialization_compressed: {
- byte[] b = tree.predictSerCodegen(outputType.isCompressed());
- b = Base91.encode(b);
- model = new Text(b);
- break;
- }
- case opscode:
- case opscode_compressed: {
- String s = tree.predictOpCodegen(StackMachine.SEP);
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
- }
- break;
- }
- case javascript:
- case javascript_compressed: {
- String s = tree.predictJsCodegen();
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
- }
- break;
- }
- default:
- throw new HiveException("Unexpected output type: " + outputType
- + ". Use javascript for the output instead");
- }
- return model;
+ @Nonnull
+ private static Text getModel(@Nonnull final RegressionTree tree) throws HiveException {
+ byte[] b = tree.predictSerCodegen(true);
+ b = Base91.encode(b);
+ return new Text(b);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RegressionTree.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 07887c1..da7e80b 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -33,20 +33,28 @@
*/
package hivemall.smile.regression;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
import hivemall.smile.data.Attribute.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.collections.sets.IntArraySet;
+import hivemall.utils.collections.sets.IntSet;
import hivemall.utils.lang.ObjectUtils;
-import hivemall.utils.lang.StringUtils;
+import hivemall.utils.math.MathUtils;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
@@ -55,60 +63,48 @@ import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import smile.math.Math;
-import smile.math.Random;
import smile.regression.GradientTreeBoost;
import smile.regression.RandomForest;
import smile.regression.Regression;
/**
- * Decision tree for regression. A decision tree can be learned by splitting the training set into
- * subsets based on an attribute value test. This process is repeated on each derived subset in a
- * recursive manner called recursive partitioning.
+ * Decision tree for regression. A decision tree can be learned by splitting the training set into subsets based on an attribute value test. This
+ * process is repeated on each derived subset in a recursive manner called recursive partitioning.
* <p>
- * Classification and Regression Tree techniques have a number of advantages over many of those
- * alternative techniques.
+ * Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques.
* <dl>
* <dt>Simple to understand and interpret.</dt>
- * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This
- * simplicity is useful not only for purposes of rapid classification of new observations, but can
- * also often yield a much simpler "model" for explaining why observations are classified or
- * predicted in a particular manner.</dd>
+ * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This simplicity is useful not only for purposes of rapid
+ * classification of new observations, but can also often yield a much simpler "model" for explaining why observations are classified or predicted in
+ * a particular manner.</dd>
* <dt>Able to handle both numerical and categorical data.</dt>
- * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of
- * variable.</dd>
+ * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of variable.</dd>
* <dt>Tree methods are nonparametric and nonlinear.</dt>
- * <dd>The final results of using tree methods for classification or regression can be summarized in
- * a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no
- * implicit assumption that the underlying relationships between the predictor variables and the
- * dependent variable are linear, follow some specific non-linear link function, or that they are
- * even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks,
- * where there is often little a priori knowledge nor any coherent set of theories or predictions
- * regarding which variables are related and how. In those types of data analytics, tree methods can
- * often reveal simple relationships between just a few variables that could have easily gone
- * unnoticed using other analytic techniques.</dd>
+ * <dd>The final results of using tree methods for classification or regression can be summarized in a series of (usually few) logical if-then
+ * conditions (tree nodes). Therefore, there is no implicit assumption that the underlying relationships between the predictor variables and the
+ * dependent variable are linear, follow some specific non-linear link function, or that they are even monotonic in nature. Thus, tree methods are
+ * particularly well suited for data mining tasks, where there is often little a priori knowledge nor any coherent set of theories or predictions
+ * regarding which variables are related and how. In those types of data analytics, tree methods can often reveal simple relationships between just a
+ * few variables that could have easily gone unnoticed using other analytic techniques.</dd>
* </dl>
- * One major problem with classification and regression trees is their high variance. Often a small
- * change in the data can result in a very different series of splits, making interpretation
- * somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause
- * over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation
- * of trees is the lack of smoothness of the prediction surface.
+ * One major problem with classification and regression trees is their high variance. Often a small change in the data can result in a very different
+ * series of splits, making interpretation somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause over-fitting.
+ * Mechanisms such as pruning are necessary to avoid this problem. Another limitation of trees is the lack of smoothness of the prediction surface.
* <p>
- * Some techniques such as bagging, boosting, and random forest use more than one decision tree for
- * their analysis.
+ * Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis.
*
* @see GradientTreeBoost
* @see RandomForest
*/
-public final class RegressionTree implements Regression<double[]> {
+public final class RegressionTree implements Regression<Vector> {
/**
* The attributes of independent variable.
*/
private final Attribute[] _attributes;
private final boolean _hasNumericType;
/**
- * Variable importance. Every time a split of a node is made on variable the impurity criterion
- * for the two descendant nodes is less than the parent node. Adding up the decreases for each
- * individual variable over the tree gives a simple measure of variable importance.
+ * Variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendant nodes is less than the
+ * parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
*/
private final double[] _importance;
/**
@@ -120,8 +116,7 @@ public final class RegressionTree implements Regression<double[]> {
*/
private final int _maxDepth;
/**
- * The number of instances in a node below which the tree will not split, setting S = 5
- * generally gives good results.
+ * The number of instances in a node below which the tree will not split, setting S = 5 generally gives good results.
*/
private final int _minSplit;
/**
@@ -133,19 +128,17 @@ public final class RegressionTree implements Regression<double[]> {
*/
private final int _numVars;
/**
- * The index of training values in ascending order. Note that only numeric attributes will be
- * sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ private final ColumnMajorIntMatrix _order;
- private final Random _rnd;
+ private final PRNG _rnd;
private final NodeOutput _nodeOutput;
/**
- * An interface to calculate node output. Note that samples[i] is the number of sampling of
- * dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible
- * because of sampling with replacement.
+ * An interface to calculate node output. Note that samples[i] is the number of sampling of dataset[i]. 0 means that the datum is not included and
+ * values of greater than 1 are possible because of sampling with replacement.
*/
public interface NodeOutput {
/**
@@ -205,22 +198,30 @@ public final class RegressionTree implements Regression<double[]> {
this.output = output;
}
+ private boolean isLeaf() {
+ return trueChild == null && falseChild == null;
+ }
+
+ @VisibleForTesting
+ public double predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
+ }
+
/**
* Evaluate the regression tree over an instance.
*/
- public double predict(final double[] x) {
+ public double predict(@Nonnull final Vector x) {
if (trueChild == null && falseChild == null) {
return output;
} else {
if (splitFeatureType == AttributeType.NOMINAL) {
- // REVIEWME if(Math.equals(x[splitFeature], splitValue)) {
- if (x[splitFeature] == splitValue) {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
} else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x[splitFeature] <= splitValue) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
@@ -283,99 +284,58 @@ public final class RegressionTree implements Regression<double[]> {
}
}
- public int opCodegen(final List<String> scripts, int depth) {
- int selfDepth = 0;
- final StringBuilder buf = new StringBuilder();
- if (trueChild == null && falseChild == null) {
- buf.append("push ").append(output);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("goto last");
- scripts.add(buf.toString());
- selfDepth += 2;
- } else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifeq ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifle ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else {
- throw new IllegalStateException("Unsupported attribute type: "
- + splitFeatureType);
- }
- }
- return selfDepth;
- }
-
@Override
public void writeExternal(ObjectOutput out) throws IOException {
- out.writeDouble(output);
out.writeInt(splitFeature);
if (splitFeatureType == null) {
- out.writeInt(-1);
+ out.writeByte(-1);
} else {
- out.writeInt(splitFeatureType.getTypeId());
+ out.writeByte(splitFeatureType.getTypeId());
}
out.writeDouble(splitValue);
- if (trueChild == null) {
- out.writeBoolean(false);
- } else {
+
+ if (isLeaf()) {
out.writeBoolean(true);
- trueChild.writeExternal(out);
- }
- if (falseChild == null) {
- out.writeBoolean(false);
+ out.writeDouble(output);
} else {
- out.writeBoolean(true);
- falseChild.writeExternal(out);
+ out.writeBoolean(false);
+ if (trueChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ trueChild.writeExternal(out);
+ }
+ if (falseChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ falseChild.writeExternal(out);
+ }
}
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this.output = in.readDouble();
this.splitFeature = in.readInt();
- int typeId = in.readInt();
+ byte typeId = in.readByte();
if (typeId == -1) {
this.splitFeatureType = null;
} else {
this.splitFeatureType = AttributeType.resolve(typeId);
}
this.splitValue = in.readDouble();
- if (in.readBoolean()) {
- this.trueChild = new Node();
- trueChild.readExternal(in);
- }
- if (in.readBoolean()) {
- this.falseChild = new Node();
- falseChild.readExternal(in);
+
+ if (in.readBoolean()) {// isLeaf()
+ this.output = in.readDouble();
+ } else {
+ if (in.readBoolean()) {
+ this.trueChild = new Node();
+ trueChild.readExternal(in);
+ }
+ if (in.readBoolean()) {
+ this.falseChild = new Node();
+ falseChild.readExternal(in);
+ }
}
}
}
@@ -406,7 +366,7 @@ public final class RegressionTree implements Regression<double[]> {
/**
* Training dataset.
*/
- final double[][] x;
+ final Matrix x;
/**
* Training data response value.
*/
@@ -419,7 +379,7 @@ public final class RegressionTree implements Regression<double[]> {
/**
* Constructor.
*/
- public TrainNode(Node node, double[][] x, double[] y, int[] bags, int depth) {
+ public TrainNode(Node node, Matrix x, double[] y, int[] bags, int depth) {
this.node = node;
this.x = x;
this.y = y;
@@ -452,8 +412,7 @@ public final class RegressionTree implements Regression<double[]> {
}
/**
- * Finds the best attribute to split on at the current node. Returns true if a split exists
- * to reduce squared error, false otherwise.
+ * Finds the best attribute to split on at the current node. Returns true if a split exists to reduce squared error, false otherwise.
*/
public boolean findBestSplit() {
// avoid split if tree depth is larger than threshold
@@ -467,22 +426,14 @@ public final class RegressionTree implements Regression<double[]> {
}
final double sum = node.output * numSamples;
- final int p = _attributes.length;
- final int[] variables = new int[p];
- for (int i = 0; i < p; i++) {
- variables[i] = i;
- }
- if (_numVars < p) {
- SmileExtUtils.shuffle(variables, _rnd);
- }
// Loop through features and compute the reduction of squared error,
// which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2
- final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length)
+ final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows())
: null;
- for (int j = 0; j < _numVars; j++) {
- Node split = findBestSplit(numSamples, sum, variables[j], samples);
+ for (int varJ : variableIndex(x, bags)) {
+ final Node split = findBestSplit(numSamples, sum, varJ, samples);
if (split.splitScore > node.splitScore) {
node.splitFeature = split.splitFeature;
node.splitFeatureType = split.splitFeatureType;
@@ -496,6 +447,31 @@ public final class RegressionTree implements Regression<double[]> {
return node.splitFeature != -1;
}
+ private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) {
+ final int[] variableIndex;
+ if (x.isSparse()) {
+ final IntSet cols = new IntArraySet(_numVars);
+ final VectorProcedure proc = new VectorProcedure() {
+ public void apply(int col, double value) {
+ cols.add(col);
+ }
+ };
+ for (final int row : bags) {
+ x.eachNonNullInRow(row, proc);
+ }
+ variableIndex = cols.toArray(false);
+ } else {
+ variableIndex = MathUtils.permutation(_attributes.length);
+ }
+
+ if (_numVars < variableIndex.length) {
+ SmileExtUtils.shuffle(variableIndex, _rnd);
+ return Arrays.copyOf(variableIndex, _numVars);
+
+ }
+ return variableIndex;
+ }
+
/**
* Finds the best split cutoff for attribute j at the current node.
*
@@ -517,7 +493,11 @@ public final class RegressionTree implements Regression<double[]> {
// For each true feature of this datum increment the
// sufficient statistics for the "true" branch to evaluate
// splitting on this feature.
- int index = (int) x[i][j];
+ final double v = x.get(i, j, Double.NaN);
+ if (Double.isNaN(v)) {
+ continue;
+ }
+ int index = (int) v;
trueSum[index] += y[i];
++trueCount[index];
}
@@ -548,28 +528,38 @@ public final class RegressionTree implements Regression<double[]> {
}
}
} else if (_attributes[j].type == AttributeType.NUMERIC) {
- double trueSum = 0.0;
- int trueCount = 0;
- double prevx = Double.NaN;
-
- for (int i : _order[j]) {
- final int sample = samples[i];
- if (sample > 0) {
- if (Double.isNaN(prevx) || x[i][j] == prevx) {
- prevx = x[i][j];
- trueSum += sample * y[i];
+
+ _order.eachNonNullInColumn(j, new VectorProcedure() {
+ double trueSum = 0.0;
+ int trueCount = 0;
+ double prevx = Double.NaN;
+
+ public void apply(final int row, final int i) {
+ final int sample = samples[i];
+ if (sample == 0) {
+ return;
+ }
+ final double x_ij = x.get(i, j, Double.NaN);
+ if (Double.isNaN(x_ij)) {
+ return;
+ }
+ final double y_i = y[i];
+
+ if (Double.isNaN(prevx) || x_ij == prevx) {
+ prevx = x_ij;
+ trueSum += sample * y_i;
trueCount += sample;
- continue;
+ return;
}
final double falseCount = n - trueCount;
// If either side is empty, skip this feature.
if (trueCount < _minSplit || falseCount < _minSplit) {
- prevx = x[i][j];
- trueSum += sample * y[i];
+ prevx = x_ij;
+ trueSum += sample * y_i;
trueCount += sample;
- continue;
+ return;
}
// compute penalized means
@@ -586,17 +576,18 @@ public final class RegressionTree implements Regression<double[]> {
// new best split
split.splitFeature = j;
split.splitFeatureType = AttributeType.NUMERIC;
- split.splitValue = (x[i][j] + prevx) / 2;
+ split.splitValue = (x_ij + prevx) / 2;
split.splitScore = gain;
split.trueChildOutput = trueMean;
split.falseChildOutput = falseMean;
}
- prevx = x[i][j];
- trueSum += sample * y[i];
+ prevx = x_ij;
+ trueSum += sample * y_i;
trueCount += sample;
- }
- }
+ }//apply
+ });
+
} else {
throw new IllegalStateException("Unsupported attribute type: "
+ _attributes[j].type);
@@ -672,7 +663,7 @@ public final class RegressionTree implements Regression<double[]> {
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] == splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) == splitValue) {
trueBags.add(index);
tc++;
} else {
@@ -684,7 +675,7 @@ public final class RegressionTree implements Regression<double[]> {
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] <= splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) <= splitValue) {
trueBags.add(index);
tc++;
} else {
@@ -700,20 +691,19 @@ public final class RegressionTree implements Regression<double[]> {
}
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int maxLeafs) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int maxLeafs) {
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
}
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int maxLeafs, @Nullable smile.math.Random rand) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int maxLeafs, @Nullable PRNG rand) {
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
}
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
- int minLeafSize, @Nullable int[][] order, @Nullable int[] bags,
- @Nullable smile.math.Random rand) {
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
+ @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags, @Nullable PRNG rand) {
this(attributes, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, order, bags, null, rand);
}
@@ -723,24 +713,22 @@ public final class RegressionTree implements Regression<double[]> {
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
- * @param numVars the number of input variables to pick to split on at each node. It seems that
- * dim/3 give generally good performance, where dim is the number of variables.
+ * @param numVars the number of input variables to pick to split on at each node. It seems that dim/3 give generally good performance, where dim
+ * is the number of variables.
* @param maxLeafs the maximum number of leaf nodes in the tree.
- * @param minSplits number of instances in a node below which the tree will not split, setting S
- * = 5 generally gives good results.
- * @param order the index of training values in ascending order. Note that only numeric
- * attributes need be sorted.
+ * @param minSplits number of instances in a node below which the tree will not split, setting S = 5 generally gives good results.
+ * @param order the index of training values in ascending order. Note that only numeric attributes need be sorted.
* @param bags the sample set of instances for stochastic learning.
* @param output An interface to calculate node output.
*/
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
- int minLeafSize, @Nullable int[][] order, @Nullable int[] bags,
- @Nullable NodeOutput output, @Nullable smile.math.Random rand) {
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
+ @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags,
+ @Nullable NodeOutput output, @Nullable PRNG rand) {
checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
this._attributes = SmileExtUtils.attributeTypes(attributes, x);
- if (_attributes.length != x[0].length) {
+ if (_attributes.length != x.numColumns()) {
throw new IllegalArgumentException("-attrs option is invliad: "
+ Arrays.toString(attributes));
}
@@ -752,7 +740,7 @@ public final class RegressionTree implements Regression<double[]> {
this._minLeafSize = minLeafSize;
this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
this._importance = new double[_attributes.length];
- this._rnd = (rand == null) ? new smile.math.Random() : rand;
+ this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
this._nodeOutput = output;
int n = 0;
@@ -803,13 +791,13 @@ public final class RegressionTree implements Regression<double[]> {
}
}
- private static void checkArgument(@Nonnull double[][] x, @Nonnull double[] y, int numVars,
+ private static void checkArgument(@Nonnull Matrix x, @Nonnull double[] y, int numVars,
int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
- if (x.length != y.length) {
+ if (x.numRows() != y.length) {
throw new IllegalArgumentException(String.format(
- "The sizes of X and Y don't match: %d != %d", x.length, y.length));
+ "The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
}
- if (numVars <= 0 || numVars > x[0].length) {
+ if (numVars <= 0 || numVars > x.numColumns()) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
}
@@ -830,10 +818,8 @@ public final class RegressionTree implements Regression<double[]> {
}
/**
- * Returns the variable importance. Every time a split of a node is made on variable the
- * impurity criterion for the two descendent nodes is less than the parent node. Adding up the
- * decreases for each individual variable over the tree gives a simple measure of variable
- * importance.
+ * Returns the variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendent nodes is less
+ * than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
*
* @return the variable importance
*/
@@ -841,8 +827,13 @@ public final class RegressionTree implements Regression<double[]> {
return _importance;
}
+ @VisibleForTesting
+ public double predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
+ }
+
@Override
- public double predict(double[] x) {
+ public double predict(@Nonnull final Vector x) {
return _root.predict(x);
}
@@ -852,14 +843,6 @@ public final class RegressionTree implements Regression<double[]> {
return buf.toString();
}
- public String predictOpCodegen(@Nonnull String sep) {
- List<String> opslist = new ArrayList<String>();
- _root.opCodegen(opslist, 0);
- opslist.add("call end");
- String scripts = StringUtils.concat(opslist, sep);
- return scripts;
- }
-
@Nonnull
public byte[] predictSerCodegen(boolean compress) throws HiveException {
try {