You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/08/25 18:20:03 UTC
[systemds] branch master updated: [SYSTEMDS-2990] Extended
workload-tree extraction, sliceline test
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new a9fe3d9 [SYSTEMDS-2990] Extended workload-tree extraction, sliceline test
a9fe3d9 is described below
commit a9fe3d956f65555a18f0368c7e514a69941a1ad2
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Wed Aug 25 20:18:33 2021 +0200
[SYSTEMDS-2990] Extended workload-tree extraction, sliceline test
This patch makes some minor extensions of the CLA workload analyzer,
enabling more aggressive compression of intermediates, pruning of
unnecessary workload-tree construction (for already compressed
intermediates), and adds a related sliceline test and temporary fix for
initialization of MinMaxGroups.
---
.../ipa/IPAPassCompressionWorkloadAnalysis.java | 1 -
.../hops/rewrite/RewriteCompressedReblock.java | 20 +++-----
.../colgroup/dictionary/MatrixBlockDictionary.java | 7 ++-
.../compress/workload/WorkloadAnalyzer.java | 33 ++++++++++---
.../compress/workload/WorkloadAlgorithmTest.java | 41 +++++++++++-----
...iceFinder.dml => WorkloadAnalysisSliceLine.dml} | 57 ++++++++++++----------
6 files changed, 99 insertions(+), 60 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java
index 832e47e..71d4904 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java
@@ -66,6 +66,5 @@ public class IPAPassCompressionWorkloadAnalysis extends IPAPass {
}
return map != null;
-
}
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
index 9f67f19..6b51a51 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
@@ -141,21 +141,17 @@ public class RewriteCompressedReblock extends StatementBlockRewriteRule {
return satisfies;
}
- private static boolean satisfiesAggressiveCompressionCondition(Hop hop) {
- boolean satisfies = false;
+ public static boolean satisfiesAggressiveCompressionCondition(Hop hop) {
+ //size-independent conditions (robust against unknowns)
+ boolean satisfies = HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) //matrix (no vector) ctable
+ && hop.getInput(0).getDataType().isMatrix() && hop.getInput(1).getDataType().isMatrix();
+ //size-dependent conditions
if(satisfiesSizeConstraintsForCompression(hop)) {
satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD);
satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL);
- satisfies |= HopRewriteUtils.isBinary(hop,
- OpOp2.EQUAL,
- OpOp2.NOTEQUAL,
- OpOp2.LESS,
- OpOp2.LESSEQUAL,
- OpOp2.GREATER,
- OpOp2.GREATEREQUAL,
- OpOp2.AND,
- OpOp2.OR,
- OpOp2.MODULUS);
+ satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS,
+ OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.AND, OpOp2.OR, OpOp2.MODULUS);
+ satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE);
}
if(LOG.isDebugEnabled() && satisfies)
LOG.debug("Operation Satisfies: " + hop);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
index a705c78..7193d48 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
@@ -24,7 +24,6 @@ import java.io.DataOutput;
import java.io.IOException;
import org.apache.commons.lang.NotImplementedException;
-import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFP64;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -50,7 +49,11 @@ public class MatrixBlockDictionary extends ADictionary {
@Override
public double[] getValues() {
- throw new DMLCompressionException("Get Values should not be called when you have a MatrixBlockDictionary");
+ // FIXME fix MinMaxGroup Initialization to avoid conversion to dense
+ if( !_data.isInSparseFormat() )
+ _data.sparseToDense();
+ return _data.getDenseBlockValues();
+ //throw new DMLCompressionException("Get Values should not be called when you have a MatrixBlockDictionary");
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index e77d1a0..b37acc1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -58,7 +59,12 @@ import org.apache.sysds.runtime.compress.workload.AWTreeNode.WTNodeType;
public class WorkloadAnalyzer {
private static final Log LOG = LogFactory.getLog(WorkloadAnalyzer.class.getName());
-
+ // indicator for more aggressive compression of intermediates
+ public static boolean ALLOW_INTERMEDIATE_CANDIDATES = false;
+ // avoid wtree construction for assumptionly already compressed intermediates
+ // (due to conditional control flow this might miss compression opportunities)
+ public static boolean PRUNE_COMPRESSED_INTERMEDIATES = true;
+
private final Set<Hop> visited;
private final Set<Long> compressed;
private final Set<Long> transposed;
@@ -69,16 +75,23 @@ public class WorkloadAnalyzer {
private final List<Hop> decompressHops;
public static Map<Long, WTreeRoot> getAllCandidateWorkloads(DMLProgram prog) {
- // extract all compression candidates from program
+ // extract all compression candidates from program (in program order)
List<Hop> candidates = getCandidates(prog);
-
+
// for each candidate, create pruned workload tree
- // TODO memoization of processed subtree if overlap
+ List<WorkloadAnalyzer> allWAs = new LinkedList<>();
Map<Long, WTreeRoot> map = new HashMap<>();
for(Hop cand : candidates) {
- WTreeRoot tree = new WorkloadAnalyzer(prog).createWorkloadTree(cand);
-
+ //prune already covered candidate (intermediate already compressed)
+ if( PRUNE_COMPRESSED_INTERMEDIATES )
+ if( allWAs.stream().anyMatch(w -> w.containsCompressed(cand)) )
+ continue; //intermediate already compressed
+
+ //construct workload tree for candidate
+ WorkloadAnalyzer wa = new WorkloadAnalyzer(prog);
+ WTreeRoot tree = wa.createWorkloadTree(cand);
map.put(cand.getHopID(), tree);
+ allWAs.add(wa);
}
return map;
@@ -128,6 +141,10 @@ public class WorkloadAnalyzer {
return main;
}
+ protected boolean containsCompressed(Hop hop) {
+ return compressed.contains(hop.getHopID());
+ }
+
private static List<Hop> getCandidates(DMLProgram prog) {
List<Hop> candidates = new ArrayList<>();
for(StatementBlock sb : prog.getStatementBlocks()) {
@@ -191,7 +208,9 @@ public class WorkloadAnalyzer {
if(hop.isVisited())
return;
// evaluate and add candidates (type and size)
- if(RewriteCompressedReblock.satisfiesCompressionCondition(hop))
+ if( ( RewriteCompressedReblock.satisfiesAggressiveCompressionCondition(hop)
+ & ALLOW_INTERMEDIATE_CANDIDATES)
+ || RewriteCompressedReblock.satisfiesCompressionCondition(hop))
cands.add(hop);
// recursively process children (inputs)
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index eaa9a73..04fe79e 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -24,6 +24,7 @@ import static org.junit.Assert.fail;
import java.io.File;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.compress.workload.WorkloadAnalyzer;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -36,6 +37,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "WorkloadAnalysisMLogReg";
private final static String TEST_NAME2 = "WorkloadAnalysisLm";
private final static String TEST_NAME3 = "WorkloadAnalysisPCA";
+ private final static String TEST_NAME4 = "WorkloadAnalysisSliceLine";
private final static String TEST_DIR = "functions/compress/workload/";
private final static String TEST_CLASS_DIR = TEST_DIR + WorkloadAnalysisTest.class.getSimpleName() + "/";
@@ -45,40 +47,52 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"B"}));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"B"}));
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"B"}));
-
+ addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"B"}));
}
@Test
public void testMLogRegCP() {
- runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID, 2);
+ runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID, 2, false);
}
@Test
public void testLmSP() {
- runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SPARK, 2);
+ runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SPARK, 2, false);
}
@Test
public void testLmCP() {
- runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2);
+ runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false);
}
@Test
public void testPCASP() {
- runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SPARK, 1);
+ runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SPARK, 1, false);
}
@Test
public void testPCACP() {
- runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1);
+ runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1, false);
}
- private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount) {
+ @Test
+ public void testSliceLineCP1() {
+ runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 0, false);
+ }
+
+ @Test
+ public void testSliceLineCP2() {
+ runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 2, true);
+ }
+
+ private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates) {
ExecMode oldPlatform = setExecMode(mode);
-
+ boolean oldIntermediates = WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;
+
try {
loadTestConfiguration(getTestConfiguration(testname));
-
+ WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES = intermediates;
+
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[] {"-explain","-stats",
@@ -99,9 +113,11 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
.getCPHeavyHitterCount("compress") : Statistics.getCPHeavyHitterCount("sp_compress");
Assert.assertEquals(compressionCount, actualCompressionCount);
- Assert.assertTrue( mode == ExecMode.HYBRID ? heavyHittersContainsString("compress") : heavyHittersContainsString("sp_compress"));
- Assert.assertFalse(heavyHittersContainsString("m_scale"));
-
+ if( compressionCount > 0 )
+ Assert.assertTrue( mode == ExecMode.HYBRID ?
+ heavyHittersContainsString("compress") : heavyHittersContainsString("sp_compress"));
+ if( !testname.equals(TEST_NAME4) )
+ Assert.assertFalse(heavyHittersContainsString("m_scale"));
}
catch(Exception e) {
resetExecMode(oldPlatform);
@@ -109,6 +125,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
}
finally {
resetExecMode(oldPlatform);
+ WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES = oldIntermediates;
}
}
diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisSliceFinder.dml b/src/test/scripts/functions/compress/workload/WorkloadAnalysisSliceLine.dml
similarity index 52%
rename from src/test/scripts/functions/compress/workload/WorkloadAnalysisSliceFinder.dml
rename to src/test/scripts/functions/compress/workload/WorkloadAnalysisSliceLine.dml
index b287cbc..eeeba41 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisSliceFinder.dml
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisSliceLine.dml
@@ -19,30 +19,35 @@
#
#-------------------------------------------------------------
-X = read($1) + 1;
-Y = read($2);
-
-
-print("")
-print("MLogReg")
-
-[X_s,s,c] = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X_s, Y=Y, verbose=FALSE, maxi=2, maxii=2);
-[nn, P, acc] = multiLogRegPredict(X=X_s, B=B, Y=Y)
-
-
-[nn, C] = confusionMatrix(P, Y)
-print("Confusion: ")
-print(toString(C))
-
-
-print("")
-print("SliceFinder")
-
-e = Y == P
-
-[tk,tkc,d] = slicefinder(X=X, e=e, maxL = 2, verbose=TRUE)
-
-print("tk :\n" + toString(tk))
-print("tkc :\n" + toString(tkc))
+# data preparation
+FXY = read("./src/test/resources/datasets/Salaries.csv",
+ data_type="frame", format="csv", header=TRUE);
+F = FXY[,1:ncol(FXY)-1];
+y = as.matrix(FXY[,ncol(FXY)]);
+jspec= "{ ids:true, recode:[1,2,3,6],bin:["
+ +"{id:4, method:equi-width, numbins:14},"
+ +"{id:5, method:equi-width, numbins:12}]}"
+[X,M] = transformencode(target=F, spec=jspec);
+X = X[,2:ncol(X)]
+
+m = nrow(X)
+n = ncol(X)
+fdom = colMaxs(X);
+foffb = t(cumsum(t(fdom))) - fdom;
+foffe = t(cumsum(t(fdom)))
+rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1)
+cix = matrix(X + foffb, m*n, 1);
+X2 = table(rix, cix); #one-hot encoded
+
+# learn model
+B = lm(X=X2, y=y, verbose=FALSE);
+yhat = X2 %*% B;
+e = (y-yhat)^2;
+
+# call slice finding
+[TS,TR,d] = slicefinder(X=X, e=e, k=10,
+ alpha=0.95, minSup=4, tpEval=TRUE, verbose=TRUE);
+
+print("TS:\n" + toString(TS))
+print("TR:\n" + toString(TR))
print("Debug matrix:\n" + toString(d))