You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/11/25 16:53:36 UTC

[incubator-hivemall] branch master updated: [HIVEMALL-165] Fixed to accept any primitive

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new afa8b13  [HIVEMALL-165] Fixed to accept any primitive
afa8b13 is described below

commit afa8b133824cd0e9cd3c5dcce6fe7da601dc16a5
Author: Makoto Yui <my...@apache.org>
AuthorDate: Tue Nov 26 01:53:29 2019 +0900

    [HIVEMALL-165] Fixed to accept any primitive
    
    ## What changes were proposed in this pull request?
    
    Fix a bug that `array_remove` UDF throws exception when the first argument is null
    
    ## What type of PR is it?
    
    Bug Fix
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-165
    
    ## How was this patch tested?
    
    manual tests on EMR
    
    ## How to use this feature?
    
    ```sql
    WITH data4 as (
      select false as n, array(2.0, 3.0, 4.0) as nums
      union all
       select true as n, array(2.0, 3.0, 4.0) as nums
    )
    select
      array_remove(if(n = true, null, nums), 2.0) as c1,
      array_remove(if(n = true, null, nums), array(3.0,2.0)) as c2,
      array_remove(if(n = false, null, nums), 2.0) as c3
    from
      data4;
    > c1      c2      c3
    > [3,4]   [4]     NULL
    > NULL    NULL    [3,4]
    
    select array_remove(array(2.0,2.1,3.0,4.0,2.0),2), array_remove(array(2.0,3.0,4.0),array(3,2.0));
    > [2.1,3,4]       [4]
    
    SELECT array_remove(array(1,null,3),null);
    > [1,3]
    
    SELECT array_remove(array(1,null,3,null,5),null);
    > [1,3,5]
    
    SELECT array_remove(array(1,null,3),array(null));
    > [1,3]
    
    SELECT array_remove(array('aaa','bbb'),'bbb');
    > ["aaa"]
    
    SELECT array_remove(array('aaa','bbb','ccc','bbb'), array('bbb','ccc'));
    > ["aaa"]
    
    select array_remove(array(null),null);
    > []
    
    select array_remove(array(null,'bbb'),'aaa');
    > [null,"bbb"]
    ```
    
    ## Checklist
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
    - [x] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <my...@apache.org>
    
    Closes #217 from myui/HIVEMALL-165.
---
 .../java/hivemall/tools/array/ArrayRemoveUDF.java  | 130 ++++++++++++++++++---
 .../main/java/hivemall/utils/hadoop/HiveUtils.java |  78 +++++++++++++
 docs/gitbook/misc/generic_funcs.md                 |  26 ++++-
 3 files changed, 213 insertions(+), 21 deletions(-)

diff --git a/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java b/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java
index 207c398..5dda0ff 100644
--- a/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java
+++ b/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java
@@ -18,35 +18,131 @@
  */
 package hivemall.tools.array;
 
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.StringUtils;
+
+import java.util.Collections;
 import java.util.List;
 
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
 import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 
+//@formatter:off
 @Description(name = "array_remove",
-        value = "_FUNC_(array<int|text> original, int|text|array<int> target)"
-                + " - Returns an array that the target is removed " + "from the original array",
-        extended = "SELECT array_remove(array(1,null,3),array(null));\n" + " [3]\n" + "\n"
-                + "SELECT array_remove(array(\"aaa\",\"bbb\"),\"bbb\");\n" + " [\"aaa\"]")
+        value = "_FUNC_(array<PRIMITIVE> values, PRIMITIVE|array<PRIMITIVE> target)"
+                + " - Returns an array that the target elements are removed from the original array",
+        extended = "select array_remove(array(2.0,2.1,3.0,4.0,2.0),2), array_remove(array(2.0,3.0,4.0),array(3,2.0));\n" + 
+                "[2.1,3,4]       [4]\n" + 
+                "\n" + 
+                "SELECT array_remove(array(1,null,3),null);\n" + 
+                "[1,3]\n" + 
+                "\n" + 
+                "SELECT array_remove(array(1,null,3,null,5),null);\n" + 
+                "[1,3,5]\n" + 
+                "\n" + 
+                "SELECT array_remove(array(1,null,3),array(null));\n" + 
+                "[1,3]\n" + 
+                "\n" + 
+                "SELECT array_remove(array('aaa','bbb'),'bbb');\n" + 
+                "[\"aaa\"]\n" + 
+                "\n" + 
+                "SELECT array_remove(array('aaa','bbb','ccc','bbb'), array('bbb','ccc'));\n" + 
+                "[\"aaa\"]\n" + 
+                "\n" + 
+                "select array_remove(array(null),null);\n" + 
+                "[]\n" + 
+                "\n" + 
+                "select array_remove(array(null,'bbb'),'aaa');\n" + 
+                "[null,\"bbb\"]")
+//@formatter:on
 @UDFType(deterministic = true, stateful = false)
