You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/08/01 02:17:26 UTC

[1/4] incubator-systemml git commit: [SYSTEMML-824] Performance scalar/unary/binary ops (nnz, sparse, skip)

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 7f3327658 -> 11a85775f


[SYSTEMML-824] Performance scalar/unary/binary ops (nnz, sparse, skip)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a528f5e4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a528f5e4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a528f5e4

Branch: refs/heads/master
Commit: a528f5e4ee3e4717ff0dfe27850533b94043c8fe
Parents: 7f33276
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jul 29 23:22:39 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 30 16:23:17 2016 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixBincell.java   | 140 ++++++++++---------
 .../sysml/runtime/matrix/data/MatrixBlock.java  |  56 ++++++--
 .../matrix/operators/ScalarOperator.java        |  32 ++---
 .../runtime/matrix/operators/UnaryOperator.java |  10 +-
 4 files changed, 126 insertions(+), 112 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
index 6ad08d8..f9269e8 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java
@@ -29,11 +29,13 @@ import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
 import org.apache.sysml.runtime.functionobjects.LessThan;
 import org.apache.sysml.runtime.functionobjects.LessThanEquals;
 import org.apache.sysml.runtime.functionobjects.Minus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Multiply2;
 import org.apache.sysml.runtime.functionobjects.NotEquals;
 import org.apache.sysml.runtime.functionobjects.Or;
 import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.Power2;
 import org.apache.sysml.runtime.functionobjects.ValueFunction;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
