You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/12/21 22:57:07 UTC

[systemds] branch main updated: [SYSTEMDS-3479] Mark Spark instructions to persist and locally cache

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f89a38da04 [SYSTEMDS-3479] Mark Spark instructions to persist and locally cache
f89a38da04 is described below

commit f89a38da041e685651afe016aafbe70288a31a66
Author: Arnab Phani <ph...@gmail.com>
AuthorDate: Wed Dec 21 23:56:24 2022 +0100

    [SYSTEMDS-3479] Mark Spark instructions to persist and locally cache
    
    This patch adds the compiler flags and runtime support to checkpoint
    any Spark instruction which is marked for caching. During postprocessing
    of a marked instruction, we first inplace persist the RDD and then store
    the RDD in the local Lineage cache for reuse. This patch also fixes a
    bug in the last commit which was unpersisting the locally cached RDDs
    during rmvar. Future commits will add rewrites to mark the Spark
    instructions for caching in a cost-based manner.
    Hyperparameter tuning of LmDS with 2.5k columns improves by
    22x by caching the cpmm results in the executors.
    
    Closes #1756
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    |   4 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |   8 +
 src/main/java/org/apache/sysds/lops/MMCJ.java      |   2 +
 src/main/java/org/apache/sysds/lops/MMRJ.java      |   4 +-
 .../org/apache/sysds/lops/OutputParameters.java    |   6 +-
 .../controlprogram/caching/MatrixObject.java       |   2 +-
 .../context/SparkExecutionContext.java             |   4 +-
 .../spark/AggregateBinarySPInstruction.java        |   8 +-
 .../instructions/spark/BinarySPInstruction.java    |   7 +-
 .../spark/ComputationSPInstruction.java            |  40 ++++
 .../instructions/spark/CpmmSPInstruction.java      |   8 +-
 .../instructions/spark/RmmSPInstruction.java       |   8 +-
 .../instructions/spark/data/LineageObject.java     |  10 +
 .../runtime/instructions/spark/data/RDDObject.java |   5 +
 .../apache/sysds/runtime/lineage/LineageCache.java | 215 +++++++++++----------
 .../sysds/runtime/lineage/LineageCacheConfig.java  |  15 +-
 .../sysds/runtime/lineage/LineageCacheEntry.java   |   4 +
 .../runtime/lineage/LineageCacheStatistics.java    |  26 ++-
 .../java/org/apache/sysds/utils/Statistics.java    |   1 +
 .../functions/async/LineageReuseSparkTest.java     |  12 ++
 .../scripts/functions/async/LineageReuseSpark2.dml |  53 +++++
 21 files changed, 316 insertions(+), 126 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index dd04307229..9f80f7a683 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -217,7 +217,7 @@ public class AggBinaryOp extends MultiThreadedHop {
 						input1.getDim1(), input1.getDim2(), input1.getBlocksize(), input1.getNnz(),
 						input2.getDim1(), input2.getDim2(), input2.getBlocksize(), input2.getNnz(),
 						mmtsj, chain, _hasLeftPMInput, tmmRewrite );
-				//dispatch SPARK lops construction 
+				//dispatch SPARK lops construction
 				switch( _method )
 				{
 					case TSMM:
@@ -790,6 +790,7 @@ public class AggBinaryOp extends MultiThreadedHop {
 			Lop cpmm = new MMCJ(getInput().get(0).constructLops(), getInput().get(1).constructLops(),
 				getDataType(), getValueType(), _outputEmptyBlocks, aggtype, ExecType.SPARK);
 			setOutputDimensions( cpmm );
+			//setMarkForLineageCaching(cpmm);
 			setLineNumbers( cpmm );
 			setLops( cpmm );
 		}
@@ -823,6 +824,7 @@ public class AggBinaryOp extends MultiThreadedHop {
 		Lop rmm = new MMRJ(getInput().get(0).constructLops(),getInput().get(1).constructLops(), 
 			getDataType(), getValueType(), ExecType.SPARK);
 		setOutputDimensions(rmm);
+		//setMarkForLineageCaching(rmm);
 		setLineNumbers( rmm );
 		setLops(rmm);
 	}
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 3988a6b59f..fa911749ee 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -57,6 +57,7 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -1235,6 +1236,13 @@ public abstract class Hop implements ParseInfo {
 			getDim1(), getDim2(), getBlocksize(), getNnz(), getUpdateType());
 	}
 
