You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:01:52 UTC

[02/50] [abbrv] incubator-hivemall git commit: add transpose_and_dot

add transpose_and_dot



Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/6f9b4fa0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/6f9b4fa0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/6f9b4fa0

Branch: refs/heads/JIRA-22/pr-385
Commit: 6f9b4fa0acebf604882240ccd5507d9df45bab2d
Parents: 56adf2d
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Sep 16 15:52:54 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Sep 16 15:52:54 2016 +0900

----------------------------------------------------------------------
 .../tools/matrix/TransposeAndDotUDAF.java       | 191 +++++++++++++++++++
 1 file changed, 191 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6f9b4fa0/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
new file mode 100644
index 0000000..4fa5ce4
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -0,0 +1,191 @@
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+@Description(name = "transpose_and_dot",
+        value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)" +
+                " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)")
+public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
+    @Override
+    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
+        ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+        if (OIs.length != 2) {
+            throw new UDFArgumentLengthException("Specify two arguments.");
+        }
+
+        if (!HiveUtils.isNumberListOI(OIs[0])) {
+            throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+                    + OIs[0].getTypeName() + " was passed as `matrix0_row`");
+        }
+
+        if (!HiveUtils.isNumberListOI(OIs[1])) {
+            throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but "
+                    + OIs[1].getTypeName() + " was passed as `matrix1_row`");
+        }
+
+        return new TransposeAndDotUDAFEvaluator();
+    }
+
+    private static final class TransposeAndDotUDAFEvaluator extends GenericUDAFEvaluator {
+        // PARTIAL1 and COMPLETE
+        private ListObjectInspector matrix0RowOI;
+        private PrimitiveObjectInspector matrix0ElOI;
+        private ListObjectInspector matrix1RowOI;
+        private PrimitiveObjectInspector matrix1ElOI;
+
+        // PARTIAL2 and FINAL
+        private ListObjectInspector aggMatrixOI;
+        private ListObjectInspector aggMatrixRowOI;
+        private DoubleObjectInspector aggMatrixElOI;
+
+        private double[] matrix0Row;
+        private double[] matrix1Row;
+
+        @AggregationType(estimable = true)
+        static class TransposeAndDotAggregationBuffer extends AbstractAggregationBuffer {
+            double[][] aggMatrix;
+
+            @Override
+            public int estimate() {
+                return aggMatrix != null
+                        ? aggMatrix.length * aggMatrix[0].length * 8
+                        : 0;
+            }
+
+            public void init(int n, int m) {
+                aggMatrix = new double[n][m];
+            }
+
+            public void reset() {
+                if (aggMatrix != null) {
+                    for (double[] row : aggMatrix) {
+                        Arrays.fill(row, 0.0);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
+            super.init(mode, OIs);
+
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+                matrix0RowOI = HiveUtils.asListOI( OIs[0]);
+                matrix0ElOI = HiveUtils.asDoubleCompatibleOI(matrix0RowOI.getListElementObjectInspector());
+                matrix1RowOI = HiveUtils.asListOI(OIs[1]);
+                matrix1ElOI = HiveUtils.asDoubleCompatibleOI(matrix1RowOI.getListElementObjectInspector());
+            } else {
+                aggMatrixOI =  HiveUtils.asListOI( OIs[0]);
+                aggMatrixRowOI =  HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector());
+                aggMatrixElOI = HiveUtils.asDoubleOI(aggMatrixRowOI.getListElementObjectInspector());
+            }
+
+            return ObjectInspectorFactory.getStandardListObjectInspector(
+                    ObjectInspectorFactory.getStandardListObjectInspector(
+                            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        }
+
+        @Override
+        public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
+            TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer();
+            reset(myAgg);
+            return myAgg;
+        }
+
+        @Override
+        public void reset(AggregationBuffer agg) throws HiveException {
+            TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+            myAgg.reset();
+        }
+
+        @Override
+        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+            TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+            if(matrix0Row==null){
+                matrix0Row=new double[matrix0RowOI.getListLength(parameters[0])];
+            }
+            if(matrix1Row==null){
+                matrix1Row=new double[matrix1RowOI.getListLength(parameters[1])];
+            }
+
+            HiveUtils.toDoubleArray(parameters[0], matrix0RowOI, matrix0ElOI, matrix0Row, false);
+            HiveUtils.toDoubleArray(parameters[1], matrix1RowOI, matrix1ElOI, matrix1Row, false);
+
+            Preconditions.checkNotNull(matrix0Row);
+            Preconditions.checkNotNull(matrix1Row);
+
+            if (myAgg.aggMatrix == null) {
+                myAgg.init(matrix0Row.length, matrix1Row.length);
+            }
+
+            for (int i = 0; i < matrix0Row.length; i++) {
+                for (int j = 0; j < matrix1Row.length; j++) {
+                    myAgg.aggMatrix[i][j] += matrix0Row[i] * matrix1Row[j];
+                }
+            }
+        }
+
+        @Override
+        public void merge(AggregationBuffer agg, Object other) throws HiveException {
+            if (other == null) {
+                return;
+            }
+
+            TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+            List matrix = aggMatrixOI.getList(other);
+            final int n = matrix.size();
+            final double[] row =new double[ aggMatrixRowOI.getListLength(matrix.get(0))];
+            for (int i = 0; i < n; i++) {
+                HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, aggMatrixElOI,row,false);
+
+                if (myAgg.aggMatrix == null) {
+                    myAgg.init(n, row.length);
+                }
+
+                for (int j = 0; j < row.length; j++) {
+                    myAgg.aggMatrix[i][j] += row[j];
+                }
+            }
+        }
+
+        @Override
+        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+            return terminate(agg);
+        }
+
+        @Override
+        public Object terminate(AggregationBuffer agg) throws HiveException {
+            TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+            List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>();
+            for (double[] row : myAgg.aggMatrix) {
+                result.add(WritableUtils.toWritableList(row));
+            }
+            return result;
+        }
+    }
+}