You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 07:04:10 UTC
[08/50] [abbrv] incubator-hivemall git commit: add array_top_k_indices
add array_top_k_indices
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e9d1a94f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e9d1a94f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e9d1a94f
Branch: refs/heads/JIRA-22/pr-385
Commit: e9d1a94f29f31e2910a54add7c2625825d715318
Parents: 7b07e4a
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 16:55:57 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 18:37:38 2016 +0900
----------------------------------------------------------------------
.../tools/array/ArrayTopKIndicesUDF.java | 96 ++++++++++++++++++++
1 file changed, 96 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9d1a94f/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
new file mode 100644
index 0000000..bf9fe15
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
@@ -0,0 +1,96 @@
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.IntWritable;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+@Description(name = "array_top_k_indices",
+ value = "_FUNC_(array<number> array, const int k) - Returns indices array of top-k as array<int>")
+public class ArrayTopKIndicesUDF extends GenericUDF {
+ private ListObjectInspector arrayOI;
+ private PrimitiveObjectInspector elementOI;
+ private PrimitiveObjectInspector kOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two or three arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `array`");
+ }
+ if (!HiveUtils.isIntegerOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1, "Only int type argument is acceptable but "
+ + OIs[1].getTypeName() + " was passed as `k`");
+ }
+
+ arrayOI = HiveUtils.asListOI(OIs[0]);
+ elementOI = HiveUtils.asDoubleCompatibleOI(arrayOI.getListElementObjectInspector());
+ kOI = HiveUtils.asIntegerOI(OIs[1]);
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
+ final double[] array = HiveUtils.asDoubleArray(dObj[0].get(), arrayOI, elementOI);
+ final int k = PrimitiveObjectInspectorUtils.getInt(dObj[1].get(), kOI);
+
+ Preconditions.checkNotNull(array);
+ Preconditions.checkArgument(array.length >= k);
+
+ List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>();
+ for (int i = 0; i < array.length; i++) {
+ list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, array[i]));
+ }
+ list.sort(new Comparator<Map.Entry<Integer, Double>>() {
+ @Override
+ public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) {
+ return o1.getValue() > o2.getValue() ? -1 : 1;
+ }
+ });
+
+ List<IntWritable> result = new ArrayList<IntWritable>();
+ for (int i = 0; i < k; i++) {
+ result.add(new IntWritable(list.get(i).getKey()));
+ }
+ return result;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("array_top_k_indices");
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+}