You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/07/26 05:49:00 UTC

[2/8] incubator-hivemall git commit: Create FFMPredictGenericUDAF

Create FFMPredictGenericUDAF


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/343f704d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/343f704d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/343f704d

Branch: refs/heads/HIVEMALL-24-2
Commit: 343f704dc4eeee25f195e298ee014368db1fef9e
Parents: 591e3b0
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Fri Mar 3 17:48:26 2017 +0900
Committer: Takuya Kitazawa <k....@gmail.com>
Committed: Fri Mar 3 17:48:26 2017 +0900

----------------------------------------------------------------------
 .../java/hivemall/fm/FFMPredictGenericUDAF.java | 272 +++++++++++++++++++
 1 file changed, 272 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/343f704d/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
new file mode 100644
index 0000000..91d1b6b
--- /dev/null
+++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.fm;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
+import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
+
+@Description(
+        name = "ffm_predict",
+        value = "_FUNC_(Float Wi, Float Wj, array<float> Vifj, array<float> Vjfi, float Xi, float Xj)"
+                + " - Returns a prediction value in Double")
+public final class FFMPredictGenericUDAF extends AbstractGenericUDAFResolver {
+
+    private FFMPredictGenericUDAF() {}
+
+    @Override
+    public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
+        if (typeInfo.length != 5) {
+            throw new UDFArgumentLengthException(
+                "Expected argument length is 6 but given argument length was " + typeInfo.length);
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) {
+            throw new UDFArgumentTypeException(0,
+                "Number type is expected for the first argument Wi: " + typeInfo[0].getTypeName());
+        }
+        if (typeInfo[1].getCategory() != Category.LIST) {
+            throw new UDFArgumentTypeException(1,
+                "List type is expected for the second argument Vifj: " + typeInfo[1].getTypeName());
+        }
+        if (typeInfo[2].getCategory() != Category.LIST) {
+            throw new UDFArgumentTypeException(2,
+                "List type is expected for the third argument Vjfi: " + typeInfo[2].getTypeName());
+        }
+        ListTypeInfo typeInfo1 = (ListTypeInfo) typeInfo[1];
+        if (!HiveUtils.isNumberTypeInfo(typeInfo1.getListElementTypeInfo())) {
+            throw new UDFArgumentTypeException(1,
+                "Number type is expected for the element type of list Vifj: "
+                        + typeInfo1.getTypeName());
+        }
+        ListTypeInfo typeInfo2 = (ListTypeInfo) typeInfo[2];
+        if (!HiveUtils.isNumberTypeInfo(typeInfo2.getListElementTypeInfo())) {
+            throw new UDFArgumentTypeException(2,
+                "Number type is expected for the element type of list Vjfi: "
+                        + typeInfo1.getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
+            throw new UDFArgumentTypeException(3,
+                "Number type is expected for the third argument Xi: " + typeInfo[3].getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[4])) {
+            throw new UDFArgumentTypeException(4,
+                "Number type is expected for the third argument Xi: " + typeInfo[4].getTypeName());
+        }
+        return new Evaluator();
+    }
+
+    public static class Evaluator extends GenericUDAFEvaluator {
+
+        // input OI
+        private PrimitiveObjectInspector wiOI;
+        private ListObjectInspector vijOI;
+        private ListObjectInspector vjiOI;
+        private PrimitiveObjectInspector xiOI;
+        private PrimitiveObjectInspector xjOI;
+
+        // merge OI
+        private StructObjectInspector internalMergeOI;
+        private StructField sumField;
+
+        public Evaluator() {}
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+            assert (parameters.length == 5);
+            super.init(mode, parameters);
+
+            // initialize input
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+                this.wiOI = HiveUtils.asDoubleCompatibleOI(parameters[0]);
+                this.vijOI = HiveUtils.asListOI(parameters[1]);
+                this.vjiOI = HiveUtils.asListOI(parameters[2]);
+                this.xiOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
+                this.xjOI = HiveUtils.asDoubleCompatibleOI(parameters[4]);
+            } else {// from partial aggregation
+                StructObjectInspector soi = (StructObjectInspector) parameters[0];
+                this.internalMergeOI = soi;
+                this.sumField = soi.getStructFieldRef("sum");
+            }
+
+            // initialize output
+            final ObjectInspector outputOI;
+            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+                outputOI = internalMergeOI();
+            } else {
+                outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+            }
+            return outputOI;
+        }
+
+        private static StructObjectInspector internalMergeOI() {
+            ArrayList<String> fieldNames = new ArrayList<String>();
+            ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+            fieldNames.add("sum");
+            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+        }
+
+        @Override
+        public FFMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException {
+            FFMPredictAggregationBuffer myAggr = new FFMPredictAggregationBuffer();
+            reset(myAggr);
+            return myAggr;
+        }
+
+        @Override
+        public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+            myAggr.reset();
+        }
+
+        @Override
+        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+                Object[] parameters) throws HiveException {
+            if (parameters[0] == null) {
+                return;
+            }
+            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+
+            double wi = PrimitiveObjectInspectorUtils.getDouble(parameters[0], wiOI);
+            if (parameters[3] == null && parameters[4] == null) {// Xi and Xj are null => global bias `w0`
+                myAggr.addW0(wi);
+            } else if (parameters[4] == null) {// Only Xi is nonnull => linear combination `wi` * `xi`
+                double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI);
+                myAggr.addWiXi(wi, xi);
+            } else {// both Xi and Xj are nonnull => <Vifj, Vjfi> Xi Xj
+                if (parameters[1] == null || parameters[2] == null) {
+                    throw new UDFArgumentException("The second and third arguments (Vij, Vji) must not be null");
+                }
+
+                List<Float> vij = (List<Float>) vijOI.getList(parameters[1]);
+                List<Float> vji = (List<Float>) vjiOI.getList(parameters[2]);
+
+                if (vij.size() != vji.size()) {
+                    throw new HiveException("Mismatch in the number of factors");
+                }
+
+                double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI);
+                double xj = PrimitiveObjectInspectorUtils.getDouble(parameters[4], xjOI);
+
+                myAggr.addViVjXiXj(vij, vji, xi, xj);
+            }
+        }
+
+        @Override
+        public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+
+            final Object[] partialResult = new Object[1];
+            return partialResult;
+        }
+
+        @Override
+        public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+                throws HiveException {
+            if (partial == null) {
+                return;
+            }
+
+            Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
+            double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
+
+            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+            myAggr.merge(sum);
+        }
+
+        @Override
+        public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+            double result = myAggr.get();
+            return new DoubleWritable(result);
+        }
+
+    }
+
+    public static class FFMPredictAggregationBuffer extends AbstractAggregationBuffer {
+
+        double sum;
+
+        FFMPredictAggregationBuffer() {
+            super();
+        }
+
+        void reset() {
+            this.sum = 0.d;
+        }
+
+        void merge(final double o_sum) {
+            sum += o_sum;
+        }
+
+        double get() {
+            return sum;
+        }
+
+        void addW0(final double W0) {
+            sum += W0;
+        }
+
+        void addWiXi(final double Wi, final double Xi) {
+            sum += Wi * Xi;
+        }
+
+        void addViVjXiXj(@Nonnull final List<Float> Vifj, @Nonnull final List<Float> Vjfi,
+                     final double Xi, final double Xj) {
+            final int factors = Vifj.size();
+            double prod = 0.d;
+
+            // compute inner product <Vifj, Vjfi>
+            for (int i = 0; i < factors; i++) {
+                prod += (Vifj.get(i) * Vjfi.get(i));
+            }
+
+            sum += (prod * Xi * Xj);
+        }
+
+    }
+
+}