+	protected void setMarkForLineageCaching(Lop lop) {
+		//TODO: set the flag in the HOP via a rewrite
+		//lop.getOutputParameters().setLineageCacheCandidate(requiresLineageCaching());
+		if (!LineageCacheConfig.ReuseCacheType.isNone())
+			lop.getOutputParameters().setLineageCacheCandidate(true);
+	}
+
 	protected void setOutputDimensionsIncludeCompressedSize(Lop lop) {
 		lop.getOutputParameters().setDimensions(
 			getDim1(), getDim2(), getBlocksize(), getNnz(), getUpdateType(), getCompressedSize());
diff --git a/src/main/java/org/apache/sysds/lops/MMCJ.java b/src/main/java/org/apache/sysds/lops/MMCJ.java
index e804f84e9b..544e89341b 100644
--- a/src/main/java/org/apache/sysds/lops/MMCJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMCJ.java
@@ -109,6 +109,8 @@ public class MMCJ extends Lop
 		}
 		else
 			sb.append(_type.name());
+		sb.append( OPERAND_DELIMITOR );
+		sb.append(getOutputParameters().getLinCacheMarking());
 		
 		return sb.toString();
 	}
diff --git a/src/main/java/org/apache/sysds/lops/MMRJ.java b/src/main/java/org/apache/sysds/lops/MMRJ.java
index e35cdead7e..21577eeddb 100644
--- a/src/main/java/org/apache/sysds/lops/MMRJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMRJ.java
@@ -59,11 +59,13 @@ public class MMRJ extends Lop
 
 	@Override
 	public String getInstructions(String input1, String input2, String output) {
+		boolean toCache = getOutputParameters().getLinCacheMarking();
 		return InstructionUtils.concatOperands(
 			getExecType().name(),
 			"rmm",
 			getInputs().get(0).prepInputOperand(input1),
 			getInputs().get(1).prepInputOperand(input2),
-			prepOutputOperand(output));
+			prepOutputOperand(output),
+			Boolean.toString(toCache));
 	}
 }
diff --git a/src/main/java/org/apache/sysds/lops/OutputParameters.java b/src/main/java/org/apache/sysds/lops/OutputParameters.java
index 64ba755395..9454f19a13 100644
--- a/src/main/java/org/apache/sysds/lops/OutputParameters.java
+++ b/src/main/java/org/apache/sysds/lops/OutputParameters.java
@@ -39,7 +39,7 @@ public class OutputParameters
 	private long _blocksize = -1;
 	private String _file_name = null;
 	private String _file_label = null;
-	private boolean _linCacheCandidate = true;
+	private boolean _linCacheCandidate = false;
 	private long _compressedSize = -1;
 
 	FileFormat matrix_format = FileFormat.BINARY;
@@ -162,6 +162,10 @@ public class OutputParameters
 	public void setUpdateType(UpdateType update) {
 		_updateType = update;
 	}
