You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/07/10 07:17:29 UTC
[incubator-hivemall] branch master updated: Added sanity checks for
training data in RandomForest
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new 72dca39 Added sanity checks for training data in RandomForest
72dca39 is described below
commit 72dca396c6851c9ea44df7eac86ba677ea21879e
Author: Makoto Yui <my...@apache.org>
AuthorDate: Wed Jul 10 16:17:20 2019 +0900
Added sanity checks for training data in RandomForest
---
.../classification/RandomForestClassifierUDTF.java | 10 ++
.../RandomForestClassifierUDTFTest.java | 101 ++++++++++++++++++++-
2 files changed, 108 insertions(+), 3 deletions(-)
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 7f2966b..99396b7 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -327,6 +327,16 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
int[] y = labels.toArray();
this.labels = null;
+ // sanity checks
+ if (x.numColumns() == 0) {
+ throw new HiveException(
+ "No non-null features in the training examples. Revise training data");
+ }
+ if (x.numRows() != y.length) {
+ throw new HiveException("Illegal condition was met. y.length=" + y.length
+ + ", X.length=" + x.numRows());
+ }
+
// run training
train(x, y);
}
diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
index 0793ae6..aa839fa 100644
--- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
@@ -22,6 +22,8 @@ import hivemall.TestUtils;
import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF;
import hivemall.utils.codec.Base91;
import hivemall.utils.lang.mutable.MutableInt;
+import smile.data.AttributeDataset;
+import smile.data.parser.ArffParser;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
@@ -32,6 +34,7 @@ import java.net.URL;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;
+import java.util.Random;
import java.util.StringTokenizer;
import java.util.zip.GZIPInputStream;
@@ -48,9 +51,6 @@ import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
-import smile.data.AttributeDataset;
-import smile.data.parser.ArffParser;
-
public class RandomForestClassifierUDTFTest {
@Test
@@ -98,6 +98,101 @@ public class RandomForestClassifierUDTFTest {
}
@Test
+ public void testIrisDenseSomeNullFeaturesTest()
+ throws IOException, ParseException, HiveException {
+ URL url = new URL(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final Random rand = new Random(43);
+ final List<Double> xi = new ArrayList<Double>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < x[i].length; j++) {
+ if (rand.nextDouble() >= 0.7) {
+ xi.add(j, null);
+ } else {
+ xi.add(j, x[i][j]);
+ }
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(49, count.getValue());
+ }
+
+ @Test(expected = HiveException.class)
+ public void testIrisDenseAllNullFeaturesTest()
+ throws IOException, ParseException, HiveException {
+ URL url = new URL(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<Double> xi = new ArrayList<Double>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < x[i].length; j++) {
+ xi.add(j, null);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.fail("should not be called");
+ }
+
+ @Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
URL url = new URL(
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");