You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:22 UTC
[10/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java b/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java
deleted file mode 100644
index 1c7a9a1..0000000
--- a/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import hivemall.utils.lang.Preconditions;
-
-import java.util.Arrays;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-/**
- * Read-only CSR Matrix.
- *
- * @see http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
- */
-public final class ReadOnlyCSRMatrix extends Matrix {
-
- @Nonnull
- private final int[] rowPointers;
- @Nonnull
- private final int[] columnIndices;
- @Nonnull
- private final double[] values;
-
- @Nonnegative
- private final int numRows;
- @Nonnegative
- private final int numColumns;
-
- public ReadOnlyCSRMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices,
- @Nonnull double[] values, @Nonnegative int numColumns) {
- super();
- Preconditions.checkArgument(rowPointers.length >= 1,
- "rowPointers must be greather than 0: " + rowPointers.length);
- Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices ("
- + columnIndices.length + ") must be equals to #values (" + values.length + ")");
- this.rowPointers = rowPointers;
- this.columnIndices = columnIndices;
- this.values = values;
- this.numRows = rowPointers.length - 1;
- this.numColumns = numColumns;
- }
-
- @Override
- public boolean readOnly() {
- return true;
- }
-
- @Override
- public int numRows() {
- return numRows;
- }
-
- @Override
- public int numColumns() {
- return numColumns;
- }
-
- @Override
- public int numColumns(@Nonnegative final int row) {
- checkRowIndex(row, numRows);
-
- int columns = rowPointers[row + 1] - rowPointers[row];
- return columns;
- }
-
- @Override
- public double get(@Nonnegative final int row, @Nonnegative final int col,
- final double defaultValue) {
- checkIndex(row, col, numRows, numColumns);
-
- final int index = getIndex(row, col);
- if (index < 0) {
- return defaultValue;
- }
- return values[index];
- }
-
- @Override
- public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
- final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final int index = getIndex(row, col);
- if (index < 0) {
- throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
- + col);
- }
-
- double old = values[index];
- values[index] = value;
- return old;
- }
-
- @Override
- public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final int index = getIndex(row, col);
- if (index < 0) {
- throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
- + col);
- }
- values[index] = value;
- }
-
- private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
- int leftIn = rowPointers[row];
- int rightEx = rowPointers[row + 1];
- final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col);
- if (index >= 0 && index >= values.length) {
- throw new IndexOutOfBoundsException("Value index " + index + " out of range "
- + values.length);
- }
- return index;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java b/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java
deleted file mode 100644
index 040fef8..0000000
--- a/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-public final class ReadOnlyDenseMatrix2d extends Matrix {
-
- @Nonnull
- private final double[][] data;
-
- @Nonnegative
- private final int numRows;
- @Nonnegative
- private final int numColumns;
-
- public ReadOnlyDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns) {
- this.data = data;
- this.numRows = data.length;
- this.numColumns = numColumns;
- }
-
- @Override
- public boolean readOnly() {
- return true;
- }
-
- @Override
- public void setDefaultValue(double value) {
- throw new UnsupportedOperationException("The defaultValue of a DenseMatrix is fixed to 0.d");
- }
-
- @Override
- public int numRows() {
- return numRows;
- }
-
- @Override
- public int numColumns() {
- return numColumns;
- }
-
- @Override
- public int numColumns(@Nonnegative final int row) {
- checkRowIndex(row, numRows);
-
- return data[row].length;
- }
-
- @Override
- public double get(@Nonnegative final int row, @Nonnegative final int col,
- final double defaultValue) {
- checkIndex(row, col, numRows, numColumns);
-
- final double[] rowData = data[row];
- if (col >= rowData.length) {
- return defaultValue;
- }
- return rowData[col];
- }
-
- @Override
- public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
- final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final double[] rowData = data[row];
- checkColIndex(col, rowData.length);
-
- double old = rowData[col];
- rowData[col] = value;
- return old;
- }
-
- @Override
- public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final double[] rowData = data[row];
- checkColIndex(col, rowData.length);
-
- rowData[col] = value;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/mf/FactorizedModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java
index b92a5d8..a4bea00 100644
--- a/core/src/main/java/hivemall/mf/FactorizedModel.java
+++ b/core/src/main/java/hivemall/mf/FactorizedModel.java
@@ -18,7 +18,7 @@
*/
package hivemall.mf;
-import hivemall.utils.collections.IntOpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashMap;
import hivemall.utils.math.MathUtils;
import java.util.Random;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/model/AbstractPredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java
index 37b69da..b48282b 100644
--- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java
+++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java
@@ -21,8 +21,8 @@ package hivemall.model;
import hivemall.mix.MixedWeight;
import hivemall.mix.MixedWeight.WeightWithCovar;
import hivemall.mix.MixedWeight.WeightWithDelta;
-import hivemall.utils.collections.IntOpenHashMap;
-import hivemall.utils.collections.OpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashMap;
+import hivemall.utils.collections.maps.OpenHashMap;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java
index 96e1d5a..a2b4708 100644
--- a/core/src/main/java/hivemall/model/SparseModel.java
+++ b/core/src/main/java/hivemall/model/SparseModel.java
@@ -22,7 +22,7 @@ import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock;
import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock;
import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock;
import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.collections.OpenHashMap;
+import hivemall.utils.collections.maps.OpenHashMap;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/ModelType.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/ModelType.java b/core/src/main/java/hivemall/smile/ModelType.java
deleted file mode 100644
index 8925075..0000000
--- a/core/src/main/java/hivemall/smile/ModelType.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile;
-
-public enum ModelType {
-
- // not compressed
- opscode(1, false), javascript(2, false), serialization(3, false),
- // compressed
- opscode_compressed(-1, true), javascript_compressed(-2, true),
- serialization_compressed(-3, true);
-
- private final int id;
- private final boolean compressed;
-
- private ModelType(int id, boolean compressed) {
- this.id = id;
- this.compressed = compressed;
- }
-
- public int getId() {
- return id;
- }
-
- public boolean isCompressed() {
- return compressed;
- }
-
- public static ModelType resolve(String name, boolean compressed) {
- name = name.toLowerCase();
- if ("opscode".equals(name) || "vm".equals(name)) {
- return compressed ? opscode_compressed : opscode;
- } else if ("javascript".equals(name) || "js".equals(name)) {
- return compressed ? javascript_compressed : javascript;
- } else if ("serialization".equals(name) || "ser".equals(name)) {
- return compressed ? serialization_compressed : serialization;
- } else {
- throw new IllegalStateException("Unexpected output type: " + name);
- }
- }
-
- public static ModelType resolve(final int id) {
- final ModelType type;
- switch (id) {
- case 1:
- type = opscode;
- break;
- case -1:
- type = opscode_compressed;
- break;
- case 2:
- type = javascript;
- break;
- case -2:
- type = javascript_compressed;
- break;
- case 3:
- type = serialization;
- break;
- case -3:
- type = serialization_compressed;
- break;
- default:
- throw new IllegalStateException("Unexpected ID for ModelType: " + id);
- }
- return type;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/DecisionTree.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
index 6b22473..2d086b9 100644
--- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
+++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
@@ -33,100 +33,94 @@
*/
package hivemall.smile.classification;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
import hivemall.smile.data.Attribute.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.ObjectUtils;
-import hivemall.utils.lang.StringUtils;
+import hivemall.utils.sampling.IntReservoirSampler;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.roaringbitmap.IntConsumer;
+import org.roaringbitmap.RoaringBitmap;
import smile.classification.Classifier;
import smile.math.Math;
-import smile.math.Random;
/**
- * Decision tree for classification. A decision tree can be learned by splitting the training set
- * into subsets based on an attribute value test. This process is repeated on each derived subset in
- * a recursive manner called recursive partitioning. The recursion is completed when the subset at a
- * node all has the same value of the target variable, or when splitting no longer adds value to the
- * predictions.
+ * Decision tree for classification. A decision tree can be learned by splitting the training set into subsets based on an attribute value test. This
+ * process is repeated on each derived subset in a recursive manner called recursive partitioning. The recursion is completed when the subset at a
+ * node all has the same value of the target variable, or when splitting no longer adds value to the predictions.
* <p>
- * The algorithms that are used for constructing decision trees usually work top-down by choosing a
- * variable at each step that is the next best variable to use in splitting the set of items. "Best"
- * is defined by how well the variable splits the set into homogeneous subsets that have the same
- * value of the target variable. Different algorithms use different formulae for measuring "best".
- * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element
- * from the set would be incorrectly labeled if it were randomly labeled according to the
- * distribution of labels in the subset. Gini impurity can be computed by summing the probability of
- * each item being chosen times the probability of a mistake in categorizing that item. It reaches
- * its minimum (zero) when all cases in the node fall into a single target category. Information
- * gain is another popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is
- * based on the concept of entropy used in information theory. For categorical variables with
- * different number of levels, however, information gain are biased in favor of those attributes
- * with more levels. Instead, one may employ the information gain ratio, which solves the drawback
- * of information gain.
+ * The algorithms that are used for constructing decision trees usually work top-down by choosing a variable at each step that is the next best
+ * variable to use in splitting the set of items. "Best" is defined by how well the variable splits the set into homogeneous subsets that have the
+ * same value of the target variable. Different algorithms use different formulae for measuring "best". Used by the CART algorithm, Gini impurity is a
+ * measure of how often a randomly chosen element from the set would be incorrectly labeled if it were randomly labeled according to the distribution
+ * of labels in the subset. Gini impurity can be computed by summing the probability of each item being chosen times the probability of a mistake in
+ * categorizing that item. It reaches its minimum (zero) when all cases in the node fall into a single target category. Information gain is another
+ * popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is based on the concept of entropy used in information theory. For
+ * categorical variables with different number of levels, however, information gain are biased in favor of those attributes with more levels. Instead,
+ * one may employ the information gain ratio, which solves the drawback of information gain.
* <p>
- * Classification and Regression Tree techniques have a number of advantages over many of those
- * alternative techniques.
+ * Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques.
* <dl>
* <dt>Simple to understand and interpret.</dt>
- * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This
- * simplicity is useful not only for purposes of rapid classification of new observations, but can
- * also often yield a much simpler "model" for explaining why observations are classified or
- * predicted in a particular manner.</dd>
+ * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This simplicity is useful not only for purposes of rapid
+ * classification of new observations, but can also often yield a much simpler "model" for explaining why observations are classified or predicted in
+ * a particular manner.</dd>
* <dt>Able to handle both numerical and categorical data.</dt>
- * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of
- * variable.</dd>
+ * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of variable.</dd>
* <dt>Tree methods are nonparametric and nonlinear.</dt>
- * <dd>The final results of using tree methods for classification or regression can be summarized in
- * a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no
- * implicit assumption that the underlying relationships between the predictor variables and the
- * dependent variable are linear, follow some specific non-linear link function, or that they are
- * even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks,
- * where there is often little a priori knowledge nor any coherent set of theories or predictions
- * regarding which variables are related and how. In those types of data analytics, tree methods can
- * often reveal simple relationships between just a few variables that could have easily gone
- * unnoticed using other analytic techniques.</dd>
+ * <dd>The final results of using tree methods for classification or regression can be summarized in a series of (usually few) logical if-then
+ * conditions (tree nodes). Therefore, there is no implicit assumption that the underlying relationships between the predictor variables and the
+ * dependent variable are linear, follow some specific non-linear link function, or that they are even monotonic in nature. Thus, tree methods are
+ * particularly well suited for data mining tasks, where there is often little a priori knowledge nor any coherent set of theories or predictions
+ * regarding which variables are related and how. In those types of data analytics, tree methods can often reveal simple relationships between just a
+ * few variables that could have easily gone unnoticed using other analytic techniques.</dd>
* </dl>
- * One major problem with classification and regression trees is their high variance. Often a small
- * change in the data can result in a very different series of splits, making interpretation
- * somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause
- * over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation
- * of trees is the lack of smoothness of the prediction surface.
+ * One major problem with classification and regression trees is their high variance. Often a small change in the data can result in a very different
+ * series of splits, making interpretation somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause over-fitting.
+ * Mechanisms such as pruning are necessary to avoid this problem. Another limitation of trees is the lack of smoothness of the prediction surface.
* <p>
- * Some techniques such as bagging, boosting, and random forest use more than one decision tree for
- * their analysis.
+ * Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis.
*/
-public final class DecisionTree implements Classifier<double[]> {
+public final class DecisionTree implements Classifier<Vector> {
/**
* The attributes of independent variable.
*/
+ @Nonnull
private final Attribute[] _attributes;
private final boolean _hasNumericType;
/**
- * Variable importance. Every time a split of a node is made on variable the (GINI, information
- * gain, etc.) impurity criterion for the two descendant nodes is less than the parent node.
- * Adding up the decreases for each individual variable over the tree gives a simple measure of
+ * Variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the two
+ * descendant nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of
* variable importance.
*/
- private final double[] _importance;
+ @Nonnull
+ private final Vector _importance;
/**
* The root of the regression tree
*/
+ @Nonnull
private final Node _root;
/**
* The maximum number of the tree depth
@@ -135,6 +129,7 @@ public final class DecisionTree implements Classifier<double[]> {
/**
* The splitting rule.
*/
+ @Nonnull
private final SplitRule _rule;
/**
* The number of classes.
@@ -153,24 +148,23 @@ public final class DecisionTree implements Classifier<double[]> {
*/
private final int _minLeafSize;
/**
- * The index of training values in ascending order. Note that only numeric attributes will be
- * sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ @Nonnull
+ private final ColumnMajorIntMatrix _order;
- private final Random _rnd;
+ @Nonnull
+ private final PRNG _rnd;
/**
* The criterion to choose variable to split instances.
*/
public static enum SplitRule {
/**
- * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen
- * element from the set would be incorrectly labeled if it were randomly labeled according
- * to the distribution of labels in the subset. Gini impurity can be computed by summing the
- * probability of each item being chosen times the probability of a mistake in categorizing
- * that item. It reaches its minimum (zero) when all cases in the node fall into a single
- * target category.
+ * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element from the set would be incorrectly labeled if
+ * it were randomly labeled according to the distribution of labels in the subset. Gini impurity can be computed by summing the probability of
+ * each item being chosen times the probability of a mistake in categorizing that item. It reaches its minimum (zero) when all cases in the
+ * node fall into a single target category.
*/
GINI,
/**
@@ -193,6 +187,11 @@ public final class DecisionTree implements Classifier<double[]> {
*/
int output = -1;
/**
+ * Posteriori probability based on sample ratios in this node.
+ */
+ @Nullable
+ double[] posteriori = null;
+ /**
* The split feature for this node.
*/
int splitFeature = -1;
@@ -227,28 +226,35 @@ public final class DecisionTree implements Classifier<double[]> {
public Node() {}// for Externalizable
- /**
- * Constructor.
- */
- public Node(int output) {
+ public Node(int output, @Nonnull double[] posteriori) {
this.output = output;
+ this.posteriori = posteriori;
+ }
+
+ private boolean isLeaf() {
+ return posteriori != null;
+ }
+
+ @VisibleForTesting
+ public int predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
}
/**
* Evaluate the regression tree over an instance.
*/
- public int predict(final double[] x) {
+ public int predict(@Nonnull final Vector x) {
if (trueChild == null && falseChild == null) {
return output;
} else {
if (splitFeatureType == AttributeType.NOMINAL) {
- if (x[splitFeature] == splitValue) {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
} else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x[splitFeature] <= splitValue) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
@@ -260,6 +266,32 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
+ /**
+ * Evaluate the regression tree over an instance.
+ */
+ public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) {
+ if (trueChild == null && falseChild == null) {
+ handler.handle(output, posteriori);
+ } else {
+ if (splitFeatureType == AttributeType.NOMINAL) {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
+ trueChild.predict(x, handler);
+ } else {
+ falseChild.predict(x, handler);
+ }
+ } else if (splitFeatureType == AttributeType.NUMERIC) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
+ trueChild.predict(x, handler);
+ } else {
+ falseChild.predict(x, handler);
+ }
+ } else {
+ throw new IllegalStateException("Unsupported attribute type: "
+ + splitFeatureType);
+ }
+ }
+ }
+
public void jsCodegen(@Nonnull final StringBuilder builder, final int depth) {
if (trueChild == null && falseChild == null) {
indent(builder, depth);
@@ -298,99 +330,71 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- public int opCodegen(final List<String> scripts, int depth) {
- int selfDepth = 0;
- final StringBuilder buf = new StringBuilder();
- if (trueChild == null && falseChild == null) {
- buf.append("push ").append(output);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("goto last");
- scripts.add(buf.toString());
- selfDepth += 2;
- } else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifeq ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifle ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else {
- throw new IllegalStateException("Unsupported attribute type: "
- + splitFeatureType);
- }
- }
- return selfDepth;
- }
-
@Override
public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(output);
out.writeInt(splitFeature);
if (splitFeatureType == null) {
- out.writeInt(-1);
+ out.writeByte(-1);
} else {
- out.writeInt(splitFeatureType.getTypeId());
+ out.writeByte(splitFeatureType.getTypeId());
}
out.writeDouble(splitValue);
- if (trueChild == null) {
- out.writeBoolean(false);
- } else {
+
+ if (isLeaf()) {
out.writeBoolean(true);
- trueChild.writeExternal(out);
- }
- if (falseChild == null) {
- out.writeBoolean(false);
+
+ out.writeInt(output);
+ out.writeInt(posteriori.length);
+ for (int i = 0; i < posteriori.length; i++) {
+ out.writeDouble(posteriori[i]);
+ }
} else {
- out.writeBoolean(true);
- falseChild.writeExternal(out);
+ out.writeBoolean(false);
+
+ if (trueChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ trueChild.writeExternal(out);
+ }
+ if (falseChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ falseChild.writeExternal(out);
+ }
}
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this.output = in.readInt();
this.splitFeature = in.readInt();
- int typeId = in.readInt();
+ byte typeId = in.readByte();
if (typeId == -1) {
this.splitFeatureType = null;
} else {
this.splitFeatureType = AttributeType.resolve(typeId);
}
this.splitValue = in.readDouble();
- if (in.readBoolean()) {
- this.trueChild = new Node();
- trueChild.readExternal(in);
- }
- if (in.readBoolean()) {
- this.falseChild = new Node();
- falseChild.readExternal(in);
+
+ if (in.readBoolean()) {//isLeaf
+ this.output = in.readInt();
+
+ final int size = in.readInt();
+ final double[] posteriori = new double[size];
+ for (int i = 0; i < size; i++) {
+ posteriori[i] = in.readDouble();
+ }
+ this.posteriori = posteriori;
+ } else {
+ if (in.readBoolean()) {
+ this.trueChild = new Node();
+ trueChild.readExternal(in);
+ }
+ if (in.readBoolean()) {
+ this.falseChild = new Node();
+ falseChild.readExternal(in);
+ }
}
}
@@ -413,7 +417,7 @@ public final class DecisionTree implements Classifier<double[]> {
/**
* Training dataset.
*/
- final double[][] x;
+ final Matrix x;
/**
* class labels.
*/
@@ -426,7 +430,7 @@ public final class DecisionTree implements Classifier<double[]> {
/**
* Constructor.
*/
- public TrainNode(Node node, double[][] x, int[] y, int[] bags, int depth) {
+ public TrainNode(Node node, Matrix x, int[] y, int[] bags, int depth) {
this.node = node;
this.x = x;
this.y = y;
@@ -466,21 +470,12 @@ public final class DecisionTree implements Classifier<double[]> {
final double impurity = impurity(count, numSamples, _rule);
- final int p = _attributes.length;
- final int[] variableIndex = new int[p];
- for (int i = 0; i < p; i++) {
- variableIndex[i] = i;
- }
- if (_numVars < p) {
- SmileExtUtils.shuffle(variableIndex, _rnd);
- }
-
- final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length)
+ final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows())
: null;
final int[] falseCount = new int[_k];
- for (int j = 0; j < _numVars; j++) {
- Node split = findBestSplit(numSamples, count, falseCount, impurity,
- variableIndex[j], samples);
+ for (int varJ : variableIndex(x, bags)) {
+ final Node split = findBestSplit(numSamples, count, falseCount, impurity, varJ,
+ samples);
if (split.splitScore > node.splitScore) {
node.splitFeature = split.splitFeature;
node.splitFeatureType = split.splitFeatureType;
@@ -491,7 +486,33 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- return (node.splitFeature != -1);
+ return node.splitFeature != -1;
+ }
+
+ @Nonnull
+ private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) {
+ final IntReservoirSampler sampler = new IntReservoirSampler(_numVars, _rnd.nextLong());
+ if (x.isSparse()) {
+ final RoaringBitmap cols = new RoaringBitmap();
+ final VectorProcedure proc = new VectorProcedure() {
+ public void apply(final int col) {
+ cols.add(col);
+ }
+ };
+ for (final int row : bags) {
+ x.eachColumnIndexInRow(row, proc);
+ }
+ cols.forEach(new IntConsumer() {
+ public void accept(final int k) {
+ sampler.add(k);
+ }
+ });
+ } else {
+ for (int i = 0, size = _attributes.length; i < size; i++) {
+ sampler.add(i);
+ }
+ }
+ return sampler.getSample();
}
private boolean sampleCount(@Nonnull final int[] count) {
@@ -530,7 +551,11 @@ public final class DecisionTree implements Classifier<double[]> {
for (int i = 0, size = bags.length; i < size; i++) {
int index = bags[i];
- int x_ij = (int) x[index][j];
+ final double v = x.get(index, j, Double.NaN);
+ if (Double.isNaN(v)) {
+ continue;
+ }
+ int x_ij = (int) v;
trueCount[x_ij][y[index]]++;
}
@@ -563,21 +588,28 @@ public final class DecisionTree implements Classifier<double[]> {
}
} else if (_attributes[j].type == AttributeType.NUMERIC) {
final int[] trueCount = new int[_k];
- double prevx = Double.NaN;
- int prevy = -1;
-
- assert (samples != null);
- for (final int i : _order[j]) {
- final int sample = samples[i];
- if (sample > 0) {
- final double x_ij = x[i][j];
+
+ _order.eachNonNullInColumn(j, new VectorProcedure() {
+ double prevx = Double.NaN;
+ int prevy = -1;
+
+ public void apply(final int row, final int i) {
+ final int sample = samples[i];
+ if (sample == 0) {
+ return;
+ }
+
+ final double x_ij = x.get(i, j, Double.NaN);
+ if (Double.isNaN(x_ij)) {
+ return;
+ }
final int y_i = y[i];
if (Double.isNaN(prevx) || x_ij == prevx || y_i == prevy) {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += sample;
- continue;
+ return;
}
final int tc = Math.sum(trueCount);
@@ -588,7 +620,7 @@ public final class DecisionTree implements Classifier<double[]> {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += sample;
- continue;
+ return;
}
for (int l = 0; l < _k; l++) {
@@ -612,8 +644,8 @@ public final class DecisionTree implements Classifier<double[]> {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += sample;
- }
- }
+ }//apply()
+ });
} else {
throw new IllegalStateException("Unsupported attribute type: "
+ _attributes[j].type);
@@ -634,7 +666,9 @@ public final class DecisionTree implements Classifier<double[]> {
int childBagSize = (int) (bags.length * 0.4);
IntArrayList trueBags = new IntArrayList(childBagSize);
IntArrayList falseBags = new IntArrayList(childBagSize);
- int tc = splitSamples(trueBags, falseBags);
+ double[] trueChildPosteriori = new double[_k];
+ double[] falseChildPosteriori = new double[_k];
+ int tc = splitSamples(trueBags, falseBags, trueChildPosteriori, falseChildPosteriori);
int fc = bags.length - tc;
this.bags = null; // help GC for recursive call
@@ -647,7 +681,12 @@ public final class DecisionTree implements Classifier<double[]> {
return false;
}
- node.trueChild = new Node(node.trueChildOutput);
+ for (int i = 0; i < _k; i++) {
+ trueChildPosteriori[i] /= tc;
+ falseChildPosteriori[i] /= fc;
+ }
+
+ node.trueChild = new Node(node.trueChildOutput, trueChildPosteriori);
TrainNode trueChild = new TrainNode(node.trueChild, x, y, trueBags.toArray(), depth + 1);
trueBags = null; // help GC for recursive call
if (tc >= _minSplit && trueChild.findBestSplit()) {
@@ -658,7 +697,7 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- node.falseChild = new Node(node.falseChildOutput);
+ node.falseChild = new Node(node.falseChildOutput, falseChildPosteriori);
TrainNode falseChild = new TrainNode(node.falseChild, x, y, falseBags.toArray(),
depth + 1);
falseBags = null; // help GC for recursive call
@@ -670,27 +709,33 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- _importance[node.splitFeature] += node.splitScore;
+ _importance.incr(node.splitFeature, node.splitScore);
+ node.posteriori = null; // posteriori is not needed for non-leaf nodes
return true;
}
/**
+ * @param falseChildPosteriori
+ * @param trueChildPosteriori
* @return the number of true samples
*/
private int splitSamples(@Nonnull final IntArrayList trueBags,
- @Nonnull final IntArrayList falseBags) {
+ @Nonnull final IntArrayList falseBags, @Nonnull final double[] trueChildPosteriori,
+ @Nonnull final double[] falseChildPosteriori) {
int tc = 0;
if (node.splitFeatureType == AttributeType.NOMINAL) {
final int splitFeature = node.splitFeature;
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] == splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) == splitValue) {
trueBags.add(index);
+ trueChildPosteriori[y[index]]++;
tc++;
} else {
falseBags.add(index);
+ falseChildPosteriori[y[index]]++;
}
}
} else if (node.splitFeatureType == AttributeType.NUMERIC) {
@@ -698,11 +743,13 @@ public final class DecisionTree implements Classifier<double[]> {
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] <= splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) <= splitValue) {
trueBags.add(index);
+ trueChildPosteriori[y[index]]++;
tc++;
} else {
falseBags.add(index);
+ falseChildPosteriori[y[index]]++;
}
}
} else {
@@ -714,7 +761,6 @@ public final class DecisionTree implements Classifier<double[]> {
}
-
/**
* Returns the impurity of a node.
*
@@ -731,8 +777,9 @@ public final class DecisionTree implements Classifier<double[]> {
case GINI: {
impurity = 1.0;
for (int i = 0; i < count.length; i++) {
- if (count[i] > 0) {
- double p = (double) count[i] / n;
+ final int count_i = count[i];
+ if (count_i > 0) {
+ double p = (double) count_i / n;
impurity -= p * p;
}
}
@@ -740,8 +787,9 @@ public final class DecisionTree implements Classifier<double[]> {
}
case ENTROPY: {
for (int i = 0; i < count.length; i++) {
- if (count[i] > 0) {
- double p = (double) count[i] / n;
+ final int count_i = count[i];
+ if (count_i > 0) {
+ double p = (double) count_i / n;
impurity -= p * Math.log2(p);
}
}
@@ -750,8 +798,9 @@ public final class DecisionTree implements Classifier<double[]> {
case CLASSIFICATION_ERROR: {
impurity = 0.d;
for (int i = 0; i < count.length; i++) {
- if (count[i] > 0) {
- impurity = Math.max(impurity, (double) count[i] / n);
+ final int count_i = count[i];
+ if (count_i > 0) {
+ impurity = Math.max(impurity, (double) count_i / n);
}
}
impurity = Math.abs(1.d - impurity);
@@ -762,14 +811,14 @@ public final class DecisionTree implements Classifier<double[]> {
return impurity;
}
- public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y,
+ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y,
int numLeafs) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null);
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null);
}
- public DecisionTree(@Nullable Attribute[] attributes, @Nullable double[][] x,
- @Nullable int[] y, int numLeafs, @Nullable smile.math.Random rand) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand);
+ public DecisionTree(@Nullable Attribute[] attributes, @Nullable Matrix x, @Nullable int[] y,
+ int numLeafs, @Nullable PRNG rand) {
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand);
}
/**
@@ -778,21 +827,20 @@ public final class DecisionTree implements Classifier<double[]> {
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
- * @param numVars the number of input variables to pick to split on at each node. It seems that
- * dim/3 give generally good performance, where dim is the number of variables.
+ * @param numVars the number of input variables to pick to split on at each node. It seems that dim/3 give generally good performance, where dim
+ * is the number of variables.
* @param maxLeafs the maximum number of leaf nodes in the tree.
* @param minSplits the number of minimum elements in a node to split
* @param minLeafSize the minimum size of leaf nodes.
- * @param order the index of training values in ascending order. Note that only numeric
- * attributes need be sorted.
+ * @param order the index of training values in ascending order. Note that only numeric attributes need be sorted.
* @param bags the sample set of instances for stochastic learning.
* @param rule the splitting rule.
* @param seed
*/
- public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y,
+ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y,
int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
- @Nullable int[] bags, @Nullable int[][] order, @Nonnull SplitRule rule,
- @Nullable smile.math.Random rand) {
+ @Nullable int[] bags, @Nullable ColumnMajorIntMatrix order, @Nonnull SplitRule rule,
+ @Nullable PRNG rand) {
checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
this._k = Math.max(y) + 1;
@@ -801,7 +849,7 @@ public final class DecisionTree implements Classifier<double[]> {
}
this._attributes = SmileExtUtils.attributeTypes(attributes, x);
- if (attributes.length != x[0].length) {
+ if (attributes.length != x.numColumns()) {
throw new IllegalArgumentException("-attrs option is invliad: "
+ Arrays.toString(attributes));
}
@@ -813,8 +861,8 @@ public final class DecisionTree implements Classifier<double[]> {
this._minLeafSize = minLeafSize;
this._rule = rule;
this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
- this._importance = new double[_attributes.length];
- this._rnd = (rand == null) ? new smile.math.Random() : rand;
+ this._importance = x.isSparse() ? new SparseVector() : new DenseVector(_attributes.length);
+ this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
final int n = y.length;
final int[] count = new int[_k];
@@ -825,13 +873,17 @@ public final class DecisionTree implements Classifier<double[]> {
count[y[i]]++;
}
} else {
- for (int i = 0; i < n; i++) {
+ for (int i = 0, size = bags.length; i < size; i++) {
int index = bags[i];
count[y[index]]++;
}
}
- this._root = new Node(Math.whichMax(count));
+ final double[] posteriori = new double[_k];
+ for (int i = 0; i < _k; i++) {
+ posteriori[i] = (double) count[i] / n;
+ }
+ this._root = new Node(Math.whichMax(count), posteriori);
final TrainNode trainRoot = new TrainNode(_root, x, y, bags, 1);
if (maxLeafs == Integer.MAX_VALUE) {
@@ -858,13 +910,13 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- private static void checkArgument(@Nonnull double[][] x, @Nonnull int[] y, int numVars,
+ private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars,
int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
- if (x.length != y.length) {
+ if (x.numRows() != y.length) {
throw new IllegalArgumentException(String.format(
- "The sizes of X and Y don't match: %d != %d", x.length, y.length));
+ "The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
}
- if (numVars <= 0 || numVars > x[0].length) {
+ if (numVars <= 0 || numVars > x.numColumns()) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
}
@@ -885,28 +937,31 @@ public final class DecisionTree implements Classifier<double[]> {
}
/**
- * Returns the variable importance. Every time a split of a node is made on variable the (GINI,
- * information gain, etc.) impurity criterion for the two descendent nodes is less than the
- * parent node. Adding up the decreases for each individual variable over the tree gives a
- * simple measure of variable importance.
+ * Returns the variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the
+ * two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of
+ * variable importance.
*
* @return the variable importance
*/
- public double[] importance() {
+ @Nonnull
+ public Vector importance() {
return _importance;
}
+ @VisibleForTesting
+ public int predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
+ }
+
@Override
- public int predict(final double[] x) {
+ public int predict(@Nonnull final Vector x) {
return _root.predict(x);
}
/**
- * Predicts the class label of an instance and also calculate a posteriori probabilities. Not
- * supported.
+ * Predicts the class label of an instance and also calculate a posteriori probabilities. Not supported.
*/
- @Override
- public int predict(double[] x, double[] posteriori) {
+ public int predict(Vector x, double[] posteriori) {
throw new UnsupportedOperationException("Not supported.");
}
@@ -916,14 +971,6 @@ public final class DecisionTree implements Classifier<double[]> {
return buf.toString();
}
- public String predictOpCodegen(String sep) {
- List<String> opslist = new ArrayList<String>();
- _root.opCodegen(opslist, 0);
- opslist.add("call end");
- String scripts = StringUtils.concat(opslist, sep);
- return scripts;
- }
-
@Nonnull
public byte[] predictSerCodegen(boolean compress) throws HiveException {
try {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 3a0924e..a380a11 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -19,24 +19,27 @@
package hivemall.smile.classification;
import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
import hivemall.smile.data.Attribute;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.smile.vm.StackMachine;
import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Primitives;
+import hivemall.utils.math.MathUtils;
-import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
-import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -63,7 +66,7 @@ import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
@Description(name = "train_gradient_tree_boosting_classifier",
- value = "_FUNC_(double[] features, int label [, string options]) - "
+ value = "_FUNC_(array<double|string> features, int label [, string options]) - "
+ "Returns a relation consists of "
+ "<int iteration, int model_type, array<string> pred_models, double intercept, "
+ "double shrinkage, array<double> var_importance, float oob_error_rate>")
@@ -74,7 +77,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector labelOI;
- private List<double[]> featuresList;
+ private boolean denseInput;
+ private MatrixBuilder matrixBuilder;
private IntArrayList labels;
/**
* The number of trees for each task
@@ -104,7 +108,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
private int _minSamplesLeaf;
private long _seed;
private Attribute[] _attributes;
- private ModelType _outputType;
@Nullable
private Reporter _progressReporter;
@@ -134,10 +137,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
- opts.addOption("output", "output_type", true,
- "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
- opts.addOption("disable_compression", false,
- "Whether to disable compression of the output script [default: false]");
return opts;
}
@@ -149,8 +148,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
double eta = 0.05d, subsample = 0.7d;
Attribute[] attrs = null;
long seed = -1L;
- String output = "serialization";
- boolean compress = true;
CommandLine cl = null;
if (argOIs.length >= 3) {
@@ -171,10 +168,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
- output = cl.getOptionValue("output", output);
- if (cl.hasOption("disable_compression")) {
- compress = false;
- }
}
this._numTrees = trees;
@@ -187,7 +180,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
this._attributes = attrs;
- this._outputType = ModelType.resolve(output, compress);
return cl;
}
@@ -197,19 +189,29 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
if (argOIs.length != 2 && argOIs.length != 3) {
throw new UDFArgumentException(
getClass().getSimpleName()
- + " takes 2 or 3 arguments: double[] features, int label [, const string options]: "
+ + " takes 2 or 3 arguments: array<double|string> features, int label [, const string options]: "
+ argOIs.length);
}
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
- this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ this.matrixBuilder = new CSRMatrixBuilder(8192);
+ } else {
+ throw new UDFArgumentException(
+ "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+ }
this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
processOptions(argOIs);
- this.featuresList = new ArrayList<double[]>(1024);
this.labels = new IntArrayList(1024);
ArrayList<String> fieldNames = new ArrayList<String>(6);
@@ -217,8 +219,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
fieldNames.add("iteration");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
- fieldNames.add("model_type");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("pred_models");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
fieldNames.add("intercept");
@@ -238,13 +238,36 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
if (args[0] == null) {
throw new HiveException("array<double> features was null");
}
- double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+ parseFeatures(args[0], matrixBuilder);
int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
-
- featuresList.add(features);
labels.add(label);
}
+ private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+ builder.nextColumn(i, v);
+ }
+ } else {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String fv = o.toString();
+ builder.nextColumn(fv);
+ }
+ }
+ builder.nextRow();
+ }
+
@Override
public void close() throws HiveException {
this._progressReporter = getReporter();
@@ -252,14 +275,15 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
"hivemall.smile.GradientTreeBoostingClassifier$Counter", "iteration");
reportProgress(_progressReporter);
- int numExamples = featuresList.size();
- double[][] x = featuresList.toArray(new double[numExamples][]);
- this.featuresList = null;
- int[] y = labels.toArray();
- this.labels = null;
+ if (!labels.isEmpty()) {
+ Matrix x = matrixBuilder.buildMatrix();
+ this.matrixBuilder = null;
+ int[] y = labels.toArray();
+ this.labels = null;
- // run training
- train(x, y);
+ // run training
+ train(x, y);
+ }
// clean up
this.featureListOI = null;
@@ -287,25 +311,25 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
* @param x features
* @param y label
*/
- private void train(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException {
- if (x.length != y.length) {
+ private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException {
+ final int numRows = x.numRows();
+ if (numRows != y.length) {
throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
- x.length, y.length));
+ numRows, y.length));
}
checkOptions();
this._attributes = SmileExtUtils.attributeTypes(_attributes, x);
// Shuffle training samples
- SmileExtUtils.shuffle(x, y, _seed);
+ x = SmileExtUtils.shuffle(x, y, _seed);
final int k = smile.math.Math.max(y) + 1;
if (k < 2) {
throw new UDFArgumentException("Only one class or negative class labels.");
}
if (k == 2) {
- int n = x.length;
- final int[] y2 = new int[n];
- for (int i = 0; i < n; i++) {
+ final int[] y2 = new int[numRows];
+ for (int i = 0; i < numRows; i++) {
if (y[i] == 1) {
y2[i] = 1;
} else {
@@ -318,7 +342,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
}
}
- private void train2(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException {
+ private void train2(@Nonnull final Matrix x, @Nonnull final int[] y) throws HiveException {
final int numVars = SmileExtUtils.computeNumInputVars(_numVars, x);
if (logger.isInfoEnabled()) {
logger.info("k: " + 2 + ", numTrees: " + _numTrees + ", shirinkage: " + _eta
@@ -327,7 +351,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
+ _maxLeafNodes + ", seed: " + _seed);
}
- final int numInstances = x.length;
+ final int numInstances = x.numRows();
final int numSamples = (int) Math.round(numInstances * _subsample);
final double[] h = new double[numInstances]; // current F(x_i)
@@ -340,7 +364,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
h[i] = intercept;
}
- final int[][] order = SmileExtUtils.sort(_attributes, x);
+ final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x);
final RegressionTree.NodeOutput output = new L2NodeOutput(response);
final BitSet sampled = new BitSet(numInstances);
@@ -351,10 +375,11 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
}
long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
- : new smile.math.Random(_seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final Vector xProbe = x.rowVector();
for (int m = 0; m < _numTrees; m++) {
reportProgress(_progressReporter);
@@ -373,7 +398,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
_maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, order, bag, output, rnd2);
for (int i = 0; i < numInstances; i++) {
- h[i] += _eta * tree.predict(x[i]);
+ x.getRow(i, xProbe);
+ h[i] += _eta * tree.predict(xProbe);
}
// out-of-bag error estimate
@@ -398,7 +424,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
/**
* Train L-k tree boost.
*/
- private void traink(final double[][] x, final int[] y, final int k) throws HiveException {
+ private void traink(final Matrix x, final int[] y, final int k) throws HiveException {
final int numVars = SmileExtUtils.computeNumInputVars(_numVars, x);
if (logger.isInfoEnabled()) {
logger.info("k: " + k + ", numTrees: " + _numTrees + ", shirinkage: " + _eta
@@ -407,14 +433,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
+ ", maxLeafs: " + _maxLeafNodes + ", seed: " + _seed);
}
- final int numInstances = x.length;
+ final int numInstances = x.numRows();
final int numSamples = (int) Math.round(numInstances * _subsample);
final double[][] h = new double[k][numInstances]; // boost tree output.
final double[][] p = new double[k][numInstances]; // posteriori probabilities.
final double[][] response = new double[k][numInstances]; // pseudo response.
- final int[][] order = SmileExtUtils.sort(_attributes, x);
+ final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x);
final RegressionTree.NodeOutput[] output = new LKNodeOutput[k];
for (int i = 0; i < k; i++) {
output[i] = new LKNodeOutput(response[i], k);
@@ -422,19 +448,16 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
final BitSet sampled = new BitSet(numInstances);
final int[] bag = new int[numSamples];
- final int[] perm = new int[numSamples];
- for (int i = 0; i < numSamples; i++) {
- perm[i] = i;
- }
+ final int[] perm = MathUtils.permutation(numInstances);
long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
- : new smile.math.Random(_seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
// out-of-bag prediction
final int[] prediction = new int[numInstances];
-
+ final Vector xProbe = x.rowVector();
for (int m = 0; m < _numTrees; m++) {
for (int i = 0; i < numInstances; i++) {
double max = Double.NEGATIVE_INFINITY;
@@ -490,7 +513,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
trees[j] = tree;
for (int i = 0; i < numInstances; i++) {
- double h_ji = h_j[i] + _eta * tree.predict(x[i]);
+ x.getRow(i, xProbe);
+ double h_ji = h_j[i] + _eta * tree.predict(xProbe);
h_j[i] += h_ji;
if (h_ji > max_h) {
max_h = h_ji;
@@ -524,7 +548,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
*/
private void forward(final int m, final double intercept, final double shrinkage,
final float oobErrorRate, @Nonnull final RegressionTree... trees) throws HiveException {
- Text[] models = getModel(trees, _outputType);
+ Text[] models = getModel(trees);
double[] importance = new double[_attributes.length];
for (RegressionTree tree : trees) {
@@ -534,14 +558,13 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
}
}
- Object[] forwardObjs = new Object[7];
+ Object[] forwardObjs = new Object[6];
forwardObjs[0] = new IntWritable(m);
- forwardObjs[1] = new IntWritable(_outputType.getId());
- forwardObjs[2] = models;
- forwardObjs[3] = new DoubleWritable(intercept);
- forwardObjs[4] = new DoubleWritable(shrinkage);
- forwardObjs[5] = WritableUtils.toWritableList(importance);
- forwardObjs[6] = new FloatWritable(oobErrorRate);
+ forwardObjs[1] = models;
+ forwardObjs[2] = new DoubleWritable(intercept);
+ forwardObjs[3] = new DoubleWritable(shrinkage);
+ forwardObjs[4] = WritableUtils.toWritableList(importance);
+ forwardObjs[5] = new FloatWritable(oobErrorRate);
forward(forwardObjs);
@@ -551,67 +574,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
logger.info("Forwarded the output of " + m + "-th Boosting iteration out of " + _numTrees);
}
- private static Text[] getModel(@Nonnull final RegressionTree[] trees,
- @Nonnull final ModelType outputType) throws HiveException {
+ @Nonnull
+ private static Text[] getModel(@Nonnull final RegressionTree[] trees) throws HiveException {
final int m = trees.length;
final Text[] models = new Text[m];
- switch (outputType) {
- case serialization:
- case serialization_compressed: {
- for (int i = 0; i < m; i++) {
- byte[] b = trees[i].predictSerCodegen(outputType.isCompressed());
- b = Base91.encode(b);
- models[i] = new Text(b);
- }
- break;
- }
- case opscode:
- case opscode_compressed: {
- for (int i = 0; i < m; i++) {
- String s = trees[i].predictOpCodegen(StackMachine.SEP);
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- models[i] = new Text(b);
- } else {
- models[i] = new Text(s);
- }
- }
- break;
- }
- case javascript:
- case javascript_compressed: {
- for (int i = 0; i < m; i++) {
- String s = trees[i].predictJsCodegen();
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- models[i] = new Text(b);
- } else {
- models[i] = new Text(s);
- }
- }
- break;
- }
- default:
- throw new HiveException("Unexpected output type: " + outputType
- + ". Use javascript for the output instead");
+ for (int i = 0; i < m; i++) {
+ byte[] b = trees[i].predictSerCodegen(true);
+ b = Base91.encode(b);
+ models[i] = new Text(b);
}
return models;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
new file mode 100644
index 0000000..84ef244
--- /dev/null
+++ b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.classification;
+
+import javax.annotation.Nonnull;
+
+public interface PredictionHandler {
+
+ void handle(int output, @Nonnull double[] posteriori);
+
+}