You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/22 11:37:55 UTC

[systemds] branch master updated: [SYSTEMDS-3069] Extended rewrites for splitting DAGs after compression

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 07c69e6  [SYSTEMDS-3069] Extended rewrites for splitting DAGs after compression
07c69e6 is described below

commit 07c69e62449a95ad889f9453bac8410d667fe689
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Jul 22 13:36:50 2021 +0200

    [SYSTEMDS-3069] Extended rewrites for splitting DAGs after compression
    
    This patch extends the existing 'split-DAG after data-dependent
    operators' rewrite and the IPA integration of workload-aware compression
    in order to allow recompilation according to compression results (e.g.,
    compile local instead of distributed operations for highly compressible
    data).
---
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |  3 +-
 .../RewriteSplitDagDataDependentOperators.java     | 34 +++++++++++++---------
 .../spark/AggregateUnarySPInstruction.java         |  2 +-
 .../compress/workload/WorkloadAlgorithmTest.java   | 11 ++-----
 4 files changed, 27 insertions(+), 23 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 309d823..0b47a19 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -241,7 +241,8 @@ public class InterProceduralAnalysis
 		FunctionCallGraph graph2 = new FunctionCallGraph(_prog);
 		List<IPAPass> fpasses = Arrays.asList(
 			new IPAPassRemoveUnusedFunctions(),
-			new IPAPassCompressionWorkloadAnalysis());
+			new IPAPassCompressionWorkloadAnalysis(), // workload-aware compression
+			new IPAPassApplyStaticAndDynamicHopRewrites());  //split after compress
 		for(IPAPass pass : fpasses)
 			if( pass.isApplicable(graph2) )
 				pass.rewriteProgram(_prog, graph2, null);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index fe00ae0..ecc3f39 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -34,6 +34,7 @@ import org.apache.sysds.common.Types.OpOpN;
 import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.common.Types.ReOrgOp;
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
@@ -42,6 +43,7 @@ import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.ParameterizedBuiltinOp;
 import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.lops.Compression.CompressConfig;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.parser.VariableSet;
@@ -75,7 +77,10 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
 	public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state)
 	{
 		//DAG splits not required for forced single node
-		if( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
+		CompressConfig compress = CompressConfig.valueOf(ConfigurationManager
+			.getDMLConfig().getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase());
+		if( (DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
+			&& !(compress != CompressConfig.FALSE) )
 			|| !HopRewriteUtils.isLastLevelStatementBlock(sb) )
 			return Arrays.asList(sb);
 		
@@ -225,7 +230,8 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
 			return;
 		
 		//prevent unnecessary dag split (dims known or no consumer operations)
-		boolean noSplitRequired = ( hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true) );
+		boolean noSplitRequired = (HopRewriteUtils.hasOnlyWriteParents(hop, true, true)
+			|| hop.dimsKnown() || DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE);
 		boolean investigateChilds = true;
 		
 		//collect data dependent operations (to be extended as necessary)
@@ -294,14 +300,8 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
 			}
 		}
 		
-		//#4 second-order eval function
-		if( HopRewriteUtils.isNary(hop, OpOpN.EVAL) && !noSplitRequired ) {
-			cand.add(hop);
-			investigateChilds = false;
-		}
-		
-		//#5 sql
-		if( hop instanceof DataOp && ((DataOp) hop).getOp() == OpOpData.SQLREAD && !noSplitRequired) {
+		//#4 other data dependent operators (default handling)
+		if( isBasicDataDependentOperator(hop, noSplitRequired) ) {
 			cand.add(hop);
 			investigateChilds = false;
 		}
@@ -314,6 +314,14 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
 		
 		hop.setVisited();
 	}
+	
+	private static boolean isBasicDataDependentOperator(Hop hop, boolean noSplitRequired) {
+		return (HopRewriteUtils.isNary(hop, OpOpN.EVAL) & !noSplitRequired)
+			|| (HopRewriteUtils.isData(hop, OpOpData.SQLREAD) & !noSplitRequired)
+			|| (hop.requiresCompression() & !HopRewriteUtils.hasOnlyWriteParents(hop, true, true));
+		//note: for compression we probe for write parents (part of noSplitRequired) directly
+		// because we want to split even if the dimensions are known 
+	}
 
 	private static boolean hasTransientWriteParents( Hop hop ) {
 		for( Hop p : hop.getParent() )
@@ -393,7 +401,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
 			for( Hop c : hop.getInput() )
 				rAddHopsToProbeSet(c, probeSet);
 	
-		hop.setVisited();	
+		hop.setVisited();
 	}
 	
 	/**
@@ -417,11 +425,11 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
 					rProbeAndAddHopsToCandidateSet(c, probeSet, candSet);
 				else
 				{
-					candSet.add(new Pair<>(hop,c)); 
+					candSet.add(new Pair<>(hop,c));
 				}
 			}
 		
-		hop.setVisited();	
+		hop.setVisited();
 	}
 	
 	private void collectCandidateChildOperators( ArrayList<Hop> cand, HashSet<Hop> candChilds )
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index cecbd3d..c135b01 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -212,7 +212,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction {
 		
 		@Override
 		public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) 
-			throws Exception 
+			throws Exception
 		{
 			MatrixIndexes ixIn = arg0._1();
 			MatrixBlock blkIn = arg0._2();
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index c257a57..eaa9a73 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -33,8 +33,6 @@ import org.junit.Test;
 
 public class WorkloadAlgorithmTest extends AutomatedTestBase {
 
-	// private static final Log LOG = LogFactory.getLog(WorkloadAnalysisTest.class.getName());
-
 	private final static String TEST_NAME1 = "WorkloadAnalysisMLogReg";
 	private final static String TEST_NAME2 = "WorkloadAnalysisLm";
 	private final static String TEST_NAME3 = "WorkloadAnalysisPCA";
@@ -55,7 +53,6 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 		runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID, 2);
 	}
 
-
 	@Test
 	public void testLmSP() {
 		runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SPARK, 2);
@@ -80,12 +77,12 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 		ExecMode oldPlatform = setExecMode(mode);
 
 		try {
-
 			loadTestConfiguration(getTestConfiguration(testname));
 
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B")};
+			programArgs = new String[] {"-explain","-stats",
+				"20", "-args", input("X"), input("y"), output("B")};
 
 			double[][] X = TestUtils.round(getRandomMatrix(10000, 20, 0, 10, 1.0, 7));
 			writeInputMatrixWithMTD("X", X, false);
@@ -95,9 +92,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 			}
 			writeInputMatrixWithMTD("y", y, false);
 
-			String ret = runTest(null).toString();
-			if(ret.contains("ERROR:"))
-				fail(ret);
+			runTest(null);
 
 			// check various additional expectations
 			long actualCompressionCount = mode == ExecMode.HYBRID ? Statistics