You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2021/08/06 12:18:14 UTC
[systemds] branch master updated: [SYSTEMDS-3084] Basic
Auto-differentiation for Affine layer
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new c2c8864 [SYSTEMDS-3084] Basic Auto-differentiation for Affine layer
c2c8864 is described below
commit c2c88645ac45b4106979841bbd3261ab3cc30169
Author: Shafaq Siddiqi <sh...@tugraz.at>
AuthorDate: Fri Aug 6 14:14:31 2021 +0200
[SYSTEMDS-3084] Basic Auto-differentiation for Affine layer
This patch introduces autoDiff builtin. autoDiff takes the output
and lineage trace of the last layer and a list of weights and
biases, and returns their derivatives.
It internally uses the lineage trace of the forward layer to
construct the Hop dags for the derivatives (reuse common
sub-dags), compile those and execute to produce the outputs.
Current support is limited to Affine layer and local execution.
AMLS project SS2021
Closes #1350.
---
.../java/org/apache/sysds/common/Builtins.java | 1 +
src/main/java/org/apache/sysds/common/Types.java | 2 +-
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 3 +-
.../apache/sysds/lops/ParameterizedBuiltin.java | 8 +-
.../org/apache/sysds/parser/DMLTranslator.java | 2 +-
.../ParameterizedBuiltinFunctionExpression.java | 23 +-
.../sysds/runtime/functionobjects/Builtin.java | 3 +-
.../functionobjects/ParameterizedBuiltin.java | 5 +-
.../runtime/instructions/CPInstructionParser.java | 3 +-
.../runtime/instructions/SPInstructionParser.java | 1 +
.../cp/ParameterizedBuiltinCPInstruction.java | 32 +--
.../org/apache/sysds/runtime/util/AutoDiff.java | 262 +++++++++++++++++++++
.../sysds/test/functions/builtin/AutoDiffTest.java | 70 ++++++
src/test/scripts/functions/builtin/autoDiff.dml | 58 +++++
14 files changed, 449 insertions(+), 24 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index a063b47..f2b6c6a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -272,6 +272,7 @@ public enum Builtins {
XOR("xor", false),
//parameterized builtin functions
+ AUTODIFF("autoDiff", false, true),
CDF("cdf", false, true),
CVLM("cvlm", true, false),
GROUPEDAGG("aggregate", "groupedAggregate", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index da53091..d15adad 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -466,7 +466,7 @@ public class Types
}
public enum ParamBuiltinOp {
- INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
+ AUTODIFF, INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
LOWER_TRI, UPPER_TRI,
TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
TOKENIZE, TOSTRING, LIST, PARAMSERV
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index b5be675..8c9666b 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -196,7 +196,8 @@ public class ParameterizedBuiltinOp extends MultiThreadedHop {
case TRANSFORMMETA:
case TOSTRING:
case PARAMSERV:
- case LIST: {
+ case LIST:
+ case AUTODIFF:{
ExecType et = optFindExecType();
ParameterizedBuiltin pbilop = new ParameterizedBuiltin(
inputlops, _op, getDataType(), getValueType(), et);
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index 7c39548..a0f9331 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -171,7 +171,7 @@ public class ParameterizedBuiltin extends Lop
sb.append(OPERAND_DELIMITOR);
}
-
+
break;
case TOKENIZE:
@@ -184,6 +184,12 @@ public class ParameterizedBuiltin extends Lop
sb.append(compileGenericParamMap(_inputParams));
break;
}
+ case AUTODIFF: {
+ sb.append("autoDiff"); //opcode
+ sb.append(OPERAND_DELIMITOR);
+ sb.append(compileGenericParamMap(_inputParams));
+ break;
+ }
case LIST: {
sb.append("nvlist"); //opcode
sb.append(OPERAND_DELIMITOR);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 89fc4ca..430fe7f 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2008,6 +2008,7 @@ public class DMLTranslator
case TRANSFORMCOLMAP:
case TRANSFORMMETA:
case PARAMSERV:
+ case AUTODIFF:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.valueOf(source.getOpCode().name()), paramHops);
break;
@@ -2029,7 +2030,6 @@ public class DMLTranslator
target.getValueType(), ParamBuiltinOp.TOSTRING, paramHops) :
HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), OpOp2.PLUS);
break;
-
case LISTNV:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.LIST, paramHops);
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index d074d0d..26a54ac 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -48,12 +48,14 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
public static final String TF_FN_PARAM_DATA = "target";
public static final String TF_FN_PARAM_MTD2 = "meta";
public static final String TF_FN_PARAM_SPEC = "spec";
+ public static final String LINEAGE_TRACE = "lineage";
public static final String TF_FN_PARAM_MTD = "transformPath"; //NOTE MB: for backwards compatibility
public static HashMap<Builtins, ParamBuiltinOp> pbHopMap;
static {
pbHopMap = new HashMap<>();
+ pbHopMap.put(Builtins.AUTODIFF, ParamBuiltinOp.AUTODIFF);
pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG);
pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY);
pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE);
@@ -231,7 +233,10 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
case TOSTRING:
validateCastAsString(output, conditional);
break;
-
+
+ case AUTODIFF:
+ validateAutoDiff(output, conditional);
+ break;
case LISTNV:
validateNamedList(output, conditional);
break;
@@ -251,6 +256,22 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
}
}
+ private void validateAutoDiff(DataIdentifier output, boolean conditional) {
+ //validate data / metadata (recode maps)
+ checkDataType("lineage", LINEAGE_TRACE, DataType.LIST, conditional);
+
+ //validate specification
+ checkDataValueType(false, "lineage", LINEAGE_TRACE, DataType.LIST, ValueType.UNKNOWN, conditional);
+ HashMap<String, Expression> varParams = getVarParams();
+ // set output characteristics
+ output.setDataType(DataType.LIST);
+ output.setValueType(ValueType.UNKNOWN);
+ // TODO dimension should be set to -1 but could not set due to lineage parsing error in Spark contetx
+ output.setDimensions(varParams.size(), 1);
+ // output.setDimensions(-1, 1);
+ output.setBlocksize(-1);
+ }
+
@Override
public void validateExpression(MultiAssignmentStatement stmt, HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional)
{
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 46acba4..45b6c33 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -47,7 +47,7 @@ public class Builtin extends ValueFunction
{
private static final long serialVersionUID = 3836744687789840574L;
- public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
+ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, DROP_INVALID_LENGTH, MAP,
@@ -61,6 +61,7 @@ public class Builtin extends ValueFunction
static public HashMap<String, BuiltinCode> String2BuiltinCode;
static {
String2BuiltinCode = new HashMap<>();
+ String2BuiltinCode.put( "autoDiff" , BuiltinCode.AUTODIFF);
String2BuiltinCode.put( "sin" , BuiltinCode.SIN);
String2BuiltinCode.put( "cos" , BuiltinCode.COS);
String2BuiltinCode.put( "tan" , BuiltinCode.TAN);
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
index c15f6da..d800efd 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/ParameterizedBuiltin.java
@@ -43,7 +43,7 @@ public class ParameterizedBuiltin extends ValueFunction
private static final long serialVersionUID = -7987603644903675052L;
public enum ParameterizedBuiltinCode {
- CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI,
+ AUTODIFF, CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI,
TOKENIZE, TRANSFORMAPPLY, TRANSFORMDECODE, PARAMSERV }
public enum ProbabilityDistributionCode {
INVALID, NORMAL, EXP, CHISQ, F, T }
@@ -185,6 +185,9 @@ public class ParameterizedBuiltin extends ValueFunction
case PARAMSERV:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.PARAMSERV);
+
+ case AUTODIFF:
+ return new ParameterizedBuiltin(ParameterizedBuiltinCode.AUTODIFF);
default:
throw new DMLRuntimeException("Invalid parameterized builtin code: " + code);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b07cefa..323fac1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -160,7 +160,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
String2CPInstructionType.put( "n+" , CPType.BuiltinNary);
-
+
String2CPInstructionType.put( "exp" , CPType.Unary);
String2CPInstructionType.put( "abs" , CPType.Unary);
String2CPInstructionType.put( "sin" , CPType.Unary);
@@ -203,6 +203,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "list", CPType.BuiltinNary);
// Parameterized Builtin Functions
+ String2CPInstructionType.put( "autoDiff" , CPType.ParameterizedBuiltin);
String2CPInstructionType.put("paramserv", CPType.ParameterizedBuiltin);
String2CPInstructionType.put( "nvlist", CPType.ParameterizedBuiltin);
String2CPInstructionType.put( "cdf", CPType.ParameterizedBuiltin);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index d1eb4a7..a02490f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -262,6 +262,7 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "isinf", SPType.Unary);
// Parameterized Builtin Functions
+ String2SPInstructionType.put( "autoDiff" , SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "groupedagg", SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "mapgroupedagg", SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "rmempty", SPType.ParameterizedBuiltin);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f115b52..6de5878 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -19,14 +19,6 @@
package org.apache.sysds.runtime.instructions.cp;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -37,12 +29,9 @@ import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.ParameterizedBuiltinFunctionExpression;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
+import org.apache.sysds.runtime.controlprogram.caching.*;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -61,8 +50,13 @@ import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
+import org.apache.sysds.runtime.util.AutoDiff;
import org.apache.sysds.runtime.util.DataConverter;
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction {
private static final Log LOG = LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
private static final int TOSTRING_MAXROWS = 100;
@@ -91,7 +85,6 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
public static LinkedHashMap<String, String> constructParameterMap(String[] params) {
// process all elements in "params" except first(opcode) and last(output)
LinkedHashMap<String, String> paramMap = new LinkedHashMap<>();
-
// all parameters are of form <name=value>
String[] parts;
for(int i = 1; i <= params.length - 2; i++) {
@@ -150,7 +143,7 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
}
else if(opcode.equals("transformapply") || opcode.equals("transformdecode") ||
opcode.equals("transformcolmap") || opcode.equals("transformmeta") || opcode.equals("tokenize") ||
- opcode.equals("toString") || opcode.equals("nvlist")) {
+ opcode.equals("toString") || opcode.equals("nvlist") || opcode.equals("autoDiff")) {
return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str);
}
else if("paramserv".equals(opcode)) {
@@ -178,6 +171,13 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
sores = new DoubleObject(result);
ec.setScalarOutput(output.getName(), sores);
}
+ else if(opcode.equalsIgnoreCase("autoDiff"))
+ {
+ ArrayList<Data> lineage = (ArrayList<Data>) ec.getListObject(params.get("lineage")).getData();
+ MatrixObject mo = ec.getMatrixObject(params.get("output"));
+ ListObject diffs = AutoDiff.getBackward(mo, lineage, ExecutionContextFactory.createContext());
+ ec.setVariable(output.getName(), diffs);
+ }
else if(opcode.equalsIgnoreCase("groupedagg")) {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get(Statement.GAGG_TARGET));
@@ -501,7 +501,7 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
return Pair.of(output.getName(),
new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, meta, spec)));
}
- else if (opcode.equalsIgnoreCase("nvlist")) {
+ else if (opcode.equalsIgnoreCase("nvlist") || opcode.equalsIgnoreCase("autoDiff")) {
List<String> names = new ArrayList<>(params.keySet());
CPOperand[] listOperands = names.stream().map(n -> ec.containsVariable(params.get(n))
? new CPOperand(n, ec.getVariable(params.get(n)))
diff --git a/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
new file mode 100644
index 0000000..2178a13
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/util/AutoDiff.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.util;
+
+import org.apache.commons.lang3.mutable.MutableInt;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.*;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionParser;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.*;
+import org.apache.sysds.runtime.instructions.spark.RandSPInstruction;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageParser;
+import org.apache.sysds.utils.Explain;
+
+import java.util.*;
+
+public class AutoDiff {
+ private static final String ADVARPREFIX = "adVar";
+ private static final boolean DEBUG = false;
+
+ public static ListObject getBackward(MatrixObject mo, ArrayList<Data> lineage, ExecutionContext adec) {
+
+ ArrayList<String> names = new ArrayList<String>();
+ // parse the lineage and take the number of instructions as for each instruction there is separate hop DAG
+ String lin = lineage.get(0).toString();
+ // get rid of foo flag
+ lin = lin.replace("foo", "");
+ List<Data> data = parseNComputeAutoDiffFromLineage(mo, lin, names, adec);
+ return new ListObject(data, names);
+ }
+
+ public static List<Data> parseNComputeAutoDiffFromLineage(MatrixObject mo, String mainTrace,
+ ArrayList<String> names, ExecutionContext ec ) {
+
+ LineageItem root = LineageParser.parseLineageTrace(mainTrace);
+ if (DEBUG) {
+ System.out.println("Lineage trace of the forward pass");
+ System.out.println(mainTrace);
+ }
+ // Recursively construct hops
+ root.resetVisitStatusNR();
+ Map<Long, Hop> operands = new HashMap<>();
+ // set variable for input matrix
+ ec.setVariable("X", mo);
+ DataOp input = HopRewriteUtils.createTransientRead("X", mo);
+ // each instruction Hop is stored separately as each instruction creates a new differentiation
+ ArrayList<Hop> allHops = constructHopsNR(root, operands, input, names);
+
+ ArrayList<Data> results = new ArrayList<>();
+ for(int i=0; i< allHops.size(); i++) {
+ DataOp dop = HopRewriteUtils.createTransientWrite("advar"+i, allHops.get(i));
+ ArrayList<Instruction> dInst = Recompiler
+ .recompileHopsDag(dop, ec.getVariables(), null, true, true, 0);
+ if (DEBUG) {
+ System.out.println("HOP Dag and instructions for " + names.get(i));
+ System.out.println(Explain.explain(dop));
+ System.out.println(Explain.explain(dInst));
+ }
+ // create derivative instructions
+ executeInst(dInst, ec);
+ results.add(ec.getVariable("advar"+i));
+ }
+ return results;
+ }
+
+ public static ArrayList<Hop> constructHopsNR(LineageItem item, Map<Long, Hop> operands, Hop mo, ArrayList<String> names)
+ {
+ // Hop dags for the derivatives share common sub-dags with
+ // the lineage dag of the forward pass. This method starts
+ // constructing the hop dag from the lineage dag, but adds
+ // extra hops to the resulting dags as needed.
+ ArrayList<Hop> allHops = new ArrayList<>();
+ Stack<LineageItem> stackItem = new Stack<>();
+ Stack<MutableInt> stackPos = new Stack<>();
+ stackItem.push(item); stackPos.push(new MutableInt(0));
+ while (!stackItem.empty()) {
+ LineageItem tmpItem = stackItem.peek();
+ MutableInt tmpPos = stackPos.peek();
+ // check ascent condition - no item processing
+ if (tmpItem.isVisited()) {
+ stackItem.pop(); stackPos.pop();
+ }
+ // check ascent condition - append item
+ else if( tmpItem.getInputs() == null
+ || tmpItem.getInputs().length <= tmpPos.intValue() ) {
+ constructSingleHop(tmpItem, operands, mo, allHops, names);
+ stackItem.pop(); stackPos.pop();
+ tmpItem.setVisited();
+ }
+ // check descent condition
+ else if( tmpItem.getInputs() != null ) {
+ stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+ tmpPos.increment();
+ stackPos.push(new MutableInt(0));
+ }
+ }
+ return allHops;
+ }
+
+ private static void constructSingleHop(LineageItem item, Map<Long, Hop> operands, Hop mo,
+ ArrayList<Hop> allHops, ArrayList<String> names)
+ {
+ //process current lineage item
+ switch (item.getType()) {
+ case Creation: {
+ if(item.getData().startsWith(ADVARPREFIX)) {
+ long phId = Long.parseLong(item.getData().substring(3));
+ Hop input = operands.get(phId);
+ operands.remove(phId);
+ // Replace the placeholders with TReads
+ operands.put(item.getId(), input); // order preserving
+ break;
+ }
+ Instruction inst = InstructionParser.parseSingleInstruction(item.getData());
+
+ if(inst instanceof DataGenCPInstruction) {
+ DataGenCPInstruction rand = (DataGenCPInstruction) inst;
+ HashMap<String, Hop> params = new HashMap<>();
+ if(rand.getOpcode().equals("rand")) {
+ if(rand.output.getDataType() == Types.DataType.TENSOR)
+ params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+ else {
+ params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+ params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+ }
+ params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+ params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+ params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+ params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+ params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+ params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+ }
+ Hop datagen = new DataGenOp(Types.OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
+ new DataIdentifier("tmp"), params);
+ datagen.setBlocksize(rand.getBlocksize());
+ operands.put(item.getId(), datagen);
+ }
+ else if(inst instanceof VariableCPInstruction && ((VariableCPInstruction) inst).isCreateVariable()) {
+ String parts[] = InstructionUtils.getInstructionPartsWithValueType(inst.toString());
+ Types.DataType dt = Types.DataType.valueOf(parts[4]);
+ Types.ValueType vt = dt == Types.DataType.MATRIX ? Types.ValueType.FP64 : Types.ValueType.STRING;
+ HashMap<String, Hop> params = new HashMap<>();
+ params.put(DataExpression.IO_FILENAME, new LiteralOp(parts[2]));
+ params.put(DataExpression.READROWPARAM, new LiteralOp(Long.parseLong(parts[6])));
+ params.put(DataExpression.READCOLPARAM, new LiteralOp(Long.parseLong(parts[7])));
+ params.put(DataExpression.READNNZPARAM, new LiteralOp(Long.parseLong(parts[8])));
+ params.put(DataExpression.FORMAT_TYPE, new LiteralOp(parts[5]));
+ DataOp pread = new DataOp(parts[1].substring(5), dt, vt, Types.OpOpData.PERSISTENTREAD, params);
+ pread.setFileName(parts[2]);
+ operands.put(item.getId(), pread);
+ }
+ else if(inst instanceof RandSPInstruction) {
+ RandSPInstruction rand = (RandSPInstruction) inst;
+ HashMap<String, Hop> params = new HashMap<>();
+ if(rand.output.getDataType() == Types.DataType.TENSOR)
+ params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+ else {
+ params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+ params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+ }
+ params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+ params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+ params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+ params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+ params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+ params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+ Hop datagen = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("tmp"), params);
+ datagen.setBlocksize(rand.getBlocksize());
+ operands.put(item.getId(), datagen);
+ }
+ break;
+ }
+ case Instruction: {
+ CPInstruction.CPType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+
+ if(ctype != null) {
+ switch(ctype) {
+ case AggregateBinary: {
+ Hop input1 = operands.get(item.getInputs()[0].getId());
+ Hop input2 = operands.get(item.getInputs()[1].getId());
+ //Build the hops for the derivatives
+ ReorgOp trasnX = HopRewriteUtils.createTranspose(input1);
+ ReorgOp trasnW = HopRewriteUtils.createTranspose(input2);
+ Hop dX = HopRewriteUtils.createMatrixMultiply(mo, trasnW);
+ Hop dW = HopRewriteUtils.createMatrixMultiply(trasnX, mo);
+ operands.put(item.getId(), dX);
+ operands.put(item.getId() + 1, dW);
+ allHops.add(dX);
+ allHops.add(dW);
+ names.add("dX");
+ names.add("dW");
+ break;
+ }
+ case Binary: {
+ //handle special cases of binary operations
+ String opcode = item.getOpcode();
+ Hop output = null;
+ if(opcode.equals("+"))
+ output = HopRewriteUtils.createAggUnaryOp(mo, Types.AggOp.SUM, Types.Direction.Col);
+ operands.put(item.getId(), output);
+ allHops.add(output);
+ names.add("dB");
+ break;
+ }
+ default:
+ throw new DMLRuntimeException(
+ "Unsupported autoDiff instruction " + "type: " + ctype.name() + " (" + item.getOpcode() + ").");
+ }
+ }
+ break;
+ }
+ case Literal: {
+ CPOperand op = new CPOperand(item.getData());
+ operands.put(item.getId(), ScalarObjectFactory
+ .createLiteralOp(op.getValueType(), op.getName()));
+ break;
+ }
+ default:
+ throw new DMLRuntimeException("Lineage type " + item.getType() + " is not supported");
+ }
+ }
+ private static void executeInst(ArrayList<Instruction> newInst, ExecutionContext lrwec)
+ {
+ try {
+ //execute instructions
+ BasicProgramBlock pb = new BasicProgramBlock(new Program());
+ pb.setInstructions(newInst);
+ pb.execute(lrwec);
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException("Error executing autoDiff instruction" , e);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java
new file mode 100644
index 0000000..ab4a373
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/AutoDiffTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class AutoDiffTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "autoDiff";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + AutoDiffTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+ }
+
+ @Test
+ public void testAutoDiffCP1() {
+ runAutoDiffTest(Types.ExecType.CP);
+ }
+
+ private void runAutoDiffTest(Types.ExecType instType)
+ {
+ ExecMode platformOld = setExecMode(instType);
+
+ try
+ {
+ OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = false;
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-lineage", "-args", output("dX"), output("ad_dX")};
+ runTest(true, false, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> dml_dX = readDMLMatrixFromOutputDir("dX");
+ HashMap<MatrixValue.CellIndex, Double> autoDiff_dX = readDMLMatrixFromOutputDir("ad_dX");
+ TestUtils.compareMatrices(dml_dX, autoDiff_dX, 1e-6, "Stat-DML", "Stat-AutoDiff");
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/autoDiff.dml b/src/test/scripts/functions/builtin/autoDiff.dml
new file mode 100644
index 0000000..26922dc
--- /dev/null
+++ b/src/test/scripts/functions/builtin/autoDiff.dml
@@ -0,0 +1,58 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("nn/layers/affine.dml") as affine
+
+
+# # # initializing the matrix by hand parsing issues in rand command within lineage
+M = 5; N = 5
+X_batch = rand(rows=M, cols=N, sparsity=1)
+
+W_1 = rand(rows=M, cols=N, sparsity=1)
+b_1 = matrix(0, rows=1, cols=M)
+
+prob = affine::forward(X_batch, W_1, b_1)
+lin = lineage(prob)
+
+# # TODO stop instruction parser to parse lineage string
+# # for now it is stopped by adding a string foo as a work-around
+if(sum(prob) > 0)
+ lin = lin+"foo"
+
+# # The lineage is passed as a list item because even after adding "foo" string the
+# # compiler keep parsing the lineage instruction so it is passed as a list item to avoid parsing
+# # # create autodiff by parsing the lineage instructions
+diffs = autoDiff(output=prob, lineage=list(lin));
+
+ad_dX = as.matrix(diffs['dX'])
+ad_dW = as.matrix(diffs['dW'])
+ad_dB = as.matrix(diffs['dB'])
+
+# # # # compute the derivatives from the backward script
+[dX, dW, dB] = affine::backward(prob, X_batch, W_1, b_1)
+
+sameX = dX != ad_dX
+sameW = dW != ad_dW
+sameB = dB != ad_dB
+
+output = ((sum(sameX) == 0) & (sum(sameW) == 0) & (sum(sameB) == 0))
+
+write(dX, $1)
+write(ad_dX, $2)
\ No newline at end of file