You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/15 17:09:29 UTC
[systemds] branch master updated: [SYSTEMDS-2543, 2549,
2623] Additional federated instructions (for pca)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new a083501 [SYSTEMDS-2543,2549,2623] Additional federated instructions (for pca)
a083501 is described below
commit a0835010840346a260029eaebe3813d8f7a05a0f
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Aug 15 19:09:06 2020 +0200
[SYSTEMDS-2543,2549,2623] Additional federated instructions (for pca)
This patch adds all remaining federated instruction to run pca with and
without scale and shift. In details this includes:
* Federated matrix-matrix operations (matrix-matrix and matrix-vector w/
one federated input)
* Federated column aggregates for uacmean, incl local compensation
* Federated replace parameterized builtin function
* Cleanup federated statistics (alignment, spaces)
---
.../compress/AbstractCompressedMatrixBlock.java | 3 +-
.../controlprogram/federated/FederatedRange.java | 9 +-
.../controlprogram/federated/FederationMap.java | 4 +
.../controlprogram/federated/FederationUtils.java | 53 +++++++--
.../cp/ParameterizedBuiltinCPInstruction.java | 4 +
.../fed/AggregateUnaryFEDInstruction.java | 8 +-
.../instructions/fed/BinaryFEDInstruction.java | 2 +-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 61 +++++++++++
.../runtime/instructions/fed/FEDInstruction.java | 1 +
.../instructions/fed/FEDInstructionUtils.java | 27 +++--
.../fed/ParameterizedBuiltinFEDInstruction.java | 121 +++++++++++++++++++++
.../sysds/runtime/matrix/data/CM_N_COVCell.java | 2 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 3 +-
.../sysds/runtime/matrix/data/MatrixCell.java | 3 +-
.../sysds/runtime/matrix/data/MatrixValue.java | 2 +-
.../java/org/apache/sysds/utils/Statistics.java | 14 +--
.../test/functions/federated/FederatedPCATest.java | 9 +-
17 files changed, 280 insertions(+), 46 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
index bf86ede..913563c 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
@@ -148,11 +148,12 @@ public abstract class AbstractCompressedMatrixBlock extends MatrixBlock {
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
+ public MatrixBlock binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
printDecompressWarning("binaryOperationsInPlace", (MatrixBlock) thatValue);
MatrixBlock left = isCompressed() ? decompress() : this;
MatrixBlock right = getUncompressed(thatValue);
left.binaryOperationsInPlace(op, right);
+ return this;
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index b4f69ad..6571666 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -71,12 +71,15 @@ public class FederatedRange implements Comparable<FederatedRange> {
public long getSize() {
long size = 1;
- for (int i = 0; i < _beginDims.length; i++) {
- size *= _endDims[i] - _beginDims[i];
- }
+ for (int i = 0; i < _beginDims.length; i++)
+ size *= getSize(i);
return size;
}
+ public long getSize(int dim) {
+ return _endDims[dim] - _beginDims[dim];
+ }
+
@Override
public int compareTo(FederatedRange o) {
for (int i = 0; i < _beginDims.length; i++) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index f224da2..04532fd 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -62,6 +62,10 @@ public class FederationMap
return _ID >= 0;
}
+ public FederatedRange[] getFederatedRanges() {
+ return _fedMap.keySet().toArray(new FederatedRange[0]);
+ }
+
public FederatedRequest broadcast(CacheableData<?> data) {
//prepare single request for all federated data
long id = FederationUtils.getNextFedDataID();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index f2c8227..c34fa62 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -29,12 +29,16 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
+import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
public class FederationUtils {
@@ -50,9 +54,10 @@ public class FederationUtils {
String linst = inst.replace(ExecType.SPARK.name(), ExecType.CP.name());
linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(), Lop.OPERAND_DELIMITOR+String.valueOf(id));
for(int i=0; i<varOldIn.length; i++)
- if( varOldIn[i] != null )
- linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(),
- Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+ if( varOldIn[i] != null ) {
+ linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(), Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+ linst = linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i])); //parameterized
+ }
return new FederatedRequest(RequestType.EXEC_INST, id, linst);
}
@@ -69,6 +74,29 @@ public class FederationUtils {
}
}
+ public static MatrixBlock aggMean(Future<FederatedResponse>[] ffr, FederationMap map) {
+ try {
+ FederatedRange[] ranges = map.getFederatedRanges();
+ BinaryOperator bop = InstructionUtils.parseBinaryOperator("+");
+ ScalarOperator sop1 = InstructionUtils.parseScalarBinaryOperator("*", false);
+ MatrixBlock ret = null;
+ long size = 0;
+ for(int i=0; i<ffr.length; i++) {
+ MatrixBlock tmp = (MatrixBlock)ffr[i].get().getData()[0];
+ size += ranges[i].getSize(0);
+ sop1 = sop1.setConstant(ranges[i].getSize(0));
+ tmp = tmp.scalarOperations(sop1, new MatrixBlock());
+ ret = (ret==null) ? tmp : ret.binaryOperationsInPlace(bop, tmp);
+ }
+ ScalarOperator sop2 = InstructionUtils.parseScalarBinaryOperator("/", false);
+ sop2 = sop2.setConstant(size);
+ return ret.scalarOperations(sop2, new MatrixBlock());
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
public static MatrixBlock[] getResults(Future<FederatedResponse>[] ffr) {
try {
MatrixBlock[] ret = new MatrixBlock[ffr.length];
@@ -111,13 +139,20 @@ public class FederationUtils {
}
}
- public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) {
- if( !(aop.aggOp.increOp.fn instanceof KahanFunction) ) {
- throw new DMLRuntimeException("Unsupported aggregation operator: "
- + aop.aggOp.increOp.getClass().getSimpleName());
+ public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
+ // handle row aggregate
+ if( aop.isRowAggregate() ) {
+ //independent of aggregation function for row-partitioned federated matrices
+ return rbind(ffr);
}
- //assumes full row partitions for row and col aggregates
- return aop.isRowAggregate() ? rbind(ffr) : aggAdd(ffr);
+ // handle col aggregate
+ if( aop.aggOp.increOp.fn instanceof KahanFunction )
+ return aggAdd(ffr);
+ else if( aop.aggOp.increOp.fn instanceof Mean )
+ return aggMean(ffr, map);
+ else
+ throw new DMLRuntimeException("Unsupported aggregation operator: "
+ + aop.aggOp.increOp.fn.getClass().getSimpleName());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f330793..4ced46f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -450,6 +450,10 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
}
}
+ public MatrixObject getTarget(ExecutionContext ec) {
+ return ec.getMatrixObject(params.get("target"));
+ }
+
private CPOperand getTargetOperand() {
return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e5dd81e..a9b655b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -26,6 +26,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -61,11 +62,12 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
//execute federated commands and cleanups
- Future<FederatedResponse>[] tmp = in.getFedMapping().execute(fr1, fr2);
- in.getFedMapping().cleanup(fr1.getID());
+ FederationMap map = in.getFedMapping();
+ Future<FederatedResponse>[] tmp = map.execute(fr1, fr2);
+ map.cleanup(fr1.getID());
if( output.isScalar() )
ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp));
else
- ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp));
+ ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 9782558..f1f8f38 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -48,7 +48,7 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
if( in1.getDataType() == DataType.SCALAR && in2.getDataType() == DataType.SCALAR )
throw new DMLRuntimeException("Federated binary scalar scalar operations not yet supported");
else if( in1.getDataType() == DataType.MATRIX && in2.getDataType() == DataType.MATRIX )
- throw new DMLRuntimeException("Federated binary matrix matrix operations not yet supported");
+ return new BinaryMatrixMatrixFEDInstruction(operator, in1, in2, out, opcode, str);
else if( in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.TENSOR )
throw new DMLRuntimeException("Federated binary tensor tensor operations not yet supported");
else if( in1.isMatrix() && in2.isScalar() )
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
new file mode 100644
index 0000000..d124c76
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
+{
+ protected BinaryMatrixMatrixFEDInstruction(Operator op,
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
+ super(FEDType.Binary, op, in1, in2, out, opcode, istr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+
+ if( mo2.isFederated() ) {
+ throw new DMLRuntimeException("Matrix-matrix binary operations "
+ + " with a federated right input are not supported yet.");
+ }
+
+ //matrix-matrix binary operations -> lhs fed input -> fed output
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+
+ //execute federated instruction and cleanup intermediates
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1.getID());
+
+ //derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+ out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index d6bd388..9e58e52 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -33,6 +33,7 @@ public abstract class FEDInstruction extends Instruction {
Binary,
Init,
MultiReturnParameterizedBuiltin,
+ ParameterizedBuiltin,
Tsmm,
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 5f97350..00f3b04 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -54,25 +54,24 @@ public class FEDInstructionUtils {
}
else if (inst instanceof BinaryCPInstruction) {
BinaryCPInstruction instruction = (BinaryCPInstruction) inst;
- if( instruction.input1.isMatrix() && instruction.input2.isScalar() ){
- MatrixObject mo = ec.getMatrixObject(instruction.input1);
- if(mo.isFederated())
- return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ if( instruction.input1.isMatrix() && ec.getMatrixObject(instruction.input1).isFederated()
+ || instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated() ) {
+ return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
- if( instruction.input2.isMatrix() && instruction.input1.isScalar() ){
- MatrixObject mo = ec.getMatrixObject(instruction.input2);
- if(mo.isFederated())
- return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
+ ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction)inst;
+ if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) {
+ return ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
}
}
else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
- MultiReturnParameterizedBuiltinCPInstruction instruction = (MultiReturnParameterizedBuiltinCPInstruction) inst;
- String opcode = instruction.getOpcode();
- if(opcode.equals("transformencode") && instruction.input1.isFrame()) {
- CacheableData<?> fo = ec.getCacheableData(instruction.input1);
+ MultiReturnParameterizedBuiltinCPInstruction minst = (MultiReturnParameterizedBuiltinCPInstruction) inst;
+ if(minst.getOpcode().equals("transformencode") && minst.input1.isFrame()) {
+ CacheableData<?> fo = ec.getCacheableData(minst.input1);
if(fo.isFederated()) {
return MultiReturnParameterizedBuiltinFEDInstruction
- .parseInstruction(instruction.getInstructionString());
+ .parseInstruction(minst.getInstructionString());
}
}
}
@@ -80,7 +79,7 @@ public class FEDInstructionUtils {
MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
MatrixObject mo = ec.getMatrixObject(linst.input1);
if( mo.isFederated() )
- return TsmmFEDInstruction.parseInstruction(linst.toString());
+ return TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
}
return inst;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
new file mode 100644
index 0000000..3a5ff8a
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+
+public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction {
+
+ protected final LinkedHashMap<String, String> params;
+
+ protected ParameterizedBuiltinFEDInstruction(Operator op,
+ LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr)
+ {
+ super(FEDType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
+ params = paramsMap;
+ }
+
+ public HashMap<String,String> getParameterMap() {
+ return params;
+ }
+
+ public String getParam(String key) {
+ return getParameterMap().get(key);
+ }
+
+ public static LinkedHashMap<String, String> constructParameterMap(String[] params) {
+ // process all elements in "params" except first(opcode) and last(output)
+ LinkedHashMap<String,String> paramMap = new LinkedHashMap<>();
+
+ // all parameters are of form <name=value>
+ String[] parts;
+ for ( int i=1; i <= params.length-2; i++ ) {
+ parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
+ paramMap.put(parts[0], parts[1]);
+ }
+
+ return paramMap;
+ }
+
+ public static ParameterizedBuiltinFEDInstruction parseInstruction ( String str ) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ // first part is always the opcode
+ String opcode = parts[0];
+ // last part is always the output
+ CPOperand out = new CPOperand( parts[parts.length-1] );
+
+ // process remaining parts and build a hash map
+ LinkedHashMap<String,String> paramsMap = constructParameterMap(parts);
+
+ // determine the appropriate value function
+ ValueFunction func = null;
+ if( opcode.equalsIgnoreCase("replace") ) {
+ func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
+ }
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ String opcode = getOpcode();
+ if ( opcode.equalsIgnoreCase("replace") ) {
+ //similar to unary federated instructions, get federated input
+ //execute instruction, and derive federated output matrix
+ MatrixObject mo = getTarget(ec);
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{getTargetOperand()}, new long[]{mo.getFedMapping().getID()});
+ mo.getFedMapping().execute(fr1);
+
+ //derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo.getDataCharacteristics());
+ out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+ }
+ else {
+ throw new DMLRuntimeException("Unknown opcode : " + opcode);
+ }
+ }
+
+ public MatrixObject getTarget(ExecutionContext ec) {
+ return ec.getMatrixObject(params.get("target"));
+ }
+
+ private CPOperand getTargetOperand() {
+ return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
index a79677b..063ff77 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
@@ -61,7 +61,7 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
+ public MatrixValue binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
throw new DMLRuntimeException("operation not supported for CM_N_COVCell");
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 45a0965..ed52481 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2837,7 +2837,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
+ public MatrixBlock binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
MatrixBlock that=checkType(thatValue);
if( !LibMatrixBincell.isValidDimensionsBinary(this, that) ) {
throw new RuntimeException("block sizes are not matched for binary " +
@@ -2853,6 +2853,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
//core binary cell operation
LibMatrixBincell.bincellOpInPlace(this, that, op);
+ return this;
}
public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
index 42338d9..10d7e61 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
@@ -198,10 +198,11 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op,
+ public MatrixValue binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue) {
MatrixCell c2=checkType(thatValue);
setValue(op.fn.execute(this.getValue(), c2.getValue()));
+ return this;
}
public void denseScalarOperationsInPlace(ScalarOperator op) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
index 102e433..9b213ec 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
@@ -103,7 +103,7 @@ public abstract class MatrixValue implements WritableComparable
public abstract MatrixValue binaryOperations(BinaryOperator op, MatrixValue thatValue, MatrixValue result);
- public abstract void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue);
+ public abstract MatrixValue binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue);
public abstract MatrixValue reorgOperations(ReorgOperator op, MatrixValue result,
int startRow, int startColumn, int length);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index 0a6ec38..b498b0e 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -1020,13 +1020,13 @@ public class Statistics
sb.append("ParFor total update in-place:\t" + lTotalUIPVar + "/" + lTotalLixUIP + "/" + lTotalLix + "\n");
}
if( federatedReadCount.longValue() > 0){
- sb.append("Federated (Reads,Puts,Gets) :\t(" +
- federatedReadCount.longValue() + "," +
- federatedPutCount.longValue() + "," +
- federatedGetCount.longValue() + ")\n");
- sb.append("Federated Execute (In,UDF) :\t(" +
- federatedExecuteInstructionCount.longValue() + "," +
- federatedExecuteUDFCount.longValue() + ")\n");
+ sb.append("Federated I/O (Read, Put, Get):\t" +
+ federatedReadCount.longValue() + "/" +
+ federatedPutCount.longValue() + "/" +
+ federatedGetCount.longValue() + ".\n");
+ sb.append("Federated Execute (Inst, UDF):\t" +
+ federatedExecuteInstructionCount.longValue() + "/" +
+ federatedExecuteUDFCount.longValue() + ".\n");
}
sb.append("Total JIT compile time:\t\t" + ((double)getJITCompileTime())/1000 + " sec.\n");
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index 29826f8..bf674a8 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -61,8 +61,7 @@ public class FederatedPCATest extends AutomatedTestBase {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
{10000, 10, false}, {2000, 50, false}, {1000, 100, false},
- //TODO support for federated uacmean, uacvar
- //{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+ {10000, 10, true}, {2000, 50, true}, {1000, 100, true}
});
}
@@ -99,7 +98,6 @@ public class FederatedPCATest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
- setOutputBuffering(false);
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
@@ -124,8 +122,11 @@ public class FederatedPCATest extends AutomatedTestBase {
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
if( scaleAndShift ) {
+ Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
- Assert.assertTrue(heavyHittersContainsString("fed_uacvar"));
+ Assert.assertTrue(heavyHittersContainsString("fed_-"));
+ Assert.assertTrue(heavyHittersContainsString("fed_/"));
+ Assert.assertTrue(heavyHittersContainsString("fed_replace"));
}
resetExecMode(platformOld);