@@ -332,7 +334,8 @@ public class LibMatrixBincell
 				}
 			}
 			else if( !ret.sparse && (m1.sparse || m2.sparse) &&
-					(op.fn instanceof Plus || op.fn instanceof Minus || 
+					(op.fn instanceof Plus || op.fn instanceof Minus ||
+					op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply ||
 					(op.fn instanceof Multiply && !m2.sparse )))
 			{
 				//specific case in order to prevent binary search on sparse inputs (see quickget and quickset)
@@ -464,7 +467,8 @@ public class LibMatrixBincell
 		double[] a = m1.denseBlock;
 		double[] b = m2.denseBlock;
 		double[] c = ret.denseBlock;
-
+		int nnz = 0;
+		
 		if( atype == BinaryAccessType.MATRIX_COL_VECTOR )
 		{
 			for( int i=0, ix=0; i<rlen; i++, ix+=clen )
@@ -474,33 +478,39 @@ public class LibMatrixBincell
 				if( skipEmpty && v2 == 0 ) //skip empty rows
 					continue;
 					
-				if( isMultiply && v2 == 1 ) //ROW COPY
-				{
+				if( isMultiply && v2 == 1 ) { //ROW COPY
 					//a guaranteed to be non-null (see early abort)
 					System.arraycopy(a, ix, c, ix, clen);
+					nnz += m1.recomputeNonZeros(i, i, 0, clen-1);
 				}
-				else //GENERAL CASE
-				{
+				else { //GENERAL CASE
 					if( a != null )
-						for( int j=0; j<clen; j++ )
+						for( int j=0; j<clen; j++ ) {
 							c[ix+j] = op.fn.execute( a[ix+j], v2 );	
-					else
-						Arrays.fill(c, ix, ix+clen, op.fn.execute( 0, v2 ));	
+							nnz += (c[ix+j] != 0) ? 1 : 0;
+						}
+					else {
+						double val = op.fn.execute( 0, v2 );
+						Arrays.fill(c, ix, ix+clen, val);
+						nnz += (val != 0) ? clen : 0;
+					}
 				}
 			}
 		}
 		else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR )
 		{
-			if( a==null && b==null ) //both empty
-			{
+			if( a==null && b==null ) { //both empty
 				double v = op.fn.execute( 0, 0 );
 				Arrays.fill(c, 0, rlen*clen, v);
+				nnz += (v != 0) ? rlen*clen : 0;
 			}
 			else if( a==null ) //left empty
 			{
 				//compute first row
-				for( int j=0; j<clen; j++ )
+				for( int j=0; j<clen; j++ ) {
 					c[j] = op.fn.execute( 0, b[j] );
+					nnz += (c[j] != 0) ? rlen : 0;
+				}
 				//copy first to all other rows
 				for( int i=1, ix=clen; i<rlen; i++, ix+=clen )
 					System.arraycopy(c, 0, c, ix, clen);
@@ -508,12 +518,14 @@ public class LibMatrixBincell
 			else //default case (incl right empty) 
 			{
 				for( int i=0, ix=0; i<rlen; i++, ix+=clen )
-					for( int j=0; j<clen; j++ )
+					for( int j=0; j<clen; j++ ) {
 						c[ix+j] = op.fn.execute( a[ix+j], ((b!=null) ? b[j] : 0) );	
+						nnz += (c[ix+j] != 0) ? 1 : 0;
+					}
 			}
 		}
 		
-		ret.recomputeNonZeros();
+		ret.nonZeros = nnz;
 	}
 	
 	/**
@@ -608,7 +620,7 @@ public class LibMatrixBincell
 					for( int j=apos; j<apos+alen; j++ )
 					{
 						//empty left
-						for( int k = lastIx+1; k<aix[j]; k++ ){
+						for( int k=lastIx+1; !skipEmpty&&k<aix[j]; k++ ){
 							double v2 = m2.quickGetValue(0, k);
 							double v = op.fn.execute( 0, v2 );
 							ret.appendValue(i, k, v);
@@ -622,7 +634,7 @@ public class LibMatrixBincell
 				}
 				
 				//empty left
-				for( int k = lastIx+1; k<clen; k++ ){
+				for( int k=lastIx+1; !skipEmpty&&k<clen; k++ ){
 					double v2 = m2.quickGetValue(0, k);
 					double v = op.fn.execute( 0, v2 );
 					ret.appendValue(i, k, v);
@@ -945,43 +957,43 @@ public class LibMatrixBincell
 			ret.allocateSparseRowsBlock();
 			SparseBlock a = m1.sparseBlock;
 			SparseBlock c = ret.sparseBlock;
+			int rlen = Math.min(m1.rlen, a.numRows());
 			
-			for(int r=0; r<Math.min(m1.rlen, a.numRows()); r++) {
-				if( !a.isEmpty(r) )
-				{
-					int apos = a.pos(r);
-					int alen = a.size(r);
-					int[] aix = a.indexes(r);
-					double[] avals = a.values(r);
+			long nnz = 0;
+			for(int r=0; r<rlen; r++) {
+				if( a.isEmpty(r) ) continue;
+				
+				int apos = a.pos(r);
+				int alen = a.size(r);
+				int[] aix = a.indexes(r);
+				double[] avals = a.values(r);
+				
+				if( copyOnes ) { //SPECIAL CASE: e.g., (X != 0) 
+					//create sparse row without repeated resizing
+					SparseRow crow = new SparseRow(alen);
+					crow.setSize(alen);
 					
-					if( copyOnes ) //SPECIAL CASE: e.g., (X != 0) 
-					{
-						//create sparse row without repeated resizing
-						SparseRow crow = new SparseRow(alen);
-						crow.setSize(alen);
-						
-						//memcopy/memset of indexes/values (sparseblock guarantees absence of 0s) 
-						System.arraycopy(aix, apos, crow.indexes(), 0, alen);
-						Arrays.fill(crow.values(), 0, alen, 1);
-						c.set(r, crow, false);
-						ret.nonZeros+=alen;
+					//memcopy/memset of indexes/values (sparseblock guarantees absence of 0s) 
+					System.arraycopy(aix, apos, crow.indexes(), 0, alen);
+					Arrays.fill(crow.values(), 0, alen, 1);
+					c.set(r, crow, false);
+					nnz += alen;
+				}
+				else { //GENERAL CASE
+					//create sparse row without repeated resizing for specific ops
+					if( op.fn instanceof Multiply || op.fn instanceof Multiply2 
+						|| op.fn instanceof Power2  ) {
+						c.allocate(r, alen);
 					}
-					else //GENERAL CASE
-					{
-						//create sparse row without repeated resizing for specific ops
-						if( op.fn instanceof Multiply || op.fn instanceof Multiply2 
-							|| op.fn instanceof Power2  )
-						{
-							c.allocate(r, alen);
-						}
-						
-						for(int j=apos; j<apos+alen; j++) {
-							double val = op.executeScalar(avals[j]);
-							ret.appendValue(r, aix[j], val);
-						}
+					
+					for(int j=apos; j<apos+alen; j++) {
+						double val = op.executeScalar(avals[j]);
+						c.append(r, aix[j], val);
+						nnz += (val != 0) ? 1 : 0; 
 					}
 				}
 			}
+			ret.nonZeros = nnz;
 		}
 		else //DENSE <- DENSE
 		{
@@ -992,12 +1004,12 @@ public class LibMatrixBincell
 			double[] c = ret.denseBlock;
 			
 			int limit = m1.rlen*m1.clen;
-			for( int i=0; i<limit; i++ )
-			{
+			int nnz = 0;
+			for( int i=0; i<limit; i++ ) {
 				c[i] = op.executeScalar( a[i] );
-				if( c[i] != 0 )
-					ret.nonZeros++;
+				nnz += (c[i] != 0) ? 1 : 0;
 			}
+			ret.nonZeros = nnz;
 		}
 		
 	}
@@ -1040,8 +1052,8 @@ public class LibMatrixBincell
 			Arrays.fill(c, cval0);
 			
 			//compute non-zero input values
-			for(int i=0, cix=0; i<m; i++, cix+=n) 
-			{
+			int nnz = m*n;
+			for(int i=0, cix=0; i<m; i++, cix+=n) {
 				if( !a.isEmpty(i) ) {
 					int apos = a.pos(i);
 					int alen = a.size(i);
@@ -1050,12 +1062,11 @@ public class LibMatrixBincell
 					for(int j=apos; j<apos+alen; j++) {
 						double val = op.executeScalar(avals[j]);
 						c[ cix+aix[j] ] = val;
+						nnz -= (val==0) ? 1 : 0;
 					}
 				}
 			}
-		
-			//recompute non zeros 
-			ret.recomputeNonZeros();
+			ret.nonZeros = nnz;
 		}
 		else //DENSE MATRIX
 		{
@@ -1067,12 +1078,12 @@ public class LibMatrixBincell
 			
 			//compute scalar operation, incl nnz maintenance
 			int limit = m1.rlen*m1.clen;
-			for( int i=0; i<limit; i++ )
-			{
+			int nnz = 0;
+			for( int i=0; i<limit; i++ ) {
 				c[i] = op.executeScalar( a[i] );
-				if( c[i] != 0 )
-					ret.nonZeros++;
+				nnz += (c[i] != 0) ? 1 : 0;
 			}
+			ret.nonZeros = nnz;
 		}
 	}
 
@@ -1270,21 +1281,18 @@ public class LibMatrixBincell
 		while( p1<size1 && p2< size2 )
 		{
 			double value = 0;
-			if(cols1[pos1+p1]<cols2[pos2+p2])
-			{
+			if(cols1[pos1+p1]<cols2[pos2+p2]) {
 				value = op.fn.execute(values1[pos1+p1], 0);
 				column = cols1[pos1+p1];
 				p1++;
 			}
-			else if(cols1[pos1+p1]==cols2[pos2+p2])
-			{
+			else if(cols1[pos1+p1]==cols2[pos2+p2]) {
 				value = op.fn.execute(values1[pos1+p1], values2[pos2+p2]);
 				column = cols1[pos1+p1];
 				p1++;
 				p2++;
 			}
-			else
-			{
+			else {
 				value = op.fn.execute(0, values2[pos2+p2]);
 				column = cols2[pos2+p2];
 				p2++;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 452eddb..5884768 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -2901,24 +2901,48 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		final int m = rlen;
 		final int n = clen;
 		
-		if( sparse ) //SPARSE <- SPARSE
+		if( sparse && ret.sparse ) //SPARSE <- SPARSE
 		{
+			ret.allocateSparseRowsBlock();
 			SparseBlock a = sparseBlock;
+			SparseBlock c = ret.sparseBlock;
+		
+			long nnz = 0;
+			for(int i=0; i<m; i++) {
+				if( a.isEmpty(i) ) continue;
+				
+				int apos = a.pos(i);
+				int alen = a.size(i);
+				int[] aix = a.indexes(i);
+				double[] avals = a.values(i);
+				
+				c.allocate(i, alen); //avoid repeated alloc
+				for( int j=apos; j<apos+alen; j++ ) {
+					double val = op.fn.execute(avals[j]);
+					c.append(i, aix[j], val);
+					nnz += (val != 0) ? 1 : 0;
+				}
+			}
+			ret.nonZeros = nnz;
+		}
+		else if( sparse ) //DENSE <- SPARSE
+		{
+			SparseBlock a = sparseBlock;			
 			
 			for(int i=0; i<m; i++) {
-				if( !a.isEmpty(i) )
-				{
-					int apos = a.pos(i);
-					int alen = a.size(i);
-					int[] aix = a.indexes(i);
-					double[] avals = a.values(i);
-					
-					for( int j=apos; j<apos+alen; j++ ) {
-						double val = op.fn.execute(avals[j]);
-						ret.appendValue(i, aix[j], val);
-					}
+				if( a.isEmpty(i) ) continue;
+			
+				int apos = a.pos(i);
+				int alen = a.size(i);
+				int[] aix = a.indexes(i);
+				double[] avals = a.values(i);
+				
+				for( int j=apos; j<apos+alen; j++ ) {
+					double val = op.fn.execute(avals[j]);
+					ret.appendValue(i, aix[j], val);
 				}
 			}
+			//nnz maintained on appendValue
 		}
 		else //DENSE <- DENSE
 		{
@@ -2926,13 +2950,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 			ret.allocateDenseBlock();						
 			double[] a = denseBlock;
 			double[] c = ret.denseBlock;
+			int len = m * n;
 			
 			//unary op, incl nnz maintenance
-			int len = m*n;
+			int nnz = 0;
 			for( int i=0; i<len; i++ ) {
 				c[i] = op.fn.execute(a[i]);
-				ret.nonZeros += (c[i] != 0) ? 1 : 0;
-			}			
+				nnz += (c[i] != 0) ? 1 : 0;
+			}
+			ret.nonZeros = nnz;
 		}
 	}
 	

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
index 4a45569..75891d3 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java
@@ -52,32 +52,23 @@ public class ScalarOperator  extends Operator
 		
 		//as long as (0 op v)=0, then op is sparsesafe
 		//note: additional functionobjects might qualify according to constant
-		if(   fn instanceof Multiply || fn instanceof Multiply2 
-		   || fn instanceof Power || fn instanceof Power2 
-		   || fn instanceof And || fn instanceof MinusNz
-		   || (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.LOG_NZ)) 
-		{
-			sparseSafe=true;
-		}
-		else
-		{
-			sparseSafe=false;
-		}
+		sparseSafe = (fn instanceof Multiply || fn instanceof Multiply2 
+				|| fn instanceof Power || fn instanceof Power2 
+				|| fn instanceof And || fn instanceof MinusNz
+				|| (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.LOG_NZ));
 	}
 	
-	public double getConstant()
-	{
+	public double getConstant() {
 		return _constant;
 	}
 	
-	public void setConstant(double cst) 
-	{
+	public void setConstant(double cst) {
 		//set constant
 		_constant = cst;
 		
 		//revisit sparse safe decision according to known constant
 		//note: there would be even more potential if we take left/right op into account
-		if(    fn instanceof Multiply || fn instanceof Multiply2 
+		sparseSafe = ( fn instanceof Multiply || fn instanceof Multiply2 
 			|| fn instanceof Power || fn instanceof Power2 
 			|| fn instanceof And || fn instanceof MinusNz
 			|| fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.LOG_NZ
@@ -88,14 +79,7 @@ public class ScalarOperator  extends Operator
 			|| (fn instanceof Minus && _constant==0)
 			|| (fn instanceof Minus && _constant==0)
 			|| (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.MAX && _constant<=0)
-			|| (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.MIN && _constant>=0))
-		{
-			sparseSafe = true;
-		}
-		else
-		{
-			sparseSafe = false;
-		}
+			|| (fn instanceof Builtin && ((Builtin)fn).getBuiltinFunctionCode()==BuiltinFunctionCode.MIN && _constant>=0));
 	}
 	
 	public double executeScalar(double in) throws DMLRuntimeException {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a528f5e4/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
index 2f65fb4..a736c8b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java
@@ -40,17 +40,13 @@ public class UnaryOperator extends Operator
 		sparseSafe = false;
 		k = numThreads;
 		
-		if(fn instanceof Builtin)
-		{
+		if( fn instanceof Builtin ) {
 			Builtin f=(Builtin)fn;
-			if(f.bFunc==Builtin.BuiltinFunctionCode.SIN || f.bFunc==Builtin.BuiltinFunctionCode.TAN 
+			sparseSafe = (f.bFunc==Builtin.BuiltinFunctionCode.SIN || f.bFunc==Builtin.BuiltinFunctionCode.TAN 
 					|| f.bFunc==Builtin.BuiltinFunctionCode.ROUND || f.bFunc==Builtin.BuiltinFunctionCode.ABS
 					|| f.bFunc==Builtin.BuiltinFunctionCode.SQRT || f.bFunc==Builtin.BuiltinFunctionCode.SPROP
 					|| f.bFunc==Builtin.BuiltinFunctionCode.SELP || f.bFunc==Builtin.BuiltinFunctionCode.LOG_NZ
-					|| f.bFunc==Builtin.BuiltinFunctionCode.SIGN )
-			{
-				sparseSafe = true;
-			}
+					|| f.bFunc==Builtin.BuiltinFunctionCode.SIGN );
 		}
 	}
 	


[4/4] incubator-systemml git commit: [SYSTEMML-833] Additional cleanup rewrites (unnecess. cast, reorg, agg)

Posted by mb...@apache.org.
[SYSTEMML-833] Additional cleanup rewrites (unnecess. cast, reorg, agg)

This patch adds various additional cleanup rewrites in order to simplify
debugging. In detail this includes:

(1) Unnecessary data type casts (e.g., as.scalar(as.matrix))
(2) Unnecessary reorg operations (e.g., t(X), iff X 1x1 dims)
(3) Unnecessary aggregation (e.g., sum(X) iff X 1x1 dims)
(4) Pushdown of scalar casts (e.g., as.scalar(X*s)->as.scalar(X)*s)

Note that these rewrites enable each other; e.g., once (2), (3), and (4)
are performed, unnecessary casts (1) can be removed avoiding long chains
of unnecessary operations like sum(t(as.matrix(t(X))*7)) ->
as.scalar(X)*7. 

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/11a85775
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/11a85775
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/11a85775

Branch: refs/heads/master
Commit: 11a85775f11e4490d957fe4f9fab4bfd8ea7a138
Parents: 461184a
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Jul 30 23:48:51 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Jul 31 19:16:59 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  3 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     | 11 ++--
 .../RewriteAlgebraicSimplificationDynamic.java  | 67 ++++++++++++++++----
 .../RewriteAlgebraicSimplificationStatic.java   | 47 ++++++++++++++
 .../rewrite/RewriteRemoveUnnecessaryCasts.java  | 21 +++++-
 .../cp/ArithmeticBinaryCPInstruction.java       | 14 ++--
 6 files changed, 134 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index f7e4656..3bfdcb5 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -512,7 +512,8 @@ public class HopRewriteUtils
 	public static UnaryOp createUnary(Hop input, OpOp1 type) 
 		throws HopsException
 	{
-		UnaryOp unary = new UnaryOp(input.getName(), input.getDataType(), input.getValueType(), type, input);
+		DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : input.getDataType();
+		UnaryOp unary = new UnaryOp(input.getName(), dt, input.getValueType(), type, input);
 		HopRewriteUtils.setOutputBlocksizes(unary, input.getRowsInBlock(), input.getColsInBlock());
 		HopRewriteUtils.copyLineNumbers(input, unary);
 		unary.refreshSizeInformation();	

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 8e645dc..e7b03c4 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -130,12 +130,13 @@ public class ProgramRewriter
 				_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic()      ); //dependencies: cse
 				_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic()       ); //dependencies: cse
 			}
-			
-			//reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and
-			//(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse)
-			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )             
-				_dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); //dependency: simplifications 			
 		}
+		
+		// cleanup after all rewrites applied 
+		// (newly introduced operators, introduced redundancy after rewrites w/ multiple parents) 
+		_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
+		if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )             
+			_dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); 			
 	}
 	
 	/**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 10953f5..793bc25 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -69,6 +69,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 	//valid aggregation operation types for empty (sparse-safe) operations (not all operations apply)
 	//AggOp.MEAN currently not due to missing count/corrections
 	private static AggOp[] LOOKUP_VALID_EMPTY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE};
+	private static AggOp[] LOOKUP_VALID_UNNECESSARY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE};
 	
 	//valid unary operation types for empty (sparse-safe) operations (not all operations apply)
 	private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM}; 
@@ -149,13 +150,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			hi = removeUnnecessaryLeftIndexing(hop, hi, i);   //e.g., X[,1]=Y -> Y, if output == input dims 
 			hi = fuseLeftIndexingChainToAppend(hop, hi, i);   //e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix
 			hi = removeUnnecessaryCumulativeOp(hop, hi, i);   //e.g., cumsum(X) -> X, if nrow(X)==1;
-			hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if output == input dims
+			hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims
 			hi = removeUnnecessaryOuterProduct(hop, hi, i);   //e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector
 			hi = fuseDatagenAndReorgOperation(hop, hi, i);    //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
 			hi = simplifyColwiseAggregate(hop, hi, i);        //e.g., colsums(X) -> sum(X) or X, if col/row vector
 			hi = simplifyRowwiseAggregate(hop, hi, i);        //e.g., rowsums(X) -> sum(X) or X, if row/col vector
 			hi = simplifyColSumsMVMult(hop, hi, i);           //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
 			hi = simplifyRowSumsMVMult(hop, hi, i);           //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
+			hi = simplifyUnnecessaryAggregate(hop, hi, i);    //e.g., sum(X) -> as.scalar(X), if 1x1 dims
 			hi = simplifyEmptyAggregate(hop, hi, i);          //e.g., sum(X) -> 0, if nnz(X)==0
 			hi = simplifyEmptyUnaryOperation(hop, hi, i);     //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0			
 			hi = simplifyEmptyReorgOperation(hop, hi, i);     //e.g., t(X) -> matrix(0, ncol(X), nrow(X)) 
@@ -428,22 +430,26 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 	 */
 	private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
 	{
-		if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp() == ReOrgOp.RESHAPE ) //reshape operation
+		if( hi instanceof ReorgOp ) 
 		{
+			ReorgOp rop = (ReorgOp) hi;
 			Hop input = hi.getInput().get(0); 
-
-			if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims
-			{
-				//equal dims of reshape input and output -> no need for reshape because 
-				//byrow always refers to both input/output and hence gives the same result
-				
-				//remove unnecessary right indexing
-				HopRewriteUtils.removeChildReference(parent, hi);				
+			boolean apply = false;
+			
+			//equal dims of reshape input and output -> no need for reshape because 
+			//byrow always refers to both input/output and hence gives the same result
+			apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input));
+			
+			//1x1 dimensions of transpose/reshape -> no need for reorg 	
+			apply |= ((rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE) 
+					&& rop.getDim1()==1 && rop.getDim2()==1);
+			
+			if( apply ) {
+				HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);				
 				HopRewriteUtils.addChildReference(parent, input, pos);
 				parent.refreshSizeInformation();
 				hi = input;
-				
-				LOG.debug("Applied removeUnnecessaryReshape");
+				LOG.debug("Applied removeUnnecessaryReorg.");
 			}			
 		}
 		
@@ -841,6 +847,43 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 	 * @return
 	 * @throws HopsException
 	 */
+	private Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) 
+		throws HopsException
+	{
+		//e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace)
+		if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol  ) 
+		{
+			AggUnaryOp uhi = (AggUnaryOp)hi;
+			Hop input = uhi.getInput().get(0);
+			
+			if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE) ){		
+				
+				if( input.getDim1()==1 && input.getDim2()==1 )
+				{
+					UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
+					
+					//remove unnecessary aggregation 
+					HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+					HopRewriteUtils.addChildReference(parent, cast, pos);
+					parent.refreshSizeInformation();
+					hi = cast;
+					
+					LOG.debug("Applied simplifyUnncessaryAggregate");
+				}
+			}			
+		}
+		
+		return hi;
+	}
+	
+	/**
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 * @throws HopsException
+	 */
 	private Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) 
 		throws HopsException
 	{

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f23686c..784d678 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -144,6 +144,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
  			hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
  			hi = simplifyBushyBinaryOperation(hop, hi, i);       //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
  			hi = simplifyUnaryAggReorgOperation(hop, hi, i);     //e.g., sum(t(X)) -> sum(X)
+ 			hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
  			hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X))
  			hi = pushdownSumBinaryMult(hop, hi, i);              //e.g., sum(lamda*X) -> lamda*sum(X)
  			hi = simplifyUnaryPPredOperation(hop, hi, i);        //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
