You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/07/19 17:59:13 UTC
incubator-systemml git commit: [SYSTEMML-766] Extended axpy
compiler/runtime support (mr, hybrid)
Repository: incubator-systemml
Updated Branches:
refs/heads/master c22f239e3 -> b584aecf6
[SYSTEMML-766] Extended axpy compiler/runtime support (mr, hybrid)
Incl fix rewrite 'fused binary operation chain' axpy.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b584aecf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b584aecf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b584aecf
Branch: refs/heads/master
Commit: b584aecf6b3a1eb96ff83b78cc3ad7c7c6d15baa
Parents: c22f239
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Mon Jul 18 19:46:55 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Tue Jul 19 10:58:49 2016 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/BinaryOp.java | 2 +-
.../java/org/apache/sysml/hops/TernaryOp.java | 47 ++++++++++++-
.../RewriteAlgebraicSimplificationStatic.java | 12 ++--
.../java/org/apache/sysml/lops/PlusMult.java | 58 ++++++++++++----
.../runtime/functionobjects/MinusMultiply.java | 18 +++--
.../runtime/functionobjects/PlusMultiply.java | 18 +++--
.../ValueFunctionWithConstant.java | 6 +-
.../runtime/instructions/InstructionUtils.java | 6 ++
.../instructions/MRInstructionParser.java | 7 ++
.../instructions/cp/PlusMultCPInstruction.java | 17 +++--
.../runtime/instructions/mr/MRInstruction.java | 2 +-
.../instructions/mr/PlusMultInstruction.java | 69 ++++++++++++++++++++
.../spark/PlusMultSPInstruction.java | 12 ++--
.../misc/RewriteFuseBinaryOpChainTest.java | 46 +++++++++----
14 files changed, 249 insertions(+), 71 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 65e9232..edc327d 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -1335,7 +1335,7 @@ public class BinaryOp extends Hop
* @param right
* @return
*/
- private static boolean requiresReplication( Hop left, Hop right )
+ public static boolean requiresReplication( Hop left, Hop right )
{
return (!(left.getDim2()>=1 && right.getDim2()>=1) //cols of any input unknown
||(left.getDim2() > 1 && right.getDim2()==1 && left.getDim2()>=left.getColsInBlock() ) //col MV and more than 1 block
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/TernaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java
index 72e7624..626ad2c 100644
--- a/src/main/java/org/apache/sysml/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java
@@ -31,6 +31,7 @@ import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.PickByCount;
import org.apache.sysml.lops.PlusMult;
+import org.apache.sysml.lops.RepMat;
import org.apache.sysml.lops.SortKeys;
import org.apache.sysml.lops.Ternary;
import org.apache.sysml.lops.UnaryCP;
@@ -627,16 +628,58 @@ public class TernaryOp extends Hop
}
}
}
- private void constructLopsPlusMult() throws HopsException, LopsException {
+
+ /**
+ *
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private void constructLopsPlusMult()
+ throws HopsException, LopsException
+ {
if ( _op != OpOp3.PLUS_MULT && _op != OpOp3.MINUS_MULT )
throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.PLUS_MULT + " or" + OpOp3.MINUS_MULT);
ExecType et = optFindExecType();
- PlusMult plusmult = new PlusMult(getInput().get(0).constructLops(),getInput().get(1).constructLops(),getInput().get(2).constructLops(), _op, getDataType(),getValueType(), et );
+ PlusMult plusmult = null;
+
+ if( et == ExecType.CP || et == ExecType.SPARK ) {
+ plusmult = new PlusMult(
+ getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(),
+ getInput().get(2).constructLops(),
+ _op, getDataType(),getValueType(), et );
+ }
+ else { //MR
+ Hop left = getInput().get(0);
+ Hop right = getInput().get(2);
+ boolean requiresRep = BinaryOp.requiresReplication(left, right);
+
+ Lop rightLop = right.constructLops();
+ if( requiresRep ) {
+ Lop offset = createOffsetLop(left, (right.getDim2()<=1)); //ncol of left input (determines num replicates)
+ rightLop = new RepMat(rightLop, offset, (right.getDim2()<=1), right.getDataType(), right.getValueType());
+ setOutputDimensions(rightLop);
+ setLineNumbers(rightLop);
+ }
+
+ Group group1 = new Group(left.constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
+ setLineNumbers(group1);
+ setOutputDimensions(group1);
+
+ Group group2 = new Group(rightLop, Group.OperationTypes.Sort, getDataType(), getValueType());
+ setLineNumbers(group2);
+ setOutputDimensions(group2);
+
+ plusmult = new PlusMult(group1, getInput().get(1).constructLops(),
+ group2, _op, getDataType(),getValueType(), et );
+ }
+
setOutputDimensions(plusmult);
setLineNumbers(plusmult);
setLops(plusmult);
}
+
@Override
public String getOpString() {
String s = new String("");
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 816b55a..9ef2c05 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -25,8 +25,6 @@ import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
@@ -34,7 +32,6 @@ import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.IndexingOp;
-import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
@@ -1920,10 +1917,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
*/
private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) {
//pattern: X + lamda*Y -> X +* lambda Y
- if( hi instanceof BinaryOp
- && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS)
- && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof BinaryOp
- && (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) )
+ if( hi instanceof BinaryOp
+ && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS)
+ && hi.getInput().get(0).getDataType()==DataType.MATRIX
+ && hi.getInput().get(1) instanceof BinaryOp
+ && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT )
{
//Check that the inner binary Op is a product of Scalar times Matrix or viceversa
Hop innerBinaryOp = hi.getInput().get(1);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/lops/PlusMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java b/src/main/java/org/apache/sysml/lops/PlusMult.java
index 65e6440..8ee8625 100644
--- a/src/main/java/org/apache/sysml/lops/PlusMult.java
+++ b/src/main/java/org/apache/sysml/lops/PlusMult.java
@@ -34,9 +34,9 @@ public class PlusMult extends Lop
{
private void init(Lop input1, Lop input2, Lop input3, ExecType et) {
- this.addInput(input1);
- this.addInput(input2);
- this.addInput(input3);
+ addInput(input1);
+ addInput(input2);
+ addInput(input3);
input1.addOutput(this);
input2.addOutput(this);
input3.addOutput(this);
@@ -47,7 +47,13 @@ public class PlusMult extends Lop
if ( et == ExecType.CP || et == ExecType.SPARK ){
lps.addCompatibility(JobType.INVALID);
- this.lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ }
+ else if( et == ExecType.MR ) {
+ lps.addCompatibility(JobType.GMR);
+ lps.addCompatibility(JobType.DATAGEN);
+ lps.addCompatibility(JobType.REBLOCK);
+ lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob );
}
}
@@ -60,13 +66,15 @@ public class PlusMult extends Lop
@Override
public String toString() {
-
return "Operation = PlusMult";
}
+ public String getOpString() {
+ return (type==Lop.Type.PlusMult) ? "+*" : "-*";
+ }
/**
- * Function to generate CP Sum of a matrix with another matrix multiplied by Scalar.
+ * Function to generate CP/Spark axpy.
*
* input1: matrix1
* input2: Scalar
@@ -75,23 +83,51 @@ public class PlusMult extends Lop
@Override
public String getInstructions(String input1, String input2, String input3, String output) {
StringBuilder sb = new StringBuilder();
+
sb.append( getExecType() );
sb.append( OPERAND_DELIMITOR );
- if(type==Lop.Type.PlusMult)
- sb.append( "+*" );
- else
- sb.append( "-*" );
+
+ sb.append(getOpString());
sb.append( OPERAND_DELIMITOR );
// Matrix1
sb.append( getInputs().get(0).prepInputOperand(input1) );
sb.append( OPERAND_DELIMITOR );
- // Matrix2
+ // Scalar
sb.append( getInputs().get(1).prepScalarInputOperand(input2) );
sb.append( OPERAND_DELIMITOR );
+ // Matrix2
+ sb.append( getInputs().get(2).prepInputOperand(input3));
+ sb.append( OPERAND_DELIMITOR );
+
+ sb.append( prepOutputOperand(output));
+
+ return sb.toString();
+ }
+
+ @Override
+ public String getInstructions(int input1, int input2, int input3, int output)
+ throws LopsException
+ {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append( getExecType() );
+ sb.append( OPERAND_DELIMITOR );
+
+ sb.append(getOpString());
+ sb.append( OPERAND_DELIMITOR );
+
+ // Matrix1
+ sb.append( getInputs().get(0).prepInputOperand(input1) );
+ sb.append( OPERAND_DELIMITOR );
+
// Scalar
+ sb.append( getInputs().get(1).prepScalarLabel() );
+ sb.append( OPERAND_DELIMITOR );
+
+ // Matrix2
sb.append( getInputs().get(2).prepInputOperand(input3));
sb.append( OPERAND_DELIMITOR );
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
index ee7a8fb..2036cf6 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
@@ -23,21 +23,25 @@ import java.io.Serializable;
public class MinusMultiply extends ValueFunctionWithConstant implements Serializable
{
-
private static final long serialVersionUID = 2801982061205871665L;
- public MinusMultiply() {
+ private MinusMultiply() {
// nothing to do here
}
+
+ public static MinusMultiply getMinusMultiplyFnObject() {
+ //create new object as the constant is modified and hence
+ //cannot be shared across multiple threads (e.g., in parfor)
+ return new MinusMultiply();
+ }
+
public Object clone() throws CloneNotSupportedException {
// cloning is not supported for singleton classes
throw new CloneNotSupportedException();
}
+
@Override
- public double execute(double in1, double in2)
- {
- return in1 - _constant*in2;
-
+ public double execute(double in1, double in2) {
+ return in1 - _constant*in2;
}
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
index 87eb47b..2a1eea0 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
@@ -23,21 +23,25 @@ import java.io.Serializable;
public class PlusMultiply extends ValueFunctionWithConstant implements Serializable
{
-
private static final long serialVersionUID = 2801982061205871665L;
- public PlusMultiply() {
+ private PlusMultiply() {
// nothing to do here
}
+
+ public static PlusMultiply getPlusMultiplyFnObject() {
+ //create new object as the constant is modified and hence
+ //cannot be shared across multiple threads (e.g., in parfor)
+ return new PlusMultiply();
+ }
+
public Object clone() throws CloneNotSupportedException {
// cloning is not supported for singleton classes
throw new CloneNotSupportedException();
}
+
@Override
- public double execute(double in1, double in2)
- {
- return in1 + _constant*in2;
-
+ public double execute(double in1, double in2) {
+ return in1 + _constant*in2;
}
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
index 2820875..f23c29a 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
@@ -26,13 +26,11 @@ public abstract class ValueFunctionWithConstant extends ValueFunction implements
private static final long serialVersionUID = -4985988545393861058L;
protected double _constant;
- public void setConstant(double constant)
- {
+ public void setConstant(double constant) {
_constant = constant;
}
- public double getConstant()
- {
+ public double getConstant() {
return _constant;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
index d2f477d..a3a7c08 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
@@ -56,6 +56,7 @@ import org.apache.sysml.runtime.functionobjects.LessThanEquals;
import org.apache.sysml.runtime.functionobjects.Mean;
import org.apache.sysml.runtime.functionobjects.Minus;
import org.apache.sysml.runtime.functionobjects.Minus1Multiply;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
import org.apache.sysml.runtime.functionobjects.MinusNz;
import org.apache.sysml.runtime.functionobjects.Modulus;
import org.apache.sysml.runtime.functionobjects.Multiply;
@@ -63,6 +64,7 @@ import org.apache.sysml.runtime.functionobjects.Multiply2;
import org.apache.sysml.runtime.functionobjects.NotEquals;
import org.apache.sysml.runtime.functionobjects.Or;
import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
import org.apache.sysml.runtime.functionobjects.Power;
import org.apache.sysml.runtime.functionobjects.Power2;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
@@ -626,6 +628,10 @@ public class InstructionUtils
return new BinaryOperator(Builtin.getBuiltinFnObject("max"));
else if ( opcode.equalsIgnoreCase("min") )
return new BinaryOperator(Builtin.getBuiltinFnObject("min"));
+ else if ( opcode.equalsIgnoreCase("+*") )
+ return new BinaryOperator(PlusMultiply.getPlusMultiplyFnObject());
+ else if ( opcode.equalsIgnoreCase("-*") )
+ return new BinaryOperator(MinusMultiply.getMinusMultiplyFnObject());
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
index 894e7e9..0b9cb7d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
@@ -63,6 +63,7 @@ import org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction;
import org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction;
import org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction;
import org.apache.sysml.runtime.instructions.mr.PickByCountInstruction;
+import org.apache.sysml.runtime.instructions.mr.PlusMultInstruction;
import org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction;
import org.apache.sysml.runtime.instructions.mr.RandInstruction;
import org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction;
@@ -182,6 +183,9 @@ public class MRInstructionParser extends InstructionParser
String2MRInstructionType.put( "^2" , MRINSTRUCTION_TYPE.ArithmeticBinary); //special ^ case
String2MRInstructionType.put( "*2" , MRINSTRUCTION_TYPE.ArithmeticBinary); //special * case
String2MRInstructionType.put( "-nz" , MRINSTRUCTION_TYPE.ArithmeticBinary); //special - case
+ String2MRInstructionType.put( "+*" , MRINSTRUCTION_TYPE.ArithmeticBinary2);
+ String2MRInstructionType.put( "-*" , MRINSTRUCTION_TYPE.ArithmeticBinary2);
+
String2MRInstructionType.put( "map+" , MRINSTRUCTION_TYPE.ArithmeticBinary);
String2MRInstructionType.put( "map-" , MRINSTRUCTION_TYPE.ArithmeticBinary);
String2MRInstructionType.put( "map*" , MRINSTRUCTION_TYPE.ArithmeticBinary);
@@ -333,6 +337,9 @@ public class MRInstructionParser extends InstructionParser
}
}
+ case ArithmeticBinary2:
+ return PlusMultInstruction.parseInstruction(str);
+
case AggregateBinary:
return AggregateBinaryInstruction.parseInstruction(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
index 212e0b7..12bc465 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
@@ -28,13 +28,15 @@ import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
-public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction {
+public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction
+{
public PlusMultCPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2,
CPOperand in3, CPOperand out, String opcode, String str)
{
super(op, in1, in2, out, opcode, str);
input3=in3;
}
+
public static PlusMultCPInstruction parseInstruction(String str)
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
@@ -43,14 +45,11 @@ public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction {
CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier
CPOperand operand3 = new CPOperand(parts[2]);
CPOperand outOperand = new CPOperand(parts[4]);
- BinaryOperator bOperator = null;
- if(opcode.equals("+*"))
- bOperator = new BinaryOperator(new PlusMultiply());
- else if (opcode.equals("-*"))
- bOperator = new BinaryOperator(new MinusMultiply());
+ BinaryOperator bOperator = new BinaryOperator(opcode.equals("+*") ?
+ PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject());
return new PlusMultCPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str);
-
}
+
@Override
public void processInstruction( ExecutionContext ec )
throws DMLRuntimeException
@@ -60,10 +59,10 @@ public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction {
//get all the inputs
MatrixBlock matrix1 = ec.getMatrixInput(input1.getName());
MatrixBlock matrix2 = ec.getMatrixInput(input2.getName());
- ScalarObject lambda = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral());
+ ScalarObject scalar = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral());
//execution
- ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(lambda.getDoubleValue());
+ ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(scalar.getDoubleValue());
MatrixBlock out = (MatrixBlock) matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock());
//release the matrices
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
index ea47e96..62762c1 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java
@@ -31,7 +31,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
public abstract class MRInstruction extends Instruction
{
- public enum MRINSTRUCTION_TYPE { INVALID, Append, Aggregate, ArithmeticBinary, AggregateBinary, AggregateUnary,
+ public enum MRINSTRUCTION_TYPE { INVALID, Append, Aggregate, ArithmeticBinary, ArithmeticBinary2, AggregateBinary, AggregateUnary,
Rand, Seq, CSVReblock, CSVWrite, Transform,
Reblock, Reorg, Replicate, Unary, CombineBinary, CombineUnary, CombineTernary, PickByCount, Partition,
Ternary, Quaternary, CM_N_COV, Combine, MapGroupedAggregate, GroupedAggregate, RangeReIndex, ZeroOut, MMTSJ, PMMJ, MatrixReshape, ParameterizedBuiltin, Sort, MapMultChain,
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java
new file mode 100644
index 0000000..95ae817
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/PlusMultInstruction.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.mr;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue;
+import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
+import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+
+
+public class PlusMultInstruction extends BinaryInstruction
+{
+ public PlusMultInstruction(Operator op, byte in1, byte in2, byte out, String istr) {
+ super(op, in1, in2, out, istr);
+ }
+
+ /**
+ *
+ * @param str
+ * @return
+ * @throws DMLRuntimeException
+ */
+ public static PlusMultInstruction parseInstruction ( String str )
+ throws DMLRuntimeException
+ {
+ InstructionUtils.checkNumFields ( str, 4 );
+
+ String[] parts = InstructionUtils.getInstructionParts ( str );
+ String opcode = parts[0];
+ byte in1 = Byte.parseByte(parts[1]);
+ double scalar = Double.parseDouble(parts[2]);
+ byte in2 = Byte.parseByte(parts[3]);
+ byte out = Byte.parseByte(parts[4]);
+
+ BinaryOperator bop = InstructionUtils.parseBinaryOperator(opcode);
+ ((ValueFunctionWithConstant) bop.fn).setConstant(scalar);
+ return new PlusMultInstruction(bop, in1, in2, out, str);
+ }
+
+ @Override
+ public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues,
+ IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor)
+ throws DMLRuntimeException
+ {
+ //default binary mr instruction execution (custom logic encoded in operator)
+ super.processInstruction(valueClass, cachedValues, tempValue, zeroInput, blockRowFactor, blockColFactor);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
index 4b73679..c93ed0a 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
@@ -44,6 +44,7 @@ public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction
throw new DMLRuntimeException("Unknown opcode in PlusMultSPInstruction: " + toString());
}
}
+
public static PlusMultSPInstruction parseInstruction(String str) throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
@@ -52,15 +53,11 @@ public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction
CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier
CPOperand operand3 = new CPOperand(parts[2]);
CPOperand outOperand = new CPOperand(parts[4]);
- BinaryOperator bOperator = null;
- if(opcode.equals("+*"))
- bOperator = new BinaryOperator(new PlusMultiply());
- else if (opcode.equals("-*"))
- bOperator = new BinaryOperator(new MinusMultiply());
+ BinaryOperator bOperator = new BinaryOperator(opcode.equals("+*") ?
+ PlusMultiply.getPlusMultiplyFnObject():MinusMultiply.getMinusMultiplyFnObject());
return new PlusMultSPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str);
}
-
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException
@@ -74,5 +71,4 @@ public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction
super.processMatrixMatrixBinaryInstruction(sec);
}
-
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b584aecf/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
index 7fec6b0..890a3b2 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -46,8 +46,6 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
private static final String TEST_DIR = "functions/misc/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/";
- //private static final int rows = 1234;
- //private static final int cols = 567;
private static final double eps = Math.pow(10, -10);
@Override
@@ -58,44 +56,64 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
}
@Test
- public void testFuseBinaryPlusNoRewrite() {
+ public void testFuseBinaryPlusNoRewriteCP() {
testFuseBinaryChain( TEST_NAME1, false, ExecType.CP );
}
@Test
- public void testFuseBinaryPlusRewrite() {
+ public void testFuseBinaryPlusRewriteCP() {
testFuseBinaryChain( TEST_NAME1, true, ExecType.CP);
}
@Test
- public void testFuseBinaryMinusNoRewrite() {
+ public void testFuseBinaryMinusNoRewriteCP() {
testFuseBinaryChain( TEST_NAME2, false, ExecType.CP );
}
@Test
- public void testFuseBinaryMinusRewrite() {
+ public void testFuseBinaryMinusRewriteCP() {
testFuseBinaryChain( TEST_NAME2, true, ExecType.CP );
}
@Test
- public void testSpFuseBinaryPlusNoRewrite() {
+ public void testFuseBinaryPlusNoRewriteSP() {
testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK );
}
@Test
- public void testSpFuseBinaryPlusRewrite() {
+ public void testFuseBinaryPlusRewriteSP() {
testFuseBinaryChain( TEST_NAME1, true, ExecType.SPARK );
}
@Test
- public void testSpFuseBinaryMinusNoRewrite() {
- testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK );
+ public void testFuseBinaryMinusNoRewriteSP() {
+ testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK );
}
@Test
- public void testSpFuseBinaryMinusRewrite() {
- testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK );
+ public void testFuseBinaryMinusRewriteSP() {
+ testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK );
+ }
+
+ @Test
+ public void testFuseBinaryPlusNoRewriteMR() {
+ testFuseBinaryChain( TEST_NAME1, false, ExecType.MR );
+ }
+
+ @Test
+ public void testFuseBinaryPlusRewriteMR() {
+ testFuseBinaryChain( TEST_NAME1, true, ExecType.MR );
+ }
+
+ @Test
+ public void testFuseBinaryMinusNoRewriteMR() {
+ testFuseBinaryChain( TEST_NAME2, false, ExecType.MR );
+ }
+
+ @Test
+ public void testFuseBinaryMinusRewriteMR() {
+ testFuseBinaryChain( TEST_NAME2, true, ExecType.MR );
}
@@ -111,7 +129,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
switch( instType ){
case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
- default: rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
}
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
@@ -142,7 +160,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
Assert.assertTrue(TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"));
//check for applies rewrites
- if( rewrites ) {
+ if( rewrites && instType!=ExecType.MR ) {
String prefix = (instType==ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "";
Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes()
.contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" ));