You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:21 UTC

[09/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75] Support Sparse Vector Format as the input of RandomForest

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 03db65c..5a831df 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -19,30 +19,43 @@
 package hivemall.smile.classification;
 
 import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.MatrixUtils;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.matrix.ints.DoKIntMatrix;
+import hivemall.math.matrix.ints.IntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
 import hivemall.smile.classification.DecisionTree.SplitRule;
 import hivemall.smile.data.Attribute;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.smile.utils.SmileTaskExecutor;
-import hivemall.smile.vm.StackMachine;
 import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
 import hivemall.utils.hadoop.HiveUtils;
 import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
+import hivemall.utils.lang.Preconditions;
 import hivemall.utils.lang.Primitives;
 import hivemall.utils.lang.RandomUtils;
 
-import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.BitSet;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Callable;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
@@ -52,7 +65,9 @@ import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.MapredContext;
 import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -67,9 +82,9 @@ import org.apache.hadoop.mapred.Reporter;
 
 @Description(
         name = "train_randomforest_classifier",
-        value = "_FUNC_(double[] features, int label [, string options]) - "
+        value = "_FUNC_(array<double|string> features, int label [, const array<double> classWeights, const string options]) - "
                 + "Returns a relation consists of "
-                + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
+                + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests, double weight>")
 public final class RandomForestClassifierUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(RandomForestClassifierUDTF.class);
 
@@ -77,8 +92,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
     private PrimitiveObjectInspector featureElemOI;
     private PrimitiveObjectInspector labelOI;
 
-    private List<double[]> featuresList;
+    private boolean denseInput;
+    private MatrixBuilder matrixBuilder;
     private IntArrayList labels;
+
     /**
      * The number of trees for each task
      */
@@ -99,8 +116,12 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
     private int _minSamplesLeaf;
     private long _seed;
     private Attribute[] _attributes;
-    private ModelType _outputType;
     private SplitRule _splitRule;
+    private boolean _stratifiedSampling;
+    private double _subsample;
+
+    @Nullable
+    private double[] _classWeight;
 
     @Nullable
     private Reporter _progressReporter;
@@ -126,11 +147,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
         opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
         opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
                 + "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
-        opts.addOption("output", "output_type", true,
-            "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
         opts.addOption("rule", "split_rule", true, "Split algorithm [default: GINI, ENTROPY]");
-        opts.addOption("disable_compression", false,
-            "Whether to disable compression of the output script [default: false]");
+        opts.addOption("stratified", "stratified_sampling", false,
+            "Enable Stratified sampling for unbalanced data");
+        opts.addOption("subsample", true, "Sampling rate in range (0.0,1.0]");
         return opts;
     }
 
@@ -141,9 +161,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
         float numVars = -1.f;
         Attribute[] attrs = null;
         long seed = -1L;
-        String output = "serialization";
         SplitRule splitRule = SplitRule.GINI;
-        boolean compress = true;
+        double[] classWeight = null;
+        boolean stratifiedSampling = false;
+        double subsample = 1.0d;
 
         CommandLine cl = null;
         if (argOIs.length >= 3) {
@@ -162,10 +183,26 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
                 minSamplesLeaf);
             seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
             attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
-            output = cl.getOptionValue("output", output);
             splitRule = SmileExtUtils.resolveSplitRule(cl.getOptionValue("split_rule", "GINI"));
-            if (cl.hasOption("disable_compression")) {
-                compress = false;
+            stratifiedSampling = cl.hasOption("stratified_sampling");
+            subsample = Primitives.parseDouble(cl.getOptionValue("subsample"), 1.0d);
+            Preconditions.checkArgument(subsample > 0.d && subsample <= 1.0d,
+                UDFArgumentException.class, "Invalid -subsample value: " + subsample);
+
+            if (argOIs.length >= 4) {
+                classWeight = HiveUtils.getConstDoubleArray(argOIs[3]);
+                if (classWeight != null) {
+                    for (int i = 0; i < classWeight.length; i++) {
+                        double v = classWeight[i];
+                        if (Double.isNaN(v)) {
+                            classWeight[i] = 1.0d;
+                        } else if (v <= 0.d) {
+                            throw new UDFArgumentTypeException(3,
+                                "each classWeight must be greather than 0: "
+                                        + Arrays.toString(classWeight));
+                        }
+                    }
+                }
             }
         }
 
@@ -177,43 +214,60 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
         this._minSamplesLeaf = minSamplesLeaf;
         this._seed = seed;
         this._attributes = attrs;
-        this._outputType = ModelType.resolve(output, compress);
         this._splitRule = splitRule;
+        this._stratifiedSampling = stratifiedSampling;
+        this._subsample = subsample;
+        this._classWeight = classWeight;
 
         return cl;
     }
 
     @Override
     public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