-public class ArrayRemoveUDF extends UDF {
+public final class ArrayRemoveUDF extends GenericUDF {
+
+    private ListObjectInspector valueListOI;
+    private PrimitiveObjectInspector valueElemOI;
+    private boolean isTargetList;
+    @Nullable
+    private ListObjectInspector targetListOI;
+    private PrimitiveObjectInspector targetElemOI;
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+        if (argOIs.length != 2) {
+            throw new UDFArgumentLengthException("Expected 2 arguments, but got " + argOIs.length);
+        }
+
+        this.valueListOI = HiveUtils.asListOI(argOIs, 0);
+        this.valueElemOI =
+                HiveUtils.asPrimitiveObjectInspector(valueListOI.getListElementObjectInspector());
+
+        if (HiveUtils.isListOI(argOIs[1])) {
+            this.isTargetList = true;
+            this.targetListOI = HiveUtils.asListOI(argOIs, 1);
+            this.targetElemOI = HiveUtils.asPrimitiveObjectInspector(
+                targetListOI.getListElementObjectInspector());
+        } else {
+            this.isTargetList = false;
+            this.targetElemOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 1);
+        }
+
+        return ObjectInspectorFactory.getStandardListObjectInspector(valueElemOI);
+    }
+
+    @Nullable
+    @Override
+    public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
+        assert (arguments.length == 2);
+
+        final List<?> values = HiveUtils.copyListObject(arguments[0], valueListOI);
+        if (values == null) {
+            return null;
+        }
+
+        final Object target = arguments[1].get();
+        if (target == null) {
+            values.removeAll(Collections.singletonList(null));
+            return values;
+        }
+
+        if (isTargetList) {
+            Converter converter = ObjectInspectorConverters.getConverter(targetListOI, valueListOI);
+            removeAll(values, target, converter, valueListOI);
+        } else {
+            Converter converter = ObjectInspectorConverters.getConverter(targetElemOI, valueElemOI);
+            removeAll(values, target, converter);
+        }
+        return values;
+    }
 
