You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/09/20 11:27:38 UTC

[systemds] 02/03: [SYSTEMDS-3123] Rewrite c bind 0 Matrix Multiplication

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit e607544ea6ab1f2ec1c9f2c8370c38c86c346170
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Tue Sep 7 19:14:56 2021 +0200

    [SYSTEMDS-3123] Rewrite c bind 0 Matrix Multiplication
    
    ```
     cbind((X %*% Y), matrix(0, nrow(X), 1))
    
     ->
    
     X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
    ```
    
    This commit contains a rewrite that change the sequences if number
    of rows in X is 2x larger than Y:
    
    This rewrite effects MLogReg in line 215 to not force allocation of the
    large X twice.
---
 .../RewriteAlgebraicSimplificationDynamic.java     |  54 +++++---
 .../compress/workload/WorkloadAnalyzer.java        |  30 ++---
 .../compress/workload/WorkloadAlgorithmTest.java   |  34 ++---
 .../rewrite/RewriteMMCBindZeroVector.java          | 145 +++++++++++++++++++++
 .../compress/workload/WorkloadAnalysisMLogReg.dml  |  13 +-
 .../RewritMMCBindZeroVectorOp.dml}                 |  26 +---
 6 files changed, 229 insertions(+), 73 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 63a05a4..0b91e5d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -26,6 +26,20 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 
+import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
+import org.apache.sysds.common.Types.OpOp4;
+import org.apache.sysds.common.Types.OpOpDG;
+import org.apache.sysds.common.Types.OpOpN;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
+import org.apache.sysds.common.Types.ReOrgOp;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggUnaryOp;
 import org.apache.sysds.hops.BinaryOp;
@@ -41,22 +55,8 @@ import org.apache.sysds.hops.QuaternaryOp;
 import org.apache.sysds.hops.ReorgOp;
 import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.common.Types.AggOp;
