You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2020/09/05 20:43:05 UTC
[systemds] branch master updated: [SYSTEMDS-2650] Non-recursive
construction of HOPs from Lineage
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5596fcf [SYSTEMDS-2650] Non-recursive construction of HOPs from Lineage
5596fcf is described below
commit 5596fcf0d77946cf11ff34085e8879552dd852be
Author: arnabp <ar...@tugraz.at>
AuthorDate: Sat Sep 5 22:37:22 2020 +0200
[SYSTEMDS-2650] Non-recursive construction of HOPs from Lineage
This patch implements a non-recursive version of HOP dag construction
from lineage dag, which fixes the stack overflow while re-computing
from lineage.
---
.../runtime/lineage/LineageRecomputeUtils.java | 325 ++++++++++++++++++++-
.../functions/lineage/LineageTraceDedupTest.java | 7 +-
2 files changed, 314 insertions(+), 18 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index fffc2dc..0df1651 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -25,8 +25,10 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Stack;
import java.util.stream.Collectors;
+import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
@@ -100,7 +102,7 @@ public class LineageRecomputeUtils {
root.resetVisitStatusNR();
Map<Long, Hop> operands = new HashMap<>();
Map<String, Hop> partDagRoots = new HashMap<>();
- rConstructHops(root, operands, partDagRoots, prog);
+ constructHopsNR(root, operands, partDagRoots, prog);
Hop out = HopRewriteUtils.createTransientWrite(
varname, operands.get(rootId));
@@ -134,17 +136,38 @@ public class LineageRecomputeUtils {
prog.addProgramBlock(pb);
}
-
- private static void rConstructHops(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog)
+ private static void constructHopsNR(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog)
+ {
+ //NOTE: This method follows the same non-recursive
+ //skeleton as explainLineageItemNR
+ Stack<LineageItem> stackItem = new Stack<>();
+ Stack<MutableInt> stackPos = new Stack<>();
+ stackItem.push(item); stackPos.push(new MutableInt(0));
+ while (!stackItem.empty()) {
+ LineageItem tmpItem = stackItem.peek();
+ MutableInt tmpPos = stackPos.peek();
+ //check ascent condition - no item processing
+ if (tmpItem.isVisited()) {
+ stackItem.pop(); stackPos.pop();
+ }
+ //check ascent condition - append item
+ else if( tmpItem.getInputs() == null
+ || tmpItem.getInputs().length <= tmpPos.intValue() ) {
+ constructSingleHop(tmpItem, operands, partDagRoots, prog);
+ stackItem.pop(); stackPos.pop();
+ tmpItem.setVisited();
+ }
+ //check descent condition
+ else if( tmpItem.getInputs() != null ) {
+ stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+ tmpPos.increment();
+ stackPos.push(new MutableInt(0));
+ }
+ }
+ }
+
+ private static void constructSingleHop(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog)
{
- if (item.isVisited())
- return;
-
- //recursively process children (ordering by data dependencies)
- if (!item.isLeaf())
- for (LineageItem c : item.getInputs())
- rConstructHops(c, operands, partDagRoots, prog);
-
//process current lineage item
//NOTE: we generate instructions from hops (but without rewrites) to automatically
//handle execution types, rmvar instructions, and rewiring of inputs/outputs
@@ -406,8 +429,6 @@ public class LineageRecomputeUtils {
break;
}
}
-
- item.setVisited();
}
// Construct and compile the function body
@@ -428,7 +449,7 @@ public class LineageRecomputeUtils {
for (int i=0; i<inputs.length; i++)
operands.put((long)i, HopRewriteUtils.createTransientRead(inputs[i], inpHops.get(i))); //order preserving
// Construct the Hop dag.
- rConstructHops(patchRoot, operands, null, null);
+ constructHopsNR(patchRoot, operands, null, null);
// TWrite the func return (pass dag root to copy datatype)
Hop out = HopRewriteUtils.createTransientWrite(outname, operands.get(patchRoot.getId()));
// Save the Hop dag
@@ -518,6 +539,282 @@ public class LineageRecomputeUtils {
throw new DMLRuntimeException("Unsupported opcode: "+item.getOpcode());
}
+ @Deprecated
+ @SuppressWarnings("unused")
+ private static void rConstructHops(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog)
+ {
+ if (item.isVisited())
+ return;
+
+ //recursively process children (ordering by data dependencies)
+ if (!item.isLeaf())
+ for (LineageItem c : item.getInputs())
+ rConstructHops(c, operands, partDagRoots, prog);
+
+ //process current lineage item
+ //NOTE: we generate instructions from hops (but without rewrites) to automatically
+ //handle execution types, rmvar instructions, and rewiring of inputs/outputs
+ switch (item.getType()) {
+ case Creation: {
+ if (item.getData().startsWith(LPLACEHOLDER)) {
+ long phId = Long.parseLong(item.getData().substring(3));
+ Hop input = operands.get(phId);
+ operands.remove(phId);
+ // Replace the placeholders with TReads
+ operands.put(item.getId(), input); // order preserving
+ break;
+ }
+ Instruction inst = InstructionParser.parseSingleInstruction(item.getData());
+
+ if (inst instanceof DataGenCPInstruction) {
+ DataGenCPInstruction rand = (DataGenCPInstruction) inst;
+ HashMap<String, Hop> params = new HashMap<>();
+ if( rand.getOpcode().equals("rand") ) {
+ if( rand.output.getDataType() == DataType.TENSOR)
+ params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+ else {
+ params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+ params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+ }
+ params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+ params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+ params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+ params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+ params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+ params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+ }
+ else if( rand.getOpcode().equals("seq") ) {
+ params.put(Statement.SEQ_FROM, new LiteralOp(rand.getFrom()));
+ params.put(Statement.SEQ_TO, new LiteralOp(rand.getTo()));
+ params.put(Statement.SEQ_INCR, new LiteralOp(rand.getIncr()));
+ }
+ Hop datagen = new DataGenOp(OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
+ new DataIdentifier("tmp"), params);
+ datagen.setBlocksize(rand.getBlocksize());
+ operands.put(item.getId(), datagen);
+ } else if (inst instanceof VariableCPInstruction
+ && ((VariableCPInstruction) inst).isCreateVariable()) {
+ String parts[] = InstructionUtils.getInstructionPartsWithValueType(inst.toString());
+ DataType dt = DataType.valueOf(parts[4]);
+ ValueType vt = dt == DataType.MATRIX ? ValueType.FP64 : ValueType.STRING;
+ HashMap<String, Hop> params = new HashMap<>();
+ params.put(DataExpression.IO_FILENAME, new LiteralOp(parts[2]));
+ params.put(DataExpression.READROWPARAM, new LiteralOp(Long.parseLong(parts[6])));
+ params.put(DataExpression.READCOLPARAM, new LiteralOp(Long.parseLong(parts[7])));
+ params.put(DataExpression.READNNZPARAM, new LiteralOp(Long.parseLong(parts[8])));
+ params.put(DataExpression.FORMAT_TYPE, new LiteralOp(parts[5]));
+ DataOp pread = new DataOp(parts[1].substring(5), dt, vt, OpOpData.PERSISTENTREAD, params);
+ pread.setFileName(parts[2]);
+ operands.put(item.getId(), pread);
+ }
+ else if (inst instanceof RandSPInstruction) {
+ RandSPInstruction rand = (RandSPInstruction) inst;
+ HashMap<String, Hop> params = new HashMap<>();
+ if (rand.output.getDataType() == DataType.TENSOR)
+ params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+ else {
+ params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+ params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+ }
+ params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+ params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+ params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+ params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+ params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+ params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+ Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
+ datagen.setBlocksize(rand.getBlocksize());
+ operands.put(item.getId(), datagen);
+ }
+ break;
+ }
+ case Dedup: {
+ // Create function call for each dedup entry
+ String[] parts = item.getOpcode().split(LineageDedupUtils.DEDUP_DELIM); //e.g. dedup_R_SB13_0
+ String name = parts[2] + parts[1] + parts[3]; //loopId + outVar + pathId
+ List<Hop> finputs = Arrays.stream(item.getInputs())
+ .map(inp -> operands.get(inp.getId())).collect(Collectors.toList());
+ String[] inputNames = new String[item.getInputs().length];
+ for (int i=0; i<item.getInputs().length; i++)
+ inputNames[i] = LPLACEHOLDER + i; //e.g. IN#0, IN#1
+ Hop funcOp = new FunctionOp(FunctionType.DML, DMLProgram.DEFAULT_NAMESPACE,
+ name, inputNames, finputs, new String[] {parts[1]}, false);
+
+ // Cut the Hop dag after function calls
+ partDagRoots.put(parts[1], funcOp);
+ // Compile the dag and save
+ constructBasicBlock(partDagRoots, parts[1], prog);
+
+ // Construct a Hop dag for the function body from the dedup patch, and compile
+ Hop output = constructHopsDedupPatch(parts, inputNames, finputs, prog);
+ // Create a TRead on the function o/p as a leaf for the next Hop dag
+ // Use the function body root/return hop to propagate right data type
+ operands.put(item.getId(), HopRewriteUtils.createTransientRead(parts[1], output));
+ break;
+ }
+ case Instruction: {
+ CPType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+ SPType stype = InstructionUtils.getSPTypeByOpcode(item.getOpcode());
+
+ if (ctype != null) {
+ switch (ctype) {
+ case AggregateUnary: {
+ Hop input = operands.get(item.getInputs()[0].getId());
+ Hop aggunary = InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
+ HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
+ HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
+ operands.put(item.getId(), aggunary);
+ break;
+ }
+ case AggregateBinary: {
+ Hop input1 = operands.get(item.getInputs()[0].getId());
+ Hop input2 = operands.get(item.getInputs()[1].getId());
+ Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
+ operands.put(item.getId(), aggbinary);
+ break;
+ }
+ case AggregateTernary: {
+ Hop input1 = operands.get(item.getInputs()[0].getId());
+ Hop input2 = operands.get(item.getInputs()[1].getId());
+ Hop input3 = operands.get(item.getInputs()[2].getId());
+ Hop aggternary = HopRewriteUtils.createSum(
+ HopRewriteUtils.createBinary(
+ HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
+ input3, OpOp2.MULT));
+ operands.put(item.getId(), aggternary);
+ break;
+ }
+ case Unary:
+ case Builtin: {
+ Hop input = operands.get(item.getInputs()[0].getId());
+ Hop unary = HopRewriteUtils.createUnary(input, item.getOpcode());
+ operands.put(item.getId(), unary);
+ break;
+ }
+ case Reorg: {
+ operands.put(item.getId(), HopRewriteUtils.createReorg(
+ operands.get(item.getInputs()[0].getId()), item.getOpcode()));
+ break;
+ }
+ case Reshape: {
+ ArrayList<Hop> inputs = new ArrayList<>();
+ for(int i=0; i<5; i++)
+ inputs.add(operands.get(item.getInputs()[i].getId()));
+ operands.put(item.getId(), HopRewriteUtils.createReorg(inputs, ReOrgOp.RESHAPE));
+ break;
+ }
+ case Binary: {
+ //handle special cases of binary operations
+ String opcode = ("^2".equals(item.getOpcode())
+ || "*2".equals(item.getOpcode())) ?
+ item.getOpcode().substring(0, 1) : item.getOpcode();
+ Hop input1 = operands.get(item.getInputs()[0].getId());
+ Hop input2 = operands.get(item.getInputs()[1].getId());
+ Hop binary = HopRewriteUtils.createBinary(input1, input2, opcode);
+ operands.put(item.getId(), binary);
+ break;
+ }
+ case Ternary: {
+ operands.put(item.getId(), HopRewriteUtils.createTernary(
+ operands.get(item.getInputs()[0].getId()),
+ operands.get(item.getInputs()[1].getId()),
+ operands.get(item.getInputs()[2].getId()), item.getOpcode()));
+ break;
+ }
+ case Ctable: { //e.g., ctable
+ if( item.getInputs().length==3 )
+ operands.put(item.getId(), HopRewriteUtils.createTernary(
+ operands.get(item.getInputs()[0].getId()),
+ operands.get(item.getInputs()[1].getId()),
+ operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
+ else if( item.getInputs().length==5 )
+ operands.put(item.getId(), HopRewriteUtils.createTernary(
+ operands.get(item.getInputs()[0].getId()),
+ operands.get(item.getInputs()[1].getId()),
+ operands.get(item.getInputs()[2].getId()),
+ operands.get(item.getInputs()[3].getId()),
+ operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
+ break;
+ }
+ case BuiltinNary: {
+ String opcode = item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
+ operands.put(item.getId(), HopRewriteUtils.createNary(
+ OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
+ break;
+ }
+ case ParameterizedBuiltin: {
+ operands.put(item.getId(), constructParameterizedBuiltinOp(item, operands));
+ break;
+ }
+ case MatrixIndexing: {
+ operands.put(item.getId(), constructIndexingOp(item, operands));
+ break;
+ }
+ case MMTSJ: {
+ //TODO handling of tsmm type left and right -> placement transpose
+ Hop input = operands.get(item.getInputs()[0].getId());
+ Hop aggunary = HopRewriteUtils.createMatrixMultiply(
+ HopRewriteUtils.createTranspose(input), input);
+ operands.put(item.getId(), aggunary);
+ break;
+ }
+ case Variable: {
+ if( item.getOpcode().startsWith("cast") )
+ operands.put(item.getId(), HopRewriteUtils.createUnary(
+ operands.get(item.getInputs()[0].getId()),
+ OpOp1.valueOfByOpcode(item.getOpcode())));
+ else //cpvar, write
+ operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
+ break;
+ }
+ default:
+ throw new DMLRuntimeException("Unsupported instruction "
+ + "type: " + ctype.name() + " (" + item.getOpcode() + ").");
+ }
+ }
+ else if( stype != null ) {
+ switch(stype) {
+ case Reblock: {
+ Hop input = operands.get(item.getInputs()[0].getId());
+ input.setBlocksize(ConfigurationManager.getBlocksize());
+ input.setRequiresReblock(true);
+ operands.put(item.getId(), input);
+ break;
+ }
+ case Checkpoint: {
+ Hop input = operands.get(item.getInputs()[0].getId());
+ operands.put(item.getId(), input);
+ break;
+ }
+ case MatrixIndexing: {
+ operands.put(item.getId(), constructIndexingOp(item, operands));
+ break;
+ }
+ case GAppend: {
+ operands.put(item.getId(), HopRewriteUtils.createBinary(
+ operands.get(item.getInputs()[0].getId()),
+ operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
+ break;
+ }
+ default:
+ throw new DMLRuntimeException("Unsupported instruction "
+ + "type: " + stype.name() + " (" + item.getOpcode() + ").");
+ }
+ }
+ else
+ throw new DMLRuntimeException("Unsupported instruction: " + item.getOpcode());
+ break;
+ }
+ case Literal: {
+ CPOperand op = new CPOperand(item.getData());
+ operands.put(item.getId(), ScalarObjectFactory
+ .createLiteralOp(op.getValueType(), op.getName()));
+ break;
+ }
+ }
+
+ item.setVisited();
+ }
// Below class represents a single loop and contains related data
// that are needed for recomputation.
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
index 18da399..3b1ae65 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
@@ -95,12 +95,11 @@ public class LineageTraceDedupTest extends AutomatedTestBase
testLineageTrace(TEST_NAME5);
}
- /*@Test
+ @Test
public void testLineageTrace6() {
testLineageTrace(TEST_NAME6);
- }*/
- //FIXME: stack overflow only when ran the full package
-
+ }
+
@Test
public void testLineageTrace7() {
testLineageTrace(TEST_NAME7);