You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2020/09/05 20:43:05 UTC

[systemds] branch master updated: [SYSTEMDS-2650] Non-recursive construction of HOPs from Lineage

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5596fcf  [SYSTEMDS-2650] Non-recursive construction of HOPs from Lineage
5596fcf is described below

commit 5596fcf0d77946cf11ff34085e8879552dd852be
Author: arnabp <ar...@tugraz.at>
AuthorDate: Sat Sep 5 22:37:22 2020 +0200

    [SYSTEMDS-2650] Non-recursive construction of HOPs from Lineage
    
    This patch implements a non-recursive version of HOP dag construction
    from lineage dag, which fixes the stack overflow while re-computing
    from lineage.
---
 .../runtime/lineage/LineageRecomputeUtils.java     | 325 ++++++++++++++++++++-
 .../functions/lineage/LineageTraceDedupTest.java   |   7 +-
 2 files changed, 314 insertions(+), 18 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index fffc2dc..0df1651 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -25,8 +25,10 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Stack;
 import java.util.stream.Collectors;
 
+import org.apache.commons.lang3.mutable.MutableInt;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.OpOp1;
@@ -100,7 +102,7 @@ public class LineageRecomputeUtils {
 		root.resetVisitStatusNR();
 		Map<Long, Hop> operands = new HashMap<>();
 		Map<String, Hop> partDagRoots = new HashMap<>();
-		rConstructHops(root, operands, partDagRoots, prog);
+		constructHopsNR(root, operands, partDagRoots, prog);
 		Hop out = HopRewriteUtils.createTransientWrite(
 			varname, operands.get(rootId));
 		
@@ -134,17 +136,38 @@ public class LineageRecomputeUtils {
 		prog.addProgramBlock(pb);
 	}
 	
-	
-	private static void rConstructHops(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog) 
+	private static void constructHopsNR(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog) 
+	{
+		//NOTE: This method follows the same non-recursive 
+		//skeleton as explainLineageItemNR
+		Stack<LineageItem> stackItem = new Stack<>();
+		Stack<MutableInt> stackPos = new Stack<>();
+		stackItem.push(item); stackPos.push(new MutableInt(0));
+		while (!stackItem.empty()) {
+			LineageItem tmpItem = stackItem.peek();
+			MutableInt tmpPos = stackPos.peek();
+			//check ascent condition - no item processing
+			if (tmpItem.isVisited()) {
+				stackItem.pop(); stackPos.pop();
+			}
+			//check ascent condition - append item
+			else if( tmpItem.getInputs() == null 
+				|| tmpItem.getInputs().length <= tmpPos.intValue() ) {
+				constructSingleHop(tmpItem, operands, partDagRoots, prog);
+				stackItem.pop(); stackPos.pop();
+				tmpItem.setVisited();
+			}
+			//check descent condition
+			else if( tmpItem.getInputs() != null ) {
+				stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+				tmpPos.increment();
+				stackPos.push(new MutableInt(0));
+			}
+		}
+	}
+
+	private static void constructSingleHop(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog) 
 	{
-		if (item.isVisited())
-			return;
-		
-		//recursively process children (ordering by data dependencies)
-		if (!item.isLeaf())
-			for (LineageItem c : item.getInputs())
-				rConstructHops(c, operands, partDagRoots, prog);
-		
 		//process current lineage item
 		//NOTE: we generate instructions from hops (but without rewrites) to automatically
 		//handle execution types, rmvar instructions, and rewiring of inputs/outputs
@@ -406,8 +429,6 @@ public class LineageRecomputeUtils {
 				break;
 			}
 		}
-		
-		item.setVisited();
 	}
 
 	// Construct and compile the function body