@@ -890,6 +891,52 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 	 * @param hi
 	 * @param pos
 	 * @return
+	 * @throws HopsException
+	 */
+	private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) 
+		throws HopsException
+	{
+		if(   hi instanceof UnaryOp && ((UnaryOp)hi).getOp()==OpOp1.CAST_AS_SCALAR  
+		   && hi.getInput().get(0) instanceof BinaryOp ) 
+		{
+			BinaryOp bin = (BinaryOp) hi.getInput().get(0);
+			BinaryOp bout = null;
+			
+			//as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
+			if( bin.getInput().get(0).getDataType()==DataType.MATRIX 
+				&& bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
+				UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
+				UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
+				bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp());
+			}
+			//as.scalar(X*s) -> as.scalar(X) * s
+			else if( bin.getInput().get(0).getDataType()==DataType.MATRIX ) {
+				UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
+				bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp());
+			}
+			//as.scalar(s*X) -> s * as.scalar(X)
+			else if ( bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
+				UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
+				bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
+			}
+			
+			if( bout != null ) {
+				HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+				HopRewriteUtils.addChildReference(parent, bout, pos);
+				
+				LOG.debug("Applied simplifyBinaryMatrixScalarOperation.");
+			}
+		}
+		
+		return hi;
+	}
+	
+	/**
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
 	 */
 	private Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int pos )
 	{

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
index 36d8712..a8001f8 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.Hop.VisitStatus;
 import org.apache.sysml.hops.UnaryOp;
@@ -73,6 +74,7 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule
 	 * 
 	 * @param hop
 	 */
+	@SuppressWarnings("unchecked")
 	private void rule_RemoveUnnecessaryCasts( Hop hop )
 	{
 		//check mark processed
@@ -84,7 +86,7 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule
 		for( int i=0; i<inputs.size(); i++ )
 			rule_RemoveUnnecessaryCasts( inputs.get(i) );
 		
-		//remove cast if unnecessary
+		//remove unnecessary value type cast 
 		if( hop instanceof UnaryOp && HopRewriteUtils.isValueTypeCast(((UnaryOp)hop).getOp()) )
 		{
 			Hop in = hop.getInput().get(0);
@@ -116,6 +118,23 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule
 			}
 		}
 		
+		//remove unnecessary data type casts
+		if( hop instanceof UnaryOp && hop.getInput().get(0) instanceof UnaryOp ) {
+			UnaryOp uop1 = (UnaryOp) hop;
+			UnaryOp uop2 = (UnaryOp) hop.getInput().get(0);
+			if( (uop1.getOp()==OpOp1.CAST_AS_MATRIX && uop2.getOp()==OpOp1.CAST_AS_SCALAR) 
+				|| (uop1.getOp()==OpOp1.CAST_AS_SCALAR && uop2.getOp()==OpOp1.CAST_AS_MATRIX) ) {
+				Hop input = uop2.getInput().get(0);
+				//rewire parents
+				ArrayList<Hop> parents = (ArrayList<Hop>) hop.getParent().clone();
+				for( Hop p : parents ) {
+					int ix = HopRewriteUtils.getChildReferencePos(p, hop);
+					HopRewriteUtils.removeChildReference(p, hop);
+					HopRewriteUtils.addChildReference(p, input, ix);
+				}
+			}
+		}
+		
 		//mark processed
 		hop.setVisited( VisitStatus.DONE );
 	}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