-import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.common.Types.OpOp1;
-import org.apache.sysds.common.Types.OpOp2;
-import org.apache.sysds.common.Types.OpOp3;
-import org.apache.sysds.common.Types.OpOp4;
-import org.apache.sysds.common.Types.OpOpDG;
-import org.apache.sysds.common.Types.OpOpN;
-import org.apache.sysds.common.Types.ParamBuiltinOp;
-import org.apache.sysds.common.Types.ReOrgOp;
 import org.apache.sysds.lops.MapMultChain.ChainType;
 import org.apache.sysds.parser.DataExpression;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
 
 /**
  * Rule: Algebraic Simplifications. Simplifies binary expressions
@@ -109,7 +109,6 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
 		if( root == null )
 			return root;
-		
 		//one pass rewrite-descend (rewrite created pattern)
 		rule_AlgebraicSimplification( root, false );
 		
@@ -197,6 +196,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			hi = simplifyNnzComputation(hop, hi, i);          //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
 			hi = simplifyNrowNcolComputation(hop, hi, i);     //e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency
 			hi = simplifyTableSeqExpand(hop, hi, i);          //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
+			hi = simplyfyMMCBindZeroVector(hop, hi, i);       //e.g.. cbind((X %*% Y), matrix (0, nrow(X), 1)) -> X %*% (cbind(Y, matrix(0, nrow(Y), 1))) if nRows of x is larger than nCols of y
 			if( OptimizerUtils.ALLOW_OPERATOR_FUSION )
 				foldMultipleMinMaxOperations(hi);             //e.g., min(X,min(min(3,7),Y)) -> min(X,3,7,Y)
 			
@@ -2796,4 +2796,28 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		
 		return hi;
 	}
+
+	private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int pos) {
+
+		// cbind((X %*% Y), matrix(0, nrow(X), 1)) ->
+		// X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
+		// if nRows of x is larger than nCols of y
+		// rewrite used in MLogReg first level loop.
+		
+		if(HopRewriteUtils.isBinary(hi, OpOp2.CBIND) && HopRewriteUtils.isMatrixMultiply(hi.getInput(0)) &&
+			HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(1), 0) && hi.getDim1() > hi.getDim2() * 2) {
+			final Hop oldGen = hi.getInput(1);
+			final Hop y = hi.getInput(0).getInput(1);
+			final Hop x = hi.getInput(0).getInput(0);
+			final Hop newGen = HopRewriteUtils.createDataGenOp(y, oldGen, 0);
+			final Hop newCBind = HopRewriteUtils.createBinary(y, newGen, OpOp2.CBIND);
+			final Hop newMM = HopRewriteUtils.createMatrixMultiply(x, newCBind);
+
+			HopRewriteUtils.replaceChildReference(parent, hi, newMM, pos);
+			LOG.debug("Applied MMCBind Zero algebraic simplification (line " +hi.getBeginLine()+")." );
+			return newMM;
+
+		}
+		return hi;
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index 31b3714..c865507 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -381,21 +381,21 @@ public class WorkloadAnalyzer {
 					transientCompressed.contains(in.get(1).getName());
 				OpSided ret = new OpSided(hop, left, right, transposedLeft, transposedRight);
 				if(ret.isRightMM()) {
-					HashSet<Long> overlapping2 = new HashSet<>();
-					overlapping2.add(hop.getHopID());
-					WorkloadAnalyzer overlappingAnalysis = new WorkloadAnalyzer(prog, overlapping2);
-					WTreeRoot r = overlappingAnalysis.createWorkloadTree(hop);
-
-					CostEstimatorBuilder b = new CostEstimatorBuilder(r);
-					if(LOG.isTraceEnabled())
-						LOG.trace("Workload for overlapping: " + r + "\n" + b);
-
-					if(b.shouldUseOverlap())
-						overlapping.add(hop.getHopID());
-					else {
-						decompressHops.add(hop);
-						ret.setOverlappingDecompression(true);
-					}
+					// HashSet<Long> overlapping2 = new HashSet<>();
+					// overlapping2.add(hop.getHopID());
+					// WorkloadAnalyzer overlappingAnalysis = new WorkloadAnalyzer(prog, overlapping2);
+					// WTreeRoot r = overlappingAnalysis.createWorkloadTree(hop);
+
+					// CostEstimatorBuilder b = new CostEstimatorBuilder(r);
+					// if(LOG.isTraceEnabled())
+					// 	LOG.trace("Workload for overlapping: " + r + "\n" + b);
+
+					// if(b.shouldUseOverlap())
+					overlapping.add(hop.getHopID());
+					// else {
+					// 	decompressHops.add(hop);
+					// 	ret.setOverlappingDecompression(true);
+					// }
 				}
 
 				return ret;
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index 5de8880..af05bdc 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -83,7 +83,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 
 	@Test
 	public void testLmCP() {
-		runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false);
+		runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, false);
 	}
 
 	@Test
@@ -93,7 +93,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 
 	@Test
 	public void testLmDSCP() {
-		runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false);
+		runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, false);
 	}
 
 	@Test
@@ -103,41 +103,42 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 
 	@Test
 	public void testPCACP() {
-		runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1, false);
+		runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SINGLE_NODE, 1, false);
 	}
 
 	@Test
 	public void testSliceLineCP1() {
-		runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 0, false);
+		runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 0, false);
 	}
 
 	@Test
 	public void testSliceLineCP2() {
-		runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 2, true);
+		runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 2, true);
 	}
 
 	@Test
 	public void testLmCGSP() {
 		runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SPARK, 2, false);
 	}
-	
+
 	@Test
 	public void testLmCGCP() {
-		runWorkloadAnalysisTest(TEST_NAME6, ExecMode.HYBRID, 2, false);
+		runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SINGLE_NODE, 2, false);
 	}
-	
+
 	// private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount) {
 	private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates) {
 		ExecMode oldPlatform = setExecMode(mode);
 		boolean oldIntermediates = WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;
-		
+
 		try {
 			loadTestConfiguration(getTestConfiguration(testname));
 			WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES = intermediates;
-			
+
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B")};
+			programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"),
+				output("B")};
 
 			writeInputMatrixWithMTD("X", X, false);
 			writeInputMatrixWithMTD("y", y, false);
@@ -149,11 +150,12 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 			long actualCompressionCount = (mode == ExecMode.HYBRID || mode == ExecMode.SINGLE_NODE) ? Statistics
 				.getCPHeavyHitterCount("compress") : Statistics.getCPHeavyHitterCount("sp_compress");
 
-			Assert.assertEquals(compressionCount, actualCompressionCount);
-			if( compressionCount > 0 )
-				Assert.assertTrue( mode == ExecMode.HYBRID ?
-					heavyHittersContainsString("compress") : heavyHittersContainsString("sp_compress"));
-			if( !testname.equals(TEST_NAME4) )
+			Assert.assertEquals("Assert that the compression counts expeted matches actual: " + compressionCount
+				+ " vs " + actualCompressionCount, compressionCount, actualCompressionCount);
+			if(compressionCount > 0)
+				Assert.assertTrue(mode == ExecMode.SINGLE_NODE || mode == ExecMode.HYBRID ?  heavyHittersContainsString(
+					"compress") : heavyHittersContainsString("sp_compress"));
+			if(!testname.equals(TEST_NAME4))
 				Assert.assertFalse(heavyHittersContainsString("m_scale"));
 
 		}
diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java
new file mode 100644
index 0000000..1cc1cca
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import static org.junit.Assert.fail;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+/**
+ * from:
+ * 
+ * res = cbind((X %*% Y), matrix (0, nrow(X), 1));
+ * 
+ * to:
+ * 
+ * res = X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
+ * 
+ * 
+ * if the X has many rows, the allocation of x is expensive, to cbind. the case where this is applicable is mLogReg.
+ * 
+ */
+public class RewriteMMCBindZeroVector extends AutomatedTestBase {
+	// private static final Log LOG = LogFactory.getLog(RewriteMMCBindZeroVector.class.getName());
+
+	private static final String TEST_NAME1 = "RewritMMCBindZeroVectorOp";
+	private static final String TEST_DIR = "functions/rewrite/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteMMCBindZeroVector.class.getSimpleName() + "/";
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+	}
+
+	@Test
+	public void testNoRewritesCP() {
+		testRewrite(TEST_NAME1, false, ExecType.CP, 100, 3, 10);
+	}
+
+	@Test
+	public void testNoRewritesSP() {
+		testRewrite(TEST_NAME1, false, ExecType.SPARK, 100, 3, 10);
+	}
+
+	@Test
+	public void testRewritesCP() {
+		testRewrite(TEST_NAME1, true, ExecType.CP, 100, 3, 10);
+	}
+
+	@Test
+	public void testRewritesSP() {
+		testRewrite(TEST_NAME1, true, ExecType.SPARK, 100, 3, 10);
+	}
+
+	private void testRewrite(String testname, boolean rewrites, ExecType et, int leftRows, int rightCols, int shared) {
+		ExecMode platformOld = rtplatform;
+		switch(et) {
+			case SPARK:
+				rtplatform = ExecMode.SPARK;
+				break;
+			default:
+				rtplatform = ExecMode.HYBRID;
+				break;
+		}
+
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if(rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID)
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+		boolean rewritesOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+		OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+
+		try {
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[] {"-explain", "-stats", "-args", input("X"), input("Y"),
+				output("R")};
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), expectedDir());
+
+			double[][] X = getRandomMatrix(leftRows, shared, -1, 1, 0.97d, 7);
+			double[][] Y = getRandomMatrix(shared, rightCols, -1, 1, 0.9d, 3);
+			writeInputMatrixWithMTD("X", X, false);
+			writeInputMatrixWithMTD("Y", Y, false);
+
+			// execute tests
+			String out = runTest(null).toString();
+
+			for(String line : out.split("\n")) {
+				if(rewrites) {
+					if(line.contains("append"))
+						break;
+					else if(line.contains("ba+*"))
+						fail(
+							"invalid execution matrix multiplication is done before append, therefore the rewrite did not tricker.\n\n"
+								+ out);
+				}
+				else {
+					if(line.contains("ba+*"))
+						break;
+					else if(line.contains("append"))
+						fail(
+							"invalid execution append was done before multiplication, therefore the rewrite did tricker when not allowed.\n\n"
+								+ out);
+				}
+
+			}
+			// compare matrices
+			// HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
+
+		}
+		finally {
+			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewritesOld;
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
index 12d9dd5..d427506 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
@@ -22,21 +22,20 @@
 X = read($1);
 Y = read($2);
 
-
 print("")
 print("MLogReg")
 
 X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
+B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2, icpt=0);
 
 [nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)
-
 [nn, C] = confusionMatrix(P, Y)
-print("Confusion: ")
-print(toString(C))
 
+print("Confusion:")
+print(toString(C))
+print("")
 print(acc)
 
-if(acc < 50){
+if(acc < 50)
     stop("MLogReg Accuracy achieved is not high enough")
-}
+
diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
similarity index 71%
copy from src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
copy to src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
index 12d9dd5..e6b0498 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,24 +19,10 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-Y = read($2);
-
-
-print("")
-print("MLogReg")
-
-X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
-
-[nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)
 
-[nn, C] = confusionMatrix(P, Y)
-print("Confusion: ")
-print(toString(C))
+X = read($1)
+Y = read($2)
 
-print(acc)
+res = cbind((X %*% Y), matrix (0, nrow(X), 1));
 
-if(acc < 50){
-    stop("MLogReg Accuracy achieved is not high enough")
-}
+print(sum(res))