You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/09/13 09:23:08 UTC
[incubator-hivemall] branch master updated: [HIVEMALL-245] Refactor
RandomForest for Sparse Data handling
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new af2afeb [HIVEMALL-245] Refactor RandomForest for Sparse Data handling
af2afeb is described below
commit af2afeb641ec729b1dbb505be9471d0f422f8dda
Author: Makoto Yui <my...@apache.org>
AuthorDate: Fri Sep 13 18:23:00 2019 +0900
[HIVEMALL-245] Refactor RandomForest for Sparse Data handling
## What changes were proposed in this pull request?
Refactor RandomForest for Sparse Data handling
## What type of PR is it?
Refactoring
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-245
https://issues.apache.org/jira/browse/HIVEMALL-171
## How was this patch tested?
unit tests, manual tests on EMR
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [ ] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #198 from myui/HIVEMALL-245.
---
.travis.yml | 1 +
.../math/matrix/ints/AbstractIntMatrix.java | 13 +
.../matrix/ints/ColumnMajorDenseIntMatrix2d.java | 27 +
.../math/matrix/ints/ColumnMajorIntMatrix.java | 7 +
.../java/hivemall/math/vector/AbstractVector.java | 5 +
.../src/main/java/hivemall/math/vector/Vector.java | 2 +
.../java/hivemall/math/vector/VectorProcedure.java | 3 +
.../smile/classification/DecisionTree.java | 785 ++++++++++++++-------
.../GradientTreeBoostingClassifierUDTF.java | 102 +--
.../classification/RandomForestClassifierUDTF.java | 118 ++--
.../java/hivemall/smile/data/AttributeType.java | 67 --
.../regression/RandomForestRegressionUDTF.java | 96 +--
.../hivemall/smile/regression/RegressionTree.java | 740 ++++++++++++-------
.../hivemall/smile/tools/TreePredictUDFv1.java | 44 +-
.../java/hivemall/smile/utils/SmileExtUtils.java | 135 ++--
.../java/hivemall/smile/utils/VariableOrder.java | 54 ++
.../utils/collections/arrays/DenseIntArray.java | 11 +-
.../utils/collections/arrays/IntArray.java | 4 +
.../utils/collections/arrays/SparseIntArray.java | 194 ++++-
.../function/Consumer.java} | 21 +-
.../IntArray.java => function/IntPredicate.java} | 31 +-
.../java/hivemall/utils/hadoop/SerdeUtils.java | 63 ++
.../main/java/hivemall/utils/lang/ArrayUtils.java | 111 +++
.../utils/lang/mutable/MutableBoolean.java | 87 +++
.../hivemall/utils/lang/mutable/MutableInt.java | 20 +-
.../main/java/hivemall/utils/math/MathUtils.java | 9 +
.../smile/classification/DecisionTreeTest.java | 67 +-
.../RandomForestClassifierUDTFTest.java | 4 +-
.../smile/regression/RegressionTreeTest.java | 15 +-
.../hivemall/smile/tools/TreePredictUDFTest.java | 10 +-
.../hivemall/smile/tools/TreePredictUDFv1Test.java | 10 +-
.../hivemall/smile/utils/SmileExtUtilsTest.java} | 31 +-
.../utils/collections/arrays/IntArrayTest.java | 1 +
.../collections/arrays/SparseIntArrayTest.java | 113 ++-
.../java/hivemall/utils/lang/ArrayUtilsTest.java | 68 ++
35 files changed, 2145 insertions(+), 924 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index dd0eedf..8b2427e 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -16,6 +16,7 @@ env:
# build_command: "mvn -DskipTests=true compile"
# branch_pattern: master
+dist: trusty
language: java
jdk:
# - openjdk7
diff --git a/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java
index 1c5fb6e..eedd616 100644
--- a/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java
@@ -91,6 +91,19 @@ public abstract class AbstractIntMatrix implements IntMatrix {
}
}
+ protected static void rangeCheck(final int length, final int fromIndex, final int toIndex) {
+ if (fromIndex > toIndex) {
+ throw new IllegalArgumentException(
+ "fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")");
+ }
+ if (fromIndex < 0) {
+ throw new ArrayIndexOutOfBoundsException(fromIndex);
+ }
+ if (toIndex > length) {
+ throw new ArrayIndexOutOfBoundsException(toIndex);
+ }
+ }
+
@Override
public void eachInRow(final int row, @Nonnull final VectorProcedure procedure) {
eachInRow(row, procedure, true);
diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
index d028d47..bd45481 100644
--- a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
+++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
@@ -169,4 +169,31 @@ public final class ColumnMajorDenseIntMatrix2d extends ColumnMajorIntMatrix {
}
}
+ @Override
+ public void eachNonNullInColumn(final int col, final int startRow, final int endRow,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+ rangeCheck(numRows, startRow, endRow);
+
+ final int[] colData = data[col];
+ if (colData == null) {
+ return;
+ }
+
+ for (int row = startRow, end = Math.min(endRow, colData.length); row < end; row++) {
+ procedure.apply(row, colData[row]);
+ }
+ }
+
+ @Override
+ public void eachRow(@Nonnull final VectorProcedure procedure) {
+ for (int col = 0; col < data.length; col++) {
+ final int[] row = data[col];
+ if (row == null) {
+ continue;
+ }
+ procedure.apply(col, row);
+ }
+ }
+
}
diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
index e0b3b4b..ff230a2 100644
--- a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
@@ -20,6 +20,8 @@ package hivemall.math.matrix.ints;
import hivemall.math.vector.VectorProcedure;
+import javax.annotation.Nonnull;
+
public abstract class ColumnMajorIntMatrix extends AbstractIntMatrix {
public ColumnMajorIntMatrix() {
@@ -36,4 +38,9 @@ public abstract class ColumnMajorIntMatrix extends AbstractIntMatrix {
throw new UnsupportedOperationException();
}
+ public abstract void eachNonNullInColumn(final int col, final int startRow, final int endRow,
+ @Nonnull final VectorProcedure procedure);
+
+ public abstract void eachRow(@Nonnull final VectorProcedure procedure);
+
}
diff --git a/core/src/main/java/hivemall/math/vector/AbstractVector.java b/core/src/main/java/hivemall/math/vector/AbstractVector.java
index 7c4579f..53f65b3 100644
--- a/core/src/main/java/hivemall/math/vector/AbstractVector.java
+++ b/core/src/main/java/hivemall/math/vector/AbstractVector.java
@@ -39,6 +39,11 @@ public abstract class AbstractVector implements Vector {
set(index, (double) value);
}
+ @Override
+ public void decr(int index, double delta) {
+ incr(index, -delta);
+ }
+
protected static final void checkIndex(final int index) {
if (index < 0) {
throw new IndexOutOfBoundsException("Invalid index " + index);
diff --git a/core/src/main/java/hivemall/math/vector/Vector.java b/core/src/main/java/hivemall/math/vector/Vector.java
index d1d3ebc..a032e86 100644
--- a/core/src/main/java/hivemall/math/vector/Vector.java
+++ b/core/src/main/java/hivemall/math/vector/Vector.java
@@ -41,6 +41,8 @@ public interface Vector {
public void incr(@Nonnegative int index, double delta);
+ public void decr(@Nonnegative int index, double delta);
+
public void each(@Nonnull VectorProcedure procedure);
public int size();
diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
index 4978885..8ce473b 100644
--- a/core/src/main/java/hivemall/math/vector/VectorProcedure.java
+++ b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
@@ -19,6 +19,7 @@
package hivemall.math.vector;
import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
public abstract class VectorProcedure {
@@ -40,4 +41,6 @@ public abstract class VectorProcedure {
public void apply(@Nonnegative int i) {}
+ public void apply(@Nonnegative int i, @Nonnull int[] values) {}
+
}
diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
index 00ebae3..95b4b2a 100644
--- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
+++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
@@ -17,21 +17,26 @@
// https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/classification/DecisionTree.java
package hivemall.smile.classification;
+import static hivemall.smile.utils.SmileExtUtils.NOMINAL;
+import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
import static hivemall.smile.utils.SmileExtUtils.resolveName;
import hivemall.annotations.VisibleForTesting;
import hivemall.math.matrix.Matrix;
-import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
+import hivemall.smile.utils.VariableOrder;
+import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.function.Consumer;
+import hivemall.utils.function.IntPredicate;
+import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.lang.mutable.MutableInt;
@@ -53,6 +58,8 @@ import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.IntConsumer;
import org.roaringbitmap.RoaringBitmap;
@@ -110,13 +117,46 @@ import org.roaringbitmap.RoaringBitmap;
* Some techniques such as bagging, boosting, and random forest use more than one decision tree for
* their analysis.
*/
-public final class DecisionTree implements Classifier<Vector> {
+public class DecisionTree implements Classifier<Vector> {
+ private static final Log logger = LogFactory.getLog(DecisionTree.class);
+
+ /**
+ * Training dataset.
+ */
+ @Nonnull
+ private final Matrix _X;
+ /**
+ * class labels.
+ */
+ @Nonnull
+ private final int[] _y;
+ /**
+ * The samples for training this node. Note that samples[i] is the number of sampling of
+ * dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible
+ * because of sampling with replacement.
+ */
+ @Nonnull
+ private final int[] _samples;
+ /**
+ * An index of training values. Initially, order[j] is a set of indices that iterate through the
+ * training values for attribute j in ascending order. During training, the array is rearranged
+ * so that all values for each leaf node occupy a contiguous range, but within that range they
+ * maintain the original ordering. Note that only numeric attributes will be sorted; non-numeric
+ * attributes will have a null in the corresponding place in the array.
+ */
+ @Nonnull
+ private final VariableOrder _order;
+ /**
+ * An index that maps their current position in the {@link #_order} to their original locations
+ * in {@link #_samples}.
+ */
+ @Nonnull
+ private final int[] _sampleIndex;
/**
* The attributes of independent variable.
*/
@Nonnull
- private final AttributeType[] _attributes;
- private final boolean _hasNumericType;
+ private final RoaringBitmap _nominalAttrs;
/**
* Variable importance. Every time a split of a node is made on variable the (GINI, information
* gain, etc.) impurity criterion for the two descendant nodes is less than the parent node.
@@ -150,19 +190,15 @@ public final class DecisionTree implements Classifier<Vector> {
/**
* The number of instances in a node below which the tree will not split.
*/
- private final int _minSplit;
+ private final int _minSamplesSplit;
/**
- * The minimum number of samples in a leaf node
+ * The minimum number of samples in a leaf node.
*/
- private final int _minLeafSize;
+ private final int _minSamplesLeaf;
/**
- * The index of training values in ascending order. Note that only numeric attributes will be
- * sorted.
+ * The random number generator.
*/
@Nonnull
- private final ColumnMajorIntMatrix _order;
-
- @Nonnull
private final PRNG _rnd;
/**
@@ -209,7 +245,7 @@ public final class DecisionTree implements Classifier<Vector> {
/**
* The type of split feature
*/
- AttributeType splitFeatureType = null;
+ boolean quantitativeFeature = true;
/**
* The split value.
*/
@@ -226,24 +262,28 @@ public final class DecisionTree implements Classifier<Vector> {
* Children node.
*/
Node falseChild = null;
- /**
- * Predicted output for children node.
- */
- int trueChildOutput = -1;
- /**
- * Predicted output for children node.
- */
- int falseChildOutput = -1;
public Node() {}// for Externalizable
+ public Node(@Nonnull double[] posteriori) {
+ this(Math.whichMax(posteriori), posteriori);
+ }
+
public Node(int output, @Nonnull double[] posteriori) {
this.output = output;
this.posteriori = posteriori;
}
private boolean isLeaf() {
- return posteriori != null;
+ return trueChild == null && falseChild == null;
+ }
+
+ private void markAsLeaf() {
+ this.splitFeature = -1;
+ this.splitValue = Double.NaN;
+ this.splitScore = 0.0;
+ this.trueChild = null;
+ this.falseChild = null;
}
@VisibleForTesting
@@ -255,24 +295,21 @@ public final class DecisionTree implements Classifier<Vector> {
* Evaluate the regression tree over an instance.
*/
public int predict(@Nonnull final Vector x) {
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
return output;
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- if (x.get(splitFeature, Double.NaN) == splitValue) {
+ if (quantitativeFeature) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x.get(splitFeature, Double.NaN) <= splitValue) {
+ } else {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
}
@@ -281,24 +318,21 @@ public final class DecisionTree implements Classifier<Vector> {
* Evaluate the regression tree over an instance.
*/
public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) {
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
handler.handle(output, posteriori);
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- if (x.get(splitFeature, Double.NaN) == splitValue) {
+ if (quantitativeFeature) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
trueChild.predict(x, handler);
} else {
falseChild.predict(x, handler);
}
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x.get(splitFeature, Double.NaN) <= splitValue) {
+ } else {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
trueChild.predict(x, handler);
} else {
falseChild.predict(x, handler);
}
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
}
@@ -306,42 +340,39 @@ public final class DecisionTree implements Classifier<Vector> {
public void exportJavascript(@Nonnull final StringBuilder builder,
@Nullable final String[] featureNames, @Nullable final String[] classNames,
final int depth) {
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
indent(builder, depth);
builder.append("").append(resolveName(output, classNames)).append(";\n");
} else {
indent(builder, depth);
- if (splitFeatureType == AttributeType.NOMINAL) {
+ if (quantitativeFeature) {
if (featureNames == null) {
builder.append("if( x[")
.append(splitFeature)
- .append("] == ")
+ .append("] <= ")
.append(splitValue)
.append(" ) {\n");
} else {
builder.append("if( ")
.append(resolveFeatureName(splitFeature, featureNames))
- .append(" == ")
+ .append(" <= ")
.append(splitValue)
.append(" ) {\n");
}
- } else if (splitFeatureType == AttributeType.NUMERIC) {
+ } else {
if (featureNames == null) {
builder.append("if( x[")
.append(splitFeature)
- .append("] <= ")
+ .append("] == ")
.append(splitValue)
.append(" ) {\n");
} else {
builder.append("if( ")
.append(resolveFeatureName(splitFeature, featureNames))
- .append(" <= ")
+ .append(" == ")
.append(splitValue)
.append(" ) {\n");
}
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
trueChild.exportJavascript(builder, featureNames, classNames, depth + 1);
indent(builder, depth);
@@ -358,7 +389,7 @@ public final class DecisionTree implements Classifier<Vector> {
@Nonnull final MutableInt nodeIdGenerator, final int parentNodeId) {
final int myNodeId = nodeIdGenerator.getValue();
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
// fillcolor=h,s,v
// https://en.wikipedia.org/wiki/HSL_and_HSV
// http://www.graphviz.org/doc/info/attrs.html#k:colorList
@@ -382,21 +413,17 @@ public final class DecisionTree implements Classifier<Vector> {
builder.append(";\n");
}
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- builder.append(
- String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId,
- resolveFeatureName(splitFeature, featureNames),
- Double.toString(splitValue)));
- } else if (splitFeatureType == AttributeType.NUMERIC) {
+ if (quantitativeFeature) {
builder.append(
String.format(" %d [label=<%s ≤ %s>, fillcolor=\"#00000000\"];\n",
myNodeId, resolveFeatureName(splitFeature, featureNames),
Double.toString(splitValue)));
} else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
+ builder.append(
+ String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId,
+ resolveFeatureName(splitFeature, featureNames),
+ Double.toString(splitValue)));
}
-
if (myNodeId != parentNodeId) {
builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
if (parentNodeId == 0) {//only draw edge label on top
@@ -424,7 +451,7 @@ public final class DecisionTree implements Classifier<Vector> {
public int opCodegen(@Nonnull final List<String> scripts, int depth) {
int selfDepth = 0;
final StringBuilder buf = new StringBuilder();
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
buf.append("push ").append(output);
scripts.add(buf.toString());
buf.setLength(0);
@@ -432,41 +459,38 @@ public final class DecisionTree implements Classifier<Vector> {
scripts.add(buf.toString());
selfDepth += 2;
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
+ if (quantitativeFeature) {
buf.append("push ").append("x[").append(splitFeature).append("]");
scripts.add(buf.toString());
buf.setLength(0);
buf.append("push ").append(splitValue);
scripts.add(buf.toString());
buf.setLength(0);
- buf.append("ifeq ");
+ buf.append("ifle ");
scripts.add(buf.toString());
depth += 3;
selfDepth += 3;
int trueDepth = trueChild.opCodegen(scripts, depth);
selfDepth += trueDepth;
- scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
+ scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
selfDepth += falseDepth;
- } else if (splitFeatureType == AttributeType.NUMERIC) {
+ } else {
buf.append("push ").append("x[").append(splitFeature).append("]");
scripts.add(buf.toString());
buf.setLength(0);
buf.append("push ").append(splitValue);
scripts.add(buf.toString());
buf.setLength(0);
- buf.append("ifle ");
+ buf.append("ifeq ");
scripts.add(buf.toString());
depth += 3;
selfDepth += 3;
int trueDepth = trueChild.opCodegen(scripts, depth);
selfDepth += trueDepth;
- scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
+ scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
selfDepth += falseDepth;
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
return selfDepth;
@@ -475,11 +499,7 @@ public final class DecisionTree implements Classifier<Vector> {
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(splitFeature);
- if (splitFeatureType == null) {
- out.writeByte(-1);
- } else {
- out.writeByte(splitFeatureType.getTypeId());
- }
+ out.writeByte(quantitativeFeature ? NUMERIC : NOMINAL);
out.writeDouble(splitValue);
if (isLeaf()) {
@@ -511,12 +531,8 @@ public final class DecisionTree implements Classifier<Vector> {
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.splitFeature = in.readInt();
- byte typeId = in.readByte();
- if (typeId == -1) {
- this.splitFeatureType = null;
- } else {
- this.splitFeatureType = AttributeType.resolve(typeId);
- }
+ final byte typeId = in.readByte();
+ this.quantitativeFeature = (typeId == NUMERIC);
this.splitValue = in.readDouble();
if (in.readBoolean()) {//isLeaf
@@ -555,29 +571,44 @@ public final class DecisionTree implements Classifier<Vector> {
/**
* The associated regression tree node.
*/
+ @Nonnull
final Node node;
/**
- * Training dataset.
+ * Depth of the node in the tree
*/
- final Matrix x;
+ final int depth;
+ /**
+ * The lower bound (inclusive) in the order array of the samples belonging to this node.
+ */
+ final int low;
/**
- * class labels.
+ * The upper bound (exclusive) in the order array of the samples belonging to this node.
*/
- final int[] y;
+ final int high;
+ /**
+ * The number of samples
+ */
+ final int samples;
- int[] bags;
+ @Nullable
+ int[] constFeatures;
- final int depth;
+ public TrainNode(@Nonnull Node node, int depth, int low, int high, int samples) {
+ this(node, depth, low, high, samples, new int[0]);
+ }
- /**
- * Constructor.
- */
- public TrainNode(Node node, Matrix x, int[] y, int[] bags, int depth) {
+ public TrainNode(@Nonnull Node node, int depth, int low, int high, int samples,
+ @Nonnull int[] constFeatures) {
+ if (low >= high) {
+ throw new IllegalArgumentException(
+ "Unexpected condition was met. low=" + low + ", high=" + high);
+ }
this.node = node;
- this.x = x;
- this.y = y;
- this.bags = bags;
this.depth = depth;
+ this.low = low;
+ this.high = high;
+ this.samples = samples;
+ this.constFeatures = constFeatures;
}
@Override
@@ -596,35 +627,30 @@ public final class DecisionTree implements Classifier<Vector> {
return false;
}
// avoid split if the number of samples is less than threshold
- final int numSamples = bags.length;
- if (numSamples <= _minSplit) {
+ if (samples <= _minSamplesSplit) {
return false;
}
// Sample count in each class.
final int[] count = new int[_k];
- final boolean pure = sampleCount(count);
-
- // Since all instances have same label, stop splitting.
- if (pure) {
+ final boolean pure = countSamples(count);
+ if (pure) {// if all instances have same label, stop splitting.
return false;
}
- final double impurity = impurity(count, numSamples, _rule);
-
- final int[] samples =
- _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows()) : null;
+ final int[] constFeatures_ = this.constFeatures; // this.constFeatures may be replace in findBestSplit but it's accepted
+ final double impurity = impurity(count, samples, _rule);
final int[] falseCount = new int[_k];
- for (int varJ : variableIndex(x, bags)) {
- final Node split =
- findBestSplit(numSamples, count, falseCount, impurity, varJ, samples);
+ for (int varJ : variableIndex()) {
+ if (ArrayUtils.contains(constFeatures_, varJ)) {
+ continue; // skip constant features
+ }
+ final Node split = findBestSplit(samples, count, falseCount, impurity, varJ);
if (split.splitScore > node.splitScore) {
node.splitFeature = split.splitFeature;
- node.splitFeatureType = split.splitFeatureType;
+ node.quantitativeFeature = split.quantitativeFeature;
node.splitValue = split.splitValue;
node.splitScore = split.splitScore;
- node.trueChildOutput = split.trueChildOutput;
- node.falseChildOutput = split.falseChildOutput;
}
}
@@ -632,17 +658,22 @@ public final class DecisionTree implements Classifier<Vector> {
}
@Nonnull
- private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) {
+ private int[] variableIndex() {
+ final Matrix X = _X;
final IntReservoirSampler sampler = new IntReservoirSampler(_numVars, _rnd.nextLong());
- if (x.isSparse()) {
+ if (X.isSparse()) {
+ // sample columns from sampled examples
final RoaringBitmap cols = new RoaringBitmap();
final VectorProcedure proc = new VectorProcedure() {
public void apply(final int col) {
cols.add(col);
}
};
- for (final int row : bags) {
- x.eachColumnIndexInRow(row, proc);
+ final int[] sampleIndex = _sampleIndex;
+ for (int i = low, end = high; i < end; i++) {
+ int row = sampleIndex[i];
+ assert (_samples[row] != 0) : row;
+ X.eachColumnIndexInRow(row, proc);
}
cols.forEach(new IntConsumer() {
public void accept(final int k) {
@@ -650,20 +681,25 @@ public final class DecisionTree implements Classifier<Vector> {
}
});
} else {
- for (int i = 0, size = _attributes.length; i < size; i++) {
+ final int ncols = X.numColumns();
+ for (int i = 0; i < ncols; i++) {
sampler.add(i);
}
}
return sampler.getSample();
}
- private boolean sampleCount(@Nonnull final int[] count) {
- int label = -1;
+ private boolean countSamples(@Nonnull final int[] count) {
+ final int[] sampleIndex = _sampleIndex;
+ final int[] samples = _samples;
+ final int[] y = _y;
+
boolean pure = true;
- for (int i = 0; i < bags.length; i++) {
- int index = bags[i];
+
+ for (int i = low, end = high, label = -1; i < end; i++) {
+ int index = sampleIndex[i];
int y_i = y[index];
- count[y_i]++;
+ count[y_i] += samples[index];
if (label == -1) {
label = y_i;
@@ -671,6 +707,7 @@ public final class DecisionTree implements Classifier<Vector> {
pure = false;
}
}
+
return pure;
}
@@ -684,24 +721,44 @@ public final class DecisionTree implements Classifier<Vector> {
* @param j the attribute index to split on.
*/
private Node findBestSplit(final int n, final int[] count, final int[] falseCount,
- final double impurity, final int j, @Nullable final int[] samples) {
+ final double impurity, final int j) {
+ final int[] samples = _samples;
+ final int[] sampleIndex = _sampleIndex;
+ final Matrix X = _X;
+ final int[] y = _y;
+ final int classes = _k;
+
final Node splitNode = new Node();
- if (_attributes[j] == AttributeType.NOMINAL) {
+ if (_nominalAttrs.contains(j)) {
final Int2ObjectMap<int[]> trueCount = new Int2ObjectOpenHashMap<int[]>();
- for (int i = 0, size = bags.length; i < size; i++) {
- int index = bags[i];
- final double v = x.get(index, j, Double.NaN);
+ int countNaN = 0;
+ for (int i = low, end = high; i < end; i++) {
+ final int index = sampleIndex[i];
+ final int numSamples = samples[index];
+ if (numSamples == 0) {
+ continue;
+ }
+
+ final double v = X.get(index, j, Double.NaN);
if (Double.isNaN(v)) {
+ countNaN++;
continue;
}
int x_ij = (int) v;
+
int[] tc_x = trueCount.get(x_ij);
if (tc_x == null) {
- tc_x = new int[_k];
+ tc_x = new int[classes];
+ trueCount.put(x_ij, tc_x);
}
- tc_x[y[index]]++;
+ int y_i = y[index];
+ tc_x[y_i] += numSamples;
+ }
+ final int countDistinctX = trueCount.size() + (countNaN == 0 ? 0 : 1);
+ if (countDistinctX <= 1) { // mark as a constant feature
+ this.constFeatures = ArrayUtils.sortedArraySet(constFeatures, j);
}
for (Int2ObjectMap.Entry<int[]> e : trueCount.int2ObjectEntrySet()) {
@@ -712,12 +769,12 @@ public final class DecisionTree implements Classifier<Vector> {
final int fc = n - tc;
// skip splitting this feature.
- if (tc < _minSplit || fc < _minSplit) {
+ if (tc < _minSamplesSplit || fc < _minSamplesSplit) {
continue;
}
- for (int q = 0; q < _k; q++) {
- falseCount[q] = count[q] - trueCount_l[q];
+ for (int k = 0; k < classes; k++) {
+ falseCount[k] = count[k] - trueCount_l[k];
}
final double gain =
@@ -727,36 +784,42 @@ public final class DecisionTree implements Classifier<Vector> {
if (gain > splitNode.splitScore) {
// new best split
splitNode.splitFeature = j;
- splitNode.splitFeatureType = AttributeType.NOMINAL;
+ splitNode.quantitativeFeature = false;
splitNode.splitValue = l;
splitNode.splitScore = gain;
- splitNode.trueChildOutput = Math.whichMax(trueCount_l);
- splitNode.falseChildOutput = Math.whichMax(falseCount);
}
}
- } else if (_attributes[j] == AttributeType.NUMERIC) {
- final int[] trueCount = new int[_k];
+ } else {
+ final int[] trueCount = new int[classes];
+ final MutableInt countNaN = new MutableInt(0);
+ final MutableInt replaceCount = new MutableInt(0);
- _order.eachNonNullInColumn(j, new VectorProcedure() {
- double prevx = Double.NaN;
+ _order.eachNonNullInColumn(j, low, high, new Consumer() {
+ double prevx = Double.NaN, lastx = Double.NaN;
int prevy = -1;
- public void apply(final int row, final int i) {
- final int sample = samples[i];
- if (sample == 0) {
+ @Override
+ public void accept(int pos, final int i) {
+ final int numSamples = samples[i];
+ if (numSamples == 0) {
return;
}
- final double x_ij = x.get(i, j, Double.NaN);
+ final double x_ij = X.get(i, j, Double.NaN);
if (Double.isNaN(x_ij)) {
+ countNaN.incr();
return;
}
- final int y_i = y[i];
+ if (lastx != x_ij) {
+ lastx = x_ij;
+ replaceCount.incr();
+ }
+ final int y_i = y[i];
if (Double.isNaN(prevx) || x_ij == prevx || y_i == prevy) {
prevx = x_ij;
prevy = y_i;
- trueCount[y_i] += sample;
+ trueCount[y_i] += numSamples;
return;
}
@@ -764,14 +827,14 @@ public final class DecisionTree implements Classifier<Vector> {
final int fc = n - tc;
// skip splitting this feature.
- if (tc < _minSplit || fc < _minSplit) {
+ if (tc < _minSamplesSplit || fc < _minSamplesSplit) {
prevx = x_ij;
prevy = y_i;
- trueCount[y_i] += sample;
+ trueCount[y_i] += numSamples;
return;
}
- for (int l = 0; l < _k; l++) {
+ for (int l = 0; l < classes; l++) {
falseCount[l] = count[l] - trueCount[l];
}
@@ -782,20 +845,21 @@ public final class DecisionTree implements Classifier<Vector> {
if (gain > splitNode.splitScore) {
// new best split
splitNode.splitFeature = j;
- splitNode.splitFeatureType = AttributeType.NUMERIC;
+ splitNode.quantitativeFeature = true;
splitNode.splitValue = (x_ij + prevx) / 2.d;
splitNode.splitScore = gain;
- splitNode.trueChildOutput = Math.whichMax(trueCount);
- splitNode.falseChildOutput = Math.whichMax(falseCount);
}
prevx = x_ij;
prevy = y_i;
- trueCount[y_i] += sample;
+ trueCount[y_i] += numSamples;
}//apply()
});
- } else {
- throw new IllegalStateException("Unsupported attribute type: " + _attributes[j]);
+
+ final int countDistinctX = replaceCount.get() + (countNaN.get() == 0 ? 0 : 1);
+ if (countDistinctX <= 1) { // mark as a constant feature
+ this.constFeatures = ArrayUtils.sortedArraySet(constFeatures, j);
+ }
}
return splitNode;
@@ -803,110 +867,240 @@ public final class DecisionTree implements Classifier<Vector> {
/**
* Split the node into two children nodes. Returns true if split success.
+ *
+ * @return true if split occurred. false if the node is set to leaf.
*/
public boolean split(@Nullable final PriorityQueue<TrainNode> nextSplits) {
if (node.splitFeature < 0) {
throw new IllegalStateException("Split a node with invalid feature.");
}
- // split sample bags
- int childBagSize = (int) (bags.length * 0.4);
- IntArrayList trueBags = new IntArrayList(childBagSize);
- IntArrayList falseBags = new IntArrayList(childBagSize);
- double[] trueChildPosteriori = new double[_k];
- double[] falseChildPosteriori = new double[_k];
- int tc = splitSamples(trueBags, falseBags, trueChildPosteriori, falseChildPosteriori);
- int fc = bags.length - tc;
- this.bags = null; // help GC for recursive call
-
- if (tc < _minLeafSize || fc < _minLeafSize) {
- // set the node as leaf
- node.splitFeature = -1;
- node.splitFeatureType = null;
- node.splitValue = Double.NaN;
- node.splitScore = 0.0;
+ final IntPredicate goesLeft = getPredicate();
+
+ // split samples
+ final int tc, fc, pivot;
+ final double[] trueChildPosteriori = new double[_k],
+ falseChildPosteriori = new double[_k];
+ {
+ MutableInt tc_ = new MutableInt(0);
+ MutableInt fc_ = new MutableInt(0);
+ pivot = splitSamples(tc_, fc_, trueChildPosteriori, falseChildPosteriori, goesLeft);
+ tc = tc_.get();
+ fc = fc_.get();
+ }
+
+ if (tc < _minSamplesLeaf || fc < _minSamplesLeaf) {
+ node.markAsLeaf();
return false;
}
for (int i = 0; i < _k; i++) {
- trueChildPosteriori[i] /= tc;
+ trueChildPosteriori[i] /= tc; // divide by zero never happens
falseChildPosteriori[i] /= fc;
}
- node.trueChild = new Node(node.trueChildOutput, trueChildPosteriori);
+ partitionOrder(low, pivot, high, goesLeft);
+
+ int leaves = 0;
+
+ node.trueChild = new Node(trueChildPosteriori);
TrainNode trueChild =
- new TrainNode(node.trueChild, x, y, trueBags.toArray(), depth + 1);
- trueBags = null; // help GC for recursive call
- if (tc >= _minSplit && trueChild.findBestSplit()) {
+ new TrainNode(node.trueChild, depth + 1, low, pivot, tc, constFeatures.clone());
+ node.falseChild = new Node(falseChildPosteriori);
+ TrainNode falseChild =
+ new TrainNode(node.falseChild, depth + 1, pivot, high, fc, constFeatures);
+ this.constFeatures = null;
+
+ if (tc >= _minSamplesSplit && trueChild.findBestSplit()) {
if (nextSplits != null) {
nextSplits.add(trueChild);
} else {
- trueChild.split(null);
+ if (trueChild.split(null) == false) {
+ leaves++;
+ }
}
+ } else {
+ leaves++;
}
- node.falseChild = new Node(node.falseChildOutput, falseChildPosteriori);
- TrainNode falseChild =
- new TrainNode(node.falseChild, x, y, falseBags.toArray(), depth + 1);
- falseBags = null; // help GC for recursive call
- if (fc >= _minSplit && falseChild.findBestSplit()) {
+ if (fc >= _minSamplesSplit && falseChild.findBestSplit()) {
if (nextSplits != null) {
nextSplits.add(falseChild);
} else {
- falseChild.split(null);
+ if (falseChild.split(null) == false) {
+ leaves++;
+ }
+ }
+ } else {
+ leaves++;
+ }
+
+ // Prune meaningless branches
+ if (leaves == 2) {// both left and right child are leaf node
+ if (node.trueChild.output == node.falseChild.output) {// found a meaningless branch
+ node.markAsLeaf();
+ return false;
}
}
_importance.incr(node.splitFeature, node.splitScore);
- node.posteriori = null; // a posteriori is not needed for non-leaf nodes
+ if (nextSplits == null) {
+ // For depth-first splitting, a posteriori is not needed for non-leaf nodes
+ node.posteriori = null;
+ }
return true;
}
/**
- * @param falseChildPosteriori
- * @param trueChildPosteriori
- * @return the number of true samples
+ * @return Pivot to split samples
*/
- private int splitSamples(@Nonnull final IntArrayList trueBags,
- @Nonnull final IntArrayList falseBags, @Nonnull final double[] trueChildPosteriori,
- @Nonnull final double[] falseChildPosteriori) {
- int tc = 0;
- if (node.splitFeatureType == AttributeType.NOMINAL) {
- final int splitFeature = node.splitFeature;
- final double splitValue = node.splitValue;
- for (int i = 0, size = bags.length; i < size; i++) {
- final int index = bags[i];
- if (x.get(index, splitFeature, Double.NaN) == splitValue) {
- trueBags.add(index);
- trueChildPosteriori[y[index]]++;
- tc++;
- } else {
- falseBags.add(index);
- falseChildPosteriori[y[index]]++;
- }
+ private int splitSamples(@Nonnull final MutableInt tc, @Nonnull final MutableInt fc,
+ @Nonnull final double[] trueChildPosteriori,
+ @Nonnull final double[] falseChildPosteriori,
+ @Nonnull final IntPredicate goesLeft) {
+ final int[] sampleIndex = _sampleIndex;
+ final int[] samples = _samples;
+ final int[] y = _y;
+
+ int pivot = low;
+ for (int k = low, end = high; k < end; k++) {
+ final int i = sampleIndex[k];
+ final int numSamples = samples[i];
+ final int yi = y[i];
+ if (goesLeft.test(i)) {
+ tc.addValue(numSamples);
+ trueChildPosteriori[yi] += numSamples;
+ pivot++;
+ } else {
+ fc.addValue(numSamples);
+ falseChildPosteriori[yi] += numSamples;
}
- } else if (node.splitFeatureType == AttributeType.NUMERIC) {
- final int splitFeature = node.splitFeature;
- final double splitValue = node.splitValue;
- for (int i = 0, size = bags.length; i < size; i++) {
- final int index = bags[i];
- if (x.get(index, splitFeature, Double.NaN) <= splitValue) {
- trueBags.add(index);
- trueChildPosteriori[y[index]]++;
- tc++;
- } else {
- falseBags.add(index);
- falseChildPosteriori[y[index]]++;
- }
+ }
+ return pivot;
+ }
+
+ /**
+ * Modifies {@link #_order} and {@link #_sampleIndex} by partitioning the range from low
+ * (inclusive) to high (exclusive) so that all elements i for which goesLeft(i) is true come
+ * before all elements for which it is false, but element ordering is otherwise preserved.
+ * The number of true values returned by goesLeft must equal split-low.
+ *
+ * @param low the low bound of the segment of the order arrays which will be partitioned.
+ * @param split where the partition's split point will end up.
+ * @param high the high bound of the segment of the order arrays which will be partitioned.
+ * @param goesLeft whether an element goes to the left side or the right side of the
+ * partition.
+ * @param buffer scratch space large enough to hold all elements for which goesLeft is
+ * false.
+ */
+ private void partitionOrder(final int low, final int pivot, final int high,
+ @Nonnull final IntPredicate goesLeft) {
+ final int[] buf = new int[high - pivot];
+ _order.eachRow(new Consumer() {
+ @Override
+ public void accept(int col, @Nonnull final SparseIntArray row) {
+ partitionArray(row, low, pivot, high, goesLeft, buf);
}
+ });
+ partitionArray(_sampleIndex, low, pivot, high, goesLeft, buf);
+ }
+
+ @Nonnull
+ private IntPredicate getPredicate() {
+ if (node.quantitativeFeature) {
+ return new IntPredicate() {
+ @Override
+ public boolean test(int i) {
+ return _X.get(i, node.splitFeature, Double.NaN) <= node.splitValue;
+ }
+ };
+ } else {
+ return new IntPredicate() {
+ @Override
+ public boolean test(int i) {
+ return _X.get(i, node.splitFeature, Double.NaN) == node.splitValue;
+ }
+ };
+ }
+ }
+
+ }
+
+ private static void partitionArray(@Nonnull final SparseIntArray a, final int low,
+ final int pivot, final int high, @Nonnull final IntPredicate goesLeft,
+ @Nonnull final int[] buf) {
+ final int[] rowIndexes = a.keys();
+ final int[] rowPtrs = a.values();
+ final int size = a.size();
+
+ final int startPos = ArrayUtils.insertionPoint(rowIndexes, size, low);
+ final int endPos = ArrayUtils.insertionPoint(rowIndexes, size, high);
+ int pos = startPos, k = 0, j = low;
+ for (int i = startPos; i < endPos; i++) {
+ final int rowPtr = rowPtrs[i];
+ if (goesLeft.test(rowPtr)) {
+ rowIndexes[pos] = j;
+ rowPtrs[pos] = rowPtr;
+ pos++;
+ j++;
} else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + node.splitFeatureType);
+ if (k >= buf.length) {
+ throw new IndexOutOfBoundsException(String.format(
+ "low=%d, pivot=%d, high=%d, a.size()=%d, buf.length=%d, i=%d, j=%d, k=%d, startPos=%d, endPos=%d\na=%s\nbuf=%s",
+ low, pivot, high, a.size(), buf.length, i, j, k, startPos, endPos,
+ a.toString(), Arrays.toString(buf)));
+ }
+ buf[k++] = rowPtr;
}
- return tc;
}
+ for (int i = 0; i < k; i++) {
+ rowIndexes[pos] = pivot + i;
+ rowPtrs[pos] = buf[i];
+ pos++;
+ }
+ if (pos != endPos) {
+ throw new IllegalStateException(
+ String.format("pos=%d, startPos=%d, endPos=%d, k=%d\na=%s", pos, startPos, endPos,
+ k, a.toString()));
+ }
+ }
+ /**
+ * Modifies an array in-place by partitioning the range from low (inclusive) to high (exclusive)
+ * so that all elements i for which goesLeft(i) is true come before all elements for which it is
+ * false, but element ordering is otherwise preserved. The number of true values returned by
+ * goesLeft must equal split-low. buf is scratch space large enough (i.e., at least high-split
+ * long) to hold all elements for which goesLeft is false.
+ */
+ private static void partitionArray(@Nonnull final int[] a, final int low, final int pivot,
+ final int high, @Nonnull final IntPredicate goesLeft, @Nonnull final int[] buf) {
+ int j = low;
+ int k = 0;
+ for (int i = low; i < high; i++) {
+ if (i >= a.length) {
+ throw new IndexOutOfBoundsException(String.format(
+ "low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", low,
+ pivot, high, a.length, buf.length, i, j, k));
+ }
+ final int rowPtr = a[i];
+ if (goesLeft.test(rowPtr)) {
+ a[j++] = rowPtr;
+ } else {
+ if (k >= buf.length) {
+ throw new IndexOutOfBoundsException(String.format(
+ "low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d",
+ low, pivot, high, a.length, buf.length, i, j, k));
+ }
+ buf[k++] = rowPtr;
+ }
+ }
+ if (k != high - pivot || j != pivot) {
+ throw new IndexOutOfBoundsException(
+ String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, j=%d, k=%d",
+ low, pivot, high, a.length, buf.length, j, k));
+ }
+ System.arraycopy(buf, 0, a, pivot, k);
}
/**
@@ -924,8 +1118,7 @@ public final class DecisionTree implements Classifier<Vector> {
switch (rule) {
case GINI: {
impurity = 1.0;
- for (int i = 0; i < count.length; i++) {
- final int count_i = count[i];
+ for (int count_i : count) {
if (count_i > 0) {
double p = (double) count_i / n;
impurity -= p * p;
@@ -934,8 +1127,7 @@ public final class DecisionTree implements Classifier<Vector> {
break;
}
case ENTROPY: {
- for (int i = 0; i < count.length; i++) {
- final int count_i = count[i];
+ for (int count_i : count) {
if (count_i > 0) {
double p = (double) count_i / n;
impurity -= p * Math.log2(p);
@@ -945,8 +1137,7 @@ public final class DecisionTree implements Classifier<Vector> {
}
case CLASSIFICATION_ERROR: {
impurity = 0.d;
- for (int i = 0; i < count.length; i++) {
- final int count_i = count[i];
+ for (int count_i : count) {
if (count_i > 0) {
impurity = Math.max(impurity, (double) count_i / n);
}
@@ -959,74 +1150,121 @@ public final class DecisionTree implements Classifier<Vector> {
return impurity;
}
- public DecisionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix x, @Nonnull int[] y,
- int numLeafs) {
- this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null);
+ /**
+ * Prunes redundant leaves from the tree. In some cases, a node is split into two leaves that
+ * get assigned the same label, so this recursively combines leaves when it notices this
+ * situation.
+ */
+ private static void pruneRedundantLeaves(@Nonnull final Node node, @Nonnull Vector importance) {
+ if (node.isLeaf()) {
+ return;
+ }
+
+ // The children might not be leaves now, but might collapse into leaves given the chance.
+ pruneRedundantLeaves(node.trueChild, importance);
+ pruneRedundantLeaves(node.falseChild, importance);
+
+ if (node.trueChild.isLeaf() && node.falseChild.isLeaf()
+ && node.trueChild.output == node.falseChild.output) {
+ node.trueChild = null;
+ node.falseChild = null;
+ importance.decr(node.splitFeature, node.splitScore);
+ } else {
+ // a posteriori is not needed for non-leaf nodes
+ node.posteriori = null;
+ }
+ }
+
+ public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull int[] y,
+ int numSamplesLeaf) {
+ this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, numSamplesLeaf, 2, 1, null, SplitRule.GINI, null);
}
- public DecisionTree(@Nullable AttributeType[] attributes, @Nullable Matrix x, @Nullable int[] y,
- int numLeafs, @Nullable PRNG rand) {
- this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand);
+ public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nullable Matrix x, @Nullable int[] y,
+ int numSamplesLeaf, @Nullable PRNG rand) {
+ this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, numSamplesLeaf, 2, 1, null, SplitRule.GINI, rand);
}
/**
* Constructor. Learns a classification tree for random forest.
*
- * @param attributes the attribute properties.
+ * @param nominalAttrs the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param numVars the number of input variables to pick to split on at each node. It seems that
* dim/3 give generally good performance, where dim is the number of variables.
- * @param maxLeafs the maximum number of leaf nodes in the tree.
- * @param minSplits the number of minimum elements in a node to split
- * @param minLeafSize the minimum size of leaf nodes.
- * @param order the index of training values in ascending order. Note that only numeric
- * attributes need be sorted.
- * @param bags the sample set of instances for stochastic learning.
+ * @param maxLeafNodes the maximum number of leaf nodes in the tree.
+ * @param minSamplesSplit the number of minimum elements in a node to split
+ * @param minSamplesLeaf The minimum number of samples in a leaf node
+ * @param samples the sample set of instances for stochastic learning. samples[i] is the number
+ * of sampling for instance i.
* @param rule the splitting rule.
- * @param seed
+ * @param rand random number generator
*/
- public DecisionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix x, @Nonnull int[] y,
- int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
- @Nullable int[] bags, @Nullable ColumnMajorIntMatrix order, @Nonnull SplitRule rule,
- @Nullable PRNG rand) {
- checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
+ public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull int[] y,
+ int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf,
+ @Nullable int[] samples, @Nonnull SplitRule rule, @Nullable PRNG rand) {
+ checkArgument(x, y, numVars, maxDepth, maxLeafNodes, minSamplesSplit, minSamplesLeaf);
+
+ this._X = x;
+ this._y = y;
this._k = Math.max(y) + 1;
if (_k < 2) {
throw new IllegalArgumentException("Only one class or negative class labels.");
}
- this._attributes = SmileExtUtils.attributeTypes(attributes, x);
- if (attributes.length != x.numColumns()) {
- throw new IllegalArgumentException(
- "-attrs option is invalid: " + Arrays.toString(attributes));
+ if (nominalAttrs == null) {
+ nominalAttrs = new RoaringBitmap();
}
- this._hasNumericType = SmileExtUtils.containsNumericType(_attributes);
+ this._nominalAttrs = nominalAttrs;
this._numVars = numVars;
this._maxDepth = maxDepth;
- this._minSplit = minSplits;
- this._minLeafSize = minLeafSize;
+ // min_sample_leaf >= 2 is satisfied iff min_sample_split >= 4
+ // So, split only happens when samples in intermediate nodes has >= 2 * min_sample_leaf nodes.
+ if (minSamplesSplit < minSamplesLeaf * 2) {
+ if (logger.isInfoEnabled()) {
+ logger.info(String.format(
+ "min_sample_leaf = %d replaces min_sample_split = %d with min_sample_split = %d",
+ minSamplesLeaf, minSamplesSplit, minSamplesLeaf * 2));
+ }
+ minSamplesSplit = minSamplesLeaf * 2;
+ }
+ this._minSamplesSplit = minSamplesSplit;
+ this._minSamplesLeaf = minSamplesLeaf;
this._rule = rule;
- this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
- this._importance = x.isSparse() ? new SparseVector() : new DenseVector(_attributes.length);
+ this._importance = x.isSparse() ? new SparseVector() : new DenseVector(x.numColumns());
this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
final int n = y.length;
final int[] count = new int[_k];
- if (bags == null) {
- bags = new int[n];
+ final int[] sampleIndex;
+ int totalNumSamples = 0;
+ if (samples == null) {
+ samples = new int[n];
+ sampleIndex = new int[n];
for (int i = 0; i < n; i++) {
- bags[i] = i;
+ samples[i] = 1;
count[y[i]]++;
+ sampleIndex[i] = i;
}
+ totalNumSamples = n;
} else {
- for (int i = 0, size = bags.length; i < size; i++) {
- int index = bags[i];
- count[y[index]]++;
+ final IntArrayList positions = new IntArrayList(n);
+ for (int i = 0; i < n; i++) {
+ final int sample = samples[i];
+ if (sample != 0) {
+ count[y[i]] += sample;
+ positions.add(i);
+ totalNumSamples += sample;
+ }
}
+ sampleIndex = positions.toArray(true);
}
+ this._samples = samples;
+ this._order = SmileExtUtils.sort(nominalAttrs, x, samples);
+ this._sampleIndex = sampleIndex;
final double[] posteriori = new double[_k];
for (int i = 0; i < _k; i++) {
@@ -1034,12 +1272,13 @@ public final class DecisionTree implements Classifier<Vector> {
}
this._root = new Node(Math.whichMax(count), posteriori);
- final TrainNode trainRoot = new TrainNode(_root, x, y, bags, 1);
- if (maxLeafs == Integer.MAX_VALUE) {
+ final TrainNode trainRoot =
+ new TrainNode(_root, 1, 0, _sampleIndex.length, totalNumSamples);
+ if (maxLeafNodes == Integer.MAX_VALUE) { // depth-first split
if (trainRoot.findBestSplit()) {
trainRoot.split(null);
}
- } else {
+ } else { // best-first split
// Priority queue for best-first tree growing.
final PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
// Now add splits to the tree until max tree size is reached
@@ -1048,14 +1287,17 @@ public final class DecisionTree implements Classifier<Vector> {
}
// Pop best leaf from priority queue, split it, and push
// children nodes into the queue if possible.
- for (int leaves = 1; leaves < maxLeafs; leaves++) {
+ for (int leaves = 1; leaves < maxLeafNodes; leaves++) {
// parent is the leaf to split
- TrainNode parent = nextSplits.poll();
- if (parent == null) {
+ TrainNode node = nextSplits.poll();
+ if (node == null) {
break;
}
- parent.split(nextSplits); // Split the parent node into two children nodes
+ if (!node.split(nextSplits)) { // Split the parent node into two children nodes
+ leaves--;
+ }
}
+ pruneRedundantLeaves(_root, _importance);
}
}
@@ -1065,11 +1307,14 @@ public final class DecisionTree implements Classifier<Vector> {
}
private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars,
- int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
+ int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf) {
if (x.numRows() != y.length) {
throw new IllegalArgumentException(
String.format("The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
}
+ if (y.length == 0) {
+ throw new IllegalArgumentException("No training example given");
+ }
if (numVars <= 0 || numVars > x.numColumns()) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
@@ -1077,17 +1322,17 @@ public final class DecisionTree implements Classifier<Vector> {
if (maxDepth < 2) {
throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
}
- if (maxLeafs < 2) {
- throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafs);
+ if (maxLeafNodes < 2) {
+ throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafNodes);
}
- if (minSplits < 2) {
+ if (minSamplesSplit < 2) {
throw new IllegalArgumentException(
"Invalid minimum number of samples required to split an internal node: "
- + minSplits);
+ + minSamplesSplit);
}
- if (minLeafSize < 1) {
+ if (minSamplesLeaf < 1) {
throw new IllegalArgumentException(
- "Invalid minimum size of leaf nodes: " + minLeafSize);
+ "Invalid minimum size of leaf nodes: " + minSamplesLeaf);
}
}
diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 5feaa36..a25ab44 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -23,26 +23,24 @@ import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.builders.MatrixBuilder;
import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
-import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.SerdeUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.BitSet;
import java.util.HashMap;
import java.util.Map;
@@ -69,6 +67,7 @@ import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
+import org.roaringbitmap.RoaringBitmap;
@Description(name = "train_gradient_tree_boosting_classifier",
value = "_FUNC_(array<double|string> features, int label [, string options]) - "
@@ -112,12 +111,12 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
private int _minSamplesSplit;
private int _minSamplesLeaf;
private long _seed;
- private AttributeType[] _attributes;
+ private byte[] _nominalAttrs;
@Nullable
- private Reporter _progressReporter;
+ private transient Reporter _progressReporter;
@Nullable
- private Counter _iterationCounter;
+ private transient Counter _iterationCounter;
@Override
protected Options getOptions() {
@@ -142,16 +141,18 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
+ opts.addOption("nominal_attr_indicies", "categorical_attr_indicies", true,
+ "Comma seperated indicies of categorical attributes, e.g., [3,5,6]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
int trees = 500, maxDepth = 8;
- int maxLeafs = Integer.MAX_VALUE, minSplit = 5, minSamplesLeaf = 1;
+ int maxLeafNodes = Integer.MAX_VALUE, minSamplesSplit = 5, minSamplesLeaf = 1;
float numVars = -1.f;
double eta = 0.05d, subsample = 0.7d;
- AttributeType[] attrs = null;
+ RoaringBitmap attrs = new RoaringBitmap();
long seed = -1L;
CommandLine cl = null;
@@ -167,12 +168,23 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
subsample = Primitives.parseDouble(cl.getOptionValue("subsample"), subsample);
numVars = Primitives.parseFloat(cl.getOptionValue("num_variables"), numVars);
maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth);
- maxLeafs = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafs);
- minSplit = Primitives.parseInt(cl.getOptionValue("min_split"), minSplit);
+ maxLeafNodes = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafNodes);
+ String min_samples_split = cl.getOptionValue("min_samples_split");
+ if (min_samples_split == null) {
+ minSamplesSplit =
+ Primitives.parseInt(cl.getOptionValue("min_split"), minSamplesSplit);
+ } else {
+ minSamplesSplit = Integer.parseInt(min_samples_split);
+ }
minSamplesLeaf =
Primitives.parseInt(cl.getOptionValue("min_samples_leaf"), minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
- attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
+ String nominal_attr_indicies = cl.getOptionValue("nominal_attr_indicies");
+ if (nominal_attr_indicies != null) {
+ attrs = SmileExtUtils.parseNominalAttributeIndicies(nominal_attr_indicies);
+ } else {
+ attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
+ }
}
this._numTrees = trees;
@@ -180,11 +192,11 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
this._subsample = subsample;
this._numVars = numVars;
this._maxDepth = maxDepth;
- this._maxLeafNodes = maxLeafs;
- this._minSamplesSplit = minSplit;
+ this._maxLeafNodes = maxLeafNodes;
+ this._minSamplesSplit = minSamplesSplit;
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
- this._attributes = attrs;
+ this._nominalAttrs = SerdeUtils.serializeRoaring(attrs);
return cl;
}
@@ -303,7 +315,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
this.featureListOI = null;
this.featureElemOI = null;
this.labelOI = null;
- this._attributes = null;
}
private void checkOptions() throws HiveException {
@@ -332,7 +343,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
String.format("The sizes of X and Y don't match: %d != %d", numRows, y.length));
}
checkOptions();
- this._attributes = SmileExtUtils.attributeTypes(_attributes, x);
// Shuffle training samples
x = SmileExtUtils.shuffle(x, y, _seed);
@@ -378,38 +388,36 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
h[i] = intercept;
}
- final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x);
final RegressionTree.NodeOutput output = new L2NodeOutput(response);
- final BitSet sampled = new BitSet(numInstances);
- final int[] bag = new int[numSamples];
- final int[] perm = new int[numSamples];
- for (int i = 0; i < numSamples; i++) {
- perm[i] = i;
- }
+ final int[] samples = new int[numInstances];
+ final int[] perm = MathUtils.permutation(numInstances);
long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
: RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final RoaringBitmap nominalAttrs = SerdeUtils.deserializeRoaring(_nominalAttrs);
+ this._nominalAttrs = null;
+
final Vector xProbe = x.rowVector();
for (int m = 0; m < _numTrees; m++) {
reportProgress(_progressReporter);
+ Arrays.fill(samples, 0);
SmileExtUtils.shuffle(perm, rnd1);
for (int i = 0; i < numSamples; i++) {
int index = perm[i];
- bag[i] = index;
- sampled.set(index);
+ samples[index] += 1;
}
for (int i = 0; i < numInstances; i++) {
response[i] = 2.0d * y[i] / (1.d + Math.exp(2.d * y[i] * h[i]));
}
- RegressionTree tree = new RegressionTree(_attributes, x, response, numVars, _maxDepth,
- _maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, order, bag, output, rnd2);
+ RegressionTree tree = new RegressionTree(nominalAttrs, x, response, numVars, _maxDepth,
+ _maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, samples, output, rnd2);
for (int i = 0; i < numInstances; i++) {
x.getRow(i, xProbe);
@@ -418,8 +426,10 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
// out-of-bag error estimate
int oobTests = 0, oobErrors = 0;
- for (int i = sampled.nextClearBit(0); i < numInstances; i =
- sampled.nextClearBit(i + 1)) {
+ for (int i = 0; i < samples.length; i++) {
+ if (samples[i] != 0) {
+ continue;
+ }
oobTests++;
final int pred = (h[i] > 0.d) ? 1 : 0;
if (pred != y[i]) {
@@ -431,8 +441,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
oobErrorRate = ((float) oobErrors) / oobTests;
}
- forward(m + 1, intercept, _eta, oobErrorRate, tree);
- sampled.clear();
+ forward(m + 1, intercept, _eta, oobErrorRate, x.numColumns(), tree);
}
}
@@ -455,14 +464,12 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
final double[][] p = new double[k][numInstances]; // a posteriori probabilities.
final double[][] response = new double[k][numInstances]; // pseudo response.
- final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x);
final RegressionTree.NodeOutput[] output = new LKNodeOutput[k];
for (int i = 0; i < k; i++) {
output[i] = new LKNodeOutput(response[i], k);
}
- final BitSet sampled = new BitSet(numInstances);
- final int[] bag = new int[numSamples];
+ final int[] samples = new int[numInstances];
final int[] perm = MathUtils.permutation(numInstances);
long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
@@ -470,6 +477,9 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final RoaringBitmap nominalAttrs = SerdeUtils.deserializeRoaring(_nominalAttrs);
+ this._nominalAttrs = null;
+
// out-of-bag prediction
final int[] prediction = new int[numInstances];
final Vector xProbe = x.rowVector();
@@ -515,16 +525,16 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
response_j[i] -= p_j[i];
}
+ Arrays.fill(samples, 0);
SmileExtUtils.shuffle(perm, rnd1);
for (int i = 0; i < numSamples; i++) {
int index = perm[i];
- bag[i] = index;
- sampled.set(i);
+ samples[index] += 1;
}
- RegressionTree tree = new RegressionTree(_attributes, x, response[j], numVars,
- _maxDepth, _maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, order, bag,
- output[j], rnd2);
+ RegressionTree tree = new RegressionTree(nominalAttrs, x, response[j], numVars,
+ _maxDepth, _maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, samples, output[j],
+ rnd2);
trees[j] = tree;
for (int i = 0; i < numInstances; i++) {
@@ -540,21 +550,22 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
} // for each k
// out-of-bag error estimate
- for (int i = sampled.nextClearBit(0); i < numInstances; i =
- sampled.nextClearBit(i + 1)) {
+ for (int i = 0; i < samples.length; i++) {
+ if (samples[i] != 0) {
+ continue;
+ }
oobTests++;
if (prediction[i] != y[i]) {
oobErrors++;
}
}
- sampled.clear();
float oobErrorRate = 0.f;
if (oobTests > 0) {
oobErrorRate = ((float) oobErrors) / oobTests;
}
// forward a row
- forward(m + 1, 0.d, _eta, oobErrorRate, trees);
+ forward(m + 1, 0.d, _eta, oobErrorRate, x.numColumns(), trees);
} // for each m
}
@@ -563,10 +574,11 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
* @param m m-th boosting iteration
*/
private void forward(final int m, final double intercept, final double shrinkage,
- final float oobErrorRate, @Nonnull final RegressionTree... trees) throws HiveException {
+ final float oobErrorRate, final int numColumns, @Nonnull final RegressionTree... trees)
+ throws HiveException {
Text[] models = getModel(trees);
- Vector importance = denseInput ? new DenseVector(_attributes.length) : new SparseVector();
+ Vector importance = denseInput ? new DenseVector(numColumns) : new SparseVector();
for (RegressionTree tree : trees) {
Vector imp = tree.importance();
for (int i = 0, size = imp.size(); i < size; i++) {
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 99396b7..7ae6f10 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -24,7 +24,6 @@ import hivemall.math.matrix.MatrixUtils;
import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.builders.MatrixBuilder;
import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
-import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.matrix.ints.DoKIntMatrix;
import hivemall.math.matrix.ints.IntMatrix;
import hivemall.math.random.PRNG;
@@ -32,12 +31,12 @@ import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
import hivemall.smile.classification.DecisionTree.SplitRule;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
import hivemall.utils.codec.Base91;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.SerdeUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;
@@ -45,7 +44,6 @@ import hivemall.utils.lang.RandomUtils;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -79,6 +77,7 @@ import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
+import org.roaringbitmap.RoaringBitmap;
@Description(name = "train_randomforest_classifier",
value = "_FUNC_(array<double|string> features, int label [, const string options, const array<double> classWeights])"
@@ -114,7 +113,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private int _minSamplesSplit;
private int _minSamplesLeaf;
private long _seed;
- private AttributeType[] _attributes;
+ private byte[] _nominalAttrs;
private SplitRule _splitRule;
private boolean _stratifiedSampling;
private double _subsample;
@@ -123,9 +122,9 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private double[] _classWeight;
@Nullable
- private Reporter _progressReporter;
+ private transient Reporter _progressReporter;
@Nullable
- private Counter _treeBuildTaskCounter;
+ private transient Counter _treeBuildTaskCounter;
@Override
protected Options getOptions() {
@@ -146,6 +145,8 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
+ opts.addOption("nominal_attr_indicies", "categorical_attr_indicies", true,
+ "Comma seperated indicies of categorical attributes, e.g., [3,5,6]");
opts.addOption("rule", "split_rule", true,
"Split algorithm [default: GINI, ENTROPY, CLASSIFICATION_ERROR]");
opts.addOption("stratified", "stratified_sampling", false,
@@ -157,9 +158,9 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
int trees = 50, maxDepth = Integer.MAX_VALUE;
- int numLeafs = Integer.MAX_VALUE, minSplits = 2, minSamplesLeaf = 1;
+ int maxLeafNodes = Integer.MAX_VALUE, minSamplesSplit = 2, minSamplesLeaf = 1;
float numVars = -1.f;
- AttributeType[] attrs = null;
+ RoaringBitmap attrs = new RoaringBitmap();
long seed = -1L;
SplitRule splitRule = SplitRule.GINI;
double[] classWeight = null;
@@ -177,12 +178,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
}
numVars = Primitives.parseFloat(cl.getOptionValue("num_variables"), numVars);
maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth);
- numLeafs = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), numLeafs);
- minSplits = Primitives.parseInt(cl.getOptionValue("min_split"), minSplits);
+ maxLeafNodes = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafNodes);
+ String min_samples_split = cl.getOptionValue("min_samples_split");
+ if (min_samples_split == null) {
+ minSamplesSplit =
+ Primitives.parseInt(cl.getOptionValue("min_split"), minSamplesSplit);
+ } else {
+ minSamplesSplit = Integer.parseInt(min_samples_split);
+ }
minSamplesLeaf =
Primitives.parseInt(cl.getOptionValue("min_samples_leaf"), minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
- attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
+ String nominal_attr_indicies = cl.getOptionValue("nominal_attr_indicies");
+ if (nominal_attr_indicies != null) {
+ attrs = SmileExtUtils.parseNominalAttributeIndicies(nominal_attr_indicies);
+ } else {
+ attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
+ }
splitRule = SmileExtUtils.resolveSplitRule(cl.getOptionValue("split_rule", "GINI"));
stratifiedSampling = cl.hasOption("stratified_sampling");
subsample = Primitives.parseDouble(cl.getOptionValue("subsample"), 1.0d);
@@ -209,11 +221,11 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
this._numTrees = trees;
this._numVars = numVars;
this._maxDepth = maxDepth;
- this._maxLeafNodes = numLeafs;
- this._minSamplesSplit = minSplits;
+ this._maxLeafNodes = maxLeafNodes;
+ this._minSamplesSplit = minSamplesSplit;
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
- this._attributes = attrs;
+ this._nominalAttrs = SerdeUtils.serializeRoaring(attrs);
this._splitRule = splitRule;
this._stratifiedSampling = stratifiedSampling;
this._subsample = subsample;
@@ -345,7 +357,6 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
this.featureListOI = null;
this.featureElemOI = null;
this.labelOI = null;
- this._attributes = null;
}
private void checkOptions() throws HiveException {
@@ -377,7 +388,6 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
x = SmileExtUtils.shuffle(x, y, _seed);
int[] labels = SmileExtUtils.classLabels(y);
- AttributeType[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x);
if (logger.isInfoEnabled()) {
@@ -386,14 +396,15 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
+ _maxLeafNodes + ", splitRule: " + _splitRule + ", seed: " + _seed);
}
- IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder for out-of-bag prediction
- ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
- AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
- List<TrainingTask> tasks = new ArrayList<TrainingTask>();
+ final RoaringBitmap nominalAttrs = SerdeUtils.deserializeRoaring(_nominalAttrs);
+ this._nominalAttrs = null;
+ final IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder for out-of-bag prediction
+ final AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
+ final List<TrainingTask> tasks = new ArrayList<TrainingTask>();
for (int i = 0; i < _numTrees; i++) {
long s = (_seed == -1L) ? -1L : _seed + i;
- tasks.add(new TrainingTask(this, i, attributes, x, y, numInputVars, order, prediction,
- s, remainingTasks));
+ tasks.add(new TrainingTask(this, i, nominalAttrs, x, y, numInputVars, prediction, s,
+ remainingTasks));
}
MapredContext mapredContext = MapredContextAccessor.get();
@@ -449,8 +460,9 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
}
forwardObjs[4] = new IntWritable(oobErrors);
forwardObjs[5] = new IntWritable(oobTests);
- forward(forwardObjs);
-
+ synchronized (this) {
+ forward(forwardObjs);
+ }
reportProgress(_progressReporter);
incrCounter(_treeBuildTaskCounter, 1);
@@ -465,7 +477,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
* Attribute properties.
*/
@Nonnull
- private final AttributeType[] _attributes;
+ private final RoaringBitmap _nominalAttrs;
/**
* Training instances.
*/
@@ -477,12 +489,6 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
@Nonnull
private final int[] _y;
/**
- * The index of training values in ascending order. Note that only numeric attributes will
- * be sorted.
- */
- @Nonnull
- private final ColumnMajorIntMatrix _order;
- /**
* The number of variables to pick up in each node.
*/
private final int _numVars;
@@ -501,15 +507,14 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private final AtomicInteger _remainingTasks;
TrainingTask(@Nonnull RandomForestClassifierUDTF udtf, int taskId,
- @Nonnull AttributeType[] attributes, @Nonnull Matrix x, @Nonnull int[] y,
- int numVars, @Nonnull ColumnMajorIntMatrix order, @Nonnull IntMatrix prediction,
- long seed, @Nonnull AtomicInteger remainingTasks) {
+ @Nonnull RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull int[] y,
+ int numVars, @Nonnull IntMatrix prediction, long seed,
+ @Nonnull AtomicInteger remainingTasks) {
this._udtf = udtf;
this._taskId = taskId;
- this._attributes = attributes;
+ this._nominalAttrs = nominalAttrs;
this._x = x;
this._y = y;
- this._order = order;
this._numVars = numVars;
this._prediction = prediction;
this._seed = seed;
@@ -525,18 +530,20 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
final int N = _x.numRows();
// Training samples draw with replacement.
- final BitSet sampled = new BitSet(N);
- final int[] bags = sampling(sampled, N, rnd1);
+ final int[] samples = sampling(N, rnd1);
- DecisionTree tree = new DecisionTree(_attributes, _x, _y, _numVars, _udtf._maxDepth,
- _udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, bags, _order,
+ DecisionTree tree = new DecisionTree(_nominalAttrs, _x, _y, _numVars, _udtf._maxDepth,
+ _udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, samples,
_udtf._splitRule, rnd2);
// out-of-bag prediction
int oob = 0;
int correct = 0;
final Vector xProbe = _x.rowVector();
- for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
+ for (int i = 0; i < samples.length; i++) {
+ if (samples[i] != 0) {
+ continue;
+ }
oob++;
_x.getRow(i, xProbe);
final int p = tree.predict(xProbe);
@@ -559,22 +566,20 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
}
@Nonnull
- private int[] sampling(@Nonnull final BitSet sampled, final int N, @Nonnull PRNG rnd) {
- return _udtf._stratifiedSampling ? stratifiedSampling(sampled, N, _udtf._subsample, rnd)
- : uniformSampling(sampled, N, _udtf._subsample, rnd);
+ private int[] sampling(final int N, @Nonnull PRNG rnd) {
+ return _udtf._stratifiedSampling ? stratifiedSampling(N, _udtf._subsample, rnd)
+ : uniformSampling(N, _udtf._subsample, rnd);
}
@Nonnull
- private static int[] uniformSampling(@Nonnull final BitSet sampled, final int N,
- final double subsample, final PRNG rnd) {
+ private static int[] uniformSampling(final int N, final double subsample, final PRNG rnd) {
final int size = (int) Math.round(N * subsample);
- final int[] bags = new int[N];
+ final int[] samples = new int[N];
for (int i = 0; i < size; i++) {
int index = rnd.nextInt(N);
- bags[i] = index;
- sampled.set(index);
+ samples[index] += 1;
}
- return bags;
+ return samples;
}
/**
@@ -583,9 +588,8 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
* @link https://en.wikipedia.org/wiki/Stratified_sampling
*/
@Nonnull
- private int[] stratifiedSampling(@Nonnull final BitSet sampled, final int N,
- final double subsample, final PRNG rnd) {
- final IntArrayList bagsList = new IntArrayList(N);
+ private int[] stratifiedSampling(final int N, final double subsample, final PRNG rnd) {
+ final int[] samples = new int[N];
final int k = smile.math.Math.max(_y) + 1;
final IntArrayList cj = new IntArrayList(N / k);
for (int l = 0; l < k; l++) {
@@ -604,14 +608,12 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
for (int j = 0; j < size; j++) {
int xi = rnd.nextInt(nj);
int index = cj.get(xi);
- bagsList.add(index);
- sampled.set(index);
+ samples[index] += 1;
}
cj.clear();
}
- int[] bags = bagsList.toArray(true);
- SmileExtUtils.shuffle(bags, rnd);
- return bags;
+ // SmileExtUtils.shuffle(samples, rnd); // not needed in DecisionTrees
+ return samples;
}
@Nonnull
diff --git a/core/src/main/java/hivemall/smile/data/AttributeType.java b/core/src/main/java/hivemall/smile/data/AttributeType.java
deleted file mode 100644
index 7aa0ef0..0000000
--- a/core/src/main/java/hivemall/smile/data/AttributeType.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile.data;
-
-import hivemall.annotations.BackwardCompatibility;
-
-public enum AttributeType {
- NUMERIC((byte) 1), NOMINAL((byte) 2);
-
- private final byte id;
-
- private AttributeType(byte id) {
- this.id = id;
- }
-
- public byte getTypeId() {
- return id;
- }
-
- public static AttributeType resolve(final byte id) {
- final AttributeType type;
- switch (id) {
- case 1:
- type = NUMERIC;
- break;
- case 2:
- type = NOMINAL;
- break;
- default:
- throw new IllegalStateException("Unexpected type: " + id);
- }
- return type;
- }
-
- @BackwardCompatibility
- public static AttributeType resolve(final int id) {
- final AttributeType type;
- switch (id) {
- case 1:
- type = NUMERIC;
- break;
- case 2:
- type = NOMINAL;
- break;
- default:
- throw new IllegalStateException("Unexpected type: " + id);
- }
- return type;
- }
-
-}
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index ec2e25d..19cec91 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -23,24 +23,22 @@ import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.builders.MatrixBuilder;
import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
-import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
import hivemall.utils.codec.Base91;
import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.SerdeUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.RandomUtils;
import java.util.ArrayList;
-import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -73,6 +71,7 @@ import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
+import org.roaringbitmap.RoaringBitmap;
@Description(name = "train_randomforest_regressor",
value = "_FUNC_(array<double|string> features, double target [, string options]) - "
@@ -107,16 +106,16 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private int _minSamplesSplit;
private int _minSamplesLeaf;
private long _seed;
- private AttributeType[] _attributes;
+ private byte[] _nominalAttrs;
@Nullable
- private Reporter _progressReporter;
+ private transient Reporter _progressReporter;
@Nullable
- private Counter _treeBuildTaskCounter;
+ private transient Counter _treeBuildTaskCounter;
@Nullable
- private Counter _treeConstructionTimeCounter;
+ private transient Counter _treeConstructionTimeCounter;
@Nullable
- private Counter _treeSerializationTimeCounter;
+ private transient Counter _treeSerializationTimeCounter;
@Override
protected Options getOptions() {
@@ -130,6 +129,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
"The maximum number of the tree depth [default: Integer.MAX_VALUE]");
opts.addOption("leafs", "max_leaf_nodes", true,
"The maximum number of leaf nodes [default: Integer.MAX_VALUE]");
+ opts.addOption("min_samples_split", true,
+ "A node that has greater than or equals to `min_split` examples will split [default: 5]");
+ // synonym of min_samples_split
opts.addOption("split", "min_split", true,
"A node that has greater than or equals to `min_split` examples will split [default: 5]");
opts.addOption("min_samples_leaf", true,
@@ -137,15 +139,17 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
+ opts.addOption("nominal_attr_indicies", "categorical_attr_indicies", true,
+ "Comma seperated indicies of categorical attributes, e.g., [3,5,6]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
int trees = 50, maxDepth = Integer.MAX_VALUE;
- int maxLeafs = Integer.MAX_VALUE, minSplit = 5, minSamplesLeaf = 1;
+ int maxLeafNodes = Integer.MAX_VALUE, minSamplesSplit = 5, minSamplesLeaf = 1;
float numVars = -1.f;
- AttributeType[] attrs = null;
+ RoaringBitmap attrs = new RoaringBitmap();
long seed = -1L;
CommandLine cl = null;
@@ -159,22 +163,33 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
}
numVars = Primitives.parseFloat(cl.getOptionValue("num_variables"), numVars);
maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth);
- maxLeafs = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafs);
- minSplit = Primitives.parseInt(cl.getOptionValue("min_split"), minSplit);
+ maxLeafNodes = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafNodes);
+ String min_samples_split = cl.getOptionValue("min_samples_split");
+ if (min_samples_split == null) {
+ minSamplesSplit =
+ Primitives.parseInt(cl.getOptionValue("min_split"), minSamplesSplit);
+ } else {
+ minSamplesSplit = Integer.parseInt(min_samples_split);
+ }
minSamplesLeaf =
Primitives.parseInt(cl.getOptionValue("min_samples_leaf"), minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
- attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
+ String nominal_attr_indicies = cl.getOptionValue("nominal_attr_indicies");
+ if (nominal_attr_indicies != null) {
+ attrs = SmileExtUtils.parseNominalAttributeIndicies(nominal_attr_indicies);
+ } else {
+ attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
+ }
}
this._numTrees = trees;
this._numVars = numVars;
this._maxDepth = maxDepth;
- this._maxLeafNodes = maxLeafs;
- this._minSamplesSplit = minSplit;
+ this._maxLeafNodes = maxLeafNodes;
+ this._minSamplesSplit = minSamplesSplit;
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
- this._attributes = attrs;
+ this._nominalAttrs = SerdeUtils.serializeRoaring(attrs);
return cl;
}
@@ -299,7 +314,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
this.featureListOI = null;
this.featureElemOI = null;
this.targetOI = null;
- this._attributes = null;
+ this._nominalAttrs = null;
}
private void checkOptions() throws HiveException {
@@ -330,7 +345,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
// Shuffle training samples
x = SmileExtUtils.shuffle(x, y, _seed);
- AttributeType[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x);
if (logger.isInfoEnabled()) {
@@ -340,15 +354,16 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
+ ", seed: " + _seed);
}
- double[] prediction = new double[numExamples]; // placeholder for out-of-bag prediction
- int[] oob = new int[numExamples];
- ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
- AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
+ final RoaringBitmap nominalAttrs = SerdeUtils.deserializeRoaring(_nominalAttrs);
+ this._nominalAttrs = null;
+ final double[] prediction = new double[numExamples]; // placeholder for out-of-bag prediction
+ final int[] oob = new int[numExamples];
+ final AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
List<TrainingTask> tasks = new ArrayList<TrainingTask>();
for (int i = 0; i < _numTrees; i++) {
long s = (_seed == -1L) ? -1L : _seed + i;
- tasks.add(new TrainingTask(this, i, attributes, x, y, numInputVars, order, prediction,
- oob, s, remainingTasks));
+ tasks.add(new TrainingTask(this, i, nominalAttrs, x, y, numInputVars, prediction, oob,
+ s, remainingTasks));
}
MapredContext mapredContext = MapredContextAccessor.get();
@@ -417,7 +432,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
/**
* Attribute properties.
*/
- private final AttributeType[] _attributes;
+ private final RoaringBitmap _nominalAttrs;
/**
* Training instances.
*/
@@ -427,11 +442,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
*/
private final double[] _y;
/**
- * The index of training values in ascending order. Note that only numeric attributes will
- * be sorted.
- */
- private final ColumnMajorIntMatrix _order;
- /**
* The number of variables to pick up in each node.
*/
private final int _numVars;
@@ -449,15 +459,14 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private final long _seed;
private final AtomicInteger _remainingTasks;
- TrainingTask(RandomForestRegressionUDTF udtf, int taskId, AttributeType[] attributes,
- Matrix x, double[] y, int numVars, ColumnMajorIntMatrix order, double[] prediction,
- int[] oob, long seed, AtomicInteger remainingTasks) {
+ TrainingTask(RandomForestRegressionUDTF udtf, int taskId, RoaringBitmap nominalAttrs,
+ Matrix x, double[] y, int numVars, double[] prediction, int[] oob, long seed,
+ AtomicInteger remainingTasks) {
this._udtf = udtf;
this._taskId = taskId;
- this._attributes = attributes;
+ this._nominalAttrs = nominalAttrs;
this._x = x;
this._y = y;
- this._order = order;
this._numVars = numVars;
this._prediction = prediction;
this._oob = oob;
@@ -474,25 +483,26 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
final int N = _x.numRows();
// Training samples draw with replacement.
- final int[] bags = new int[N];
- final BitSet sampled = new BitSet(N);
+ final int[] samples = new int[N];
for (int i = 0; i < N; i++) {
int index = rnd1.nextInt(N);
- bags[i] = index;
- sampled.set(index);
+ samples[index] += 1;
}
StopWatch stopwatch = new StopWatch();
- RegressionTree tree = new RegressionTree(_attributes, _x, _y, _numVars, _udtf._maxDepth,
- _udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, _order, bags,
- rnd2);
+ RegressionTree tree = new RegressionTree(_nominalAttrs, _x, _y, _numVars,
+ _udtf._maxDepth, _udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf,
+ samples, rnd2);
incrCounter(_udtf._treeConstructionTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
// out-of-bag prediction
int oob = 0;
double error = 0.d;
final Vector xProbe = _x.rowVector();
- for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
+ for (int i = 0; i < samples.length; i++) {
+ if (samples[i] != 0) {
+ continue;
+ }
oob++;
_x.getRow(i, xProbe);
final double pred = tree.predict(xProbe);
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 0e42094..764c352 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -17,26 +17,29 @@
// https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/regression/RegressionTree.java
package hivemall.smile.regression;
+import static hivemall.smile.utils.SmileExtUtils.NOMINAL;
+import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
import hivemall.annotations.VisibleForTesting;
import hivemall.math.matrix.Matrix;
-import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
+import hivemall.smile.utils.VariableOrder;
+import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.IntArrayList;
-import hivemall.utils.collections.sets.IntArraySet;
-import hivemall.utils.collections.sets.IntSet;
+import hivemall.utils.function.Consumer;
+import hivemall.utils.function.IntPredicate;
+import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.lang.mutable.MutableInt;
-import hivemall.utils.math.MathUtils;
+import hivemall.utils.sampling.IntReservoirSampler;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap.Entry;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
@@ -50,14 +53,17 @@ import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.roaringbitmap.IntConsumer;
+import org.roaringbitmap.RoaringBitmap;
/**
* Decision tree for regression. A decision tree can be learned by splitting the training set into
@@ -99,11 +105,40 @@ import org.apache.hadoop.hive.ql.metadata.HiveException;
* @see RandomForest
*/
public final class RegressionTree implements Regression<Vector> {
+ private static final Log logger = LogFactory.getLog(RegressionTree.class);
+
+ /**
+ * Training dataset.
+ */
+ private final Matrix _X;
+ /**
+ * Training data response value.
+ */
+ private final double[] _y;
+ /**
+ * The samples for training this node. Note that samples[i] is the number of sampling of
+ * dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible
+ * because of sampling with replacement.
+ */
+ @Nonnull
+ private final int[] _samples;
+ /**
+ * The index of training values in ascending order. Note that only numeric attributes will be
+ * sorted.
+ */
+ @Nonnull
+ private final VariableOrder _order;
+ /**
+ * An index that maps their current position in the {@link #_order} to their original locations
+ * in {@link #_samples}.
+ */
+ @Nonnull
+ private final int[] _sampleIndex;
/**
* The attributes of independent variable.
*/
- private final AttributeType[] _attributes;
- private final boolean _hasNumericType;
+ @Nonnull
+ private final RoaringBitmap _nominalAttrs;
/**
* Variable importance. Every time a split of a node is made on variable the impurity criterion
* for the two descendant nodes is less than the parent node. Adding up the decreases for each
@@ -122,25 +157,20 @@ public final class RegressionTree implements Regression<Vector> {
* The number of instances in a node below which the tree will not split, setting S = 5
* generally gives good results.
*/
- private final int _minSplit;
+ private final int _minSamplesSplit;
/**
* The minimum number of samples in a leaf node
*/
- private final int _minLeafSize;
+ private final int _minSamplesLeaf;
/**
* The number of input variables to be used to determine the decision at a node of the tree.
*/
private final int _numVars;
/**
- * The index of training values in ascending order. Note that only numeric attributes will be
- * sorted.
+ * The random number generator.
*/
- private final ColumnMajorIntMatrix _order;
-
private final PRNG _rnd;
- private final NodeOutput _nodeOutput;
-
/**
* An interface to calculate node output. Note that samples[i] is the number of sampling of
* dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible
@@ -153,7 +183,7 @@ public final class RegressionTree implements Regression<Vector> {
* @param samples the samples in the node.
* @return the node output
*/
- public double calculate(int[] samples);
+ double calculate(int[] samples);
}
/**
@@ -172,7 +202,7 @@ public final class RegressionTree implements Regression<Vector> {
/**
* The type of split feature
*/
- AttributeType splitFeatureType = null;
+ boolean quantitativeFeature = true;
/**
* The split value.
*/
@@ -208,6 +238,14 @@ public final class RegressionTree implements Regression<Vector> {
return trueChild == null && falseChild == null;
}
+ private void markAsLeaf() {
+ this.splitFeature = -1;
+ this.splitValue = Double.NaN;
+ this.splitScore = 0.0;
+ this.trueChild = null;
+ this.falseChild = null;
+ }
+
@VisibleForTesting
public double predict(@Nonnull final double[] x) {
return predict(new DenseVector(x));
@@ -217,24 +255,21 @@ public final class RegressionTree implements Regression<Vector> {
* Evaluate the regression tree over an instance.
*/
public double predict(@Nonnull final Vector x) {
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
return output;
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- if (x.get(splitFeature, Double.NaN) == splitValue) {
+ if (quantitativeFeature) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x.get(splitFeature, Double.NaN) <= splitValue) {
+ } else {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
}
@@ -243,7 +278,7 @@ public final class RegressionTree implements Regression<Vector> {
* Evaluate the regression tree over an instance.
*/
public double predict(final int[] x) {
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
return output;
} else if (x[splitFeature] == (int) splitValue) {
return trueChild.predict(x);
@@ -254,22 +289,22 @@ public final class RegressionTree implements Regression<Vector> {
public void exportJavascript(@Nonnull final StringBuilder builder,
@Nullable final String[] featureNames, final int depth) {
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
indent(builder, depth);
builder.append(output).append(";\n");
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
+ if (quantitativeFeature) {
indent(builder, depth);
if (featureNames == null) {
builder.append("if( x[")
.append(splitFeature)
- .append("] == ")
+ .append("] <= ")
.append(splitValue)
.append(") {\n");
} else {
builder.append("if( ")
.append(resolveFeatureName(splitFeature, featureNames))
- .append(" == ")
+ .append(" <= ")
.append(splitValue)
.append(") {\n");
}
@@ -279,18 +314,18 @@ public final class RegressionTree implements Regression<Vector> {
falseChild.exportJavascript(builder, featureNames, depth + 1);
indent(builder, depth);
builder.append("}\n");
- } else if (splitFeatureType == AttributeType.NUMERIC) {
+ } else {
indent(builder, depth);
if (featureNames == null) {
builder.append("if( x[")
.append(splitFeature)
- .append("] <= ")
+ .append("] == ")
.append(splitValue)
.append(") {\n");
} else {
builder.append("if( ")
.append(resolveFeatureName(splitFeature, featureNames))
- .append(" <= ")
+ .append(" == ")
.append(splitValue)
.append(") {\n");
}
@@ -300,9 +335,6 @@ public final class RegressionTree implements Regression<Vector> {
falseChild.exportJavascript(builder, featureNames, depth + 1);
indent(builder, depth);
builder.append("}\n");
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
}
@@ -312,7 +344,7 @@ public final class RegressionTree implements Regression<Vector> {
final @Nonnull MutableInt nodeIdGenerator, final int parentNodeId) {
final int myNodeId = nodeIdGenerator.getValue();
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
builder.append(String.format(
" %d [label=<%s = %s>, fillcolor=\"#00000000\", shape=ellipse];\n", myNodeId,
outputName, Double.toString(output)));
@@ -331,19 +363,16 @@ public final class RegressionTree implements Regression<Vector> {
builder.append(";\n");
}
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- builder.append(
- String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId,
- resolveFeatureName(splitFeature, featureNames),
- Double.toString(splitValue)));
- } else if (splitFeatureType == AttributeType.NUMERIC) {
+ if (quantitativeFeature) {
builder.append(
String.format(" %d [label=<%s ≤ %s>, fillcolor=\"#00000000\"];\n",
myNodeId, resolveFeatureName(splitFeature, featureNames),
Double.toString(splitValue)));
} else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
+ builder.append(
+ String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId,
+ resolveFeatureName(splitFeature, featureNames),
+ Double.toString(splitValue)));
}
if (myNodeId != parentNodeId) {
@@ -373,7 +402,7 @@ public final class RegressionTree implements Regression<Vector> {
public int opCodegen(@Nonnull final List<String> scripts, int depth) {
int selfDepth = 0;
final StringBuilder buf = new StringBuilder();
- if (trueChild == null && falseChild == null) {
+ if (isLeaf()) {
buf.append("push ").append(output);
scripts.add(buf.toString());
buf.setLength(0);
@@ -381,41 +410,38 @@ public final class RegressionTree implements Regression<Vector> {
scripts.add(buf.toString());
selfDepth += 2;
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
+ if (quantitativeFeature) {
buf.append("push ").append("x[").append(splitFeature).append("]");
scripts.add(buf.toString());
buf.setLength(0);
buf.append("push ").append(splitValue);
scripts.add(buf.toString());
buf.setLength(0);
- buf.append("ifeq ");
+ buf.append("ifle ");
scripts.add(buf.toString());
depth += 3;
selfDepth += 3;
int trueDepth = trueChild.opCodegen(scripts, depth);
selfDepth += trueDepth;
- scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
+ scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
selfDepth += falseDepth;
- } else if (splitFeatureType == AttributeType.NUMERIC) {
+ } else {
buf.append("push ").append("x[").append(splitFeature).append("]");
scripts.add(buf.toString());
buf.setLength(0);
buf.append("push ").append(splitValue);
scripts.add(buf.toString());
buf.setLength(0);
- buf.append("ifle ");
+ buf.append("ifeq ");
scripts.add(buf.toString());
depth += 3;
selfDepth += 3;
int trueDepth = trueChild.opCodegen(scripts, depth);
selfDepth += trueDepth;
- scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
+ scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
selfDepth += falseDepth;
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
return selfDepth;
@@ -424,11 +450,7 @@ public final class RegressionTree implements Regression<Vector> {
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(splitFeature);
- if (splitFeatureType == null) {
- out.writeByte(-1);
- } else {
- out.writeByte(splitFeatureType.getTypeId());
- }
+ out.writeByte(quantitativeFeature ? NUMERIC : NOMINAL);
out.writeDouble(splitValue);
if (isLeaf()) {
@@ -455,11 +477,7 @@ public final class RegressionTree implements Regression<Vector> {
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.splitFeature = in.readInt();
byte typeId = in.readByte();
- if (typeId == -1) {
- this.splitFeatureType = null;
- } else {
- this.splitFeatureType = AttributeType.resolve(typeId);
- }
+ this.quantitativeFeature = (typeId == NUMERIC);
this.splitValue = in.readDouble();
if (in.readBoolean()) {// isLeaf()
@@ -491,37 +509,54 @@ public final class RegressionTree implements Regression<Vector> {
/**
* The associated regression tree node.
*/
+ @Nonnull
final Node node;
/**
- * Child node that passes the test.
+ * Depth of the node in the tree
*/
- TrainNode trueChild;
+ final int depth;
/**
- * Child node that fails the test.
+ * The lower bound (inclusive) in the order array of the samples belonging to this node.
*/
- TrainNode falseChild;
+ final int low;
/**
- * Training dataset.
+ * The upper bound (exclusive) in the order array of the samples belonging to this node.
+ */
+ final int high;
+ /**
+ * The number of samples
+ */
+ final int samples;
+ /**
+ * Child node that passes the test.
*/
- final Matrix x;
+ @Nullable
+ TrainNode trueChild;
/**
- * Training data response value.
+ * Child node that fails the test.
*/
- final double[] y;
+ @Nullable
+ TrainNode falseChild;
- int[] bags;
+ @Nullable
+ int[] constFeatures;
- final int depth;
+ public TrainNode(@Nonnull Node node, int depth, int low, int high, int samples) {
+ this(node, depth, low, high, samples, new int[0]);
+ }
- /**
- * Constructor.
- */
- public TrainNode(Node node, Matrix x, double[] y, int[] bags, int depth) {
+ public TrainNode(@Nonnull Node node, int depth, int low, int high, int samples,
+ @Nonnull int[] constFeatures) {
+ if (low >= high) {
+ throw new IllegalArgumentException(
+ "Unexpected condition was met. low=" + low + ", high=" + high);
+ }
this.node = node;
- this.x = x;
- this.y = y;
- this.bags = bags;
this.depth = depth;
+ this.low = low;
+ this.high = high;
+ this.samples = samples;
+ this.constFeatures = constFeatures;
}
@Override
@@ -536,7 +571,7 @@ public final class RegressionTree implements Regression<Vector> {
*/
public void calculateOutput(final NodeOutput output) {
if (node.trueChild == null && node.falseChild == null) {
- int[] samples = SmileExtUtils.bagsToSamples(bags);
+ int[] samples = getSamples();
node.output = output.calculate(samples);
} else {
if (trueChild != null) {
@@ -548,6 +583,24 @@ public final class RegressionTree implements Regression<Vector> {
}
}
+ @Nonnull
+ private int[] getSamples() {
+ int size = high - low;
+ final IntArrayList result = new IntArrayList(size);
+
+ final int[] sampleIndex = _sampleIndex;
+ final int[] samples = _samples;
+ for (int i = low, end = high; i < end; i++) {
+ int index = sampleIndex[i];
+ int sample = samples[index];
+ if (sample > 0) {
+ result.add(index);
+ }
+ }
+
+ return result.toArray(true);
+ }
+
/**
* Finds the best attribute to split on at the current node. Returns true if a split exists
* to reduce squared error, false otherwise.
@@ -558,23 +611,23 @@ public final class RegressionTree implements Regression<Vector> {
return false;
}
// avoid split if the number of samples is less than threshold
- final int numSamples = bags.length;
- if (numSamples <= _minSplit) {
+ if (samples <= _minSamplesSplit) {
return false;
}
- final double sum = node.output * numSamples;
-
+ final int[] constFeatures_ = this.constFeatures;
// Loop through features and compute the reduction of squared error,
// which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2
- final int[] samples =
- _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows()) : null;
- for (int varJ : variableIndex(x, bags)) {
- final Node split = findBestSplit(numSamples, sum, varJ, samples);
+ final double sum = node.output * samples;
+ for (int varJ : variableIndex()) {
+ if (ArrayUtils.contains(constFeatures_, varJ)) {
+ continue;
+ }
+ final Node split = findBestSplit(samples, sum, varJ);
if (split.splitScore > node.splitScore) {
node.splitFeature = split.splitFeature;
- node.splitFeatureType = split.splitFeatureType;
+ node.quantitativeFeature = split.quantitativeFeature;
node.splitValue = split.splitValue;
node.splitScore = split.splitScore;
node.trueChildOutput = split.trueChildOutput;
@@ -585,29 +638,35 @@ public final class RegressionTree implements Regression<Vector> {
return node.splitFeature != -1;
}
- private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) {
- final int[] variableIndex;
- if (x.isSparse()) {
- final IntSet cols = new IntArraySet(_numVars);
+ @Nonnull
+ private int[] variableIndex() {
+ final Matrix X = _X;
+ final IntReservoirSampler sampler = new IntReservoirSampler(_numVars, _rnd.nextLong());
+ if (X.isSparse()) {
+ // sample columns from sampled examples
+ final RoaringBitmap cols = new RoaringBitmap();
final VectorProcedure proc = new VectorProcedure() {
- public void apply(int col, double value) {
+ public void apply(final int col) {
cols.add(col);
}
};
- for (final int row : bags) {
- x.eachNonNullInRow(row, proc);
+ final int[] sampleIndex = _sampleIndex;
+ for (int i = low, end = high; i < end; i++) {
+ int row = sampleIndex[i];
+ X.eachColumnIndexInRow(row, proc);
}
- variableIndex = cols.toArray(false);
+ cols.forEach(new IntConsumer() {
+ public void accept(final int k) {
+ sampler.add(k);
+ }
+ });
} else {
- variableIndex = MathUtils.permutation(_attributes.length);
- }
-
- if (_numVars < variableIndex.length) {
- SmileExtUtils.shuffle(variableIndex, _rnd);
- return Arrays.copyOf(variableIndex, _numVars);
-
+ final int ncols = X.numColumns();
+ for (int i = 0; i < ncols; i++) {
+ sampler.add(i);
+ }
}
- return variableIndex;
+ return sampler.getSample();
}
/**
@@ -618,29 +677,42 @@ public final class RegressionTree implements Regression<Vector> {
* @param impurity the impurity of this node.
* @param j the attribute to split on.
*/
- private Node findBestSplit(final int n, final double sum, final int j,
- @Nullable final int[] samples) {
+ private Node findBestSplit(final int n, final double sum, final int j) {
+ final int[] samples = _samples;
+ final int[] sampleIndex = _sampleIndex;
+ final Matrix X = _X;
+ final double[] y = _y;
+
final Node split = new Node(0.d);
- if (_attributes[j] == AttributeType.NOMINAL) {
- //final int m = _attributes[j].getSize();
- //final double[] trueSum = new double[m];
- //final int[] trueCount = new int[m];
+
+ if (_nominalAttrs.contains(j)) {// nominal
final Int2DoubleOpenHashMap trueSum = new Int2DoubleOpenHashMap();
final Int2IntOpenHashMap trueCount = new Int2IntOpenHashMap();
- for (int b = 0, size = bags.length; b < size; b++) {
- int i = bags[b];
+ int countNaN = 0;
+ for (int i = low, end = high; i < end; i++) {
+ final int index = sampleIndex[i];
+ final int numSamples = samples[index];
+ if (numSamples == 0) {
+ continue;
+ }
+
// For each true feature of this datum increment the
// sufficient statistics for the "true" branch to evaluate
// splitting on this feature.
- final double v = x.get(i, j, Double.NaN);
+ final double v = X.get(i, j, Double.NaN);
if (Double.isNaN(v)) {
+ countNaN++;
continue;
}
- int index = (int) v;
+ int x_ij = (int) v;
- trueSum.addTo(index, y[i]);
- trueCount.addTo(index, 1);
+ trueSum.addTo(x_ij, y[i]);
+ trueCount.addTo(x_ij, 1);
+ }
+ final int countDistinctX = trueCount.size() + (countNaN == 0 ? 0 : 1);
+ if (countDistinctX <= 1) { // mark as a constant feature
+ this.constFeatures = ArrayUtils.sortedArraySet(constFeatures, j);
}
for (Entry e : trueCount.int2IntEntrySet()) {
@@ -650,7 +722,7 @@ public final class RegressionTree implements Regression<Vector> {
final double fc = n - tc;
// skip splitting
- if (tc < _minSplit || fc < _minSplit) {
+ if (tc < _minSamplesSplit || fc < _minSamplesSplit) {
continue;
}
@@ -664,45 +736,53 @@ public final class RegressionTree implements Regression<Vector> {
if (gain > split.splitScore) {
// new best split
split.splitFeature = j;
- split.splitFeatureType = AttributeType.NOMINAL;
+ split.quantitativeFeature = false;
split.splitValue = k;
split.splitScore = gain;
split.trueChildOutput = trueMean;
split.falseChildOutput = falseMean;
}
}
- } else if (_attributes[j] == AttributeType.NUMERIC) {
+ } else {
+ final MutableInt countNaN = new MutableInt(0);
+ final MutableInt replaceCount = new MutableInt(0);
- _order.eachNonNullInColumn(j, new VectorProcedure() {
+ _order.eachNonNullInColumn(j, low, high, new Consumer() {
double trueSum = 0.0;
int trueCount = 0;
- double prevx = Double.NaN;
+ double prevx = Double.NaN, lastx = Double.NaN;
- public void apply(final int row, final int i) {
- final int sample = samples[i];
- if (sample == 0) {
+ public void accept(int pos, final int i) {
+ final int numSamples = samples[i];
+ if (numSamples == 0) {
return;
}
- final double x_ij = x.get(i, j, Double.NaN);
+
+ final double x_ij = _X.get(i, j, Double.NaN);
if (Double.isNaN(x_ij)) {
+ countNaN.incr();
return;
}
- final double y_i = y[i];
+ if (lastx != x_ij) {
+ lastx = x_ij;
+ replaceCount.incr();
+ }
+ final double y_i = _y[i];
if (Double.isNaN(prevx) || x_ij == prevx) {
prevx = x_ij;
- trueSum += sample * y_i;
- trueCount += sample;
+ trueSum += numSamples * y_i;
+ trueCount += numSamples;
return;
}
final double falseCount = n - trueCount;
// If either side is empty, skip this feature.
- if (trueCount < _minSplit || falseCount < _minSplit) {
+ if (trueCount < _minSamplesSplit || falseCount < _minSamplesSplit) {
prevx = x_ij;
- trueSum += sample * y_i;
- trueCount += sample;
+ trueSum += numSamples * y_i;
+ trueCount += numSamples;
return;
}
@@ -719,7 +799,7 @@ public final class RegressionTree implements Regression<Vector> {
if (gain > split.splitScore) {
// new best split
split.splitFeature = j;
- split.splitFeatureType = AttributeType.NUMERIC;
+ split.quantitativeFeature = true;
split.splitValue = (x_ij + prevx) / 2;
split.splitScore = gain;
split.trueChildOutput = trueMean;
@@ -727,13 +807,15 @@ public final class RegressionTree implements Regression<Vector> {
}
prevx = x_ij;
- trueSum += sample * y_i;
- trueCount += sample;
+ trueSum += numSamples * y_i;
+ trueCount += numSamples;
}//apply
});
- } else {
- throw new IllegalStateException("Unsupported attribute type: " + _attributes[j]);
+ final int countDistinctX = replaceCount.get() + (countNaN.get() == 0 ? 0 : 1);
+ if (countDistinctX <= 1) { // mark as a constant feature
+ this.constFeatures = ArrayUtils.sortedArraySet(constFeatures, j);
+ }
}
return split;
@@ -742,51 +824,69 @@ public final class RegressionTree implements Regression<Vector> {
/**
* Split the node into two children nodes. Returns true if split success.
*/
- public boolean split(final PriorityQueue<TrainNode> nextSplits) {
+ public boolean split(@Nullable final PriorityQueue<TrainNode> nextSplits) {
if (node.splitFeature < 0) {
throw new IllegalStateException("Split a node with invalid feature.");
}
- // split sample bags
- int childBagSize = (int) (bags.length * 0.4);
- IntArrayList trueBags = new IntArrayList(childBagSize);
- IntArrayList falseBags = new IntArrayList(childBagSize);
- int tc = splitSamples(trueBags, falseBags);
- int fc = bags.length - tc;
-
- if (tc < _minLeafSize || fc < _minLeafSize) {
- // set as a leaf node
- node.splitFeature = -1;
- node.splitFeatureType = null;
- node.splitValue = Double.NaN;
- node.splitScore = 0.0;
- if (_nodeOutput == null) {
- this.bags = null;
- }
+ final IntPredicate goesLeft = getPredicate();
+
+ // split samples
+ final int tc, fc, pivot;
+ {
+ MutableInt tc_ = new MutableInt(0);
+ MutableInt fc_ = new MutableInt(0);
+ pivot = splitSamples(tc_, fc_, goesLeft);
+ tc = tc_.get();
+ fc = fc_.get();
+ }
+
+ if (tc < _minSamplesLeaf || fc < _minSamplesLeaf) {
+ node.markAsLeaf();
return false;
}
- this.bags = null; // help GC for recursive call
+ partitionOrder(low, pivot, high, goesLeft);
+
+ int leaves = 0;
node.trueChild = new Node(node.trueChildOutput);
- this.trueChild = new TrainNode(node.trueChild, x, y, trueBags.toArray(), depth + 1);
- trueBags = null; // help GC for recursive call
- if (tc >= _minSplit && trueChild.findBestSplit()) {
+ this.trueChild =
+ new TrainNode(node.trueChild, depth + 1, low, pivot, tc, constFeatures.clone());
+ node.falseChild = new Node(node.falseChildOutput);
+ this.falseChild =
+ new TrainNode(node.falseChild, depth + 1, pivot, high, fc, constFeatures);
+ this.constFeatures = null;
+
+ if (tc >= _minSamplesSplit && trueChild.findBestSplit()) {
if (nextSplits != null) {
nextSplits.add(trueChild);
} else {
- trueChild.split(null);
+ if (trueChild.split(null) == false) {
+ leaves++;
+ }
}
+ } else {
+ leaves++;
}
- node.falseChild = new Node(node.falseChildOutput);
- this.falseChild = new TrainNode(node.falseChild, x, y, falseBags.toArray(), depth + 1);
- falseBags = null; // help GC for recursive call
- if (fc >= _minSplit && falseChild.findBestSplit()) {
+ if (fc >= _minSamplesSplit && falseChild.findBestSplit()) {
if (nextSplits != null) {
nextSplits.add(falseChild);
} else {
- falseChild.split(null);
+ if (falseChild.split(null) == false) {
+ leaves++;
+ }
+ }
+ } else {
+ leaves++;
+ }
+
+ // Prune meaningless branches
+ if (leaves == 2) {// both left and right child is leaf node
+ if (node.trueChild.output == node.falseChild.output) {// found meaningless branch
+ node.markAsLeaf();
+ return false;
}
}
@@ -796,120 +896,262 @@ public final class RegressionTree implements Regression<Vector> {
}
/**
- * @return the number of true samples
+ * @return Pivot to split samples
*/
- private int splitSamples(@Nonnull final IntArrayList trueBags,
- @Nonnull final IntArrayList falseBags) {
- int tc = 0;
- if (node.splitFeatureType == AttributeType.NOMINAL) {
- final int splitFeature = node.splitFeature;
- final double splitValue = node.splitValue;
- for (int i = 0, size = bags.length; i < size; i++) {
- final int index = bags[i];
- if (x.get(index, splitFeature, Double.NaN) == splitValue) {
- trueBags.add(index);
- tc++;
- } else {
- falseBags.add(index);
- }
+ private int splitSamples(@Nonnull final MutableInt tc, @Nonnull final MutableInt fc,
+ @Nonnull final IntPredicate goesLeft) {
+ final int[] sampleIndex = _sampleIndex;
+ final int[] samples = _samples;
+
+ int pivot = low;
+ for (int k = low, end = high; k < end; k++) {
+ final int i = sampleIndex[k];
+ final int numSamples = samples[i];
+ if (goesLeft.test(i)) {
+ tc.addValue(numSamples);
+ pivot++;
+ } else {
+ fc.addValue(numSamples);
}
- } else if (node.splitFeatureType == AttributeType.NUMERIC) {
- final int splitFeature = node.splitFeature;
- final double splitValue = node.splitValue;
- for (int i = 0, size = bags.length; i < size; i++) {
- final int index = bags[i];
- if (x.get(index, splitFeature, Double.NaN) <= splitValue) {
- trueBags.add(index);
- tc++;
- } else {
- falseBags.add(index);
+ }
+ return pivot;
+ }
+
+ /**
+ * Modifies {@link #_order} and {@link #_sampleIndex} by partitioning the range from low
+ * (inclusive) to high (exclusive) so that all elements i for which goesLeft(i) is true come
+ * before all elements for which it is false, but element ordering is otherwise preserved.
+ * The number of true values returned by goesLeft must equal split-low.
+ *
+ * @param low the low bound of the segment of the order arrays which will be partitioned.
+ * @param split where the partition's split point will end up.
+ * @param high the high bound of the segment of the order arrays which will be partitioned.
+ * @param goesLeft whether an element goes to the left side or the right side of the
+ * partition.
+ */
+ private void partitionOrder(final int low, final int pivot, final int high,
+ @Nonnull final IntPredicate goesLeft) {
+ final int[] buf = new int[high - pivot];
+ _order.eachRow(new Consumer() {
+ @Override
+ public void accept(int col, @Nonnull final SparseIntArray row) {
+ partitionArray(row, low, pivot, high, goesLeft, buf);
+ }
+ });
+ partitionArray(_sampleIndex, low, pivot, high, goesLeft, buf);
+ }
+
+ @Nonnull
+ private IntPredicate getPredicate() {
+ if (node.quantitativeFeature) {
+ return new IntPredicate() {
+ @Override
+ public boolean test(int i) {
+ return _X.get(i, node.splitFeature, Double.NaN) <= node.splitValue;
}
+ };
+ } else {
+ return new IntPredicate() {
+ @Override
+ public boolean test(int i) {
+ return _X.get(i, node.splitFeature, Double.NaN) == node.splitValue;
+ }
+ };
+ }
+ }
+
+ }
+
+ private static void partitionArray(@Nonnull final SparseIntArray a, final int low,
+ final int pivot, final int high, @Nonnull final IntPredicate goesLeft,
+ @Nonnull final int[] buf) {
+ final int[] keys = a.keys();
+ final int[] values = a.values();
+ final int size = a.size();
+
+ final int startPos = ArrayUtils.insertionPoint(keys, size, low);
+ final int endPos = ArrayUtils.insertionPoint(keys, size, high);
+ int pos = startPos, k = 0;
+ for (int i = startPos, j = 0; i < endPos; i++) {
+ final int a_i = values[i];
+ if (goesLeft.test(a_i)) {
+ keys[pos] = low + j;
+ values[pos] = a_i;
+ pos++;
+ j++;
+ } else {
+ if (k >= buf.length) {
+ throw new IndexOutOfBoundsException(String.format(
+ "low=%d, pivot=%d, high=%d, a.size()=%d, buf.length=%d, i=%d, j=%d, k=%d",
+ low, pivot, high, a.size(), buf.length, i, j, k));
}
+ buf[k++] = a_i;
+ }
+ }
+ for (int i = 0; i < k; i++) {
+ keys[pos] = pivot + i;
+ values[pos] = buf[i];
+ pos++;
+ }
+ if (pos != endPos) {
+ throw new IllegalStateException(
+ String.format("pos=%d, startPos=%d, endPos=%d, k=%d", pos, startPos, endPos, k));
+ }
+ }
+
+ /**
+ * Modifies an array in-place by partitioning the range from low (inclusive) to high (exclusive)
+ * so that all elements i for which goesLeft(i) is true come before all elements for which it is
+ * false, but element ordering is otherwise preserved. The number of true values returned by
+ * goesLeft must equal split-low. buf is scratch space large enough (i.e., at least high-split
+ * long) to hold all elements for which goesLeft is false.
+ */
+ private static void partitionArray(@Nonnull final int[] a, final int low, final int pivot,
+ final int high, @Nonnull final IntPredicate goesLeft, @Nonnull final int[] buf) {
+ int j = low;
+ int k = 0;
+ for (int i = low; i < high; i++) {
+ if (i >= a.length) {
+ throw new IndexOutOfBoundsException(String.format(
+ "low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", low,
+ pivot, high, a.length, buf.length, i, j, k));
+ }
+ final int a_i = a[i];
+ if (goesLeft.test(a_i)) {
+ a[j++] = a_i;
} else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + node.splitFeatureType);
+ if (k >= buf.length) {
+ throw new IndexOutOfBoundsException(String.format(
+ "low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d",
+ low, pivot, high, a.length, buf.length, i, j, k));
+ }
+ buf[k++] = a_i;
}
- return tc;
}
+ if (k != high - pivot || j != pivot) {
+ throw new IndexOutOfBoundsException(
+ String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, j=%d, k=%d",
+ low, pivot, high, a.length, buf.length, j, k));
+ }
+ System.arraycopy(buf, 0, a, pivot, k);
+ }
+ /**
+ * Prunes redundant leaves from the tree. In some cases, a node is split into two leaves that
+ * get assigned the same label, so this recursively combines leaves when it notices this
+ * situation.
+ */
+ private static void pruneRedundantLeaves(@Nonnull final Node node, @Nonnull Vector importance) {
+ if (node.isLeaf()) {
+ return;
+ }
+
+ // The children might not be leaves now, but might collapse into leaves given the chance.
+ pruneRedundantLeaves(node.trueChild, importance);
+ pruneRedundantLeaves(node.falseChild, importance);
+
+ if (node.trueChild.isLeaf() && node.falseChild.isLeaf()
+ && node.trueChild.output == node.falseChild.output) {
+ node.trueChild = null;
+ node.falseChild = null;
+ importance.decr(node.splitFeature, node.splitScore);
+ }
}
- public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix x,
+
+ public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x,
@Nonnull double[] y, int maxLeafs) {
- this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
+ this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null);
}
- public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix x,
+ public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x,
@Nonnull double[] y, int maxLeafs, @Nullable PRNG rand) {
- this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
+ this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, rand);
}
- public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix x,
- @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
- int minLeafSize, @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags,
- @Nullable PRNG rand) {
- this(attributes, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, order, bags, null, rand);
+ public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x,
+ @Nonnull double[] y, int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit,
+ int minSamplesLeaf, @Nullable int[] samples, @Nullable PRNG rand) {
+ this(nominalAttrs, x, y, numVars, maxDepth, maxLeafNodes, minSamplesSplit, minSamplesLeaf, samples, null, rand);
}
/**
* Constructor. Learns a regression tree for gradient tree boosting.
*
- * @param attributes the attribute properties.
+ * @param nominalAttrs the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param numVars the number of input variables to pick to split on at each node. It seems that
* dim/3 give generally good performance, where dim is the number of variables.
- * @param maxLeafs the maximum number of leaf nodes in the tree.
- * @param minSplits number of instances in a node below which the tree will not split, setting S
- * = 5 generally gives good results.
- * @param order the index of training values in ascending order. Note that only numeric
- * attributes need be sorted.
- * @param bags the sample set of instances for stochastic learning.
+ * @param maxLeafNodes the maximum number of leaf nodes in the tree.
+ * @param minSamplesLeaf number of instances in a node below which the tree will not split,
+ * setting 5 generally gives good results.
+ * @param samples the sample set of instances for stochastic learning.
* @param output An interface to calculate node output.
*/
- public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix x,
- @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
- int minLeafSize, @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags,
- @Nullable NodeOutput output, @Nullable PRNG rand) {
- checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
-
- this._attributes = SmileExtUtils.attributeTypes(attributes, x);
- if (_attributes.length != x.numColumns()) {
- throw new IllegalArgumentException(
- "-attrs option is invalid: " + Arrays.toString(attributes));
+ public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x,
+ @Nonnull double[] y, int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit,
+ int minSamplesLeaf, @Nullable int[] samples, @Nullable NodeOutput output,
+ @Nullable PRNG rand) {
+ checkArgument(x, y, numVars, maxDepth, maxLeafNodes, minSamplesSplit, minSamplesLeaf);
+
+ this._X = x;
+ this._y = y;
+
+ if (nominalAttrs == null) {
+ nominalAttrs = new RoaringBitmap();
}
- this._hasNumericType = SmileExtUtils.containsNumericType(_attributes);
+ this._nominalAttrs = nominalAttrs;
this._numVars = numVars;
this._maxDepth = maxDepth;
- this._minSplit = minSplits;
- this._minLeafSize = minLeafSize;
- this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
- this._importance = x.isSparse() ? new SparseVector() : new DenseVector(_attributes.length);
+ // min_sample_leaf >= 2 is satisfied iff min_sample_split >= 4
+ // So, split only happens when samples in intermediate nodes has >= 2 * min_sample_leaf nodes.
+ if (minSamplesSplit < minSamplesLeaf * 2) {
+ if (logger.isInfoEnabled()) {
+ logger.info(String.format(
+ "min_sample_leaf = %d replaces min_sample_split = %d with min_sample_split = %d",
+ minSamplesLeaf, minSamplesSplit, minSamplesLeaf * 2));
+ }
+ minSamplesSplit = minSamplesLeaf * 2;
+ }
+ this._minSamplesSplit = minSamplesSplit;
+ this._minSamplesLeaf = minSamplesLeaf;
+ this._importance = x.isSparse() ? new SparseVector() : new DenseVector(x.numColumns());
this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
- this._nodeOutput = output;
int n = 0;
double sum = 0.0;
- if (bags == null) {
+ final int[] sampleIndex;
+ if (samples == null) {
n = y.length;
- bags = new int[n];
+ samples = new int[n];
+ sampleIndex = new int[n];
for (int i = 0; i < n; i++) {
- bags[i] = i;
+ samples[i] = 1;
sum += y[i];
+ sampleIndex[i] = i;
}
} else {
- n = bags.length;
- for (int i = 0; i < n; i++) {
- int index = bags[i];
- sum += y[index];
+ final IntArrayList positions = new IntArrayList(n);
+ for (int i = 0, end = y.length; i < end; i++) {
+ final int sample = samples[i];
+ if (sample != 0) {
+ n += sample;
+ sum += sample * y[i];
+ positions.add(i);
+ }
}
+ sampleIndex = positions.toArray(true);
}
+ this._samples = samples;
+ this._order = SmileExtUtils.sort(nominalAttrs, x, samples);
+ this._sampleIndex = sampleIndex;
this._root = new Node(sum / n);
- TrainNode trainRoot = new TrainNode(_root, x, y, bags, 1);
- if (maxLeafs == Integer.MAX_VALUE) {
+ TrainNode trainRoot = new TrainNode(_root, 1, 0, _sampleIndex.length, n);
+ if (maxLeafNodes == Integer.MAX_VALUE) {
if (trainRoot.findBestSplit()) {
trainRoot.split(null);
}
@@ -922,14 +1164,17 @@ public final class RegressionTree implements Regression<Vector> {
}
// Pop best leaf from priority queue, split it, and push
// children nodes into the queue if possible.
- for (int leaves = 1; leaves < maxLeafs; leaves++) {
+ for (int leaves = 1; leaves < maxLeafNodes; leaves++) {
// parent is the leaf to split
TrainNode node = nextSplits.poll();
if (node == null) {
break;
}
- node.split(nextSplits); // Split the parent node into two children nodes
+ if (!node.split(nextSplits)) { // Split the parent node into two children nodes
+ leaves--;
+ }
}
+ pruneRedundantLeaves(_root, _importance);
}
if (output != null) {
@@ -938,11 +1183,14 @@ public final class RegressionTree implements Regression<Vector> {
}
private static void checkArgument(@Nonnull Matrix x, @Nonnull double[] y, int numVars,
- int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
+ int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf) {
if (x.numRows() != y.length) {
throw new IllegalArgumentException(
String.format("The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
}
+ if (y.length == 0) {
+ throw new IllegalArgumentException("No training example given");
+ }
if (numVars <= 0 || numVars > x.numColumns()) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
@@ -950,17 +1198,17 @@ public final class RegressionTree implements Regression<Vector> {
if (maxDepth < 2) {
throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
}
- if (maxLeafs < 2) {
- throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafs);
+ if (maxLeafNodes < 2) {
+ throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafNodes);
}
- if (minSplits < 2) {
+ if (minSamplesSplit < 2) {
throw new IllegalArgumentException(
"Invalid minimum number of samples required to split an internal node: "
- + minSplits);
+ + minSamplesSplit);
}
- if (minLeafSize < 1) {
+ if (minSamplesLeaf < 1) {
throw new IllegalArgumentException(
- "Invalid minimum size of leaf nodes: " + minLeafSize);
+ "Invalid minimum size of leaf nodes: " + minSamplesLeaf);
}
}
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
index 12afa7c..64d5c3b 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
@@ -18,9 +18,10 @@
*/
package hivemall.smile.tools;
+import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
+
import hivemall.annotations.Since;
import hivemall.annotations.VisibleForTesting;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.vm.StackMachine;
import hivemall.smile.vm.VMRuntimeException;
import hivemall.utils.codec.Base91;
@@ -372,7 +373,7 @@ public final class TreePredictUDFv1 extends GenericUDF {
/**
* The type of split feature
*/
- AttributeType splitFeatureType = null;
+ boolean quantitativeFeature = true;
/**
* The split value.
*/
@@ -414,21 +415,18 @@ public final class TreePredictUDFv1 extends GenericUDF {
if (trueChild == null && falseChild == null) {
return output;
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- if (x[splitFeature] == splitValue) {
+ if (quantitativeFeature) {
+ if (x[splitFeature] <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x[splitFeature] <= splitValue) {
+ } else {
+ if (x[splitFeature] == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
}
@@ -443,11 +441,8 @@ public final class TreePredictUDFv1 extends GenericUDF {
this.output = in.readInt();
this.splitFeature = in.readInt();
int typeId = in.readInt();
- if (typeId == -1) {
- this.splitFeatureType = null;
- } else {
- this.splitFeatureType = AttributeType.resolve(typeId);
- }
+
+ this.quantitativeFeature = (typeId == NUMERIC);
this.splitValue = in.readDouble();
if (in.readBoolean()) {
this.trueChild = new DtNodeV1();
@@ -477,7 +472,7 @@ public final class TreePredictUDFv1 extends GenericUDF {
/**
* The type of split feature
*/
- AttributeType splitFeatureType = null;
+ boolean quantitativeFeature = true;
/**
* The split value.
*/
@@ -516,22 +511,19 @@ public final class TreePredictUDFv1 extends GenericUDF {
if (trueChild == null && falseChild == null) {
return output;
} else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- // REVIEWME if(Math.equals(x[splitFeature], splitValue)) {
- if (x[splitFeature] == splitValue) {
+ if (quantitativeFeature) {
+ if (x[splitFeature] <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x[splitFeature] <= splitValue) {
+ } else {
+ // REVIEWME if(Math.equals(x[splitFeature], splitValue)) {
+ if (x[splitFeature] == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
- } else {
- throw new IllegalStateException(
- "Unsupported attribute type: " + splitFeatureType);
}
}
}
@@ -546,11 +538,7 @@ public final class TreePredictUDFv1 extends GenericUDF {
this.output = in.readDouble();
this.splitFeature = in.readInt();
int typeId = in.readInt();
- if (typeId == -1) {
- this.splitFeatureType = null;
- } else {
- this.splitFeatureType = AttributeType.resolve(typeId);
- }
+ this.quantitativeFeature = (typeId == NUMERIC);
this.splitValue = in.readDouble();
if (in.readBoolean()) {
this.trueChild = new RtNodeV1();
diff --git a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
index 0c72866..a0844f4 100644
--- a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
+++ b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
@@ -22,15 +22,14 @@ import hivemall.annotations.VisibleForTesting;
import hivemall.math.matrix.ColumnMajorMatrix;
import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.MatrixUtils;
-import hivemall.math.matrix.ints.ColumnMajorDenseIntMatrix2d;
-import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.VectorProcedure;
import hivemall.smile.classification.DecisionTree.SplitRule;
-import hivemall.smile.data.AttributeType;
+import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.MathUtils;
import smile.data.NominalAttribute;
@@ -45,79 +44,90 @@ import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.roaringbitmap.RoaringBitmap;
public final class SmileExtUtils {
+ public static final byte NUMERIC = (byte) 1;
+ public static final byte NOMINAL = (byte) 2;
private SmileExtUtils() {}
/**
* Q for {@link NumericAttribute}, C for {@link NominalAttribute}.
*/
- @Nullable
- public static AttributeType[] resolveAttributes(@Nullable final String opt)
+ @Nonnull
+ public static RoaringBitmap resolveAttributes(@Nullable final String opt)
throws UDFArgumentException {
+ final RoaringBitmap attr = new RoaringBitmap();
if (opt == null) {
- return null;
+ return attr;
}
final String[] opts = opt.split(",");
final int size = opts.length;
- final AttributeType[] attr = new AttributeType[size];
for (int i = 0; i < size; i++) {
final String type = opts[i];
if ("Q".equals(type)) {
- attr[i] = AttributeType.NUMERIC;
+ continue;
} else if ("C".equals(type)) {
- attr[i] = AttributeType.NOMINAL;
+ attr.add(i);
} else {
- throw new UDFArgumentException("Unexpected type: " + type);
+ throw new UDFArgumentException("Unsupported attribute type: " + type);
}
}
return attr;
}
+ /**
+ * Q for {@link NumericAttribute}, C for {@link NominalAttribute}.
+ */
@Nonnull
- public static AttributeType[] attributeTypes(@Nullable final AttributeType[] attributes,
- @Nonnull final Matrix x) {
- if (attributes == null) {
- int p = x.numColumns();
- AttributeType[] newAttributes = new AttributeType[p];
- Arrays.fill(newAttributes, AttributeType.NUMERIC);
- return newAttributes;
+ public static RoaringBitmap parseNominalAttributeIndicies(@Nullable final String opt)
+ throws UDFArgumentException {
+ final RoaringBitmap attr = new RoaringBitmap();
+ if (opt == null) {
+ return attr;
}
- return attributes;
+ for (String s : opt.split(",")) {
+ if (NumberUtils.isDigits(s)) {
+ int index = NumberUtils.parseInt(s);
+ attr.add(index);
+ } else {
+ throw new UDFArgumentException("Expected integer but got " + s);
+ }
+ }
+ return attr;
}
@VisibleForTesting
@Nonnull
- public static AttributeType[] convertAttributeTypes(
+ public static RoaringBitmap convertAttributeTypes(
@Nonnull final smile.data.Attribute[] original) {
final int size = original.length;
- final AttributeType[] dst = new AttributeType[size];
+ final RoaringBitmap nominalAttrs = new RoaringBitmap();
for (int i = 0; i < size; i++) {
smile.data.Attribute o = original[i];
switch (o.type) {
case NOMINAL: {
- dst[i] = AttributeType.NOMINAL;
+ nominalAttrs.add(i);
break;
}
case NUMERIC: {
- dst[i] = AttributeType.NUMERIC;
break;
}
default:
throw new UnsupportedOperationException("Unsupported type: " + o.type);
}
}
- return dst;
+ return nominalAttrs;
}
@Nonnull
- public static ColumnMajorIntMatrix sort(@Nonnull final AttributeType[] attributes,
- @Nonnull final Matrix x) {
+ public static VariableOrder sort(@Nonnull final RoaringBitmap nominalAttrs,
+ @Nonnull final Matrix x, @Nonnull final int[] samples) {
final int n = x.numRows();
final int p = x.numColumns();
- final int[][] index = new int[p][];
+ final SparseIntArray[] index = new SparseIntArray[p];
if (x.isSparse()) {
int initSize = n / 10;
final DoubleArrayList dlist = new DoubleArrayList(initSize);
@@ -125,6 +135,9 @@ public final class SmileExtUtils {
final VectorProcedure proc = new VectorProcedure() {
@Override
public void apply(final int i, final double v) {
+ if (samples[i] == 0) {
+ return;
+ }
dlist.add(v);
ilist.add(i);
}
@@ -132,32 +145,48 @@ public final class SmileExtUtils {
final ColumnMajorMatrix x2 = x.toColumnMajorMatrix();
for (int j = 0; j < p; j++) {
- if (attributes[j] != AttributeType.NUMERIC) {
- continue;
+ if (nominalAttrs.contains(j)) {
+ continue; // nop for categorical columns
}
+ // sort only numerical columns
x2.eachNonNullInColumn(j, proc);
if (ilist.isEmpty()) {
continue;
}
- int[] indexJ = ilist.toArray();
- QuickSort.sort(dlist.array(), indexJ, indexJ.length);
- index[j] = indexJ;
+ int[] rowPtrs = ilist.toArray();
+ QuickSort.sort(dlist.array(), rowPtrs, rowPtrs.length);
+ index[j] = new SparseIntArray(rowPtrs);
dlist.clear();
ilist.clear();
}
} else {
- final double[] a = new double[n];
+ final DoubleArrayList dlist = new DoubleArrayList(n);
+ final IntArrayList ilist = new IntArrayList(n);
for (int j = 0; j < p; j++) {
- if (attributes[j] == AttributeType.NUMERIC) {
- for (int i = 0; i < n; i++) {
- a[i] = x.get(i, j);
+ if (nominalAttrs.contains(j)) {
+ continue; // nop for categorical columns
+ }
+ // sort only numerical columns
+ for (int i = 0; i < n; i++) {
+ if (samples[i] == 0) {
+ continue;
}
- index[j] = QuickSort.sort(a);
+ double x_ij = x.get(i, j);
+ dlist.add(x_ij);
+ ilist.add(i);
+ }
+ if (ilist.isEmpty()) {
+ continue;
}
+ int[] rowPtrs = ilist.toArray();
+ QuickSort.sort(dlist.array(), rowPtrs, rowPtrs.length);
+ index[j] = new SparseIntArray(rowPtrs);
+ dlist.clear();
+ ilist.clear();
}
}
- return new ColumnMajorDenseIntMatrix2d(index, n);
+ return new VariableOrder(index);
}
@Nonnull
@@ -300,33 +329,11 @@ public final class SmileExtUtils {
x[j] = s;
}
- @Nonnull
- public static int[] bagsToSamples(@Nonnull final int[] bags) {
- int maxIndex = -1;
- for (int e : bags) {
- if (e > maxIndex) {
- maxIndex = e;
- }
- }
- return bagsToSamples(bags, maxIndex + 1);
- }
-
- @Nonnull
- public static int[] bagsToSamples(@Nonnull final int[] bags, final int samplesLength) {
- final int[] samples = new int[samplesLength];
- for (int i = 0, size = bags.length; i < size; i++) {
- samples[bags[i]]++;
- }
- return samples;
- }
-
- public static boolean containsNumericType(@Nonnull final AttributeType[] attributes) {
- for (AttributeType attr : attributes) {
- if (attr == AttributeType.NUMERIC) {
- return true;
- }
- }
- return false;
+ public static boolean containsNumericType(@Nonnull final Matrix x,
+ final RoaringBitmap attributes) {
+ int numColumns = x.numColumns();
+ int numCategoricalCols = attributes.getCardinality();
+ return numColumns != numCategoricalCols; // contains at least one numerical column
}
@Nonnull
diff --git a/core/src/main/java/hivemall/smile/utils/VariableOrder.java b/core/src/main/java/hivemall/smile/utils/VariableOrder.java
new file mode 100644
index 0000000..f175c96
--- /dev/null
+++ b/core/src/main/java/hivemall/smile/utils/VariableOrder.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.utils;
+
+import hivemall.utils.collections.arrays.SparseIntArray;
+import hivemall.utils.function.Consumer;
+
+import javax.annotation.Nonnull;
+
+public final class VariableOrder {
+
+ @Nonnull
+ private final SparseIntArray[] cols; // col => row
+
+ public VariableOrder(@Nonnull SparseIntArray[] cols) {
+ this.cols = cols;
+ }
+
+ public void eachRow(@Nonnull final Consumer consumer) {
+ for (int j = 0; j < cols.length; j++) {
+ final SparseIntArray row = cols[j];
+ if (row == null) {
+ continue;
+ }
+ consumer.accept(j, row);
+ }
+ }
+
+ public void eachNonNullInColumn(final int col, final int startRow, final int endRow,
+ @Nonnull final Consumer consumer) {
+ final SparseIntArray row = cols[col];
+ if (row == null) {
+ return;
+ }
+ row.forEach(startRow, endRow, consumer);
+ }
+
+}
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java b/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java
index 0869ff2..a5ba0c2 100644
--- a/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java
+++ b/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java
@@ -18,6 +18,8 @@
*/
package hivemall.utils.collections.arrays;
+import hivemall.utils.function.Consumer;
+
import java.util.Arrays;
import javax.annotation.Nonnull;
@@ -67,7 +69,7 @@ public final class DenseIntArray implements IntArray {
@Override
public int size() {
- return array.length;
+ return size;
}
@Override
@@ -89,4 +91,11 @@ public final class DenseIntArray implements IntArray {
}
}
+ @Override
+ public void forEach(@Nonnull final Consumer consumer) {
+ for (int i = 0; i < array.length; i++) {
+ consumer.accept(i, array[i]);
+ }
+ }
+
}
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java b/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
index 8edb0d4..58e0519 100644
--- a/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
+++ b/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
@@ -18,6 +18,8 @@
*/
package hivemall.utils.collections.arrays;
+import hivemall.utils.function.Consumer;
+
import java.io.Serializable;
import javax.annotation.Nonnull;
@@ -42,4 +44,6 @@ public interface IntArray extends Serializable {
@Nonnull
public int[] toArray(boolean copy);
+ public void forEach(@Nonnull Consumer consumer);
+
}
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java b/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java
index 8de5476..3ead15a 100644
--- a/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java
+++ b/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java
@@ -18,11 +18,14 @@
*/
package hivemall.utils.collections.arrays;
+import hivemall.utils.function.Consumer;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.Preconditions;
+import hivemall.utils.math.MathUtils;
import java.util.Arrays;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
public final class SparseIntArray implements IntArray {
@@ -32,22 +35,49 @@ public final class SparseIntArray implements IntArray {
private int[] mValues;
private int mSize;
- public SparseIntArray() {
- this(10);
+ public SparseIntArray() {}
+
+ public SparseIntArray(@Nonnegative int initialCapacity) {
+ this.mKeys = new int[initialCapacity];
+ this.mValues = new int[initialCapacity];
+ this.mSize = 0;
}
- public SparseIntArray(int initialCapacity) {
- mKeys = new int[initialCapacity];
- mValues = new int[initialCapacity];
- mSize = 0;
+ public SparseIntArray(@Nonnull final int[] values) {
+ this.mKeys = MathUtils.permutation(values.length);
+ this.mValues = values;
+ this.mSize = values.length;
}
- private SparseIntArray(int[] mKeys, int[] mValues, int mSize) {
+ public SparseIntArray(@Nonnull int[] mKeys, @Nonnull int[] mValues, @Nonnegative int mSize) {
this.mKeys = mKeys;
this.mValues = mValues;
this.mSize = mSize;
}
+ public void init(@Nonnull int[] keys, @Nonnull int[] values) {
+ init(keys, values, keys.length);
+ }
+
+ public void init(@Nonnull int[] keys, @Nonnull int[] values, final int size) {
+ if (size > keys.length || size > values.length) {
+ throw new IllegalArgumentException(String.format(
+ "Illegal size was specified... size = %d, keys.length = %d, values.length", size,
+ keys.length, values.length));
+ }
+ this.mKeys = keys;
+ this.mValues = values;
+ this.mSize = size;
+ }
+
+ public int[] keys() {
+ return mKeys;
+ }
+
+ public int[] values() {
+ return mValues;
+ }
+
public IntArray deepCopy() {
int[] newKeys = new int[mSize];
int[] newValues = new int[mSize];
@@ -84,16 +114,38 @@ public final class SparseIntArray implements IntArray {
mSize--;
}
+ public void removeRange(@Nonnegative final int start, @Nonnegative final int end) {
+ Preconditions.checkArgument(start <= end);
+
+ int startPos = indexOfKey(start);
+ if (startPos < 0) {
+ startPos = ~startPos;
+ }
+ int endPos = indexOfKey(end);
+ if (endPos < 0) {
+ endPos = ~endPos;
+ }
+
+ final int sizeToRemove = endPos - startPos;
+ if (sizeToRemove <= 0) {
+ return;
+ }
+
+ ArrayUtils.clearRange(mKeys, startPos, endPos, 0);
+ ArrayUtils.clearRange(mValues, startPos, endPos, 0);
+ this.mSize -= sizeToRemove;
+ }
+
@Override
public void put(int key, int value) {
int i = Arrays.binarySearch(mKeys, 0, mSize, key);
if (i >= 0) {
- mValues[i] = value;
+ this.mValues[i] = value;
} else {
i = ~i;
- mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
- mValues = ArrayUtils.insert(mValues, mSize, i, value);
- mSize++;
+ this.mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ this.mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ this.mSize++;
}
}
@@ -101,12 +153,12 @@ public final class SparseIntArray implements IntArray {
public void increment(int key, int value) {
int i = Arrays.binarySearch(mKeys, 0, mSize, key);
if (i >= 0) {
- mValues[i] += value;
+ this.mValues[i] += value;
} else {
i = ~i;
- mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
- mValues = ArrayUtils.insert(mValues, mSize, i, value);
- mSize++;
+ this.mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ this.mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ this.mSize++;
}
}
@@ -115,6 +167,20 @@ public final class SparseIntArray implements IntArray {
return mSize;
}
+ public int firstKey() {
+ if (mSize == 0) {
+ return -1;
+ }
+ return mKeys[0];
+ }
+
+ public int lastKey() {
+ if (mSize == 0) {
+ return -1;
+ }
+ return mKeys[mSize - 1];
+ }
+
@Override
public int keyAt(int index) {
return mKeys[index];
@@ -124,8 +190,12 @@ public final class SparseIntArray implements IntArray {
return mValues[index];
}
+ public void setKeyAt(int index, int key) {
+ this.mKeys[index] = key;
+ }
+
public void setValueAt(int index, int value) {
- mValues[index] = value;
+ this.mValues[index] = value;
}
public int indexOfKey(int key) {
@@ -146,7 +216,7 @@ public final class SparseIntArray implements IntArray {
}
public void clear(boolean zeroFill) {
- mSize = 0;
+ this.mSize = 0;
if (zeroFill) {
Arrays.fill(mKeys, 0);
Arrays.fill(mValues, 0);
@@ -158,9 +228,90 @@ public final class SparseIntArray implements IntArray {
put(key, value);
return;
}
- mKeys = ArrayUtils.append(mKeys, mSize, key);
- mValues = ArrayUtils.append(mValues, mSize, value);
- mSize++;
+ this.mKeys = ArrayUtils.append(mKeys, mSize, key);
+ this.mValues = ArrayUtils.append(mValues, mSize, value);
+ this.mSize++;
+ }
+
+ public void append(@Nonnegative final int dstPos, @Nonnull final int[] values) {
+ if (mSize == 0) {
+ this.mKeys = MathUtils.permutation(dstPos, values.length);
+ this.mValues = values.clone();
+ this.mSize = values.length;
+ return;
+ }
+
+ final int lastKey = mKeys[mSize - 1];
+ for (int i = 0; i < values.length; i++) {
+ final int key = dstPos + i;
+ if (key <= lastKey) {
+ put(key, values[i]);
+ } else {// append
+ int size = values.length - i;
+ int[] appendKeys = MathUtils.permutation(key, size);
+ this.mKeys = ArrayUtils.concat(mKeys, 0, mSize, appendKeys, 0, appendKeys.length);
+ this.mValues = ArrayUtils.concat(mValues, 0, mSize, values, i, size);
+ this.mSize += size;
+ break;
+ }
+ }
+ }
+
+ public void append(@Nonnegative final int dstPos, @Nonnull final int[] values, final int offset,
+ final int length) {
+ if (mSize == 0) {
+ this.mKeys = MathUtils.permutation(dstPos, length);
+ this.mValues = Arrays.copyOfRange(values, offset, length);
+ this.mSize = length;
+ return;
+ }
+
+ final int lastKey = mKeys[mSize - 1];
+ for (int i = 0; i < length; i++) {
+ final int valuePos = offset + i;
+ final int key = dstPos + i;
+ if (key <= lastKey) {
+ put(key, values[valuePos]);
+ } else {// append
+ int size = length - i;
+ int[] appendKeys = MathUtils.permutation(key, size);
+ this.mKeys = ArrayUtils.concat(mKeys, 0, mSize, appendKeys, 0, appendKeys.length);
+ this.mValues = ArrayUtils.concat(mValues, 0, mSize, values, valuePos, size);
+ this.mSize += size;
+ break;
+ }
+ }
+ }
+
+ public void forEach(@Nonnegative final int start, @Nonnegative final int end,
+ @Nonnull final Consumer consumer) {
+ int startPos = indexOfKey(start);
+ if (startPos < 0) {
+ startPos = ~startPos;
+ }
+ int endPos = indexOfKey(end);
+ if (endPos < 0) {
+ endPos = ~endPos;
+ }
+ final int[] keys = mKeys;
+ final int[] values = mValues;
+ for (int i = startPos; i < endPos; i++) {
+ int k = keys[i];
+ int v = values[i];
+ consumer.accept(k, v);
+ }
+ }
+
+ @Override
+ public void forEach(@Nonnull final Consumer consumer) {
+ final int size = mSize;
+ final int[] keys = mKeys;
+ final int[] values = mValues;
+ for (int i = 0; i < size; i++) {
+ int k = keys[i];
+ int v = values[i];
+ consumer.accept(k, v);
+ }
}
@Nonnull
@@ -187,7 +338,7 @@ public final class SparseIntArray implements IntArray {
@Override
public String toString() {
- if (size() <= 0) {
+ if (mSize == 0) {
return "{}";
}
@@ -207,5 +358,4 @@ public final class SparseIntArray implements IntArray {
return buffer.toString();
}
-
}
diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java b/core/src/main/java/hivemall/utils/function/Consumer.java
similarity index 68%
copy from core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
copy to core/src/main/java/hivemall/utils/function/Consumer.java
index e0b3b4b..af165be 100644
--- a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
+++ b/core/src/main/java/hivemall/utils/function/Consumer.java
@@ -16,23 +16,26 @@
* specific language governing permissions and limitations
* under the License.
*/
-package hivemall.math.matrix.ints;
+package hivemall.utils.function;
-import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.arrays.SparseIntArray;
-public abstract class ColumnMajorIntMatrix extends AbstractIntMatrix {
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
- public ColumnMajorIntMatrix() {
- super();
+public abstract class Consumer {
+
+ public Consumer() {}
+
+ public void accept(int value) {
+ throw new UnsupportedOperationException();
}
- @Override
- public void eachInRow(int row, VectorProcedure procedure, boolean nullOutput) {
+ public void accept(int i, int value) {
throw new UnsupportedOperationException();
}
- @Override
- public void eachNonZeroInRow(int row, VectorProcedure procedure) {
+ public void accept(@Nonnegative int i, @Nonnull SparseIntArray values) {
throw new UnsupportedOperationException();
}
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java b/core/src/main/java/hivemall/utils/function/IntPredicate.java
similarity index 62%
copy from core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
copy to core/src/main/java/hivemall/utils/function/IntPredicate.java
index 8edb0d4..cce2eca 100644
--- a/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
+++ b/core/src/main/java/hivemall/utils/function/IntPredicate.java
@@ -16,30 +16,15 @@
* specific language governing permissions and limitations
* under the License.
*/
-package hivemall.utils.collections.arrays;
+package hivemall.utils.function;
-import java.io.Serializable;
+public interface IntPredicate {
-import javax.annotation.Nonnull;
-
-public interface IntArray extends Serializable {
-
- public int get(int key);
-
- public int get(int key, int valueIfKeyNotFound);
-
- public void put(int key, int value);
-
- public void increment(int key, int value);
-
- public int size();
-
- public int keyAt(int index);
-
- @Nonnull
- public int[] toArray();
-
- @Nonnull
- public int[] toArray(boolean copy);
+ /**
+ * Evaluates this predicate on the given argument.
+ *
+ * @return true if the input argument matches the predicate, otherwise false
+ */
+ boolean test(int value);
}
diff --git a/core/src/main/java/hivemall/utils/hadoop/SerdeUtils.java b/core/src/main/java/hivemall/utils/hadoop/SerdeUtils.java
new file mode 100644
index 0000000..9b82996
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/hadoop/SerdeUtils.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.hadoop;
+
+import hivemall.utils.io.FastByteArrayInputStream;
+import hivemall.utils.io.FastByteArrayOutputStream;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Objects;
+
+import javax.annotation.CheckForNull;
+import javax.annotation.Nonnull;
+
+import org.roaringbitmap.RoaringBitmap;
+
+public final class SerdeUtils {
+
+ @Nonnull
+ public static byte[] serializeRoaring(@Nonnull final RoaringBitmap r) {
+ r.runOptimize(); // might improve compression
+ // next we create the ByteBuffer where the data will be stored
+ final byte[] buf = new byte[r.serializedSizeInBytes()];
+ // then we can serialize on a custom OutputStream
+ try {
+ r.serialize(new DataOutputStream(new FastByteArrayOutputStream(buf)));
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to serialize RoaringBitmap", e);
+ }
+ return buf;
+ }
+
+ @Nonnull
+ public static RoaringBitmap deserializeRoaring(@CheckForNull final byte[] b) {
+ final RoaringBitmap bitmap = new RoaringBitmap();
+ try {
+ bitmap.deserialize(
+ new DataInputStream(new FastByteArrayInputStream(Objects.requireNonNull(b))));
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to deserialize RoaringBitmap", e);
+ }
+ return bitmap;
+ }
+
+
+}
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 5df63d9..4e73ebc 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -62,6 +62,22 @@ public final class ArrayUtils {
}
@Nonnull
+ public static int[] sortedArraySet(@Nonnull final int[] sorted,
+ @Nonnegative final int element) {
+ final int i = Arrays.binarySearch(sorted, element);
+ if (i >= 0) {// found element
+ return sorted;
+ } else {
+ return insert(sorted, ~i, element);
+ }
+ }
+
+ public static boolean contains(@Nonnull final int[] sorted, @Nonnegative final int element) {
+ int i = Arrays.binarySearch(sorted, element);
+ return i >= 0;
+ }
+
+ @Nonnull
public static float[] toArray(@Nonnull final List<Float> lst) {
final int ndim = lst.size();
final float[] ary = new float[ndim];
@@ -271,6 +287,24 @@ public final class ArrayUtils {
return INDEX_NOT_FOUND;
}
+ public static int insertionPoint(@Nullable final int[] a, final int key) {
+ final int pos = Arrays.binarySearch(a, key);
+ if (pos < 0) {
+ return ~pos;
+ } else {
+ return pos;
+ }
+ }
+
+ public static int insertionPoint(@Nullable final int[] a, final int size, final int key) {
+ final int pos = Arrays.binarySearch(a, 0, size, key);
+ if (pos < 0) {
+ return ~pos;
+ } else {
+ return pos;
+ }
+ }
+
@Nonnull
public static byte[] copyOf(@Nonnull final byte[] original, final int newLength) {
final byte[] copy = new byte[newLength];
@@ -294,6 +328,15 @@ public final class ArrayUtils {
}
@Nonnull
+ public static int[] append(@Nonnull final int[] array, final int element) {
+ int size = array.length;
+ final int[] newArray = new int[size + 1];
+ System.arraycopy(array, 0, newArray, 0, size);
+ newArray[size] = element;
+ return newArray;
+ }
+
+ @Nonnull
public static int[] append(@Nonnull int[] array, final int currentSize, final int element) {
if (currentSize + 1 > array.length) {
int[] newArray = new int[currentSize * 2];
@@ -329,6 +372,50 @@ public final class ArrayUtils {
}
@Nonnull
+ public static int[] concat(@Nonnegative final int[] array1, @Nonnegative final int... array2) {
+ final int[] joinedArray = new int[array1.length + array2.length];
+ System.arraycopy(array1, 0, joinedArray, 0, array1.length);
+ System.arraycopy(array2, 0, joinedArray, array1.length, array2.length);
+ return joinedArray;
+ }
+
+ @Nonnull
+ public static int[] concat(@Nonnegative final int[] array1, @Nonnegative final int[] array2,
+ final int offset, final int length) {
+ final int[] joinedArray = new int[array1.length + length];
+ System.arraycopy(array1, 0, joinedArray, 0, array1.length);
+ System.arraycopy(array2, offset, joinedArray, array1.length, length);
+ return joinedArray;
+ }
+
+ @Nonnull
+ public static int[] concat(@Nonnegative final int[] array1, final int offset1,
+ final int length1, @Nonnegative final int[] array2, final int offset2,
+ final int length2) {
+ final int[] joinedArray = new int[length1 + length2];
+ System.arraycopy(array1, offset1, joinedArray, 0, length1);
+ System.arraycopy(array2, offset2, joinedArray, length1, length2);
+ return joinedArray;
+ }
+
+ @Nonnull
+ public static int[] insert(@Nonnull final int[] array, final int index, final int element) {
+ final int size = array.length;
+ if (index > size) {
+ throw new IllegalArgumentException(String.format(
+ "index should be less than or equals to array.length: index=%d, array.length=%d",
+ index, array.length));
+ }
+ final int[] newArray = new int[size + 1];
+ System.arraycopy(array, 0, newArray, 0, Math.min(index, size));
+ newArray[index] = element;
+ if (index != size) {
+ System.arraycopy(array, index, newArray, index + 1, size - index);
+ }
+ return newArray;
+ }
+
+ @Nonnull
public static int[] insert(@Nonnull final int[] array, final int currentSize, final int index,
final int element) {
if (currentSize + 1 <= array.length) {
@@ -373,6 +460,30 @@ public final class ArrayUtils {
return newArray;
}
+ /**
+ * Removes from {@code array} all of the elements whose index is between {@code fromIndex},
+ * inclusive, and {@code toIndex}, exclusive.
+ *
+ * @param fromIndex index of first element to be removed
+ * @param toIndex index after last element to be removed
+ * @throws IndexOutOfBoundsException if {@code fromIndex} or {@code toIndex} is out of range
+ * ({@code fromIndex < 0 ||
+ * fromIndex >= size() ||
+ * toIndex > size() ||
+ * toIndex < fromIndex})
+ */
+ public static void clearRange(@Nonnull final int[] array, @Nonnegative final int fromIndex,
+ @Nonnegative final int toIndex, final int fillVal) {
+ final int size = array.length;
+ if (fromIndex < 0 || fromIndex >= size || toIndex > size || toIndex < fromIndex) {
+ throw new IllegalArgumentException(String.format(
+ "fromIndex: %d, toIndex: %d, array.length=%d", fromIndex, toIndex, size));
+ }
+
+ System.arraycopy(array, toIndex, array, fromIndex, array.length - toIndex);
+ Arrays.fill(array, toIndex, array.length, fillVal);
+ }
+
public static boolean equals(@Nonnull final float[] array, final float value) {
for (int i = 0, size = array.length; i < size; i++) {
if (array[i] != value) {
diff --git a/core/src/main/java/hivemall/utils/lang/mutable/MutableBoolean.java b/core/src/main/java/hivemall/utils/lang/mutable/MutableBoolean.java
new file mode 100644
index 0000000..0f8a8cb
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/lang/mutable/MutableBoolean.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.lang.mutable;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+public final class MutableBoolean implements Comparable<MutableBoolean>, Serializable {
+ private static final long serialVersionUID = -8946436031470563775L;
+
+ private boolean value;
+
+ public MutableBoolean() {
+ this(false);
+ }
+
+ public MutableBoolean(boolean value) {
+ this.value = value;
+ }
+
+ public boolean get() {
+ return value;
+ }
+
+ public boolean booleanValue() {
+ return value;
+ }
+
+ public void setValue(boolean value) {
+ this.value = value;
+ }
+
+ public void setValue(@Nonnull Boolean value) {
+ this.value = value.booleanValue();
+ }
+
+ public void setFalse() {
+ this.value = false;
+ }
+
+ public void setTrue() {
+ this.value = true;
+ }
+
+ @Override
+ public int hashCode() {
+ return value ? Boolean.TRUE.hashCode() : Boolean.FALSE.hashCode();
+ }
+
+ @Override
+ public boolean equals(@Nullable Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (other == null) {
+ return false;
+ }
+ if (other instanceof MutableBoolean) {
+ return value == ((MutableBoolean) other).booleanValue();
+ }
+ return false;
+ }
+
+ @Override
+ public int compareTo(@Nonnull MutableBoolean o) {
+ return Boolean.compare(value, o.value);
+ }
+
+}
diff --git a/core/src/main/java/hivemall/utils/lang/mutable/MutableInt.java b/core/src/main/java/hivemall/utils/lang/mutable/MutableInt.java
index e527f2b..4fa9360 100644
--- a/core/src/main/java/hivemall/utils/lang/mutable/MutableInt.java
+++ b/core/src/main/java/hivemall/utils/lang/mutable/MutableInt.java
@@ -42,8 +42,26 @@ public final class MutableInt extends Number
this.value = value.intValue();
}
+ public int getAndIncrement() {
+ int ret = value;
+ this.value += 1;
+ return ret;
+ }
+
+ public void incr() {
+ this.value += 1;
+ }
+
+ public void decr() {
+ this.value -= 1;
+ }
+
public void addValue(int o) {
- value += o;
+ this.value += o;
+ }
+
+ public int get() {
+ return value;
}
public int getValue() {
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index e3f32f5..ed5b2fd 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -327,6 +327,15 @@ public final class MathUtils {
return perm;
}
+ @Nonnull
+ public static int[] permutation(@Nonnegative final int start, @Nonnegative final int size) {
+ final int[] perm = new int[size];
+ for (int i = 0; i < size; i++) {
+ perm[i] = start + i;
+ }
+ return perm;
+ }
+
public static double sum(@Nullable final float[] arr) {
if (arr == null) {
return 0.d;
diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
index 3f287af..c3601eb 100644
--- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
@@ -23,15 +23,18 @@ import static org.junit.Assert.assertEquals;
import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.smile.classification.DecisionTree.Node;
-import hivemall.smile.data.AttributeType;
+import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.smile.tools.TreeExportUDF.Evaluator;
import hivemall.smile.tools.TreeExportUDF.OutputType;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
import smile.data.AttributeDataset;
+import smile.data.NominalAttribute;
import smile.data.parser.ArffParser;
+import smile.data.parser.DelimitedTextParser;
import smile.math.Math;
import smile.validation.LOOCV;
@@ -47,6 +50,7 @@ import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
+import org.roaringbitmap.RoaringBitmap;
public class DecisionTreeTest {
private static final boolean DEBUG = false;
@@ -165,7 +169,7 @@ public class DecisionTreeTest {
double[][] x = ds.toArray(new double[ds.size()][]);
int[] y = ds.toArray(new int[ds.size()]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
DecisionTree tree = new DecisionTree(attrs, matrix(x, dense), y, numLeafs,
RandomNumberGeneratorFactory.createPRNG(31));
@@ -196,7 +200,7 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs,
RandomNumberGeneratorFactory.createPRNG(i));
if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) {
@@ -226,12 +230,13 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs,
RandomNumberGeneratorFactory.createPRNG(i));
DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs,
RandomNumberGeneratorFactory.createPRNG(i));
Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]]));
+ Assert.assertEquals(dtree.toString(), stree.toString());
}
}
@@ -253,7 +258,7 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
byte[] b = tree.serialize(false);
@@ -280,7 +285,7 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
byte[] b1 = tree.serialize(true);
@@ -292,6 +297,56 @@ public class DecisionTreeTest {
}
}
+ @Test
+ public void testTitanicPruning() throws IOException, ParseException {
+ String datasetUrl =
+ "https://gist.githubusercontent.com/myui/7cd82c443db84ba7e7add1523d0247a9/raw/f2d3e3051b0292577e8c01a1759edabaa95c5781/titanic_train.tsv";
+
+ URL url = new URL(datasetUrl);
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ DelimitedTextParser parser = new DelimitedTextParser();
+ parser.setColumnNames(true);
+ parser.setDelimiter(",");
+ parser.setResponseIndex(new NominalAttribute("survived"), 0);
+
+ AttributeDataset train = parser.parse("titanic train", is);
+ double[][] x_ = train.toArray(new double[train.size()][]);
+ int[] y = train.toArray(new int[train.size()]);
+
+ // pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked
+ // C,C,C,Q,Q,Q,C,Q,C,C
+ RoaringBitmap nominalAttrs = new RoaringBitmap();
+ nominalAttrs.add(0);
+ nominalAttrs.add(1);
+ nominalAttrs.add(2);
+ nominalAttrs.add(6);
+ nominalAttrs.add(8);
+ nominalAttrs.add(9);
+
+ int columns = x_[0].length;
+ Matrix x = new RowMajorDenseMatrix2d(x_, columns);
+ int numVars = (int) Math.ceil(Math.sqrt(columns));
+ int maxDepth = Integer.MAX_VALUE;
+ int maxLeafs = Integer.MAX_VALUE;
+ int minSplits = 2;
+ int minLeafSize = 1;
+ int[] samples = null;
+ PRNG rand = RandomNumberGeneratorFactory.createPRNG(43L);
+
+ final String[] featureNames = new String[] {"pclass", "name", "sex", "age", "sibsp",
+ "parch", "ticket", "fare", "cabin", "embarked"};
+ final String[] classNames = new String[] {"yes", "no"};
+ DecisionTree tree = new DecisionTree(nominalAttrs, x, y, numVars, maxDepth, maxLeafs,
+ minSplits, minLeafSize, samples, SplitRule.GINI, rand) {
+ @Override
+ public String toString() {
+ return predictJsCodegen(featureNames, classNames);
+ }
+ };
+ tree.toString();
+ }
+
@Nonnull
private static Matrix matrix(@Nonnull final double[][] x, boolean dense) {
if (dense) {
diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
index f7b0285..efb3eea 100644
--- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
@@ -396,7 +396,7 @@ public class RandomForestClassifierUDTFTest {
final MutableInt oobErrors = new MutableInt(0);
final MutableInt oobTests = new MutableInt(0);
Collector collector = new Collector() {
- public void collect(Object input) throws HiveException {
+ public synchronized void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
oobErrors.addValue(((IntWritable) forward[4]).get());
oobTests.addValue(((IntWritable) forward[5]).get());
@@ -448,7 +448,7 @@ public class RandomForestClassifierUDTFTest {
final MutableInt oobErrors = new MutableInt(0);
final MutableInt oobTests = new MutableInt(0);
Collector collector = new Collector() {
- public void collect(Object input) throws HiveException {
+ public synchronized void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
oobErrors.addValue(((IntWritable) forward[4]).get());
oobTests.addValue(((IntWritable) forward[5]).get());
diff --git a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
index 75aa65a..3f1b782 100644
--- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
@@ -22,7 +22,6 @@ import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.math.random.RandomNumberGeneratorFactory;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.tools.TreeExportUDF.Evaluator;
import hivemall.smile.tools.TreeExportUDF.OutputType;
import hivemall.utils.codec.Base91;
@@ -31,7 +30,6 @@ import smile.validation.LOOCV;
import java.io.IOException;
import java.text.ParseException;
-import java.util.Arrays;
import javax.annotation.Nonnull;
@@ -39,6 +37,7 @@ import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
+import org.roaringbitmap.RoaringBitmap;
public class RegressionTreeTest {
private static final boolean DEBUG = false;
@@ -66,8 +65,7 @@ public class RegressionTreeTest {
double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
112.6, 114.2, 115.7, 116.9};
- AttributeType[] attrs = new AttributeType[longley[0].length];
- Arrays.fill(attrs, AttributeType.NUMERIC);
+ RoaringBitmap attrs = new RoaringBitmap();
int n = longley.length;
LOOCV loocv = new LOOCV(n);
@@ -109,8 +107,7 @@ public class RegressionTreeTest {
double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
112.6, 114.2, 115.7, 116.9};
- AttributeType[] attrs = new AttributeType[longley[0].length];
- Arrays.fill(attrs, AttributeType.NUMERIC);
+ RoaringBitmap attrs = new RoaringBitmap();
int n = longley.length;
LOOCV loocv = new LOOCV(n);
@@ -152,8 +149,7 @@ public class RegressionTreeTest {
double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
112.6, 114.2, 115.7, 116.9};
- AttributeType[] attrs = new AttributeType[longley[0].length];
- Arrays.fill(attrs, AttributeType.NUMERIC);
+ RoaringBitmap attrs = new RoaringBitmap();
int n = longley.length;
LOOCV loocv = new LOOCV(n);
@@ -204,8 +200,7 @@ public class RegressionTreeTest {
private static String graphvizOutput(double[][] x, double[] y, int maxLeafs, boolean dense,
String[] featureNames, String outputName)
throws IOException, HiveException, ParseException {
- AttributeType[] attrs = new AttributeType[x[0].length];
- Arrays.fill(attrs, AttributeType.NUMERIC);
+ RoaringBitmap attrs = new RoaringBitmap();
RegressionTree tree = new RegressionTree(attrs, matrix(x, dense), y, maxLeafs);
Text model = new Text(Base91.encode(tree.serialize(true)));
diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
index f44b9ec..8ccdc03 100644
--- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
+++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
@@ -21,7 +21,6 @@ package hivemall.smile.tools;
import hivemall.TestUtils;
import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.smile.classification.DecisionTree;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
@@ -51,6 +50,7 @@ import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
+import org.roaringbitmap.RoaringBitmap;
public class TreePredictUDFTest {
private static final boolean DEBUG = false;
@@ -76,7 +76,7 @@ public class TreePredictUDFTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
DecisionTree tree = new DecisionTree(attrs,
new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
Assert.assertEquals(tree.predict(x[loocv.test[i]]),
@@ -105,7 +105,7 @@ public class TreePredictUDFTest {
double[] trainy = Math.slice(datay, cv.train[i]);
double[][] testx = Math.slice(datax, cv.test[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
RegressionTree tree = new RegressionTree(attrs,
new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
@@ -145,7 +145,7 @@ public class TreePredictUDFTest {
testy[i - m] = datay[index[i]];
}
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
RegressionTree tree = new RegressionTree(attrs,
new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));
@@ -240,7 +240,7 @@ public class TreePredictUDFTest {
testy[i - m] = datay[index[i]];
}
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
RegressionTree tree = new RegressionTree(attrs,
new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java
index 25e1cc6..a9a0f26 100644
--- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java
+++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java
@@ -23,7 +23,6 @@ import static org.junit.Assert.assertEquals;
import hivemall.TestUtils;
import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.smile.classification.DecisionTree;
-import hivemall.smile.data.AttributeType;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.tools.TreePredictUDFv1.DtNodeV1;
import hivemall.smile.tools.TreePredictUDFv1.JavaSerializationEvaluator;
@@ -58,6 +57,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.io.IntWritable;
import org.junit.Assert;
import org.junit.Test;
+import org.roaringbitmap.RoaringBitmap;
@SuppressWarnings("deprecation")
public class TreePredictUDFv1Test {
@@ -97,7 +97,7 @@ public class TreePredictUDFv1Test {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
DecisionTree tree = new DecisionTree(attrs,
new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]]));
@@ -125,7 +125,7 @@ public class TreePredictUDFv1Test {
double[] trainy = Math.slice(datay, cv.train[i]);
double[][] testx = Math.slice(datax, cv.test[i]);
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
RegressionTree tree = new RegressionTree(attrs,
new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
@@ -165,7 +165,7 @@ public class TreePredictUDFv1Test {
testy[i - m] = datay[index[i]];
}
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
RegressionTree tree = new RegressionTree(attrs,
new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));
@@ -260,7 +260,7 @@ public class TreePredictUDFv1Test {
testy[i - m] = datay[index[i]];
}
- AttributeType[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
+ RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
RegressionTree tree = new RegressionTree(attrs,
new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
String opScript = tree.predictOpCodegen(StackMachine.SEP);
diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/test/java/hivemall/smile/utils/SmileExtUtilsTest.java
similarity index 51%
copy from core/src/main/java/hivemall/math/vector/VectorProcedure.java
copy to core/src/test/java/hivemall/smile/utils/SmileExtUtilsTest.java
index 4978885..471fc35 100644
--- a/core/src/main/java/hivemall/math/vector/VectorProcedure.java
+++ b/core/src/test/java/hivemall/smile/utils/SmileExtUtilsTest.java
@@ -16,28 +16,25 @@
* specific language governing permissions and limitations
* under the License.
*/
-package hivemall.math.vector;
+package hivemall.smile.utils;
-import javax.annotation.Nonnegative;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.junit.Assert;
+import org.junit.Test;
-public abstract class VectorProcedure {
+public class SmileExtUtilsTest {
- public VectorProcedure() {}
-
- public void apply(@Nonnegative int i, @Nonnegative int j, float value) {
- apply(i, j, (double) value);
+ @Test
+ public void testResolveAttributes() throws UDFArgumentException {
+ Assert.assertTrue(SmileExtUtils.resolveAttributes("Q,Q,Q").isEmpty());
+ Assert.assertEquals(4, SmileExtUtils.resolveAttributes("Q,C,C,Q,C,Q,C").getCardinality());
+ Assert.assertEquals(SmileExtUtils.resolveAttributes("Q,C,C,Q,C"),
+ SmileExtUtils.parseNominalAttributeIndicies("1,2,4"));
}
- public void apply(@Nonnegative int i, @Nonnegative int j, double value) {}
-
- public void apply(@Nonnegative int i, float value) {
- apply(i, (double) value);
+ @Test(expected = UDFArgumentException.class)
+ public void testResolveAttributesInvalidFormat() throws UDFArgumentException {
+ Assert.assertTrue(SmileExtUtils.resolveAttributes("Q,Q,3,Q").isEmpty());
}
- public void apply(@Nonnegative int i, double value) {}
-
- public void apply(@Nonnegative int i, int value) {}
-
- public void apply(@Nonnegative int i) {}
-
}
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
index 0ce3912..a84b396 100644
--- a/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
+++ b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
@@ -64,6 +64,7 @@ public class IntArrayTest {
for (int i = 0; i < 10; i++) {
array.put(i, 10 + i);
}
+ Assert.assertEquals(10, array.size());
array.clear();
Assert.assertEquals(0, array.size());
Assert.assertEquals(0, array.get(0));
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
index db3c8eb..70a6f62 100644
--- a/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
+++ b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
@@ -18,8 +18,8 @@
*/
package hivemall.utils.collections.arrays;
-import hivemall.utils.collections.arrays.IntArray;
-import hivemall.utils.collections.arrays.SparseIntArray;
+import hivemall.utils.function.Consumer;
+import hivemall.utils.lang.mutable.MutableInt;
import java.util.Random;
@@ -61,4 +61,113 @@ public class SparseIntArrayTest {
Assert.assertEquals(expected[key], actual.get(key, 0));
}
}
+
+ @Test
+ public void testAppend() {
+ int[] a1 = new int[500];
+ for (int i = 0; i < a1.length; i++) {
+ a1[i] = i;
+ }
+ SparseIntArray array = new SparseIntArray(a1);
+ for (int i = 0; i < a1.length; i++) {
+ Assert.assertEquals(a1[i], array.get(i));
+ }
+ int[] a2 = new int[100];
+ for (int i = 0; i < 100; i++) {
+ a2[i] = a1[a1.length - 1] + i;
+ }
+ array.append(a1.length - 9, a2);
+ Assert.assertEquals(a1.length + a2.length - 9, array.size());
+ }
+
+ @Test
+ public void testAppend2() {
+ int[] a1 = new int[500];
+ for (int i = 0; i < a1.length; i++) {
+ a1[i] = i;
+ }
+ SparseIntArray array = new SparseIntArray(a1);
+ for (int i = 0; i < a1.length; i++) {
+ Assert.assertEquals(a1[i], array.get(i));
+ }
+ int[] a2 = new int[100];
+ for (int i = 0; i < 100; i++) {
+ a2[i] = a1[a1.length - 1] + i;
+ }
+ array.append(a1.length - 9, a2, 0, a2.length);
+ Assert.assertEquals(a1.length + a2.length - 9, array.size());
+ }
+
+ @Test
+ public void testConsume() {
+ final Random rng = new Random(43L);
+ int[] keys = new int[500];
+ int[] values = new int[keys.length];
+ for (int i = 0; i < keys.length; i++) {
+ keys[i] = i * 2;
+ values[i] = rng.nextInt(1000);
+ }
+ final SparseIntArray actual = new SparseIntArray(keys, values, keys.length);
+ Assert.assertEquals(500, actual.size());
+
+ actual.forEach(10, 30, new Consumer() {
+ @Override
+ public void accept(int i, int value) {
+ actual.put(i, value);
+ actual.put(i + 1, value);
+ }
+ });
+
+ int lastKey = actual.lastKey();
+ Assert.assertEquals(998, lastKey);
+ actual.append(lastKey, new int[] {-1, -2, -3});
+
+ Assert.assertEquals(512, actual.size());
+ Assert.assertEquals(-1, actual.get(998));
+ Assert.assertEquals(-2, actual.get(999));
+ Assert.assertEquals(-3, actual.get(1000));
+
+ for (int i = 10; i < 30; i += 2) {
+ Assert.assertEquals(actual.get(i), actual.get(i + 1));
+ }
+ }
+
+ @Test
+ public void testRemoveRange() {
+ SparseIntArray acutal = new SparseIntArray(2);
+ acutal.append(3, 3);
+ acutal.append(4, 4);
+ acutal.append(6, 6);
+ acutal.append(7, 7);
+ acutal.append(8, 8);
+ acutal.append(9, 9);
+
+ acutal.removeRange(5, 8);
+ Assert.assertEquals(4, acutal.size());
+ Assert.assertEquals(3, acutal.get(3));
+ Assert.assertEquals(4, acutal.get(4));
+ Assert.assertEquals(8, acutal.get(8));
+ Assert.assertEquals(9, acutal.get(9));
+ }
+
+ @Test
+ public void testAppendRange() {
+ SparseIntArray a = new SparseIntArray(2);
+ a.append(3, 3);
+ a.append(4, 4);
+ a.append(6, 6);
+ a.append(a.lastKey() + 1, new int[] {7, 8, 9, 10}, 1, 2);
+
+ Assert.assertEquals(5, a.size());
+
+ final int[] actual = new int[5];
+ final MutableInt index = new MutableInt(0);
+ a.forEach(new Consumer() {
+ @Override
+ public void accept(int i, int value) {
+ actual[index.getAndIncrement()] = value;
+ }
+ });
+ Assert.assertArrayEquals(new int[] {3, 4, 6, 8, 9}, actual);
+ }
}
diff --git a/core/src/test/java/hivemall/utils/lang/ArrayUtilsTest.java b/core/src/test/java/hivemall/utils/lang/ArrayUtilsTest.java
new file mode 100644
index 0000000..e38be8c
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/lang/ArrayUtilsTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.lang;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ArrayUtilsTest {
+
+ @Test
+ public void testSortedArraySet() {
+ final int[] original = new int[] {3, 7, 10};
+ Assert.assertSame(original, ArrayUtils.sortedArraySet(original, 7));
+ Assert.assertSame(original, ArrayUtils.sortedArraySet(original, 3));
+ Assert.assertSame(original, ArrayUtils.sortedArraySet(original, 10));
+ Assert.assertArrayEquals(new int[] {3, 7, 8, 10},
+ ArrayUtils.sortedArraySet(new int[] {3, 7, 10}, 8));
+ Assert.assertArrayEquals(new int[] {3, 7, 7, 8, 10},
+ ArrayUtils.sortedArraySet(new int[] {3, 7, 7, 10}, 8));
+ Assert.assertArrayEquals(new int[] {3, 7, 7, 10},
+ ArrayUtils.sortedArraySet(new int[] {3, 7, 7, 10}, 7));
+ Assert.assertArrayEquals(new int[] {3, 7, 10, 11},
+ ArrayUtils.sortedArraySet(new int[] {3, 7, 10}, 11));
+ Assert.assertArrayEquals(new int[] {-2, 3, 7, 10},
+ ArrayUtils.sortedArraySet(new int[] {3, 7, 10}, -2));
+ }
+
+ @Test
+ public void testAppendIntArrayInt() {
+ Assert.assertArrayEquals(new int[] {3, 7, 10, 8},
+ ArrayUtils.append(new int[] {3, 7, 10}, 8));
+ }
+
+ @Test
+ public void testInsert() {
+ final int[] original = new int[] {3, 7, 10};
+ Assert.assertArrayEquals(new int[] {3, 7, 8, 10}, ArrayUtils.insert(original, 2, 8));
+ Assert.assertArrayEquals(new int[] {1, 3, 7, 10}, ArrayUtils.insert(original, 0, 1));
+ Assert.assertArrayEquals(new int[] {3, 3, 7, 10}, ArrayUtils.insert(original, 0, 3));
+ Assert.assertArrayEquals(new int[] {3, 3, 7, 10}, ArrayUtils.insert(original, 0, 3));
+ Assert.assertArrayEquals(new int[] {3, 7, 10, 11},
+ ArrayUtils.insert(original, original.length, 11));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testInsertFail() {
+ final int[] original = new int[] {3, 7, 10};
+ Assert.assertArrayEquals(new int[] {3, 7, 10, 11},
+ ArrayUtils.insert(original, original.length + 1, 11));
+ }
+
+}