index 38ba9dd..c9545ac 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java
@@ -60,16 +60,10 @@ public abstract class ArithmeticBinaryCPInstruction extends BinaryCPInstruction
 		//make sure these checks belong here
 		//if either input is a matrix, then output
 		//has to be a matrix
-		if((dt1 == DataType.MATRIX 
-			|| dt2 == DataType.MATRIX) 
-		   && dt3 != DataType.MATRIX)
-			throw new DMLRuntimeException("Element-wise matrix operations between variables "
-										  + in1.getName()
-										  + " and "
-										  + in2.getName()
-										  + " must produce a matrix, which "
-										  + out.getName()
-										  + "is not");
+		if((dt1 == DataType.MATRIX  || dt2 == DataType.MATRIX) && dt3 != DataType.MATRIX) {
+			throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() + 
+					" and " + in2.getName() + " must produce a matrix, which " + out.getName() + "is not");
+		}
 		
 		Operator operator = (dt1 != dt2) ?
 					InstructionUtils.parseScalarBinaryOperator(opcode, (dt1 == DataType.SCALAR)) : 


[3/4] incubator-systemml git commit: [SYSTEMML-694] Improved binary-unary rewrites (added max(0, X)->selp(X))

Posted by mb...@apache.org.
[SYSTEMML-694] Improved binary-unary rewrites (added max(0,X)->selp(X))

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/461184aa
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/461184aa
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/461184aa

