You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/02/05 04:56:01 UTC

[incubator-hivemall] branch master updated: [HIVEMALL-233-2] RandomForest regressor accepts sparse vector input

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e1104c  [HIVEMALL-233-2] RandomForest regressor accepts sparse vector input
2e1104c is described below

commit 2e1104c1eeb4598ba5cc8e74dfba5d36699344f3
Author: Takuya Kitazawa <k....@gmail.com>
AuthorDate: Tue Feb 5 13:55:55 2019 +0900

    [HIVEMALL-233-2] RandomForest regressor accepts sparse vector input
    
    ## What changes were proposed in this pull request?
    
    Enable RandomForestRegressor to accept sparse vector input as RandomForestClassifier already does.
    
    This closes #178
    
    ## What type of PR is it?
    
    Improvement
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-233
    
    ## How was this patch tested?
    
    manual tests on EMR
    
    ## How to use this feature?
    
    ```sql
    with customers as (
      select 1 as id, "male" as gender, 23 as age, "Japan" as country, 12 as num_purchases
      union all
      select 2 as id, "female" as gender, 43 as age, "US" as country, 4 as num_purchases
      union all
      select 3 as id, "other" as gender, 19 as age, "UK" as country, 2 as num_purchases
      union all
      select 4 as id, "male" as gender, 31 as age, "US" as country, 20 as num_purchases
      union all
      select 5 as id, "female" as gender, 37 as age, "Australia" as country, 9 as num_purchases
    ),
    training as (
      select
        array_concat(
          quantitative_features(
            array("age"),
            age
          ),
          categorical_features(
            array("country", "gender"),
            country, gender
          )
        ) as features,
        num_purchases
      from
        customers
    )
    select
      train_randomforest_regressor(
        feature_hashing(features), -- feature vector
        num_purchases, -- target value
        '-trees 40 -seed 31' -- hyper-parameters
      )
    from
      training
    ;
    ```
    
    ## Checklist
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
    - [ ] Did you run system tests on Hive (or Spark)?
    
    Author: Takuya Kitazawa <k....@gmail.com>
    Author: Makoto Yui <my...@apache.org>
    
    Closes #181 from myui/HIVEMALL-233-2.
---
 .../GradientTreeBoostingClassifierUDTF.java        |  36 ++-
 .../regression/RandomForestRegressionUDTF.java     |  44 ++-
 .../hivemall/smile/regression/RegressionTree.java  |   9 +-
 .../GradientTreeBoostingClassifierUDTFTest.java    | 181 ++++++++++++
 .../regression/RandomForestRegressionUDTFTest.java | 308 +++++++++++++++++++++
 5 files changed, 555 insertions(+), 23 deletions(-)

diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 1a4fa1c..9edc63d 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -26,7 +26,10 @@ import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
 import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
 import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
 import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
 import hivemall.smile.data.Attribute;
 import hivemall.smile.regression.RegressionTree;
 import hivemall.smile.utils.SmileExtUtils;
@@ -40,6 +43,8 @@ import hivemall.utils.math.MathUtils;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.BitSet;
+import java.util.Map;
+import java.util.HashMap;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
@@ -227,8 +232,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
         fieldNames.add("shrinkage");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
         fieldNames.add("var_importance");
-        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
-            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        if (denseInput) {
+            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        } else {
+            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
+                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        }
         fieldNames.add("oob_error_rate");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
 
