You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/02/16 08:34:34 UTC
[2/4] systemml git commit: [SYSTEMML-2149] New simplification rewrite
for replace zero w/ scalar
[SYSTEMML-2149] New simplification rewrite for replace zero w/ scalar
There are multiple scripts that emulate the replacement of zeros with a
scalar via X + (X==0) * s. We now rewrite this pattern to the builtin
function replace(X, 0, s), which avoids an unnecessary intermediate and
(partitioning-preserving) joins for distributed operations.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/62e590ce
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/62e590ce
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/62e590ce
Branch: refs/heads/master
Commit: 62e590ced04900364bdc294538e78de6af3f4988
Parents: 72830f0
Author: Matthias Boehm <mb...@gmail.com>
Authored: Thu Feb 15 18:49:00 2018 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Thu Feb 15 18:49:00 2018 -0800
----------------------------------------------------------------------
.../RewriteAlgebraicSimplificationStatic.java | 32 +++++++++++++++++---
1 file changed, 28 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/62e590ce/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index ac45e77..3a3235d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12))
hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
+ hi = simplifyReplaceZeroOperation(hop, hi, i); //e.g., X + (X==0) * s -> replace(X, 0, s)
hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites
hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1556,14 +1557,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
{
if( HopRewriteUtils.isTransposeOperation(hi)
&& hi.getInput().get(0) instanceof BinaryOp //basic binary
- && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
+ && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
{
Hop left = hi.getInput().get(0).getInput().get(0);
Hop C = hi.getInput().get(0).getInput().get(1);
//check matrix mult and both inputs transposes w/ single consumer
if( left instanceof AggBinaryOp && C.getDataType().isMatrix()
- && HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
+ && HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
&& left.getInput().get(0).getParent().size()==1
&& HopRewriteUtils.isTransposeOperation(left.getInput().get(1))
&& left.getInput().get(1).getParent().size()==1 )
@@ -1578,13 +1579,36 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
- LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");
- }
+ LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");
+ }
}
return hi;
}
+ // Patterns: X + (X==0) * s -> replace(X, 0, s)
+ private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) && hi.getInput().get(0).isMatrix()
+ && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT)
+ && hi.getInput().get(1).getInput().get(1).isScalar()
+ && HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0), OpOp2.EQUAL, 0)
+ && hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0)) )
+ {
+ HashMap<String, Hop> args = new HashMap<>();
+ args.put("target", hi.getInput().get(0));
+ args.put("pattern", new LiteralOp(0));
+ args.put("replacement", hi.getInput().get(1).getInput().get(1));
+ Hop replace = HopRewriteUtils.createParameterizedBuiltinOp(
+ hi.getInput().get(0), args, ParamBuiltinOp.REPLACE);
+ HopRewriteUtils.replaceChildReference(parent, hi, replace, pos);
+ hi = replace;
+ LOG.debug("Applied simplifyReplaceZeroOperation (line "+hi.getBeginLine()+").");
+ }
+ return hi;
+ }
+
/**
* Pattners: t(t(X)) -> X, rev(rev(X)) -> X
*