You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/07/18 07:51:42 UTC

[incubator-hivemall] branch master updated: Add test of sparse input for randomforest classifier

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ef09e4  Add test of sparse input for randomforest classifier
4ef09e4 is described below

commit 4ef09e4a768d8a4ecd01bf0a864006b895a634ec
Author: Makoto Yui <my...@apache.org>
AuthorDate: Thu Jul 18 16:51:33 2019 +0900

    Add test of sparse input for randomforest classifier
---
 .../RandomForestClassifierUDTFTest.java            | 56 ++++++++++++++++++++++
 1 file changed, 56 insertions(+)

diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
index aa839fa..f7b0285 100644
--- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
@@ -464,6 +464,62 @@ public class RandomForestClassifierUDTFTest {
     }
 
     @Test
+    public void testSparseRandomForestClassifier() throws HiveException {
+        RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector});
+
+        udtf.process(new Object[] {new String[] {"1:1.0", "4:1.0", "7:1.0", "12:1.0"}, 1}); // 0
+        udtf.process(new Object[] {new String[] {"2:1.0", "4:1.0", "5:1.0", "11:1.0"}, 1}); // 1
+        udtf.process(new Object[] {
+                new String[] {"1:1.0", "4:1.0", "7:1.0", "113:1.0", "497:1.0", "635:1.0"}, 0}); // 2
+        udtf.process(new Object[] {
+                new String[] {"1:1.0", "4:1.0", "5:1.0", "7:1.0", "10:1.0", "14:1.0"}, 1}); // 3
+        udtf.process(new Object[] {new String[] {"1:1.0", "2:1.0", "4:1.0", "7:1.0", "8:1.0"}, 1}); // 4
+        udtf.process(new Object[] {new String[] {"13:1.0", "18:1.0", "25:1.0", "27:1.0", "65:1.0",
+                "116:1.0", "200:1.0", "468:1.0", "585:1.0", "715:1.0"}, 0});
+
+        udtf.setCollector(new Collector() {
+            @Override
+            public void collect(Object input) throws HiveException {}
+
+        });
+
+        udtf.close();
+    }
+
+    @Test
+    public void testSparseRandomForestClassifierL2Normalized() throws HiveException {
+        RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector});
+
+        udtf.process(new Object[] {new String[] {"1:0.5", "4:0.5", "7:0.5", "12:0.5"}, 1}); // 0
+        udtf.process(new Object[] {new String[] {"2:0.5", "4:0.5", "5:0.5", "11:0.5"}, 1}); // 1
+        udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "7:0.40824828",
+                "113:0.40824828", "497:0.40824828", "635:0.40824828"}, 0}); // 2
+        udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "5:0.40824828",
+                "7:0.40824828", "10:0.40824828", "14:0.40824828"}, 1}); // 3
+        udtf.process(new Object[] {new String[] {"1:0.4472136", "2:0.4472136", "4:0.4472136",
+                "7:0.4472136", "8:0.4472136"}, 1}); // 4
+        udtf.process(new Object[] {new String[] {"13:0.31622776", "18:0.31622776", "25:0.31622776",
+                "27:0.31622776", "65:0.31622776", "116:0.31622776", "200:0.31622776",
+                "468:0.31622776", "585:0.31622776", "715:0.31622776"}, 0}); // 5
+
+        udtf.setCollector(new Collector() {
+            @Override
+            public void collect(Object input) throws HiveException {}
+
+        });
+
+        udtf.close();
+    }
+
+    @Test
     public void testSerialization() throws HiveException, IOException, ParseException {
         URL url = new URL(
             "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");