You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/04/24 03:18:04 UTC
systemml git commit: [SYSTEMML-2277, 224,
423] New rewrite for hoisting loop-invariant ops
Repository: systemml
Updated Branches:
refs/heads/master afbedf3bf -> 1e1210b9e
[SYSTEMML-2277,224,423] New rewrite for hoisting loop-invariant ops
This patch introduces a new optional (still disabled) rewrite for code
motion, i.e., hoisting loop-invariant operations from while, for, or
parfor loops. These loop-invariant operations are defined as reads of
variables used read-only in the loop, and operations that have only
loop-invariant inputs (modulo some special cases such as rand without
seed). Furthermore, this also includes a cleanup of various rewrites
that deal with the creation of transient reads and writes.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1e1210b9
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1e1210b9
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1e1210b9
Branch: refs/heads/master
Commit: 1e1210b9ebdb68e76ad20ee08f132ef32483f829
Parents: afbedf3
Author: Matthias Boehm <mb...@gmail.com>
Authored: Mon Apr 23 19:13:12 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Mon Apr 23 19:13:12 2018 -0700
----------------------------------------------------------------------
.../org/apache/sysml/hops/OptimizerUtils.java | 6 +
.../sysml/hops/rewrite/HopRewriteUtils.java | 39 +++-
.../sysml/hops/rewrite/ProgramRewriter.java | 2 +
.../RewriteHoistLoopInvariantOperations.java | 209 +++++++++++++++++++
.../RewriteInjectSparkLoopCheckpointing.java | 5 +-
.../RewriteSplitDagDataDependentOperators.java | 41 +---
.../rewrite/RewriteSplitDagUnknownCSVRead.java | 20 +-
.../hops/rewrite/StatementBlockRewriteRule.java | 16 +-
.../org/apache/sysml/parser/StatementBlock.java | 4 +-
.../RewriteHoistingLoopInvariantOpsTest.java | 127 +++++++++++
.../functions/misc/RewriteCodeMotionFor.R | 37 ++++
.../functions/misc/RewriteCodeMotionFor.dml | 31 +++
.../functions/misc/RewriteCodeMotionWhile.R | 39 ++++
.../functions/misc/RewriteCodeMotionWhile.dml | 33 +++
.../functions/misc/ZPackageSuite.java | 1 +
15 files changed, 541 insertions(+), 69 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index 2d76759..e9af001 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -182,6 +182,12 @@ public class OptimizerUtils
*/
public static boolean ALLOW_LOOP_UPDATE_IN_PLACE = true;
+ /**
+ * Enables a specific rewrite for code motion, i.e., hoisting loop invariant code
+ * out of while, for, and parfor loops.
+ */
+ public static boolean ALLOW_CODE_MOTION = false;
+
/**
* Specifies a multiplier computing the degree of parallelism of parallel
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index c6c42ae..8abe90b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -499,7 +499,31 @@ public class HopRewriteUtils
public static Hop getDataGenOpConstantValue(Hop hop) {
return ((DataGenOp) hop).getConstantValue();
- }
+ }
+
+ public static DataOp createTransientRead(String name, Hop h) {
+ //note: different constructor necessary for formattype
+ DataOp tread = new DataOp(name, h.getDataType(), h.getValueType(),
+ DataOpTypes.TRANSIENTREAD, null, h.getDim1(), h.getDim2(), h.getNnz(),
+ h.getUpdateType(), h.getRowsInBlock(), h.getColsInBlock());
+ tread.setVisited();
+ copyLineNumbers(h, tread);
+ return tread;
+ }
+
+ public static DataOp createTransientWrite(String name, Hop in) {
+ return createDataOp(name, in, DataOpTypes.TRANSIENTWRITE);
+ }
+
+ public static DataOp createDataOp(String name, Hop in, DataOpTypes type) {
+ DataOp dop = new DataOp(name, in.getDataType(),
+ in.getValueType(), in, type, null);
+ dop.setVisited();
+ dop.setOutputParams(in.getDim1(), in.getDim2(), in.getNnz(),
+ in.getUpdateType(), in.getRowsInBlock(), in.getColsInBlock());
+ copyLineNumbers(in, dop);
+ return dop;
+ }
public static ReorgOp createTranspose(Hop input) {
return createReorg(input, ReOrgOp.TRANS);
@@ -684,14 +708,6 @@ public class HopRewriteUtils
return ternOp;
}
- public static DataOp createDataOp(String name, Hop input, DataOpTypes type) {
- DataOp dop = new DataOp(name, input.getDataType(), input.getValueType(), input, type, null);
- dop.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
- copyLineNumbers(input, dop);
- dop.refreshSizeInformation();
- return dop;
- }
-
public static void setOutputParameters( Hop hop, long rlen, long clen, int brlen, int bclen, long nnz ) {
hop.setDim1( rlen );
hop.setDim2( clen );
@@ -1295,6 +1311,11 @@ public class HopRewriteUtils
|| sb instanceof ForStatementBlock); //incl parfor
}
+ public static boolean isLoopStatementBlock(StatementBlock sb) {
+ return sb instanceof WhileStatementBlock
+ || sb instanceof ForStatementBlock; //incl parfor
+ }
+
public static long getMaxNrowInput(Hop hop) {
return getMaxInputDim(hop, true);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index eb7d23c..2963e9d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -112,6 +112,8 @@ public class ProgramRewriter
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
_sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop)
_sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes)
+ if( OptimizerUtils.ALLOW_CODE_MOTION )
+ _sbRuleSet.add( new RewriteHoistLoopInvariantOperations() ); //dependency: vectorize, but before inplace
if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE )
_sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() );
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java
new file mode 100644
index 0000000..3e77486
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.hops.DataOp;
+import org.apache.sysml.hops.FunctionOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.DataGenMethod;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.ForStatement;
+import org.apache.sysml.parser.ForStatementBlock;
+import org.apache.sysml.parser.IfStatementBlock;
+import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.parser.VariableSet;
+import org.apache.sysml.parser.WhileStatement;
+import org.apache.sysml.parser.WhileStatementBlock;
+
+/**
+ * Rule: Simplify program structure by hoisting loop-invariant operations
+ * out of while, for, or parfor loops.
+ */
+public class RewriteHoistLoopInvariantOperations extends StatementBlockRewriteRule
+{
+ private final boolean _sideEffectFreeFuns;
+
+ public RewriteHoistLoopInvariantOperations() {
+ this(false);
+ }
+
+ public RewriteHoistLoopInvariantOperations(boolean noSideEffects) {
+ _sideEffectFreeFuns = noSideEffects;
+ }
+
+ @Override
+ public boolean createsSplitDag() {
+ return true;
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
+ //early abort if possible
+ if( sb == null || !HopRewriteUtils.isLoopStatementBlock(sb) )
+ return Arrays.asList(sb); //rewrite only applies to loops
+
+ //step 1: determine read-only variables
+ Set<String> candInputs = sb.variablesRead().getVariableNames().stream()
+ .filter(v -> !sb.variablesUpdated().containsVariable(v))
+ .collect(Collectors.toSet());
+
+ //step 2: collect loop-invariant operations along with their tmp names
+ Map<String, Hop> invariantOps = new HashMap<>();
+ collectOperations(sb, candInputs, invariantOps);
+
+ //step 3: create new statement block for all temporary intermediates
+ return invariantOps.isEmpty() ? Arrays.asList(sb) :
+ Arrays.asList(createStatementBlock(sb, invariantOps), sb);
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
+ return sbs;
+ }
+
+ private void collectOperations(StatementBlock sb, Set<String> candInputs, Map<String, Hop> invariantOps) {
+
+ if( sb instanceof WhileStatementBlock ) {
+ WhileStatement wstmt = (WhileStatement) sb.getStatement(0);
+ for( StatementBlock csb : wstmt.getBody() )
+ collectOperations(csb, candInputs, invariantOps);
+ }
+ else if( sb instanceof ForStatementBlock ) {
+ ForStatement fstmt = (ForStatement) sb.getStatement(0);
+ for( StatementBlock csb : fstmt.getBody() )
+ collectOperations(csb, candInputs, invariantOps);
+ }
+ else if( sb instanceof IfStatementBlock ) {
+ //note: for now we do not pull loop-invariant code out of
+ //if statement blocks because these operations are conditionally
+ //executed, so unconditional execution might be counter productive
+ }
+ else if( sb.getHops() != null ) {
+ //step a: bottom-up flagging of loop-invariant operations
+ //(these are defined operations whose inputs are read only
+ //variables or other loop-invariant operations)
+ Hop.resetVisitStatus(sb.getHops());
+ HashSet<Long> memo = new HashSet<>();
+ for( Hop hop : sb.getHops() )
+ rTagLoopInvariantOperations(hop, candInputs, memo);
+
+ //step b: copy hop sub dag and replace it via tread
+ Hop.resetVisitStatus(sb.getHops());
+ for( Hop hop : sb.getHops() )
+ rCollectAndReplaceOperations(hop, candInputs, memo, invariantOps);
+
+ if( !memo.isEmpty() ) {
+ LOG.debug("Applied hoistLoopInvariantOperations (lines "
+ +sb.getBeginLine()+"-"+sb.getEndLine()+"): "+memo.size()+".");
+ }
+ }
+ }
+
+ private void rTagLoopInvariantOperations(Hop hop, Set<String> candInputs, Set<Long> memo) {
+ if( hop.isVisited() )
+ return;
+
+ //process inputs first (depth first)
+ for( Hop c : hop.getInput() )
+ rTagLoopInvariantOperations(c, candInputs, memo);
+
+ //flag operation if all inputs are loop invariant
+ boolean invariant = !HopRewriteUtils.isDataGenOp(hop, DataGenMethod.RAND)
+ && (!(hop instanceof FunctionOp) || _sideEffectFreeFuns)
+ && !HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTREAD)
+ && !HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTWRITE);
+ for( Hop c : hop.getInput() ) {
+ invariant &= (candInputs.contains(c.getName())
+ || memo.contains(c.getHopID()) || c instanceof LiteralOp);
+ }
+ if( invariant )
+ memo.add(hop.getHopID());
+
+ hop.setVisited();
+ }
+
+ private void rCollectAndReplaceOperations(Hop hop, Set<String> candInputs, Set<Long> memo, Map<String, Hop> invariantOps) {
+ if( hop.isVisited() )
+ return;
+
+ //replace amenable inputs or process recursively
+ //(without iterators due to parent-child modifications)
+ for( int i=0; i<hop.getInput().size(); i++ ) {
+ Hop c = hop.getInput().get(i);
+ if( memo.contains(c.getHopID()) ) {
+ String tmpName = createCutVarName(false);
+ Hop tmp = Recompiler.deepCopyHopsDag(c);
+ tmp.getParent().clear();
+ invariantOps.put(tmpName, tmp);
+
+ //create read and replace all parent references
+ DataOp tread = HopRewriteUtils.createTransientRead(tmpName, c);
+ List<Hop> parents = new ArrayList<>(c.getParent());
+ for( Hop p : parents )
+ HopRewriteUtils.replaceChildReference(p, c, tread);
+ }
+ else {
+ rCollectAndReplaceOperations(c, candInputs, memo, invariantOps);
+ }
+ }
+
+ hop.setVisited();
+ }
+
+ private StatementBlock createStatementBlock(StatementBlock sb, Map<String, Hop> invariantOps) {
+ //create empty last-level statement block
+ StatementBlock ret = new StatementBlock();
+ ret.setDMLProg(sb.getDMLProg());
+ ret.setParseInfo(sb);
+ ret.setLiveIn(new VariableSet(sb.liveIn()));
+ ret.setLiveOut(new VariableSet(sb.liveIn()));
+
+ //append hops with custom
+ ArrayList<Hop> hops = new ArrayList<>();
+ for( Entry<String, Hop> e : invariantOps.entrySet() ) {
+ Hop h = e.getValue();
+ DataOp twrite = HopRewriteUtils.createTransientWrite(e.getKey(), h);
+ hops.add(twrite);
+ //update live variable analysis
+ DataIdentifier diVar = new DataIdentifier(e.getKey());
+ diVar.setDimensions(h.getDim1(), h.getDim2());
+ diVar.setBlockDimensions(h.getRowsInBlock(), h.getColsInBlock());
+ diVar.setDataType(h.getDataType());
+ diVar.setValueType(h.getValueType());
+ ret.liveOut().addVariable(e.getKey(), diVar);
+ sb.liveIn().addVariable(e.getKey(), diVar);
+ }
+ ret.setHops(hops);
+ return ret;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
index 6c3ad76..853a02d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
@@ -102,10 +102,9 @@ public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRu
long dim1 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim1() : dat.getDim1();
long dim2 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim2() : dat.getDim2();
DataOp tread = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, DataOpTypes.TRANSIENTREAD,
- dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize, blocksize);
+ dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize, blocksize);
tread.setRequiresCheckpoint(true);
- DataOp twrite = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, tread, DataOpTypes.TRANSIENTWRITE, null);
- HopRewriteUtils.setOutputParameters(twrite, dim1, dim2, blocksize, blocksize, dat.getNnz());
+ DataOp twrite = HopRewriteUtils.createTransientWrite(var, tread);
hops.add(twrite);
livein.addVariable(var, read.getVariable(var));
liveout.addVariable(var, read.getVariable(var));
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index a758ee0..afbf483 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -45,8 +45,6 @@ import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
-import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.matrix.data.Pair;
/**
@@ -68,10 +66,6 @@ import org.apache.sysml.runtime.matrix.data.Pair;
*/
public class RewriteSplitDagDataDependentOperators extends StatementBlockRewriteRule
{
- private static final String SB_CUT_PREFIX = "_sbcvar";
- private static final String FUN_CUT_PREFIX = "_funvar";
- private static IDSequence _seq = new IDSequence();
-
@Override
public boolean createsSplitDag() {
return true;
@@ -123,8 +117,6 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
String varname = null;
long rlen = c.getDim1();
long clen = c.getDim2();
- long nnz = c.getNnz();
- UpdateType update = c.getUpdateType();
int brlen = c.getRowsInBlock();
int bclen = c.getColsInBlock();
@@ -134,10 +126,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
varname = twrite.getName();
//create new transient read
- DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(),
- DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
- tread.setVisited();
- HopRewriteUtils.copyLineNumbers(c, tread);
+ DataOp tread = HopRewriteUtils.createTransientRead(varname, c);
//replace data-dependent operator with transient read
ArrayList<Hop> parents = new ArrayList<>(c.getParent());
@@ -160,10 +149,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
varname = createCutVarName(false);
//create new transient read
- DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(),
- DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
- tread.setVisited();
- HopRewriteUtils.copyLineNumbers(c, tread);
+ DataOp tread = HopRewriteUtils.createTransientRead(varname, c);
//replace data-dependent operator with transient read
ArrayList<Hop> parents = new ArrayList<>(c.getParent());
@@ -175,11 +161,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
}
//add data-dependent operator sub dag to first statement block
- DataOp twrite = new DataOp(varname, c.getDataType(),
- c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null);
- twrite.setVisited();
- twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen);
- HopRewriteUtils.copyLineNumbers(c, twrite);
+ DataOp twrite = HopRewriteUtils.createTransientWrite(varname, c);
sb1hops.add(twrite);
}
@@ -364,16 +346,10 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
if( tread == null ) {
String varname = createCutVarName(false);
- tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, null,
- c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
- tread.setVisited();
- HopRewriteUtils.copyLineNumbers(c, tread);
+ tread = HopRewriteUtils.createTransientRead(varname, c);
reuseTRead.put(c.getHopID(), tread);
- DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null);
- twrite.setVisited();
- twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
- HopRewriteUtils.copyLineNumbers(c, twrite);
+ DataOp twrite = HopRewriteUtils.createTransientWrite(varname, c);
//update live in and out of new statement block (for piggybacking)
DataIdentifier diVar = new DataIdentifier(varname);
@@ -484,11 +460,4 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
return sbs;
}
-
- public static String createCutVarName(boolean fun) {
- return fun ?
- FUN_CUT_PREFIX + _seq.getNextID() :
- SB_CUT_PREFIX + _seq.getNextID();
-
- }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
index 30631d6..a4c31d9 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
@@ -34,7 +34,6 @@ import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
/**
* Rule: Split Hop DAG after CSV reads with unknown size. This is
@@ -81,13 +80,6 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule
ArrayList<Hop> sb1hops = new ArrayList<>();
for( Hop reblock : cand )
{
- long rlen = reblock.getDim1();
- long clen = reblock.getDim2();
- long nnz = reblock.getNnz();
- UpdateType update = reblock.getUpdateType();
- int brlen = reblock.getRowsInBlock();
- int bclen = reblock.getColsInBlock();
-
//replace reblock inputs to avoid dangling references across dags
//(otherwise, for instance, literal ops are shared across dags)
for( int i=0; i<reblock.getInput().size(); i++ )
@@ -96,9 +88,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule
new LiteralOp((LiteralOp)reblock.getInput().get(i)));
//create new transient read
- DataOp tread = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(),
- DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
- HopRewriteUtils.copyLineNumbers(reblock, tread);
+ DataOp tread = HopRewriteUtils.createTransientRead(reblock.getName(), reblock);
//replace reblock with transient read
ArrayList<Hop> parents = new ArrayList<>(reblock.getParent());
@@ -108,10 +98,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule
}
//add reblock sub dag to first statement block
- DataOp twrite = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(),
- reblock, DataOpTypes.TRANSIENTWRITE, null);
- twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen);
- HopRewriteUtils.copyLineNumbers(reblock, twrite);
+ DataOp twrite = HopRewriteUtils.createTransientWrite(reblock.getName(), reblock);
sb1hops.add(twrite);
//update live in and out of new statement block (for piggybacking)
@@ -128,8 +115,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule
ret.add(sb); //statement block with remaining hops
sb.setSplitDag(true); //avoid later merge by other rewrites
}
- catch(Exception ex)
- {
+ catch(Exception ex) {
throw new HopsException("Failed to split hops dag for csv read with unknown size.", ex);
}
LOG.debug("Applied splitDagUnknownCSVRead.");
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
index fe8d111..9f4b619 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
@@ -25,6 +25,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
/**
* Base class for all hop rewrites in order to enable generic
@@ -34,6 +35,17 @@ import org.apache.sysml.parser.StatementBlock;
public abstract class StatementBlockRewriteRule
{
protected static final Log LOG = LogFactory.getLog(StatementBlockRewriteRule.class.getName());
+
+ private static final String SB_CUT_PREFIX = "_sbcvar";
+ private static final String FUN_CUT_PREFIX = "_funvar";
+ private static IDSequence _seq = new IDSequence();
+
+ public static String createCutVarName(boolean fun) {
+ return fun ?
+ FUN_CUT_PREFIX + _seq.getNextID() :
+ SB_CUT_PREFIX + _seq.getNextID();
+
+ }
/**
* Indicates if the rewrite potentially splits dags, which is used
@@ -52,7 +64,7 @@ public abstract class StatementBlockRewriteRule
* @param sate program rewrite status
* @return list of statement blocks
*/
- public abstract List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus sate);
+ public abstract List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state);
/**
* Handle a list of statement blocks. Specific type constraints have to be ensured
@@ -63,5 +75,5 @@ public abstract class StatementBlockRewriteRule
* @param sate program rewrite status
* @return list of statement blocks
*/
- public abstract List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate);
+ public abstract List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/parser/StatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java
index 2957482..190a481 100644
--- a/src/main/java/org/apache/sysml/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java
@@ -30,7 +30,7 @@ import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.recompile.Recompiler;
-import org.apache.sysml.hops.rewrite.RewriteSplitDagDataDependentOperators;
+import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.FormatType;
@@ -537,7 +537,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
for( ParameterExpression pexpr : fexpr.getParamExprs() )
pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp));
if( !root ) { //core hoisting
- String varname = RewriteSplitDagDataDependentOperators.createCutVarName(true);
+ String varname = StatementBlockRewriteRule.createCutVarName(true);
DataIdentifier di = new DataIdentifier(varname);
di.setDataType(fexpr.getDataType());
di.setValueType(fexpr.getValueType());
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java
new file mode 100644
index 0000000..2a28ae7
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteHoistingLoopInvariantOpsTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteCodeMotionFor";
+ private static final String TEST_NAME2 = "RewriteCodeMotionWhile";
+
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteHoistingLoopInvariantOpsTest.class.getSimpleName() + "/";
+
+ private static final int rows = 265;
+ private static final int cols = 132;
+ private static final int iters = 10;
+ private static final double sparsity = 0.1;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
+ addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
+ }
+
+ @Test
+ public void testCodeMotionForCP() {
+ testRewriteCodeMotion(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testCodeMotionForRewriteCP() {
+ testRewriteCodeMotion(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testCodeMotionWhileCP() {
+ testRewriteCodeMotion(TEST_NAME2, false, ExecType.CP);
+ }
+
+ @Test
+ public void testCodeMotionWhileRewriteCP() {
+ testRewriteCodeMotion(TEST_NAME2, true, ExecType.CP);
+ }
+
+ private void testRewriteCodeMotion(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_CODE_MOTION;
+ OptimizerUtils.ALLOW_CODE_MOTION = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args",
+ input("X"), String.valueOf(iters), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), String.valueOf(iters), expectedDir());
+
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsity, 7);
+ writeInputMatrixWithMTD("X", X, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check applied code motion rewrites (moved sum and - from 10 to 1)
+ Assert.assertEquals(rewrites?1:10, Statistics.getCPHeavyHitterCount("uak+"));
+ Assert.assertEquals(rewrites?1:10, Statistics.getCPHeavyHitterCount("-"));
+ }
+ finally {
+ OptimizerUtils.ALLOW_CODE_MOTION = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionFor.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionFor.R b/src/test/scripts/functions/misc/RewriteCodeMotionFor.R
new file mode 100644
index 0000000..5d21bd1
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCodeMotionFor.R
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+
+R = matrix(0, 1, 1);
+for(i in 1:as.integer(args[2])) {
+ t1 = X - sum(X);
+ t2 = X + max(X/i);
+ R = R + min(t1 * t2);
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""));
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml b/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml
new file mode 100644
index 0000000..e1acd85
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+
+R = matrix(0, 1, 1);
+for(i in 1:$2) {
+ t1 = X - sum(X);
+ t2 = X + max(X/i);
+ R = R + min(t1 * t2);
+}
+
+write(R, $3);
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R
new file mode 100644
index 0000000..1cfe05d
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+
+R = matrix(0, 1, 1);
+i = 1;
+while( i <= as.integer(args[2]) ) {
+ t1 = X - sum(X);
+ t2 = X + max(X/i);
+ R = R + min(t1 * t2);
+ i = i + 1;
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""));
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml
new file mode 100644
index 0000000..2e3f349
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+
+R = matrix(0, 1, 1);
+i = 1;
+while( i <= $2 ) {
+ t1 = X - sum(X);
+ t2 = X + max(X/i);
+ R = R + min(t1 * t2);
+ i += 1;
+}
+
+write(R, $3);
http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index b75b07a..6166e3d 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -61,6 +61,7 @@ import org.junit.runners.Suite;
RewriteFoldRCBindTest.class,
RewriteFuseBinaryOpChainTest.class,
RewriteFusedRandTest.class,
+ RewriteHoistingLoopInvariantOpsTest.class,
RewriteIndexingVectorizationTest.class,
RewriteLoopVectorization.class,
RewriteMatrixMultChainOptTest.class,