Branch: refs/heads/master
Commit: 461184aa5576bef29b7561d8dfc1d6de37bcff83
Parents: 382df65
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Jul 30 16:23:01 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 30 18:56:46 2016 -0700

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   | 21 +++++++++++++++++++-
 1 file changed, 20 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/461184aa/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index ae9c073..f23686c 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1184,6 +1184,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 				//by definition, either left or right or none applies. 
 				//note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial
 				//to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation 
+				boolean applied = false;
 				
 				if( left instanceof BinaryOp ) //(X>0)*X
 				{
@@ -1206,11 +1207,12 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 							HopRewriteUtils.removeAllChildReferences(left);
 						
 						hi = unary;
+						applied = true;
 						
 						LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp1");
 					}
 				}				
-				if( right instanceof BinaryOp ) //X*(X>0)
+				if( !applied && right instanceof BinaryOp ) //X*(X>0)
 				{
 					BinaryOp bright = (BinaryOp)right;
 					Hop right1 = bright.getInput().get(0);
@@ -1231,6 +1233,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 							HopRewriteUtils.removeAllChildReferences(right);
 						
 						hi = unary;
+						applied= true;
 						
 						LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp2");
 					}
@@ -1252,6 +1255,22 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 				
 				LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp3");
 			}
+			
+			//select positive (selp) operator; pattern: max(0,X) -> selp+
+			if( bop.getOp() == OpOp2.MAX && right.getDataType()==DataType.MATRIX 
+					&& left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 )
+			{
+				UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP);
+				HopRewriteUtils.removeChildReferenceByPos(parent, bop, pos);
+				HopRewriteUtils.addChildReference(parent, unary, pos);
+				
+				//cleanup if only consumer of intermediate
+				if( bop.getParent().isEmpty() )
+					HopRewriteUtils.removeAllChildReferences(bop);					
+				hi = unary;
+				
+				LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp4");
+			}
 		}
 		
 		return hi;


