You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2021/08/06 12:18:14 UTC

[systemds] branch master updated: [SYSTEMDS-3084] Basic Auto-differentiation for Affine layer

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c2c8864  [SYSTEMDS-3084] Basic Auto-differentiation for Affine layer
c2c8864 is described below

commit c2c88645ac45b4106979841bbd3261ab3cc30169
Author: Shafaq Siddiqi <sh...@tugraz.at>
AuthorDate: Fri Aug 6 14:14:31 2021 +0200

    [SYSTEMDS-3084] Basic Auto-differentiation for Affine layer
    
    This patch introduces autoDiff builtin. autoDiff takes the output
    and lineage trace of the last layer and a list of weights and
    biases, and returns their derivatives.
    It internally uses the lineage trace of the forward layer to
    construct the Hop dags for the derivatives (reuse common
    sub-dags), compile those and execute to produce the outputs.
    Current support is limited to Affine layer and local execution.
    
    AMLS project SS2021
    Closes #1350.
---
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  |   3 +-
 .../apache/sysds/lops/ParameterizedBuiltin.java    |   8 +-
 .../org/apache/sysds/parser/DMLTranslator.java     |   2 +-
 .../ParameterizedBuiltinFunctionExpression.java    |  23 +-
 .../sysds/runtime/functionobjects/Builtin.java     |   3 +-
 .../functionobjects/ParameterizedBuiltin.java      |   5 +-
 .../runtime/instructions/CPInstructionParser.java  |   3 +-
 .../runtime/instructions/SPInstructionParser.java  |   1 +
 .../cp/ParameterizedBuiltinCPInstruction.java      |  32 +--
 .../org/apache/sysds/runtime/util/AutoDiff.java    | 262 +++++++++++++++++++++
 .../sysds/test/functions/builtin/AutoDiffTest.java |  70 ++++++
 src/test/scripts/functions/builtin/autoDiff.dml    |  58 +++++
 14 files changed, 449 insertions(+), 24 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index a063b47..f2b6c6a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -272,6 +272,7 @@ public enum Builtins {
 	XOR("xor", false),
 
 	//parameterized builtin functions
+	AUTODIFF("autoDiff", false, true),
 	CDF("cdf", false, true),
 	CVLM("cvlm", true, false),
 	GROUPEDAGG("aggregate", "groupedAggregate", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index da53091..d15adad 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -466,7 +466,7 @@ public class Types
 	}
 	
 	public enum ParamBuiltinOp {
-		INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
+		AUTODIFF, INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
 		LOWER_TRI, UPPER_TRI,
 		TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
 		TOKENIZE, TOSTRING, LIST, PARAMSERV
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index b5be675..8c9666b 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -196,7 +196,8 @@ public class ParameterizedBuiltinOp extends MultiThreadedHop {
 			case TRANSFORMMETA:
 			case TOSTRING:
 			case PARAMSERV:
-			case LIST: {
+			case LIST:
+			case AUTODIFF:{
 				ExecType et = optFindExecType();
 				ParameterizedBuiltin pbilop = new ParameterizedBuiltin(
 					inputlops, _op, getDataType(), getValueType(), et);
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index 7c39548..a0f9331 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -171,7 +171,7 @@ public class ParameterizedBuiltin extends Lop
 					
 					sb.append(OPERAND_DELIMITOR);
 				}
-				
+
 				break;
 
 			case TOKENIZE:
@@ -184,6 +184,12 @@ public class ParameterizedBuiltin extends Lop
 				sb.append(compileGenericParamMap(_inputParams));
 				break;
 			}
+			case AUTODIFF: {
+				sb.append("autoDiff"); //opcode
+				sb.append(OPERAND_DELIMITOR);
+				sb.append(compileGenericParamMap(_inputParams));
+				break;
+			}
 			case LIST: {
 				sb.append("nvlist"); //opcode
 				sb.append(OPERAND_DELIMITOR);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 89fc4ca..430fe7f 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2008,6 +2008,7 @@ public class DMLTranslator
 			case TRANSFORMCOLMAP:
 			case TRANSFORMMETA:
 			case PARAMSERV:
+			case AUTODIFF:
 				currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(),
 					target.getValueType(), ParamBuiltinOp.valueOf(source.getOpCode().name()), paramHops);
 				break;
@@ -2029,7 +2030,6 @@ public class DMLTranslator
 						target.getValueType(), ParamBuiltinOp.TOSTRING, paramHops) :
 					HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), OpOp2.PLUS);
 				break;