-        if (argOIs.length != 2 && argOIs.length != 3) {
+        if (argOIs.length < 2 || argOIs.length > 4) {
             throw new UDFArgumentException(
-                getClass().getSimpleName()
-                        + " takes 2 or 3 arguments: double[] features, int label [, const string options]: "
+                "_FUNC_ takes 2 ~ 4 arguments: array<double|string> features, int label [, const string options, const array<double> classWeight]: "
                         + argOIs.length);
         }
 
         ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
         ObjectInspector elemOI = listOI.getListElementObjectInspector();
         this.featureListOI = listOI;
-        this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+        if (HiveUtils.isNumberOI(elemOI)) {
+            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+            this.denseInput = true;
+            this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+        } else if (HiveUtils.isStringOI(elemOI)) {
+            this.featureElemOI = HiveUtils.asStringOI(elemOI);
+            this.denseInput = false;
+            this.matrixBuilder = new CSRMatrixBuilder(8192);
+        } else {
+            throw new UDFArgumentException(
+                "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+        }
         this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
 
         processOptions(argOIs);
 
-        this.featuresList = new ArrayList<double[]>(1024);
         this.labels = new IntArrayList(1024);
 
-        ArrayList<String> fieldNames = new ArrayList<String>(6);
-        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
+        final ArrayList<String> fieldNames = new ArrayList<String>(6);
+        final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
 
         fieldNames.add("model_id");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
-        fieldNames.add("model_type");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
-        fieldNames.add("pred_model");
+        fieldNames.add("model_weight");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+        fieldNames.add("model");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
         fieldNames.add("var_importance");
-        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        if (denseInput) {
+            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        } else {
+            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
+                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        }
         fieldNames.add("oob_errors");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
         fieldNames.add("oob_tests");
@@ -227,13 +281,36 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
         if (args[0] == null) {
             throw new HiveException("array<double> features was null");
         }
-        double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+        parseFeatures(args[0], matrixBuilder);
         int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
-
-        featuresList.add(features);
         labels.add(label);
     }
 
+    private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+        if (denseInput) {
+            final int length = featureListOI.getListLength(argObj);
+            for (int i = 0; i < length; i++) {
+                Object o = featureListOI.getListElement(argObj, i);
+                if (o == null) {
+                    continue;
+                }
+                double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+                builder.nextColumn(i, v);
+            }
+        } else {
+            final int length = featureListOI.getListLength(argObj);
+            for (int i = 0; i < length; i++) {
+                Object o = featureListOI.getListElement(argObj, i);
+                if (o == null) {
+                    continue;
+                }
+                String fv = o.toString();
+                builder.nextColumn(fv);
+            }
+        }
+        builder.nextRow();
+    }
+
     @Override
     public void close() throws HiveException {
         this._progressReporter = getReporter();
@@ -242,10 +319,9 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
                     "finishedTreeBuildTasks");
         reportProgress(_progressReporter);
 
-        int numExamples = featuresList.size();
-        if (numExamples > 0) {
-            double[][] x = featuresList.toArray(new double[numExamples][]);
-            this.featuresList = null;
+        if (!labels.isEmpty()) {
+            Matrix x = matrixBuilder.buildMatrix();
+            this.matrixBuilder = null;
             int[] y = labels.toArray();
             this.labels = null;
 
@@ -277,15 +353,16 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
      * @param numVars The number of variables to pick up in each node.
      * @param seed The seed number for Random Forest
      */
-    private void train(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException {
-        if (x.length != y.length) {
+    private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException {
+        final int numExamples = x.numRows();
+        if (numExamples != y.length) {
             throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
-                x.length, y.length));
+                numExamples, y.length));
         }
         checkOptions();
 
-        // Shuffle training samples    
-        SmileExtUtils.shuffle(x, y, _seed);
+        // Shuffle training samples
+        x = SmileExtUtils.shuffle(x, y, _seed);
 
         int[] labels = SmileExtUtils.classLables(y);
         Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
@@ -297,9 +374,8 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
                     + _maxLeafNodes + ", splitRule: " + _splitRule + ", seed: " + _seed);
         }
 
-        final int numExamples = x.length;
-        int[][] prediction = new int[numExamples][labels.length]; // placeholder for out-of-bag prediction
-        int[][] order = SmileExtUtils.sort(attributes, x);
+        IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder for out-of-bag prediction
+        ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
         AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
         List<TrainingTask> tasks = new ArrayList<TrainingTask>();
         for (int i = 0; i < _numTrees; i++) {
@@ -321,17 +397,19 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
 
     /**
      * Synchronized because {@link #forward(Object)} should be called from a single thread.
+     * 
+     * @param accuracy
      */
     synchronized void forward(final int taskId, @Nonnull final Text model,
-            @Nonnull final double[] importance, final int[] y, final int[][] prediction,
-            final boolean lastTask) throws HiveException {
+            @Nonnull final Vector importance, @Nonnegative final double accuracy, final int[] y,
+            @Nonnull final IntMatrix prediction, final boolean lastTask) throws HiveException {
         int oobErrors = 0;
         int oobTests = 0;
         if (lastTask) {
             // out-of-bag error estimate
             for (int i = 0; i < y.length; i++) {
-                final int pred = smile.math.Math.whichMax(prediction[i]);
-                if (prediction[i][pred] > 0) {
+                final int pred = MatrixUtils.whichMax(prediction, i);
+                if (pred != -1 && prediction.get(i, pred) > 0) {
                     oobTests++;
                     if (pred != y[i]) {
                         oobErrors++;
@@ -340,12 +418,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
             }
         }
 
-        String modelId = RandomUtils.getUUID();
         final Object[] forwardObjs = new Object[6];
+        String modelId = RandomUtils.getUUID();
         forwardObjs[0] = new Text(modelId);
-        forwardObjs[1] = new IntWritable(_outputType.getId());
+        forwardObjs[1] = new DoubleWritable(accuracy);
         forwardObjs[2] = model;
-        forwardObjs[3] = WritableUtils.toWritableList(importance);
+        if (denseInput) {
+            forwardObjs[3] = WritableUtils.toWritableList(importance.toArray());
+        } else {
+            final Map<IntWritable, DoubleWritable> map = new HashMap<IntWritable, DoubleWritable>(
+                importance.size());
+            importance.each(new VectorProcedure() {
+                public void apply(int i, double value) {
+                    map.put(new IntWritable(i), new DoubleWritable(value));
+                }
+            });
+            forwardObjs[3] = map;
+        }
         forwardObjs[4] = new IntWritable(oobErrors);
         forwardObjs[5] = new IntWritable(oobTests);
         forward(forwardObjs);
@@ -363,20 +452,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
         /**
          * Attribute properties.
          */
+        @Nonnull
         private final Attribute[] _attributes;
         /**
          * Training instances.
          */
-        private final double[][] _x;
+        @Nonnull
+        private final Matrix _x;
         /**
          * Training sample labels.
          */
+        @Nonnull
         private final int[] _y;
         /**
-         * The index of training values in ascending order. Note that only numeric attributes will
-         * be sorted.
+         * The index of training values in ascending order. Note that only numeric attributes will be sorted.
          */
-        private final int[][] _order;
+        @Nonnull
+        private final ColumnMajorIntMatrix _order;
         /**
          * The number of variables to pick up in each node.
          */
@@ -384,16 +476,21 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
         /**
          * The out-of-bag predictions.
          */
-        private final int[][] _prediction;
+        @Nonnull
+        @GuardedBy("_udtf")
+        private final IntMatrix _prediction;
 
+        @Nonnull
         private final RandomForestClassifierUDTF _udtf;
         private final int _taskId;
         private final long _seed;
+        @Nonnull
         private final AtomicInteger _remainingTasks;
 
-        TrainingTask(RandomForestClassifierUDTF udtf, int taskId, Attribute[] attributes,
-                double[][] x, int[] y, int numVars, int[][] order, int[][] prediction, long seed,
-                AtomicInteger remainingTasks) {
+        TrainingTask(@Nonnull RandomForestClassifierUDTF udtf, int taskId,
+                @Nonnull Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y, int numVars,
+                @Nonnull ColumnMajorIntMatrix order, @Nonnull IntMatrix prediction, long seed,
+                @Nonnull AtomicInteger remainingTasks) {
             this._udtf = udtf;
             this._taskId = taskId;
             this._attributes = attributes;
@@ -408,98 +505,107 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
 
         @Override
         public Integer call() throws HiveException {
-            long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random(
-                _seed).nextLong();
-            final smile.math.Random rnd1 = new smile.math.Random(s);
-            final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
-            final int N = _x.length;
+            long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
+                    : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+            final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+            final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+            final int N = _x.numRows();
 
             // Training samples draw with replacement.
-            final int[] bags = new int[N];
             final BitSet sampled = new BitSet(N);
-            for (int i = 0; i < N; i++) {
-                int index = rnd1.nextInt(N);
-                bags[i] = index;
-                sampled.set(index);
-            }
+            final int[] bags = sampling(sampled, N, rnd1);
 
             DecisionTree tree = new DecisionTree(_attributes, _x, _y, _numVars, _udtf._maxDepth,
                 _udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, bags, _order,
                 _udtf._splitRule, rnd2);
 
             // out-of-bag prediction
+            int oob = 0;
+            int correct = 0;
+            final Vector xProbe = _x.rowVector();
             for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
-                final int p = tree.predict(_x[i]);
-                synchronized (_prediction[i]) {
-                    _prediction[i][p]++;
+                oob++;
+                _x.getRow(i, xProbe);
+                final int p = tree.predict(xProbe);
+                if (p == _y[i]) {
+                    correct++;
+                }
+                synchronized (_udtf) {
+                    _prediction.incr(i, p);
                 }
             }
 
-            Text model = getModel(tree, _udtf._outputType);
-            double[] importance = tree.importance();
+            Text model = getModel(tree);
+            Vector importance = tree.importance();
+            double accuracy = (oob == 0) ? 1.0d : (double) correct / oob;
             int remain = _remainingTasks.decrementAndGet();
             boolean lastTask = (remain == 0);
-            _udtf.forward(_taskId + 1, model, importance, _y, _prediction, lastTask);
+            _udtf.forward(_taskId + 1, model, importance, accuracy, _y, _prediction, lastTask);
 
             return Integer.valueOf(remain);
         }
 
-        private static Text getModel(@Nonnull final DecisionTree tree,
-                @Nonnull final ModelType outputType) throws HiveException {
-            final Text model;
-            switch (outputType) {
-                case serialization:
-                case serialization_compressed: {
-                    byte[] b = tree.predictSerCodegen(outputType.isCompressed());
-                    b = Base91.encode(b);
-                    model = new Text(b);
-                    break;
-                }
-                case opscode:
-                case opscode_compressed: {
-                    String s = tree.predictOpCodegen(StackMachine.SEP);
-                    if (outputType.isCompressed()) {
-                        byte[] b = s.getBytes();
-                        final DeflateCodec codec = new DeflateCodec(true, false);
-                        try {
-                            b = codec.compress(b);
-                        } catch (IOException e) {
-                            throw new HiveException("Failed to compressing a model", e);
-                        } finally {
-                            IOUtils.closeQuietly(codec);
-                        }
-                        b = Base91.encode(b);
-                        model = new Text(b);
-                    } else {
-                        model = new Text(s);
+        @Nonnull
+        private int[] sampling(@Nonnull final BitSet sampled, final int N, @Nonnull PRNG rnd) {
+            return _udtf._stratifiedSampling ? stratifiedSampling(sampled, N, _udtf._subsample, rnd)
+                    : uniformSampling(sampled, N, _udtf._subsample, rnd);
+        }
+
+        @Nonnull
+        private static int[] uniformSampling(@Nonnull final BitSet sampled, final int N,
+                final double subsample, final PRNG rnd) {
+            final int size = (int) Math.round(N * subsample);
+            final int[] bags = new int[N];
+            for (int i = 0; i < size; i++) {
+                int index = rnd.nextInt(N);
+                bags[i] = index;
+                sampled.set(index);
+            }
+            return bags;
+        }
+
+        /**
+         * Stratified sampling for unbalanced data.
+         * 
+         * @link https://en.wikipedia.org/wiki/Stratified_sampling
+         */
+        @Nonnull
+        private int[] stratifiedSampling(@Nonnull final BitSet sampled, final int N,
+                final double subsample, final PRNG rnd) {
+            final IntArrayList bagsList = new IntArrayList(N);
+            final int k = smile.math.Math.max(_y) + 1;
+            final IntArrayList cj = new IntArrayList(N / k);
+            for (int l = 0; l < k; l++) {
+                int nj = 0;
+                for (int i = 0; i < N; i++) {
+                    if (_y[i] == l) {
+                        cj.add(i);
+                        nj++;
                     }
-                    break;
                 }
-                case javascript:
-                case javascript_compressed: {
-                    String s = tree.predictJsCodegen();
-                    if (outputType.isCompressed()) {
-                        byte[] b = s.getBytes();
-                        final DeflateCodec codec = new DeflateCodec(true, false);
-                        try {
-                            b = codec.compress(b);
-                        } catch (IOException e) {
-                            throw new HiveException("Failed to compressing a model", e);
-                        } finally {
-                            IOUtils.closeQuietly(codec);
-                        }
-                        b = Base91.encode(b);
-                        model = new Text(b);
-                    } else {
-                        model = new Text(s);
-                    }
-                    break;
+                if (subsample != 1.0d) {
+                    nj = (int) Math.round(nj * subsample);
+                }
+                final int size = (_udtf._classWeight == null) ? nj : (int) Math.round(nj
+                        * _udtf._classWeight[l]);
+                for (int j = 0; j < size; j++) {
+                    int xi = rnd.nextInt(nj);
+                    int index = cj.get(xi);
+                    bagsList.add(index);
+                    sampled.set(index);
                 }
-                default:
-                    throw new HiveException("Unexpected output type: " + outputType
-                            + ". Use javascript for the output instead");
+                cj.clear();
             }
-            return model;
+            int[] bags = bagsList.toArray(true);
+            SmileExtUtils.shuffle(bags, rnd);
+            return bags;
+        }
+
+        @Nonnull
+        private static Text getModel(@Nonnull final DecisionTree tree) throws HiveException {
+            byte[] b = tree.predictSerCodegen(true);
+            b = Base91.encode(b);
+            return new Text(b);
         }
 
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/data/Attribute.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/data/Attribute.java b/core/src/main/java/hivemall/smile/data/Attribute.java
index be6651a..6569726 100644
--- a/core/src/main/java/hivemall/smile/data/Attribute.java
+++ b/core/src/main/java/hivemall/smile/data/Attribute.java
@@ -18,6 +18,9 @@
  */
 package hivemall.smile.data;
 
+import hivemall.annotations.Immutable;
+import hivemall.annotations.Mutable;
+
 import java.io.IOException;
 import java.io.ObjectInput;
 import java.io.ObjectOutput;
@@ -25,11 +28,9 @@ import java.io.ObjectOutput;
 public abstract class Attribute {
 
     public final AttributeType type;
-    public final int attrIndex;
 
-    Attribute(AttributeType type, int attrIndex) {
+    Attribute(AttributeType type) {
         this.type = type;
-        this.attrIndex = attrIndex;
     }
 
     public void setSize(int size) {
@@ -44,24 +45,23 @@ public abstract class Attribute {
     }
 
     public void writeTo(ObjectOutput out) throws IOException {
-        out.writeInt(type.getTypeId());
-        out.writeInt(attrIndex);
+        out.writeByte(type.getTypeId());
     }
 
     public enum AttributeType {
-        NUMERIC(1), NOMINAL(2);
+        NUMERIC((byte) 1), NOMINAL((byte) 2);
 
-        private final int id;
+        private final byte id;
 
-        private AttributeType(int id) {
+        private AttributeType(byte id) {
             this.id = id;
         }
 
-        public int getTypeId() {
+        public byte getTypeId() {
             return id;
         }
 
-        public static AttributeType resolve(int id) {
+        public static AttributeType resolve(byte id) {
             final AttributeType type;
             switch (id) {
                 case 1:
@@ -78,25 +78,27 @@ public abstract class Attribute {
 
     }
 
+    @Immutable
     public static final class NumericAttribute extends Attribute {
 
-        public NumericAttribute(int attrIndex) {
-            super(AttributeType.NUMERIC, attrIndex);
+        public NumericAttribute() {
+            super(AttributeType.NUMERIC);
         }
 
         @Override
         public String toString() {
-            return "NumericAttribute [type=" + type + ", attrIndex=" + attrIndex + "]";
+            return "NumericAttribute [type=" + type + "]";
         }
 
     }
 
+    @Mutable
     public static final class NominalAttribute extends Attribute {
 
         private int size;
 
-        public NominalAttribute(int attrIndex) {
-            super(AttributeType.NOMINAL, attrIndex);
+        public NominalAttribute() {
+            super(AttributeType.NOMINAL);
             this.size = -1;
         }
 
@@ -118,25 +120,23 @@ public abstract class Attribute {
 
         @Override
         public String toString() {
-            return "NominalAttribute [size=" + size + ", type=" + type + ", attrIndex=" + attrIndex
-                    + "]";
+            return "NominalAttribute [size=" + size + ", type=" + type + "]";
         }
 
     }
 
     public static Attribute readFrom(ObjectInput in) throws IOException {
-        int typeId = in.readInt();
-        int attrIndex = in.readInt();
-
         final Attribute attr;
+
+        byte typeId = in.readByte();
         final AttributeType type = AttributeType.resolve(typeId);
         switch (type) {
             case NUMERIC: {
-                attr = new NumericAttribute(attrIndex);
+                attr = new NumericAttribute();
                 break;
             }
             case NOMINAL: {
-                attr = new NominalAttribute(attrIndex);
+                attr = new NominalAttribute();
                 int size = in.readInt();
                 attr.setSize(size);
                 break;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index ebb58c6..557df21 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -19,22 +19,25 @@
 package hivemall.smile.regression;
 
 import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
 import hivemall.smile.data.Attribute;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.smile.utils.SmileTaskExecutor;
-import hivemall.smile.vm.StackMachine;
 import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.DoubleArrayList;
+import hivemall.utils.collections.lists.DoubleArrayList;
 import hivemall.utils.datetime.StopWatch;
 import hivemall.utils.hadoop.HiveUtils;
 import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
 import hivemall.utils.lang.Primitives;
 import hivemall.utils.lang.RandomUtils;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.BitSet;
 import java.util.List;
@@ -42,6 +45,7 @@ import java.util.concurrent.Callable;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
@@ -69,7 +73,7 @@ import org.apache.hadoop.mapred.Reporter;
 
 @Description(
         name = "train_randomforest_regression",
-        value = "_FUNC_(double[] features, double target [, string options]) - "
+        value = "_FUNC_(array<double|string> features, double target [, string options]) - "
                 + "Returns a relation consists of "
                 + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
 public final class RandomForestRegressionUDTF extends UDTFWithOptions {
@@ -79,7 +83,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
     private PrimitiveObjectInspector featureElemOI;
     private PrimitiveObjectInspector targetOI;
 
-    private List<double[]> featuresList;
+    private boolean denseInput;
+    private MatrixBuilder matrixBuilder;
     private DoubleArrayList targets;
     /**
      * The number of trees for each task
@@ -101,7 +106,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
     private int _minSamplesLeaf;
     private long _seed;
     private Attribute[] _attributes;
-    private ModelType _outputType;
 
     @Nullable
     private Reporter _progressReporter;
@@ -131,10 +135,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
         opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
                 + "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
-        opts.addOption("output", "output_type", true,
-            "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
-        opts.addOption("disable_compression", false,
-            "Whether to disable compression of the output script [default: false]");
         return opts;
     }
 
@@ -145,8 +145,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         float numVars = -1.f;
         Attribute[] attrs = null;
         long seed = -1L;
-        String output = "serialization";
-        boolean compress = true;
 
         CommandLine cl = null;
         if (argOIs.length >= 3) {
@@ -165,10 +163,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
                 minSamplesLeaf);
             seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
             attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
-            output = cl.getOptionValue("output", output);
-            if (cl.hasOption("disable_compression")) {
-                compress = false;
-            }
         }
 
         this._numTrees = trees;
@@ -179,7 +173,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         this._minSamplesLeaf = minSamplesLeaf;
         this._seed = seed;
         this._attributes = attrs;
-        this._outputType = ModelType.resolve(output, compress);
 
         return cl;
     }
@@ -189,19 +182,29 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         if (argOIs.length != 2 && argOIs.length != 3) {
             throw new UDFArgumentException(
                 getClass().getSimpleName()
-                        + " takes 2 or 3 arguments: double[] features, double target [, const string options]: "
+                        + " takes 2 or 3 arguments: array<double|string> features, double target [, const string options]: "
                         + argOIs.length);
         }
 
         ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
         ObjectInspector elemOI = listOI.getListElementObjectInspector();
         this.featureListOI = listOI;
-        this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+        if (HiveUtils.isNumberOI(elemOI)) {
+            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+            this.denseInput = true;
+            this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+        } else if (HiveUtils.isStringOI(elemOI)) {
+            this.featureElemOI = HiveUtils.asStringOI(elemOI);
+            this.denseInput = false;
+            this.matrixBuilder = new CSRMatrixBuilder(8192);
+        } else {
+            throw new UDFArgumentException(
+                "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+        }
         this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
 
         processOptions(argOIs);
 
-        this.featuresList = new ArrayList<double[]>(1024);
         this.targets = new DoubleArrayList(1024);
 
         ArrayList<String> fieldNames = new ArrayList<String>(5);
@@ -209,8 +212,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
 
         fieldNames.add("model_id");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
-        fieldNames.add("model_type");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+        fieldNames.add("model_err");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
         fieldNames.add("pred_model");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
         fieldNames.add("var_importance");
@@ -228,13 +231,36 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         if (args[0] == null) {
             throw new HiveException("array<double> features was null");
         }
-        double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+        parseFeatures(args[0], matrixBuilder);
         double target = PrimitiveObjectInspectorUtils.getDouble(args[1], targetOI);
-
-        featuresList.add(features);
         targets.add(target);
     }
 
+    private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+        if (denseInput) {
+            final int length = featureListOI.getListLength(argObj);
+            for (int i = 0; i < length; i++) {
+                Object o = featureListOI.getListElement(argObj, i);
+                if (o == null) {
+                    continue;
+                }
+                double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+                builder.nextColumn(i, v);
+            }
+        } else {
+            final int length = featureListOI.getListLength(argObj);
+            for (int i = 0; i < length; i++) {
+                Object o = featureListOI.getListElement(argObj, i);
+                if (o == null) {
+                    continue;
+                }
+                String fv = o.toString();
+                builder.nextColumn(fv);
+            }
+        }
+        builder.nextRow();
+    }
+
     @Override
     public void close() throws HiveException {
         this._progressReporter = getReporter();
@@ -250,10 +276,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
 
         reportProgress(_progressReporter);
 
-        int numExamples = featuresList.size();
-        if (numExamples > 0) {
-            double[][] x = featuresList.toArray(new double[numExamples][]);
-            this.featuresList = null;
+        if (!targets.isEmpty()) {
+            Matrix x = matrixBuilder.buildMatrix();
+            this.matrixBuilder = null;
             double[] y = targets.toArray();
             this.targets = null;
 
@@ -285,15 +310,16 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
      * @param _numVars The number of variables to pick up in each node.
      * @param _seed The seed number for Random Forest
      */
-    private void train(@Nonnull final double[][] x, @Nonnull final double[] y) throws HiveException {
-        if (x.length != y.length) {
+    private void train(@Nonnull Matrix x, @Nonnull final double[] y) throws HiveException {
+        final int numExamples = x.numRows();
+        if (numExamples != y.length) {
             throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
-                x.length, y.length));
+                numExamples, y.length));
         }
         checkOptions();
 
-        // Shuffle training samples 
-        SmileExtUtils.shuffle(x, y, _seed);
+        // Shuffle training samples
+        x = SmileExtUtils.shuffle(x, y, _seed);
 
         Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
         int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x);
@@ -305,10 +331,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
                     + ", seed: " + _seed);
         }
 
-        int numExamples = x.length;
         double[] prediction = new double[numExamples]; // placeholder for out-of-bag prediction
         int[] oob = new int[numExamples];
-        int[][] order = SmileExtUtils.sort(attributes, x);
+        ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
         AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
         List<TrainingTask> tasks = new ArrayList<TrainingTask>();
         for (int i = 0; i < _numTrees; i++) {
@@ -330,10 +355,13 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
 
     /**
      * Synchronized because {@link #forward(Object)} should be called from a single thread.
+     * 
+     * @param error
      */
     synchronized void forward(final int taskId, @Nonnull final Text model,
-            @Nonnull final double[] importance, final double[] y, final double[] prediction,
-            final int[] oob, final boolean lastTask) throws HiveException {
+            @Nonnull final double[] importance, @Nonnegative final double error, final double[] y,
+            final double[] prediction, final int[] oob, final boolean lastTask)
+            throws HiveException {
         double oobErrors = 0.d;
         int oobTests = 0;
         if (lastTask) {
@@ -349,7 +377,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         String modelId = RandomUtils.getUUID();
         final Object[] forwardObjs = new Object[6];
         forwardObjs[0] = new Text(modelId);
-        forwardObjs[1] = new IntWritable(_outputType.getId());
+        forwardObjs[1] = new DoubleWritable(error);
         forwardObjs[2] = model;
         forwardObjs[3] = WritableUtils.toWritableList(importance);
         forwardObjs[4] = new DoubleWritable(oobErrors);
@@ -373,16 +401,15 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         /**
          * Training instances.
          */
-        private final double[][] _x;
+        private final Matrix _x;
         /**
          * Training sample labels.
          */
         private final double[] _y;
         /**
-         * The index of training values in ascending order. Note that only numeric attributes will
-         * be sorted.
+         * The index of training values in ascending order. Note that only numeric attributes will be sorted.
          */
-        private final int[][] _order;
+        private final ColumnMajorIntMatrix _order;
         /**
          * The number of variables to pick up in each node.
          */
@@ -401,8 +428,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
         private final long _seed;
         private final AtomicInteger _remainingTasks;
 
-        TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes,
-                double[][] x, double[] y, int numVars, int[][] order, double[] prediction,
+        TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes, Matrix x,
+                double[] y, int numVars, ColumnMajorIntMatrix order, double[] prediction,
                 int[] oob, long seed, AtomicInteger remainingTasks) {
             this._udtf = udtf;
             this._taskId = taskId;
@@ -419,11 +446,11 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
 
         @Override
         public Integer call() throws HiveException {
-            long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random(
-                _seed).nextLong();
-            final smile.math.Random rnd1 = new smile.math.Random(s);
-            final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
-            final int N = _x.length;
+            long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
+                    : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+            final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+            final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+            final int N = _x.numRows();
 
             // Training samples draw with replacement.
             final int[] bags = new int[N];
@@ -441,82 +468,40 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
             incrCounter(_udtf._treeConstuctionTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
 
             // out-of-bag prediction
+            int oob = 0;
+            double error = 0.d;
+            final Vector xProbe = _x.rowVector();
             for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
-                double pred = tree.predict(_x[i]);
-                synchronized (_x[i]) {
+                oob++;
+                _x.getRow(i, xProbe);
+                final double pred = tree.predict(xProbe);
+                synchronized (_prediction) {
                     _prediction[i] += pred;
                     _oob[i]++;
                 }
+                error += Math.abs(pred - _y[i]);
+            }
+            if (oob != 0) {
+                error /= oob;
             }
 
             stopwatch.reset().start();
-            Text model = getModel(tree, _udtf._outputType);
+            Text model = getModel(tree);
             double[] importance = tree.importance();
             tree = null; // help GC
             int remain = _remainingTasks.decrementAndGet();
             boolean lastTask = (remain == 0);
-            _udtf.forward(_taskId + 1, model, importance, _y, _prediction, _oob, lastTask);
+            _udtf.forward(_taskId + 1, model, importance, error, _y, _prediction, _oob, lastTask);
             incrCounter(_udtf._treeSerializationTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
 
             return Integer.valueOf(remain);
         }
 
-        private static Text getModel(@Nonnull final RegressionTree tree,
-                @Nonnull final ModelType outputType) throws HiveException {
-            final Text model;
-            switch (outputType) {
-                case serialization:
-                case serialization_compressed: {
-                    byte[] b = tree.predictSerCodegen(outputType.isCompressed());
-                    b = Base91.encode(b);
-                    model = new Text(b);
-                    break;
-                }
-                case opscode:
-                case opscode_compressed: {
-                    String s = tree.predictOpCodegen(StackMachine.SEP);
-                    if (outputType.isCompressed()) {
-                        byte[] b = s.getBytes();
-                        final DeflateCodec codec = new DeflateCodec(true, false);
-                        try {
-                            b = codec.compress(b);
-                        } catch (IOException e) {
-                            throw new HiveException("Failed to compressing a model", e);
-                        } finally {
-                            IOUtils.closeQuietly(codec);
-                        }
-                        b = Base91.encode(b);
-                        model = new Text(b);
-                    } else {
-                        model = new Text(s);
-                    }
-                    break;
-                }
-                case javascript:
-                case javascript_compressed: {
-                    String s = tree.predictJsCodegen();
-                    if (outputType.isCompressed()) {
-                        byte[] b = s.getBytes();
-                        final DeflateCodec codec = new DeflateCodec(true, false);
-                        try {
-                            b = codec.compress(b);
-                        } catch (IOException e) {
-                            throw new HiveException("Failed to compressing a model", e);
-                        } finally {
-                            IOUtils.closeQuietly(codec);
-                        }
-                        b = Base91.encode(b);
-                        model = new Text(b);
-                    } else {
-                        model = new Text(s);
-                    }
-                    break;
-                }
-                default:
-                    throw new HiveException("Unexpected output type: " + outputType
-                            + ". Use javascript for the output instead");
-            }
-            return model;
+        @Nonnull
+        private static Text getModel(@Nonnull final RegressionTree tree) throws HiveException {
+            byte[] b = tree.predictSerCodegen(true);
+            b = Base91.encode(b);
+            return new Text(b);
         }
 
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RegressionTree.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 07887c1..da7e80b 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -33,20 +33,28 @@
  */
 package hivemall.smile.regression;
 
+import hivemall.annotations.VisibleForTesting;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
 import hivemall.smile.data.Attribute;
 import hivemall.smile.data.Attribute.AttributeType;
 import hivemall.smile.utils.SmileExtUtils;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.collections.sets.IntArraySet;
+import hivemall.utils.collections.sets.IntSet;
 import hivemall.utils.lang.ObjectUtils;
-import hivemall.utils.lang.StringUtils;
+import hivemall.utils.math.MathUtils;
 
 import java.io.Externalizable;
 import java.io.IOException;
 import java.io.ObjectInput;
 import java.io.ObjectOutput;
-import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.List;
 import java.util.PriorityQueue;
 
 import javax.annotation.Nonnull;
@@ -55,60 +63,48 @@ import javax.annotation.Nullable;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 
 import smile.math.Math;
-import smile.math.Random;
 import smile.regression.GradientTreeBoost;
 import smile.regression.RandomForest;
 import smile.regression.Regression;
 
 /**
- * Decision tree for regression. A decision tree can be learned by splitting the training set into
- * subsets based on an attribute value test. This process is repeated on each derived subset in a
- * recursive manner called recursive partitioning.
+ * Decision tree for regression. A decision tree can be learned by splitting the training set into subsets based on an attribute value test. This
+ * process is repeated on each derived subset in a recursive manner called recursive partitioning.
  * <p>
- * Classification and Regression Tree techniques have a number of advantages over many of those
- * alternative techniques.
+ * Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques.
  * <dl>
  * <dt>Simple to understand and interpret.</dt>
- * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This
- * simplicity is useful not only for purposes of rapid classification of new observations, but can
- * also often yield a much simpler "model" for explaining why observations are classified or
- * predicted in a particular manner.</dd>
+ * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This simplicity is useful not only for purposes of rapid
+ * classification of new observations, but can also often yield a much simpler "model" for explaining why observations are classified or predicted in
+ * a particular manner.</dd>
  * <dt>Able to handle both numerical and categorical data.</dt>
- * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of
- * variable.</dd>
+ * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of variable.</dd>
  * <dt>Tree methods are nonparametric and nonlinear.</dt>
- * <dd>The final results of using tree methods for classification or regression can be summarized in
- * a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no
- * implicit assumption that the underlying relationships between the predictor variables and the
- * dependent variable are linear, follow some specific non-linear link function, or that they are
- * even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks,
- * where there is often little a priori knowledge nor any coherent set of theories or predictions
- * regarding which variables are related and how. In those types of data analytics, tree methods can
- * often reveal simple relationships between just a few variables that could have easily gone
- * unnoticed using other analytic techniques.</dd>
+ * <dd>The final results of using tree methods for classification or regression can be summarized in a series of (usually few) logical if-then
+ * conditions (tree nodes). Therefore, there is no implicit assumption that the underlying relationships between the predictor variables and the
+ * dependent variable are linear, follow some specific non-linear link function, or that they are even monotonic in nature. Thus, tree methods are
+ * particularly well suited for data mining tasks, where there is often little a priori knowledge nor any coherent set of theories or predictions
+ * regarding which variables are related and how. In those types of data analytics, tree methods can often reveal simple relationships between just a
+ * few variables that could have easily gone unnoticed using other analytic techniques.</dd>
  * </dl>
- * One major problem with classification and regression trees is their high variance. Often a small
- * change in the data can result in a very different series of splits, making interpretation
- * somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause
- * over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation
- * of trees is the lack of smoothness of the prediction surface.
+ * One major problem with classification and regression trees is their high variance. Often a small change in the data can result in a very different
+ * series of splits, making interpretation somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause over-fitting.
+ * Mechanisms such as pruning are necessary to avoid this problem. Another limitation of trees is the lack of smoothness of the prediction surface.
  * <p>
- * Some techniques such as bagging, boosting, and random forest use more than one decision tree for
- * their analysis.
+ * Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis.
  * 
  * @see GradientTreeBoost
  * @see RandomForest
  */
-public final class RegressionTree implements Regression<double[]> {
+public final class RegressionTree implements Regression<Vector> {
     /**
      * The attributes of independent variable.
      */
     private final Attribute[] _attributes;
     private final boolean _hasNumericType;
     /**
-     * Variable importance. Every time a split of a node is made on variable the impurity criterion
-     * for the two descendant nodes is less than the parent node. Adding up the decreases for each
-     * individual variable over the tree gives a simple measure of variable importance.
+     * Variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendant nodes is less than the
+     * parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
      */
     private final double[] _importance;
     /**
@@ -120,8 +116,7 @@ public final class RegressionTree implements Regression<double[]> {
      */
     private final int _maxDepth;
     /**
-     * The number of instances in a node below which the tree will not split, setting S = 5
-     * generally gives good results.
+     * The number of instances in a node below which the tree will not split, setting S = 5 generally gives good results.
      */
     private final int _minSplit;
     /**
@@ -133,19 +128,17 @@ public final class RegressionTree implements Regression<double[]> {
      */
     private final int _numVars;
     /**
-     * The index of training values in ascending order. Note that only numeric attributes will be
-     * sorted.
+     * The index of training values in ascending order. Note that only numeric attributes will be sorted.
      */
-    private final int[][] _order;
+    private final ColumnMajorIntMatrix _order;
 
-    private final Random _rnd;
+    private final PRNG _rnd;
 
     private final NodeOutput _nodeOutput;
 
     /**
-     * An interface to calculate node output. Note that samples[i] is the number of sampling of
-     * dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible
-     * because of sampling with replacement.
+     * An interface to calculate node output. Note that samples[i] is the number of sampling of dataset[i]. 0 means that the datum is not included and
+     * values of greater than 1 are possible because of sampling with replacement.
      */
     public interface NodeOutput {
         /**
@@ -205,22 +198,30 @@ public final class RegressionTree implements Regression<double[]> {
             this.output = output;
         }
 
+        private boolean isLeaf() {
+            return trueChild == null && falseChild == null;
+        }
+
+        @VisibleForTesting
+        public double predict(@Nonnull final double[] x) {
+            return predict(new DenseVector(x));
+        }
+
         /**
          * Evaluate the regression tree over an instance.
          */
-        public double predict(final double[] x) {
+        public double predict(@Nonnull final Vector x) {
             if (trueChild == null && falseChild == null) {
                 return output;
             } else {
                 if (splitFeatureType == AttributeType.NOMINAL) {
-                    // REVIEWME if(Math.equals(x[splitFeature], splitValue)) {
-                    if (x[splitFeature] == splitValue) {
+                    if (x.get(splitFeature, Double.NaN) == splitValue) {
                         return trueChild.predict(x);
                     } else {
                         return falseChild.predict(x);
                     }
                 } else if (splitFeatureType == AttributeType.NUMERIC) {
-                    if (x[splitFeature] <= splitValue) {
+                    if (x.get(splitFeature, Double.NaN) <= splitValue) {
                         return trueChild.predict(x);
                     } else {
                         return falseChild.predict(x);
@@ -283,99 +284,58 @@ public final class RegressionTree implements Regression<double[]> {
             }
         }
 
-        public int opCodegen(final List<String> scripts, int depth) {
-            int selfDepth = 0;
-            final StringBuilder buf = new StringBuilder();
-            if (trueChild == null && falseChild == null) {
-                buf.append("push ").append(output);
-                scripts.add(buf.toString());
-                buf.setLength(0);
-                buf.append("goto last");
-                scripts.add(buf.toString());
-                selfDepth += 2;
-            } else {
-                if (splitFeatureType == AttributeType.NOMINAL) {
-                    buf.append("push ").append("x[").append(splitFeature).append("]");
-                    scripts.add(buf.toString());
-                    buf.setLength(0);
-                    buf.append("push ").append(splitValue);
-                    scripts.add(buf.toString());
-                    buf.setLength(0);
-                    buf.append("ifeq ");
-                    scripts.add(buf.toString());
-                    depth += 3;
-                    selfDepth += 3;
-                    int trueDepth = trueChild.opCodegen(scripts, depth);
-                    selfDepth += trueDepth;
-                    scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
-                    int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
-                    selfDepth += falseDepth;
-                } else if (splitFeatureType == AttributeType.NUMERIC) {
-                    buf.append("push ").append("x[").append(splitFeature).append("]");
-                    scripts.add(buf.toString());
-                    buf.setLength(0);
-                    buf.append("push ").append(splitValue);
-                    scripts.add(buf.toString());
-                    buf.setLength(0);
-                    buf.append("ifle ");
-                    scripts.add(buf.toString());
-                    depth += 3;
-                    selfDepth += 3;
-                    int trueDepth = trueChild.opCodegen(scripts, depth);
-                    selfDepth += trueDepth;
-                    scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
-                    int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
-                    selfDepth += falseDepth;
-                } else {
-                    throw new IllegalStateException("Unsupported attribute type: "
-                            + splitFeatureType);
-                }
-            }
-            return selfDepth;
-        }
-
         @Override
         public void writeExternal(ObjectOutput out) throws IOException {
-            out.writeDouble(output);
             out.writeInt(splitFeature);
             if (splitFeatureType == null) {
-                out.writeInt(-1);
+                out.writeByte(-1);
             } else {
-                out.writeInt(splitFeatureType.getTypeId());
+                out.writeByte(splitFeatureType.getTypeId());
             }
             out.writeDouble(splitValue);
-            if (trueChild == null) {
-                out.writeBoolean(false);
-            } else {
+
+            if (isLeaf()) {
                 out.writeBoolean(true);
-                trueChild.writeExternal(out);
-            }
-            if (falseChild == null) {
-                out.writeBoolean(false);
+                out.writeDouble(output);
             } else {
-                out.writeBoolean(true);
-                falseChild.writeExternal(out);
+                out.writeBoolean(false);
+                if (trueChild == null) {
+                    out.writeBoolean(false);
+                } else {
+                    out.writeBoolean(true);
+                    trueChild.writeExternal(out);
+                }
+                if (falseChild == null) {
+                    out.writeBoolean(false);
+                } else {
+                    out.writeBoolean(true);
+                    falseChild.writeExternal(out);
+                }
             }
         }
 
         @Override
         public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
-            this.output = in.readDouble();
             this.splitFeature = in.readInt();
-            int typeId = in.readInt();
+            byte typeId = in.readByte();
             if (typeId == -1) {
                 this.splitFeatureType = null;
             } else {
                 this.splitFeatureType = AttributeType.resolve(typeId);
             }
             this.splitValue = in.readDouble();
-            if (in.readBoolean()) {
-                this.trueChild = new Node();
-                trueChild.readExternal(in);
-            }
-            if (in.readBoolean()) {
-                this.falseChild = new Node();
-                falseChild.readExternal(in);
+
+            if (in.readBoolean()) {// isLeaf()
+                this.output = in.readDouble();
+            } else {
+                if (in.readBoolean()) {
+                    this.trueChild = new Node();
+                    trueChild.readExternal(in);
+                }
+                if (in.readBoolean()) {
+                    this.falseChild = new Node();
+                    falseChild.readExternal(in);
+                }
             }
         }
     }
@@ -406,7 +366,7 @@ public final class RegressionTree implements Regression<double[]> {
         /**
          * Training dataset.
          */
-        final double[][] x;
+        final Matrix x;
         /**
          * Training data response value.
          */
@@ -419,7 +379,7 @@ public final class RegressionTree implements Regression<double[]> {
         /**
          * Constructor.
          */
-        public TrainNode(Node node, double[][] x, double[] y, int[] bags, int depth) {
+        public TrainNode(Node node, Matrix x, double[] y, int[] bags, int depth) {
             this.node = node;
             this.x = x;
             this.y = y;
@@ -452,8 +412,7 @@ public final class RegressionTree implements Regression<double[]> {
         }
 
         /**
-         * Finds the best attribute to split on at the current node. Returns true if a split exists
-         * to reduce squared error, false otherwise.
+         * Finds the best attribute to split on at the current node. Returns true if a split exists to reduce squared error, false otherwise.
          */
         public boolean findBestSplit() {
             // avoid split if tree depth is larger than threshold
@@ -467,22 +426,14 @@ public final class RegressionTree implements Regression<double[]> {
             }
 
             final double sum = node.output * numSamples;
-            final int p = _attributes.length;
-            final int[] variables = new int[p];
-            for (int i = 0; i < p; i++) {
-                variables[i] = i;
-            }
 
-            if (_numVars < p) {
-                SmileExtUtils.shuffle(variables, _rnd);
-            }
 
             // Loop through features and compute the reduction of squared error,
             // which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2      
-            final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length)
+            final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows())
                     : null;
-            for (int j = 0; j < _numVars; j++) {
-                Node split = findBestSplit(numSamples, sum, variables[j], samples);
+            for (int varJ : variableIndex(x, bags)) {
+                final Node split = findBestSplit(numSamples, sum, varJ, samples);
                 if (split.splitScore > node.splitScore) {
                     node.splitFeature = split.splitFeature;
                     node.splitFeatureType = split.splitFeatureType;
@@ -496,6 +447,31 @@ public final class RegressionTree implements Regression<double[]> {
             return node.splitFeature != -1;
         }
 
+        private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) {
+            final int[] variableIndex;
+            if (x.isSparse()) {
+                final IntSet cols = new IntArraySet(_numVars);
+                final VectorProcedure proc = new VectorProcedure() {
+                    public void apply(int col, double value) {
+                        cols.add(col);
+                    }
+                };
+                for (final int row : bags) {
+                    x.eachNonNullInRow(row, proc);
+                }
+                variableIndex = cols.toArray(false);
+            } else {
+                variableIndex = MathUtils.permutation(_attributes.length);
+            }
+
+            if (_numVars < variableIndex.length) {
+                SmileExtUtils.shuffle(variableIndex, _rnd);
+                return Arrays.copyOf(variableIndex, _numVars);
+
+            }
+            return variableIndex;
+        }
+
         /**
          * Finds the best split cutoff for attribute j at the current node.
          * 
@@ -517,7 +493,11 @@ public final class RegressionTree implements Regression<double[]> {
                     // For each true feature of this datum increment the
                     // sufficient statistics for the "true" branch to evaluate
                     // splitting on this feature.
-                    int index = (int) x[i][j];
+                    final double v = x.get(i, j, Double.NaN);
+                    if (Double.isNaN(v)) {
+                        continue;
+                    }
+                    int index = (int) v;
                     trueSum[index] += y[i];
                     ++trueCount[index];
                 }
@@ -548,28 +528,38 @@ public final class RegressionTree implements Regression<double[]> {
                     }
                 }
             } else if (_attributes[j].type == AttributeType.NUMERIC) {
-                double trueSum = 0.0;
-                int trueCount = 0;
-                double prevx = Double.NaN;
-
-                for (int i : _order[j]) {
-                    final int sample = samples[i];
-                    if (sample > 0) {
-                        if (Double.isNaN(prevx) || x[i][j] == prevx) {
-                            prevx = x[i][j];
-                            trueSum += sample * y[i];
+
+                _order.eachNonNullInColumn(j, new VectorProcedure() {
+                    double trueSum = 0.0;
+                    int trueCount = 0;
+                    double prevx = Double.NaN;
+
+                    public void apply(final int row, final int i) {
+                        final int sample = samples[i];
+                        if (sample == 0) {
+                            return;
+                        }
+                        final double x_ij = x.get(i, j, Double.NaN);
+                        if (Double.isNaN(x_ij)) {
+                            return;
+                        }
+                        final double y_i = y[i];
+
+                        if (Double.isNaN(prevx) || x_ij == prevx) {
+                            prevx = x_ij;
+                            trueSum += sample * y_i;
                             trueCount += sample;
-                            continue;
+                            return;
                         }
 
                         final double falseCount = n - trueCount;
 
                         // If either side is empty, skip this feature.
                         if (trueCount < _minSplit || falseCount < _minSplit) {
-                            prevx = x[i][j];
-                            trueSum += sample * y[i];
+                            prevx = x_ij;
+                            trueSum += sample * y_i;
                             trueCount += sample;
-                            continue;
+                            return;
                         }
 
                         // compute penalized means
@@ -586,17 +576,18 @@ public final class RegressionTree implements Regression<double[]> {
                             // new best split
                             split.splitFeature = j;
                             split.splitFeatureType = AttributeType.NUMERIC;
-                            split.splitValue = (x[i][j] + prevx) / 2;
+                            split.splitValue = (x_ij + prevx) / 2;
                             split.splitScore = gain;
                             split.trueChildOutput = trueMean;
                             split.falseChildOutput = falseMean;
                         }
 
-                        prevx = x[i][j];
-                        trueSum += sample * y[i];
+                        prevx = x_ij;
+                        trueSum += sample * y_i;
                         trueCount += sample;
-                    }
-                }
+                    }//apply
+                });
+
             } else {
                 throw new IllegalStateException("Unsupported attribute type: "
                         + _attributes[j].type);
@@ -672,7 +663,7 @@ public final class RegressionTree implements Regression<double[]> {
                 final double splitValue = node.splitValue;
                 for (int i = 0, size = bags.length; i < size; i++) {
                     final int index = bags[i];
-                    if (x[index][splitFeature] == splitValue) {
+                    if (x.get(index, splitFeature, Double.NaN) == splitValue) {
                         trueBags.add(index);
                         tc++;
                     } else {
@@ -684,7 +675,7 @@ public final class RegressionTree implements Regression<double[]> {
                 final double splitValue = node.splitValue;
                 for (int i = 0, size = bags.length; i < size; i++) {
                     final int index = bags[i];
-                    if (x[index][splitFeature] <= splitValue) {
+                    if (x.get(index, splitFeature, Double.NaN) <= splitValue) {
                         trueBags.add(index);
                         tc++;
                     } else {
@@ -700,20 +691,19 @@ public final class RegressionTree implements Regression<double[]> {
 
     }
 
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
-            @Nonnull double[] y, int maxLeafs) {
-        this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
+    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+            int maxLeafs) {
+        this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
     }
 
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
-            @Nonnull double[] y, int maxLeafs, @Nullable smile.math.Random rand) {
-        this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
+    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+            int maxLeafs, @Nullable PRNG rand) {
+        this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
     }
 
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
-            @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
-            int minLeafSize, @Nullable int[][] order, @Nullable int[] bags,
-            @Nullable smile.math.Random rand) {
+    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+            int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
+            @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags, @Nullable PRNG rand) {
         this(attributes, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, order, bags, null, rand);
     }
 
@@ -723,24 +713,22 @@ public final class RegressionTree implements Regression<double[]> {
      * @param attributes the attribute properties.
      * @param x the training instances.
      * @param y the response variable.
-     * @param numVars the number of input variables to pick to split on at each node. It seems that
-     *        dim/3 give generally good performance, where dim is the number of variables.
+     * @param numVars the number of input variables to pick to split on at each node. It seems that dim/3 give generally good performance, where dim
+     *        is the number of variables.
      * @param maxLeafs the maximum number of leaf nodes in the tree.
-     * @param minSplits number of instances in a node below which the tree will not split, setting S
-     *        = 5 generally gives good results.
-     * @param order the index of training values in ascending order. Note that only numeric
-     *        attributes need be sorted.
+     * @param minSplits number of instances in a node below which the tree will not split, setting S = 5 generally gives good results.
+     * @param order the index of training values in ascending order. Note that only numeric attributes need be sorted.
      * @param bags the sample set of instances for stochastic learning.
      * @param output An interface to calculate node output.
      */
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
-            @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
-            int minLeafSize, @Nullable int[][] order, @Nullable int[] bags,
-            @Nullable NodeOutput output, @Nullable smile.math.Random rand) {
+    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+            int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
+            @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags,
+            @Nullable NodeOutput output, @Nullable PRNG rand) {
         checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
 
         this._attributes = SmileExtUtils.attributeTypes(attributes, x);
-        if (_attributes.length != x[0].length) {
+        if (_attributes.length != x.numColumns()) {
             throw new IllegalArgumentException("-attrs option is invliad: "
                     + Arrays.toString(attributes));
         }
@@ -752,7 +740,7 @@ public final class RegressionTree implements Regression<double[]> {
         this._minLeafSize = minLeafSize;
         this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
         this._importance = new double[_attributes.length];
-        this._rnd = (rand == null) ? new smile.math.Random() : rand;
+        this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
         this._nodeOutput = output;
 
         int n = 0;
@@ -803,13 +791,13 @@ public final class RegressionTree implements Regression<double[]> {
         }
     }
 
-    private static void checkArgument(@Nonnull double[][] x, @Nonnull double[] y, int numVars,
+    private static void checkArgument(@Nonnull Matrix x, @Nonnull double[] y, int numVars,
             int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
-        if (x.length != y.length) {
+        if (x.numRows() != y.length) {
             throw new IllegalArgumentException(String.format(
-                "The sizes of X and Y don't match: %d != %d", x.length, y.length));
+                "The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
         }
-        if (numVars <= 0 || numVars > x[0].length) {
+        if (numVars <= 0 || numVars > x.numColumns()) {
             throw new IllegalArgumentException(
                 "Invalid number of variables to split on at a node of the tree: " + numVars);
         }
@@ -830,10 +818,8 @@ public final class RegressionTree implements Regression<double[]> {
     }
 
     /**
-     * Returns the variable importance. Every time a split of a node is made on variable the
-     * impurity criterion for the two descendent nodes is less than the parent node. Adding up the
-     * decreases for each individual variable over the tree gives a simple measure of variable
-     * importance.
+     * Returns the variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendent nodes is less
+     * than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
      *
      * @return the variable importance
      */
@@ -841,8 +827,13 @@ public final class RegressionTree implements Regression<double[]> {
         return _importance;
     }
 
+    @VisibleForTesting
+    public double predict(@Nonnull final double[] x) {
+        return predict(new DenseVector(x));
+    }
+
     @Override
-    public double predict(double[] x) {
+    public double predict(@Nonnull final Vector x) {
         return _root.predict(x);
     }
 
@@ -852,14 +843,6 @@ public final class RegressionTree implements Regression<double[]> {
         return buf.toString();
     }
 
-    public String predictOpCodegen(@Nonnull String sep) {
-        List<String> opslist = new ArrayList<String>();
-        _root.opCodegen(opslist, 0);
-        opslist.add("call end");
-        String scripts = StringUtils.concat(opslist, sep);
-        return scripts;
-    }
-
     @Nonnull
     public byte[] predictSerCodegen(boolean compress) throws HiveException {
         try {