+
+	public void setLineageCacheCandidate(boolean reqCaching) {
+		_linCacheCandidate = reqCaching;
+	}
 	
 	public boolean getLinCacheMarking() {
 		return _linCacheCandidate;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index c723cc56fa..e0139f2a62 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -476,7 +476,7 @@ public class MatrixObject extends CacheableData<MatrixBlock> {
 		FileFormat fmt = iimd.getFileFormat();
 		MatrixBlock mb = null;
 		try {
-			// prevent unnecessary collect through rdd checkpoint
+			// prevent unnecessary collect through rdd checkpoint (unless lineage cached)
 			if(rdd.allowsShortCircuitCollect()) {
 				lrdd = (RDDObject) rdd.getLineageChilds().get(0);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 48778cb4d4..77eca47640 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1522,7 +1522,9 @@ public class SparkExecutionContext extends ExecutionContext
 		if( lob instanceof RDDObject ) {
 			RDDObject rdd = (RDDObject)lob;
 			int rddID = rdd.getRDD().id();
-			cleanupRDDVariable(rdd.getRDD());
+			//skip unpersisting if locally cached
+			if (!lob.isInLineageCache())
+				cleanupRDDVariable(rdd.getRDD());
 			if( rdd.getHDFSFilename()!=null ) { //deferred file removal
 				HDFSTool.deleteFileWithMTDIfExistOnHDFS(rdd.getHDFSFilename());
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
index 80de732505..4575b06252 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateBinarySPInstruction.java
@@ -26,8 +26,12 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
  * Class to group the different MM <code>SPInstruction</code>s together.
  */
 public abstract class AggregateBinarySPInstruction extends BinarySPInstruction {
-	protected AggregateBinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
-			String opcode, String istr) {
+	protected AggregateBinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
 		super(type, op, in1, in2, out, opcode, istr);
 	}
+
+	protected AggregateBinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+		String opcode, boolean toCache, String istr) {
+		super(type, op, in1, in2, out, opcode, toCache, istr);
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index 16196d5e0d..3c70d4021a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -56,7 +56,12 @@ public abstract class BinarySPInstruction extends ComputationSPInstruction {
 	protected BinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
 		super(type, op, in1, in2, out, opcode, istr);
 	}
-	
+
+	protected BinarySPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2,
+		CPOperand out, String opcode, boolean toCache, String istr) {
+		super(type, op, in1, in2, out, opcode, toCache, istr);
+	}
+
 	public static BinarySPInstruction parseInstruction ( String str ) {
 		CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
 		CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
index d380d913b2..465ce35820 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ComputationSPInstruction.java
@@ -20,7 +20,10 @@
 package org.apache.sysds.runtime.instructions.spark;
 
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.storage.StorageLevel;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.functionobjects.IndexFunction;
@@ -28,15 +31,22 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll;
 import org.apache.sysds.runtime.functionobjects.ReduceCol;
 import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
+import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.lineage.LineageTraceable;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
+import java.util.Map;
+
 public abstract class ComputationSPInstruction extends SPInstruction implements LineageTraceable {
 	public CPOperand output;
 	public CPOperand input1, input2, input3;
+	private boolean toPersistAndCache;
 
 	protected ComputationSPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
 		super(type, op, opcode, istr);
@@ -46,6 +56,15 @@ public abstract class ComputationSPInstruction extends SPInstruction implements
 		output = out;
 	}
 
+	protected ComputationSPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, boolean toCache, String istr) {
+		super(type, op, opcode, istr);
+		input1 = in1;
+		input2 = in2;
+		input3 = null;
+		output = out;
+		toPersistAndCache = toCache;
+	}
+
 	protected ComputationSPInstruction(SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
 		super(type, op, opcode, istr);
 		input1 = in1;
@@ -126,6 +145,27 @@ public abstract class ComputationSPInstruction extends SPInstruction implements
 				mcOut.set(1, mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
 		}
 	}
+
+	public boolean isRDDtoCache() {
+		return toPersistAndCache;
+	}
+
+	public void checkpointRDD(ExecutionContext ec) {
+		if (!toPersistAndCache)
+			return;
+
+		SparkExecutionContext sec = (SparkExecutionContext)ec;
+		CacheableData<?> cd = sec.getCacheableData(output.getName());
+		RDDObject inro =  cd.getRDDHandle();
+		JavaPairRDD<?,?> outrdd = SparkUtils.copyBinaryBlockMatrix((JavaPairRDD<MatrixIndexes, MatrixBlock>)inro.getRDD(), false);
+		//TODO: remove shallow copying as short-circuit collect is disabled if locally cached
+		outrdd = outrdd.persist((StorageLevel.MEMORY_AND_DISK()));
+		RDDObject outro = new RDDObject(outrdd); //create new rdd object
+		outro.setCheckpointRDD(true);            //mark as checkpointed
+		outro.addLineageChild(inro);             //keep lineage to prevent cycles on cleanup
+		cd.setRDDHandle(outro);
+		sec.setVariable(output.getName(), cd);
+	}
 	
 	@Override
 	public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 79832eabe2..253480b482 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -66,8 +66,9 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
 	private final boolean _outputEmptyBlocks;
 	private final SparkAggType _aggtype;
 	
-	private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, String istr) {
-		super(SPType.CPMM, op, in1, in2, out, opcode, istr);
+	private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+		boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, boolean toCache, String istr) {
+		super(SPType.CPMM, op, in1, in2, out, opcode, toCache, istr);
 		_outputEmptyBlocks = outputEmptyBlocks;
 		_aggtype = aggtype;
 	}
@@ -83,7 +84,8 @@ public class CpmmSPInstruction extends AggregateBinarySPInstruction {
 		AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
 		boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
 		SparkAggType aggtype = SparkAggType.valueOf(parts[5]);
-		return new CpmmSPInstruction(aggbin, in1, in2, out, outputEmptyBlocks, aggtype, opcode, str);
+		boolean toCache = parts.length == 7 ? Boolean.parseBoolean(parts[6]) : false;
+		return new CpmmSPInstruction(aggbin, in1, in2, out, outputEmptyBlocks, aggtype, opcode, toCache, str);
 	}
 	
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
index 70d4bef1dd..130045a890 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/RmmSPInstruction.java
@@ -49,8 +49,9 @@ import java.util.LinkedList;
 
 public class RmmSPInstruction extends AggregateBinarySPInstruction {
 
-	private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
-		super(SPType.RMM, op, in1, in2, out, opcode, istr);
+	private RmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+		String opcode, boolean toCache, String istr) {
+		super(SPType.RMM, op, in1, in2, out, opcode, toCache, istr);
 	}
 
 	public static RmmSPInstruction parseInstruction( String str ) {
@@ -61,8 +62,9 @@ public class RmmSPInstruction extends AggregateBinarySPInstruction {
 			CPOperand in1 = new CPOperand(parts[1]);
 			CPOperand in2 = new CPOperand(parts[2]);
 			CPOperand out = new CPOperand(parts[3]);
+			boolean toCache = parts.length == 5 ? Boolean.parseBoolean(parts[4]) : false;
 			
-			return new RmmSPInstruction(null, in1, in2, out, opcode, str);
+			return new RmmSPInstruction(null, in1, in2, out, opcode, toCache, str);
 		} 
 		else {
 			throw new DMLRuntimeException("RmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
index 0882dd1d9a..f4b99bb03e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
@@ -28,6 +28,7 @@ public abstract class LineageObject
 {
 	//basic lineage information
 	protected int _numRef = -1;
+	protected boolean _lineageCached = false;
 	protected final List<LineageObject> _childs;
 	
 	//N:1 back reference to matrix/frame object
@@ -35,6 +36,7 @@ public abstract class LineageObject
 	
 	protected LineageObject() {
 		_numRef = 0;
+		_lineageCached = false;
 		_childs = new ArrayList<>();
 	}
 	
@@ -49,6 +51,14 @@ public abstract class LineageObject
 	public boolean hasBackReference() {
 		return (_cd != null);
 	}
+
+	public void setLineageCached() {
+		_lineageCached = true;
+	}
+
+	public boolean isInLineageCache() {
+		return _lineageCached;
+	}
 	
 	public void incrementNumReferences() {
 		_numRef++;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
index 2b03a00d31..04d021b6ff 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/RDDObject.java
@@ -103,6 +103,11 @@ public class RDDObject extends LineageObject
 
 	public boolean allowsShortCircuitCollect()
 	{
+		// If the RDD is marked to be persisted and cached locally, we want to collect the RDD
+		// so that the next time we can reuse the RDD.
+		if (isInLineageCache())
+			return false;
+
 		return ( isCheckpointRDD() && getLineageChilds().size() == 1
 			     && getLineageChilds().get(0) instanceof RDDObject );
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index d4eb4b8f92..0391583b4c 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -53,7 +53,6 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
 import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
-import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
@@ -95,31 +94,10 @@ public class LineageCache
 			return false;
 		
 		boolean reuse = false;
-		//NOTE: the check for computation CP instructions ensures that the output
-		// will always fit in memory and hence can be pinned unconditionally
-		if (LineageCacheConfig.isReusable(inst, ec)) {
-			ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
-			ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
-			ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
-			GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
-			//TODO: Replace with generic type
-				
-			LineageItem instLI = (cinst != null) ? cinst.getLineageItem(ec).getValue()
-					: (cfinst != null) ? cfinst.getLineageItem(ec).getValue()
-					: (cspinst != null) ? cspinst.getLineageItem(ec).getValue()
-					: gpuinst.getLineageItem(ec).getValue();
-			List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
-			if (inst instanceof MultiReturnBuiltinCPInstruction) {
-				liList = new ArrayList<>();
-				MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
-				for (int i=0; i<mrInst.getNumOutputs(); i++) {
-					String opcode = instLI.getOpcode() + String.valueOf(i);
-					liList.add(MutablePair.of(new LineageItem(opcode, instLI.getInputs()), null));
-				}
-			}
-			else
-				liList = Arrays.asList(MutablePair.of(instLI, null));
-			
+		if (LineageCacheConfig.isReusable(inst, ec))
+		{
+			List<MutablePair<LineageItem, LineageCacheEntry>> liList = getLineageItems(inst, ec);
+
 			//atomic try reuse full/partial and set placeholder, without
 			//obtaining value to avoid blocking in critical section
 			LineageCacheEntry e = null;
@@ -131,49 +109,28 @@ public class LineageCache
 						e = LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
 					//TODO need to also move execution of compensation plan out of here
 					//(create lazily evaluated entry)
-					if (e == null && LineageCacheConfig.getCacheType().isPartialReuse() && cspinst == null)
+					if (e == null && LineageCacheConfig.getCacheType().isPartialReuse()
+						&& !(inst instanceof ComputationSPInstruction))
 						if( LineageRewriteReuse.executeRewrites(inst, ec) )
 							e = getIntern(item.getKey());
-					//TODO: Partial reuse for Spark instructions
 					reuseAll &= (e != null);
 					item.setValue(e);
 					
 					//create a placeholder if no reuse to avoid redundancy
 					//(e.g., concurrent threads that try to start the computation)
-					if(e == null && isMarkedForCaching(inst, ec)) {
-						if (cinst != null)
-							putIntern(item.getKey(), cinst.output.getDataType(), null, null,  0);
-						else if (cfinst != null)
-							putIntern(item.getKey(), cfinst.output.getDataType(), null, null,  0);
-						else if (cspinst != null)
-							putIntern(item.getKey(), cspinst.output.getDataType(), null, null,  0);
-						else if (gpuinst != null)
-							putIntern(item.getKey(), gpuinst._output.getDataType(), null, null,  0);
-						//FIXME: different o/p datatypes for MultiReturnBuiltins.
-					}
+					if(e == null && isMarkedForCaching(inst, ec))
+						putInternPlaceholder(inst, item.getKey());
 				}
 			}
 			reuse = reuseAll;
 			
 			if(reuse) { //reuse
-				boolean gpuReuse = false;
-				//put reuse value into symbol table (w/ blocking on placeholders)
+				//put reused value into symbol table (w/ blocking on placeholders)
 				for (MutablePair<LineageItem, LineageCacheEntry> entry : liList) {
 					e = entry.getValue();
-					String outName = null;
-					if (inst instanceof MultiReturnBuiltinCPInstruction)
-						outName = ((MultiReturnBuiltinCPInstruction)inst).
-							getOutput(entry.getKey().getOpcode().charAt(entry.getKey().getOpcode().length()-1)-'0').getName(); 
-					else if (inst instanceof ComputationCPInstruction)
-						outName = cinst.output.getName();
-					else if (inst instanceof ComputationFEDInstruction)
-						outName = cfinst.output.getName();
-					else if (inst instanceof ComputationSPInstruction)
-						outName = cspinst.output.getName();
-					else if (inst instanceof GPUInstruction)
-						outName = gpuinst._output.getName();
-					
-					if (e.isMatrixValue() && e._gpuObject == null) {
+					String outName = getOutputName(inst, entry.getKey());
+
+					if (e.isMatrixValue() && !e.isGPUObject()) {
 						MatrixBlock mb = e.getMBValue(); //wait if another thread is executing the same inst.
 						if (mb == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
 							return false;  //the executing thread removed this entry from cache
@@ -190,10 +147,13 @@ public class LineageCache
 					else if (e.isRDDPersist()) {
 						//Reuse the RDD which is also persisted in Spark
 						RDDObject rdd = e.getRDDObject();
+						if (!((SparkExecutionContext) ec).isRDDCached(rdd.getRDD().id()))
+							//Return if the RDD is not cached in the executors
+							return false;
 						if (rdd == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
 							return false;
 						else
-							((SparkExecutionContext)ec).setRDDHandleForVariable(outName, rdd);
+							((SparkExecutionContext) ec).setRDDHandleForVariable(outName, rdd);
 					}
 					else { //TODO handle locks on gpu objects
 						//shallow copy the cached GPUObj to the output MatrixObject
@@ -201,26 +161,15 @@ public class LineageCache
 								ec.getGPUContext(0).shallowCopyGPUObject(e._gpuObject, ec.getMatrixObject(outName)));
 						//Set dirty to true, so that it is later copied to the host for write
 						ec.getMatrixObject(outName).getGPUObject(ec.getGPUContext(0)).setDirty(true);
-						gpuReuse = true;
 					}
-
-					reuse = true;
-
-					if (DMLScript.STATISTICS) //increment saved time
-						LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
-				}
-				if (DMLScript.STATISTICS) {
-					if (gpuReuse)
-						LineageCacheStatistics.incrementGpuHits();
-					else
-						LineageCacheStatistics.incrementInstHits();
 				}
+				maintainReuseStatistics(inst, liList.get(0).getValue());
 			}
 		}
 		
 		return reuse;
 	}
-	
+
 	public static boolean reuse(List<String> outNames, List<DataIdentifier> outParams, 
 			int numOutputs, LineageItem[] liInputs, String name, ExecutionContext ec)
 	{
@@ -532,9 +481,7 @@ public class LineageCache
 			//if (!isMarkedForCaching(inst, ec)) return;
 			List<Pair<LineageItem, Data>> liData = null;
 			GPUObject liGpuObj = null;
-			RDDObject rddObj = null;
 			LineageItem instLI = ((LineageTraceable) inst).getLineageItem(ec).getValue();
-			LineageItem instInputLI = null;
 			if (inst instanceof MultiReturnBuiltinCPInstruction) {
 				liData = new ArrayList<>();
 				MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
@@ -556,14 +503,9 @@ public class LineageCache
 				if (liGpuObj == null)
 					liData = Arrays.asList(Pair.of(instLI, ec.getVariable(((GPUInstruction)inst)._output)));
 			}
-			else if (inst instanceof CheckpointSPInstruction) {
-				// Get the lineage of the instruction being checkpointed
-				instInputLI = ec.getLineageItem(((ComputationSPInstruction)inst).input1);
-				// Get the RDD handle of the persisted RDD
-				CacheableData<?> cd = ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
-				rddObj = ((CacheableData<?>) cd).getRDDHandle();
-				// Remove the lineage item of the chkpoint instruction
-				removePlaceholder(instLI);
+			else if (inst instanceof ComputationSPInstruction && ((ComputationSPInstruction) inst).isRDDtoCache()) {
+				putValueRDD(inst, instLI, ec, computetime);
+				return;
 			}
 			else
 				if (inst instanceof ComputationCPInstruction)
@@ -573,12 +515,10 @@ public class LineageCache
 				else if (inst instanceof ComputationSPInstruction)
 					liData = Arrays.asList(Pair.of(instLI, ec.getVariable(((ComputationSPInstruction) inst).output)));
 
-			if (liGpuObj == null && rddObj == null)
+			if (liGpuObj == null)
 				putValueCPU(inst, liData, computetime);
-			if (liGpuObj != null)
+			else
 				putValueGPU(liGpuObj, instLI, computetime);
-			if (rddObj != null)
-				putValueRDD(rddObj, instInputLI, computetime);
 		}
 	}
 	
@@ -607,13 +547,6 @@ public class LineageCache
 					continue;
 				}
 
-				if (LineageCacheConfig.isToPersist(inst) && LineageCacheConfig.getCompAssRW()) {
-					// The immediately following instruction must be a checkpoint, which will
-					// fill the rdd in this cache entry.
-					// TODO: Instead check if this instruction is marked for checkpointing
-					continue;
-				}
-
 				if (data instanceof MatrixObject && ((MatrixObject) data).hasRDDHandle()) {
 					// Avoid triggering pre-matured Spark instruction chains
 					removePlaceholder(item);
@@ -672,18 +605,21 @@ public class LineageCache
 		}
 	}
 
-	private static void putValueRDD(RDDObject rdd, LineageItem instLI, long computetime) {
+	private static void putValueRDD(Instruction inst, LineageItem instLI, ExecutionContext ec, long computetime) {
 		synchronized( _cache ) {
-			// Not available in the cache indicates this RDD is not marked for caching
 			if (!probe(instLI))
 				return;
+			// Call persist on the output RDD
+			((ComputationSPInstruction) inst).checkpointRDD(ec);
+			// Get the RDD handle of the persisted RDD
+			CacheableData<?> cd = ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
+			RDDObject rddObj = ((CacheableData<?>) cd).getRDDHandle();
 
 			LineageCacheEntry centry = _cache.get(instLI);
-			if (centry.isRDDPersist() && centry.getRDDObject().isCheckpointRDD())
-				// Do nothing if the cached RDD is already checkpointed
-				return;
-
-			centry.setRDDValue(rdd, computetime);
+			// Set the RDD object in the cache
+			// TODO: Make space in the executors
+			rddObj.setLineageCached();
+			centry.setRDDValue(rddObj, computetime);
 			// Maintain order for eviction
 			LineageCacheEviction.addEntry(centry);
 		}
@@ -879,7 +815,24 @@ public class LineageCache
 
 	
 	//----------------- INTERNAL CACHE LOGIC IMPLEMENTATION --------------//
-	
+
+	private static void putInternPlaceholder(Instruction inst, LineageItem key) {
+		ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+		ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+		ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+		GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
+
+		if (cinst != null)
+			putIntern(key, cinst.output.getDataType(), null, null,  0);
+		else if (cfinst != null)
+			putIntern(key, cfinst.output.getDataType(), null, null,  0);
+		else if (cspinst != null)
+			putIntern(key, cspinst.output.getDataType(), null, null,  0);
+		else if (gpuinst != null)
+			putIntern(key, gpuinst._output.getDataType(), null, null,  0);
+		//FIXME: different o/p datatypes for MultiReturnBuiltins.
+	}
+
 	private static void putIntern(LineageItem key, DataType dt, MatrixBlock Mval, ScalarObject Sval, long computetime) {
 		if (_cache.containsKey(key))
 			//can come here if reuse_partial option is enabled
@@ -1130,4 +1083,70 @@ public class LineageCache
 		
 		return nflops / (2L * 1024 * 1024 * 1024);
 	}
+
+
+	//----------------- UTILITY FUNCTIONS --------------------//
+
+	private static List<MutablePair<LineageItem, LineageCacheEntry>> getLineageItems(Instruction inst, ExecutionContext ec) {
+		ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+		ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+		ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+		GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
+		//TODO: Replace with generic type
+
+		List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
+		LineageItem instLI = (cinst != null) ? cinst.getLineageItem(ec).getValue()
+			: (cfinst != null) ? cfinst.getLineageItem(ec).getValue()
+			: (cspinst != null) ? cspinst.getLineageItem(ec).getValue()
+			: gpuinst.getLineageItem(ec).getValue();
+		if (inst instanceof MultiReturnBuiltinCPInstruction) {
+			liList = new ArrayList<>();
+			MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
+			for (int i=0; i<mrInst.getNumOutputs(); i++) {
+				String opcode = instLI.getOpcode() + String.valueOf(i);
+				liList.add(MutablePair.of(new LineageItem(opcode, instLI.getInputs()), null));
+			}
+		}
+		else
+			liList = List.of(MutablePair.of(instLI, null));
+
+		return liList;
+	}
+
+	private static String getOutputName(Instruction inst, LineageItem li) {
+		ComputationCPInstruction cinst = inst instanceof ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+		ComputationFEDInstruction cfinst = inst instanceof ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+		ComputationSPInstruction cspinst = inst instanceof ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
+		GPUInstruction gpuinst = inst instanceof GPUInstruction ? (GPUInstruction)inst : null;
+
+		String outName = null;
+		if (inst instanceof MultiReturnBuiltinCPInstruction)
+			outName = ((MultiReturnBuiltinCPInstruction)inst).
+				getOutput(li.getOpcode().charAt(li.getOpcode().length()-1)-'0').getName();
+		else if (inst instanceof ComputationCPInstruction)
+			outName = cinst.output.getName();
+		else if (inst instanceof ComputationFEDInstruction)
+			outName = cfinst.output.getName();
+		else if (inst instanceof ComputationSPInstruction)
+			outName = cspinst.output.getName();
+		else if (inst instanceof GPUInstruction)
+			outName = gpuinst._output.getName();
+
+		return outName;
+	}
+	private static void maintainReuseStatistics(Instruction inst, LineageCacheEntry e) {
+		if (!DMLScript.STATISTICS)
+			return;
+
+		LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
+		if (e.isGPUObject()) LineageCacheStatistics.incrementGpuHits();
+		if (e.isRDDPersist()) LineageCacheStatistics.incrementRDDHits();
+		if (e.isMatrixValue() || e.isScalarValue()) {
+			if (inst instanceof ComputationSPInstruction || inst.getOpcode().equals("prefetch"))
+				// Single_block Spark instructions (sync/async) and prefetch
+				LineageCacheStatistics.incrementSparkCollectHits();
+			else
+				LineageCacheStatistics.incrementInstHits();
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 52b8399fcc..04ed47093d 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -58,13 +58,7 @@ public class LineageCacheConfig
 		//TODO: Reuse everything. 
 	};
 
-	private static final String[] OPCODES_CP = new String[] {
-		"cpmm", "rmm"
-		//TODO: Instead mark an instruction to be checkpointed
-	};
-
 	private static String[] REUSE_OPCODES  = new String[] {};
-	private static String[] OPCODES_CHECKPOINTS  = new String[] {};
 
 	public enum ReuseCacheType {
 		REUSE_FULL,
@@ -196,10 +190,9 @@ public class LineageCacheConfig
 	static {
 		//setup static configuration parameters
 		REUSE_OPCODES = OPCODES;
-		OPCODES_CHECKPOINTS = OPCODES_CP;
-		//setSpill(true); 
+		//setSpill(true);
 		setCachePolicy(LineageCachePolicy.COSTNSIZE);
-		setCompAssRW(false);
+		setCompAssRW(true);
 	}
 
 	public static void setReusableOpcodes(String... ops) {
@@ -210,10 +203,6 @@ public class LineageCacheConfig
 		return REUSE_OPCODES;
 	}
 
-	public static boolean isToPersist(Instruction inst) {
-		return ArrayUtils.contains(OPCODES_CHECKPOINTS, inst.getOpcode());
-	}
-	
 	public static void resetReusableOpcodes() {
 		REUSE_OPCODES = OPCODES;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 0042674e56..8efe57a162 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -161,6 +161,10 @@ public class LineageCacheEntry {
 		return _rddObject != null;
 	}
 
+	public boolean isGPUObject() {
+		return _gpuObject != null;
+	}
+
 	public boolean isSerializedBytes() {
 		return _dt.isUnknown() && _key.getOpcode().equals(LineageItemUtils.SERIALIZATION_OPCODE);
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index fc34f7341a..01f2177b33 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -41,10 +41,13 @@ public class LineageCacheStatistics {
 	private static final LongAdder _ctimeFSWrite    = new LongAdder();
 	private static final LongAdder _ctimeSaved      = new LongAdder();
 	private static final LongAdder _ctimeMissed     = new LongAdder();
-	// Bellow entries are for specific to gpu lineage cache
+	// Bellow entries are specific to gpu lineage cache
 	private static final LongAdder _numHitsGpu      = new LongAdder();
 	private static final LongAdder _numAsyncEvictGpu= new LongAdder();
 	private static final LongAdder _numSyncEvictGpu = new LongAdder();
+	// Below entries are specific to Spark instructions
+	private static final LongAdder _numHitsRdd      = new LongAdder();
+	private static final LongAdder _numHitsSparkActions = new LongAdder();
 
 	public static void reset() {
 		_numHitsMem.reset();
@@ -64,6 +67,8 @@ public class LineageCacheStatistics {
 		_numHitsGpu.reset();
 		_numAsyncEvictGpu.reset();
 		_numSyncEvictGpu.reset();
+		_numHitsRdd.reset();
+		_numHitsSparkActions.reset();
 	}
 	
 	public static void incrementMemHits() {
@@ -197,6 +202,17 @@ public class LineageCacheStatistics {
 		_numSyncEvictGpu.increment();
 	}
 
+	public static void incrementRDDHits() {
+		// Number of times a persisted RDD are reused.
+		_numHitsRdd.increment();
+	}
+
+	public static void incrementSparkCollectHits() {
+		// Spark instructions that bring intermediate back to local.
+		// Both synchronous and asynchronous (e.g. tsmm, prefetch)
+		_numHitsSparkActions.increment();
+	}
+
 	public static String displayHits() {
 		StringBuilder sb = new StringBuilder();
 		sb.append(_numHitsMem.longValue());
@@ -257,4 +273,12 @@ public class LineageCacheStatistics {
 		sb.append(_numSyncEvictGpu.longValue());
 		return sb.toString();
 	}
+
+	public static String displaySparkStats() {
+		StringBuilder sb = new StringBuilder();
+		sb.append(_numHitsSparkActions.longValue());
+		sb.append("/");
+		sb.append(_numHitsRdd.longValue());
+		return sb.toString();
+	}
 }
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index ca359a46e6..fbbce8049b 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -639,6 +639,7 @@ public class Statistics
 				sb.append("LinCache hits (Mem/FS/Del): \t" + LineageCacheStatistics.displayHits() + ".\n");
 				sb.append("LinCache MultiLevel (Ins/SB/Fn):" + LineageCacheStatistics.displayMultiLevelHits() + ".\n");
 				sb.append("LinCache GPU (Hit/Async/Sync): \t" + LineageCacheStatistics.displayGpuStats() + ".\n");
+				sb.append("LinCache Spark (Col/RDD): \t\t" + LineageCacheStatistics.displaySparkStats() + ".\n");
 				sb.append("LinCache writes (Mem/FS/Del): \t" + LineageCacheStatistics.displayWtrites() + ".\n");
 				sb.append("LinCache FStimes (Rd/Wr): \t" + LineageCacheStatistics.displayFSTime() + " sec.\n");
 				sb.append("LinCache Computetime (S/M): \t" + LineageCacheStatistics.displayComputeTime() + " sec.\n");
diff --git a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index 5b49bb82fa..98c4d14c83 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -35,6 +35,7 @@ package org.apache.sysds.test.functions.async;
 	import org.apache.sysds.test.TestUtils;
 	import org.apache.sysds.utils.Statistics;
 	import org.junit.Assert;
+	import org.junit.Ignore;
 	import org.junit.Test;
 
 public class LineageReuseSparkTest extends AutomatedTestBase {
@@ -62,6 +63,13 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 		runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
 	}
 
+	@Ignore
+	@Test
+	public void testlmdsRDD() {
+		// Persist and cache RDDs of shuffle-based Spark operations (eg. rmm, cpmm)
+		runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
+	}
+
 	public void runTest(String testname, ExecMode execMode, int testId) {
 		boolean old_simplification = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
 		boolean old_sum_product = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -90,6 +98,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 			HashMap<MatrixValue.CellIndex, Double> R = readDMLScalarFromOutputDir("R");
 			long numTsmm = Statistics.getCPHeavyHitterCount("sp_tsmm");
 			long numMapmm = Statistics.getCPHeavyHitterCount("sp_mapmm");
+			long numRmm = Statistics.getCPHeavyHitterCount("sp_rmm");
 
 			proArgs.clear();
 			proArgs.add("-explain");
@@ -105,6 +114,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 			HashMap<MatrixValue.CellIndex, Double> R_reused = readDMLScalarFromOutputDir("R");
 			long numTsmm_r = Statistics.getCPHeavyHitterCount("sp_tsmm");
 			long numMapmm_r = Statistics.getCPHeavyHitterCount("sp_mapmm");
+			long numRmm_r = Statistics.getCPHeavyHitterCount("sp_rmm");
 
 			//compare matrices
 			boolean matchVal = TestUtils.compareMatrices(R, R_reused, 1e-6, "Origin", "withPrefetch");
@@ -114,6 +124,8 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 				Assert.assertTrue("Violated sp_tsmm reuse count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
 				Assert.assertTrue("Violated sp_mapmm reuse count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
 			}
+			if (testId == 2)
+				Assert.assertTrue("Violated sp_rmm reuse count: " + numRmm_r + " < " + numRmm, numRmm_r < numRmm);
 		} finally {
 			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = old_simplification;
 			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = old_sum_product;
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml b/src/test/scripts/functions/async/LineageReuseSpark2.dml
new file mode 100644
index 0000000000..22f127c07d
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark2.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, Integer N) return (Matrix[double] beta)
+{
+  # Reuse sp_tsmm and sp_mapmm if not future-based
+  A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+  b = t(X) %*% y;
+  beta = solve(A, b);
+}
+
+no_lamda = 10;
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
+
+X = rand(rows=1500, cols=1500, seed=42);
+y = rand(rows=1500, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+  beta = SimlinRegDS(X, y, lamda, N);
+  #beta = lmDS(X=X, y=y, reg=lamda);
+  R[,i] = beta;
+  lamda = lamda + stp;
+  i = i + 1;
+}
+
+R = sum(R);
+write(R, $1, format="text");
+