You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/02/24 20:27:34 UTC
[2/6] incubator-systemml git commit: [SYSTEMML-1326] Cleanup hop
rewrites (removed redundancy, minor fixes)
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1fe1a02d/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 6ffcbd5..cc67cc1 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -24,7 +24,6 @@ import java.util.HashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
@@ -204,12 +203,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
HopRewriteUtils.isDimsKnown(hi)) //output dims known
{
//remove unnecessary right indexing
- HopRewriteUtils.removeChildReference(parent, hi);
-
Hop hnew = HopRewriteUtils.createDataGenOpByVal( new LiteralOp(hi.getDim1()),
new LiteralOp(hi.getDim2()), 0);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied removeEmptyRightIndexing");
@@ -232,9 +228,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//handling if out of range indexing)
//remove unnecessary right indexing
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
hi = input;
LOG.debug("Applied removeUnnecessaryRightIndexing");
@@ -255,11 +249,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( input1.getNnz()==0 //nnz original known and empty
&& input2.getNnz()==0 ) //nnz input known and empty
{
- //remove unnecessary right indexing
- HopRewriteUtils.removeChildReference(parent, hi);
+ //remove unnecessary right indexing
Hop hnew = HopRewriteUtils.createDataGenOp( input1, 0);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied removeEmptyLeftIndexing");
@@ -279,10 +271,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
//equal dims of left indexing input and output -> no need for indexing
- //remove unnecessary right indexing
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
+ //remove unnecessary right indexing
+ HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
hi = input;
LOG.debug("Applied removeUnnecessaryLeftIndexing");
@@ -314,9 +304,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& input1.getDataType()!=DataType.SCALAR && input2.getDataType()!=DataType.SCALAR )
{
//create new cbind operation and rewrite inputs
- HopRewriteUtils.removeChildReference(parent, hi);
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.CBIND);
- HopRewriteUtils.addChildReference(parent, bop, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
applied = true;
@@ -341,9 +330,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& input1.getDataType()!=DataType.SCALAR && input2.getDataType()!=DataType.SCALAR )
{
//create new cbind operation and rewrite inputs
- HopRewriteUtils.removeChildReference(parent, hi);
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.RBIND);
- HopRewriteUtils.addChildReference(parent, bop, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
applied = true;
@@ -366,10 +354,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
OpOp1 op = ((UnaryOp)hi).getOp();
- //remove unnecessary unary cumsum operator
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
+ //remove unnecessary unary cumsum operator
+ HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
hi = input;
LOG.debug("Applied removeUnnecessaryCumulativeOp: "+op);
@@ -396,9 +382,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& rop.getDim1()==1 && rop.getDim2()==1);
if( apply ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
hi = input;
LOG.debug("Applied removeUnnecessaryReorg.");
}
@@ -414,7 +398,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
Hop right = hi.getInput().get(1);
//check for column replication
- if( right instanceof AggBinaryOp //matrix mult with datagen
+ if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
&& right.getInput().get(1) instanceof DataGenOp
&& ((DataGenOp)right.getInput().get(1)).getOp()==DataGenMethod.RAND
&& ((DataGenOp)right.getInput().get(1)).hasConstantValue(1d)
@@ -422,31 +406,21 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& right.getInput().get(0).getDim2() == 1 ) //column vector for mv binary
{
//remove unnecessary outer product
- HopRewriteUtils.removeChildReference(hi, right);
- HopRewriteUtils.addChildReference(hi, right.getInput().get(0) );
- hi.refreshSizeInformation();
-
- //cleanup refs to matrix mult if no remaining consumers
- if( right.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( right );
+ HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1 );
+ HopRewriteUtils.cleanupUnreferenced(right);
LOG.debug("Applied removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")");
}
//check for row replication
- else if( right instanceof AggBinaryOp //matrix mult with datagen
+ else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
&& right.getInput().get(0) instanceof DataGenOp
&& ((DataGenOp)right.getInput().get(0)).hasConstantValue(1d)
&& right.getInput().get(0).getDim2() == 1 //colunm vector for replication
&& right.getInput().get(1).getDim1() == 1 ) //row vector for mv binary
{
//remove unnecessary outer product
- HopRewriteUtils.removeChildReference(hi, right);
- HopRewriteUtils.addChildReference(hi, right.getInput().get(1) );
- hi.refreshSizeInformation();
-
- //cleanup refs to matrix mult if no remaining consumers
- if( right.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( right );
+ HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1 );
+ HopRewriteUtils.cleanupUnreferenced(right);
LOG.debug("Applied removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")");
}
@@ -458,7 +432,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
@SuppressWarnings("unchecked")
private Hop fuseDatagenAndReorgOperation(Hop parent, Hop hi, int pos)
{
- if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE //transpose
+ if( HopRewriteUtils.isTransposeOperation(hi)
&& hi.getInput().get(0) instanceof DataGenOp //datagen
&& hi.getInput().get(0).getParent().size()==1 ) //transpose only consumer
{
@@ -512,17 +486,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
// the column variances will each be zero.
// Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros.
Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0);
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, emptyRow, pos);
- parent.refreshSizeInformation();
-
- // cleanup
- if (hi.getParent().isEmpty())
- HopRewriteUtils.removeAllChildReferences(hi);
- if (input.getParent().isEmpty())
- HopRewriteUtils.removeAllChildReferences(input);
-
- // replace current HOP with new empty row HOP
+ HopRewriteUtils.replaceChildReference(parent, hi, emptyRow, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, input);
hi = emptyRow;
LOG.debug("Applied simplifyColwiseAggregate for colVars");
@@ -530,13 +495,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
// All other valid column aggregations over a row vector will result
// in the row vector itself.
// Therefore, remove unnecessary col aggregation for 1 row.
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- if( hi.getParent().isEmpty() ) //no remaining consumers
- HopRewriteUtils.removeChildReference(hi, input);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
-
- // replace current HOP with input HOP
+ HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied simplifyColwiseAggregate1");
@@ -552,15 +512,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
uhi.setDataType(DataType.SCALAR);
//create cast to keep same output datatype
- UnaryOp cast = new UnaryOp(uhi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp1.CAST_AS_MATRIX, uhi);
- HopRewriteUtils.setOutputParameters(cast, 1, 1, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), -1);
+ UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
//rehang cast under all parents
for( Hop p : parents ) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
- HopRewriteUtils.removeChildReference(p, hi);
- HopRewriteUtils.addChildReference(p, cast, ix);
+ HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
}
hi = cast;
@@ -594,15 +551,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
// Therefore, perform a rewrite from ROWVAR(X) to a column vector of
// zeros.
Hop emptyCol = HopRewriteUtils.createDataGenOp(input, uhi, 0);
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, emptyCol, pos);
- parent.refreshSizeInformation();
-
- // cleanup
- if (hi.getParent().isEmpty())
- HopRewriteUtils.removeAllChildReferences(hi);
- if (input.getParent().isEmpty())
- HopRewriteUtils.removeAllChildReferences(input);
+ HopRewriteUtils.replaceChildReference(parent, hi, emptyCol, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, input);
// replace current HOP with new empty column HOP
hi = emptyCol;
@@ -612,13 +562,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
// All other valid row aggregations over a column vector will result
// in the column vector itself.
// Therefore, remove unnecessary row aggregation for 1 col
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- if( hi.getParent().isEmpty() ) //no remaining consumers
- HopRewriteUtils.removeChildReference(hi, input);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
-
- // replace current HOP with input HOP
+ HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied simplifyRowwiseAggregate1");
@@ -634,15 +579,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
uhi.setDataType(DataType.SCALAR);
//create cast to keep same output datatype
- UnaryOp cast = new UnaryOp(uhi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp1.CAST_AS_MATRIX, uhi);
- HopRewriteUtils.setOutputParameters(cast, 1, 1, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), -1);
+ UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
//rehang cast under all parents
for( Hop p : parents ) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
- HopRewriteUtils.removeChildReference(p, hi);
- HopRewriteUtils.addChildReference(p, cast, ix);
+ HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
}
hi = cast;
@@ -666,36 +608,26 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
- if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col ) //colsums
+ if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col //colsums
+ && HopRewriteUtils.isBinary(input, OpOp2.MULT) ) //b(*)
{
- if( input instanceof BinaryOp && ((BinaryOp)input).getOp()==OpOp2.MULT ) //b(*)
+ Hop left = input.getInput().get(0);
+ Hop right = input.getInput().get(1);
+
+ if( left.getDim1()>1 && left.getDim2()>1
+ && right.getDim1()>1 && right.getDim2()==1 ) // MV (col vector)
{
- Hop left = input.getInput().get(0);
- Hop right = input.getInput().get(1);
+ //create new operators
+ ReorgOp trans = HopRewriteUtils.createTranspose(right);
+ AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
- if( left.getDim1()>1 && left.getDim2()>1
- && right.getDim1()>1 && right.getDim2()==1 ) // MV (col vector)
- {
- //remove link parent to rowsums
- HopRewriteUtils.removeChildReference(parent, hi);
-
- //create new operators
- ReorgOp trans = HopRewriteUtils.createTranspose(right);
- AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
-
- //relink new child
- HopRewriteUtils.addChildReference(parent, mmult, pos);
- hi = mmult;
-
- //cleanup old dag
- if( uhi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences(uhi);
- if( input.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences(input);
-
- LOG.debug("Applied simplifyColSumsMVMult");
- }
- }
+ //relink new child
+ HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
+ HopRewriteUtils.cleanupUnreferenced(uhi, input);
+ hi = mmult;
+
+ LOG.debug("Applied simplifyColSumsMVMult");
+ }
}
}
@@ -712,37 +644,27 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
- if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row ) //rowsums
+ if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row //rowsums
+ && HopRewriteUtils.isBinary(input, OpOp2.MULT) ) //b(*)
{
- if( input instanceof BinaryOp && ((BinaryOp)input).getOp()==OpOp2.MULT ) //b(*)
+ Hop left = input.getInput().get(0);
+ Hop right = input.getInput().get(1);
+
+ if( left.getDim1()>1 && left.getDim2()>1
+ && right.getDim1()==1 && right.getDim2()>1 ) // MV (row vector)
{
- Hop left = input.getInput().get(0);
- Hop right = input.getInput().get(1);
+ //create new operators
+ ReorgOp trans = HopRewriteUtils.createTranspose(right);
+ AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans);
- if( left.getDim1()>1 && left.getDim2()>1
- && right.getDim1()==1 && right.getDim2()>1 ) // MV (row vector)
- {
- //remove link parent to rowsums
- HopRewriteUtils.removeChildReference(parent, hi);
-
- //create new operators
- ReorgOp trans = HopRewriteUtils.createTranspose(right);
- AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans);
-
- //relink new child
- HopRewriteUtils.addChildReference(parent, mmult, pos);
- hi = mmult;
-
- //cleanup old dag
- if( uhi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences(uhi);
- if( input.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences(input);
-
- LOG.debug("Applied simplifyRowSumsMVMult");
- }
- }
- }
+ //relink new child
+ HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, input);
+ hi = mmult;
+
+ LOG.debug("Applied simplifyRowSumsMVMult");
+ }
+ }
}
return hi;
@@ -764,9 +686,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
//remove unnecessary aggregation
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, cast, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
hi = cast;
LOG.debug("Applied simplifyUnncessaryAggregate");
@@ -789,9 +709,6 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( HopRewriteUtils.isEmpty(input) )
{
- //remove unnecessary aggregation
- HopRewriteUtils.removeChildReference(parent, hi);
-
Hop hnew = null;
if( uhi.getDirection() == Direction.RowCol )
hnew = new LiteralOp(0.0);
@@ -801,8 +718,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = HopRewriteUtils.createDataGenOp(input, uhi, 0); //ncol(uhi)=1
//add new child to parent input
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyAggregate");
@@ -825,14 +741,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( HopRewriteUtils.isEmpty(input) )
{
- //remove unnecessary aggregation
- HopRewriteUtils.removeChildReference(parent, hi);
-
//create literal add it to parent
Hop hnew = HopRewriteUtils.createDataGenOp(input, 0);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
-
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyUnaryOperation");
@@ -873,9 +784,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//modify dag if one of the above rules applied
if( hnew != null ){
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyReorgOperation");
@@ -914,9 +823,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//modify dag if one of the above rules applied
if( hnew != null ){
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptySortOperation (indexreturn="+ixret+").");
@@ -931,7 +838,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
private Hop simplifyEmptyMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
- if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y -> matrix(0, )
+ if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> matrix(0, )
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
@@ -939,15 +846,10 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( HopRewriteUtils.isEmpty(left) //one input empty
|| HopRewriteUtils.isEmpty(right) )
{
- //remove unnecessary matrix mult
- HopRewriteUtils.removeChildReference(parent, hi);
-
//create datagen and add it to parent
Hop hnew = HopRewriteUtils.createDataGenOp(left, right, 0);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
-
- hi = hnew;
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
+ hi = hnew;
LOG.debug("Applied simplifyEmptyMatrixMult");
}
@@ -959,7 +861,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
private Hop simplifyIdentityRepMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
- if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y -> X, if y is matrix(1,1,1)
+ if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> X, if y is matrix(1,1,1)
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
@@ -969,8 +871,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
right instanceof DataGenOp && ((DataGenOp)right).getOp()==DataGenMethod.RAND
&& ((DataGenOp)right).hasConstantValue(1.0)) //matrix(1,)
{
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, left, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, left, pos);
hi = left;
LOG.debug("Applied simplifyIdentiyMatrixMult");
@@ -983,7 +884,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
private Hop simplifyScalarMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
- if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y
+ if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
@@ -991,49 +892,27 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
// y %*% X -> as.scalar(y) * X
if( HopRewriteUtils.isDimsKnown(left) && left.getDim1()==1 && left.getDim2()==1 ) //scalar left
{
- //remove link from parent to matrix mult
- HopRewriteUtils.removeChildReference(parent, hi);
-
- UnaryOp cast = new UnaryOp(left.getName(), DataType.SCALAR, ValueType.DOUBLE,
- OpOp1.CAST_AS_SCALAR, left);
- HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, 0);
- BinaryOp mult = new BinaryOp(cast.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, cast, right);
- HopRewriteUtils.setOutputParameters(mult, right.getDim1(), right.getDim2(), right.getRowsInBlock(), right.getColsInBlock(), -1);
-
- //cleanup if only consumer of intermediate
- if( hi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi );
+ UnaryOp cast = HopRewriteUtils.createUnary(left, OpOp1.CAST_AS_SCALAR);
+ BinaryOp mult = HopRewriteUtils.createBinary(cast, right, OpOp2.MULT);
//add mult to parent
- HopRewriteUtils.addChildReference(parent, mult, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
hi = mult;
-
LOG.debug("Applied simplifyScalarMatrixMult1");
}
// X %*% y -> X * as.scalar(y)
else if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right
{
- //remove link from parent to matrix mult
- HopRewriteUtils.removeChildReference(parent, hi);
-
- UnaryOp cast = new UnaryOp(right.getName(), DataType.SCALAR, ValueType.DOUBLE,
- OpOp1.CAST_AS_SCALAR, right);
- HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, 0);
- BinaryOp mult = new BinaryOp(cast.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, cast, left);
- HopRewriteUtils.setOutputParameters(mult, left.getDim1(), left.getDim2(), left.getRowsInBlock(), left.getColsInBlock(), -1);
-
- //cleanup if only consumer of intermediate
- if( hi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi );
+ UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR);
+ BinaryOp mult = HopRewriteUtils.createBinary(cast, left, OpOp2.MULT);
//add mult to parent
- HopRewriteUtils.addChildReference(parent, mult, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
hi = mult;
-
LOG.debug("Applied simplifyScalarMatrixMult2");
}
}
@@ -1046,7 +925,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
Hop hnew = null;
- if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y
+ if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y
{
Hop left = hi.getInput().get(0);
@@ -1061,36 +940,22 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( right.getDim2()==1 ) //right column vector
{
- //remove link from parent to matrix mult
- HopRewriteUtils.removeChildReference(parent, hi);
-
//create binary operation over input and right
Hop input = left.getInput().get(0); //diag input
- hnew = new BinaryOp(input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, input, right);
- HopRewriteUtils.setOutputParameters(hnew, left.getDim1(), right.getDim2(), left.getRowsInBlock(), left.getColsInBlock(), -1);
-
+ hnew = HopRewriteUtils.createBinary(input, right, OpOp2.MULT);
+
LOG.debug("Applied simplifyMatrixMultDiag1");
}
else if( right.getDim2()>1 ) //multi column vector
{
- //remove link from parent to matrix mult
- HopRewriteUtils.removeChildReference(parent, hi);
-
//create binary operation over input and right; in contrast to above rewrite,
//we need to switch the order because MV binary cell operations require vector on the right
Hop input = left.getInput().get(0); //diag input
- hnew = new BinaryOp(input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, right, input);
- HopRewriteUtils.setOutputParameters(hnew, left.getDim1(), right.getDim2(), left.getRowsInBlock(), left.getColsInBlock(), -1);
+ hnew = HopRewriteUtils.createBinary(right, input, OpOp2.MULT);
+
+ //NOTE: previously to MV binary cell operations we replicated the left
+ //(if moderate number of columns: 2), but this is no longer required
- //NOTE: previously to MV binary cell operations we replicated the left (if moderate number of columns: 2)
- //create binary operation over input and right
- //Hop input = left.getInput().get(0);
- //Hop ones = HopRewriteUtils.createDataGenOpByVal(new LiteralOp("1",1), new LiteralOp(String.valueOf(right.getDim2()),right.getDim2()), 1);
- //Hop repmat = new AggBinaryOp( input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, AggOp.SUM, input, ones );
- //HopRewriteUtils.setOutputParameters(repmat, input.getDim1(), ones.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), -1);
- //hnew = new BinaryOp(input.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp2.MULT, repmat, right);
- //HopRewriteUtils.setOutputParameters(hnew, right.getDim1(), right.getDim2(), right.getRowsInBlock(), right.getColsInBlock(), -1);
-
LOG.debug("Applied simplifyMatrixMultDiag2");
}
}
@@ -1100,13 +965,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//if one of the above rewrites applied
if( hnew !=null ){
- //cleanup if only consumer of intermediate
- if( hi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi );
-
//add mult to parent
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
hi = hnew;
}
@@ -1119,41 +980,21 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.DIAG && hi.getDim2()==1 ) //diagM2V
{
Hop hi2 = hi.getInput().get(0);
- if( hi2 instanceof AggBinaryOp && ((AggBinaryOp)hi2).isMatrixMultiply() ) //X%*%Y
+ if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y
{
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
- //remove link from parent to diag
- HopRewriteUtils.removeChildReference(parent, hi);
-
- //remove links to inputs to matrix mult
- //removeChildReference(hi2, left);
- //removeChildReference(hi2, right);
-
//create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
- BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, left, trans);
- mult.setRowsInBlock(right.getRowsInBlock());
- mult.setColsInBlock(right.getColsInBlock());
- mult.refreshSizeInformation();
- AggUnaryOp rowSum = new AggUnaryOp(right.getName(), right.getDataType(), right.getValueType(), AggOp.SUM, Direction.Row, mult);
- rowSum.setRowsInBlock(right.getRowsInBlock());
- rowSum.setColsInBlock(right.getColsInBlock());
- rowSum.refreshSizeInformation();
+ BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
+ AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
//rehang new subdag under parent node
- HopRewriteUtils.addChildReference(parent, rowSum, pos);
- parent.refreshSizeInformation();
-
- //cleanup if only consumer of intermediate
- if( hi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi );
- if( hi2.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi2 );
-
- hi = rowSum;
+ HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, hi2);
+ hi = rowSum;
LOG.debug("Applied simplifyDiagMatrixMult");
}
}
@@ -1174,16 +1015,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
Hop hi3 = hi2.getInput().get(0);
//remove diag operator
- HopRewriteUtils.removeChildReference(au, hi2);
- HopRewriteUtils.addChildReference(au, hi3, 0);
+ HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
+ HopRewriteUtils.cleanupUnreferenced(hi2);
//change sum to trace
au.setOp( AggOp.TRACE );
- //cleanup if only consumer of intermediate
- if( hi2.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi2 );
-
LOG.debug("Applied simplifySumDiagToTrace");
}
}
@@ -1198,7 +1035,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
//diag(X)*7 --> diag(X*7) in order to (1) reduce required memory for b(*) and
//(2) in order to make the binary operation more efficient (dense vector vs sparse matrix)
- if( hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==OpOp2.MULT )
+ if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) )
{
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
@@ -1279,7 +1116,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
*/
private Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
{
- //all patterns headed by fiull sum over binary operation
+ //all patterns headed by full sum over binary operation
if( hi instanceof AggUnaryOp //full sum root over binaryop
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
@@ -1305,13 +1142,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
//rewire new subdag
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, newBin, pos);
- if( hi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences(hi);
- if( bop.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences(bop);
-
+ HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, bop);
hi = newBin;
@@ -1362,17 +1194,16 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
//alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
- if( bop.getOp()==OpOp2.MULT && bop.getInput().get(1) instanceof BinaryOp
+ if( bop.getOp()==OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW)
&& bop.getInput().get(0).getDataType()==DataType.MATRIX
&& HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv
- && ((BinaryOp)bop.getInput().get(1)).getOp()==OpOp2.POW
&& bop.getInput().get(1).getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2)
{
Hop W = bop.getInput().get(0);
Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V))
- if( tmp instanceof BinaryOp && ((BinaryOp)tmp).getOp()==OpOp2.MINUS
+ if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS)
&& HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv
&& tmp.getInput().get(0).getDataType() == DataType.MATRIX )
{
@@ -1424,9 +1255,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
- && bop.getInput().get(0) instanceof BinaryOp
+ && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS)
&& bop.getInput().get(0).getDataType()==DataType.MATRIX
- && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
&& HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
&& bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
{
@@ -1479,9 +1309,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
- && bop.getInput().get(0) instanceof BinaryOp
+ && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS)
&& bop.getInput().get(0).getDataType()==DataType.MATRIX
- && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
&& HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
&& bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
{
@@ -1529,8 +1358,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//relink new hop into original position
if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
@@ -1542,8 +1370,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
Hop hnew = null;
- if( hi instanceof BinaryOp //all patterns subrooted by W *
- && ((BinaryOp) hi).getOp()==OpOp2.MULT
+ if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) //all patterns subrooted by W *
&& hi.getDim2() > 1 //not applied for vector-vector mult
&& HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
&& hi.getInput().get(0).getDataType()==DataType.MATRIX
@@ -1569,7 +1396,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WSIGMOID, W, Y, tX, false, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1579,8 +1406,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
if( !appliedPattern
&& uop.getOp() == OpOp1.SIGMOID
- && uop.getInput().get(0) instanceof BinaryOp
- && ((BinaryOp)uop.getInput().get(0)).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS)
&& uop.getInput().get(0).getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe(
(LiteralOp)uop.getInput().get(0).getInput().get(0))==0
@@ -1599,7 +1425,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WSIGMOID, W, Y, tX, false, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1609,8 +1435,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
if( !appliedPattern
&& uop.getOp() == OpOp1.LOG
- && uop.getInput().get(0) instanceof UnaryOp
- && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID
+ && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID)
&& uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp
&& HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true) )
{
@@ -1626,7 +1451,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WSIGMOID, W, Y, tX, true, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1636,14 +1461,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
if( !appliedPattern
&& uop.getOp() == OpOp1.LOG
- && uop.getInput().get(0) instanceof UnaryOp
- && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID
- && uop.getInput().get(0).getInput().get(0) instanceof BinaryOp )
+ && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID)
+ && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS) )
{
BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
- if( bop.getOp() == OpOp2.MINUS
- && bop.getInput().get(0) instanceof LiteralOp
+ if( bop.getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0
&& bop.getInput().get(1) instanceof AggBinaryOp
&& HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true))
@@ -1660,7 +1483,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WSIGMOID, W, Y, tX, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1671,8 +1494,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//relink new hop into original position
if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
@@ -1687,7 +1509,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)'
//note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
- if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply()
+ if( HopRewriteUtils.isMatrixMultiply(hi)
&& (hi.getInput().get(0) instanceof BinaryOp
&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
|| hi.getInput().get(1) instanceof BinaryOp
@@ -1718,7 +1540,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
@@ -1731,10 +1553,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
if( !appliedPattern
- && right instanceof BinaryOp && ((BinaryOp)right).getOp() == LOOKUP_VALID_WDIVMM_BINARY[1] //DIV
+ && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
&& HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv
- && right.getInput().get(1) instanceof BinaryOp
- && ((BinaryOp) right.getInput().get(1)).getOp() == Hop.OpOp2.PLUS
+ && HopRewriteUtils.isBinary(right.getInput().get(1), Hop.OpOp2.PLUS)
&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
@@ -1753,7 +1574,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, W, U, V, X, 3, false, false); // 3=>DIV_LEFT_EPS
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
@@ -1786,7 +1607,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1796,10 +1617,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 2e) (W/(U%*%t(V) + x)) %*% V
if( !appliedPattern
- && left instanceof BinaryOp && ((BinaryOp)left).getOp() == LOOKUP_VALID_WDIVMM_BINARY[1] //DIV
+ && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
&& HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv
- && left.getInput().get(1) instanceof BinaryOp
- && ((BinaryOp) left.getInput().get(1)).getOp() == Hop.OpOp2.PLUS
+ && HopRewriteUtils.isBinary(left.getInput().get(1), Hop.OpOp2.PLUS)
&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
@@ -1818,7 +1638,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, W, U, V, X, 4, false, false); // 4=>DIV_RIGHT_EPS
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1828,8 +1648,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
if( !appliedPattern
- && right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
- && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
+ && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
@@ -1849,7 +1669,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
@@ -1862,8 +1682,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
if( !appliedPattern
- && left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
- && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
+ && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
@@ -1883,7 +1703,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1893,8 +1713,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
if( !appliedPattern
- && right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
- && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
+ && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
@@ -1914,7 +1734,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//note: x and w exchanged compared to patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, W, U, V, X, 1, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
@@ -1927,8 +1747,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 6) (W*(U%*%t(V)-X)) %*% V
if( !appliedPattern
- && left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
- && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
+ && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
@@ -1948,7 +1768,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//note: x and w exchanged compared to patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, W, U, V, X, 2, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1959,7 +1779,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 7) (W*(U%*%t(V)))
if( !appliedPattern
- && hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
+ && HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
&& HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
&& hi.getDim2() > 1 //not applied for vector-vector mult
&& hi.getInput().get(0).getDataType() == DataType.MATRIX
@@ -1982,7 +1802,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -1992,8 +1812,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//relink new hop into original position
if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
@@ -2018,7 +1837,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//Pattern 1) sum( X * log(U %*% t(V)))
if( bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX
&& HopRewriteUtils.isEqualSize(left, right) //prevent mb
- && right instanceof UnaryOp && ((UnaryOp)right).getOp()==OpOp1.LOG
+ && HopRewriteUtils.isUnary(right, OpOp1.LOG)
&& right.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices
&& HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
{
@@ -2033,7 +1852,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V,
new LiteralOp(0.0), 0, false, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, X.getRowsInBlock(), X.getColsInBlock());
+ hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
appliedPattern = true;
LOG.debug("Applied simplifyWeightedCEMM (line "+hi.getBeginLine()+")");
@@ -2043,9 +1862,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX
&& HopRewriteUtils.isEqualSize(left, right)
- && right instanceof UnaryOp && ((UnaryOp)right).getOp()==OpOp1.LOG
- && right.getInput().get(0) instanceof BinaryOp
- && ((BinaryOp)right.getInput().get(0)).getOp() == OpOp2.PLUS
+ && HopRewriteUtils.isUnary(right, OpOp1.LOG)
+ && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS)
&& right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp
&& right.getInput().get(0).getInput().get(1) instanceof LiteralOp
&& right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR
@@ -2063,7 +1881,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE,
OpOp4.WCEMM, X, U, V, eps, 1, false, false); // 1 => BASIC_EPS
- HopRewriteUtils.setOutputBlocksizes(hnew, X.getRowsInBlock(), X.getColsInBlock());
+ hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
LOG.debug("Applied simplifyWeightedCEMMEps (line "+hi.getBeginLine()+")");
}
@@ -2071,8 +1889,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//relink new hop into original position
if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
@@ -2109,7 +1926,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WUMM, W, U, V, mult, op, null);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -2162,7 +1979,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
OpOp4.WUMM, W, U, V, mult, null, op);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
@@ -2173,8 +1990,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//relink new hop into original position
if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
@@ -2207,7 +2023,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
Hop hi2 = hi.getInput().get(0); //check for ^2 w/o multiple consumers
//check for sum(v^2), might have been rewritten from sum(v*v)
- if( hi2 instanceof BinaryOp && ((BinaryOp)hi2).getOp()==OpOp2.POW
+ if( HopRewriteUtils.isBinary(hi2, OpOp2.POW)
&& hi2.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)hi2.getInput().get(1))==2
&& hi2.getParent().size() == 1 ) //no other consumer than sum
@@ -2217,11 +2033,10 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
baRight = input;
}
//check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
- else if( hi2 instanceof BinaryOp && ((BinaryOp)hi2).getOp()==OpOp2.MULT
+ else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum
&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
- && hi2.getParent().size() == 1 //no other consumer than sum
- && !(hi2.getInput().get(0) instanceof BinaryOp && ((BinaryOp)hi2.getInput().get(0)).getOp()==OpOp2.MULT)
- && !(hi2.getInput().get(1) instanceof BinaryOp && ((BinaryOp)hi2.getInput().get(1)).getOp()==OpOp2.MULT))
+ && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
+ && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) )
{
baLeft = hi2.getInput().get(0);
baRight = hi2.getInput().get(1);
@@ -2230,25 +2045,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//perform actual rewrite (if necessary)
if( baLeft != null && baRight != null )
{
- //remove link from parent to diag
- HopRewriteUtils.removeChildReference(parent, hi);
-
//create new operator chain
ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
-
- UnaryOp cast = new UnaryOp(baLeft.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp1.CAST_AS_SCALAR, mmult);
- HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, -1);
+ UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
//rehang new subdag under parent node
- HopRewriteUtils.addChildReference(parent, cast, pos);
- parent.refreshSizeInformation();
-
- //cleanup if only consumer of intermediate
- if( hi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi );
- if( hi2.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi2 );
+ HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = cast;
@@ -2277,7 +2081,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
Hop sumInput = hi.getInput().get(0);
// if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
- if (sumInput instanceof BinaryOp && ((BinaryOp) sumInput).getOp() == OpOp2.POW
+ if( HopRewriteUtils.isBinary(sumInput, OpOp2.POW)
&& sumInput.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2
&& sumInput.getParent().size() == 1) {
@@ -2286,24 +2090,13 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
// if X is NOT a column vector
if (x.getDim2() > 1) {
// perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
- DataType dt = hi.getDataType();
- ValueType vt = hi.getValueType();
Direction dir = ((AggUnaryOp) hi).getDirection();
- long brlen = hi.getRowsInBlock();
- long bclen = hi.getColsInBlock();
- AggUnaryOp sumSq = new AggUnaryOp("sumSq", dt, vt, AggOp.SUM_SQ, dir, x);
- HopRewriteUtils.setOutputBlocksizes(sumSq, brlen, bclen);
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, sumSq, pos);
-
- // cleanup
- if (hi.getParent().isEmpty())
- HopRewriteUtils.removeAllChildReferences(hi);
- if(sumInput.getParent().isEmpty())
- HopRewriteUtils.removeAllChildReferences(sumInput);
-
- // replace current HOP with new SUM_SQ HOP
+ AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir);
+ HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, sumInput);
hi = sumSq;
+
+ LOG.debug("Applied fuseSumSquared.");
}
}
}
@@ -2358,8 +2151,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//rewire parent-child operators if rewrite applied
if( ternop != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, ternop, pos);
+ HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos);
hi = ternop;
}
}
@@ -2421,15 +2213,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hnew = null;
}
- if( hnew != null )
- {
- //remove unnecessary matrix mult
- HopRewriteUtils.removeChildReference(parent, hi);
-
+ if( hnew != null ) {
//create datagen and add it to parent
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- parent.refreshSizeInformation();
-
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyBinaryOperation");
@@ -2460,12 +2246,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
private Hop reorderMinusMatrixMult(Hop parent, Hop hi, int pos)
throws HopsException
{
- if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply() ) //X%*%Y
+ if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y
{
Hop hileft = hi.getInput().get(0);
Hop hiright = hi.getInput().get(1);
- if( hileft instanceof BinaryOp && ((BinaryOp)hileft).getOp()==OpOp2.MINUS //X=-Z
+ if( HopRewriteUtils.isBinary(hileft, OpOp2.MINUS) //X=-Z
&& hileft.getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)hileft.getInput().get(0))==0.0
&& hi.dimsKnown() && hileft.getInput().get(1).dimsKnown() //size comparison
@@ -2480,9 +2266,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
//create new operators
- BinaryOp minus = new BinaryOp(hi.getName(), hi.getDataType(), hi.getValueType(), OpOp2.MINUS, new LiteralOp(0), hi);
- minus.setRowsInBlock(hi.getRowsInBlock());
- minus.setColsInBlock(hi.getColsInBlock());
+ BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
//rehang minus under all parents
for( Hop p : parents ) {
@@ -2495,14 +2279,13 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
HopRewriteUtils.addChildReference(hi, hi2, 0);
//cleanup if only consumer of minus
- if( hileft.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hileft );
+ HopRewriteUtils.cleanupUnreferenced(hileft);
hi = minus;
LOG.debug("Applied reorderMinusMatrixMult (line "+hi.getBeginLine()+").");
}
- else if( hiright instanceof BinaryOp && ((BinaryOp)hiright).getOp()==OpOp2.MINUS //X=-Z
+ else if( HopRewriteUtils.isBinary(hiright, OpOp2.MINUS) //X=-Z
&& hiright.getInput().get(0) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)hiright.getInput().get(0))==0.0
&& hi.dimsKnown() && hiright.getInput().get(1).dimsKnown() //size comparison
@@ -2517,9 +2300,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
//create new operators
- BinaryOp minus = new BinaryOp(hi.getName(), hi.getDataType(), hi.getValueType(), OpOp2.MINUS, new LiteralOp(0), hi);
- minus.setRowsInBlock(hi.getRowsInBlock());
- minus.setColsInBlock(hi.getColsInBlock());
+ BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
//rehang minus under all parents
for( Hop p : parents ) {
@@ -2532,8 +2313,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
HopRewriteUtils.addChildReference(hi, hi2, 1);
//cleanup if only consumer of minus
- if( hiright.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hiright );
+ HopRewriteUtils.cleanupUnreferenced(hiright);
hi = minus;
@@ -2592,8 +2372,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi.refreshSizeInformation();
//cleanup if only consumer of intermediate
- if( hi2.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi2 );
+ HopRewriteUtils.cleanupUnreferenced(hi2);
}
return hi;
@@ -2612,11 +2391,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right
{
//remove link to right child and introduce cast
- HopRewriteUtils.removeChildReference(hi, right);
- UnaryOp cast = new UnaryOp(right.getName(), DataType.SCALAR, ValueType.DOUBLE,
- OpOp1.CAST_AS_SCALAR, right);
- HopRewriteUtils.setOutputParameters(cast, 0, 0, 0, 0, 0);
- HopRewriteUtils.addChildReference(hi, cast, 1);
+ UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR);
+ HopRewriteUtils.replaceChildReference(hi, right, cast, 1);
LOG.debug("Applied simplifyScalarMVBinaryOperation.");
}
@@ -2631,8 +2407,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
&& ((AggUnaryOp)hi).getDirection() == Direction.RowCol //full aggregate
- && hi.getInput().get(0) instanceof BinaryOp
- && ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.NOTEQUAL )
+ && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.NOTEQUAL) )
{
Hop ppred = hi.getInput().get(0);
Hop X = null;
@@ -2650,13 +2425,10 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//apply rewrite if known nnz
if( X != null && X.getNnz() > 0 ){
Hop hnew = new LiteralOp(X.getNnz());
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
-
- if( hi.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi );
-
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
hi = hnew;
+
LOG.debug("Applied simplifyNnzComputation.");
}
}