-			
 			case LISTNV:
 				currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(),
 					target.getValueType(), ParamBuiltinOp.LIST, paramHops);
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index d074d0d..26a54ac 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -48,12 +48,14 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 	public static final String TF_FN_PARAM_DATA = "target";
 	public static final String TF_FN_PARAM_MTD2 = "meta";
 	public static final String TF_FN_PARAM_SPEC = "spec";
+	public static final String LINEAGE_TRACE = "lineage";
 	public static final String TF_FN_PARAM_MTD = "transformPath"; //NOTE MB: for backwards compatibility
 	
 	public static HashMap<Builtins, ParamBuiltinOp> pbHopMap;
 	static {
 		pbHopMap = new HashMap<>();
 		
+		pbHopMap.put(Builtins.AUTODIFF, ParamBuiltinOp.AUTODIFF);
 		pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG);
 		pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY);
 		pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE);
@@ -231,7 +233,10 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 		case TOSTRING:
 			validateCastAsString(output, conditional);
 			break;
-		
+
+		case AUTODIFF:
+			validateAutoDiff(output, conditional);
+			break;
 		case LISTNV:
 			validateNamedList(output, conditional);
 			break;
@@ -251,6 +256,22 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 		}
 	}
 
+	private void validateAutoDiff(DataIdentifier output, boolean conditional) {
+		//validate data / metadata (recode maps)
+		checkDataType("lineage", LINEAGE_TRACE, DataType.LIST, conditional);
+
+		//validate specification
+		checkDataValueType(false, "lineage", LINEAGE_TRACE, DataType.LIST, ValueType.UNKNOWN, conditional);
+		HashMap<String, Expression> varParams = getVarParams();
+		// set output characteristics
+		output.setDataType(DataType.LIST);
+		output.setValueType(ValueType.UNKNOWN);
+		// TODO dimension should be set to -1 but could not set due to lineage parsing error in Spark contetx
+		output.setDimensions(varParams.size(), 1);
+		// output.setDimensions(-1, 1);
+		output.setBlocksize(-1);
+	}
+
 	@Override
 	public void validateExpression(MultiAssignmentStatement stmt, HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional)
 	{
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 46acba4..45b6c33 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -47,7 +47,7 @@ public class Builtin extends ValueFunction
 {
 	private static final long serialVersionUID = 3836744687789840574L;
 	
-	public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
+	public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
 		MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
 		STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
 		TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, DROP_INVALID_LENGTH, MAP,
@@ -61,6 +61,7 @@ public class Builtin extends ValueFunction
 	static public HashMap<String, BuiltinCode> String2BuiltinCode;
 	static {
 		String2BuiltinCode = new HashMap<>();
+		String2BuiltinCode.put( "autoDiff"    , BuiltinCode.AUTODIFF);
 		String2BuiltinCode.put( "sin"    , BuiltinCode.SIN);
 		String2BuiltinCode.put( "cos"    , BuiltinCode.COS);
 		String2BuiltinCode.put( "tan"    , BuiltinCode.TAN);
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
index c15f6da..d800efd 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
@@ -43,7 +43,7 @@ public class ParameterizedBuiltin extends ValueFunction
 	private static final long serialVersionUID = -7987603644903675052L;
 	
 	public enum ParameterizedBuiltinCode { 
-		CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI,
+		AUTODIFF, CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI,
 		TOKENIZE, TRANSFORMAPPLY, TRANSFORMDECODE, PARAMSERV }
 	public enum ProbabilityDistributionCode { 
 		INVALID, NORMAL, EXP, CHISQ, F, T }
@@ -185,6 +185,9 @@ public class ParameterizedBuiltin extends ValueFunction
 
 			case PARAMSERV:
 				return new ParameterizedBuiltin(ParameterizedBuiltinCode.PARAMSERV);
+
+			case AUTODIFF:
+				return new ParameterizedBuiltin(ParameterizedBuiltinCode.AUTODIFF);
 				
 			default:
 				throw new DMLRuntimeException("Invalid parameterized builtin code: " + code);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b07cefa..323fac1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -160,7 +160,7 @@ public class CPInstructionParser extends InstructionParser
 		String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
 		String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
 		String2CPInstructionType.put( "n+"  , CPType.BuiltinNary);
-		
+
 		String2CPInstructionType.put( "exp"   , CPType.Unary);
 		String2CPInstructionType.put( "abs"   , CPType.Unary);
 		String2CPInstructionType.put( "sin"   , CPType.Unary);
@@ -203,6 +203,7 @@ public class CPInstructionParser extends InstructionParser
 		String2CPInstructionType.put( "list",   CPType.BuiltinNary);
 		
 		// Parameterized Builtin Functions
+		String2CPInstructionType.put( "autoDiff" , CPType.ParameterizedBuiltin);
 		String2CPInstructionType.put("paramserv",       CPType.ParameterizedBuiltin);
 		String2CPInstructionType.put( "nvlist",         CPType.ParameterizedBuiltin);
 		String2CPInstructionType.put( "cdf",            CPType.ParameterizedBuiltin);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index d1eb4a7..a02490f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -262,6 +262,7 @@ public class SPInstructionParser extends InstructionParser
 		String2SPInstructionType.put( "isinf", SPType.Unary);
 
 		// Parameterized Builtin Functions
+		String2SPInstructionType.put( "autoDiff"   , SPType.ParameterizedBuiltin);
 		String2SPInstructionType.put( "groupedagg",     SPType.ParameterizedBuiltin);
 		String2SPInstructionType.put( "mapgroupedagg",  SPType.ParameterizedBuiltin);
 		String2SPInstructionType.put( "rmempty",        SPType.ParameterizedBuiltin);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f115b52..6de5878 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -19,14 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -37,12 +29,9 @@ import org.apache.sysds.lops.Lop;
 import org.apache.sysds.parser.ParameterizedBuiltinFunctionExpression;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
+import org.apache.sysds.runtime.controlprogram.caching.*;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -61,8 +50,13 @@ import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
 import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
 import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
+import org.apache.sysds.runtime.util.AutoDiff;
 import org.apache.sysds.runtime.util.DataConverter;
 
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
 public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction {
 	private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
 	private static final int TOSTRING_MAXROWS = 100;
@@ -91,7 +85,6 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
 	public static LinkedHashMap<String, String> constructParameterMap(String[] params) {
 		// process all elements in "params" except first(opcode) and last(output)
 		LinkedHashMap<String, String> paramMap = new LinkedHashMap<>();
-
 		// all parameters are of form <name=value>
 		String[] parts;
 		for(int i = 1; i <= params.length - 2; i++) {
@@ -150,7 +143,7 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
 		}
 		else if(opcode.equals("transformapply") || opcode.equals("transformdecode") ||
 			opcode.equals("transformcolmap") || opcode.equals("transformmeta") || opcode.equals("tokenize") ||
-			opcode.equals("toString") || opcode.equals("nvlist")) {
+			opcode.equals("toString") || opcode.equals("nvlist") || opcode.equals("autoDiff")) {
 			return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str);
 		}
 		else if("paramserv".equals(opcode)) {
@@ -178,6 +171,13 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
 			sores = new DoubleObject(result);
 			ec.setScalarOutput(output.getName(), sores);
 		}
+		else if(opcode.equalsIgnoreCase("autoDiff"))
+		{
+			ArrayList<Data> lineage = (ArrayList<Data>) ec.getListObject(params.get("lineage")).getData();
+			MatrixObject mo = ec.getMatrixObject(params.get("output"));
+			ListObject diffs = AutoDiff.getBackward(mo, lineage, ExecutionContextFactory.createContext());
+			ec.setVariable(output.getName(), diffs);
+		}
 		else if(opcode.equalsIgnoreCase("groupedagg")) {
 			// acquire locks
 			MatrixBlock target = ec.getMatrixInput(params.get(Statement.GAGG_TARGET));
@@ -501,7 +501,7 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
 			return Pair.of(output.getName(),
 				new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, meta, spec)));
 		}
-		else if (opcode.equalsIgnoreCase("nvlist")) {
+		else if (opcode.equalsIgnoreCase("nvlist") || opcode.equalsIgnoreCase("autoDiff")) {
 			List<String> names = new ArrayList<>(params.keySet());
 			CPOperand[] listOperands = names.stream().map(n -> ec.containsVariable(params.get(n)) 
 					? new CPOperand(n, ec.getVariable(params.get(n))) 
diff --git a/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
new file mode 100644
index 0000000..2178a13
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.util;
+
+import org.apache.commons.lang3.mutable.MutableInt;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.*;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionParser;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.*;
+import org.apache.sysds.runtime.instructions.spark.RandSPInstruction;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageParser;
+import org.apache.sysds.utils.Explain;
+
+import java.util.*;
+
+public class AutoDiff {
+	private static final String ADVARPREFIX = "adVar";
+	private static final boolean DEBUG = false;
+
+	public static ListObject getBackward(MatrixObject mo, ArrayList<Data> lineage, ExecutionContext adec) {
+
+		ArrayList<String> names = new ArrayList<String>();
+		// parse the lineage and take the number of instructions as for each instruction there is separate hop DAG
+		String lin = lineage.get(0).toString();
+		// get rid of foo flag
+		lin = lin.replace("foo", "");
+		List<Data>  data = parseNComputeAutoDiffFromLineage(mo, lin, names, adec);
+		return new ListObject(data, names);
+	}
+
+	public static List<Data> parseNComputeAutoDiffFromLineage(MatrixObject mo, String mainTrace,
+		ArrayList<String> names, ExecutionContext ec ) {
+
+		LineageItem root = LineageParser.parseLineageTrace(mainTrace);
+		if (DEBUG) {
+			System.out.println("Lineage trace of the forward pass");
+			System.out.println(mainTrace);
+		}
+		// Recursively construct hops
+		root.resetVisitStatusNR();
+		Map<Long, Hop> operands = new HashMap<>();
+		// set variable for input matrix
+		ec.setVariable("X", mo);
+		DataOp input = HopRewriteUtils.createTransientRead("X", mo);
+		// each instruction Hop is stored separately as each instruction creates a new differentiation
+		ArrayList<Hop> allHops = constructHopsNR(root, operands, input, names);
+
+		ArrayList<Data> results = new ArrayList<>();
+		for(int i=0; i< allHops.size(); i++) {
+			DataOp dop = HopRewriteUtils.createTransientWrite("advar"+i, allHops.get(i));
+			ArrayList<Instruction> dInst = Recompiler
+				.recompileHopsDag(dop, ec.getVariables(), null, true, true, 0);
+			if (DEBUG) {
+				System.out.println("HOP Dag and instructions for " + names.get(i));
+				System.out.println(Explain.explain(dop));
+				System.out.println(Explain.explain(dInst));
+			}
+			// create derivative instructions
+			executeInst(dInst, ec);
+			results.add(ec.getVariable("advar"+i));
+		}
+		return results;
+	}
+
+	public static ArrayList<Hop> constructHopsNR(LineageItem item, Map<Long, Hop> operands,	Hop mo, ArrayList<String> names)
+	{
+		// Hop dags for the derivatives share common sub-dags with 
+		// the lineage dag of the forward pass. This method starts 
+		// constructing the hop dag from the lineage dag, but adds 
+		// extra hops to the resulting dags as needed.
+		ArrayList<Hop>  allHops = new ArrayList<>();
+		Stack<LineageItem> stackItem = new Stack<>();
+		Stack<MutableInt> stackPos = new Stack<>();
+		stackItem.push(item); stackPos.push(new MutableInt(0));
+		while (!stackItem.empty()) {
+			LineageItem tmpItem = stackItem.peek();
+			MutableInt tmpPos = stackPos.peek();
+			// check ascent condition - no item processing
+			if (tmpItem.isVisited()) {
+				stackItem.pop(); stackPos.pop();
+			}
+			// check ascent condition - append item
+			else if( tmpItem.getInputs() == null
+				|| tmpItem.getInputs().length <= tmpPos.intValue() ) {
+				constructSingleHop(tmpItem, operands, mo, allHops, names);
+				stackItem.pop(); stackPos.pop();
+				tmpItem.setVisited();
+			}
+			// check descent condition
+			else if( tmpItem.getInputs() != null ) {
+				stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+				tmpPos.increment();
+				stackPos.push(new MutableInt(0));
+			}
+		}
+		return allHops;
+	}
+
+	private static void constructSingleHop(LineageItem item, Map<Long, Hop> operands, Hop mo,
+		ArrayList<Hop> allHops, ArrayList<String> names)
+	{
+		//process current lineage item
+		switch (item.getType()) {
+			case Creation: {
+				if(item.getData().startsWith(ADVARPREFIX)) {
+					long phId = Long.parseLong(item.getData().substring(3));
+					Hop input = operands.get(phId);
+					operands.remove(phId);
+					// Replace the placeholders with TReads
+					operands.put(item.getId(), input); // order preserving
+					break;
+				}
+				Instruction inst = InstructionParser.parseSingleInstruction(item.getData());
+
+				if(inst instanceof DataGenCPInstruction) {
+					DataGenCPInstruction rand = (DataGenCPInstruction) inst;
+					HashMap<String, Hop> params = new HashMap<>();
+					if(rand.getOpcode().equals("rand")) {
+						if(rand.output.getDataType() == Types.DataType.TENSOR)
+							params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+						else {
+							params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+							params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+						}
+						params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+						params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+						params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+						params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+						params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+						params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+					}
+					Hop datagen = new DataGenOp(Types.OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
+						new DataIdentifier("tmp"), params);
+					datagen.setBlocksize(rand.getBlocksize());
+					operands.put(item.getId(), datagen);
+				}
+				else if(inst instanceof VariableCPInstruction && ((VariableCPInstruction) inst).isCreateVariable()) {
+					String parts[] = InstructionUtils.getInstructionPartsWithValueType(inst.toString());
+					Types.DataType dt = Types.DataType.valueOf(parts[4]);
+					Types.ValueType vt = dt == Types.DataType.MATRIX ? Types.ValueType.FP64 : Types.ValueType.STRING;
+					HashMap<String, Hop> params = new HashMap<>();
+					params.put(DataExpression.IO_FILENAME, new LiteralOp(parts[2]));
+					params.put(DataExpression.READROWPARAM, new LiteralOp(Long.parseLong(parts[6])));
+					params.put(DataExpression.READCOLPARAM, new LiteralOp(Long.parseLong(parts[7])));
+					params.put(DataExpression.READNNZPARAM, new LiteralOp(Long.parseLong(parts[8])));
+					params.put(DataExpression.FORMAT_TYPE, new LiteralOp(parts[5]));
+					DataOp pread = new DataOp(parts[1].substring(5), dt, vt, Types.OpOpData.PERSISTENTREAD, params);
+					pread.setFileName(parts[2]);
+					operands.put(item.getId(), pread);
+				}
+				else if(inst instanceof RandSPInstruction) {
+					RandSPInstruction rand = (RandSPInstruction) inst;
+					HashMap<String, Hop> params = new HashMap<>();
+					if(rand.output.getDataType() == Types.DataType.TENSOR)
+						params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+					else {
+						params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+						params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+					}
+					params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+					params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+					params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+					params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+					params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+					params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+					Hop datagen = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("tmp"), params);
+					datagen.setBlocksize(rand.getBlocksize());
+					operands.put(item.getId(), datagen);
+				}
+				break;
+			}
+			case Instruction: {
+				CPInstruction.CPType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+
+				if(ctype != null) {
+					switch(ctype) {
+						case AggregateBinary: {
+							Hop input1 = operands.get(item.getInputs()[0].getId());
+							Hop input2 = operands.get(item.getInputs()[1].getId());
+							//Build the hops for the derivatives
+							ReorgOp trasnX = HopRewriteUtils.createTranspose(input1);
+							ReorgOp trasnW = HopRewriteUtils.createTranspose(input2);
+							Hop dX = HopRewriteUtils.createMatrixMultiply(mo, trasnW);
+							Hop dW = HopRewriteUtils.createMatrixMultiply(trasnX, mo);
+							operands.put(item.getId(), dX);
+							operands.put(item.getId() + 1, dW);
+							allHops.add(dX);
+							allHops.add(dW);
+							names.add("dX");
+							names.add("dW");
+							break;
+						}
+						case Binary: {
+							//handle special cases of binary operations
+							String opcode = item.getOpcode();
+							Hop output = null;
+							if(opcode.equals("+"))
+								output = HopRewriteUtils.createAggUnaryOp(mo, Types.AggOp.SUM, Types.Direction.Col);
+							operands.put(item.getId(), output);
+							allHops.add(output);
+							names.add("dB");
+							break;
+						}
+						default:
+							throw new DMLRuntimeException(
+								"Unsupported autoDiff instruction " + "type: " + ctype.name() + " (" + item.getOpcode() + ").");
+					}
+				}
+				break;
+			}
+			case Literal: {
+				CPOperand op = new CPOperand(item.getData());
+				operands.put(item.getId(), ScalarObjectFactory
+					.createLiteralOp(op.getValueType(), op.getName()));
+				break;
+			}
+			default:
+				throw new DMLRuntimeException("Lineage type " + item.getType() + " is not supported");
+		}
+	}
+	private static void executeInst(ArrayList<Instruction> newInst, ExecutionContext lrwec)
+	{
+		try {
+			//execute instructions
+			BasicProgramBlock pb = new BasicProgramBlock(new Program());
+			pb.setInstructions(newInst);
+			pb.execute(lrwec);
+		}
+		catch (Exception e) {
+			throw new DMLRuntimeException("Error executing autoDiff instruction" , e);
+		}
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java
new file mode 100644
index 0000000..ab4a373
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class AutoDiffTest extends AutomatedTestBase
+{
+	private final static String TEST_NAME = "autoDiff";
+	private final static String TEST_DIR = "functions/builtin/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + AutoDiffTest.class.getSimpleName() + "/";
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+	}
+
+	@Test
+	public void testAutoDiffCP1() {
+		runAutoDiffTest(Types.ExecType.CP);
+	}
+
+	private void runAutoDiffTest(Types.ExecType instType)
+	{
+		ExecMode platformOld = setExecMode(instType);
+
+		try
+		{
+			OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = false;
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[]{"-lineage", "-args", output("dX"), output("ad_dX")};
+			runTest(true, false, null, -1);
+			HashMap<MatrixValue.CellIndex, Double> dml_dX = readDMLMatrixFromOutputDir("dX");
+			HashMap<MatrixValue.CellIndex, Double> autoDiff_dX = readDMLMatrixFromOutputDir("ad_dX");
+			TestUtils.compareMatrices(dml_dX, autoDiff_dX, 1e-6, "Stat-DML", "Stat-AutoDiff");
+		}
+		finally {
+			rtplatform = platformOld;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/builtin/autoDiff.dml b/src/test/scripts/functions/builtin/autoDiff.dml
new file mode 100644
index 0000000..26922dc
--- /dev/null
+++ b/src/test/scripts/functions/builtin/autoDiff.dml
@@ -0,0 +1,58 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("nn/layers/affine.dml") as affine
+
+
+# # # initializing the matrix by hand parsing issues in rand command within lineage
+M = 5; N = 5
+X_batch = rand(rows=M, cols=N, sparsity=1)
+                  
+W_1 = rand(rows=M, cols=N, sparsity=1)
+b_1 = matrix(0, rows=1, cols=M)
+
+prob = affine::forward(X_batch, W_1, b_1)
+lin = lineage(prob)
+
+# # TODO stop instruction parser to parse lineage string
+# # for now it is stopped by adding a string foo as a work-around
+if(sum(prob) > 0)
+  lin = lin+"foo"
+
+# # The lineage is passed as a list item because even after adding "foo" string the 
+# # compiler keep parsing the lineage instruction so it is passed as a list item to avoid parsing
+# # # create autodiff by parsing the lineage instructions
+diffs = autoDiff(output=prob, lineage=list(lin));
+
+ad_dX = as.matrix(diffs['dX'])
+ad_dW = as.matrix(diffs['dW'])
+ad_dB = as.matrix(diffs['dB'])
+
+# # # # compute the derivatives from the backward script
+[dX, dW, dB] = affine::backward(prob, X_batch, W_1, b_1)
+
+sameX = dX != ad_dX
+sameW = dW != ad_dW
+sameB = dB != ad_dB
+
+output = ((sum(sameX) == 0) & (sum(sameW) == 0) & (sum(sameB) == 0))
+
+write(dX, $1)
+write(ad_dX, $2)
\ No newline at end of file