@@ -555,11 +566,11 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
             final float oobErrorRate, @Nonnull final RegressionTree... trees) throws HiveException {
         Text[] models = getModel(trees);
 
-        double[] importance = new double[_attributes.length];
+        Vector importance = denseInput ? new DenseVector(_attributes.length) : new SparseVector();
         for (RegressionTree tree : trees) {
-            double[] imp = tree.importance();
-            for (int i = 0; i < imp.length; i++) {
-                importance[i] += imp[i];
+            Vector imp = tree.importance();
+            for (int i = 0, size = imp.size(); i < size; i++) {
+                importance.incr(i, imp.get(i));
             }
         }
 
@@ -568,7 +579,18 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
         forwardObjs[1] = models;
         forwardObjs[2] = new DoubleWritable(intercept);
         forwardObjs[3] = new DoubleWritable(shrinkage);
-        forwardObjs[4] = WritableUtils.toWritableList(importance);
+        if (denseInput) {
+            forwardObjs[4] = WritableUtils.toWritableList(importance.toArray());
+        } else {
+            final Map<IntWritable, DoubleWritable> map =
+                    new HashMap<IntWritable, DoubleWritable>(importance.size());
+            importance.each(new VectorProcedure() {
+                public void apply(int i, double value) {
+                    map.put(new IntWritable(i), new DoubleWritable(value));
+                }
+            });
+            forwardObjs[4] = map;
+        }
         forwardObjs[5] = new FloatWritable(oobErrorRate);
 
         forward(forwardObjs);
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index df5d55b..dc148e2 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -27,6 +27,7 @@ import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
 import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
 import hivemall.smile.data.Attribute;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.smile.utils.SmileTaskExecutor;
@@ -40,7 +41,9 @@ import hivemall.utils.lang.RandomUtils;
 
 import java.util.ArrayList;
 import java.util.BitSet;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Callable;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -71,10 +74,10 @@ import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.Counters.Counter;
 import org.apache.hadoop.mapred.Reporter;
 
-@Description(name = "train_randomforest_regression",
+@Description(name = "train_randomforest_regressor",
         value = "_FUNC_(array<double|string> features, double target [, string options]) - "
                 + "Returns a relation consists of "
-                + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
+                + "<int model_id, int model_type, string model, array<double> var_importance, double oob_errors, int oob_tests>")
 public final class RandomForestRegressionUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(RandomForestRegressionUDTF.class);
 
@@ -206,18 +209,24 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
 
         this.targets = new DoubleArrayList(1024);
 
-        ArrayList<String> fieldNames = new ArrayList<String>(5);
-        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(5);
+        final ArrayList<String> fieldNames = new ArrayList<String>(6);
+        final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
 
         fieldNames.add("model_id");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
         fieldNames.add("model_err");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
-        fieldNames.add("pred_model");
+        fieldNames.add("model");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
         fieldNames.add("var_importance");
-        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
-            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        if (denseInput) {
+            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        } else {
+            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
+                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        }
         fieldNames.add("oob_errors");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
         fieldNames.add("oob_tests");
@@ -359,7 +368,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
      * @param error
      */
     synchronized void forward(final int taskId, @Nonnull final Text model,
-            @Nonnull final double[] importance, @Nonnegative final double error, final double[] y,
+            @Nonnull final Vector importance, @Nonnegative final double error, final double[] y,
             final double[] prediction, final int[] oob, final boolean lastTask)
             throws HiveException {
         double oobErrors = 0.d;
@@ -379,7 +388,18 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         forwardObjs[0] = new Text(modelId);
         forwardObjs[1] = new DoubleWritable(error);
         forwardObjs[2] = model;
-        forwardObjs[3] = WritableUtils.toWritableList(importance);
+        if (denseInput) {
+            forwardObjs[3] = WritableUtils.toWritableList(importance.toArray());
+        } else {
+            final Map<IntWritable, DoubleWritable> map =
+                    new HashMap<IntWritable, DoubleWritable>(importance.size());
+            importance.each(new VectorProcedure() {
+                public void apply(int i, double value) {
+                    map.put(new IntWritable(i), new DoubleWritable(value));
+                }
+            });
+            forwardObjs[3] = map;
+        }
         forwardObjs[4] = new DoubleWritable(oobErrors);
         forwardObjs[5] = new IntWritable(oobTests);
         forward(forwardObjs);
@@ -403,7 +423,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
          */
         private final Matrix _x;
         /**
-         * Training sample labels.
+         * Training sample target values.
          */
         private final double[] _y;
         /**
@@ -476,7 +496,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
                 oob++;
                 _x.getRow(i, xProbe);
                 final double pred = tree.predict(xProbe);
-                synchronized (_prediction) {
+                synchronized (_udtf) {
                     _prediction[i] += pred;
                     _oob[i]++;
                 }
@@ -488,7 +508,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
 
             stopwatch.reset().start();
             Text model = getModel(tree);
-            double[] importance = tree.importance();
+            Vector importance = tree.importance();
             tree = null; // help GC
             int remain = _remainingTasks.decrementAndGet();
             boolean lastTask = (remain == 0);
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 61964ae..f1fe7b0 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -25,6 +25,7 @@ import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
 import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
 import hivemall.math.vector.Vector;
 import hivemall.math.vector.VectorProcedure;
 import hivemall.smile.data.Attribute;
@@ -106,7 +107,7 @@ public final class RegressionTree implements Regression<Vector> {
      * for the two descendant nodes is less than the parent node. Adding up the decreases for each
      * individual variable over the tree gives a simple measure of variable importance.
      */
-    private final double[] _importance;
+    private final Vector _importance;
     /**
      * The root of the regression tree
      */
@@ -782,7 +783,7 @@ public final class RegressionTree implements Regression<Vector> {
                 }
             }
 
-            _importance[node.splitFeature] += node.splitScore;
+            _importance.incr(node.splitFeature, node.splitScore);
 
             return true;
         }
@@ -876,7 +877,7 @@ public final class RegressionTree implements Regression<Vector> {
         this._minSplit = minSplits;
         this._minLeafSize = minLeafSize;
         this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
-        this._importance = new double[_attributes.length];
+        this._importance = x.isSparse() ? new SparseVector() : new DenseVector(_attributes.length);
         this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
         this._nodeOutput = output;
 
@@ -963,7 +964,7 @@ public final class RegressionTree implements Regression<Vector> {
      *
      * @return the variable importance
      */
-    public double[] importance() {
+    public Vector importance() {
         return _importance;
     }
 
diff --git a/core/src/test/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTFTest.java
new file mode 100644
index 0000000..6fc56cd
--- /dev/null
+++ b/core/src/test/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTFTest.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.classification;
+
+import hivemall.TestUtils;
+import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF;
+import hivemall.utils.lang.mutable.MutableInt;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+import smile.data.AttributeDataset;
+import smile.data.parser.ArffParser;
+
+import javax.annotation.Nonnull;
+import java.io.BufferedInputStream;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.net.URL;
+import java.text.ParseException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.zip.GZIPInputStream;
+
+public class GradientTreeBoostingClassifierUDTFTest {
+
+    @Test
+    public void testIrisDense() throws IOException, ParseException, HiveException {
+        URL url = new URL(
+            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+        InputStream is = new BufferedInputStream(url.openStream());
+
+        ArffParser arffParser = new ArffParser();
+        arffParser.setResponseIndex(4);
+
+        AttributeDataset iris = arffParser.parse(is);
+        int size = iris.size();
+        double[][] x = iris.toArray(new double[size][]);
+        int[] y = iris.toArray(new int[size]);
+
+        GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+        final List<Double> xi = new ArrayList<Double>(x[0].length);
+        for (int i = 0; i < size; i++) {
+            for (int j = 0; j < x[i].length; j++) {
+                xi.add(j, x[i][j]);
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final MutableInt count = new MutableInt(0);
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                count.addValue(1);
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Assert.assertEquals(490, count.getValue());
+    }
+
+    @Test
+    public void testIrisSparse() throws IOException, ParseException, HiveException {
+        URL url = new URL(
+            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+        InputStream is = new BufferedInputStream(url.openStream());
+
+        ArffParser arffParser = new ArffParser();
+        arffParser.setResponseIndex(4);
+
+        AttributeDataset iris = arffParser.parse(is);
+        int size = iris.size();
+        double[][] x = iris.toArray(new double[size][]);
+        int[] y = iris.toArray(new int[size]);
+
+        GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+        final List<String> xi = new ArrayList<String>(x[0].length);
+        for (int i = 0; i < size; i++) {
+            double[] row = x[i];
+            for (int j = 0; j < row.length; j++) {
+                xi.add(j + ":" + row[j]);
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final MutableInt count = new MutableInt(0);
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                count.addValue(1);
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Assert.assertEquals(490, count.getValue());
+    }
+
+    @Test
+    public void testSerialization() throws HiveException, IOException, ParseException {
+        URL url = new URL(
+            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+        InputStream is = new BufferedInputStream(url.openStream());
+
+        ArffParser arffParser = new ArffParser();
+        arffParser.setResponseIndex(4);
+
+        AttributeDataset iris = arffParser.parse(is);
+        int size = iris.size();
+        double[][] x = iris.toArray(new double[size][]);
+        int[] y = iris.toArray(new int[size]);
+
+        final Object[][] rows = new Object[size][2];
+        for (int i = 0; i < size; i++) {
+            double[] row = x[i];
+            final List<String> xi = new ArrayList<String>(x[0].length);
+            for (int j = 0; j < row.length; j++) {
+                xi.add(j + ":" + row[j]);
+            }
+            rows[i][0] = xi;
+            rows[i][1] = y[i];
+        }
+
+        TestUtils.testGenericUDTFSerialization(GradientTreeBoostingClassifierUDTF.class,
+            new ObjectInspector[] {
+                    ObjectInspectorFactory.getStandardListObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                    PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+                    ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490")},
+            rows);
+    }
+
+    @Nonnull
+    private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+        InputStream is = KernelExpansionPassiveAggressiveUDTF.class.getResourceAsStream(fileName);
+        if (fileName.endsWith(".gz")) {
+            is = new GZIPInputStream(is);
+        }
+        return new BufferedReader(new InputStreamReader(is));
+    }
+}
diff --git a/core/src/test/java/hivemall/smile/regression/RandomForestRegressionUDTFTest.java b/core/src/test/java/hivemall/smile/regression/RandomForestRegressionUDTFTest.java
new file mode 100644
index 0000000..9dc79b6
--- /dev/null
+++ b/core/src/test/java/hivemall/smile/regression/RandomForestRegressionUDTFTest.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.regression;
+
+import hivemall.TestUtils;
+import hivemall.utils.codec.Base91;
+import hivemall.utils.hashing.MurmurHash3;
+import hivemall.utils.lang.mutable.MutableInt;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.Text;
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+import java.io.IOException;
+import java.text.ParseException;
+import java.util.ArrayList;
+import java.util.List;
+
+public class RandomForestRegressionUDTFTest {
+
+    @Test
+    public void testDense() throws IOException, ParseException, HiveException {
+        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+        final List<Double> xi = new ArrayList<Double>(x[0].length);
+        for (int i = 0; i < x.length; i++) {
+            for (int j = 0; j < x[i].length; j++) {
+                xi.add(j, x[i][j]);
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final MutableInt count = new MutableInt(0);
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                count.addValue(1);
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Assert.assertEquals(49, count.getValue());
+    }
+
+    @Test
+    public void testSparse() throws IOException, ParseException, HiveException {
+        String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};
+
+        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+        final List<String> xi = new ArrayList<String>(x[0].length);
+        for (int i = 0; i < x.length; i++) {
+            double[] row = x[i];
+            for (int j = 0; j < row.length; j++) {
+                xi.add(mhash(featureNames[j]) + ":" + row[j]);
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final MutableInt count = new MutableInt(0);
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                count.addValue(1);
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Assert.assertEquals(49, count.getValue());
+    }
+
+    @Test
+    public void testSparseDenseEquals() throws IOException, ParseException, HiveException {
+        RegressionTree.Node denseNode = getRegressionTreeFromDenseInput();
+        RegressionTree.Node sparseNode = getRegressionTreeFromSparseInput();
+
+        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+        int diff = 0;
+        for (int i = 0; i < x.length; i++) {
+            if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
+                diff++;
+            }
+        }
+
+        Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
+    }
+
+    private static RegressionTree.Node getRegressionTreeFromDenseInput()
+            throws IOException, ParseException, HiveException {
+        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+        final List<Double> xi = new ArrayList<Double>(x[0].length);
+        for (int i = 0; i < x.length; i++) {
+            for (int j = 0; j < x[i].length; j++) {
+                xi.add(j, x[i][j]);
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final Text[] placeholder = new Text[1];
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                Object[] forward = (Object[]) input;
+                placeholder[0] = (Text) forward[2];
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Text modelTxt = placeholder[0];
+        Assert.assertNotNull(modelTxt);
+
+        byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+        RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
+        return node;
+    }
+
+    private static RegressionTree.Node getRegressionTreeFromSparseInput()
+            throws IOException, ParseException, HiveException {
+        String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};
+
+        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});
+
+        final List<String> xi = new ArrayList<String>(x[0].length);
+        for (int i = 0; i < x.length; i++) {
+            final double[] row = x[i];
+            for (int j = 0; j < row.length; j++) {
+                xi.add(mhash(featureNames[j]) + ":" + row[j]);
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final Text[] placeholder = new Text[1];
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                Object[] forward = (Object[]) input;
+                placeholder[0] = (Text) forward[2];
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Text modelTxt = placeholder[0];
+        Assert.assertNotNull(modelTxt);
+
+        byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+        RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
+        return node;
+    }
+
+    @Test
+    public void testSerialization() throws HiveException, IOException, ParseException {
+        String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};
+
+        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};
+
+        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};
+
+        final Object[][] rows = new Object[x.length][2];
+        for (int i = 0; i < x.length; i++) {
+            double[] row = x[i];
+            final List<String> xi = new ArrayList<String>(x[0].length);
+            for (int j = 0; j < row.length; j++) {
+                xi.add(mhash(featureNames[j]) + ":" + row[j]);
+            }
+            rows[i][0] = xi;
+            rows[i][1] = y[i];
+        }
+
+        TestUtils.testGenericUDTFSerialization(RandomForestRegressionUDTF.class,
+            new ObjectInspector[] {
+                    ObjectInspectorFactory.getStandardListObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+                    ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49")},
+            rows);
+    }
+
+    private static int mhash(@Nonnull final String word) {
+        final int n = 16777217; // 2^24
+        int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % n;
+        if (r < 0) {
+            r += n;
+        }
+        return r + 1;
+    }
+
+}