You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/09/20 11:27:38 UTC
[systemds] 02/03: [SYSTEMDS-3123] Rewrite c bind 0 Matrix
Multiplication
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit e607544ea6ab1f2ec1c9f2c8370c38c86c346170
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Tue Sep 7 19:14:56 2021 +0200
[SYSTEMDS-3123] Rewrite c bind 0 Matrix Multiplication
```
cbind((X %*% Y), matrix(0, nrow(X), 1))
->
X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
```
This commit contains a rewrite that change the sequences if number
of rows in X is 2x larger than Y:
This rewrite effects MLogReg in line 215 to not force allocation of the
large X twice.
---
.../RewriteAlgebraicSimplificationDynamic.java | 54 +++++---
.../compress/workload/WorkloadAnalyzer.java | 30 ++---
.../compress/workload/WorkloadAlgorithmTest.java | 34 ++---
.../rewrite/RewriteMMCBindZeroVector.java | 145 +++++++++++++++++++++
.../compress/workload/WorkloadAnalysisMLogReg.dml | 13 +-
.../RewritMMCBindZeroVectorOp.dml} | 26 +---
6 files changed, 229 insertions(+), 73 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 63a05a4..0b91e5d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -26,6 +26,20 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
+import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
+import org.apache.sysds.common.Types.OpOp4;
+import org.apache.sysds.common.Types.OpOpDG;
+import org.apache.sysds.common.Types.OpOpN;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
+import org.apache.sysds.common.Types.ReOrgOp;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
@@ -41,22 +55,8 @@ import org.apache.sysds.hops.QuaternaryOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.common.Types.AggOp;
-import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.common.Types.OpOp1;
-import org.apache.sysds.common.Types.OpOp2;
-import org.apache.sysds.common.Types.OpOp3;
-import org.apache.sysds.common.Types.OpOp4;
-import org.apache.sysds.common.Types.OpOpDG;
-import org.apache.sysds.common.Types.OpOpN;
-import org.apache.sysds.common.Types.ParamBuiltinOp;
-import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.parser.DataExpression;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
/**
* Rule: Algebraic Simplifications. Simplifies binary expressions
@@ -109,7 +109,6 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
if( root == null )
return root;
-
//one pass rewrite-descend (rewrite created pattern)
rule_AlgebraicSimplification( root, false );
@@ -197,6 +196,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
hi = simplifyNrowNcolComputation(hop, hi, i); //e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency
hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
+ hi = simplyfyMMCBindZeroVector(hop, hi, i); //e.g.. cbind((X %*% Y), matrix (0, nrow(X), 1)) -> X %*% (cbind(Y, matrix(0, nrow(Y), 1))) if nRows of x is larger than nCols of y
if( OptimizerUtils.ALLOW_OPERATOR_FUSION )
foldMultipleMinMaxOperations(hi); //e.g., min(X,min(min(3,7),Y)) -> min(X,3,7,Y)
@@ -2796,4 +2796,28 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
return hi;
}
+
+ private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int pos) {
+
+ // cbind((X %*% Y), matrix(0, nrow(X), 1)) ->
+ // X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
+ // if nRows of x is larger than nCols of y
+ // rewrite used in MLogReg first level loop.
+
+ if(HopRewriteUtils.isBinary(hi, OpOp2.CBIND) && HopRewriteUtils.isMatrixMultiply(hi.getInput(0)) &&
+ HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(1), 0) && hi.getDim1() > hi.getDim2() * 2) {
+ final Hop oldGen = hi.getInput(1);
+ final Hop y = hi.getInput(0).getInput(1);
+ final Hop x = hi.getInput(0).getInput(0);
+ final Hop newGen = HopRewriteUtils.createDataGenOp(y, oldGen, 0);
+ final Hop newCBind = HopRewriteUtils.createBinary(y, newGen, OpOp2.CBIND);
+ final Hop newMM = HopRewriteUtils.createMatrixMultiply(x, newCBind);
+
+ HopRewriteUtils.replaceChildReference(parent, hi, newMM, pos);
+ LOG.debug("Applied MMCBind Zero algebraic simplification (line " +hi.getBeginLine()+")." );
+ return newMM;
+
+ }
+ return hi;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index 31b3714..c865507 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -381,21 +381,21 @@ public class WorkloadAnalyzer {
transientCompressed.contains(in.get(1).getName());
OpSided ret = new OpSided(hop, left, right, transposedLeft, transposedRight);
if(ret.isRightMM()) {
- HashSet<Long> overlapping2 = new HashSet<>();
- overlapping2.add(hop.getHopID());
- WorkloadAnalyzer overlappingAnalysis = new WorkloadAnalyzer(prog, overlapping2);
- WTreeRoot r = overlappingAnalysis.createWorkloadTree(hop);
-
- CostEstimatorBuilder b = new CostEstimatorBuilder(r);
- if(LOG.isTraceEnabled())
- LOG.trace("Workload for overlapping: " + r + "\n" + b);
-
- if(b.shouldUseOverlap())
- overlapping.add(hop.getHopID());
- else {
- decompressHops.add(hop);
- ret.setOverlappingDecompression(true);
- }
+ // HashSet<Long> overlapping2 = new HashSet<>();
+ // overlapping2.add(hop.getHopID());
+ // WorkloadAnalyzer overlappingAnalysis = new WorkloadAnalyzer(prog, overlapping2);
+ // WTreeRoot r = overlappingAnalysis.createWorkloadTree(hop);
+
+ // CostEstimatorBuilder b = new CostEstimatorBuilder(r);
+ // if(LOG.isTraceEnabled())
+ // LOG.trace("Workload for overlapping: " + r + "\n" + b);
+
+ // if(b.shouldUseOverlap())
+ overlapping.add(hop.getHopID());
+ // else {
+ // decompressHops.add(hop);
+ // ret.setOverlappingDecompression(true);
+ // }
}
return ret;
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index 5de8880..af05bdc 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -83,7 +83,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
@Test
public void testLmCP() {
- runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false);
+ runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, false);
}
@Test
@@ -93,7 +93,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
@Test
public void testLmDSCP() {
- runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false);
+ runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, false);
}
@Test
@@ -103,41 +103,42 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
@Test
public void testPCACP() {
- runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1, false);
+ runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SINGLE_NODE, 1, false);
}
@Test
public void testSliceLineCP1() {
- runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 0, false);
+ runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 0, false);
}
@Test
public void testSliceLineCP2() {
- runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 2, true);
+ runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 2, true);
}
@Test
public void testLmCGSP() {
runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SPARK, 2, false);
}
-
+
@Test
public void testLmCGCP() {
- runWorkloadAnalysisTest(TEST_NAME6, ExecMode.HYBRID, 2, false);
+ runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SINGLE_NODE, 2, false);
}
-
+
// private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount) {
private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates) {
ExecMode oldPlatform = setExecMode(mode);
boolean oldIntermediates = WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;
-
+
try {
loadTestConfiguration(getTestConfiguration(testname));
WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES = intermediates;
-
+
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B")};
+ programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"),
+ output("B")};
writeInputMatrixWithMTD("X", X, false);
writeInputMatrixWithMTD("y", y, false);
@@ -149,11 +150,12 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
long actualCompressionCount = (mode == ExecMode.HYBRID || mode == ExecMode.SINGLE_NODE) ? Statistics
.getCPHeavyHitterCount("compress") : Statistics.getCPHeavyHitterCount("sp_compress");
- Assert.assertEquals(compressionCount, actualCompressionCount);
- if( compressionCount > 0 )
- Assert.assertTrue( mode == ExecMode.HYBRID ?
- heavyHittersContainsString("compress") : heavyHittersContainsString("sp_compress"));
- if( !testname.equals(TEST_NAME4) )
+ Assert.assertEquals("Assert that the compression counts expeted matches actual: " + compressionCount
+ + " vs " + actualCompressionCount, compressionCount, actualCompressionCount);
+ if(compressionCount > 0)
+ Assert.assertTrue(mode == ExecMode.SINGLE_NODE || mode == ExecMode.HYBRID ? heavyHittersContainsString(
+ "compress") : heavyHittersContainsString("sp_compress"));
+ if(!testname.equals(TEST_NAME4))
Assert.assertFalse(heavyHittersContainsString("m_scale"));
}
diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java
new file mode 100644
index 0000000..1cc1cca
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import static org.junit.Assert.fail;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+/**
+ * from:
+ *
+ * res = cbind((X %*% Y), matrix (0, nrow(X), 1));
+ *
+ * to:
+ *
+ * res = X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
+ *
+ *
+ * if the X has many rows, the allocation of x is expensive, to cbind. the case where this is applicable is mLogReg.
+ *
+ */
+public class RewriteMMCBindZeroVector extends AutomatedTestBase {
+ // private static final Log LOG = LogFactory.getLog(RewriteMMCBindZeroVector.class.getName());
+
+ private static final String TEST_NAME1 = "RewritMMCBindZeroVectorOp";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteMMCBindZeroVector.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+ }
+
+ @Test
+ public void testNoRewritesCP() {
+ testRewrite(TEST_NAME1, false, ExecType.CP, 100, 3, 10);
+ }
+
+ @Test
+ public void testNoRewritesSP() {
+ testRewrite(TEST_NAME1, false, ExecType.SPARK, 100, 3, 10);
+ }
+
+ @Test
+ public void testRewritesCP() {
+ testRewrite(TEST_NAME1, true, ExecType.CP, 100, 3, 10);
+ }
+
+ @Test
+ public void testRewritesSP() {
+ testRewrite(TEST_NAME1, true, ExecType.SPARK, 100, 3, 10);
+ }
+
+ private void testRewrite(String testname, boolean rewrites, ExecType et, int leftRows, int rightCols, int shared) {
+ ExecMode platformOld = rtplatform;
+ switch(et) {
+ case SPARK:
+ rtplatform = ExecMode.SPARK;
+ break;
+ default:
+ rtplatform = ExecMode.HYBRID;
+ break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if(rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+
+ try {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] {"-explain", "-stats", "-args", input("X"), input("Y"),
+ output("R")};
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double[][] X = getRandomMatrix(leftRows, shared, -1, 1, 0.97d, 7);
+ double[][] Y = getRandomMatrix(shared, rightCols, -1, 1, 0.9d, 3);
+ writeInputMatrixWithMTD("X", X, false);
+ writeInputMatrixWithMTD("Y", Y, false);
+
+ // execute tests
+ String out = runTest(null).toString();
+
+ for(String line : out.split("\n")) {
+ if(rewrites) {
+ if(line.contains("append"))
+ break;
+ else if(line.contains("ba+*"))
+ fail(
+ "invalid execution matrix multiplication is done before append, therefore the rewrite did not tricker.\n\n"
+ + out);
+ }
+ else {
+ if(line.contains("ba+*"))
+ break;
+ else if(line.contains("append"))
+ fail(
+ "invalid execution append was done before multiplication, therefore the rewrite did tricker when not allowed.\n\n"
+ + out);
+ }
+
+ }
+ // compare matrices
+ // HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
+
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
index 12d9dd5..d427506 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
@@ -22,21 +22,20 @@
X = read($1);
Y = read($2);
-
print("")
print("MLogReg")
X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
+B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2, icpt=0);
[nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)
-
[nn, C] = confusionMatrix(P, Y)
-print("Confusion: ")
-print(toString(C))
+print("Confusion:")
+print(toString(C))
+print("")
print(acc)
-if(acc < 50){
+if(acc < 50)
stop("MLogReg Accuracy achieved is not high enough")
-}
+
diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
similarity index 71%
copy from src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
copy to src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
index 12d9dd5..e6b0498 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,24 +19,10 @@
#
#-------------------------------------------------------------
-X = read($1);
-Y = read($2);
-
-
-print("")
-print("MLogReg")
-
-X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
-
-[nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)
-[nn, C] = confusionMatrix(P, Y)
-print("Confusion: ")
-print(toString(C))
+X = read($1)
+Y = read($2)
-print(acc)
+res = cbind((X %*% Y), matrix (0, nrow(X), 1));
-if(acc < 50){
- stop("MLogReg Accuracy achieved is not high enough")
-}
+print(sum(res))