You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/06/05 21:16:36 UTC

[systemds] branch main updated: [SYSTEMDS-1622] Fix federated left indexing with scalar inputs

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new fa81f6a40a [SYSTEMDS-1622] Fix federated left indexing with scalar inputs
fa81f6a40a is described below

commit fa81f6a40ae15c9f50e00cf8ec96af2626684e4f
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sun Jun 5 23:16:07 2022 +0200

    [SYSTEMDS-1622] Fix federated left indexing with scalar inputs
    
    This patch generalizes the federated left indexing instruction
    for scalar, and fixes a more general issue of replacing
    instruction operands for edge cases where the scalar matches
    federated input or output variable names.
    
    Closes #1622.
    
    Co-authored-by: Matthias Boehm <mb...@gmail.com>
---
 .../federated/FederatedLookupTable.java            |   4 +
 .../federated/FederatedWorkerHandler.java          |  10 +-
 .../controlprogram/federated/FederationMap.java    |   5 +
 .../controlprogram/federated/FederationUtils.java  |  21 ++-
 .../instructions/fed/IndexingFEDInstruction.java   | 143 +++++++++-----
 .../primitives/FederatedLeftIndexTest.java         | 205 +++++++++++----------
 .../federated/FederatedLeftIndexFrameFullTest.dml  |   2 -
 .../FederatedLeftIndexFrameFullTestReference.dml   |   2 -
 .../federated/FederatedLeftIndexFullTest.dml       |   2 -
 .../FederatedLeftIndexFullTestReference.dml        |   2 -
 ...llTest.dml => FederatedLeftIndexScalarTest.dml} |  17 +-
 ...l => FederatedLeftIndexScalarTestReference.dml} |  17 +-
 12 files changed, 252 insertions(+), 178 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index 55ab9715e2..afba8ac42a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -47,6 +47,10 @@ public class FederatedLookupTable {
 		_lookup_table = new ConcurrentHashMap<>();
 	}
 
