You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2020/08/06 07:05:46 UTC

[incubator-hivemall] branch master updated: [HIVEMALL-297] Fixed null element handling in feature vector

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new dceff6b  [HIVEMALL-297] Fixed null element handling in feature vector
dceff6b is described below

commit dceff6b8bc6c1e28016c75200e7a4cc1adf9cfa5
Author: Makoto Yui <my...@apache.org>
AuthorDate: Thu Aug 6 16:05:37 2020 +0900

    [HIVEMALL-297] Fixed null element handling in feature vector
    
    ## What changes were proposed in this pull request?
    
    Fixed null element handling in feature vector
    
    ## What type of PR is it?
    
    Bug Fix
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-297
    
    ## How was this patch tested?
    
    unit tests
    
    ## Checklist
    
    (Please remove this section if not needed; check `x` for YES, blank for NO)
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
    - [ ] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <my...@apache.org>
    
    Closes #231 from myui/HIVEMALL-297.
---
 .../main/java/hivemall/GeneralLearnerBaseUDTF.java |  8 ++--
 .../utils/collections/CollectionUtils.java         | 39 ++++++++++++++++
 .../java/hivemall/GeneralLearnerBaseUDTFTest.java  | 52 ++++++++++++++++++++++
 3 files changed, 96 insertions(+), 3 deletions(-)

diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
index 8400a2c..c8675e5 100644
--- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
@@ -18,6 +18,8 @@
  */
 package hivemall;
 
+import static hivemall.utils.collections.CollectionUtils.countNonNulls;
+
 import hivemall.annotations.VisibleForTesting;
 import hivemall.common.ConversionState;
 import hivemall.model.FeatureValue;
@@ -418,8 +420,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
         }
 
         final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
-        final FeatureValue[] featureVector = new FeatureValue[size];
-        for (int i = 0; i < size; i++) {
+        final FeatureValue[] featureVector = new FeatureValue[countNonNulls(features)];
+        for (int i = 0, j = 0; i < size; i++) {
             Object f = features.get(i);
             if (f == null) {
                 continue;
@@ -433,7 +435,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
                     ObjectInspectorCopyOption.JAVA); // should be Integer or Long
                 fv = new FeatureValue(k, 1.f);
             }
-            featureVector[i] = fv;
+            featureVector[j++] = fv;
         }
         return featureVector;
     }
diff --git a/core/src/main/java/hivemall/utils/collections/CollectionUtils.java b/core/src/main/java/hivemall/utils/collections/CollectionUtils.java
new file mode 100644
index 0000000..9a6ae06
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/CollectionUtils.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections;
+
+import java.util.Collection;
+
+import javax.annotation.Nonnull;
+
+public final class CollectionUtils {
+
+    private CollectionUtils() {}
+
+    public static int countNonNulls(@Nonnull final Collection<?> col) {
+        int cnt = 0;
+        for (Object e : col) {
+            if (e != null) {
+                cnt++;
+            }
+        }
+        return cnt;
+    }
+
+}
diff --git a/core/src/test/java/hivemall/GeneralLearnerBaseUDTFTest.java b/core/src/test/java/hivemall/GeneralLearnerBaseUDTFTest.java
new file mode 100644
index 0000000..e18257e
--- /dev/null
+++ b/core/src/test/java/hivemall/GeneralLearnerBaseUDTFTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall;
+
+import hivemall.classifier.GeneralClassifierUDTF;
+import hivemall.model.FeatureValue;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class GeneralLearnerBaseUDTFTest {
+
+    @Test
+    public void testNullFeature() throws UDFArgumentException {
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+        udtf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector});
+
+        List<String> features = Arrays.asList("1", "2", null, "3", null);
+
+        FeatureValue[] array = udtf.parseFeatures(features);
+        Assert.assertEquals(3, array.length);
+
+        Assert.assertEquals(0, udtf.parseFeatures(Arrays.asList(new String[] {null})).length);
+    }
+
+}