You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/02/05 04:56:01 UTC
[incubator-hivemall] branch master updated: [HIVEMALL-233-2]
RandomForest regressor accepts sparse vector input
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new 2e1104c [HIVEMALL-233-2] RandomForest regressor accepts sparse vector input
2e1104c is described below
commit 2e1104c1eeb4598ba5cc8e74dfba5d36699344f3
Author: Takuya Kitazawa <k....@gmail.com>
AuthorDate: Tue Feb 5 13:55:55 2019 +0900
[HIVEMALL-233-2] RandomForest regressor accepts sparse vector input
## What changes were proposed in this pull request?
Enable RandomForestRegressor to accept sparse vector input as RandomForestClassifier already does.
This closes #178
## What type of PR is it?
Improvement
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-233
## How was this patch tested?
manual tests on EMR
## How to use this feature?
```sql
with customers as (
select 1 as id, "male" as gender, 23 as age, "Japan" as country, 12 as num_purchases
union all
select 2 as id, "female" as gender, 43 as age, "US" as country, 4 as num_purchases
union all
select 3 as id, "other" as gender, 19 as age, "UK" as country, 2 as num_purchases
union all
select 4 as id, "male" as gender, 31 as age, "US" as country, 20 as num_purchases
union all
select 5 as id, "female" as gender, 37 as age, "Australia" as country, 9 as num_purchases
),
training as (
select
array_concat(
quantitative_features(
array("age"),
age
),
categorical_features(
array("country", "gender"),
country, gender
)
) as features,
num_purchases
from
customers
)
select
train_randomforest_regressor(
feature_hashing(features), -- feature vector
num_purchases, -- target value
'-trees 40 -seed 31' -- hyper-parameters
)
from
training
;
```
## Checklist
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [ ] Did you run system tests on Hive (or Spark)?
Author: Takuya Kitazawa <k....@gmail.com>
Author: Makoto Yui <my...@apache.org>
Closes #181 from myui/HIVEMALL-233-2.
---
.../GradientTreeBoostingClassifierUDTF.java | 36 ++-
.../regression/RandomForestRegressionUDTF.java | 44 ++-
.../hivemall/smile/regression/RegressionTree.java | 9 +-
.../GradientTreeBoostingClassifierUDTFTest.java | 181 ++++++++++++
.../regression/RandomForestRegressionUDTFTest.java | 308 +++++++++++++++++++++
5 files changed, 555 insertions(+), 23 deletions(-)
diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 1a4fa1c..9edc63d 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -26,7 +26,10 @@ import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
@@ -40,6 +43,8 @@ import hivemall.utils.math.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
+import java.util.Map;
+import java.util.HashMap;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -227,8 +232,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
fieldNames.add("shrinkage");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("var_importance");
- fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ if (denseInput) {
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ } else {
+ fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector,
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ }
fieldNames.add("oob_error_rate");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
@@ -555,11 +566,11 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
final float oobErrorRate, @Nonnull final RegressionTree... trees) throws HiveException {
Text[] models = getModel(trees);
- double[] importance = new double[_attributes.length];
+ Vector importance = denseInput ? new DenseVector(_attributes.length) : new SparseVector();
for (RegressionTree tree : trees) {
- double[] imp = tree.importance();
- for (int i = 0; i < imp.length; i++) {
- importance[i] += imp[i];
+ Vector imp = tree.importance();
+ for (int i = 0, size = imp.size(); i < size; i++) {
+ importance.incr(i, imp.get(i));
}
}
@@ -568,7 +579,18 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
forwardObjs[1] = models;
forwardObjs[2] = new DoubleWritable(intercept);
forwardObjs[3] = new DoubleWritable(shrinkage);
- forwardObjs[4] = WritableUtils.toWritableList(importance);
+ if (denseInput) {
+ forwardObjs[4] = WritableUtils.toWritableList(importance.toArray());
+ } else {
+ final Map<IntWritable, DoubleWritable> map =
+ new HashMap<IntWritable, DoubleWritable>(importance.size());
+ importance.each(new VectorProcedure() {
+ public void apply(int i, double value) {
+ map.put(new IntWritable(i), new DoubleWritable(value));
+ }
+ });
+ forwardObjs[4] = map;
+ }
forwardObjs[5] = new FloatWritable(oobErrorRate);
forward(forwardObjs);
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index df5d55b..dc148e2 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -27,6 +27,7 @@ import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
@@ -40,7 +41,9 @@ import hivemall.utils.lang.RandomUtils;
import java.util.ArrayList;
import java.util.BitSet;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@@ -71,10 +74,10 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
-@Description(name = "train_randomforest_regression",
+@Description(name = "train_randomforest_regressor",
value = "_FUNC_(array<double|string> features, double target [, string options]) - "
+ "Returns a relation consists of "
- + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
+ + "<int model_id, int model_type, string model, array<double> var_importance, double oob_errors, int oob_tests>")
public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(RandomForestRegressionUDTF.class);
@@ -206,18 +209,24 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
this.targets = new DoubleArrayList(1024);
- ArrayList<String> fieldNames = new ArrayList<String>(5);
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(5);
+ final ArrayList<String> fieldNames = new ArrayList<String>(6);
+ final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
fieldNames.add("model_err");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
- fieldNames.add("pred_model");
+ fieldNames.add("model");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
fieldNames.add("var_importance");
- fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ if (denseInput) {
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ } else {
+ fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector,
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ }
fieldNames.add("oob_errors");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("oob_tests");
@@ -359,7 +368,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
* @param error
*/
synchronized void forward(final int taskId, @Nonnull final Text model,
- @Nonnull final double[] importance, @Nonnegative final double error, final double[] y,
+ @Nonnull final Vector importance, @Nonnegative final double error, final double[] y,
final double[] prediction, final int[] oob, final boolean lastTask)
throws HiveException {
double oobErrors = 0.d;
@@ -379,7 +388,18 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
forwardObjs[0] = new Text(modelId);
forwardObjs[1] = new DoubleWritable(error);
forwardObjs[2] = model;
- forwardObjs[3] = WritableUtils.toWritableList(importance);
+ if (denseInput) {
+ forwardObjs[3] = WritableUtils.toWritableList(importance.toArray());
+ } else {
+ final Map<IntWritable, DoubleWritable> map =
+ new HashMap<IntWritable, DoubleWritable>(importance.size());
+ importance.each(new VectorProcedure() {
+ public void apply(int i, double value) {
+ map.put(new IntWritable(i), new DoubleWritable(value));
+ }
+ });
+ forwardObjs[3] = map;
+ }
forwardObjs[4] = new DoubleWritable(oobErrors);
forwardObjs[5] = new IntWritable(oobTests);
forward(forwardObjs);
@@ -403,7 +423,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
*/
private final Matrix _x;
/**
- * Training sample labels.
+ * Training sample target values.
*/
private final double[] _y;
/**
@@ -476,7 +496,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
oob++;
_x.getRow(i, xProbe);
final double pred = tree.predict(xProbe);
- synchronized (_prediction) {
+ synchronized (_udtf) {
_prediction[i] += pred;
_oob[i]++;
}
@@ -488,7 +508,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
stopwatch.reset().start();
Text model = getModel(tree);
- double[] importance = tree.importance();
+ Vector importance = tree.importance();
tree = null; // help GC
int remain = _remainingTasks.decrementAndGet();
boolean lastTask = (remain == 0);
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 61964ae..f1fe7b0 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -25,6 +25,7 @@ import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
@@ -106,7 +107,7 @@ public final class RegressionTree implements Regression<Vector> {
* for the two descendant nodes is less than the parent node. Adding up the decreases for each
* individual variable over the tree gives a simple measure of variable importance.
*/
- private final double[] _importance;
+ private final Vector _importance;
/**
* The root of the regression tree
*/
@@ -782,7 +783,7 @@ public final class RegressionTree implements Regression<Vector> {
}
}
- _importance[node.splitFeature] += node.splitScore;
+ _importance.incr(node.splitFeature, node.splitScore);
return true;
}
@@ -876,7 +877,7 @@ public final class RegressionTree implements Regression<Vector> {
this._minSplit = minSplits;
this._minLeafSize = minLeafSize;
this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
- this._importance = new double[_attributes.length];
+ this._importance = x.isSparse() ? new SparseVector() : new DenseVector(_attributes.length);
this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
this._nodeOutput = output;
@@ -963,7 +964,7 @@ public final class RegressionTree implements Regression<Vector> {
*
* @return the variable importance
*/
- public double[] importance() {
+ public Vector importance() {
return _importance;
}
diff --git a/core/src/test/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTFTest.java
new file mode 100644
index 0000000..6fc56cd
--- /dev/null
+++ b/core/src/test/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTFTest.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.classification;
+
+import hivemall.TestUtils;
+import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF;
+import hivemall.utils.lang.mutable.MutableInt;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+import smile.data.AttributeDataset;
+import smile.data.parser.ArffParser;
+
+import javax.annotation.Nonnull;
+import java.io.BufferedInputStream;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.net.URL;
+import java.text.ParseException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.zip.GZIPInputStream;
+
+public class GradientTreeBoostingClassifierUDTFTest {
+
+ @Test
+ public void testIrisDense() throws IOException, ParseException, HiveException {
+ URL url = new URL(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<Double> xi = new ArrayList<Double>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < x[i].length; j++) {
+ xi.add(j, x[i][j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(490, count.getValue());
+ }
+
+ @Test
+ public void testIrisSparse() throws IOException, ParseException, HiveException {
+ URL url = new URL(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ double[] row = x[i];
+ for (int j = 0; j < row.length; j++) {
+ xi.add(j + ":" + row[j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(490, count.getValue());
+ }
+
+ @Test
+ public void testSerialization() throws HiveException, IOException, ParseException {
+ URL url = new URL(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ final Object[][] rows = new Object[size][2];
+ for (int i = 0; i < size; i++) {
+ double[] row = x[i];
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int j = 0; j < row.length; j++) {
+ xi.add(j + ":" + row[j]);
+ }
+ rows[i][0] = xi;
+ rows[i][1] = y[i];
+ }
+
+ TestUtils.testGenericUDTFSerialization(GradientTreeBoostingClassifierUDTF.class,
+ new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490")},
+ rows);
+ }
+
+ @Nonnull
+ private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+ InputStream is = KernelExpansionPassiveAggressiveUDTF.class.getResourceAsStream(fileName);
+ if (fileName.endsWith(".gz")) {
+ is = new GZIPInputStream(is);
+ }
+ return new BufferedReader(new InputStreamReader(is));
+ }
+}
diff --git a/core/src/test/java/hivemall/smile/regression/RandomForestRegressionUDTFTest.java b/core/src/test/java/hivemall/smile/regression/RandomForestRegressionUDTFTest.java
new file mode 100644
index 0000000..9dc79b6
--- /dev/null
+++ b/core/src/test/java/hivemall/smile/regression/RandomForestRegressionUDTFTest.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.regression;
+
+import hivemall.TestUtils;
+import hivemall.utils.codec.Base91;
+import hivemall.utils.hashing.MurmurHash3;
+import hivemall.utils.lang.mutable.MutableInt;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.Text;
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+import java.io.IOException;
+import java.text.ParseException;
+import java.util.ArrayList;
+import java.util.List;
+
+public class RandomForestRegressionUDTFTest {
+
+ @Test
+ public void testDense() throws IOException, ParseException, HiveException {
+ double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+ RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+ final List<Double> xi = new ArrayList<Double>(x[0].length);
+ for (int i = 0; i < x.length; i++) {
+ for (int j = 0; j < x[i].length; j++) {
+ xi.add(j, x[i][j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(49, count.getValue());
+ }
+
+ @Test
+ public void testSparse() throws IOException, ParseException, HiveException {
+ String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};
+
+ double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+ RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int i = 0; i < x.length; i++) {
+ double[] row = x[i];
+ for (int j = 0; j < row.length; j++) {
+ xi.add(mhash(featureNames[j]) + ":" + row[j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(49, count.getValue());
+ }
+
+ @Test
+ public void testSparseDenseEquals() throws IOException, ParseException, HiveException {
+ RegressionTree.Node denseNode = getRegressionTreeFromDenseInput();
+ RegressionTree.Node sparseNode = getRegressionTreeFromSparseInput();
+
+ double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+ int diff = 0;
+ for (int i = 0; i < x.length; i++) {
+ if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
+ diff++;
+ }
+ }
+
+ Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
+ }
+
+ private static RegressionTree.Node getRegressionTreeFromDenseInput()
+ throws IOException, ParseException, HiveException {
+ double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+ RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+ final List<Double> xi = new ArrayList<Double>(x[0].length);
+ for (int i = 0; i < x.length; i++) {
+ for (int j = 0; j < x[i].length; j++) {
+ xi.add(j, x[i][j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final Text[] placeholder = new Text[1];
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ placeholder[0] = (Text) forward[2];
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Text modelTxt = placeholder[0];
+ Assert.assertNotNull(modelTxt);
+
+ byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+ RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
+ return node;
+ }
+
+ private static RegressionTree.Node getRegressionTreeFromSparseInput()
+ throws IOException, ParseException, HiveException {
+ String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};
+
+ double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+ RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int i = 0; i < x.length; i++) {
+ final double[] row = x[i];
+ for (int j = 0; j < row.length; j++) {
+ xi.add(mhash(featureNames[j]) + ":" + row[j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final Text[] placeholder = new Text[1];
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ placeholder[0] = (Text) forward[2];
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Text modelTxt = placeholder[0];
+ Assert.assertNotNull(modelTxt);
+
+ byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+ RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
+ return node;
+ }
+
+ @Test
+ public void testSerialization() throws HiveException, IOException, ParseException {
+ String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};
+
+ double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+ final Object[][] rows = new Object[x.length][2];
+ for (int i = 0; i < x.length; i++) {
+ double[] row = x[i];
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int j = 0; j < row.length; j++) {
+ xi.add(mhash(featureNames[j]) + ":" + row[j]);
+ }
+ rows[i][0] = xi;
+ rows[i][1] = y[i];
+ }
+
+ TestUtils.testGenericUDTFSerialization(RandomForestRegressionUDTF.class,
+ new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49")},
+ rows);
+ }
+
+ private static int mhash(@Nonnull final String word) {
+ final int n = 16777217; // 2^24
+ int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % n;
+ if (r < 0) {
+ r += n;
+ }
+ return r + 1;
+ }
+
+}