You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/11/12 19:57:52 UTC
[systemds] 01/05: [SYSTEMDS-2726] Federated right indexing
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit e2bd5bf4fe90a66f316898208c851f796117a90c
Author: Olga <ov...@gmail.com>
AuthorDate: Tue Nov 10 16:05:30 2020 +0100
[SYSTEMDS-2726] Federated right indexing
---
.../controlprogram/federated/FederatedRange.java | 15 ++
.../controlprogram/federated/FederationMap.java | 6 +-
.../runtime/instructions/fed/FEDInstruction.java | 1 +
.../instructions/fed/FEDInstructionUtils.java | 10 ++
.../instructions/fed/IndexingFEDInstruction.java | 113 ++++++++++++
.../fed/MatrixIndexingFEDInstruction.java | 144 ++++++++++++++++
.../primitives/FederatedRightIndexTest.java | 191 +++++++++++++++++++++
.../federated/FederatedRightIndexFullTest.dml | 36 ++++
.../FederatedRightIndexFullTestReference.dml | 29 ++++
.../federated/FederatedRightIndexLeftTest.dml | 36 ++++
.../FederatedRightIndexLeftTestReference.dml | 29 ++++
.../federated/FederatedRightIndexRightTest.dml | 36 ++++
.../FederatedRightIndexRightTestReference.dml | 29 ++++
13 files changed, 673 insertions(+), 2 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 4289cfe..3bd5734 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -102,6 +102,21 @@ public class FederatedRange implements Comparable<FederatedRange> {
return Arrays.toString(_beginDims) + " - " + Arrays.toString(_endDims);
}
+ @Override public boolean equals(Object o) {
+ if(this == o)
+ return true;
+ if(o == null || getClass() != o.getClass())
+ return false;
+ FederatedRange range = (FederatedRange) o;
+ return Arrays.equals(_beginDims, range._beginDims) && Arrays.equals(_endDims, range._endDims);
+ }
+
+ @Override public int hashCode() {
+ int result = Arrays.hashCode(_beginDims);
+ result = 31 * result + Arrays.hashCode(_endDims);
+ return result;
+ }
+
public FederatedRange shift(long rshift, long cshift) {
//row shift
_beginDims[0] += rshift;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 04251fc..b647476 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -224,8 +224,10 @@ public class FederationMap
public FederationMap copyWithNewID(long id) {
Map<FederatedRange, FederatedData> map = new TreeMap<>();
//TODO handling of file path, but no danger as never written
- for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
- map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
+ for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() ) {
+ if(e.getKey().getSize() != 0)
+ map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
+ }
return new FederationMap(id, map, _type);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 9301765..8094c96 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -37,6 +37,7 @@ public abstract class FEDInstruction extends Instruction {
Tsmm,
MMChain,
Reorg,
+ MatrixIndexing
}
protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 795db11..2edc5f2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -32,6 +32,7 @@ import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
@@ -127,6 +128,15 @@ public class FEDInstructionUtils {
if( mo.isFederated() )
fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
}
+ else if(inst instanceof MatrixIndexingCPInstruction && inst.getOpcode().equalsIgnoreCase("rightIndex")) {
+ // matrix indexing
+ MatrixIndexingCPInstruction minst = (MatrixIndexingCPInstruction) inst;
+ if(minst.input1.isMatrix()) {
+ CacheableData<?> fo = ec.getCacheableData(minst.input1);
+ if(fo.isFederated())
+ fedinst = MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString());
+ }
+ }
else if(inst instanceof VariableCPInstruction ){
VariableCPInstruction ins = (VariableCPInstruction) inst;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
new file mode 100644
index 0000000..15fe1ab
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.LeftIndex;
+import org.apache.sysds.lops.RightIndex;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.util.IndexRange;
+
+public abstract class IndexingFEDInstruction extends UnaryFEDInstruction {
+ protected final CPOperand rowLower, rowUpper, colLower, colUpper;
+
+ protected IndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
+ CPOperand out, String opcode, String istr) {
+ super(FEDInstruction.FEDType.MatrixIndexing, null, in, out, opcode, istr);
+ rowLower = rl;
+ rowUpper = ru;
+ colLower = cl;
+ colUpper = cu;
+ }
+
+ protected IndexingFEDInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl,
+ CPOperand cu, CPOperand out, String opcode, String istr) {
+ super(FEDInstruction.FEDType.MatrixIndexing, null, lhsInput, rhsInput, out, opcode, istr);
+ rowLower = rl;
+ rowUpper = ru;
+ colLower = cl;
+ colUpper = cu;
+ }
+
+ protected IndexRange getIndexRange(ExecutionContext ec) {
+ return new IndexRange( //rl, ru, cl, ru
+ (int) (ec.getScalarInput(rowLower).getLongValue() - 1),
+ (int) (ec.getScalarInput(rowUpper).getLongValue() - 1),
+ (int) (ec.getScalarInput(colLower).getLongValue() - 1),
+ (int) (ec.getScalarInput(colUpper).getLongValue() - 1));
+ }
+
+ public static IndexingFEDInstruction parseInstruction(String str) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if(opcode.equalsIgnoreCase(RightIndex.OPCODE)) {
+ if(parts.length == 7) {
+ CPOperand in, rl, ru, cl, cu, out;
+ in = new CPOperand(parts[1]);
+ rl = new CPOperand(parts[2]);
+ ru = new CPOperand(parts[3]);
+ cl = new CPOperand(parts[4]);
+ cu = new CPOperand(parts[5]);
+ out = new CPOperand(parts[6]);
+ if(in.getDataType() == Types.DataType.MATRIX)
+ return new MatrixIndexingFEDInstruction(in, rl, ru, cl, cu, out, opcode, str);
+ // else if( in.getDataType() == Types.DataType.FRAME )
+ // return new FrameIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str);
+ // else if( in.getDataType() == Types.DataType.LIST )
+ // return new ListIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str);
+ else
+ throw new DMLRuntimeException("Can index only on matrices, frames, and lists.");
+ }
+ else {
+ throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
+ }
+ }
+ // else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE)) {
+ // if ( parts.length == 8 ) {
+ // CPOperand lhsInput, rhsInput, rl, ru, cl, cu, out;
+ // lhsInput = new CPOperand(parts[1]);
+ // rhsInput = new CPOperand(parts[2]);
+ // rl = new CPOperand(parts[3]);
+ // ru = new CPOperand(parts[4]);
+ // cl = new CPOperand(parts[5]);
+ // cu = new CPOperand(parts[6]);
+ // out = new CPOperand(parts[7]);
+ // if( lhsInput.getDataType()== Types.DataType.MATRIX )
+ // return new MatrixIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str);
+ // else if (lhsInput.getDataType() == Types.DataType.FRAME)
+ // return new FrameIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str);
+ // else if( lhsInput.getDataType() == Types.DataType.LIST )
+ // return new ListIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str);
+ // else
+ // throw new DMLRuntimeException("Can index only on matrices, frames, and lists.");
+ // }
+ // else {
+ // throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
+ // }
+ // }
+ else {
+ throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingFEDInstruction: " + str);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
new file mode 100644
index 0000000..ea2e905
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.IndexRange;
+
+public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction {
+ private static final Log LOG = LogFactory.getLog(MatrixIndexingFEDInstruction.class.getName());
+
+ public MatrixIndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
+ CPOperand out, String opcode, String istr) {
+ super(in, rl, ru, cl, cu, out, opcode, istr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ rightIndexing(ec);
+ }
+
+
+ private void rightIndexing (ExecutionContext ec) {
+ MatrixObject in = ec.getMatrixObject(input1);
+ FederationMap fedMapping = in.getFedMapping();
+ IndexRange ixrange = getIndexRange(ec);
+ FederationMap.FType fedType;
+ Map <FederatedRange, IndexRange> ixs = new HashMap<>();
+
+ FederatedRange nextDim = new FederatedRange(new long[]{0, 0}, new long[]{0, 0});
+
+ for (int i = 0; i < fedMapping.getFederatedRanges().length; i++) {
+ long rs = fedMapping.getFederatedRanges()[i].getBeginDims()[0], re = fedMapping.getFederatedRanges()[i]
+ .getEndDims()[0], cs = fedMapping.getFederatedRanges()[i].getBeginDims()[1], ce = fedMapping.getFederatedRanges()[i].getEndDims()[1];
+
+ // for OTHER
+ fedType = ((i + 1) < fedMapping.getFederatedRanges().length &&
+ fedMapping.getFederatedRanges()[i].getEndDims()[0] == fedMapping.getFederatedRanges()[i+1].getBeginDims()[0]) ?
+ FederationMap.FType.ROW : FederationMap.FType.COL;
+
+ long rsn = 0, ren = 0, csn = 0, cen = 0;
+
+ rsn = (ixrange.rowStart >= rs && ixrange.rowStart < re) ? (ixrange.rowStart - rs) : 0;
+ ren = (ixrange.rowEnd >= rs && ixrange.rowEnd < re) ? (ixrange.rowEnd - rs) : (re - rs - 1);
+ csn = (ixrange.colStart >= cs && ixrange.colStart < ce) ? (ixrange.colStart - cs) : 0;
+ cen = (ixrange.colEnd >= cs && ixrange.colEnd < ce) ? (ixrange.colEnd - cs) : (ce - cs - 1);
+
+ fedMapping.getFederatedRanges()[i].setBeginDim(0, i != 0 ? nextDim.getBeginDims()[0] : 0);
+ fedMapping.getFederatedRanges()[i].setBeginDim(1, i != 0 ? nextDim.getBeginDims()[1] : 0);
+ if((ixrange.colStart < ce) && (ixrange.colEnd >= cs) && (ixrange.rowStart < re) && (ixrange.rowEnd >= rs)) {
+ fedMapping.getFederatedRanges()[i].setEndDim(0, ren - rsn + 1 + nextDim.getBeginDims()[0]);
+ fedMapping.getFederatedRanges()[i].setEndDim(1, cen - csn + 1 + nextDim.getBeginDims()[1]);
+
+ ixs.put(fedMapping.getFederatedRanges()[i], new IndexRange(rsn, ren, csn, cen));
+ } else {
+ fedMapping.getFederatedRanges()[i].setEndDim(0, i != 0 ? nextDim.getBeginDims()[0] : 0);
+ fedMapping.getFederatedRanges()[i].setEndDim(1, i != 0 ? nextDim.getBeginDims()[1] : 0);
+ }
+
+ if(fedType == FederationMap.FType.ROW) {
+ nextDim.setBeginDim(0,fedMapping.getFederatedRanges()[i].getEndDims()[0]);
+ nextDim.setBeginDim(1, fedMapping.getFederatedRanges()[i].getBeginDims()[1]);
+ } else if(fedType == FederationMap.FType.COL) {
+ nextDim.setBeginDim(1,fedMapping.getFederatedRanges()[i].getEndDims()[1]);
+ nextDim.setBeginDim(0, fedMapping.getFederatedRanges()[i].getBeginDims()[0]);
+ }
+ }
+
+ long varID = FederationUtils.getNextFedDataID();
+ FederationMap slicedMapping = fedMapping.mapParallel(varID, (range, data) -> {
+ try {
+ FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ -1, new SliceMatrix(data.getVarID(), varID, ixs.getOrDefault(range, new IndexRange(-1, -1, -1, -1))))).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ MatrixObject sliced = ec.getMatrixObject(output);
+ sliced.getDataCharacteristics().set(fedMapping.getMaxIndexInRange(0), fedMapping.getMaxIndexInRange(1), (int) in.getBlocksize());
+ sliced.setFedMapping(slicedMapping);
+ }
+
+ private static class SliceMatrix extends FederatedUDF {
+
+ private static final long serialVersionUID = 5956832933333848772L;
+ private final long _outputID;
+ private final IndexRange _ixrange;
+
+ private SliceMatrix(long input, long outputID, IndexRange ixrange) {
+ super(new long[] {input});
+ _outputID = outputID;
+ _ixrange = ixrange;
+ }
+
+
+ @Override public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
+ MatrixBlock res;
+ if(_ixrange.rowStart != -1)
+ res = mb.slice(_ixrange, new MatrixBlock());
+ else res = new MatrixBlock();
+ MatrixObject mout = ExecutionContext.createMatrixObject(res);
+ ec.setVariable(String.valueOf(_outputID), mout);
+
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
new file mode 100644
index 0000000..a16e4ed
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedRightIndexTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "FederatedRightIndexRightTest";
+ private final static String TEST_NAME2 = "FederatedRightIndexLeftTest";
+ private final static String TEST_NAME3 = "FederatedRightIndexFullTest";
+
+ private final static String TEST_DIR = "functions/federated/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRightIndexTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
+ public int from;
+
+ @Parameterized.Parameter(3)
+ public int to;
+
+ @Parameterized.Parameter(4)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {20, 10, 6, 8, true}, {20, 10, 2, 10, true},
+ {20, 12, 2, 10, false}, {20, 12, 1, 4, false}
+ });
+ }
+
+ private enum IndexType {
+ RIGHT, LEFT, FULL
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
+ }
+
+ @Test
+ public void testRightIndexRightDenseMatrixCP() {
+ runAggregateOperationTest(IndexType.RIGHT, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testRightIndexLeftDenseMatrixCP() {
+ runAggregateOperationTest(IndexType.LEFT, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testRightIndexFullDenseMatrixCP() {
+ runAggregateOperationTest(IndexType.FULL, ExecMode.SINGLE_NODE);
+ }
+
+ private void runAggregateOperationTest(IndexType type, ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ String TEST_NAME = null;
+ switch(type) {
+ case RIGHT:
+ TEST_NAME = TEST_NAME1; break;
+ case LEFT:
+ TEST_NAME = TEST_NAME2; break;
+ case FULL:
+ TEST_NAME = TEST_NAME3; break;
+ }
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+ Thread t3 = startLocalFedWorkerThread(port3);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ System.out.println(7);
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] { "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ String.valueOf(from), String.valueOf(to), Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+ "from=" + from, "to=" + to, "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+ "out_S=" + output("S")};
+
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+
+ Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml b/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml
new file mode 100644
index 0000000..46bc064
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $from;
+to = $to;
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
+}
+
+s = A[from:to, from:to];
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml b/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml
new file mode 100644
index 0000000..8261f5e
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $5;
+to = $6;
+
+if($7) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = A[from:to, from:to];
+write(s, $8);
diff --git a/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml b/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml
new file mode 100644
index 0000000..3f690b1
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $from;
+to = $to;
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
+}
+
+s = A[from:to,];
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml b/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml
new file mode 100644
index 0000000..ef095f3
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $5;
+to = $6;
+
+if($7) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = A[from:to,];
+write(s, $8);
diff --git a/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml b/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml
new file mode 100644
index 0000000..ee80b46
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $from;
+to = $to;
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
+}
+
+s = A[, from:to];
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml b/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml
new file mode 100644
index 0000000..af83ca0
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $5;
+to = $6;
+
+if($7) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = A[, from:to];
+write(s, $8);