-    public List<IntWritable> evaluate(List<IntWritable> original, IntWritable target) {
-        while (original.remove(target));
-        return original;
+    private static void removeAll(@Nonnull final List<?> values, @Nonnull final Object target,
+            @Nonnull final Converter converter, @Nonnull final ListObjectInspector valueListOI) {
+        Object converted = converter.convert(target);
+        List<?> convertedList = valueListOI.getList(converted);
+        values.removeAll(convertedList);
     }
 
-    public List<IntWritable> evaluate(List<IntWritable> original, List<IntWritable> targets) {
-        original.removeAll(targets);
-        return original;
+    private static void removeAll(@Nonnull final List<?> values, @Nonnull final Object target,
+            @Nonnull final Converter converter) {
+        Object converted = converter.convert(target);
+        values.removeAll(Collections.singleton(converted));
     }
 
-    public List<Text> evaluate(List<Text> original, Text target) {
-        while (original.remove(target));
-        return original;
+    @Override
+    public String getDisplayString(String[] children) {
+        return "array_remove(" + StringUtils.join(children, ',') + ')';
     }
 
 }
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 5c485cd..6c02974 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -31,6 +31,7 @@ import static hivemall.HivemallConstants.TINYINT_TYPE_NAME;
 import static hivemall.HivemallConstants.VOID_TYPE_NAME;
 
 import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.BitSet;
 import java.util.Collections;
@@ -66,6 +67,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
 import org.apache.hadoop.hive.serde2.objectinspector.StandardConstantListObjectInspector;
@@ -472,6 +474,82 @@ public final class HiveUtils {
         return (ListTypeInfo) typeInfo;
     }
 
+    public static boolean isSameCategoryGroup(@Nonnull final PrimitiveCategory cat1,
+            @Nonnull final PrimitiveCategory cat2) {
+        if (cat1 == cat2) {
+            return true;
+        }
+
+        switch (cat1) {
+            // integers
+            case BYTE:
+            case SHORT:
+            case INT:
+            case LONG: {
+                switch (cat2) {
+                    case BYTE:
+                    case SHORT:
+                    case INT:
+                    case LONG:
+                        return true;
+                    default:
+                        return false;
+                }
+            }
+            // floating point number
+            case FLOAT:
+            case DOUBLE: {
+                switch (cat2) {
+                    case FLOAT:
+                    case DOUBLE:
+                        return true;
+                    default:
+                        return false;
+                }
+            }
+            // string
+            case STRING:
+            case CHAR:
+            case VARCHAR:
+                switch (cat2) {
+                    case STRING:
+                    case CHAR:
+                    case VARCHAR:
+                        return true;
+                    default:
+                        return false;
+                }
+            default:
+                break;
+        }
+        return false;
+    }
+
+    @Nullable
+    public static ArrayList<Object> copyListObject(@Nonnull final DeferredObject argument,
+            @Nonnull final ListObjectInspector loi) throws HiveException {
+        return copyListObject(argument, loi, ObjectInspectorCopyOption.DEFAULT);
+    }
+
+    @Nullable
+    public static ArrayList<Object> copyListObject(@Nonnull final DeferredObject argument,
+            @Nonnull final ListObjectInspector loi,
+            @Nonnull final ObjectInspectorCopyOption objectInspectorOption) throws HiveException {
+        final Object o = argument.get();
+        if (o == null) {
+            return null;
+        }
+
+        final int length = loi.getListLength(o);
+        final ArrayList<Object> list = new ArrayList<Object>(length);
+        for (int i = 0; i < length; i++) {
+            Object e = ObjectInspectorUtils.copyToStandardObject(loi.getListElement(o, i),
+                loi.getListElementObjectInspector(), objectInspectorOption);
+            list.add(e);
+        }
+        return list;
+    }
+
     public static float getFloat(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) {
         if (o == null) {
             return 0.f;
diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md
index 73983f5..fb0b240 100644
--- a/docs/gitbook/misc/generic_funcs.md
+++ b/docs/gitbook/misc/generic_funcs.md
@@ -155,13 +155,31 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
    [3]
   ```
 
-- `array_remove(array<int|text> original, int|text|array<int> target)` - Returns an array that the target is removed from the original array
+- `array_remove(array<PRIMITIVE> values, PRIMITIVE|array<PRIMITIVE> target)` - Returns an array that the target elements are removed from the original array
   ```sql
+  select array_remove(array(2.0,2.1,3.0,4.0,2.0),2), array_remove(array(2.0,3.0,4.0),array(3,2.0));
+  [2.1,3,4]       [4]
+
+  SELECT array_remove(array(1,null,3),null);
+  [1,3]
+
+  SELECT array_remove(array(1,null,3,null,5),null);
+  [1,3,5]
+
   SELECT array_remove(array(1,null,3),array(null));
-   [3]
+  [1,3]
+
+  SELECT array_remove(array('aaa','bbb'),'bbb');
+  ["aaa"]
+
+  SELECT array_remove(array('aaa','bbb','ccc','bbb'), array('bbb','ccc'));
+  ["aaa"]
+
+  select array_remove(array(null),null);
+  []
 
-  SELECT array_remove(array("aaa","bbb"),"bbb");
-   ["aaa"]
+  select array_remove(array(null,'bbb'),'aaa');
+  [null,"bbb"]
   ```
 
 - `array_slice(array<ANY> values, int offset [, int length])` - Slices the given array by the given offset and length parameters.