You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/08/01 02:17:26 UTC
[1/4] incubator-systemml git commit: [SYSTEMML-824] Performance
scalar/unary/binary ops (nnz, sparse, skip)
Repository: incubator-systemml
Updated Branches:
refs/heads/master 7f3327658 -> 11a85775f
[SYSTEMML-824] Performance scalar/unary/binary ops (nnz, sparse, skip)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a528f5e4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a528f5e4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a528f5e4
Branch: refs/heads/master
Commit: a528f5e4ee3e4717ff0dfe27850533b94043c8fe
Parents: 7f33276
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jul 29 23:22:39 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 30 16:23:17 2016 -0700
----------------------------------------------------------------------
.../runtime/matrix/data/LibMatrixBincell.java | 140 ++++++++++---------
.../sysml/runtime/matrix/data/MatrixBlock.java | 56 ++++++--
.../matrix/operators/ScalarOperator.java | 32 ++---
.../runtime/matrix/operators/UnaryOperator.java | 10 +-
4 files changed, 126 insertions(+), 112 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
index 6ad08d8..f9269e8 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
@@ -29,11 +29,13 @@ import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysml.runtime.functionobjects.LessThan;
import org.apache.sysml.runtime.functionobjects.LessThanEquals;
import org.apache.sysml.runtime.functionobjects.Minus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Multiply2;
import org.apache.sysml.runtime.functionobjects.NotEquals;
import org.apache.sysml.runtime.functionobjects.Or;
import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
import org.apache.sysml.runtime.functionobjects.Power2;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
@@ -332,7 +334,8 @@ public class LibMatrixBincell
}
}
else if( !ret.sparse && (m1.sparse || m2.sparse) &&
- (op.fn instanceof Plus || op.fn instanceof Minus ||
+ (op.fn instanceof Plus || op.fn instanceof Minus ||
+ op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply ||
(op.fn instanceof Multiply && !m2.sparse )))
{
//specific case in order to prevent binary search on sparse inputs (see quickget and quickset)
@@ -464,7 +467,8 @@ public class LibMatrixBincell
double[] a = m1.denseBlock;
double[] b = m2.denseBlock;
double[] c = ret.denseBlock;
-
+ int nnz = 0;
+
if( atype == BinaryAccessType.MATRIX_COL_VECTOR )
{
for( int i=0, ix=0; i<rlen; i++, ix+=clen )
@@ -474,33 +478,39 @@ public class LibMatrixBincell
if( skipEmpty && v2 == 0 ) //skip empty rows
continue;
- if( isMultiply && v2 == 1 ) //ROW COPY
- {
+ if( isMultiply && v2 == 1 ) { //ROW COPY
//a guaranteed to be non-null (see early abort)
System.arraycopy(a, ix, c, ix, clen);
+ nnz += m1.recomputeNonZeros(i, i, 0, clen-1);
}
- else //GENERAL CASE
- {
+ else { //GENERAL CASE
if( a != null )
- for( int j=0; j<clen; j++ )
+ for( int j=0; j<clen; j++ ) {
c[ix+j] = op.fn.execute( a[ix+j], v2 );
- else
- Arrays.fill(c, ix, ix+clen, op.fn.execute( 0, v2 ));
+ nnz += (c[ix+j] != 0) ? 1 : 0;
+ }
+ else {
+ double val = op.fn.execute( 0, v2 );
+ Arrays.fill(c, ix, ix+clen, val);
+ nnz += (val != 0) ? clen : 0;
+ }
}
}
}
else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR )
{
- if( a==null && b==null ) //both empty
- {
+ if( a==null && b==null ) { //both empty
double v = op.fn.execute( 0, 0 );
Arrays.fill(c, 0, rlen*clen, v);
+ nnz += (v != 0) ? rlen*clen : 0;
}
else if( a==null ) //left empty
{
//compute first row
- for( int j=0; j<clen; j++ )
+ for( int j=0; j<clen; j++ ) {
c[j] = op.fn.execute( 0, b[j] );
+ nnz += (c[j] != 0) ? rlen : 0;
+ }
//copy first to all other rows
for( int i=1, ix=clen; i<rlen; i++, ix+=clen )
System.arraycopy(c, 0, c, ix, clen);
@@ -508,12 +518,14 @@ public class LibMatrixBincell
else //default case (incl right empty)
{
for( int i=0, ix=0; i<rlen; i++, ix+=clen )
- for( int j=0; j<clen; j++ )
+ for( int j=0; j<clen; j++ ) {
c[ix+j] = op.fn.execute( a[ix+j], ((b!=null) ? b[j] : 0) );
+ nnz += (c[ix+j] != 0) ? 1 : 0;
+ }
}
}
- ret.recomputeNonZeros();
+ ret.nonZeros = nnz;
}
/**
@@ -608,7 +620,7 @@ public class LibMatrixBincell
for( int j=apos; j<apos+alen; j++ )
{
//empty left
- for( int k = lastIx+1; k<aix[j]; k++ ){
+ for( int k=lastIx+1; !skipEmpty&&k<aix[j]; k++ ){
double v2 = m2.quickGetValue(0, k);
double v = op.fn.execute( 0, v2 );
ret.appendValue(i, k, v);
@@ -622,7 +634,7 @@ public class LibMatrixBincell
}
//empty left
- for( int k = lastIx+1; k<clen; k++ ){
+ for( int k=lastIx+1; !skipEmpty&&k<clen; k++ ){
double v2 = m2.quickGetValue(0, k);
double v = op.fn.execute( 0, v2 );
ret.appendValue(i, k, v);
@@ -945,43 +957,43 @@ public class LibMatrixBincell
ret.allocateSparseRowsBlock();
SparseBlock a = m1.sparseBlock;
SparseBlock c = ret.sparseBlock;
+ int rlen = Math.min(m1.rlen, a.numRows());
- for(int r=0; r<Math.min(m1.rlen, a.numRows()); r++) {
- if( !a.isEmpty(r) )
- {
- int apos = a.pos(r);
- int alen = a.size(r);
- int[] aix = a.indexes(r);
- double[] avals = a.values(r);
+ long nnz = 0;
+ for(int r=0; r<rlen; r++) {
+ if( a.isEmpty(r) ) continue;
+
+ int apos = a.pos(r);
+ int alen = a.size(r);
+ int[] aix = a.indexes(r);
+ double[] avals = a.values(r);
+
+ if( copyOnes ) { //SPECIAL CASE: e.g., (X != 0)
+ //create sparse row without repeated resizing
+ SparseRow crow = new SparseRow(alen);
+ crow.setSize(alen);
- if( copyOnes ) //SPECIAL CASE: e.g., (X != 0)
- {
- //create sparse row without repeated resizing
- SparseRow crow = new SparseRow(alen);
- crow.setSize(alen);
-
- //memcopy/memset of indexes/values (sparseblock guarantees absence of 0s)
- System.arraycopy(aix, apos, crow.indexes(), 0, alen);
- Arrays.fill(crow.values(), 0, alen, 1);
- c.set(r, crow, false);
- ret.nonZeros+=alen;
+ //memcopy/memset of indexes/values (sparseblock guarantees absence of 0s)
+ System.arraycopy(aix, apos, crow.indexes(), 0, alen);
+ Arrays.fill(crow.values(), 0, alen, 1);
+ c.set(r, crow, false);
+ nnz += alen;
+ }
+ else { //GENERAL CASE
+ //create sparse row without repeated resizing for specific ops
+ if( op.fn instanceof Multiply || op.fn instanceof Multiply2
+ || op.fn instanceof Power2 ) {
+ c.allocate(r, alen);
}
- else //GENERAL CASE
- {
- //create sparse row without repeated resizing for specific ops
- if( op.fn instanceof Multiply || op.fn instanceof Multiply2
- || op.fn instanceof Power2 )
- {
- c.allocate(r, alen);
- }
-
- for(int j=apos; j<apos+alen; j++) {
- double val = op.executeScalar(avals[j]);
- ret.appendValue(r, aix[j], val);
- }
+
+ for(int j=apos; j<apos+alen; j++) {
+ double val = op.executeScalar(avals[j]);
+ c.append(r, aix[j], val);
+ nnz += (val != 0) ? 1 : 0;
}
}
}
+ ret.nonZeros = nnz;
}
else //DENSE <- DENSE
{
@@ -992,12 +1004,12 @@ public class LibMatrixBincell
double[] c = ret.denseBlock;
int limit = m1.rlen*m1.clen;
- for( int i=0; i<limit; i++ )
- {
+ int nnz = 0;
+ for( int i=0; i<limit; i++ ) {
c[i] = op.executeScalar( a[i] );
- if( c[i] != 0 )
- ret.nonZeros++;
+ nnz += (c[i] != 0) ? 1 : 0;
}
+ ret.nonZeros = nnz;
}
}
@@ -1040,8 +1052,8 @@ public class LibMatrixBincell
Arrays.fill(c, cval0);
//compute non-zero input values
- for(int i=0, cix=0; i<m; i++, cix+=n)
- {
+ int nnz = m*n;
+ for(int i=0, cix=0; i<m; i++, cix+=n) {
if( !a.isEmpty(i) ) {
int apos = a.pos(i);
int alen = a.size(i);
@@ -1050,12 +1062,11 @@ public class LibMatrixBincell
for(int j=apos; j<apos+alen; j++) {
double val = op.executeScalar(avals[j]);
c[ cix+aix[j] ] = val;
+ nnz -= (val==0) ? 1 : 0;
}
}
}
-
- //recompute non zeros
- ret.recomputeNonZeros();
+ ret.nonZeros = nnz;
}
else //DENSE MATRIX
{
@@ -1067,12 +1078,12 @@ public class LibMatrixBincell
//compute scalar operation, incl nnz maintenance
int limit = m1.rlen*m1.clen;
- for( int i=0; i<limit; i++ )
- {
+ int nnz = 0;
+ for( int i=0; i<limit; i++ ) {
c[i] = op.executeScalar( a[i] );
- if( c[i] != 0 )
- ret.nonZeros++;
+ nnz += (c[i] != 0) ? 1 : 0;
}
+ ret.nonZeros = nnz;
}
}
@@ -1270,21 +1281,18 @@ public class LibMatrixBincell
while( p1<size1 && p2< size2 )
{
double value = 0;
- if(cols1[pos1+p1]<cols2[pos2+p2])
- {
+ if(cols1[pos1+p1]<cols2[pos2+p2]) {
value = op.fn.execute(values1[pos1+p1], 0);
column = cols1[pos1+p1];
p1++;
}
- else if(cols1[pos1+p1]==cols2[pos2+p2])
- {
+ else if(cols1[pos1+p1]==cols2[pos2+p2]) {
value = op.fn.execute(values1[pos1+p1], values2[pos2+p2]);
column = cols1[pos1+p1];
p1++;
p2++;
}
- else
- {
+ else {
value = op.fn.execute(0, values2[pos2+p2]);
column = cols2[pos2+p2];
p2++;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 452eddb..5884768 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -2901,24 +2901,48 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
final int m = rlen;
final int n = clen;
- if( sparse ) //SPARSE <- SPARSE
+ if( sparse && ret.sparse ) //SPARSE <- SPARSE
{
+ ret.allocateSparseRowsBlock();
SparseBlock a = sparseBlock;
+ SparseBlock c = ret.sparseBlock;
+
+ long nnz = 0;
+ for(int i=0; i<m; i++) {
+ if( a.isEmpty(i) ) continue;
+
+ int apos = a.pos(i);
+ int alen = a.size(i);
+ int[] aix = a.indexes(i);
+ double[] avals = a.values(i);
+
+ c.allocate(i, alen); //avoid repeated alloc
+ for( int j=apos; j<apos+alen; j++ ) {
+ double val = op.fn.execute(avals[j]);
+ c.append(i, aix[j], val);
+ nnz += (val != 0) ? 1 : 0;
+ }
+ }
+ ret.nonZeros = nnz;
+ }
+ else if( sparse ) //DENSE <- SPARSE
+ {
+ SparseBlock a = sparseBlock;
for(int i=0; i<m; i++) {
- if( !a.isEmpty(i) )
- {
- int apos = a.pos(i);
- int alen = a.size(i);
- int[] aix = a.indexes(i);
- double[] avals = a.values(i);
-
- for( int j=apos; j<apos+alen; j++ ) {
- double val = op.fn.execute(avals[j]);
- ret.appendValue(i, aix[j], val);
- }
+ if( a.isEmpty(i) ) continue;
+
+ int apos = a.pos(i);
+ int alen = a.size(i);
+ int[] aix = a.indexes(i);
+ double[] avals = a.values(i);
+
+ for( int j=apos; j<apos+alen; j++ ) {
+ double val = op.fn.execute(avals[j]);
+ ret.appendValue(i, aix[j], val);
}
}
+ //nnz maintained on appendValue
}
else //DENSE <- DENSE
{
@@ -2926,13 +2950,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
ret.allocateDenseBlock();
double[] a = denseBlock;
double[] c = ret.denseBlock;
+ int len = m * n;
//unary op, incl nnz maintenance
- int len = m*n;
+ int nnz = 0;
for( int i=0; i<len; i++ ) {
c[i] = op.fn.execute(a[i]);
- ret.nonZeros += (c[i] != 0) ? 1 : 0;
- }
+ nnz += (c[i] != 0) ? 1 : 0;
+ }
+ ret.nonZeros = nnz;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
index 4a45569..75891d3 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
@@ -52,32 +52,23 @@ public class ScalarOperator extends Operator
//as long as (0 op v)=0, then op is sparsesafe
//note: additional functionobjects might qualify according to constant
- if( fn instanceof Multiply || fn instanceof Multiply2
- || fn instanceof Power || fn instanceof Power2
- || fn instanceof And || fn instanceof MinusNz
- || (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.LOG_NZ))
- {
- sparseSafe=true;
- }
- else
- {
- sparseSafe=false;
- }
+ sparseSafe = (fn instanceof Multiply || fn instanceof Multiply2
+ || fn instanceof Power || fn instanceof Power2
+ || fn instanceof And || fn instanceof MinusNz
+ || (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.LOG_NZ));
}
- public double getConstant()
- {
+ public double getConstant() {
return _constant;
}
- public void setConstant(double cst)
- {
+ public void setConstant(double cst) {
//set constant
_constant = cst;
//revisit sparse safe decision according to known constant
//note: there would be even more potential if we take left/right op into account
- if( fn instanceof Multiply || fn instanceof Multiply2
+ sparseSafe = ( fn instanceof Multiply || fn instanceof Multiply2
|| fn instanceof Power || fn instanceof Power2
|| fn instanceof And || fn instanceof MinusNz
|| fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.LOG_NZ
@@ -88,14 +79,7 @@ public class ScalarOperator extends Operator
|| (fn instanceof Minus && _constant==0)
|| (fn instanceof Minus && _constant==0)
|| (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.MAX && _constant<=0)
- || (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.MIN && _constant>=0))
- {
- sparseSafe = true;
- }
- else
- {
- sparseSafe = false;
- }
+ || (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.MIN && _constant>=0));
}
public double executeScalar(double in) throws DMLRuntimeException {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
index 2f65fb4..a736c8b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
@@ -40,17 +40,13 @@ public class UnaryOperator extends Operator
sparseSafe = false;
k = numThreads;
- if(fn instanceof Builtin)
- {
+ if( fn instanceof Builtin ) {
Builtin f=(Builtin)fn;
- if(f.bFunc==Builtin.BuiltinFunctionCode.SIN || f.bFunc==Builtin.BuiltinFunctionCode.TAN
+ sparseSafe = (f.bFunc==Builtin.BuiltinFunctionCode.SIN || f.bFunc==Builtin.BuiltinFunctionCode.TAN
|| f.bFunc==Builtin.BuiltinFunctionCode.ROUND || f.bFunc==Builtin.BuiltinFunctionCode.ABS
|| f.bFunc==Builtin.BuiltinFunctionCode.SQRT || f.bFunc==Builtin.BuiltinFunctionCode.SPROP
|| f.bFunc==Builtin.BuiltinFunctionCode.SELP || f.bFunc==Builtin.BuiltinFunctionCode.LOG_NZ
- || f.bFunc==Builtin.BuiltinFunctionCode.SIGN )
- {
- sparseSafe = true;
- }
+ || f.bFunc==Builtin.BuiltinFunctionCode.SIGN );
}
}
[4/4] incubator-systemml git commit: [SYSTEMML-833] Additional
cleanup rewrites (unnecess. cast, reorg, agg)
Posted by mb...@apache.org.
[SYSTEMML-833] Additional cleanup rewrites (unnecess. cast, reorg, agg)
This patch adds various additional cleanup rewrites in order to simplify
debugging. In detail this includes:
(1) Unnecessary data type casts (e.g., as.scalar(as.matrix))
(2) Unnecessary reorg operations (e.g., t(X), iff X 1x1 dims)
(3) Unnecessary aggregation (e.g., sum(X) iff X 1x1 dims)
(4) Pushdown of scalar casts (e.g., as.scalar(X*s)->as.scalar(X)*s)
Note that these rewrites enable each other; e.g., once (2), (3), and (4)
are performed, unnecessary casts (1) can be removed avoiding long chains
of unnecessary operations like sum(t(as.matrix(t(X))*7)) ->
as.scalar(X)*7.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/11a85775
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/11a85775
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/11a85775
Branch: refs/heads/master
Commit: 11a85775f11e4490d957fe4f9fab4bfd8ea7a138
Parents: 461184a
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Jul 30 23:48:51 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Jul 31 19:16:59 2016 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 3 +-
.../sysml/hops/rewrite/ProgramRewriter.java | 11 ++--
.../RewriteAlgebraicSimplificationDynamic.java | 67 ++++++++++++++++----
.../RewriteAlgebraicSimplificationStatic.java | 47 ++++++++++++++
.../rewrite/RewriteRemoveUnnecessaryCasts.java | 21 +++++-
.../cp/ArithmeticBinaryCPInstruction.java | 14 ++--
6 files changed, 134 insertions(+), 29 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index f7e4656..3bfdcb5 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -512,7 +512,8 @@ public class HopRewriteUtils
public static UnaryOp createUnary(Hop input, OpOp1 type)
throws HopsException
{
- UnaryOp unary = new UnaryOp(input.getName(), input.getDataType(), input.getValueType(), type, input);
+ DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : input.getDataType();
+ UnaryOp unary = new UnaryOp(input.getName(), dt, input.getValueType(), type, input);
HopRewriteUtils.setOutputBlocksizes(unary, input.getRowsInBlock(), input.getColsInBlock());
HopRewriteUtils.copyLineNumbers(input, unary);
unary.refreshSizeInformation();
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 8e645dc..e7b03c4 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -130,12 +130,13 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
}
-
- //reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and
- //(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse)
- if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
- _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); //dependency: simplifications
}
+
+ // cleanup after all rewrites applied
+ // (newly introduced operators, introduced redundancy after rewrites w/ multiple parents)
+ _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
+ if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
+ _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) );
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 10953f5..793bc25 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -69,6 +69,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//valid aggregation operation types for empty (sparse-safe) operations (not all operations apply)
//AggOp.MEAN currently not due to missing count/corrections
private static AggOp[] LOOKUP_VALID_EMPTY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE};
+ private static AggOp[] LOOKUP_VALID_UNNECESSARY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE};
//valid unary operation types for empty (sparse-safe) operations (not all operations apply)
private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM};
@@ -149,13 +150,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = removeUnnecessaryLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> Y, if output == input dims
hi = fuseLeftIndexingChainToAppend(hop, hi, i); //e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix
hi = removeUnnecessaryCumulativeOp(hop, hi, i); //e.g., cumsum(X) -> X, if nrow(X)==1;
- hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if output == input dims
+ hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims
hi = removeUnnecessaryOuterProduct(hop, hi, i); //e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector
hi = fuseDatagenAndReorgOperation(hop, hi, i); //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
hi = simplifyColwiseAggregate(hop, hi, i); //e.g., colsums(X) -> sum(X) or X, if col/row vector
hi = simplifyRowwiseAggregate(hop, hi, i); //e.g., rowsums(X) -> sum(X) or X, if row/col vector
hi = simplifyColSumsMVMult(hop, hi, i); //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
hi = simplifyRowSumsMVMult(hop, hi, i); //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
+ hi = simplifyUnnecessaryAggregate(hop, hi, i); //e.g., sum(X) -> as.scalar(X), if 1x1 dims
hi = simplifyEmptyAggregate(hop, hi, i); //e.g., sum(X) -> 0, if nnz(X)==0
hi = simplifyEmptyUnaryOperation(hop, hi, i); //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
hi = simplifyEmptyReorgOperation(hop, hi, i); //e.g., t(X) -> matrix(0, ncol(X), nrow(X))
@@ -428,22 +430,26 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
*/
private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
{
- if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp() == ReOrgOp.RESHAPE ) //reshape operation
+ if( hi instanceof ReorgOp )
{
+ ReorgOp rop = (ReorgOp) hi;
Hop input = hi.getInput().get(0);
-
- if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims
- {
- //equal dims of reshape input and output -> no need for reshape because
- //byrow always refers to both input/output and hence gives the same result
-
- //remove unnecessary right indexing
- HopRewriteUtils.removeChildReference(parent, hi);
+ boolean apply = false;
+
+ //equal dims of reshape input and output -> no need for reshape because
+ //byrow always refers to both input/output and hence gives the same result
+ apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input));
+
+ //1x1 dimensions of transpose/reshape -> no need for reorg
+ apply |= ((rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE)
+ && rop.getDim1()==1 && rop.getDim2()==1);
+
+ if( apply ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
HopRewriteUtils.addChildReference(parent, input, pos);
parent.refreshSizeInformation();
hi = input;
-
- LOG.debug("Applied removeUnnecessaryReshape");
+ LOG.debug("Applied removeUnnecessaryReorg.");
}
}
@@ -841,6 +847,43 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
* @return
* @throws HopsException
*/
+ private Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ //e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace)
+ if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol )
+ {
+ AggUnaryOp uhi = (AggUnaryOp)hi;
+ Hop input = uhi.getInput().get(0);
+
+ if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE) ){
+
+ if( input.getDim1()==1 && input.getDim2()==1 )
+ {
+ UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
+
+ //remove unnecessary aggregation
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, cast, pos);
+ parent.refreshSizeInformation();
+ hi = cast;
+
+ LOG.debug("Applied simplifyUnncessaryAggregate");
+ }
+ }
+ }
+
+ return hi;
+ }
+
+ /**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
private Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos)
throws HopsException
{
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f23686c..784d678 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -144,6 +144,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X)
+ hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X))
hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X)
hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
@@ -890,6 +891,52 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
* @param hi
* @param pos
* @return
+ * @throws HopsException
+ */
+ private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos )
+ throws HopsException
+ {
+ if( hi instanceof UnaryOp && ((UnaryOp)hi).getOp()==OpOp1.CAST_AS_SCALAR
+ && hi.getInput().get(0) instanceof BinaryOp )
+ {
+ BinaryOp bin = (BinaryOp) hi.getInput().get(0);
+ BinaryOp bout = null;
+
+ //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
+ if( bin.getInput().get(0).getDataType()==DataType.MATRIX
+ && bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
+ UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
+ UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
+ bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp());
+ }
+ //as.scalar(X*s) -> as.scalar(X) * s
+ else if( bin.getInput().get(0).getDataType()==DataType.MATRIX ) {
+ UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
+ bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp());
+ }
+ //as.scalar(s*X) -> s * as.scalar(X)
+ else if ( bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
+ UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
+ bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
+ }
+
+ if( bout != null ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, bout, pos);
+
+ LOG.debug("Applied simplifyBinaryMatrixScalarOperation.");
+ }
+ }
+
+ return hi;
+ }
+
+ /**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
*/
private Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int pos )
{
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
index 36d8712..a8001f8 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.Hop.VisitStatus;
import org.apache.sysml.hops.UnaryOp;
@@ -73,6 +74,7 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule
*
* @param hop
*/
+ @SuppressWarnings("unchecked")
private void rule_RemoveUnnecessaryCasts( Hop hop )
{
//check mark processed
@@ -84,7 +86,7 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule
for( int i=0; i<inputs.size(); i++ )
rule_RemoveUnnecessaryCasts( inputs.get(i) );
- //remove cast if unnecessary
+ //remove unnecessary value type cast
if( hop instanceof UnaryOp && HopRewriteUtils.isValueTypeCast(((UnaryOp)hop).getOp()) )
{
Hop in = hop.getInput().get(0);
@@ -116,6 +118,23 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule
}
}
+ //remove unnecessary data type casts
+ if( hop instanceof UnaryOp && hop.getInput().get(0) instanceof UnaryOp ) {
+ UnaryOp uop1 = (UnaryOp) hop;
+ UnaryOp uop2 = (UnaryOp) hop.getInput().get(0);
+ if( (uop1.getOp()==OpOp1.CAST_AS_MATRIX && uop2.getOp()==OpOp1.CAST_AS_SCALAR)
+ || (uop1.getOp()==OpOp1.CAST_AS_SCALAR && uop2.getOp()==OpOp1.CAST_AS_MATRIX) ) {
+ Hop input = uop2.getInput().get(0);
+ //rewire parents
+ ArrayList<Hop> parents = (ArrayList<Hop>) hop.getParent().clone();
+ for( Hop p : parents ) {
+ int ix = HopRewriteUtils.getChildReferencePos(p, hop);
+ HopRewriteUtils.removeChildReference(p, hop);
+ HopRewriteUtils.addChildReference(p, input, ix);
+ }
+ }
+ }
+
//mark processed
hop.setVisited( VisitStatus.DONE );
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
index 38ba9dd..c9545ac 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
@@ -60,16 +60,10 @@ public abstract class ArithmeticBinaryCPInstruction extends BinaryCPInstruction
//make sure these checks belong here
//if either input is a matrix, then output
//has to be a matrix
- if((dt1 == DataType.MATRIX
- || dt2 == DataType.MATRIX)
- && dt3 != DataType.MATRIX)
- throw new DMLRuntimeException("Element-wise matrix operations between variables "
- + in1.getName()
- + " and "
- + in2.getName()
- + " must produce a matrix, which "
- + out.getName()
- + "is not");
+ if((dt1 == DataType.MATRIX || dt2 == DataType.MATRIX) && dt3 != DataType.MATRIX) {
+ throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() +
+ " and " + in2.getName() + " must produce a matrix, which " + out.getName() + "is not");
+ }
Operator operator = (dt1 != dt2) ?
InstructionUtils.parseScalarBinaryOperator(opcode, (dt1 == DataType.SCALAR)) :
[3/4] incubator-systemml git commit: [SYSTEMML-694] Improved
binary-unary rewrites (added max(0, X)->selp(X))
Posted by mb...@apache.org.
[SYSTEMML-694] Improved binary-unary rewrites (added max(0,X)->selp(X))
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/461184aa
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/461184aa
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/461184aa
Branch: refs/heads/master
Commit: 461184aa5576bef29b7561d8dfc1d6de37bcff83
Parents: 382df65
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Jul 30 16:23:01 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 30 18:56:46 2016 -0700
----------------------------------------------------------------------
.../RewriteAlgebraicSimplificationStatic.java | 21 +++++++++++++++++++-
1 file changed, 20 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/461184aa/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index ae9c073..f23686c 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1184,6 +1184,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
//by definition, either left or right or none applies.
//note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial
//to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
+ boolean applied = false;
if( left instanceof BinaryOp ) //(X>0)*X
{
@@ -1206,11 +1207,12 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
HopRewriteUtils.removeAllChildReferences(left);
hi = unary;
+ applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp1");
}
}
- if( right instanceof BinaryOp ) //X*(X>0)
+ if( !applied && right instanceof BinaryOp ) //X*(X>0)
{
BinaryOp bright = (BinaryOp)right;
Hop right1 = bright.getInput().get(0);
@@ -1231,6 +1233,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
HopRewriteUtils.removeAllChildReferences(right);
hi = unary;
+ applied= true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp2");
}
@@ -1252,6 +1255,22 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp3");
}
+
+ //select positive (selp) operator; pattern: max(0,X) -> selp+
+ if( bop.getOp() == OpOp2.MAX && right.getDataType()==DataType.MATRIX
+ && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 )
+ {
+ UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP);
+ HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos);
+ HopRewriteUtils.addChildReference(parent, unary, pos);
+
+ //cleanup if only consumer of intermediate
+ if( bop.getParent().isEmpty() )
+ HopRewriteUtils.removeAllChildReferences(bop);
+ hi = unary;
+
+ LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp4");
+ }
}
return hi;
[2/4] incubator-systemml git commit: [SYSTEMML-694] Improved wdivmm
rewrites (outer-product-like mm only)
Posted by mb...@apache.org.
[SYSTEMML-694] Improved wdivmm rewrites (outer-product-like mm only)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/382df653
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/382df653
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/382df653
Branch: refs/heads/master
Commit: 382df653a5438124b1cdff8c676a55b9c0d50976
Parents: a528f5e
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jul 29 23:25:01 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 30 16:23:21 2016 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 11 +++++++++
.../RewriteAlgebraicSimplificationDynamic.java | 24 +++++++++++---------
2 files changed, 24 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 385a888..f7e4656 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -785,6 +785,17 @@ public class HopRewriteUtils
: (hop.getDim1()>0 && hop.getDim1()<=hop.getRowsInBlock());
}
+ /**
+ *
+ * @param hop
+ * @return
+ */
+ public static boolean isOuterProductLikeMM( Hop hop ) {
+ return hop instanceof AggBinaryOp
+ && hop.getInput().get(0).getDim1() > hop.getInput().get(0).getDim2()
+ && hop.getInput().get(1).getDim1() < hop.getInput().get(1).getDim2();
+ }
+
public static boolean isEqualValue( LiteralOp hop1, LiteralOp hop2 )
throws HopsException
{
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index dbde506..10953f5 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1852,7 +1852,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//alternative pattern: t(U) %*% (W*(U%*%t(V)))
if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
&& HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv
- && right.getInput().get(1) instanceof AggBinaryOp
+ && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1))
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput().get(0);
@@ -1886,6 +1886,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& right.getInput().get(1) instanceof BinaryOp
&& ((BinaryOp) right.getInput().get(1)).getOp() == Hop.OpOp2.PLUS
&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
+ && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput().get(0);
@@ -1917,7 +1918,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
&& HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv
- && left.getInput().get(1) instanceof AggBinaryOp
+ && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1))
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput().get(0);
@@ -1948,6 +1949,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& left.getInput().get(1) instanceof BinaryOp
&& ((BinaryOp) left.getInput().get(1)).getOp() == Hop.OpOp2.PLUS
&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
+ && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput().get(0);
@@ -1975,8 +1977,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
&& right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS
- && right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
- && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+ && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
+ && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput().get(0);
@@ -2008,8 +2010,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
&& left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS
- && left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
- && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+ && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
+ && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput().get(0);
@@ -2038,8 +2040,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
&& right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS
- && right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
- && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+ && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
+ && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput().get(0);
@@ -2071,8 +2073,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
&& left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS
- && left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
- && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+ && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
+ && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput().get(0);
@@ -2105,7 +2107,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& hi.getDim2() > 1 //not applied for vector-vector mult
&& hi.getInput().get(0).getDataType() == DataType.MATRIX
&& hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock()
- && hi.getInput().get(1) instanceof AggBinaryOp
+ && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1))
&& (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
&& HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
{