You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:01:52 UTC
[02/50] [abbrv] incubator-hivemall git commit: add transpose_and_dot
add transpose_and_dot
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/6f9b4fa0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/6f9b4fa0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/6f9b4fa0
Branch: refs/heads/JIRA-22/pr-385
Commit: 6f9b4fa0acebf604882240ccd5507d9df45bab2d
Parents: 56adf2d
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Sep 16 15:52:54 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Sep 16 15:52:54 2016 +0900
----------------------------------------------------------------------
.../tools/matrix/TransposeAndDotUDAF.java | 191 +++++++++++++++++++
1 file changed, 191 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6f9b4fa0/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
new file mode 100644
index 0000000..4fa5ce4
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -0,0 +1,191 @@
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+@Description(name = "transpose_and_dot",
+ value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)" +
+ " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)")
+public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
+ @Override
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
+ ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `matrix0_row`");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but "
+ + OIs[1].getTypeName() + " was passed as `matrix1_row`");
+ }
+
+ return new TransposeAndDotUDAFEvaluator();
+ }
+
+ private static final class TransposeAndDotUDAFEvaluator extends GenericUDAFEvaluator {
+ // PARTIAL1 and COMPLETE
+ private ListObjectInspector matrix0RowOI;
+ private PrimitiveObjectInspector matrix0ElOI;
+ private ListObjectInspector matrix1RowOI;
+ private PrimitiveObjectInspector matrix1ElOI;
+
+ // PARTIAL2 and FINAL
+ private ListObjectInspector aggMatrixOI;
+ private ListObjectInspector aggMatrixRowOI;
+ private DoubleObjectInspector aggMatrixElOI;
+
+ private double[] matrix0Row;
+ private double[] matrix1Row;
+
+ @AggregationType(estimable = true)
+ static class TransposeAndDotAggregationBuffer extends AbstractAggregationBuffer {
+ double[][] aggMatrix;
+
+ @Override
+ public int estimate() {
+ return aggMatrix != null
+ ? aggMatrix.length * aggMatrix[0].length * 8
+ : 0;
+ }
+
+ public void init(int n, int m) {
+ aggMatrix = new double[n][m];
+ }
+
+ public void reset() {
+ if (aggMatrix != null) {
+ for (double[] row : aggMatrix) {
+ Arrays.fill(row, 0.0);
+ }
+ }
+ }
+ }
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
+ super.init(mode, OIs);
+
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+ matrix0RowOI = HiveUtils.asListOI( OIs[0]);
+ matrix0ElOI = HiveUtils.asDoubleCompatibleOI(matrix0RowOI.getListElementObjectInspector());
+ matrix1RowOI = HiveUtils.asListOI(OIs[1]);
+ matrix1ElOI = HiveUtils.asDoubleCompatibleOI(matrix1RowOI.getListElementObjectInspector());
+ } else {
+ aggMatrixOI = HiveUtils.asListOI( OIs[0]);
+ aggMatrixRowOI = HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector());
+ aggMatrixElOI = HiveUtils.asDoubleOI(aggMatrixRowOI.getListElementObjectInspector());
+ }
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ }
+
+ @Override
+ public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer();
+ reset(myAgg);
+ return myAgg;
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+ myAgg.reset();
+ }
+
+ @Override
+ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ if(matrix0Row==null){
+ matrix0Row=new double[matrix0RowOI.getListLength(parameters[0])];
+ }
+ if(matrix1Row==null){
+ matrix1Row=new double[matrix1RowOI.getListLength(parameters[1])];
+ }
+
+ HiveUtils.toDoubleArray(parameters[0], matrix0RowOI, matrix0ElOI, matrix0Row, false);
+ HiveUtils.toDoubleArray(parameters[1], matrix1RowOI, matrix1ElOI, matrix1Row, false);
+
+ Preconditions.checkNotNull(matrix0Row);
+ Preconditions.checkNotNull(matrix1Row);
+
+ if (myAgg.aggMatrix == null) {
+ myAgg.init(matrix0Row.length, matrix1Row.length);
+ }
+
+ for (int i = 0; i < matrix0Row.length; i++) {
+ for (int j = 0; j < matrix1Row.length; j++) {
+ myAgg.aggMatrix[i][j] += matrix0Row[i] * matrix1Row[j];
+ }
+ }
+ }
+
+ @Override
+ public void merge(AggregationBuffer agg, Object other) throws HiveException {
+ if (other == null) {
+ return;
+ }
+
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ List matrix = aggMatrixOI.getList(other);
+ final int n = matrix.size();
+ final double[] row =new double[ aggMatrixRowOI.getListLength(matrix.get(0))];
+ for (int i = 0; i < n; i++) {
+ HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, aggMatrixElOI,row,false);
+
+ if (myAgg.aggMatrix == null) {
+ myAgg.init(n, row.length);
+ }
+
+ for (int j = 0; j < row.length; j++) {
+ myAgg.aggMatrix[i][j] += row[j];
+ }
+ }
+ }
+
+ @Override
+ public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+ return terminate(agg);
+ }
+
+ @Override
+ public Object terminate(AggregationBuffer agg) throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>();
+ for (double[] row : myAgg.aggMatrix) {
+ result.add(WritableUtils.toWritableList(row));
+ }
+ return result;
+ }
+ }
+}