You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/11/14 23:53:28 UTC
[systemds] branch master updated: [SYSTEMDS-2732] Federated remove
empty
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 4d8ec5d [SYSTEMDS-2732] Federated remove empty
4d8ec5d is described below
commit 4d8ec5dd64922441f0acc452eee3e49bee0653cd
Author: Olga <ov...@gmail.com>
AuthorDate: Sat Nov 14 17:42:33 2020 +0100
[SYSTEMDS-2732] Federated remove empty
closes #1104
---
.../instructions/fed/FEDInstructionUtils.java | 2 +-
.../instructions/fed/InitFEDInstruction.java | 2 +-
.../fed/ParameterizedBuiltinFEDInstruction.java | 262 +++++++++++++++++++--
.../primitives/FederatedRemoveEmptyTest.java | 161 +++++++++++++
.../federated/FederatedRemoveEmptyTest.dml | 33 +++
.../FederatedRemoveEmptyTestReference.dml | 26 ++
6 files changed, 468 insertions(+), 18 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index f4b19bf..d8af245 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -129,7 +129,7 @@ public class FEDInstructionUtils {
}
else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
- if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) {
+ if((pinst.getOpcode().equals("replace") || pinst.getOpcode().equals("rmempty")) && pinst.getTarget(ec).isFederated()) {
fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
}
else if((pinst.getOpcode().equals("transformdecode") || pinst.getOpcode().equals("transformapply")) &&
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 8821a71..ce7f3b4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -234,7 +234,7 @@ public class InitFEDInstruction extends FEDInstruction {
}
try {
int timeout = ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT);
- LOG.error("Federated Initialization with timeout: " + timeout);
+ LOG.debug("Federated Initialization with timeout: " + timeout);
for (Pair<FederatedData, Future<FederatedResponse>> idResponse : idResponses)
idResponse.getRight().get(timeout,TimeUnit.SECONDS); //wait for initialization
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index f549dca..c50671e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,11 +19,14 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.AbstractMap;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
-
import java.util.List;
+import java.util.Map;
+
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
@@ -34,6 +37,7 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType;
@@ -47,6 +51,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -100,7 +105,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
LinkedHashMap<String, String> paramsMap = constructParameterMap(parts);
// determine the appropriate value function
- if( opcode.equalsIgnoreCase("replace") ) {
+ if(opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("rmempty")) {
ValueFunction func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
@@ -120,8 +125,10 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
// similar to unary federated instructions, get federated input
// execute instruction, and derive federated output matrix
MatrixObject mo = (MatrixObject) getTarget(ec);
- FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
- new CPOperand[] {getTargetOperand()}, new long[] {mo.getFedMapping().getID()});
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {getTargetOperand()},
+ new long[] {mo.getFedMapping().getID()});
mo.getFedMapping().execute(getTID(), true, fr1);
// derive new fed mapping for output
@@ -129,6 +136,9 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
out.getDataCharacteristics().set(mo.getDataCharacteristics());
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
}
+ else if(opcode.equals("rmempty")) {
+ rmempty(ec);
+ }
else if(opcode.equalsIgnoreCase("transformdecode"))
transformDecode(ec);
else if(opcode.equalsIgnoreCase("transformapply"))
@@ -138,6 +148,136 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
}
}
+ private void rmempty(ExecutionContext ec) {
+ MatrixObject mo = (MatrixObject) getTarget(ec);
+ MatrixObject out = ec.getMatrixObject(output);
+ Map<FederatedRange, int[]> dcs;
+ if((instString.contains("margin=rows") && mo.isFederated(FederationMap.FType.ROW)) ||
+ (instString.contains("margin=cols") && mo.isFederated(FederationMap.FType.COL))) {
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {getTargetOperand()},
+ new long[] {mo.getFedMapping().getID()});
+ mo.getFedMapping().execute(getTID(), true, fr1);
+ out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+
+ // new ranges
+ dcs = new HashMap<>();
+ out.getFedMapping().forEachParallel((range, data) -> {
+ try {
+ FederatedResponse response = data
+ .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new GetDataCharacteristics(data.getVarID())))
+ .get();
+
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ int[] subRangeCharacteristics = (int[]) response.getData()[0];
+ synchronized(dcs) {
+ dcs.put(range, subRangeCharacteristics);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+ }
+ else {
+ Map.Entry<FederationMap, Map<FederatedRange, int[]>> entry = rmemptyC(ec, mo);
+ out.setFedMapping(entry.getKey());
+ dcs = entry.getValue();
+ }
+ out.getDataCharacteristics().set(mo.getDataCharacteristics());
+ for(int i = 0; i < mo.getFedMapping().getFederatedRanges().length; i++) {
+ int[] newRange = dcs.get(out.getFedMapping().getFederatedRanges()[i]);
+
+ out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
+ (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+ i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
+
+ out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
+ out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+
+ out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
+ (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+ i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
+
+ out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
+ out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+ }
+
+ out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
+ out.getFedMapping().getMaxIndexInRange(1),
+ (int) mo.getBlocksize());
+ }
+
+ private Map.Entry<FederationMap, Map<FederatedRange, int[]>> rmemptyC(ExecutionContext ec, MatrixObject mo) {
+ boolean marginRow = instString.contains("margin=rows");
+
+ // find empty in ranges
+ List<MatrixBlock> colSums = new ArrayList<>();
+ mo.getFedMapping().forEachParallel((range, data) -> {
+ try {
+ FederatedResponse response = data
+ .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new GetVector(data.getVarID(), marginRow)))
+ .get();
+
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ MatrixBlock vector = (MatrixBlock) response.getData()[0];
+ synchronized(colSums) {
+ colSums.add(vector);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ // find empty in matrix
+ BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
+ BinaryOperator greater = InstructionUtils.parseBinaryOperator(">");
+ MatrixBlock tmp1 = colSums.get(0);
+ for(int i = 1; i < colSums.size(); i++)
+ tmp1 = tmp1.binaryOperationsInPlace(plus, colSums.get(i));
+ tmp1 = tmp1.binaryOperationsInPlace(greater, new MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
+
+ // remove empty from matrix
+ Map<FederatedRange, int[]> dcs = new HashMap<>();
+ long varID = FederationUtils.getNextFedDataID();
+ MatrixBlock finalTmp = new MatrixBlock(tmp1);
+ FederationMap resMapping;
+ if(tmp1.sum() == (marginRow ? tmp1.getNumColumns() : tmp1.getNumRows())) {
+ resMapping = mo.getFedMapping();
+ }
+ else {
+ resMapping = mo.getFedMapping().mapParallel(varID, (range, data) -> {
+ try {
+ FederatedResponse response = data
+ .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new ParameterizedBuiltinFEDInstruction.RemoveEmpty(data.getVarID(), varID, finalTmp,
+ params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null,
+ Boolean.parseBoolean(params.get("empty.return").toLowerCase()), marginRow)))
+ .get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ int[] subRangeCharacteristics = (int[]) response.getData()[0];
+ synchronized(dcs) {
+ dcs.put(range, subRangeCharacteristics);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+ }
+ return new AbstractMap.SimpleEntry<>(resMapping, dcs);
+ }
+
private void transformDecode(ExecutionContext ec) {
// acquire locks
MatrixObject mo = ec.getMatrixObject(params.get("target"));
@@ -155,14 +295,14 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
long[] beginDims = range.getBeginDims();
long[] endDims = range.getEndDims();
int colStartBefore = (int) beginDims[1];
-
+
// update begin end dims (column part) considering columns added by dummycoding
globalDecoder.updateIndexRanges(beginDims, endDims);
-
+
// get the decoder segment that is relevant for this federated worker
Decoder decoder = globalDecoder
- .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore);
-
+ .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore);
+
FrameBlock metaSlice = new FrameBlock();
synchronized(meta) {
meta.slice(0, meta.getNumRows() - 1, (int) beginDims[1], (int) endDims[1] - 1, metaSlice);
@@ -170,9 +310,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
FederatedResponse response;
try {
- response = data.executeFederatedOperation(
- new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get();
+ response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ -1, new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
@@ -217,7 +356,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
try {
FederatedResponse response = data
.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new GetColumnNames(data.getVarID()))).get();
+ new GetColumnNames(data.getVarID())))
+ .get();
// no synchronization necessary since names should anyway match
String[] subRangeColNames = (String[]) response.getData()[0];
@@ -261,7 +401,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
EncoderOmit subRangeEncoder = (EncoderOmit) omitEncoder.subRangeEncoder(range.asIndexRange().add(1));
FederatedResponse response = data
.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder))).get();
+ new InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder)))
+ .get();
// no synchronization necessary since names should anyway match
Encoder builtEncoder = (Encoder) response.getData()[0];
@@ -283,7 +424,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
private CPOperand getTargetOperand() {
return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
}
-
+
public static class DecodeMatrix extends FederatedUDF {
private static final long serialVersionUID = 2376756757742169692L;
private final long _outputID;
@@ -330,7 +471,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
- FrameBlock fb = ((FrameObject)data[0]).acquireReadAndRelease();
+ FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease();
// return column names
return new FederatedResponse(ResponseType.SUCCESS, new Object[] {fb.getColumnNames()});
}
@@ -348,9 +489,98 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
- FrameBlock fb = ((FrameObject)data[0]).acquireReadAndRelease();
+ FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease();
_encoder.build(fb);
return new FederatedResponse(ResponseType.SUCCESS, new Object[] {_encoder});
}
}
+
+ private static class GetDataCharacteristics extends FederatedUDF {
+
+ private static final long serialVersionUID = 578461386177730925L;
+
+ public GetDataCharacteristics(long varID) {
+ super(new long[] {varID});
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
+ return new FederatedResponse(ResponseType.SUCCESS, new int[] {mb.getNumRows(), mb.getNumColumns()});
+ }
+ }
+
+ private static class RemoveEmpty extends FederatedUDF {
+
+ private static final long serialVersionUID = 12341521331L;
+ private final MatrixBlock _vector;
+ private final long _outputID;
+ private MatrixBlock _select;
+ private boolean _emptyReturn;
+ private final boolean _marginRow;
+
+ public RemoveEmpty(long varID, long outputID, MatrixBlock vector, MatrixBlock select, boolean emptyReturn,
+ boolean marginRow) {
+ super(new long[] {varID});
+ _vector = vector;
+ _outputID = outputID;
+ _select = select;
+ _emptyReturn = emptyReturn;
+ _marginRow = marginRow;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
+
+ BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
+ BinaryOperator minus = InstructionUtils.parseBinaryOperator("-");
+
+ mb = mb.binaryOperationsInPlace(plus, new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1.0));
+ for(int i = 0; i < mb.getNumRows(); i++)
+ for(int j = 0; j < mb.getNumColumns(); j++)
+ if(_marginRow)
+ mb.setValue(i, j, _vector.getValue(i, 0) * mb.getValue(i, j));
+ else
+ mb.setValue(i, j, _vector.getValue(0, j) * mb.getValue(i, j));
+
+ MatrixBlock res = mb.removeEmptyOperations(new MatrixBlock(), _marginRow, _emptyReturn, _select);
+ res = res.binaryOperationsInPlace(minus, new MatrixBlock(res.getNumRows(), res.getNumColumns(), 1.0));
+
+ MatrixObject mout = ExecutionContext.createMatrixObject(res);
+ ec.setVariable(String.valueOf(_outputID), mout);
+
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+ new int[] {res.getNumRows(), res.getNumColumns()});
+ }
+ }
+
+ private static class GetVector extends FederatedUDF {
+
+ private static final long serialVersionUID = -1003061862215703768L;
+ private final boolean _marginRow;
+
+ public GetVector(long varID, boolean marginRow) {
+ super(new long[] {varID});
+ _marginRow = marginRow;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
+
+ BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
+ BinaryOperator greater = InstructionUtils.parseBinaryOperator(">");
+ int len = _marginRow ? mb.getNumColumns() : mb.getNumRows();
+ MatrixBlock tmp1 = _marginRow ? mb.slice(0, mb.getNumRows() - 1, 0, 0, new MatrixBlock()) : mb
+ .slice(0, 0, 0, mb.getNumColumns() - 1, new MatrixBlock());
+ for(int i = 1; i < len; i++) {
+ MatrixBlock tmp2 = _marginRow ? mb.slice(0, mb.getNumRows() - 1, i, i, new MatrixBlock()) : mb
+ .slice(i, i, 0, mb.getNumColumns() - 1, new MatrixBlock());
+ tmp1 = tmp1.binaryOperationsInPlace(plus, tmp2);
+ }
+ tmp1 = tmp1.binaryOperationsInPlace(greater, new MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
+ return new FederatedResponse(ResponseType.SUCCESS, tmp1);
+ }
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
new file mode 100644
index 0000000..de1e6d5
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedRemoveEmptyTest extends AutomatedTestBase {
+ // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName());
+
+ private final static String TEST_NAME = "FederatedRemoveEmptyTest";
+
+ private final static String TEST_DIR = "functions/federated/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRemoveEmptyTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{20, 10, true}, {20, 12, false}});
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ }
+
+ @Test
+ public void testRemoveEmptyCP() {
+ runAggregateOperationTest(ExecMode.SINGLE_NODE);
+ }
+
+ private void runAggregateOperationTest(ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ for(int k : new int[] {1, 2, 3}) {
+ Arrays.fill(X3[k], 0);
+ if(!rowPartitioned) {
+ Arrays.fill(X1[k], 0);
+ Arrays.fill(X2[k], 0);
+ Arrays.fill(X4[k], 0);
+ }
+ }
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t2 = startLocalFedWorkerThread(port2, 10);
+ Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ System.out.println(7);
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
+
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+
+ runTest(null);
+
+ // compare via files
+ compareResults(1e-9);
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml b/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml
new file mode 100644
index 0000000..0c6b77b
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
+}
+
+s = removeEmpty(target=A, margin="cols");
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml b/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml
new file mode 100644
index 0000000..c4b2dc9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = removeEmpty(target=A, margin="cols");
+write(s, $6);