You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/11/25 16:53:36 UTC
[incubator-hivemall] branch master updated: [HIVEMALL-165] Fixed to
accept any primitive
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new afa8b13 [HIVEMALL-165] Fixed to accept any primitive
afa8b13 is described below
commit afa8b133824cd0e9cd3c5dcce6fe7da601dc16a5
Author: Makoto Yui <my...@apache.org>
AuthorDate: Tue Nov 26 01:53:29 2019 +0900
[HIVEMALL-165] Fixed to accept any primitive
## What changes were proposed in this pull request?
Fix a bug that `array_remove` UDF throws exception when the first argument is null
## What type of PR is it?
Bug Fix
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-165
## How was this patch tested?
manual tests on EMR
## How to use this feature?
```sql
WITH data4 as (
select false as n, array(2.0, 3.0, 4.0) as nums
union all
select true as n, array(2.0, 3.0, 4.0) as nums
)
select
array_remove(if(n = true, null, nums), 2.0) as c1,
array_remove(if(n = true, null, nums), array(3.0,2.0)) as c2,
array_remove(if(n = false, null, nums), 2.0) as c3
from
data4;
> c1 c2 c3
> [3,4] [4] NULL
> NULL NULL [3,4]
select array_remove(array(2.0,2.1,3.0,4.0,2.0),2), array_remove(array(2.0,3.0,4.0),array(3,2.0));
> [2.1,3,4] [4]
SELECT array_remove(array(1,null,3),null);
> [1,3]
SELECT array_remove(array(1,null,3,null,5),null);
> [1,3,5]
SELECT array_remove(array(1,null,3),array(null));
> [1,3]
SELECT array_remove(array('aaa','bbb'),'bbb');
> ["aaa"]
SELECT array_remove(array('aaa','bbb','ccc','bbb'), array('bbb','ccc'));
> ["aaa"]
select array_remove(array(null),null);
> []
select array_remove(array(null,'bbb'),'aaa');
> [null,"bbb"]
```
## Checklist
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #217 from myui/HIVEMALL-165.
---
.../java/hivemall/tools/array/ArrayRemoveUDF.java | 130 ++++++++++++++++++---
.../main/java/hivemall/utils/hadoop/HiveUtils.java | 78 +++++++++++++
docs/gitbook/misc/generic_funcs.md | 26 ++++-
3 files changed, 213 insertions(+), 21 deletions(-)
diff --git a/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java b/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java
index 207c398..5dda0ff 100644
--- a/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java
+++ b/core/src/main/java/hivemall/tools/array/ArrayRemoveUDF.java
@@ -18,35 +18,131 @@
*/
package hivemall.tools.array;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.StringUtils;
+
+import java.util.Collections;
import java.util.List;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+//@formatter:off
@Description(name = "array_remove",
- value = "_FUNC_(array<int|text> original, int|text|array<int> target)"
- + " - Returns an array that the target is removed " + "from the original array",
- extended = "SELECT array_remove(array(1,null,3),array(null));\n" + " [3]\n" + "\n"
- + "SELECT array_remove(array(\"aaa\",\"bbb\"),\"bbb\");\n" + " [\"aaa\"]")
+ value = "_FUNC_(array<PRIMITIVE> values, PRIMITIVE|array<PRIMITIVE> target)"
+ + " - Returns an array that the target elements are removed from the original array",
+ extended = "select array_remove(array(2.0,2.1,3.0,4.0,2.0),2), array_remove(array(2.0,3.0,4.0),array(3,2.0));\n" +
+ "[2.1,3,4] [4]\n" +
+ "\n" +
+ "SELECT array_remove(array(1,null,3),null);\n" +
+ "[1,3]\n" +
+ "\n" +
+ "SELECT array_remove(array(1,null,3,null,5),null);\n" +
+ "[1,3,5]\n" +
+ "\n" +
+ "SELECT array_remove(array(1,null,3),array(null));\n" +
+ "[1,3]\n" +
+ "\n" +
+ "SELECT array_remove(array('aaa','bbb'),'bbb');\n" +
+ "[\"aaa\"]\n" +
+ "\n" +
+ "SELECT array_remove(array('aaa','bbb','ccc','bbb'), array('bbb','ccc'));\n" +
+ "[\"aaa\"]\n" +
+ "\n" +
+ "select array_remove(array(null),null);\n" +
+ "[]\n" +
+ "\n" +
+ "select array_remove(array(null,'bbb'),'aaa');\n" +
+ "[null,\"bbb\"]")
+//@formatter:on
@UDFType(deterministic = true, stateful = false)
-public class ArrayRemoveUDF extends UDF {
+public final class ArrayRemoveUDF extends GenericUDF {
+
+ private ListObjectInspector valueListOI;
+ private PrimitiveObjectInspector valueElemOI;
+ private boolean isTargetList;
+ @Nullable
+ private ListObjectInspector targetListOI;
+ private PrimitiveObjectInspector targetElemOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length != 2) {
+ throw new UDFArgumentLengthException("Expected 2 arguments, but got " + argOIs.length);
+ }
+
+ this.valueListOI = HiveUtils.asListOI(argOIs, 0);
+ this.valueElemOI =
+ HiveUtils.asPrimitiveObjectInspector(valueListOI.getListElementObjectInspector());
+
+ if (HiveUtils.isListOI(argOIs[1])) {
+ this.isTargetList = true;
+ this.targetListOI = HiveUtils.asListOI(argOIs, 1);
+ this.targetElemOI = HiveUtils.asPrimitiveObjectInspector(
+ targetListOI.getListElementObjectInspector());
+ } else {
+ this.isTargetList = false;
+ this.targetElemOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 1);
+ }
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(valueElemOI);
+ }
+
+ @Nullable
+ @Override
+ public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
+ assert (arguments.length == 2);
+
+ final List<?> values = HiveUtils.copyListObject(arguments[0], valueListOI);
+ if (values == null) {
+ return null;
+ }
+
+ final Object target = arguments[1].get();
+ if (target == null) {
+ values.removeAll(Collections.singletonList(null));
+ return values;
+ }
+
+ if (isTargetList) {
+ Converter converter = ObjectInspectorConverters.getConverter(targetListOI, valueListOI);
+ removeAll(values, target, converter, valueListOI);
+ } else {
+ Converter converter = ObjectInspectorConverters.getConverter(targetElemOI, valueElemOI);
+ removeAll(values, target, converter);
+ }
+ return values;
+ }
- public List<IntWritable> evaluate(List<IntWritable> original, IntWritable target) {
- while (original.remove(target));
- return original;
+ private static void removeAll(@Nonnull final List<?> values, @Nonnull final Object target,
+ @Nonnull final Converter converter, @Nonnull final ListObjectInspector valueListOI) {
+ Object converted = converter.convert(target);
+ List<?> convertedList = valueListOI.getList(converted);
+ values.removeAll(convertedList);
}
- public List<IntWritable> evaluate(List<IntWritable> original, List<IntWritable> targets) {
- original.removeAll(targets);
- return original;
+ private static void removeAll(@Nonnull final List<?> values, @Nonnull final Object target,
+ @Nonnull final Converter converter) {
+ Object converted = converter.convert(target);
+ values.removeAll(Collections.singleton(converted));
}
- public List<Text> evaluate(List<Text> original, Text target) {
- while (original.remove(target));
- return original;
+ @Override
+ public String getDisplayString(String[] children) {
+ return "array_remove(" + StringUtils.join(children, ',') + ')';
}
}
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 5c485cd..6c02974 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -31,6 +31,7 @@ import static hivemall.HivemallConstants.TINYINT_TYPE_NAME;
import static hivemall.HivemallConstants.VOID_TYPE_NAME;
import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
@@ -66,6 +67,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.StandardConstantListObjectInspector;
@@ -472,6 +474,82 @@ public final class HiveUtils {
return (ListTypeInfo) typeInfo;
}
+ public static boolean isSameCategoryGroup(@Nonnull final PrimitiveCategory cat1,
+ @Nonnull final PrimitiveCategory cat2) {
+ if (cat1 == cat2) {
+ return true;
+ }
+
+ switch (cat1) {
+ // integers
+ case BYTE:
+ case SHORT:
+ case INT:
+ case LONG: {
+ switch (cat2) {
+ case BYTE:
+ case SHORT:
+ case INT:
+ case LONG:
+ return true;
+ default:
+ return false;
+ }
+ }
+ // floating point number
+ case FLOAT:
+ case DOUBLE: {
+ switch (cat2) {
+ case FLOAT:
+ case DOUBLE:
+ return true;
+ default:
+ return false;
+ }
+ }
+ // string
+ case STRING:
+ case CHAR:
+ case VARCHAR:
+ switch (cat2) {
+ case STRING:
+ case CHAR:
+ case VARCHAR:
+ return true;
+ default:
+ return false;
+ }
+ default:
+ break;
+ }
+ return false;
+ }
+
+ @Nullable
+ public static ArrayList<Object> copyListObject(@Nonnull final DeferredObject argument,
+ @Nonnull final ListObjectInspector loi) throws HiveException {
+ return copyListObject(argument, loi, ObjectInspectorCopyOption.DEFAULT);
+ }
+
+ @Nullable
+ public static ArrayList<Object> copyListObject(@Nonnull final DeferredObject argument,
+ @Nonnull final ListObjectInspector loi,
+ @Nonnull final ObjectInspectorCopyOption objectInspectorOption) throws HiveException {
+ final Object o = argument.get();
+ if (o == null) {
+ return null;
+ }
+
+ final int length = loi.getListLength(o);
+ final ArrayList<Object> list = new ArrayList<Object>(length);
+ for (int i = 0; i < length; i++) {
+ Object e = ObjectInspectorUtils.copyToStandardObject(loi.getListElement(o, i),
+ loi.getListElementObjectInspector(), objectInspectorOption);
+ list.add(e);
+ }
+ return list;
+ }
+
public static float getFloat(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) {
if (o == null) {
return 0.f;
diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md
index 73983f5..fb0b240 100644
--- a/docs/gitbook/misc/generic_funcs.md
+++ b/docs/gitbook/misc/generic_funcs.md
@@ -155,13 +155,31 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
[3]
```
-- `array_remove(array<int|text> original, int|text|array<int> target)` - Returns an array that the target is removed from the original array
+- `array_remove(array<PRIMITIVE> values, PRIMITIVE|array<PRIMITIVE> target)` - Returns an array that the target elements are removed from the original array
```sql
+ select array_remove(array(2.0,2.1,3.0,4.0,2.0),2), array_remove(array(2.0,3.0,4.0),array(3,2.0));
+ [2.1,3,4] [4]
+
+ SELECT array_remove(array(1,null,3),null);
+ [1,3]
+
+ SELECT array_remove(array(1,null,3,null,5),null);
+ [1,3,5]
+
SELECT array_remove(array(1,null,3),array(null));
- [3]
+ [1,3]
+
+ SELECT array_remove(array('aaa','bbb'),'bbb');
+ ["aaa"]
+
+ SELECT array_remove(array('aaa','bbb','ccc','bbb'), array('bbb','ccc'));
+ ["aaa"]
+
+ select array_remove(array(null),null);
+ []
- SELECT array_remove(array("aaa","bbb"),"bbb");
- ["aaa"]
+ select array_remove(array(null,'bbb'),'aaa');
+ [null,"bbb"]
```
- `array_slice(array<ANY> values, int offset [, int length])` - Slices the given array by the given offset and length parameters.