You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/14 18:18:00 UTC
[2/2] incubator-systemml git commit: New cholesky cp builtin function,
by Shirish, incl cleanups
New cholesky cp builtin function, by Shirish, incl cleanups
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/4f4e94ec
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/4f4e94ec
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/4f4e94ec
Branch: refs/heads/master
Commit: 4f4e94ec11e4dcffbcafc731d7eb685d6d0524ee
Parents: 0280696
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Wed Jan 13 20:13:43 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Wed Jan 13 20:13:43 2016 -0800
----------------------------------------------------------------------
src/main/java/org/apache/sysml/hops/Hop.java | 3 +-
.../java/org/apache/sysml/hops/UnaryOp.java | 5 +-
.../hops/cost/CostEstimatorStaticRuntime.java | 1 +
src/main/java/org/apache/sysml/lops/Unary.java | 34 ++++--
.../sysml/parser/BuiltinFunctionExpression.java | 20 ++++
.../org/apache/sysml/parser/DMLTranslator.java | 7 +-
.../org/apache/sysml/parser/Expression.java | 1 +
.../instructions/CPInstructionParser.java | 1 +
.../runtime/matrix/data/LibCommonsMath.java | 119 ++++++++++---------
9 files changed, 122 insertions(+), 69 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 3d2c76b..e031a87 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1026,7 +1026,7 @@ public abstract class Hop
public enum OpOp1 {
NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SIGN, SQRT, LOG, EXP,
CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
- PRINT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, FLOOR, MEDIAN, INVERSE,
+ PRINT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, FLOOR, MEDIAN, INVERSE, CHOLESKY,
//cumulative sums, products, extreme values
CUMSUM, CUMPROD, CUMMIN, CUMMAX,
//fused ML-specific operators for performance
@@ -1245,6 +1245,7 @@ public abstract class Hop
HopsOpOp1LopsU.put(OpOp1.CUMMIN, org.apache.sysml.lops.Unary.OperationTypes.CUMMIN);
HopsOpOp1LopsU.put(OpOp1.CUMMAX, org.apache.sysml.lops.Unary.OperationTypes.CUMMAX);
HopsOpOp1LopsU.put(OpOp1.INVERSE, org.apache.sysml.lops.Unary.OperationTypes.INVERSE);
+ HopsOpOp1LopsU.put(OpOp1.CHOLESKY, org.apache.sysml.lops.Unary.OperationTypes.CHOLESKY);
HopsOpOp1LopsU.put(OpOp1.CAST_AS_SCALAR, org.apache.sysml.lops.Unary.OperationTypes.NOTSUPPORTED);
HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, org.apache.sysml.lops.Unary.OperationTypes.NOTSUPPORTED);
HopsOpOp1LopsU.put(OpOp1.SPROP, org.apache.sysml.lops.Unary.OperationTypes.SPROP);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index a3ca530..4437cdb 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -674,8 +674,11 @@ public class UnaryOp extends Hop
setRequiresRecompile();
//ensure cp exec type for single-node operations
- if( _op == OpOp1.PRINT || _op == OpOp1.STOP || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN )
+ if( _op == OpOp1.PRINT || _op == OpOp1.STOP
+ || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY )
+ {
_etype = ExecType.CP;
+ }
return _etype;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
index 73203f0..b68285f 100644
--- a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
+++ b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
@@ -1005,6 +1005,7 @@ public class CostEstimatorStaticRuntime extends CostEstimator
case BuiltinUnary: //opcodes: exp, abs, sin, cos, tan, sign, sqrt, plogp, print, round, sprop, sigmoid
+ //TODO add cost functions for commons math builtins: inverse, cholesky
if( optype.equals("print") ) //scalar only
return 1;
else
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/lops/Unary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Unary.java b/src/main/java/org/apache/sysml/lops/Unary.java
index 8375202..1f56d66 100644
--- a/src/main/java/org/apache/sysml/lops/Unary.java
+++ b/src/main/java/org/apache/sysml/lops/Unary.java
@@ -40,7 +40,7 @@ public class Unary extends Lop
ADD, SUBTRACT, SUBTRACTRIGHT, MULTIPLY, MULTIPLY2, DIVIDE, MODULUS, INTDIV, MINUS1_MULTIPLY,
POW, POW2, LOG, MAX, MIN, NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SIGN, SQRT, EXP, Over,
LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
- ROUND, CEIL, FLOOR, MR_IQM, INVERSE,
+ ROUND, CEIL, FLOOR, MR_IQM, INVERSE, CHOLESKY,
CUMSUM, CUMPROD, CUMMIN, CUMMAX,
SPROP, SIGMOID, SELP, SUBTRACT_NZ, LOG_NZ,
NOTSUPPORTED
@@ -107,27 +107,32 @@ public class Unary extends Lop
*
* @param input1
* @param op
+ * @throws LopsException
*/
- public Unary(Lop input1, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
+ public Unary(Lop input1, OperationTypes op, DataType dt, ValueType vt, ExecType et)
+ throws LopsException
+ {
super(Lop.Type.UNARY, dt, vt);
init(input1, op, dt, vt, et);
}
- public Unary(Lop input1, OperationTypes op, DataType dt, ValueType vt) {
+ public Unary(Lop input1, OperationTypes op, DataType dt, ValueType vt)
+ throws LopsException
+ {
super(Lop.Type.UNARY, dt, vt);
init(input1, op, dt, vt, ExecType.MR);
}
- private ExecType forceExecType(OperationTypes op, ExecType et) {
- if ( op == OperationTypes.INVERSE )
- return ExecType.CP;
- return et;
- }
- private void init(Lop input1, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
- operation = op;
-
- et = forceExecType(op, et);
+ private void init(Lop input1, OperationTypes op, DataType dt, ValueType vt, ExecType et)
+ throws LopsException
+ {
+ //sanity check
+ if ( (op == OperationTypes.INVERSE || op == OperationTypes.CHOLESKY)
+ && (et == ExecType.SPARK || et == ExecType.MR) ) {
+ throw new LopsException("Invalid exection type "+et.toString()+" for operation "+op.toString());
+ }
+ operation = op;
valInput = null;
this.addInput(input1);
@@ -301,7 +306,10 @@ public class Unary extends Lop
case INVERSE:
return "inverse";
-
+
+ case CHOLESKY:
+ return "cholesky";
+
case MR_IQM:
return "qpick";
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index d56aa08..9273590 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -949,6 +949,24 @@ public class BuiltinFunctionExpression extends DataIdentifier
output.setDimensions(in.getDim1(), in.getDim2());
output.setBlockDimensions(in.getRowsInBlock(), in.getColumnsInBlock());
break;
+
+ case CHOLESKY:
+ {
+ // A = L%*%t(L) where L is the lower triangular matrix
+ checkNumParameters(1);
+ checkMatrixParam(getFirstExpr());
+
+ output.setDataType(DataType.MATRIX);
+ output.setValueType(ValueType.DOUBLE);
+
+ Identifier inA = getFirstExpr().getOutput();
+ if(inA.dimsKnown() && inA.getDim1() != inA.getDim2())
+ raiseValidateError("Input to cholesky() must be square matrix -- given: a " + inA.getDim1() + "x" + inA.getDim2() + " matrix.", conditional);
+
+ output.setDimensions(inA.getDim1(), inA.getDim2());
+ output.setBlockDimensions(inA.getRowsInBlock(), inA.getColumnsInBlock());
+ break;
+ }
case OUTER:
Identifier id2 = this.getSecondExpr().getOutput();
@@ -1424,6 +1442,8 @@ public class BuiltinFunctionExpression extends DataIdentifier
bifop = Expression.BuiltinFunctionOp.MEDIAN;
else if (functionName.equals("inv"))
bifop = Expression.BuiltinFunctionOp.INVERSE;
+ else if (functionName.equals("cholesky"))
+ bifop = Expression.BuiltinFunctionOp.CHOLESKY;
else if (functionName.equals("sample"))
bifop = Expression.BuiltinFunctionOp.SAMPLE;
else if ( functionName.equals("outer") )
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index c643396..017c246 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2708,10 +2708,15 @@ public class DMLTranslator
break;
case INVERSE:
- currBuiltinOp=new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
Hop.OpOp1.INVERSE, expr);
break;
+ case CHOLESKY:
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
+ Hop.OpOp1.CHOLESKY, expr);
+ break;
+
case OUTER:
if( !(expr3 instanceof LiteralOp) )
throw new HopsException("Operator for outer builtin function must be a constant: "+expr3);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index 709a581..99edb74 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -60,6 +60,7 @@ public abstract class Expression
CAST_AS_SCALAR,
CBIND, //previously APPEND
CEIL,
+ CHOLESKY,
COLMAX,
COLMEAN,
COLMIN,
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index c70f6a9..4cb67e3 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -167,6 +167,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "ucummax", CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "stop" , CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "inverse", CPINSTRUCTION_TYPE.BuiltinUnary);
+ String2CPInstructionType.put( "cholesky",CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "sprop", CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "sigmoid", CPINSTRUCTION_TYPE.BuiltinUnary);
String2CPInstructionType.put( "sel+", CPINSTRUCTION_TYPE.BuiltinUnary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4f4e94ec/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java
index 52569d1..cb163f8 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java
@@ -20,12 +20,12 @@
package org.apache.sysml.runtime.matrix.data;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
-
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.util.DataConverter;
@@ -38,62 +38,52 @@ import org.apache.sysml.runtime.util.DataConverter;
* matrix inverse, matrix decompositions (QR, LU, Eigen), solve
*/
public class LibCommonsMath
-{
-
- public static boolean isSupportedUnaryOperation( String opcode )
- {
- if ( opcode.equals("inverse") ) {
- return true;
- }
- return false;
+{
+ private LibCommonsMath() {
+ //prevent instantiation via private constructor
}
- public static boolean isSupportedMultiReturnOperation( String opcode )
- {
-
- if ( opcode.equals("qr") || opcode.equals("lu") || opcode.equals("eigen") ) {
- return true;
- }
- return false;
+ public static boolean isSupportedUnaryOperation( String opcode ) {
+ return ( opcode.equals("inverse") || opcode.equals("cholesky") );
}
- public static boolean isSupportedMatrixMatrixOperation( String opcode )
- {
- if ( opcode.equals("solve") ) {
- return true;
- }
- return false;
+ public static boolean isSupportedMultiReturnOperation( String opcode ) {
+ return ( opcode.equals("qr") || opcode.equals("lu") || opcode.equals("eigen") );
}
- private LibCommonsMath() {
- //prevent instantiation via private constructor
+ public static boolean isSupportedMatrixMatrixOperation( String opcode ) {
+ return ( opcode.equals("solve") );
}
-
- public static MatrixBlock unaryOperations(MatrixObject inj, String opcode) throws DMLRuntimeException {
+
+ public static MatrixBlock unaryOperations(MatrixObject inj, String opcode)
+ throws DMLRuntimeException
+ {
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(inj);
- MatrixBlock out = null;
if(opcode.equals("inverse"))
- out = computeMatrixInverse(matrixInput);
-
- return out;
+ return computeMatrixInverse(matrixInput);
+ else if (opcode.equals("cholesky"))
+ return computeCholesky(matrixInput);
+ return null;
}
- public static MatrixBlock[] multiReturnOperations(MatrixObject in, String opcode) throws DMLRuntimeException {
- MatrixBlock[] out = null;
+ public static MatrixBlock[] multiReturnOperations(MatrixObject in, String opcode)
+ throws DMLRuntimeException
+ {
if(opcode.equals("qr"))
- out = computeQR(in);
+ return computeQR(in);
else if (opcode.equals("lu"))
- out = computeLU(in);
+ return computeLU(in);
else if (opcode.equals("eigen"))
- out = computeEigen(in);
- return out;
+ return computeEigen(in);
+ return null;
}
- public static MatrixBlock matrixMatrixOperations(MatrixObject in1, MatrixObject in2, String opcode) throws DMLRuntimeException {
- MatrixBlock out = null;
+ public static MatrixBlock matrixMatrixOperations(MatrixObject in1, MatrixObject in2, String opcode)
+ throws DMLRuntimeException
+ {
if(opcode.equals("solve"))
- out = computeSolve(in1, in2);
- return out;
+ return computeSolve(in1, in2);
+ return null;
}
/**
@@ -104,7 +94,9 @@ public class LibCommonsMath
* @return
* @throws DMLRuntimeException
*/
- private static MatrixBlock computeSolve(MatrixObject in1, MatrixObject in2) throws DMLRuntimeException {
+ private static MatrixBlock computeSolve(MatrixObject in1, MatrixObject in2)
+ throws DMLRuntimeException
+ {
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in1);
Array2DRowRealMatrix vectorInput = DataConverter.convertToArray2DRowRealMatrix(in2);
@@ -118,9 +110,7 @@ public class LibCommonsMath
// Invoke solve
RealMatrix solutionMatrix = solver.solve(vectorInput);
- MatrixBlock solution = DataConverter.convertToMatrixBlock(solutionMatrix.getData());
-
- return solution;
+ return DataConverter.convertToMatrixBlock(solutionMatrix.getData());
}
/**
@@ -130,7 +120,9 @@ public class LibCommonsMath
* @return
* @throws DMLRuntimeException
*/
- private static MatrixBlock[] computeQR(MatrixObject in) throws DMLRuntimeException {
+ private static MatrixBlock[] computeQR(MatrixObject in)
+ throws DMLRuntimeException
+ {
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
// Perform QR decomposition
@@ -152,7 +144,9 @@ public class LibCommonsMath
* @return
* @throws DMLRuntimeException
*/
- private static MatrixBlock[] computeLU(MatrixObject in) throws DMLRuntimeException {
+ private static MatrixBlock[] computeLU(MatrixObject in)
+ throws DMLRuntimeException
+ {
if ( in.getNumRows() != in.getNumColumns() ) {
throw new DMLRuntimeException("LU Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols="+ in.getNumColumns() +")");
}
@@ -181,7 +175,9 @@ public class LibCommonsMath
* @return
* @throws DMLRuntimeException
*/
- private static MatrixBlock[] computeEigen(MatrixObject in) throws DMLRuntimeException {
+ private static MatrixBlock[] computeEigen(MatrixObject in)
+ throws DMLRuntimeException
+ {
if ( in.getNumRows() != in.getNumColumns() ) {
throw new DMLRuntimeException("Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols="+ in.getNumColumns() +")");
}
@@ -228,8 +224,9 @@ public class LibCommonsMath
* @return
* @throws DMLRuntimeException
*/
- private static MatrixBlock computeMatrixInverse(Array2DRowRealMatrix in) throws DMLRuntimeException {
-
+ private static MatrixBlock computeMatrixInverse(Array2DRowRealMatrix in)
+ throws DMLRuntimeException
+ {
if ( !in.isSquare() )
throw new DMLRuntimeException("Input to inv() must be square matrix -- given: a " + in.getRowDimension() + "x" + in.getColumnDimension() + " matrix.");
@@ -237,10 +234,26 @@ public class LibCommonsMath
DecompositionSolver solver = qrdecompose.getSolver();
RealMatrix inverseMatrix = solver.getInverse();
- MatrixBlock inverse = DataConverter.convertToMatrixBlock(inverseMatrix.getData());
-
- return inverse;
+ return DataConverter.convertToMatrixBlock(inverseMatrix.getData());
}
-}
+ /**
+ * Function to compute Cholesky decomposition of the given input matrix.
+ * The input must be a real symmetric positive-definite matrix.
+ *
+ * @param in
+ * @return
+ * @throws DMLRuntimeException
+ */
+ private static MatrixBlock computeCholesky(Array2DRowRealMatrix in)
+ throws DMLRuntimeException
+ {
+ if ( !in.isSquare() )
+ throw new DMLRuntimeException("Input to cholesky() must be square matrix -- given: a " + in.getRowDimension() + "x" + in.getColumnDimension() + " matrix.");
+ CholeskyDecomposition cholesky = new CholeskyDecomposition(in);
+ RealMatrix rmL = cholesky.getL();
+
+ return DataConverter.convertToMatrixBlock(rmL.getData());
+ }
+}