[2/4] incubator-systemml git commit: [SYSTEMML-694] Improved wdivmm rewrites (outer-product-like mm only)

Posted by mb...@apache.org.
[SYSTEMML-694] Improved wdivmm rewrites (outer-product-like mm only)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/382df653
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/382df653
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/382df653

Branch: refs/heads/master
Commit: 382df653a5438124b1cdff8c676a55b9c0d50976
Parents: a528f5e
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jul 29 23:25:01 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 30 16:23:21 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 11 +++++++++
 .../RewriteAlgebraicSimplificationDynamic.java  | 24 +++++++++++---------
 2 files changed, 24 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 385a888..f7e4656 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -785,6 +785,17 @@ public class HopRewriteUtils
 				    : (hop.getDim1()>0 && hop.getDim1()<=hop.getRowsInBlock());
 	}
 	
+	/**
+	 * 
+	 * @param hop
+	 * @return
+	 */
+	public static boolean isOuterProductLikeMM( Hop hop ) {
+		return hop instanceof AggBinaryOp
+			&& hop.getInput().get(0).getDim1() > hop.getInput().get(0).getDim2()
+			&& hop.getInput().get(1).getDim1() < hop.getInput().get(1).getDim2();
+	}
+	
 	public static boolean isEqualValue( LiteralOp hop1, LiteralOp hop2 ) 
 		throws HopsException
 	{

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index dbde506..10953f5 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1852,7 +1852,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			//alternative pattern: t(U) %*% (W*(U%*%t(V)))
 			if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)	
 				&& HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv
