You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/02/16 08:34:34 UTC

[2/4] systemml git commit: [SYSTEMML-2149] New simplification rewrite for replace zero w/ scalar

[SYSTEMML-2149] New simplification rewrite for replace zero w/ scalar

There are multiple scripts that emulate the replacement of zeros with a
scalar via X + (X==0) * s. We now rewrite this pattern to the builtin
function replace(X, 0, s), which avoids an unnecessary intermediate and
(partitioning-preserving) joins for distributed operations.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/62e590ce
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/62e590ce
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/62e590ce

Branch: refs/heads/master
Commit: 62e590ced04900364bdc294538e78de6af3f4988
Parents: 72830f0
Author: Matthias Boehm <mb...@gmail.com>
Authored: Thu Feb 15 18:49:00 2018 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Thu Feb 15 18:49:00 2018 -0800

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   | 32 +++++++++++++++++---
 1 file changed, 28 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/62e590ce/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index ac45e77..3a3235d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			hi = fuseOrderOperationChain(hi);                    //e.g., order(order(X,2),1) -> order(X,(12))
 			hi = removeUnnecessaryReorgOperation(hop, hi, i);    //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
 			hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
+			hi = simplifyReplaceZeroOperation(hop, hi, i);       //e.g., X + (X==0) * s -> replace(X, 0, s)
 			hi = removeUnnecessaryMinus(hop, hi, i);             //e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites
 			hi = simplifyGroupedAggregate(hi);          	     //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
 			if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1556,14 +1557,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 	{
 		if( HopRewriteUtils.isTransposeOperation(hi)
 			&& hi.getInput().get(0) instanceof BinaryOp                       //basic binary
-			&& ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) 
+			&& ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
 		{
 			Hop left = hi.getInput().get(0).getInput().get(0);
 			Hop C = hi.getInput().get(0).getInput().get(1);
 			
 			//check matrix mult and both inputs transposes w/ single consumer
 			if( left instanceof AggBinaryOp && C.getDataType().isMatrix()
-				&& HopRewriteUtils.isTransposeOperation(left.getInput().get(0))     
+				&& HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
 				&& left.getInput().get(0).getParent().size()==1 
 				&& HopRewriteUtils.isTransposeOperation(left.getInput().get(1))
 				&& left.getInput().get(1).getParent().size()==1 )
@@ -1578,13 +1579,36 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 				HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
 				
 				hi = bop;
-				LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");						
-			}  
+				LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");
+			}
 		}
 		
 		return hi;
 	}
 	
+	// Patterns: X + (X==0) * s -> replace(X, 0, s)
+	private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int pos) 
+		throws HopsException
+	{
+		if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) && hi.getInput().get(0).isMatrix()
+			&& HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT)
+			&& hi.getInput().get(1).getInput().get(1).isScalar()
+			&& HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0), OpOp2.EQUAL, 0)
+			&& hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0)) )
+		{
+			HashMap<String, Hop> args = new HashMap<>();
+			args.put("target", hi.getInput().get(0));
+			args.put("pattern", new LiteralOp(0));
+			args.put("replacement", hi.getInput().get(1).getInput().get(1));
+			Hop replace = HopRewriteUtils.createParameterizedBuiltinOp(
+				hi.getInput().get(0), args, ParamBuiltinOp.REPLACE);
+			HopRewriteUtils.replaceChildReference(parent, hi, replace, pos);
+			hi = replace;
+			LOG.debug("Applied simplifyReplaceZeroOperation (line "+hi.getBeginLine()+").");
+		}
+		return hi;
+	}
+	
 	/**
 	 * Pattners: t(t(X)) -> X, rev(rev(X)) -> X
 	 *