+	public void clear() {
+		_lookup_table.clear();
+	}
+	
 	/**
 	 * Get the ExecutionContextMap corresponding to the given host and pid of the
 	 * requesting coordinator from the lookup table. Create a new
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 592f77ccce..47cedd739c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -220,8 +220,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 				containsCLEAR = true;
 		}
 
-		if(containsCLEAR)
+		if(containsCLEAR) {
+			_flt.clear();
 			printStatistics();
+		}
 
 		return response;
 	}
@@ -398,7 +400,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		checkNumParams(request.getNumParams(), 1, 2);
 		final String varName = String.valueOf(request.getID());
 		ExecutionContext ec = ecm.get(request.getTID());
-
+		
 		if(ec.containsVariable(varName)) {
 			final Data tgtData = ec.removeVariable(varName);
 			if(tgtData != null)
@@ -450,7 +452,6 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 
 	private FederatedResponse getVariable(FederatedRequest request, ExecutionContextMap ecm) {
 		try{
-
 			checkNumParams(request.getNumParams(), 0);
 			ExecutionContext ec = ecm.get(request.getTID());
 			if(!ec.containsVariable(String.valueOf(request.getID())))
@@ -494,7 +495,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		//handle missing spark execution context
 		//TODO handling of spark instructions should be under control of federated site not coordinator
 		if(ins.getType() == IType.SPARK
-		&& !(ec instanceof SparkExecutionContext) ) {
+			&& !(ec instanceof SparkExecutionContext) )
+		{
 			ecm.convertToSparkCtx();
 			return ecm.get(id);
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 0053a8b2fe..fcef0d7984 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -351,6 +351,11 @@ public class FederationMap {
 		return ret.toArray(new Future[0]);
 	}
 
+	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRange[] fedRange1,
+		FederatedRequest elseFr, FederatedRequest frSlice1, FederatedRequest frSlice2, FederatedRequest fr) {
+		return execute(tid, wait, fedRange1, elseFr, new FederatedRequest[]{frSlice1}, new FederatedRequest[]{frSlice2}, fr);
+	}
+
 	@SuppressWarnings("unchecked")
 	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRange[] fedRange1, FederatedRequest elseFr, FederatedRequest[] frSlices1, FederatedRequest[] frSlices2, FederatedRequest... fr) {
 		// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 671cd0b744..82f78fe176 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -126,19 +126,26 @@ public class FederationUtils {
 		String[] linst = inst;
 		FederatedRequest[] fr = new FederatedRequest[inst.length];
 		for(int j=0; j<inst.length; j++) {
+			linst[j] = InstructionUtils.replaceOperand(linst[j], 0, type == null ?
+				InstructionUtils.getExecType(linst[j]).name() : type.name());
+			// replace inputs before before outputs in order to prevent conflicts
+			// on outputId matching input literals (due to a mix of input instructions,
+			// have to apply this replacement even for literal inputs)
 			for(int i = 0; i < varOldIn.length; i++) {
-				linst[j] = InstructionUtils.replaceOperand(linst[j], 0, type == null ? InstructionUtils.getExecType(linst[j]).name() : type.name());
-				linst[j] = linst[j].replace(
-					Lop.OPERAND_DELIMITOR + varOldOut.getName() + Lop.DATATYPE_PREFIX,
-					Lop.OPERAND_DELIMITOR + String.valueOf(outputId) + Lop.DATATYPE_PREFIX);
-
-				if(varOldIn[i] != null) {
+				if( varOldIn[i] != null ) {
 					linst[j] = linst[j].replace(
 						Lop.OPERAND_DELIMITOR + varOldIn[i].getName() + Lop.DATATYPE_PREFIX,
 						Lop.OPERAND_DELIMITOR + String.valueOf(varNewIn[i]) + Lop.DATATYPE_PREFIX);
-					linst[j] = linst[j].replace("=" + varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i])); //parameterized
+					// handle parameterized builtin functions
+					linst[j] = linst[j].replace("=" + varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i]));
 				}
 			}
+			for(int i = 0; i < varOldIn.length; i++) {
+				linst[j] = linst[j].replace(
+					Lop.OPERAND_DELIMITOR + varOldOut.getName() + Lop.DATATYPE_PREFIX,
+					Lop.OPERAND_DELIMITOR + String.valueOf(outputId) + Lop.DATATYPE_PREFIX);
+			}
+			
 			fr[j] = new FederatedRequest(RequestType.EXEC_INST, outputId, (Object) linst[j]);
 		}
 		return fr;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index bc70b398f9..4e4448ba97 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -25,8 +25,10 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 import org.apache.sysds.lops.LeftIndex;
@@ -44,6 +46,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.IndexRange;
@@ -150,7 +153,7 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 		List<Types.ValueType> schema = new ArrayList<>();
 		// replace old reshape values for each worker
 		int i = 0;
-		for(org.apache.commons.lang3.tuple.Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
+		for(Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
 			FederatedRange range = e.getKey();
 			long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
 				cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
@@ -204,7 +207,8 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 	{
 		//get input and requested index range
 		CacheableData<?> in1 = ec.getCacheableData(input1);
-		CacheableData<?> in2 = ec.getCacheableData(input2);
+		CacheableData<?> in2 = null; // either in2 or scalar is set
+		ScalarObject scalar = null;
 		IndexRange ixrange = getIndexRange(ec);
 
 		//check bounds
@@ -213,11 +217,21 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 			throw new DMLRuntimeException("Invalid values for matrix indexing: ["+(ixrange.rowStart+1)+":"+(ixrange.rowEnd+1)+","
 				+ (ixrange.colStart+1)+":"+(ixrange.colEnd+1)+"] " + "must be within matrix dimensions ["+in1.getNumRows()+","+in1.getNumColumns()+"].");
 		}
-		if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) {
-			throw new DMLRuntimeException("Invalid values for matrix indexing: " +
-				"dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
-				"do not match the shape of the matrix specified by indices [" +
-				(ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "].");
+
+		if(input2.getDataType() == DataType.SCALAR) {
+			if(!ixrange.isScalar())
+				throw new DMLRuntimeException("Invalid index range for leftindexing with scalar: " + ixrange.toString() + ".");
+
+			scalar = ec.getScalarInput(input2);
+		}
+		else {
+			in2 = ec.getCacheableData(input2);
+			if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) {
+				throw new DMLRuntimeException("Invalid values for matrix indexing: " +
+					"dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
+					"do not match the shape of the matrix specified by indices [" +
+					(ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "].");
+			}
 		}
 
 		FederationMap fedMap = in1.getFedMapping();
@@ -226,9 +240,13 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 		int[][] sliceIxs = new int[fedMap.getSize()][];
 		FederatedRange[] ranges = new FederatedRange[fedMap.getSize()];
 
+		// instruction string for copying a partition at the federated site
+		int cpVarInstIx = fedMap.getSize();
+		String cpVarInstString = createCopyInstString();
+
 		// replace old reshape values for each worker
 		int i = 0, prev = 0, from = fedMap.getSize();
-		for(org.apache.commons.lang3.tuple.Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
+		for(Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
 			FederatedRange range = e.getKey();
 			long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
 				cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
@@ -239,29 +257,46 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 
 			long[] newIx = new long[]{(int) rsn, (int) ren, (int) csn, (int) cen};
 
-			// find ranges where to apply  leftIndex
-			long to;
-			if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 &&
-				to < in2.getNumRows() && ixrange.rowStart <= re) {
-				sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1};
-				prev = (int) (to + 1);
-
-				instStrings[i] = modifyIndices(newIx, 4, 8);
-				ranges[i] = range;
-				from = Math.min(i, from);
+			if(in2 != null) { // matrix, frame
+				// find ranges where to apply leftIndex
+				long to;
+				if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 &&
+					to < in2.getNumRows() && ixrange.rowStart <= re) {
+					sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1};
+					prev = (int) (to + 1);
+
+					instStrings[i] = modifyIndices(newIx, 4, 8);
+					ranges[i] = range;
+					from = Math.min(i, from);
+				}
+				else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 &&
+					to < in2.getNumColumns() && ixrange.colStart <= ce) {
+					sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to};
+					prev = (int) (to + 1);
+
+					instStrings[i] = modifyIndices(newIx, 4, 8);
+					ranges[i] = range;
+					from = Math.min(i, from);
+				}
+				else {
+					// TODO shallow copy, add more advanced update in place for federated
+					cpVarInstIx = Math.min(i, cpVarInstIx);
+					instStrings[i] = cpVarInstString;
+				}
 			}
-			else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 &&
-				to < in2.getNumColumns() && ixrange.colStart <= ce) {
-				sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to};
-				prev = (int) (to + 1);
-
-				instStrings[i] = modifyIndices(newIx, 4, 8);
-				ranges[i] = range;
-				from = Math.min(i, from);
+			else { // scalar
+				if(ixrange.rowStart >= rs && ixrange.rowEnd < re
+					&& ixrange.colStart >= cs && ixrange.colEnd < ce) {
+					instStrings[i] = modifyIndices(newIx, 4, 8);
+					instStrings[i] = changeScalarLiteralFlag(instStrings[i], 3);
+					ranges[i] = range;
+					from = Math.min(i, from);
+				}
+				else {
+					cpVarInstIx = Math.min(i, cpVarInstIx);
+					instStrings[i] = cpVarInstString;
+				}
 			}
-			else
-				// TODO shallow copy, add more advanced update in place for federated
-				instStrings[i] = createCopyInstString();
 
 			i++;
 		}
@@ -269,35 +304,44 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 		sliceIxs = Arrays.stream(sliceIxs).filter(Objects::nonNull).toArray(int[][] :: new);
 
 		long id = FederationUtils.getNextFedDataID();
+		//TODO remove explicit put (unnecessary in CP, only spark which is about to be cleaned up)
 		FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in1.getDataType());
 		fedMap.execute(getTID(), true, tmp);
 
-		FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
-			input2.isFrame(), sliceIxs);
-		FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
-			new long[]{fedMap.getID(), fr1[0].getID()}, null);
-		FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());
+		if(in2 != null) { // matrix, frame
+			FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
+				input2.isFrame(), sliceIxs);
+			FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
+				new long[]{fedMap.getID(), fr1[0].getID()}, null);
+			FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());
 
-		//execute federated instruction and cleanup intermediates
-		if(sliceIxs.length == fedMap.getSize())
-			fedMap.execute(getTID(), true, fr2, fr1, fr3);
-		else {
-			// get index of cpvar request
-			for(i = 0; i < fr2.length; i++)
-				if(i < from || i >= from + sliceIxs.length)
-					break;
-			fedMap.execute(getTID(), true, ranges, (fr2[i]), Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3);
+			//execute federated instruction and cleanup intermediates
+			if(sliceIxs.length == fedMap.getSize())
+				fedMap.execute(getTID(), true, fr2, fr1, fr3);
+			else
+				fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3);
+		}
+		else { // scalar
+			FederatedRequest fr1 = fedMap.broadcast(scalar);
+			FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
+				new long[]{fedMap.getID(), fr1.getID()}, null);
+			FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1.getID());
+
+			if(fr2.length == 1)
+				fedMap.execute(getTID(), true, fr2, fr1, fr3);
+			else
+				fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], fr2[from], fr1, fr3);
 		}
 
 		if(input1.isFrame()) {
 			FrameObject out = ec.getFrameObject(output);
 			out.setSchema(((FrameObject) in1).getSchema());
 			out.getDataCharacteristics().set(in1.getDataCharacteristics());
-			out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
+			out.setFedMapping(fedMap.copyWithNewID(id));
 		} else {
 			MatrixObject out = ec.getMatrixObject(output);
-			out.getDataCharacteristics().set(in1.getDataCharacteristics());;
-			out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
+			out.getDataCharacteristics().set(in1.getDataCharacteristics());
+			out.setFedMapping(fedMap.copyWithNewID(id));
 		}
 	}
 
@@ -309,6 +353,13 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 		return String.join(Lop.OPERAND_DELIMITOR, instParts);
 	}
 
+	private String changeScalarLiteralFlag(String inst, int partIx) {
+		// change the literal flag of the broadcast scalar
+		String[] instParts = inst.split(Lop.OPERAND_DELIMITOR);
+		instParts[partIx] = instParts[partIx].replace("true", "false");
+		return String.join(Lop.OPERAND_DELIMITOR, instParts);
+	}
+
 	private String createCopyInstString() {
 		String[] instParts = instString.split(Lop.OPERAND_DELIMITOR);
 		return VariableCPInstruction.prepareCopyInstruction(instParts[2], instParts[8]).toString();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
index 3c337286a3..6686bdd298 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.federated.primitives;
 import java.util.Arrays;
 import java.util.Collection;
 
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
@@ -40,6 +39,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
 
 	private final static String TEST_NAME1 = "FederatedLeftIndexFullTest";
 	private final static String TEST_NAME2 = "FederatedLeftIndexFrameFullTest";
+	private final static String TEST_NAME3 = "FederatedLeftIndexScalarTest";
 
 	private final static String TEST_DIR = "functions/federated/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + FederatedLeftIndexTest.class.getSimpleName() + "/";
@@ -81,7 +81,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
 	}
 
 	private enum DataType {
-		MATRIX, FRAME
+		MATRIX, FRAME, SCALAR
 	}
 
 	@Override
@@ -89,6 +89,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
 		TestUtils.clearAssertionInformation();
 		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
 		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+		addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
 	}
 
 	@Test
@@ -102,108 +103,122 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
 	}
 
 	@Test
-	public void testLeftIndexFullDenseMatrixSP() { runAggregateOperationTest(DataType.MATRIX, ExecMode.SPARK); }
+	public void testLeftIndexFullDenseMatrixSP() {
+		runAggregateOperationTest(DataType.MATRIX, ExecMode.SPARK);
+	}
 
 	@Test
 	public void testLeftIndexFullDenseFrameSP() {
 		runAggregateOperationTest(DataType.FRAME, ExecMode.SPARK);
 	}
 
-	private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {
-		setExecMode(execMode);
-
-		String TEST_NAME = null;
-
-		if(dataType == DataType.MATRIX)
-			TEST_NAME = TEST_NAME1;
-		else
-			TEST_NAME = TEST_NAME2;
-
+	@Test
+	public void testLeftIndexScalarCP() {
+		runAggregateOperationTest(DataType.SCALAR, ExecMode.SINGLE_NODE);
+	}
 
-		getAndLoadTestConfiguration(TEST_NAME);
-		String HOME = SCRIPT_DIR + TEST_DIR;
+	@Test
+	public void testLeftIndexScalarSP() {
+		runAggregateOperationTest(DataType.SCALAR, ExecMode.SPARK);
+	}
 
-		// write input matrices
-		int r1 = rows1;
-		int c1 = cols1 / 4;
-		if(rowPartitioned) {
-			r1 = rows1 / 4;
-			c1 = cols1;
+	private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {
+		ExecMode oldPlatform = setExecMode(execMode);
+		
+		try {
+			String TEST_NAME = null;
+	
+			if(dataType == DataType.MATRIX)
+				TEST_NAME = TEST_NAME1;
+			else if(dataType == DataType.FRAME)
+				TEST_NAME = TEST_NAME2;
+			else
+				TEST_NAME = TEST_NAME3;
+	
+			getAndLoadTestConfiguration(TEST_NAME);
+			String HOME = SCRIPT_DIR + TEST_DIR;
+	
+			// write input matrices
+			int r1 = rows1;
+			int c1 = cols1 / 4;
+			if(rowPartitioned) {
+				r1 = rows1 / 4;
+				c1 = cols1;
+			}
+	
+			double[][] X1 = getRandomMatrix(r1, c1, 1, 5, 1, 3);
+			double[][] X2 = getRandomMatrix(r1, c1, 1, 5, 1, 7);
+			double[][] X3 = getRandomMatrix(r1, c1,  1, 5, 1, 8);
+			double[][] X4 = getRandomMatrix(r1, c1, 1, 5, 1, 9);
+	
+			MatrixCharacteristics mc = new MatrixCharacteristics(r1, c1,  blocksize, r1 * c1);
+			writeInputMatrixWithMTD("X1", X1, false, mc);
+			writeInputMatrixWithMTD("X2", X2, false, mc);
+			writeInputMatrixWithMTD("X3", X3, false, mc);
+			writeInputMatrixWithMTD("X4", X4, false, mc);
+	
+			if(dataType != DataType.SCALAR) {
+				double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3);
+	
+				MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2);
+				writeInputMatrixWithMTD("Y", Y, false, mc2);
+			}
+	
+			// empty script name because we don't execute any script, just start the worker
+			fullDMLScriptName = "";
+			int port1 = getRandomAvailablePort();
+			int port2 = getRandomAvailablePort();
+			int port3 = getRandomAvailablePort();
+			int port4 = getRandomAvailablePort();
+			Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+			Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+			Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+			Thread t4 = startLocalFedWorkerThread(port4);
+	
+			TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+			loadTestConfiguration(config);
+	
+			var lfrom = Math.min(from, to);
+			var lfrom2 = Math.min(from2, to2);
+	
+			// Run reference dml script with normal matrix
+			fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+			programArgs = new String[] {"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+				input("Y"), String.valueOf(lfrom), String.valueOf(to),
+				String.valueOf(lfrom2), String.valueOf(to2),
+				Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
+			runTest(null);
+			// Run actual dml script with federated matrix
+	
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[] {"-stats", "100", "-nvargs",
+				"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+				"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+				"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+				"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
+				"in_Y=" + input("Y"), "rows=" + rows1, "cols=" + cols1,
+				"rows2=" + rows2, "cols2=" + cols2,
+				"from=" + from, "to=" + to,"from2=" + from2, "to2=" + to2,
+				"rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+	
+			runTest(null);
+	
+			// compare via files
+			compareResults(1e-9, "Stat-DML1", "Stat-DML2");
+	
+			Assert.assertTrue(rtplatform ==ExecMode.SPARK ?
+				heavyHittersContainsString("fed_mapLeftIndex") : heavyHittersContainsString("fed_leftIndex"));
+	
+			// check that federated input files are still existing
+			Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+			Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+			Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+			Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+	
+			TestUtils.shutdownThreads(t1, t2, t3, t4);
 		}
-
-		double[][] X1 = getRandomMatrix(r1, c1, 1, 5, 1, 3);
-		double[][] X2 = getRandomMatrix(r1, c1, 1, 5, 1, 7);
-		double[][] X3 = getRandomMatrix(r1, c1,  1, 5, 1, 8);
-		double[][] X4 = getRandomMatrix(r1, c1, 1, 5, 1, 9);
-
-		MatrixCharacteristics mc = new MatrixCharacteristics(r1, c1,  blocksize, r1 * c1);
-		writeInputMatrixWithMTD("X1", X1, false, mc);
-		writeInputMatrixWithMTD("X2", X2, false, mc);
-		writeInputMatrixWithMTD("X3", X3, false, mc);
-		writeInputMatrixWithMTD("X4", X4, false, mc);
-
-		double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3);
-
-		MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2);
-		writeInputMatrixWithMTD("Y", Y, false, mc2);
-
-		// empty script name because we don't execute any script, just start the worker
-		fullDMLScriptName = "";
-		int port1 = getRandomAvailablePort();
-		int port2 = getRandomAvailablePort();
-		int port3 = getRandomAvailablePort();
-		int port4 = getRandomAvailablePort();
-		Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
-		Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
-		Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
-		Thread t4 = startLocalFedWorkerThread(port4);
-
-		rtplatform = execMode;
-		if(rtplatform == ExecMode.SPARK) {
-			System.out.println(7);
-			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		finally {
+			resetExecMode(oldPlatform);
 		}
-		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
-		loadTestConfiguration(config);
-
-		if(from > to)
-			from = to;
-		if(from2 > to2)
-			from2 = to2;
-
-		// Run reference dml script with normal matrix
-		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-		programArgs = new String[] {"-explain", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
-			input("Y"), String.valueOf(from), String.valueOf(to),
-			String.valueOf(from2), String.valueOf(to2),
-			Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
-		runTest(null);
-		// Run actual dml script with federated matrix
-
-		fullDMLScriptName = HOME + TEST_NAME + ".dml";
-		programArgs = new String[] {"-stats", "100", "-nvargs",
-			"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
-			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
-			"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
-			"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
-			"in_Y=" + input("Y"), "rows=" + rows1, "cols=" + cols1,
-			"rows2=" + rows2, "cols2=" + cols2,
-			"from=" + from, "to=" + to,"from2=" + from2, "to2=" + to2,
-			"rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
-
-		runTest(null);
-
-		// compare via files
-		compareResults(1e-9, "Stat-DML1", "Stat-DML2");
-
-		Assert.assertTrue(rtplatform ==ExecMode.SPARK ? heavyHittersContainsString("fed_mapLeftIndex") : heavyHittersContainsString("fed_leftIndex"));
-
-		// check that federated input files are still existing
-		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
-		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
-
-		TestUtils.shutdownThreads(t1, t2, t3, t4);
 	}
 }
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
index ca9fe81f40..a10bb72f77 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
@@ -41,5 +41,3 @@ A = as.frame(A)
 
 A[from:to, from2:to2] = B;
 write(A, $out_S);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
index 4b5a85234c..6589134273 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
@@ -37,5 +37,3 @@ A = as.frame(A)
 
 A[from:to, from2:to2] = B;
 write(A, $11);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
index a201f7bfe3..c048cb77c2 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
@@ -38,5 +38,3 @@ B = read($in_Y)
 
 A[from:to, from2:to2] = B;
 write(A, $out_S);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
index 2cc29f7ca8..ecd123254e 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
@@ -34,5 +34,3 @@ B = read($5)
 
 A[from:to, from2:to2] = B;
 write(A, $11);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
similarity index 90%
copy from src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
copy to src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
index ca9fe81f40..71a9f93490 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
@@ -19,10 +19,10 @@
 #
 #-------------------------------------------------------------
 
-from = $from;
-to = $to;
-from2 = $from2;
-to2 = $to2;
+row1 = $from;
+row2 = $to;
+col1 = $from2;
+col2 = $to2;
 
 if ($rP) {
   A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
@@ -34,12 +34,11 @@ if ($rP) {
     list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
 }
 
-B = read($in_Y)
+b = 13;
+c = as.scalar(rand(rows=1, cols=1, seed=456));
 
-B = as.frame(B)
-A = as.frame(A)
+A[row1, col1] = b;
+A[row2, col2] = c;
 
-A[from:to, from2:to2] = B;
 write(A, $out_S);
 
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
similarity index 88%
copy from src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
copy to src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
index 4b5a85234c..14ea17fbda 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
@@ -19,10 +19,10 @@
 #
 #-------------------------------------------------------------
 
-from = $6;
-to = $7;
-from2 = $8;
-to2 = $9;
+row1 = $6;
+row2 = $7;
+col1 = $8;
+col2 = $9;
 if($10) {
   A = rbind(read($1), read($2), read($3), read($4));
 }
@@ -30,12 +30,11 @@ else {
   A = cbind(read($1), read($2), read($3), read($4));
 }
 
-B = read($5)
+b = 13;
+c = as.scalar(rand(rows=1, cols=1, seed=456));
 
-B = as.frame(B)
-A = as.frame(A)
+A[row1, col1] = b;
+A[row2, col2] = c;
 
-A[from:to, from2:to2] = B;
 write(A, $11);
 
-print(toString(A))