-				&& right.getInput().get(1) instanceof AggBinaryOp
+				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1))
 				&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = right.getInput().get(0); 
@@ -1886,6 +1886,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 				&& right.getInput().get(1) instanceof BinaryOp
 				&& ((BinaryOp) right.getInput().get(1)).getOp() == Hop.OpOp2.PLUS
 				&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
+				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
 				&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = right.getInput().get(0); 
@@ -1917,7 +1918,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			if( !appliedPattern
 				&& left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)	
 				&& HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv
-				&& left.getInput().get(1) instanceof AggBinaryOp
+				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1))
 				&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = left.getInput().get(0); 
@@ -1948,6 +1949,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 				&& left.getInput().get(1) instanceof BinaryOp
 				&& ((BinaryOp) left.getInput().get(1)).getOp() == Hop.OpOp2.PLUS
 				&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
+				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
 				&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = left.getInput().get(0); 
@@ -1975,8 +1977,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			if( !appliedPattern
 				&& right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
 				&& right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS	
-				&& right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
-                && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
+				&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
 				&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = right.getInput().get(0); 
@@ -2008,8 +2010,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			if( !appliedPattern
 				&& left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT	
 				&& left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS	
-				&& left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
-                && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
+				&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
 				&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = left.getInput().get(0); 
@@ -2038,8 +2040,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			if( !appliedPattern
 				&& right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
 				&& right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS	
-				&& right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
-                && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
+				&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
 				&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = right.getInput().get(0); 
@@ -2071,8 +2073,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			if( !appliedPattern
 				&& left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT	
 				&& left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS	
-				&& left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
-                && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
+				&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
 				&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 			{
 				Hop W = left.getInput().get(0); 
@@ -2105,7 +2107,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			&& hi.getDim2() > 1 //not applied for vector-vector mult
 			&& hi.getInput().get(0).getDataType() == DataType.MATRIX 
 			&& hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock()
-			&& hi.getInput().get(1) instanceof AggBinaryOp
+			&& HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1))
 			&& (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
 			&& HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
 		{