You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/22 11:37:55 UTC
[systemds] branch master updated: [SYSTEMDS-3069] Extended rewrites
for splitting DAGs after compression
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 07c69e6 [SYSTEMDS-3069] Extended rewrites for splitting DAGs after compression
07c69e6 is described below
commit 07c69e62449a95ad889f9453bac8410d667fe689
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Jul 22 13:36:50 2021 +0200
[SYSTEMDS-3069] Extended rewrites for splitting DAGs after compression
This patch extends the existing 'split-DAG after data-dependent
operators' rewrite and the IPA integration of workload-aware compression
in order to allow recompilation according to compression results (e.g.,
compile local instead of distributed operations for highly compressible
data).
---
.../sysds/hops/ipa/InterProceduralAnalysis.java | 3 +-
.../RewriteSplitDagDataDependentOperators.java | 34 +++++++++++++---------
.../spark/AggregateUnarySPInstruction.java | 2 +-
.../compress/workload/WorkloadAlgorithmTest.java | 11 ++-----
4 files changed, 27 insertions(+), 23 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 309d823..0b47a19 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -241,7 +241,8 @@ public class InterProceduralAnalysis
FunctionCallGraph graph2 = new FunctionCallGraph(_prog);
List<IPAPass> fpasses = Arrays.asList(
new IPAPassRemoveUnusedFunctions(),
- new IPAPassCompressionWorkloadAnalysis());
+ new IPAPassCompressionWorkloadAnalysis(), // workload-aware compression
+ new IPAPassApplyStaticAndDynamicHopRewrites()); //split after compress
for(IPAPass pass : fpasses)
if( pass.isApplicable(graph2) )
pass.rewriteProgram(_prog, graph2, null);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index fe00ae0..ecc3f39 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -34,6 +34,7 @@ import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
@@ -42,6 +43,7 @@ import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.lops.Compression.CompressConfig;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
@@ -75,7 +77,10 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state)
{
//DAG splits not required for forced single node
- if( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
+ CompressConfig compress = CompressConfig.valueOf(ConfigurationManager
+ .getDMLConfig().getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase());
+ if( (DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
+ && !(compress != CompressConfig.FALSE) )
|| !HopRewriteUtils.isLastLevelStatementBlock(sb) )
return Arrays.asList(sb);
@@ -225,7 +230,8 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
return;
//prevent unnecessary dag split (dims known or no consumer operations)
- boolean noSplitRequired = ( hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true) );
+ boolean noSplitRequired = (HopRewriteUtils.hasOnlyWriteParents(hop, true, true)
+ || hop.dimsKnown() || DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE);
boolean investigateChilds = true;
//collect data dependent operations (to be extended as necessary)
@@ -294,14 +300,8 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
}
}
- //#4 second-order eval function
- if( HopRewriteUtils.isNary(hop, OpOpN.EVAL) && !noSplitRequired ) {
- cand.add(hop);
- investigateChilds = false;
- }
-
- //#5 sql
- if( hop instanceof DataOp && ((DataOp) hop).getOp() == OpOpData.SQLREAD && !noSplitRequired) {
+ //#4 other data dependent operators (default handling)
+ if( isBasicDataDependentOperator(hop, noSplitRequired) ) {
cand.add(hop);
investigateChilds = false;
}
@@ -314,6 +314,14 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
hop.setVisited();
}
+
+ private static boolean isBasicDataDependentOperator(Hop hop, boolean noSplitRequired) {
+ return (HopRewriteUtils.isNary(hop, OpOpN.EVAL) & !noSplitRequired)
+ || (HopRewriteUtils.isData(hop, OpOpData.SQLREAD) & !noSplitRequired)
+ || (hop.requiresCompression() & !HopRewriteUtils.hasOnlyWriteParents(hop, true, true));
+ //note: for compression we probe for write parents (part of noSplitRequired) directly
+ // because we want to split even if the dimensions are known
+ }
private static boolean hasTransientWriteParents( Hop hop ) {
for( Hop p : hop.getParent() )
@@ -393,7 +401,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
for( Hop c : hop.getInput() )
rAddHopsToProbeSet(c, probeSet);
- hop.setVisited();
+ hop.setVisited();
}
/**
@@ -417,11 +425,11 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
rProbeAndAddHopsToCandidateSet(c, probeSet, candSet);
else
{
- candSet.add(new Pair<>(hop,c));
+ candSet.add(new Pair<>(hop,c));
}
}
- hop.setVisited();
+ hop.setVisited();
}
private void collectCandidateChildOperators( ArrayList<Hop> cand, HashSet<Hop> candChilds )
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index cecbd3d..c135b01 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -212,7 +212,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction {
@Override
public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 )
- throws Exception
+ throws Exception
{
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn = arg0._2();
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index c257a57..eaa9a73 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -33,8 +33,6 @@ import org.junit.Test;
public class WorkloadAlgorithmTest extends AutomatedTestBase {
- // private static final Log LOG = LogFactory.getLog(WorkloadAnalysisTest.class.getName());
-
private final static String TEST_NAME1 = "WorkloadAnalysisMLogReg";
private final static String TEST_NAME2 = "WorkloadAnalysisLm";
private final static String TEST_NAME3 = "WorkloadAnalysisPCA";
@@ -55,7 +53,6 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID, 2);
}
-
@Test
public void testLmSP() {
runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SPARK, 2);
@@ -80,12 +77,12 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
ExecMode oldPlatform = setExecMode(mode);
try {
-
loadTestConfiguration(getTestConfiguration(testname));
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B")};
+ programArgs = new String[] {"-explain","-stats",
+ "20", "-args", input("X"), input("y"), output("B")};
double[][] X = TestUtils.round(getRandomMatrix(10000, 20, 0, 10, 1.0, 7));
writeInputMatrixWithMTD("X", X, false);
@@ -95,9 +92,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
}
writeInputMatrixWithMTD("y", y, false);
- String ret = runTest(null).toString();
- if(ret.contains("ERROR:"))
- fail(ret);
+ runTest(null);
// check various additional expectations
long actualCompressionCount = mode == ExecMode.HYBRID ? Statistics