You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/07/10 07:17:29 UTC

[incubator-hivemall] branch master updated: Added sanity checks for training data in RandomForest

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 72dca39  Added sanity checks for training data in RandomForest
72dca39 is described below

commit 72dca396c6851c9ea44df7eac86ba677ea21879e
Author: Makoto Yui <my...@apache.org>
AuthorDate: Wed Jul 10 16:17:20 2019 +0900

    Added sanity checks for training data in RandomForest
---
 .../classification/RandomForestClassifierUDTF.java |  10 ++
 .../RandomForestClassifierUDTFTest.java            | 101 ++++++++++++++++++++-
 2 files changed, 108 insertions(+), 3 deletions(-)

diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 7f2966b..99396b7 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -327,6 +327,16 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
             int[] y = labels.toArray();
             this.labels = null;
 
+            // sanity checks
+            if (x.numColumns() == 0) {
+                throw new HiveException(
+                    "No non-null features in the training examples. Revise training data");
+            }
+            if (x.numRows() != y.length) {
+                throw new HiveException("Illegal condition was met. y.length=" + y.length
+                        + ", X.length=" + x.numRows());
+            }
+
             // run training
             train(x, y);
         }
diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
index 0793ae6..aa839fa 100644
--- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
@@ -22,6 +22,8 @@ import hivemall.TestUtils;
 import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF;
 import hivemall.utils.codec.Base91;
 import hivemall.utils.lang.mutable.MutableInt;
+import smile.data.AttributeDataset;
+import smile.data.parser.ArffParser;
 
 import java.io.BufferedInputStream;
 import java.io.BufferedReader;
@@ -32,6 +34,7 @@ import java.net.URL;
 import java.text.ParseException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Random;
 import java.util.StringTokenizer;
 import java.util.zip.GZIPInputStream;
 
@@ -48,9 +51,6 @@ import org.apache.hadoop.io.Text;
 import org.junit.Assert;
 import org.junit.Test;
 
-import smile.data.AttributeDataset;
-import smile.data.parser.ArffParser;
-
 public class RandomForestClassifierUDTFTest {
 
     @Test
@@ -98,6 +98,101 @@ public class RandomForestClassifierUDTFTest {
     }
 
     @Test
+    public void testIrisDenseSomeNullFeaturesTest()
+            throws IOException, ParseException, HiveException {
+        URL url = new URL(
+            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+        InputStream is = new BufferedInputStream(url.openStream());
+
+        ArffParser arffParser = new ArffParser();
+        arffParser.setResponseIndex(4);
+
+        AttributeDataset iris = arffParser.parse(is);
+        int size = iris.size();
+        double[][] x = iris.toArray(new double[size][]);
+        int[] y = iris.toArray(new int[size]);
+
+        RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+        final Random rand = new Random(43);
+        final List<Double> xi = new ArrayList<Double>(x[0].length);
+        for (int i = 0; i < size; i++) {
+            for (int j = 0; j < x[i].length; j++) {
+                if (rand.nextDouble() >= 0.7) {
+                    xi.add(j, null);
+                } else {
+                    xi.add(j, x[i][j]);
+                }
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final MutableInt count = new MutableInt(0);
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                count.addValue(1);
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Assert.assertEquals(49, count.getValue());
+    }
+
+    @Test(expected = HiveException.class)
+    public void testIrisDenseAllNullFeaturesTest()
+            throws IOException, ParseException, HiveException {
+        URL url = new URL(
+            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+        InputStream is = new BufferedInputStream(url.openStream());
+
+        ArffParser arffParser = new ArffParser();
+        arffParser.setResponseIndex(4);
+
+        AttributeDataset iris = arffParser.parse(is);
+        int size = iris.size();
+        double[][] x = iris.toArray(new double[size][]);
+        int[] y = iris.toArray(new int[size]);
+
+        RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+        final List<Double> xi = new ArrayList<Double>(x[0].length);
+        for (int i = 0; i < size; i++) {
+            for (int j = 0; j < x[i].length; j++) {
+                xi.add(j, null);
+            }
+            udtf.process(new Object[] {xi, y[i]});
+            xi.clear();
+        }
+
+        final MutableInt count = new MutableInt(0);
+        Collector collector = new Collector() {
+            public void collect(Object input) throws HiveException {
+                count.addValue(1);
+            }
+        };
+
+        udtf.setCollector(collector);
+        udtf.close();
+
+        Assert.fail("should not be called");
+    }
+
+    @Test
     public void testIrisSparse() throws IOException, ParseException, HiveException {
         URL url = new URL(
             "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");