You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/07/17 00:10:42 UTC
incubator-systemml git commit: [SYSTEMML-766] New simplification
rewrite/runtime axpy (+*, -*)
Repository: incubator-systemml
Updated Branches:
refs/heads/master 01d9fdb45 -> b0d3c6c85
[SYSTEMML-766] New simplification rewrite/runtime axpy (+*, -*)
Closes #179.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b0d3c6c8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b0d3c6c8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b0d3c6c8
Branch: refs/heads/master
Commit: b0d3c6c85135c51177dbe67c4f944b1bf7dcc498
Parents: 01d9fdb
Author: tgamal <ta...@gmail.com>
Authored: Sat Jul 16 17:08:50 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 16 17:08:50 2016 -0700
----------------------------------------------------------------------
src/main/java/org/apache/sysml/hops/Hop.java | 4 +-
.../java/org/apache/sysml/hops/TernaryOp.java | 36 +++-
.../RewriteAlgebraicSimplificationStatic.java | 45 ++++-
src/main/java/org/apache/sysml/lops/Lop.java | 1 +
.../java/org/apache/sysml/lops/PlusMult.java | 107 ++++++++++++
.../runtime/functionobjects/MinusMultiply.java | 43 +++++
.../runtime/functionobjects/PlusMultiply.java | 43 +++++
.../ValueFunctionWithConstant.java | 38 +++++
.../instructions/CPInstructionParser.java | 12 +-
.../instructions/SPInstructionParser.java | 16 +-
.../instructions/cp/PlusMultCPInstruction.java | 64 ++++++++
.../spark/PlusMultSPInstruction.java | 87 ++++++++++
.../misc/RewriteFuseBinaryOpChainTest.java | 164 +++++++++++++++++++
.../RewriteSimplifyRowColSumMVMultTest.java | 8 +-
.../misc/RewriteFuseBinaryOpChainTest1.R | 28 ++++
.../misc/RewriteFuseBinaryOpChainTest1.dml | 27 +++
.../misc/RewriteFuseBinaryOpChainTest2.R | 28 ++++
.../misc/RewriteFuseBinaryOpChainTest2.dml | 27 +++
.../functions/misc/ZPackageSuite.java | 2 +
19 files changed, 764 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 8c4999e..144ca20 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1064,7 +1064,7 @@ public abstract class Hop
// Operations that require 3 operands
public enum OpOp3 {
- QUANTILE, INTERQUANTILE, CTABLE, CENTRALMOMENT, COVARIANCE, INVALID
+ QUANTILE, INTERQUANTILE, CTABLE, CENTRALMOMENT, COVARIANCE, INVALID, PLUS_MULT, MINUS_MULT
};
// Operations that require 4 operands
@@ -1416,6 +1416,8 @@ public abstract class Hop
HopsOpOp3String.put(OpOp3.CTABLE, "ctable");
HopsOpOp3String.put(OpOp3.CENTRALMOMENT, "cm");
HopsOpOp3String.put(OpOp3.COVARIANCE, "cov");
+ HopsOpOp3String.put(OpOp3.PLUS_MULT, "+*");
+ HopsOpOp3String.put(OpOp3.MINUS_MULT, "-*");
}
protected static final HashMap<Hop.OpOp4, String> HopsOpOp4String;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/hops/TernaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java
index e353273..72e7624 100644
--- a/src/main/java/org/apache/sysml/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java
@@ -30,6 +30,7 @@ import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.PickByCount;
+import org.apache.sysml.lops.PlusMult;
import org.apache.sysml.lops.SortKeys;
import org.apache.sysml.lops.Ternary;
import org.apache.sysml.lops.UnaryCP;
@@ -138,6 +139,11 @@ public class TernaryOp extends Hop
case CTABLE:
constructLopsCtable();
break;
+
+ case PLUS_MULT:
+ case MINUS_MULT:
+ constructLopsPlusMult();
+ break;
default:
throw new HopsException(this.printErrorLocation() + "Unknown TernaryOp (" + _op + ") while constructing Lops \n");
@@ -621,7 +627,16 @@ public class TernaryOp extends Hop
}
}
}
-
+ private void constructLopsPlusMult() throws HopsException, LopsException {
+ if ( _op != OpOp3.PLUS_MULT && _op != OpOp3.MINUS_MULT )
+ throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.PLUS_MULT + " or" + OpOp3.MINUS_MULT);
+
+ ExecType et = optFindExecType();
+ PlusMult plusmult = new PlusMult(getInput().get(0).constructLops(),getInput().get(1).constructLops(),getInput().get(2).constructLops(), _op, getDataType(),getValueType(), et );
+ setOutputDimensions(plusmult);
+ setLineNumbers(plusmult);
+ setLops(plusmult);
+ }
@Override
public String getOpString() {
String s = new String("");
@@ -667,7 +682,10 @@ public class TernaryOp extends Hop
// This part of the code is executed only when a vector of quantiles are computed
// Output is a vector of length = #of quantiles to be computed, and it is likely to be dense.
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
-
+ case PLUS_MULT:
+ case MINUS_MULT:
+ sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
+ return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
default:
throw new RuntimeException("Memory for operation (" + _op + ") can not be estimated.");
}
@@ -742,7 +760,12 @@ public class TernaryOp extends Hop
if( mc[2].dimsKnown() )
return new long[]{mc[2].getRows(), 1, mc[2].getRows()};
break;
-
+ case PLUS_MULT:
+ case MINUS_MULT:
+ //compute back NNz
+ double sp1 = OptimizerUtils.getSparsity(mc[0].getRows(), mc[0].getRows(), mc[0].getNonZeros());
+ double sp2 = OptimizerUtils.getSparsity(mc[2].getRows(), mc[2].getRows(), mc[2].getNonZeros());
+ return new long[]{mc[0].getRows(), mc[0].getCols(), (long) Math.min(sp1+sp2,1)};
default:
throw new RuntimeException("Memory for operation (" + _op + ") can not be estimated.");
}
@@ -845,7 +868,12 @@ public class TernaryOp extends Hop
// Output is a vector of length = #of quantiles to be computed, and it is likely to be dense.
// TODO qx1
break;
-
+
+ case PLUS_MULT:
+ case MINUS_MULT:
+ setDim1( getInput().get(0)._dim1 );
+ setDim2( getInput().get(0)._dim2 );
+ break;
default:
throw new RuntimeException("Size information for operation (" + _op + ") can not be updated.");
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index e903a03..43d5791 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -25,6 +25,8 @@ import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
@@ -32,6 +34,7 @@ import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
@@ -162,8 +165,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
-
-
+ hi = fuseBinaryOperationChain(hop, hi, i); //e.g., X + lamda*Y -> X +* lambda Y
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
@@ -174,7 +176,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hop.setVisited(Hop.VisitStatus.DONE);
}
-
/**
*
* @param hi
@@ -1908,4 +1909,42 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
return hi;
}
+
+ /**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) {
+ //pattern: X + lamda*Y -> X +* lambda Y
+ if( hi instanceof BinaryOp
+ && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS)
+ && ((BinaryOp)hi).getInput().get(0).getDataType()==DataType.MATRIX && ((BinaryOp)hi).getInput().get(1) instanceof BinaryOp
+ && (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) )
+ {
+ //Check that the inner binary Op is a product of Scalar times Matrix or viceversa
+ Hop innerBinaryOp = ((BinaryOp)hi).getInput().get(1);
+ if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX)
+ || (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR))
+ {
+ //check which operand is the Scalar and which is the matrix
+ Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
+ Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
+
+ OpOp3 operator = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
+ TernaryOp ternOp=new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, operator, ((BinaryOp)hi).getInput().get(0), lamda, matrix);
+
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, ternOp, pos);
+
+ LOG.debug("Applied fuseBinaryOperationChain. (line " +hi.getBeginLine()+")");
+ return ternOp;
+ }
+ }
+ return hi;
+
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/lops/Lop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java
index d930da0..412e6d4 100644
--- a/src/main/java/org/apache/sysml/lops/Lop.java
+++ b/src/main/java/org/apache/sysml/lops/Lop.java
@@ -59,6 +59,7 @@ public abstract class Lop
WeightedSquaredLoss, WeightedSigmoid, WeightedDivMM, WeightedCeMM, WeightedUMM,
SortKeys, PickValues,
Checkpoint, //Spark persist into storage level
+ PlusMult, MinusMult, //CP
};
/**
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/lops/PlusMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java b/src/main/java/org/apache/sysml/lops/PlusMult.java
new file mode 100644
index 0000000..2dc16e9
--- /dev/null
+++ b/src/main/java/org/apache/sysml/lops/PlusMult.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.lops;
+
+import org.apache.sysml.hops.Hop.OpOp3;
+import org.apache.sysml.lops.LopProperties.ExecLocation;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.lops.compile.JobType;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.parser.Expression.*;
+
+
+/**
+ * Lop to perform Sum of a matrix with another matrix multiplied by Scalar.
+ */
+public class PlusMult extends Lop
+{
+
+ private void init(Lop input1, Lop input2, Lop input3, ExecType et) {
+ this.addInput(input1);
+ this.addInput(input2);
+ this.addInput(input3);
+ input1.addOutput(this);
+ input2.addOutput(this);
+ input3.addOutput(this);
+
+ boolean breaksAlignment = false;
+ boolean aligner = false;
+ boolean definesMRJob = false;
+
+ if ( et == ExecType.CP || et == ExecType.SPARK ){
+ lps.addCompatibility(JobType.INVALID);
+ this.lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ }
+ }
+
+ public PlusMult(Lop input1, Lop input2, Lop input3, OpOp3 op, DataType dt, ValueType vt, ExecType et) {
+ super(Lop.Type.PlusMult, dt, vt);
+ if(op == OpOp3.MINUS_MULT)
+ type=Lop.Type.MinusMult;
+ init(input1, input2, input3, et);
+ }
+
+ @Override
+ public String toString() {
+
+ return "Operation = PlusMult";
+ }
+
+
+ /**
+ * Function to generate CP Sum of a matrix with another matrix multiplied by Scalar.
+ *
+ * input1: matrix1
+ * input2: Scalar
+ * input3: matrix2
+ */
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String output) {
+ StringBuilder sb = new StringBuilder();
+ sb.append( getExecType() );
+ sb.append( OPERAND_DELIMITOR );
+ if(type==Lop.Type.PlusMult)
+ sb.append( "+*" );
+ else
+ sb.append( "-*" );
+ sb.append( OPERAND_DELIMITOR );
+
+ // Matrix1
+ sb.append( getInputs().get(0).prepInputOperand(input1) );
+ sb.append( OPERAND_DELIMITOR );
+
+ // Matrix2
+ sb.append( getInputs().get(1).prepScalarInputOperand(input2) );
+ sb.append( OPERAND_DELIMITOR );
+
+ // Scalar
+ sb.append( getInputs().get(2).prepInputOperand(input3));
+ sb.append( OPERAND_DELIMITOR );
+
+ sb.append( prepOutputOperand(output));
+
+ return sb.toString();
+ }
+
+
+
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
new file mode 100644
index 0000000..ee7a8fb
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.functionobjects;
+
+import java.io.Serializable;
+
+public class MinusMultiply extends ValueFunctionWithConstant implements Serializable
+{
+
+ private static final long serialVersionUID = 2801982061205871665L;
+
+ public MinusMultiply() {
+ // nothing to do here
+ }
+ public Object clone() throws CloneNotSupportedException {
+ // cloning is not supported for singleton classes
+ throw new CloneNotSupportedException();
+ }
+ @Override
+ public double execute(double in1, double in2)
+ {
+ return in1 - _constant*in2;
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
new file mode 100644
index 0000000..87eb47b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.functionobjects;
+
+import java.io.Serializable;
+
+public class PlusMultiply extends ValueFunctionWithConstant implements Serializable
+{
+
+ private static final long serialVersionUID = 2801982061205871665L;
+
+ public PlusMultiply() {
+ // nothing to do here
+ }
+ public Object clone() throws CloneNotSupportedException {
+ // cloning is not supported for singleton classes
+ throw new CloneNotSupportedException();
+ }
+ @Override
+ public double execute(double in1, double in2)
+ {
+ return in1 + _constant*in2;
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
new file mode 100644
index 0000000..2820875
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.functionobjects;
+
+import java.io.Serializable;
+
+public abstract class ValueFunctionWithConstant extends ValueFunction implements Serializable
+{
+ private static final long serialVersionUID = -4985988545393861058L;
+ protected double _constant;
+
+ public void setConstant(double constant)
+ {
+ _constant = constant;
+ }
+
+ public double getConstant()
+ {
+ return _constant;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index f3a1621..c91ad8c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -51,6 +51,7 @@ import org.apache.sysml.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysml.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysml.runtime.instructions.cp.PMMJCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction;
import org.apache.sysml.runtime.instructions.cp.QuantilePickCPInstruction;
import org.apache.sysml.runtime.instructions.cp.QuantileSortCPInstruction;
import org.apache.sysml.runtime.instructions.cp.QuaternaryCPInstruction;
@@ -63,6 +64,7 @@ import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction;
import org.apache.sysml.runtime.instructions.cpfile.ParameterizedBuiltinCPFileInstruction;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
public class CPInstructionParser extends InstructionParser
{
@@ -120,7 +122,9 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "^2" , CPINSTRUCTION_TYPE.ArithmeticBinary); //special ^ case
String2CPInstructionType.put( "*2" , CPINSTRUCTION_TYPE.ArithmeticBinary); //special * case
String2CPInstructionType.put( "-nz" , CPINSTRUCTION_TYPE.ArithmeticBinary); //special - case
-
+ String2CPInstructionType.put( "+*" , CPINSTRUCTION_TYPE.ArithmeticBinary);
+ String2CPInstructionType.put( "-*" , CPINSTRUCTION_TYPE.ArithmeticBinary);
+
// Boolean Instruction Opcodes
String2CPInstructionType.put( "&&" , CPINSTRUCTION_TYPE.BooleanBinary);
@@ -306,7 +310,11 @@ public class CPInstructionParser extends InstructionParser
return AggregateTernaryCPInstruction.parseInstruction(str);
case ArithmeticBinary:
- return ArithmeticBinaryCPInstruction.parseInstruction(str);
+ String opcode = InstructionUtils.getOpCode(str);
+ if( opcode.equals("+*") || opcode.equals("-*") )
+ return PlusMultCPInstruction.parseInstruction(str);
+ else
+ return ArithmeticBinaryCPInstruction.parseInstruction(str);
case Ternary:
return TernaryCPInstruction.parseInstruction(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index e0c4631..a9a34f5 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -34,6 +34,9 @@ import org.apache.sysml.lops.WeightedSquaredLossR;
import org.apache.sysml.lops.WeightedUnaryMM;
import org.apache.sysml.lops.WeightedUnaryMMR;
import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.cp.ArithmeticBinaryCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.spark.AggregateTernarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.AppendGAlignedSPInstruction;
@@ -59,6 +62,7 @@ import org.apache.sysml.runtime.instructions.spark.MatrixReshapeSPInstruction;
import org.apache.sysml.runtime.instructions.spark.MultiReturnParameterizedBuiltinSPInstruction;
import org.apache.sysml.runtime.instructions.spark.PMapmmSPInstruction;
import org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
+import org.apache.sysml.runtime.instructions.spark.PlusMultSPInstruction;
import org.apache.sysml.runtime.instructions.spark.PmmSPInstruction;
import org.apache.sysml.runtime.instructions.spark.QuantilePickSPInstruction;
import org.apache.sysml.runtime.instructions.spark.QuaternarySPInstruction;
@@ -148,7 +152,9 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "1-*" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "^" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "^2" , SPINSTRUCTION_TYPE.ArithmeticBinary);
- String2SPInstructionType.put( "*2" , SPINSTRUCTION_TYPE.ArithmeticBinary);
+ String2SPInstructionType.put( "*2" , SPINSTRUCTION_TYPE.ArithmeticBinary);
+ String2SPInstructionType.put( "+*" , SPINSTRUCTION_TYPE.ArithmeticBinary);
+ String2SPInstructionType.put( "-*" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "map+" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "map-" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "map*" , SPINSTRUCTION_TYPE.ArithmeticBinary);
@@ -157,6 +163,8 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "map%/%" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "map1-*" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "map^" , SPINSTRUCTION_TYPE.ArithmeticBinary);
+ String2SPInstructionType.put( "map+*" , SPINSTRUCTION_TYPE.ArithmeticBinary);
+ String2SPInstructionType.put( "map-*" , SPINSTRUCTION_TYPE.ArithmeticBinary);
String2SPInstructionType.put( "map>" , SPINSTRUCTION_TYPE.RelationalBinary);
String2SPInstructionType.put( "map>=" , SPINSTRUCTION_TYPE.RelationalBinary);
String2SPInstructionType.put( "map<" , SPINSTRUCTION_TYPE.RelationalBinary);
@@ -326,7 +334,11 @@ public class SPInstructionParser extends InstructionParser
return ReorgSPInstruction.parseInstruction(str);
case ArithmeticBinary:
- return ArithmeticBinarySPInstruction.parseInstruction(str);
+ String opcode = InstructionUtils.getOpCode(str);
+ if( opcode.equals("+*") || opcode.equals("-*") )
+ return PlusMultSPInstruction.parseInstruction(str);
+ else
+ return ArithmeticBinarySPInstruction.parseInstruction(str);
case RelationalBinary:
return RelationalBinarySPInstruction.parseInstruction(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
new file mode 100644
index 0000000..8b01cb7
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
@@ -0,0 +1,64 @@
+package org.apache.sysml.runtime.instructions.cp;
+
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.functionobjects.CM;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
+import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.CMOperator;
+import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
+
+public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction {
+ public PlusMultCPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2,
+ CPOperand in3, CPOperand out, String opcode, String str)
+ {
+ super(op, in1, in2, out, opcode, str);
+ input3=in3;
+ }
+ public static PlusMultCPInstruction parseInstruction(String str)
+ {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode=parts[0];
+ CPOperand operand1 = new CPOperand(parts[1]);
+ CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier
+ CPOperand operand3 = new CPOperand(parts[2]);
+ CPOperand outOperand = new CPOperand(parts[4]);
+ BinaryOperator bOperator = null;
+ if(opcode.equals("+*"))
+ bOperator = new BinaryOperator(new PlusMultiply());
+ else if (opcode.equals("-*"))
+ bOperator = new BinaryOperator(new MinusMultiply());
+ return new PlusMultCPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str);
+
+ }
+ @Override
+ public void processInstruction( ExecutionContext ec )
+ throws DMLRuntimeException
+ {
+
+ String output_name = output.getName();
+
+ //get all the inputs
+ MatrixBlock matrix1 = ec.getMatrixInput(input1.getName());
+ MatrixBlock matrix2 = ec.getMatrixInput(input2.getName());
+ ScalarObject lambda = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral());
+
+
+ //execution
+ ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(lambda.getDoubleValue());
+ MatrixBlock out = (MatrixBlock) matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock());
+
+ //release the matrices
+ ec.releaseMatrixInput(input1.getName());
+ ec.releaseMatrixInput(input2.getName());
+
+ ec.setMatrixOutput(output_name, out);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
new file mode 100644
index 0000000..89de821
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
+import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
+import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+
+public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction
+{
+ public PlusMultSPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2,
+ CPOperand in3, CPOperand out, String opcode, String str) throws DMLRuntimeException
+ {
+ super(op, in1, in2, out, opcode, str);
+ input3= in3;
+
+ //sanity check opcodes
+ if ( !( opcode.equalsIgnoreCase("+*") || opcode.equalsIgnoreCase("-*") ) )
+ {
+ throw new DMLRuntimeException("Unknown opcode in PlusMultSPInstruction: " + toString());
+ }
+ }
+ public static PlusMultSPInstruction parseInstruction(String str) throws DMLRuntimeException
+ {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode=parts[0];
+ CPOperand operand1 = new CPOperand(parts[1]);
+ CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier
+ CPOperand operand3 = new CPOperand(parts[2]);
+ CPOperand outOperand = new CPOperand(parts[4]);
+ BinaryOperator bOperator = null;
+ if(opcode.equals("+*"))
+ bOperator = new BinaryOperator(new PlusMultiply());
+ else if (opcode.equals("-*"))
+ bOperator = new BinaryOperator(new MinusMultiply());
+ return new PlusMultSPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str);
+ }
+
+
+ @Override
+ public void processInstruction(ExecutionContext ec)
+ throws DMLRuntimeException
+ {
+ SparkExecutionContext sec = (SparkExecutionContext)ec;
+
+ //pass the scalar
+ ScalarObject constant = (ScalarObject) ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral());
+ ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(constant.getDoubleValue());
+
+ super.processMatrixMatrixBinaryInstruction(sec);
+
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
new file mode 100644
index 0000000..e010083
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * Regression test for function recompile-once issue with literal replacement.
+ *
+ */
+public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
+{
+
+ private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1";
+ private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2";
+
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/";
+
+ //private static final int rows = 1234;
+ //private static final int cols = 567;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp()
+ {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testFuseBinaryPlusNoRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME1, false, ExecType.CP );
+ }
+
+
+ @Test
+ public void testFuseBinaryPlusRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME1, true, ExecType.CP);
+ }
+ @Test
+ public void testFuseBinaryMinusNoRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME2, false, ExecType.CP );
+ }
+
+ @Test
+ public void testFuseBinaryMinusRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME2, true, ExecType.CP );
+ }
+
+
+
+ @Test
+ public void testSpFuseBinaryPlusNoRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK );
+ }
+
+
+ @Test
+ public void testSpFuseBinaryPlusRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME1, true, ExecType.SPARK );
+ }
+
+
+ @Test
+ public void testSpFuseBinaryMinusNoRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK );
+ }
+
+ @Test
+ public void testSpFuseBinaryMinusRewrite()
+ {
+ testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK );
+ }
+
+
+ /**
+ *
+ * @param condition
+ * @param branchRemoval
+ * @param IPA
+ */
+ private void testFuseBinaryChain( String testname, boolean rewrites, ExecType instType )
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( instType ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ boolean rewritesOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+ try
+ {
+
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{"-explain", "-stats","-args", output("S") };
+
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S");
+ Assert.assertTrue(TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"));
+ }
+ finally
+ {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java
index d68d3b7..2829bab 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java
@@ -56,25 +56,25 @@ public class RewriteSimplifyRowColSumMVMultTest extends AutomatedTestBase
}
@Test
- public void testRewriteRowSumsMVMultNoRewrite()
+ public void testMultiScalarToBinaryNoRewrite()
{
testRewriteRowColSumsMVMult( TEST_NAME1, false );
}
@Test
- public void testRewriteRowSumsMVMultRewrite()
+ public void testMultiScalarToBinaryRewrite()
{
testRewriteRowColSumsMVMult( TEST_NAME1, true );
}
@Test
- public void testRewriteColSumsMVMultNoRewrite()
+ public void testMultiBinaryToScalarNoRewrite()
{
testRewriteRowColSumsMVMult( TEST_NAME2, false );
}
@Test
- public void testRewriteColSumsMVMultRewrite()
+ public void testMultiBinaryToScalarRewrite()
{
testRewriteRowColSumsMVMult( TEST_NAME2, true );
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R
new file mode 100644
index 0000000..c34948c
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X=matrix(1,10,10)
+Y=matrix(1,10,10)
+lamda=7
+S=X+lamda*Y
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml
new file mode 100644
index 0000000..077b8a9
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,rows=10,cols=10)
+Y=matrix(1,rows=10,cols=10)
+if(1==1){}
+lamda=7
+S=X+lamda*Y
+write(S,$1)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R
new file mode 100644
index 0000000..1caff09
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X=matrix(1,10,10)
+Y=matrix(1,10,10)
+lamda=7
+S=X-lamda*Y
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml
new file mode 100644
index 0000000..f3c6b9a
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,rows=10,cols=10)
+Y=matrix(1,rows=10,cols=10)
+if(1==1){}
+lamda=7
+S=X-lamda*Y
+write(S,$1)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index 4720595..6c40dd7 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -49,8 +49,10 @@ import org.junit.runners.Suite;
RewriteFusedRandTest.class,
RewritePushdownSumOnBinaryTest.class,
RewritePushdownUaggTest.class,
+ RewritePushdownSumBinaryMult.class,
RewriteSimplifyRowColSumMVMultTest.class,
RewriteSlicedMatrixMultTest.class,
+ RewriteFuseBinaryOpChainTest.class,
ScalarAssignmentTest.class,
ScalarFunctionTest.class,
ScalarMatrixUnaryBinaryTermTest.class,