You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/12/21 22:57:07 UTC
[systemds] branch main updated: [SYSTEMDS-3479] Mark Spark instructions to persist and locally cache
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f89a38da04 [SYSTEMDS-3479] Mark Spark instructions to persist and locally cache
f89a38da04 is described below
commit f89a38da041e685651afe016aafbe70288a31a66
Author: Arnab Phani <ph...@gmail.com>
AuthorDate: Wed Dec 21 23:56:24 2022 +0100
[SYSTEMDS-3479] Mark Spark instructions to persist and locally cache
This patch adds the compiler flags and runtime support to checkpoint
any Spark instruction which is marked for caching. During postprocessing
of a marked instruction, we first inplace persist the RDD and then store
the RDD in the local Lineage cache for reuse. This patch also fixes a
bug in the last commit which was unpersisting the locally cached RDDs
during rmvar. Future commits will add rewrites to mark the Spark
instructions for caching in a cost-based manner.
Hyperparameter tuning of LmDS with 2.5k columns improves by
22x by caching the cpmm results in the executors.
Closes #1756
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 4 +-
src/main/java/org/apache/sysds/hops/Hop.java | 8 +
src/main/java/org/apache/sysds/lops/MMCJ.java | 2 +
src/main/java/org/apache/sysds/lops/MMRJ.java | 4 +-
.../org/apache/sysds/lops/OutputParameters.java | 6 +-
.../controlprogram/caching/MatrixObject.java | 2 +-
.../context/SparkExecutionContext.java | 4 +-
.../spark/AggregateBinarySPInstruction.java | 8 +-
.../instructions/spark/BinarySPInstruction.java | 7 +-
.../spark/ComputationSPInstruction.java | 40 ++++
.../instructions/spark/CpmmSPInstruction.java | 8 +-
.../instructions/spark/RmmSPInstruction.java | 8 +-
.../instructions/spark/data/LineageObject.java | 10 +
.../runtime/instructions/spark/data/RDDObject.java | 5 +
.../apache/sysds/runtime/lineage/LineageCache.java | 215 +++++++++++----------
.../sysds/runtime/lineage/LineageCacheConfig.java | 15 +-
.../sysds/runtime/lineage/LineageCacheEntry.java | 4 +
.../runtime/lineage/LineageCacheStatistics.java | 26 ++-
.../java/org/apache/sysds/utils/Statistics.java | 1 +
.../functions/async/LineageReuseSparkTest.java | 12 ++
.../scripts/functions/async/LineageReuseSpark2.dml | 53 +++++
21 files changed, 316 insertions(+), 126 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index dd04307229..9f80f7a683 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -217,7 +217,7 @@ public class AggBinaryOp extends MultiThreadedHop {
input1.getDim1(), input1.getDim2(), input1.getBlocksize(), input1.getNnz(),
input2.getDim1(), input2.getDim2(), input2.getBlocksize(), input2.getNnz(),
mmtsj, chain, _hasLeftPMInput, tmmRewrite );
- //dispatch SPARK lops construction
+ //dispatch SPARK lops construction
switch( _method )
{
case TSMM:
@@ -790,6 +790,7 @@ public class AggBinaryOp extends MultiThreadedHop {
Lop cpmm = new MMCJ(getInput().get(0).constructLops(), getInput().get(1).constructLops(),
getDataType(), getValueType(), _outputEmptyBlocks, aggtype, ExecType.SPARK);
setOutputDimensions( cpmm );
+ //setMarkForLineageCaching(cpmm);
setLineNumbers( cpmm );
setLops( cpmm );
}
@@ -823,6 +824,7 @@ public class AggBinaryOp extends MultiThreadedHop {
Lop rmm = new MMRJ(getInput().get(0).constructLops(),getInput().get(1).constructLops(),
getDataType(), getValueType(), ExecType.SPARK);
setOutputDimensions(rmm);
+ //setMarkForLineageCaching(rmm);
setLineNumbers( rmm );
setLops(rmm);
}
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 3988a6b59f..fa911749ee 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -57,6 +57,7 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -1235,6 +1236,13 @@ public abstract class Hop implements ParseInfo {
getDim1(), getDim2(), getBlocksize(), getNnz(), getUpdateType());
}
+ protected void setMarkForLineageCaching(Lop lop) {
+ //TODO: set the flag in the HOP via a rewrite
+ //lop.getOutputParameters().setLineageCacheCandidate(requiresLineageCaching());
+ if (!LineageCacheConfig.ReuseCacheType.isNone())
+ lop.getOutputParameters().setLineageCacheCandidate(true);
+ }
+
protected void setOutputDimensionsIncludeCompressedSize(Lop lop) {
lop.getOutputParameters().setDimensions(
getDim1(), getDim2(), getBlocksize(), getNnz(), getUpdateType(), getCompressedSize());
diff --git a/src/main/java/org/apache/sysds/lops/MMCJ.java b/src/main/java/org/apache/sysds/lops/MMCJ.java
index e804f84e9b..544e89341b 100644
--- a/src/main/java/org/apache/sysds/lops/MMCJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMCJ.java
@@ -109,6 +109,8 @@ public class MMCJ extends Lop
}
else
sb.append(_type.name());
+ sb.append( OPERAND_DELIMITOR );
+ sb.append(getOutputParameters().getLinCacheMarking());
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/lops/MMRJ.java b/src/main/java/org/apache/sysds/lops/MMRJ.java
index e35cdead7e..21577eeddb 100644
--- a/src/main/java/org/apache/sysds/lops/MMRJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMRJ.java
@@ -59,11 +59,13 @@ public class MMRJ extends Lop
@Override
public String getInstructions(String input1, String input2, String output) {
+ boolean toCache = getOutputParameters().getLinCacheMarking();
return InstructionUtils.concatOperands(
getExecType().name(),
"rmm",
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
- prepOutputOperand(output));
+ prepOutputOperand(output),
+ Boolean.toString(toCache));
}
}
diff --git a/src/main/java/org/apache/sysds/lops/OutputParameters.java b/src/main/java/org/apache/sysds/lops/OutputParameters.java
index 64ba755395..9454f19a13 100644
--- a/src/main/java/org/apache/sysds/lops/OutputParameters.java
+++ b/src/main/java/org/apache/sysds/lops/OutputParameters.java
@@ -39,7 +39,7 @@ public class OutputParameters
private long _blocksize = -1;
private String _file_name = null;
private String _file_label = null;
- private boolean _linCacheCandidate = true;
+ private boolean _linCacheCandidate = false;
private long _compressedSize = -1;
FileFormat matrix_format = FileFormat.BINARY;
@@ -162,6 +162,10 @@ public class OutputParameters
public void setUpdateType(UpdateType update) {
_updateType = update;
}
+
+ public void setLineageCacheCandidate(boolean reqCaching) {
+ _linCacheCandidate = reqCaching;
+ }
public boolean getLinCacheMarking() {
return _linCacheCandidate;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index c723cc56fa..e0139f2a62 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -476,7 +476,7 @@ public class MatrixObject extends CacheableData<MatrixBlock> {
FileFormat fmt = iimd.getFileFormat();
MatrixBlock mb = null;
try {
- // prevent unnecessary collect through rdd checkpoint
+ // prevent unnecessary collect through rdd checkpoint (unless lineage cached)
if(rdd.allowsShortCircuitCollect()) {
lrdd = (RDDObject) rdd.getLineageChilds().get(0);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 48778cb4d4..77eca47640 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1522,7 +1522,9 @@ public class SparkExecutionContext extends ExecutionContext
if( lob instanceof RDDObject ) {
RDDObject rdd = (RDDObject)lob;
int rddID = rdd.getRDD().id();
- cleanupRDDVariable(rdd.getRDD());
+ //skip unpersisting if locally cached
+ if (!lob.isInLineageCache())
+ cleanupRDDVariable(rdd.getRDD());
if( rdd.getHDFSFilename()!=null ) { //deferred file removal
HDFSTool.deleteFileWithMTDIfExistOnHDFS(rdd.getHDFSFilename());
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
index 80de732505..4575b06252 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
@@ -26,8 +26,12 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
* Class to group the different MM <code>SPInstruction</code>s together.
*/
public abstract class AggregateBinarySPInstruction extends BinarySPInstruction {
- protected AggregateBinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
- String opcode, String istr) {
+ protected AggregateBinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
super(type, op, in1, in2, out, opcode, istr);
}
+
+ protected AggregateBinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, boolean toCache, String istr) {
+ super(type, op, in1, in2, out, opcode, toCache, istr);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index 16196d5e0d..3c70d4021a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -56,7 +56,12 @@ public abstract class BinarySPInstruction extends ComputationSPInstruction {
protected BinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
super(type, op, in1, in2, out, opcode, istr);
}
-
+
+ protected BinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2,
+ CPOperand out, String opcode, boolean toCache, String istr) {
+ super(type, op, in1, in2, out, opcode, toCache, istr);
+ }
+
public static BinarySPInstruction parseInstruction ( String str ) {
CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
index d380d913b2..465ce35820 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
@@ -20,7 +20,10 @@
package org.apache.sysds.runtime.instructions.spark;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.storage.StorageLevel;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
@@ -28,15 +31,22 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
+import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import java.util.Map;
+
public abstract class ComputationSPInstruction extends SPInstruction implements LineageTraceable {
public CPOperand output;
public CPOperand input1, input2, input3;
+ private boolean toPersistAndCache;
protected ComputationSPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
@@ -46,6 +56,15 @@ public abstract class ComputationSPInstruction extends SPInstruction implements
output = out;
}
+ protected ComputationSPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, boolean toCache, String istr) {
+ super(type, op, opcode, istr);
+ input1 = in1;
+ input2 = in2;
+ input3 = null;
+ output = out;
+ toPersistAndCache = toCache;
+ }
+
protected ComputationSPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
input1 = in1;
@@ -126,6 +145,27 @@ public abstract class ComputationSPInstruction extends SPInstruction implements
mcOut.set(1, mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
}
}
+
+ public boolean isRDDtoCache() {
+ return toPersistAndCache;
+ }
+
+ public void checkpointRDD(ExecutionContext ec) {
+ if (!toPersistAndCache)
+ return;
+
+ SparkExecutionContext sec = (SparkExecutionContext)ec;
+ CacheableData<?> cd = sec.getCacheableData(output.getName());
+ RDDObject inro = cd.getRDDHandle();
+ JavaPairRDD<?,?> outrdd = SparkUtils.copyBinaryBlockMatrix((JavaPairRDD<MatrixIndexes, MatrixBlock>)inro.getRDD(), false);
+ //TODO: remove shallow copying as short-circuit collect is disabled if locally cached
+ outrdd = outrdd.persist((StorageLevel.MEMORY_AND_DISK()));
+ RDDObject outro = new RDDObject(outrdd); //create new rdd object
+ outro.setCheckpointRDD(true); //mark as checkpointed
+ outro.addLineageChild(inro); //keep lineage to prevent cycles on cleanup
+ cd.setRDDHandle(outro);
+ sec.setVariable(output.getName(), cd);
+ }
@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 79832eabe2..253480b482 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -66,8 +66,9 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
private final boolean _outputEmptyBlocks;
private final SparkAggType _aggtype;
- private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, String istr) {
- super(SPType.CPMM, op, in1, in2, out, opcode, istr);
+ private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, boolean toCache, String istr) {
+ super(SPType.CPMM, op, in1, in2, out, opcode, toCache, istr);
_outputEmptyBlocks = outputEmptyBlocks;
_aggtype = aggtype;
}
@@ -83,7 +84,8 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
SparkAggType aggtype = SparkAggType.valueOf(parts[5]);
- return new CpmmSPInstruction(aggbin, in1, in2, out, outputEmptyBlocks, aggtype, opcode, str);
+ boolean toCache = parts.length == 7 ? Boolean.parseBoolean(parts[6]) : false;
+ return new CpmmSPInstruction(aggbin, in1, in2, out, outputEmptyBlocks, aggtype, opcode, toCache, str);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
index 70d4bef1dd..130045a890 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
@@ -49,8 +49,9 @@ import java.util.LinkedList;
public class RmmSPInstruction extends AggregateBinarySPInstruction {
- private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
- super(SPType.RMM, op, in1, in2, out, opcode, istr);
+ private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, boolean toCache, String istr) {
+ super(SPType.RMM, op, in1, in2, out, opcode, toCache, istr);
}
public static RmmSPInstruction parseInstruction( String str ) {
@@ -61,8 +62,9 @@ public class RmmSPInstruction extends AggregateBinarySPInstruction {
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
+ boolean toCache = parts.length == 5 ? Boolean.parseBoolean(parts[4]) : false;
- return new RmmSPInstruction(null, in1, in2, out, opcode, str);
+ return new RmmSPInstruction(null, in1, in2, out, opcode, toCache, str);
}
else {
throw new DMLRuntimeException("RmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
index 0882dd1d9a..f4b99bb03e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
@@ -28,6 +28,7 @@ public abstract class LineageObject
{
//basic lineage information
protected int _numRef = -1;
+ protected boolean _lineageCached = false;
protected final List<LineageObject> _childs;
//N:1 back reference to matrix/frame object
@@ -35,6 +36,7 @@ public abstract class LineageObject
protected LineageObject() {
_numRef = 0;
+ _lineageCached = false;
_childs = new ArrayList<>();
}
@@ -49,6 +51,14 @@ public abstract class LineageObject
public boolean hasBackReference() {
return (_cd != null);
}
+
+ public void setLineageCached() {
+ _lineageCached = true;
+ }
+
+ public boolean isInLineageCache() {
+ return _lineageCached;
+ }
public void incrementNumReferences() {
_numRef++;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
index 2b03a00d31..04d021b6ff 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
@@ -103,6 +103,11 @@ public class RDDObject extends LineageObject
public boolean allowsShortCircuitCollect()
{
+ // If the RDD is marked to be persisted and cached locally, we want to collect the RDD
+ // so that the next time we can reuse the RDD.
+ if (isInLineageCache())
+ return false;
+
return ( isCheckpointRDD() && getLineageChilds().size() == 1
&& getLineageChilds().get(0) instanceof RDDObject );
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index d4eb4b8f92..0391583b4c 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -53,7 +53,6 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
-import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
@@ -95,31 +94,10 @@ public class LineageCache
return false;
boolean reuse = false;
- //NOTE: the check for computation CP instructions ensures that the output
- // will always fit in memory and hence can be pinned unconditionally
- if (LineageCacheConfig.isReusable(inst, ec)) {
- ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
- ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
- ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
- GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
- //TODO: Replace with generic type
-
- LineageItem instLI = (cinst != null) ? cinst.getLineageItem(ec).getValue()
- : (cfinst != null) ? cfinst.getLineageItem(ec).getValue()
- : (cspinst != null) ? cspinst.getLineageItem(ec).getValue()
- : gpuinst.getLineageItem(ec).getValue();
- List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
- if (inst instanceof MultiReturnBuiltinCPInstruction) {
- liList = new ArrayList<>();
- MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
- for (int i=0; i<mrInst.getNumOutputs(); i++) {
- String opcode = instLI.getOpcode() + String.valueOf(i);
- liList.add(MutablePair.of(new LineageItem(opcode, instLI.getInputs()), null));
- }
- }
- else
- liList = Arrays.asList(MutablePair.of(instLI, null));
-
+ if (LineageCacheConfig.isReusable(inst, ec))
+ {
+ List<MutablePair<LineageItem, LineageCacheEntry>> liList = getLineageItems(inst, ec);
+
//atomic try reuse full/partial and set placeholder, without
//obtaining value to avoid blocking in critical section
LineageCacheEntry e = null;
@@ -131,49 +109,28 @@ public class LineageCache
e = LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
//TODO need to also move execution of compensation plan out of here
//(create lazily evaluated entry)
- if (e == null && LineageCacheConfig.getCacheType().isPartialReuse() && cspinst == null)
+ if (e == null && LineageCacheConfig.getCacheType().isPartialReuse()
+ && !(inst instanceof ComputationSPInstruction))
if( LineageRewriteReuse.executeRewrites(inst, ec) )
e = getIntern(item.getKey());
- //TODO: Partial reuse for Spark instructions
reuseAll &= (e != null);
item.setValue(e);
//create a placeholder if no reuse to avoid redundancy
//(e.g., concurrent threads that try to start the computation)
- if(e == null && isMarkedForCaching(inst, ec)) {
- if (cinst != null)
- putIntern(item.getKey(), cinst.output.getDataType(), null, null, 0);
- else if (cfinst != null)
- putIntern(item.getKey(), cfinst.output.getDataType(), null, null, 0);
- else if (cspinst != null)
- putIntern(item.getKey(), cspinst.output.getDataType(), null, null, 0);
- else if (gpuinst != null)
- putIntern(item.getKey(), gpuinst._output.getDataType(), null, null, 0);
- //FIXME: different o/p datatypes for MultiReturnBuiltins.
- }
+ if(e == null && isMarkedForCaching(inst, ec))
+ putInternPlaceholder(inst, item.getKey());
}
}
reuse = reuseAll;
if(reuse) { //reuse
- boolean gpuReuse = false;
- //put reuse value into symbol table (w/ blocking on placeholders)
+ //put reused value into symbol table (w/ blocking on placeholders)
for (MutablePair<LineageItem, LineageCacheEntry> entry : liList) {
e = entry.getValue();
- String outName = null;
- if (inst instanceof MultiReturnBuiltinCPInstruction)
- outName = ((MultiReturnBuiltinCPInstruction)inst).
- getOutput(entry.getKey().getOpcode().charAt(entry.getKey().getOpcode().length()-1)-'0').getName();
- else if (inst instanceof ComputationCPInstruction)
- outName = cinst.output.getName();
- else if (inst instanceof ComputationFEDInstruction)
- outName = cfinst.output.getName();
- else if (inst instanceof ComputationSPInstruction)
- outName = cspinst.output.getName();
- else if (inst instanceof GPUInstruction)
- outName = gpuinst._output.getName();
-
- if (e.isMatrixValue() && e._gpuObject == null) {
+ String outName = getOutputName(inst, entry.getKey());
+
+ if (e.isMatrixValue() && !e.isGPUObject()) {
MatrixBlock mb = e.getMBValue(); //wait if another thread is executing the same inst.
if (mb == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
return false; //the executing thread removed this entry from cache
@@ -190,10 +147,13 @@ public class LineageCache
else if (e.isRDDPersist()) {
//Reuse the RDD which is also persisted in Spark
RDDObject rdd = e.getRDDObject();
+ if (!((SparkExecutionContext) ec).isRDDCached(rdd.getRDD().id()))
+ //Return if the RDD is not cached in the executors
+ return false;
if (rdd == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
return false;
else
- ((SparkExecutionContext)ec).setRDDHandleForVariable(outName, rdd);
+ ((SparkExecutionContext) ec).setRDDHandleForVariable(outName, rdd);
}
else { //TODO handle locks on gpu objects
//shallow copy the cached GPUObj to the output MatrixObject
@@ -201,26 +161,15 @@ public class LineageCache
ec.getGPUContext(0).shallowCopyGPUObject(e._gpuObject, ec.getMatrixObject(outName)));
//Set dirty to true, so that it is later copied to the host for write
ec.getMatrixObject(outName).getGPUObject(ec.getGPUContext(0)).setDirty(true);
- gpuReuse = true;
}
-
- reuse = true;
-
- if (DMLScript.STATISTICS) //increment saved time
- LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
- }
- if (DMLScript.STATISTICS) {
- if (gpuReuse)
- LineageCacheStatistics.incrementGpuHits();
- else
- LineageCacheStatistics.incrementInstHits();
}
+ maintainReuseStatistics(inst, liList.get(0).getValue());
}
}
return reuse;
}
-
+
public static boolean reuse(List<String> outNames, List<DataIdentifier> outParams,
int numOutputs, LineageItem[] liInputs, String name, ExecutionContext ec)
{
@@ -532,9 +481,7 @@ public class LineageCache
//if (!isMarkedForCaching(inst, ec)) return;
List<Pair<LineageItem, Data>> liData = null;
GPUObject liGpuObj = null;
- RDDObject rddObj = null;
LineageItem instLI = ((LineageTraceable) inst).getLineageItem(ec).getValue();
- LineageItem instInputLI = null;
if (inst instanceof MultiReturnBuiltinCPInstruction) {
liData = new ArrayList<>();
MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
@@ -556,14 +503,9 @@ public class LineageCache
if (liGpuObj == null)
liData = Arrays.asList(Pair.of(instLI, ec.getVariable(((GPUInstruction)inst)._output)));
}
- else if (inst instanceof CheckpointSPInstruction) {
- // Get the lineage of the instruction being checkpointed
- instInputLI = ec.getLineageItem(((ComputationSPInstruction)inst).input1);
- // Get the RDD handle of the persisted RDD
- CacheableData<?> cd = ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
- rddObj = ((CacheableData<?>) cd).getRDDHandle();
- // Remove the lineage item of the chkpoint instruction
- removePlaceholder(instLI);
+ else if (inst instanceof ComputationSPInstruction && ((ComputationSPInstruction) inst).isRDDtoCache()) {
+ putValueRDD(inst, instLI, ec, computetime);
+ return;
}
else
if (inst instanceof ComputationCPInstruction)
@@ -573,12 +515,10 @@ public class LineageCache
else if (inst instanceof ComputationSPInstruction)
liData = Arrays.asList(Pair.of(instLI, ec.getVariable(((ComputationSPInstruction) inst).output)));
- if (liGpuObj == null && rddObj == null)
+ if (liGpuObj == null)
putValueCPU(inst, liData, computetime);
- if (liGpuObj != null)
+ else
putValueGPU(liGpuObj, instLI, computetime);
- if (rddObj != null)
- putValueRDD(rddObj, instInputLI, computetime);
}
}
@@ -607,13 +547,6 @@ public class LineageCache
continue;
}
- if (LineageCacheConfig.isToPersist(inst) && LineageCacheConfig.getCompAssRW()) {
- // The immediately following instruction must be a checkpoint, which will
- // fill the rdd in this cache entry.
- // TODO: Instead check if this instruction is marked for checkpointing
- continue;
- }
-
if (data instanceof MatrixObject && ((MatrixObject) data).hasRDDHandle()) {
// Avoid triggering pre-matured Spark instruction chains
removePlaceholder(item);
@@ -672,18 +605,21 @@ public class LineageCache
}
}
- private static void putValueRDD(RDDObject rdd, LineageItem instLI, long computetime) {
+ private static void putValueRDD(Instruction inst, LineageItem instLI, ExecutionContext ec, long computetime) {
synchronized( _cache ) {
- // Not available in the cache indicates this RDD is not marked for caching
if (!probe(instLI))
return;
+ // Call persist on the output RDD
+ ((ComputationSPInstruction) inst).checkpointRDD(ec);
+ // Get the RDD handle of the persisted RDD
+ CacheableData<?> cd = ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
+ RDDObject rddObj = ((CacheableData<?>) cd).getRDDHandle();
LineageCacheEntry centry = _cache.get(instLI);
- if (centry.isRDDPersist() && centry.getRDDObject().isCheckpointRDD())
- // Do nothing if the cached RDD is already checkpointed
- return;
-
- centry.setRDDValue(rdd, computetime);
+ // Set the RDD object in the cache
+ // TODO: Make space in the executors
+ rddObj.setLineageCached();
+ centry.setRDDValue(rddObj, computetime);
// Maintain order for eviction
LineageCacheEviction.addEntry(centry);
}
@@ -879,7 +815,24 @@ public class LineageCache
//----------------- INTERNAL CACHE LOGIC IMPLEMENTATION --------------//
-
+
+ private static void putInternPlaceholder(Instruction inst, LineageItem key) {
+ ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+ ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+ ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+ GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
+
+ if (cinst != null)
+ putIntern(key, cinst.output.getDataType(), null, null, 0);
+ else if (cfinst != null)
+ putIntern(key, cfinst.output.getDataType(), null, null, 0);
+ else if (cspinst != null)
+ putIntern(key, cspinst.output.getDataType(), null, null, 0);
+ else if (gpuinst != null)
+ putIntern(key, gpuinst._output.getDataType(), null, null, 0);
+ //FIXME: different o/p datatypes for MultiReturnBuiltins.
+ }
+
private static void putIntern(LineageItem key, DataType dt, MatrixBlock Mval, ScalarObject Sval, long computetime) {
if (_cache.containsKey(key))
//can come here if reuse_partial option is enabled
@@ -1130,4 +1083,70 @@ public class LineageCache
return nflops / (2L * 1024 * 1024 * 1024);
}
+
+
+ //----------------- UTILITY FUNCTIONS --------------------//
+
+ private static List<MutablePair<LineageItem, LineageCacheEntry>> getLineageItems(Instruction inst, ExecutionContext ec) {
+ ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+ ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+ ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+ GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
+ //TODO: Replace with generic type
+
+ List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
+ LineageItem instLI = (cinst != null) ? cinst.getLineageItem(ec).getValue()
+ : (cfinst != null) ? cfinst.getLineageItem(ec).getValue()
+ : (cspinst != null) ? cspinst.getLineageItem(ec).getValue()
+ : gpuinst.getLineageItem(ec).getValue();
+ if (inst instanceof MultiReturnBuiltinCPInstruction) {
+ liList = new ArrayList<>();
+ MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
+ for (int i=0; i<mrInst.getNumOutputs(); i++) {
+ String opcode = instLI.getOpcode() + String.valueOf(i);
+ liList.add(MutablePair.of(new LineageItem(opcode, instLI.getInputs()), null));
+ }
+ }
+ else
+ liList = List.of(MutablePair.of(instLI, null));
+
+ return liList;
+ }
+
+ private static String getOutputName(Instruction inst, LineageItem li) {
+ ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+ ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+ ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+ GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
+
+ String outName = null;
+ if (inst instanceof MultiReturnBuiltinCPInstruction)
+ outName = ((MultiReturnBuiltinCPInstruction)inst).
+ getOutput(li.getOpcode().charAt(li.getOpcode().length()-1)-'0').getName();
+ else if (inst instanceof ComputationCPInstruction)
+ outName = cinst.output.getName();
+ else if (inst instanceof ComputationFEDInstruction)
+ outName = cfinst.output.getName();
+ else if (inst instanceof ComputationSPInstruction)
+ outName = cspinst.output.getName();
+ else if (inst instanceof GPUInstruction)
+ outName = gpuinst._output.getName();
+
+ return outName;
+ }
+ private static void maintainReuseStatistics(Instruction inst, LineageCacheEntry e) {
+ if (!DMLScript.STATISTICS)
+ return;
+
+ LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
+ if (e.isGPUObject()) LineageCacheStatistics.incrementGpuHits();
+ if (e.isRDDPersist()) LineageCacheStatistics.incrementRDDHits();
+ if (e.isMatrixValue() || e.isScalarValue()) {
+ if (inst instanceof ComputationSPInstruction || inst.getOpcode().equals("prefetch"))
+ // Single_block Spark instructions (sync/async) and prefetch
+ LineageCacheStatistics.incrementSparkCollectHits();
+ else
+ LineageCacheStatistics.incrementInstHits();
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 52b8399fcc..04ed47093d 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -58,13 +58,7 @@ public class LineageCacheConfig
//TODO: Reuse everything.
};
- private static final String[] OPCODES_CP = new String[] {
- "cpmm", "rmm"
- //TODO: Instead mark an instruction to be checkpointed
- };
-
private static String[] REUSE_OPCODES = new String[] {};
- private static String[] OPCODES_CHECKPOINTS = new String[] {};
public enum ReuseCacheType {
REUSE_FULL,
@@ -196,10 +190,9 @@ public class LineageCacheConfig
static {
//setup static configuration parameters
REUSE_OPCODES = OPCODES;
- OPCODES_CHECKPOINTS = OPCODES_CP;
- //setSpill(true);
+ //setSpill(true);
setCachePolicy(LineageCachePolicy.COSTNSIZE);
- setCompAssRW(false);
+ setCompAssRW(true);
}
public static void setReusableOpcodes(String... ops) {
@@ -210,10 +203,6 @@ public class LineageCacheConfig
return REUSE_OPCODES;
}
- public static boolean isToPersist(Instruction inst) {
- return ArrayUtils.contains(OPCODES_CHECKPOINTS, inst.getOpcode());
- }
-
public static void resetReusableOpcodes() {
REUSE_OPCODES = OPCODES;
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 0042674e56..8efe57a162 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -161,6 +161,10 @@ public class LineageCacheEntry {
return _rddObject != null;
}
+ public boolean isGPUObject() {
+ return _gpuObject != null;
+ }
+
public boolean isSerializedBytes() {
return _dt.isUnknown() && _key.getOpcode().equals(LineageItemUtils.SERIALIZATION_OPCODE);
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index fc34f7341a..01f2177b33 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -41,10 +41,13 @@ public class LineageCacheStatistics {
private static final LongAdder _ctimeFSWrite = new LongAdder();
private static final LongAdder _ctimeSaved = new LongAdder();
private static final LongAdder _ctimeMissed = new LongAdder();
- // Bellow entries are for specific to gpu lineage cache
+ // Bellow entries are specific to gpu lineage cache
private static final LongAdder _numHitsGpu = new LongAdder();
private static final LongAdder _numAsyncEvictGpu= new LongAdder();
private static final LongAdder _numSyncEvictGpu = new LongAdder();
+ // Below entries are specific to Spark instructions
+ private static final LongAdder _numHitsRdd = new LongAdder();
+ private static final LongAdder _numHitsSparkActions = new LongAdder();
public static void reset() {
_numHitsMem.reset();
@@ -64,6 +67,8 @@ public class LineageCacheStatistics {
_numHitsGpu.reset();
_numAsyncEvictGpu.reset();
_numSyncEvictGpu.reset();
+ _numHitsRdd.reset();
+ _numHitsSparkActions.reset();
}
public static void incrementMemHits() {
@@ -197,6 +202,17 @@ public class LineageCacheStatistics {
_numSyncEvictGpu.increment();
}
+ public static void incrementRDDHits() {
+ // Number of times a persisted RDD are reused.
+ _numHitsRdd.increment();
+ }
+
+ public static void incrementSparkCollectHits() {
+ // Spark instructions that bring intermediate back to local.
+ // Both synchronous and asynchronous (e.g. tsmm, prefetch)
+ _numHitsSparkActions.increment();
+ }
+
public static String displayHits() {
StringBuilder sb = new StringBuilder();
sb.append(_numHitsMem.longValue());
@@ -257,4 +273,12 @@ public class LineageCacheStatistics {
sb.append(_numSyncEvictGpu.longValue());
return sb.toString();
}
+
+ public static String displaySparkStats() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(_numHitsSparkActions.longValue());
+ sb.append("/");
+ sb.append(_numHitsRdd.longValue());
+ return sb.toString();
+ }
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index ca359a46e6..fbbce8049b 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -639,6 +639,7 @@ public class Statistics
sb.append("LinCache hits (Mem/FS/Del): \t" + LineageCacheStatistics.displayHits() + ".\n");
sb.append("LinCache MultiLevel (Ins/SB/Fn):" + LineageCacheStatistics.displayMultiLevelHits() + ".\n");
sb.append("LinCache GPU (Hit/Async/Sync): \t" + LineageCacheStatistics.displayGpuStats() + ".\n");
+ sb.append("LinCache Spark (Col/RDD): \t\t" + LineageCacheStatistics.displaySparkStats() + ".\n");
sb.append("LinCache writes (Mem/FS/Del): \t" + LineageCacheStatistics.displayWtrites() + ".\n");
sb.append("LinCache FStimes (Rd/Wr): \t" + LineageCacheStatistics.displayFSTime() + " sec.\n");
sb.append("LinCache Computetime (S/M): \t" + LineageCacheStatistics.displayComputeTime() + " sec.\n");
diff --git a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index 5b49bb82fa..98c4d14c83 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -35,6 +35,7 @@ package org.apache.sysds.test.functions.async;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
+ import org.junit.Ignore;
import org.junit.Test;
public class LineageReuseSparkTest extends AutomatedTestBase {
@@ -62,6 +63,13 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
}
+ @Ignore
+ @Test
+ public void testlmdsRDD() {
+ // Persist and cache RDDs of shuffle-based Spark operations (eg. rmm, cpmm)
+ runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
+ }
+
public void runTest(String testname, ExecMode execMode, int testId) {
boolean old_simplification = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean old_sum_product = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -90,6 +98,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
HashMap<MatrixValue.CellIndex, Double> R = readDMLScalarFromOutputDir("R");
long numTsmm = Statistics.getCPHeavyHitterCount("sp_tsmm");
long numMapmm = Statistics.getCPHeavyHitterCount("sp_mapmm");
+ long numRmm = Statistics.getCPHeavyHitterCount("sp_rmm");
proArgs.clear();
proArgs.add("-explain");
@@ -105,6 +114,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
HashMap<MatrixValue.CellIndex, Double> R_reused = readDMLScalarFromOutputDir("R");
long numTsmm_r = Statistics.getCPHeavyHitterCount("sp_tsmm");
long numMapmm_r = Statistics.getCPHeavyHitterCount("sp_mapmm");
+ long numRmm_r = Statistics.getCPHeavyHitterCount("sp_rmm");
//compare matrices
boolean matchVal = TestUtils.compareMatrices(R, R_reused, 1e-6, "Origin", "withPrefetch");
@@ -114,6 +124,8 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
Assert.assertTrue("Violated sp_tsmm reuse count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
Assert.assertTrue("Violated sp_mapmm reuse count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
}
+ if (testId == 2)
+ Assert.assertTrue("Violated sp_rmm reuse count: " + numRmm_r + " < " + numRmm, numRmm_r < numRmm);
} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = old_simplification;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = old_sum_product;
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml b/src/test/scripts/functions/async/LineageReuseSpark2.dml
new file mode 100644
index 0000000000..22f127c07d
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark2.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, Integer N) return (Matrix[double] beta)
+{
+ # Reuse sp_tsmm and sp_mapmm if not future-based
+ A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+ b = t(X) %*% y;
+ beta = solve(A, b);
+}
+
+no_lamda = 10;
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
+
+X = rand(rows=1500, cols=1500, seed=42);
+y = rand(rows=1500, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+ beta = SimlinRegDS(X, y, lamda, N);
+ #beta = lmDS(X=X, y=y, reg=lamda);
+ R[,i] = beta;
+ lamda = lamda + stp;
+ i = i + 1;
+}
+
+R = sum(R);
+write(R, $1, format="text");
+