You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:01:57 UTC

[07/50] [abbrv] incubator-hivemall git commit: add array_top_k_indices

add array_top_k_indices



Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e9d1a94f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e9d1a94f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e9d1a94f

Branch: refs/heads/JIRA-22/pr-385
Commit: e9d1a94f29f31e2910a54add7c2625825d715318
Parents: 7b07e4a
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 16:55:57 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 18:37:38 2016 +0900

----------------------------------------------------------------------
 .../tools/array/ArrayTopKIndicesUDF.java        | 96 ++++++++++++++++++++
 1 file changed, 96 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9d1a94f/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
new file mode 100644
index 0000000..bf9fe15
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
@@ -0,0 +1,96 @@
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.IntWritable;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+@Description(name = "array_top_k_indices",
+        value = "_FUNC_(array<number> array, const int k) - Returns indices array of top-k as array<int>")
+public class ArrayTopKIndicesUDF extends GenericUDF {
+    private ListObjectInspector arrayOI;
+    private PrimitiveObjectInspector elementOI;
+    private PrimitiveObjectInspector kOI;
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+        if (OIs.length != 2) {
+            throw new UDFArgumentLengthException("Specify two or three arguments.");
+        }
+
+        if (!HiveUtils.isNumberListOI(OIs[0])) {
+            throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+                    + OIs[0].getTypeName() + " was passed as `array`");
+        }
+        if (!HiveUtils.isIntegerOI(OIs[1])) {
+            throw new UDFArgumentTypeException(1, "Only int type argument is acceptable but "
+                    + OIs[1].getTypeName() + " was passed as `k`");
+        }
+
+        arrayOI = HiveUtils.asListOI(OIs[0]);
+        elementOI = HiveUtils.asDoubleCompatibleOI(arrayOI.getListElementObjectInspector());
+        kOI = HiveUtils.asIntegerOI(OIs[1]);
+
+        return ObjectInspectorFactory.getStandardListObjectInspector(
+                PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+    }
+
+    @Override
+    public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
+        final double[] array = HiveUtils.asDoubleArray(dObj[0].get(), arrayOI, elementOI);
+        final int k = PrimitiveObjectInspectorUtils.getInt(dObj[1].get(), kOI);
+
+        Preconditions.checkNotNull(array);
+        Preconditions.checkArgument(array.length >= k);
+
+        List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>();
+        for (int i = 0; i < array.length; i++) {
+            list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, array[i]));
+        }
+        list.sort(new Comparator<Map.Entry<Integer, Double>>() {
+            @Override
+            public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) {
+                return o1.getValue() > o2.getValue() ? -1 : 1;
+            }
+        });
+
+        List<IntWritable> result = new ArrayList<IntWritable>();
+        for (int i = 0; i < k; i++) {
+            result.add(new IntWritable(list.get(i).getKey()));
+        }
+        return result;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("array_top_k_indices");
+        sb.append("(");
+        if (children.length > 0) {
+            sb.append(children[0]);
+            for (int i = 1; i < children.length; i++) {
+                sb.append(", ");
+                sb.append(children[i]);
+            }
+        }
+        sb.append(")");
+        return sb.toString();
+    }
+}