@@ -428,7 +449,7 @@ public class LineageRecomputeUtils {
 		for (int i=0; i<inputs.length; i++)
 			operands.put((long)i, HopRewriteUtils.createTransientRead(inputs[i], inpHops.get(i))); //order preserving
 		// Construct the Hop dag.
-		rConstructHops(patchRoot, operands, null, null);
+		constructHopsNR(patchRoot, operands, null, null);
 		// TWrite the func return (pass dag root to copy datatype)
 		Hop out = HopRewriteUtils.createTransientWrite(outname, operands.get(patchRoot.getId()));
 		// Save the Hop dag
@@ -518,6 +539,282 @@ public class LineageRecomputeUtils {
 		throw new DMLRuntimeException("Unsupported opcode: "+item.getOpcode());
 	}
 	
+	@Deprecated
+	@SuppressWarnings("unused")
+	private static void rConstructHops(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog) 
+	{
+		if (item.isVisited())
+			return;
+		
+		//recursively process children (ordering by data dependencies)
+		if (!item.isLeaf())
+			for (LineageItem c : item.getInputs())
+				rConstructHops(c, operands, partDagRoots, prog);
+		
+		//process current lineage item
+		//NOTE: we generate instructions from hops (but without rewrites) to automatically
+		//handle execution types, rmvar instructions, and rewiring of inputs/outputs
+		switch (item.getType()) {
+			case Creation: {
+				if (item.getData().startsWith(LPLACEHOLDER)) {
+					long phId = Long.parseLong(item.getData().substring(3));
+					Hop input = operands.get(phId);
+					operands.remove(phId);
+					// Replace the placeholders with TReads
+					operands.put(item.getId(), input); // order preserving
+					break;
+				}
+				Instruction inst = InstructionParser.parseSingleInstruction(item.getData());
+				
+				if (inst instanceof DataGenCPInstruction) {
+					DataGenCPInstruction rand = (DataGenCPInstruction) inst;
+					HashMap<String, Hop> params = new HashMap<>();
+					if( rand.getOpcode().equals("rand") ) {
+						if( rand.output.getDataType() == DataType.TENSOR)
+							params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+						else {
+							params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+							params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+						}
+						params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+						params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+						params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+						params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+						params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+						params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+					}
+					else if( rand.getOpcode().equals("seq") ) {
+						params.put(Statement.SEQ_FROM, new LiteralOp(rand.getFrom()));
+						params.put(Statement.SEQ_TO, new LiteralOp(rand.getTo()));
+						params.put(Statement.SEQ_INCR, new LiteralOp(rand.getIncr()));
+					}
+					Hop datagen = new DataGenOp(OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
+						new DataIdentifier("tmp"), params);
+					datagen.setBlocksize(rand.getBlocksize());
+					operands.put(item.getId(), datagen);
+				} else if (inst instanceof VariableCPInstruction
+						&& ((VariableCPInstruction) inst).isCreateVariable()) {
+					String parts[] = InstructionUtils.getInstructionPartsWithValueType(inst.toString());
+					DataType dt = DataType.valueOf(parts[4]);
+					ValueType vt = dt == DataType.MATRIX ? ValueType.FP64 : ValueType.STRING;
+					HashMap<String, Hop> params = new HashMap<>();
+					params.put(DataExpression.IO_FILENAME, new LiteralOp(parts[2]));
+					params.put(DataExpression.READROWPARAM, new LiteralOp(Long.parseLong(parts[6])));
+					params.put(DataExpression.READCOLPARAM, new LiteralOp(Long.parseLong(parts[7])));
+					params.put(DataExpression.READNNZPARAM, new LiteralOp(Long.parseLong(parts[8])));
+					params.put(DataExpression.FORMAT_TYPE, new LiteralOp(parts[5]));
+					DataOp pread = new DataOp(parts[1].substring(5), dt, vt, OpOpData.PERSISTENTREAD, params);
+					pread.setFileName(parts[2]);
+					operands.put(item.getId(), pread);
+				}
+				else if  (inst instanceof RandSPInstruction) {
+					RandSPInstruction rand = (RandSPInstruction) inst;
+					HashMap<String, Hop> params = new HashMap<>();
+					if (rand.output.getDataType() == DataType.TENSOR)
+						params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+					else {
+						params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+						params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+					}
+					params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+					params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+					params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+					params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+					params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+					params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+					Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
+					datagen.setBlocksize(rand.getBlocksize());
+					operands.put(item.getId(), datagen);
+				}
+				break;
+			}
+			case Dedup: {
+				// Create function call for each dedup entry 
+				String[] parts = item.getOpcode().split(LineageDedupUtils.DEDUP_DELIM); //e.g. dedup_R_SB13_0
+				String name = parts[2] + parts[1] + parts[3];  //loopId + outVar + pathId
+				List<Hop> finputs = Arrays.stream(item.getInputs())
+						.map(inp -> operands.get(inp.getId())).collect(Collectors.toList());
+				String[] inputNames = new String[item.getInputs().length];
+				for (int i=0; i<item.getInputs().length; i++)
+					inputNames[i] = LPLACEHOLDER + i;  //e.g. IN#0, IN#1
+				Hop funcOp = new FunctionOp(FunctionType.DML, DMLProgram.DEFAULT_NAMESPACE, 
+						name, inputNames, finputs, new String[] {parts[1]}, false);
+
+				// Cut the Hop dag after function calls 
+				partDagRoots.put(parts[1], funcOp);
+				// Compile the dag and save
+				constructBasicBlock(partDagRoots, parts[1], prog);
+
+				// Construct a Hop dag for the function body from the dedup patch, and compile
+				Hop output = constructHopsDedupPatch(parts, inputNames, finputs, prog);
+				// Create a TRead on the function o/p as a leaf for the next Hop dag
+				// Use the function body root/return hop to propagate right data type
+				operands.put(item.getId(), HopRewriteUtils.createTransientRead(parts[1], output));
+				break;
+			}
+			case Instruction: {
+				CPType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+				SPType stype = InstructionUtils.getSPTypeByOpcode(item.getOpcode());
+				
+				if (ctype != null) {
+					switch (ctype) {
+						case AggregateUnary: {
+							Hop input = operands.get(item.getInputs()[0].getId());
+							Hop aggunary = InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
+								HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
+								HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
+							operands.put(item.getId(), aggunary);
+							break;
+						}
+						case AggregateBinary: {
+							Hop input1 = operands.get(item.getInputs()[0].getId());
+							Hop input2 = operands.get(item.getInputs()[1].getId());
+							Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
+							operands.put(item.getId(), aggbinary);
+							break;
+						}
+						case AggregateTernary: {
+							Hop input1 = operands.get(item.getInputs()[0].getId());
+							Hop input2 = operands.get(item.getInputs()[1].getId());
+							Hop input3 = operands.get(item.getInputs()[2].getId());
+							Hop aggternary = HopRewriteUtils.createSum(
+								HopRewriteUtils.createBinary(
+								HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
+								input3, OpOp2.MULT));
+							operands.put(item.getId(), aggternary);
+							break;
+						}
+						case Unary:
+						case Builtin: {
+							Hop input = operands.get(item.getInputs()[0].getId());
+							Hop unary = HopRewriteUtils.createUnary(input, item.getOpcode());
+							operands.put(item.getId(), unary);
+							break;
+						}
+						case Reorg: {
+							operands.put(item.getId(), HopRewriteUtils.createReorg(
+								operands.get(item.getInputs()[0].getId()), item.getOpcode()));
+							break;
+						}
+						case Reshape: {
+							ArrayList<Hop> inputs = new ArrayList<>();
+							for(int i=0; i<5; i++)
+								inputs.add(operands.get(item.getInputs()[i].getId()));
+							operands.put(item.getId(), HopRewriteUtils.createReorg(inputs, ReOrgOp.RESHAPE));
+							break;
+						}
+						case Binary: {
+							//handle special cases of binary operations 
+							String opcode = ("^2".equals(item.getOpcode()) 
+								|| "*2".equals(item.getOpcode())) ? 
+								item.getOpcode().substring(0, 1) : item.getOpcode();
+							Hop input1 = operands.get(item.getInputs()[0].getId());
+							Hop input2 = operands.get(item.getInputs()[1].getId());
+							Hop binary = HopRewriteUtils.createBinary(input1, input2, opcode);
+							operands.put(item.getId(), binary);
+							break;
+						}
+						case Ternary: {
+							operands.put(item.getId(), HopRewriteUtils.createTernary(
+								operands.get(item.getInputs()[0].getId()), 
+								operands.get(item.getInputs()[1].getId()), 
+								operands.get(item.getInputs()[2].getId()), item.getOpcode()));
+							break;
+						}
+						case Ctable: { //e.g., ctable 
+							if( item.getInputs().length==3 )
+								operands.put(item.getId(), HopRewriteUtils.createTernary(
+									operands.get(item.getInputs()[0].getId()),
+									operands.get(item.getInputs()[1].getId()),
+									operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
+							else if( item.getInputs().length==5 )
+								operands.put(item.getId(), HopRewriteUtils.createTernary(
+									operands.get(item.getInputs()[0].getId()),
+									operands.get(item.getInputs()[1].getId()),
+									operands.get(item.getInputs()[2].getId()),
+									operands.get(item.getInputs()[3].getId()),
+									operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
+							break;
+						}
+						case BuiltinNary: {
+							String opcode = item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
+							operands.put(item.getId(), HopRewriteUtils.createNary(
+								OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
+							break;
+						}
+						case ParameterizedBuiltin: {
+							operands.put(item.getId(), constructParameterizedBuiltinOp(item, operands));
+							break;
+						}
+						case MatrixIndexing: {
+							operands.put(item.getId(), constructIndexingOp(item, operands));
+							break;
+						}
+						case MMTSJ: {
+							//TODO handling of tsmm type left and right -> placement transpose
+							Hop input = operands.get(item.getInputs()[0].getId());
+							Hop aggunary = HopRewriteUtils.createMatrixMultiply(
+								HopRewriteUtils.createTranspose(input), input);
+							operands.put(item.getId(), aggunary);
+							break;
+						}
+						case Variable: {
+							if( item.getOpcode().startsWith("cast") )
+								operands.put(item.getId(), HopRewriteUtils.createUnary(
+									operands.get(item.getInputs()[0].getId()),
+									OpOp1.valueOfByOpcode(item.getOpcode())));
+							else //cpvar, write
+								operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
+							break;
+						}
+						default:
+							throw new DMLRuntimeException("Unsupported instruction "
+								+ "type: " + ctype.name() + " (" + item.getOpcode() + ").");
+					}
+				}
+				else if( stype != null ) {
+					switch(stype) {
+						case Reblock: {
+							Hop input = operands.get(item.getInputs()[0].getId());
+							input.setBlocksize(ConfigurationManager.getBlocksize());
+							input.setRequiresReblock(true);
+							operands.put(item.getId(), input);
+							break;
+						}
+						case Checkpoint: {
+							Hop input = operands.get(item.getInputs()[0].getId());
+							operands.put(item.getId(), input);
+							break;
+						}
+						case MatrixIndexing: {
+							operands.put(item.getId(), constructIndexingOp(item, operands));
+							break;
+						}
+						case GAppend: {
+							operands.put(item.getId(), HopRewriteUtils.createBinary(
+								operands.get(item.getInputs()[0].getId()),
+								operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
+							break;
+						}
+						default:
+							throw new DMLRuntimeException("Unsupported instruction "
+								+ "type: " + stype.name() + " (" + item.getOpcode() + ").");
+					}
+				}
+				else
+					throw new DMLRuntimeException("Unsupported instruction: " + item.getOpcode());
+				break;
+			}
+			case Literal: {
+				CPOperand op = new CPOperand(item.getData());
+				operands.put(item.getId(), ScalarObjectFactory
+					.createLiteralOp(op.getValueType(), op.getName()));
+				break;
+			}
+		}
+		
+		item.setVisited();
+	}
 	
 	// Below class represents a single loop and contains related data
 	// that are needed for recomputation.
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
index 18da399..3b1ae65 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
@@ -95,12 +95,11 @@ public class LineageTraceDedupTest extends AutomatedTestBase
 		testLineageTrace(TEST_NAME5);
 	}
 	
-	/*@Test
+	@Test
 	public void testLineageTrace6() {
 		testLineageTrace(TEST_NAME6);
-	}*/
-	//FIXME: stack overflow only when ran the full package
-	
+	}
+
 	@Test
 	public void testLineageTrace7() {
 		testLineageTrace(TEST_NAME7);