You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/06/05 21:16:36 UTC
[systemds] branch main updated: [SYSTEMDS-1622] Fix federated left indexing with scalar inputs
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new fa81f6a40a [SYSTEMDS-1622] Fix federated left indexing with scalar inputs
fa81f6a40a is described below
commit fa81f6a40ae15c9f50e00cf8ec96af2626684e4f
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sun Jun 5 23:16:07 2022 +0200
[SYSTEMDS-1622] Fix federated left indexing with scalar inputs
This patch generalizes the federated left indexing instruction
for scalar, and fixes a more general issue of replacing
instruction operands for edge cases where the scalar matches
federated input or output variable names.
Closes #1622.
Co-authored-by: Matthias Boehm <mb...@gmail.com>
---
.../federated/FederatedLookupTable.java | 4 +
.../federated/FederatedWorkerHandler.java | 10 +-
.../controlprogram/federated/FederationMap.java | 5 +
.../controlprogram/federated/FederationUtils.java | 21 ++-
.../instructions/fed/IndexingFEDInstruction.java | 143 +++++++++-----
.../primitives/FederatedLeftIndexTest.java | 205 +++++++++++----------
.../federated/FederatedLeftIndexFrameFullTest.dml | 2 -
.../FederatedLeftIndexFrameFullTestReference.dml | 2 -
.../federated/FederatedLeftIndexFullTest.dml | 2 -
.../FederatedLeftIndexFullTestReference.dml | 2 -
...llTest.dml => FederatedLeftIndexScalarTest.dml} | 17 +-
...l => FederatedLeftIndexScalarTestReference.dml} | 17 +-
12 files changed, 252 insertions(+), 178 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index 55ab9715e2..afba8ac42a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -47,6 +47,10 @@ public class FederatedLookupTable {
_lookup_table = new ConcurrentHashMap<>();
}
+ public void clear() {
+ _lookup_table.clear();
+ }
+
/**
* Get the ExecutionContextMap corresponding to the given host and pid of the
* requesting coordinator from the lookup table. Create a new
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 592f77ccce..47cedd739c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -220,8 +220,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
containsCLEAR = true;
}
- if(containsCLEAR)
+ if(containsCLEAR) {
+ _flt.clear();
printStatistics();
+ }
return response;
}
@@ -398,7 +400,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
checkNumParams(request.getNumParams(), 1, 2);
final String varName = String.valueOf(request.getID());
ExecutionContext ec = ecm.get(request.getTID());
-
+
if(ec.containsVariable(varName)) {
final Data tgtData = ec.removeVariable(varName);
if(tgtData != null)
@@ -450,7 +452,6 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
private FederatedResponse getVariable(FederatedRequest request, ExecutionContextMap ecm) {
try{
-
checkNumParams(request.getNumParams(), 0);
ExecutionContext ec = ecm.get(request.getTID());
if(!ec.containsVariable(String.valueOf(request.getID())))
@@ -494,7 +495,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
//handle missing spark execution context
//TODO handling of spark instructions should be under control of federated site not coordinator
if(ins.getType() == IType.SPARK
- && !(ec instanceof SparkExecutionContext) ) {
+ && !(ec instanceof SparkExecutionContext) )
+ {
ecm.convertToSparkCtx();
return ecm.get(id);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 0053a8b2fe..fcef0d7984 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -351,6 +351,11 @@ public class FederationMap {
return ret.toArray(new Future[0]);
}
+ public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRange[] fedRange1,
+ FederatedRequest elseFr, FederatedRequest frSlice1, FederatedRequest frSlice2, FederatedRequest fr) {
+ return execute(tid, wait, fedRange1, elseFr, new FederatedRequest[]{frSlice1}, new FederatedRequest[]{frSlice2}, fr);
+ }
+
@SuppressWarnings("unchecked")
public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRange[] fedRange1, FederatedRequest elseFr, FederatedRequest[] frSlices1, FederatedRequest[] frSlices2, FederatedRequest... fr) {
// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 671cd0b744..82f78fe176 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -126,19 +126,26 @@ public class FederationUtils {
String[] linst = inst;
FederatedRequest[] fr = new FederatedRequest[inst.length];
for(int j=0; j<inst.length; j++) {
+ linst[j] = InstructionUtils.replaceOperand(linst[j], 0, type == null ?
+ InstructionUtils.getExecType(linst[j]).name() : type.name());
+ // replace inputs before before outputs in order to prevent conflicts
+ // on outputId matching input literals (due to a mix of input instructions,
+ // have to apply this replacement even for literal inputs)
for(int i = 0; i < varOldIn.length; i++) {
- linst[j] = InstructionUtils.replaceOperand(linst[j], 0, type == null ? InstructionUtils.getExecType(linst[j]).name() : type.name());
- linst[j] = linst[j].replace(
- Lop.OPERAND_DELIMITOR + varOldOut.getName() + Lop.DATATYPE_PREFIX,
- Lop.OPERAND_DELIMITOR + String.valueOf(outputId) + Lop.DATATYPE_PREFIX);
-
- if(varOldIn[i] != null) {
+ if( varOldIn[i] != null ) {
linst[j] = linst[j].replace(
Lop.OPERAND_DELIMITOR + varOldIn[i].getName() + Lop.DATATYPE_PREFIX,
Lop.OPERAND_DELIMITOR + String.valueOf(varNewIn[i]) + Lop.DATATYPE_PREFIX);
- linst[j] = linst[j].replace("=" + varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i])); //parameterized
+ // handle parameterized builtin functions
+ linst[j] = linst[j].replace("=" + varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i]));
}
}
+ for(int i = 0; i < varOldIn.length; i++) {
+ linst[j] = linst[j].replace(
+ Lop.OPERAND_DELIMITOR + varOldOut.getName() + Lop.DATATYPE_PREFIX,
+ Lop.OPERAND_DELIMITOR + String.valueOf(outputId) + Lop.DATATYPE_PREFIX);
+ }
+
fr[j] = new FederatedRequest(RequestType.EXEC_INST, outputId, (Object) linst[j]);
}
return fr;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index bc70b398f9..4e4448ba97 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -25,8 +25,10 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.LeftIndex;
@@ -44,6 +46,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
@@ -150,7 +153,7 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
List<Types.ValueType> schema = new ArrayList<>();
// replace old reshape values for each worker
int i = 0;
- for(org.apache.commons.lang3.tuple.Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
+ for(Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
FederatedRange range = e.getKey();
long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
@@ -204,7 +207,8 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
{
//get input and requested index range
CacheableData<?> in1 = ec.getCacheableData(input1);
- CacheableData<?> in2 = ec.getCacheableData(input2);
+ CacheableData<?> in2 = null; // either in2 or scalar is set
+ ScalarObject scalar = null;
IndexRange ixrange = getIndexRange(ec);
//check bounds
@@ -213,11 +217,21 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
throw new DMLRuntimeException("Invalid values for matrix indexing: ["+(ixrange.rowStart+1)+":"+(ixrange.rowEnd+1)+","
+ (ixrange.colStart+1)+":"+(ixrange.colEnd+1)+"] " + "must be within matrix dimensions ["+in1.getNumRows()+","+in1.getNumColumns()+"].");
}
- if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) {
- throw new DMLRuntimeException("Invalid values for matrix indexing: " +
- "dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
- "do not match the shape of the matrix specified by indices [" +
- (ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "].");
+
+ if(input2.getDataType() == DataType.SCALAR) {
+ if(!ixrange.isScalar())
+ throw new DMLRuntimeException("Invalid index range for leftindexing with scalar: " + ixrange.toString() + ".");
+
+ scalar = ec.getScalarInput(input2);
+ }
+ else {
+ in2 = ec.getCacheableData(input2);
+ if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) {
+ throw new DMLRuntimeException("Invalid values for matrix indexing: " +
+ "dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
+ "do not match the shape of the matrix specified by indices [" +
+ (ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "].");
+ }
}
FederationMap fedMap = in1.getFedMapping();
@@ -226,9 +240,13 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
int[][] sliceIxs = new int[fedMap.getSize()][];
FederatedRange[] ranges = new FederatedRange[fedMap.getSize()];
+ // instruction string for copying a partition at the federated site
+ int cpVarInstIx = fedMap.getSize();
+ String cpVarInstString = createCopyInstString();
+
// replace old reshape values for each worker
int i = 0, prev = 0, from = fedMap.getSize();
- for(org.apache.commons.lang3.tuple.Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
+ for(Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
FederatedRange range = e.getKey();
long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
@@ -239,29 +257,46 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
long[] newIx = new long[]{(int) rsn, (int) ren, (int) csn, (int) cen};
- // find ranges where to apply leftIndex
- long to;
- if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 &&
- to < in2.getNumRows() && ixrange.rowStart <= re) {
- sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1};
- prev = (int) (to + 1);
-
- instStrings[i] = modifyIndices(newIx, 4, 8);
- ranges[i] = range;
- from = Math.min(i, from);
+ if(in2 != null) { // matrix, frame
+ // find ranges where to apply leftIndex
+ long to;
+ if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 &&
+ to < in2.getNumRows() && ixrange.rowStart <= re) {
+ sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1};
+ prev = (int) (to + 1);
+
+ instStrings[i] = modifyIndices(newIx, 4, 8);
+ ranges[i] = range;
+ from = Math.min(i, from);
+ }
+ else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 &&
+ to < in2.getNumColumns() && ixrange.colStart <= ce) {
+ sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to};
+ prev = (int) (to + 1);
+
+ instStrings[i] = modifyIndices(newIx, 4, 8);
+ ranges[i] = range;
+ from = Math.min(i, from);
+ }
+ else {
+ // TODO shallow copy, add more advanced update in place for federated
+ cpVarInstIx = Math.min(i, cpVarInstIx);
+ instStrings[i] = cpVarInstString;
+ }
}
- else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 &&
- to < in2.getNumColumns() && ixrange.colStart <= ce) {
- sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to};
- prev = (int) (to + 1);
-
- instStrings[i] = modifyIndices(newIx, 4, 8);
- ranges[i] = range;
- from = Math.min(i, from);
+ else { // scalar
+ if(ixrange.rowStart >= rs && ixrange.rowEnd < re
+ && ixrange.colStart >= cs && ixrange.colEnd < ce) {
+ instStrings[i] = modifyIndices(newIx, 4, 8);
+ instStrings[i] = changeScalarLiteralFlag(instStrings[i], 3);
+ ranges[i] = range;
+ from = Math.min(i, from);
+ }
+ else {
+ cpVarInstIx = Math.min(i, cpVarInstIx);
+ instStrings[i] = cpVarInstString;
+ }
}
- else
- // TODO shallow copy, add more advanced update in place for federated
- instStrings[i] = createCopyInstString();
i++;
}
@@ -269,35 +304,44 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
sliceIxs = Arrays.stream(sliceIxs).filter(Objects::nonNull).toArray(int[][] :: new);
long id = FederationUtils.getNextFedDataID();
+ //TODO remove explicit put (unnecessary in CP, only spark which is about to be cleaned up)
FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in1.getDataType());
fedMap.execute(getTID(), true, tmp);
- FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
- input2.isFrame(), sliceIxs);
- FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
- new long[]{fedMap.getID(), fr1[0].getID()}, null);
- FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());
+ if(in2 != null) { // matrix, frame
+ FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
+ input2.isFrame(), sliceIxs);
+ FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
+ new long[]{fedMap.getID(), fr1[0].getID()}, null);
+ FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());
- //execute federated instruction and cleanup intermediates
- if(sliceIxs.length == fedMap.getSize())
- fedMap.execute(getTID(), true, fr2, fr1, fr3);
- else {
- // get index of cpvar request
- for(i = 0; i < fr2.length; i++)
- if(i < from || i >= from + sliceIxs.length)
- break;
- fedMap.execute(getTID(), true, ranges, (fr2[i]), Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3);
+ //execute federated instruction and cleanup intermediates
+ if(sliceIxs.length == fedMap.getSize())
+ fedMap.execute(getTID(), true, fr2, fr1, fr3);
+ else
+ fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3);
+ }
+ else { // scalar
+ FederatedRequest fr1 = fedMap.broadcast(scalar);
+ FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
+ new long[]{fedMap.getID(), fr1.getID()}, null);
+ FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1.getID());
+
+ if(fr2.length == 1)
+ fedMap.execute(getTID(), true, fr2, fr1, fr3);
+ else
+ fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], fr2[from], fr1, fr3);
}
if(input1.isFrame()) {
FrameObject out = ec.getFrameObject(output);
out.setSchema(((FrameObject) in1).getSchema());
out.getDataCharacteristics().set(in1.getDataCharacteristics());
- out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
+ out.setFedMapping(fedMap.copyWithNewID(id));
} else {
MatrixObject out = ec.getMatrixObject(output);
- out.getDataCharacteristics().set(in1.getDataCharacteristics());;
- out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
+ out.getDataCharacteristics().set(in1.getDataCharacteristics());
+ out.setFedMapping(fedMap.copyWithNewID(id));
}
}
@@ -309,6 +353,13 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
return String.join(Lop.OPERAND_DELIMITOR, instParts);
}
+ private String changeScalarLiteralFlag(String inst, int partIx) {
+ // change the literal flag of the broadcast scalar
+ String[] instParts = inst.split(Lop.OPERAND_DELIMITOR);
+ instParts[partIx] = instParts[partIx].replace("true", "false");
+ return String.join(Lop.OPERAND_DELIMITOR, instParts);
+ }
+
private String createCopyInstString() {
String[] instParts = instString.split(Lop.OPERAND_DELIMITOR);
return VariableCPInstruction.prepareCopyInstruction(instParts[2], instParts[8]).toString();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
index 3c337286a3..6686bdd298 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.federated.primitives;
import java.util.Arrays;
import java.util.Collection;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
@@ -40,6 +39,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "FederatedLeftIndexFullTest";
private final static String TEST_NAME2 = "FederatedLeftIndexFrameFullTest";
+ private final static String TEST_NAME3 = "FederatedLeftIndexScalarTest";
private final static String TEST_DIR = "functions/federated/";
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedLeftIndexTest.class.getSimpleName() + "/";
@@ -81,7 +81,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
}
private enum DataType {
- MATRIX, FRAME
+ MATRIX, FRAME, SCALAR
}
@Override
@@ -89,6 +89,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
}
@Test
@@ -102,108 +103,122 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
}
@Test
- public void testLeftIndexFullDenseMatrixSP() { runAggregateOperationTest(DataType.MATRIX, ExecMode.SPARK); }
+ public void testLeftIndexFullDenseMatrixSP() {
+ runAggregateOperationTest(DataType.MATRIX, ExecMode.SPARK);
+ }
@Test
public void testLeftIndexFullDenseFrameSP() {
runAggregateOperationTest(DataType.FRAME, ExecMode.SPARK);
}
- private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {
- setExecMode(execMode);
-
- String TEST_NAME = null;
-
- if(dataType == DataType.MATRIX)
- TEST_NAME = TEST_NAME1;
- else
- TEST_NAME = TEST_NAME2;
-
+ @Test
+ public void testLeftIndexScalarCP() {
+ runAggregateOperationTest(DataType.SCALAR, ExecMode.SINGLE_NODE);
+ }
- getAndLoadTestConfiguration(TEST_NAME);
- String HOME = SCRIPT_DIR + TEST_DIR;
+ @Test
+ public void testLeftIndexScalarSP() {
+ runAggregateOperationTest(DataType.SCALAR, ExecMode.SPARK);
+ }
- // write input matrices
- int r1 = rows1;
- int c1 = cols1 / 4;
- if(rowPartitioned) {
- r1 = rows1 / 4;
- c1 = cols1;
+ private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {
+ ExecMode oldPlatform = setExecMode(execMode);
+
+ try {
+ String TEST_NAME = null;
+
+ if(dataType == DataType.MATRIX)
+ TEST_NAME = TEST_NAME1;
+ else if(dataType == DataType.FRAME)
+ TEST_NAME = TEST_NAME2;
+ else
+ TEST_NAME = TEST_NAME3;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r1 = rows1;
+ int c1 = cols1 / 4;
+ if(rowPartitioned) {
+ r1 = rows1 / 4;
+ c1 = cols1;
+ }
+
+ double[][] X1 = getRandomMatrix(r1, c1, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r1, c1, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r1, c1, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r1, c1, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r1, c1, blocksize, r1 * c1);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ if(dataType != DataType.SCALAR) {
+ double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3);
+
+ MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2);
+ writeInputMatrixWithMTD("Y", Y, false, mc2);
+ }
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ var lfrom = Math.min(from, to);
+ var lfrom2 = Math.min(from2, to2);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ input("Y"), String.valueOf(lfrom), String.valueOf(to),
+ String.valueOf(lfrom2), String.valueOf(to2),
+ Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
+ runTest(null);
+ // Run actual dml script with federated matrix
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
+ "in_Y=" + input("Y"), "rows=" + rows1, "cols=" + cols1,
+ "rows2=" + rows2, "cols2=" + cols2,
+ "from=" + from, "to=" + to,"from2=" + from2, "to2=" + to2,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+
+ runTest(null);
+
+ // compare via files
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
+
+ Assert.assertTrue(rtplatform ==ExecMode.SPARK ?
+ heavyHittersContainsString("fed_mapLeftIndex") : heavyHittersContainsString("fed_leftIndex"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
}
-
- double[][] X1 = getRandomMatrix(r1, c1, 1, 5, 1, 3);
- double[][] X2 = getRandomMatrix(r1, c1, 1, 5, 1, 7);
- double[][] X3 = getRandomMatrix(r1, c1, 1, 5, 1, 8);
- double[][] X4 = getRandomMatrix(r1, c1, 1, 5, 1, 9);
-
- MatrixCharacteristics mc = new MatrixCharacteristics(r1, c1, blocksize, r1 * c1);
- writeInputMatrixWithMTD("X1", X1, false, mc);
- writeInputMatrixWithMTD("X2", X2, false, mc);
- writeInputMatrixWithMTD("X3", X3, false, mc);
- writeInputMatrixWithMTD("X4", X4, false, mc);
-
- double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3);
-
- MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2);
- writeInputMatrixWithMTD("Y", Y, false, mc2);
-
- // empty script name because we don't execute any script, just start the worker
- fullDMLScriptName = "";
- int port1 = getRandomAvailablePort();
- int port2 = getRandomAvailablePort();
- int port3 = getRandomAvailablePort();
- int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
-
- rtplatform = execMode;
- if(rtplatform == ExecMode.SPARK) {
- System.out.println(7);
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ finally {
+ resetExecMode(oldPlatform);
}
- TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
- loadTestConfiguration(config);
-
- if(from > to)
- from = to;
- if(from2 > to2)
- from2 = to2;
-
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-explain", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
- input("Y"), String.valueOf(from), String.valueOf(to),
- String.valueOf(from2), String.valueOf(to2),
- Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
- runTest(null);
- // Run actual dml script with federated matrix
-
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
- "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
- "in_Y=" + input("Y"), "rows=" + rows1, "cols=" + cols1,
- "rows2=" + rows2, "cols2=" + cols2,
- "from=" + from, "to=" + to,"from2=" + from2, "to2=" + to2,
- "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
-
- runTest(null);
-
- // compare via files
- compareResults(1e-9, "Stat-DML1", "Stat-DML2");
-
- Assert.assertTrue(rtplatform ==ExecMode.SPARK ? heavyHittersContainsString("fed_mapLeftIndex") : heavyHittersContainsString("fed_leftIndex"));
-
- // check that federated input files are still existing
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
-
- TestUtils.shutdownThreads(t1, t2, t3, t4);
}
}
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
index ca9fe81f40..a10bb72f77 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
@@ -41,5 +41,3 @@ A = as.frame(A)
A[from:to, from2:to2] = B;
write(A, $out_S);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
index 4b5a85234c..6589134273 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
@@ -37,5 +37,3 @@ A = as.frame(A)
A[from:to, from2:to2] = B;
write(A, $11);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
index a201f7bfe3..c048cb77c2 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
@@ -38,5 +38,3 @@ B = read($in_Y)
A[from:to, from2:to2] = B;
write(A, $out_S);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
index 2cc29f7ca8..ecd123254e 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
@@ -34,5 +34,3 @@ B = read($5)
A[from:to, from2:to2] = B;
write(A, $11);
-
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
similarity index 90%
copy from src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
copy to src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
index ca9fe81f40..71a9f93490 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
@@ -19,10 +19,10 @@
#
#-------------------------------------------------------------
-from = $from;
-to = $to;
-from2 = $from2;
-to2 = $to2;
+row1 = $from;
+row2 = $to;
+col1 = $from2;
+col2 = $to2;
if ($rP) {
A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
@@ -34,12 +34,11 @@ if ($rP) {
list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
}
-B = read($in_Y)
+b = 13;
+c = as.scalar(rand(rows=1, cols=1, seed=456));
-B = as.frame(B)
-A = as.frame(A)
+A[row1, col1] = b;
+A[row2, col2] = c;
-A[from:to, from2:to2] = B;
write(A, $out_S);
-print(toString(A))
diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
similarity index 88%
copy from src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
copy to src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
index 4b5a85234c..14ea17fbda 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
@@ -19,10 +19,10 @@
#
#-------------------------------------------------------------
-from = $6;
-to = $7;
-from2 = $8;
-to2 = $9;
+row1 = $6;
+row2 = $7;
+col1 = $8;
+col2 = $9;
if($10) {
A = rbind(read($1), read($2), read($3), read($4));
}
@@ -30,12 +30,11 @@ else {
A = cbind(read($1), read($2), read($3), read($4));
}
-B = read($5)
+b = 13;
+c = as.scalar(rand(rows=1, cols=1, seed=456));
-B = as.frame(B)
-A = as.frame(A)
+A[row1, col1] = b;
+A[row2, col2] = c;
-A[from:to, from2:to2] = B;
write(A, $11);
-print(toString(A))