You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/15 17:09:29 UTC

[systemds] branch master updated: [SYSTEMDS-2543, 2549, 2623] Additional federated instructions (for pca)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new a083501  [SYSTEMDS-2543,2549,2623] Additional federated instructions (for pca)
a083501 is described below

commit a0835010840346a260029eaebe3813d8f7a05a0f
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Aug 15 19:09:06 2020 +0200

    [SYSTEMDS-2543,2549,2623] Additional federated instructions (for pca)
    
    This patch adds all remaining federated instruction to run pca with and
    without scale and shift. In details this includes:
    
    * Federated matrix-matrix operations (matrix-matrix and matrix-vector w/
    one federated input)
    * Federated column aggregates for uacmean, incl local compensation
    * Federated replace parameterized builtin function
    * Cleanup federated statistics (alignment, spaces)
---
 .../compress/AbstractCompressedMatrixBlock.java    |   3 +-
 .../controlprogram/federated/FederatedRange.java   |   9 +-
 .../controlprogram/federated/FederationMap.java    |   4 +
 .../controlprogram/federated/FederationUtils.java  |  53 +++++++--
 .../cp/ParameterizedBuiltinCPInstruction.java      |   4 +
 .../fed/AggregateUnaryFEDInstruction.java          |   8 +-
 .../instructions/fed/BinaryFEDInstruction.java     |   2 +-
 .../fed/BinaryMatrixMatrixFEDInstruction.java      |  61 +++++++++++
 .../runtime/instructions/fed/FEDInstruction.java   |   1 +
 .../instructions/fed/FEDInstructionUtils.java      |  27 +++--
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 121 +++++++++++++++++++++
 .../sysds/runtime/matrix/data/CM_N_COVCell.java    |   2 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |   3 +-
 .../sysds/runtime/matrix/data/MatrixCell.java      |   3 +-
 .../sysds/runtime/matrix/data/MatrixValue.java     |   2 +-
 .../java/org/apache/sysds/utils/Statistics.java    |  14 +--
 .../test/functions/federated/FederatedPCATest.java |   9 +-
 17 files changed, 280 insertions(+), 46 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
index bf86ede..913563c 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
@@ -148,11 +148,12 @@ public abstract class AbstractCompressedMatrixBlock extends MatrixBlock {
 	}
 
 	@Override
-	public void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
+	public MatrixBlock binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
 		printDecompressWarning("binaryOperationsInPlace", (MatrixBlock) thatValue);
 		MatrixBlock left = isCompressed() ? decompress() : this;
 		MatrixBlock right = getUncompressed(thatValue);
 		left.binaryOperationsInPlace(op, right);
+		return this;
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index b4f69ad..6571666 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -71,12 +71,15 @@ public class FederatedRange implements Comparable<FederatedRange> {
 	
 	public long getSize() {
 		long size = 1;
-		for (int i = 0; i < _beginDims.length; i++) {
-			size *= _endDims[i] - _beginDims[i];
-		}
+		for (int i = 0; i < _beginDims.length; i++)
+			size *= getSize(i);
 		return size;
 	}
 	
+	public long getSize(int dim) {
+		return _endDims[dim] - _beginDims[dim];
+	}
+	
 	@Override
 	public int compareTo(FederatedRange o) {
 		for (int i = 0; i < _beginDims.length; i++) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index f224da2..04532fd 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -62,6 +62,10 @@ public class FederationMap
 		return _ID >= 0;
 	}
 	
+	public FederatedRange[] getFederatedRanges() {
+		return _fedMap.keySet().toArray(new FederatedRange[0]);
+	}
+	
 	public FederatedRequest broadcast(CacheableData<?> data) {
 		//prepare single request for all federated data
 		long id = FederationUtils.getNextFedDataID();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index f2c8227..c34fa62 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -29,12 +29,16 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
+import org.apache.sysds.runtime.functionobjects.Mean;
 import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 
 public class FederationUtils {
@@ -50,9 +54,10 @@ public class FederationUtils {
 		String linst = inst.replace(ExecType.SPARK.name(), ExecType.CP.name());
 		linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(), Lop.OPERAND_DELIMITOR+String.valueOf(id));
 		for(int i=0; i<varOldIn.length; i++)
-			if( varOldIn[i] != null )
-				linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(),
-					Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+			if( varOldIn[i] != null ) {
+				linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(), Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+				linst = linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i])); //parameterized
+			}
 		return new FederatedRequest(RequestType.EXEC_INST, id, linst);
 	}
 
@@ -69,6 +74,29 @@ public class FederationUtils {
 		}
 	}
 	
+	public static MatrixBlock aggMean(Future<FederatedResponse>[] ffr, FederationMap map) {
+		try {
+			FederatedRange[] ranges = map.getFederatedRanges();
+			BinaryOperator bop = InstructionUtils.parseBinaryOperator("+");
+			ScalarOperator sop1 = InstructionUtils.parseScalarBinaryOperator("*", false);
+			MatrixBlock ret = null;
+			long size = 0;
+			for(int i=0; i<ffr.length; i++) {
+				MatrixBlock tmp = (MatrixBlock)ffr[i].get().getData()[0];
+				size += ranges[i].getSize(0);
+				sop1 = sop1.setConstant(ranges[i].getSize(0));
+				tmp = tmp.scalarOperations(sop1, new MatrixBlock());
+				ret = (ret==null) ? tmp : ret.binaryOperationsInPlace(bop, tmp);
+			}
+			ScalarOperator sop2 = InstructionUtils.parseScalarBinaryOperator("/", false);
+			sop2 = sop2.setConstant(size);
+			return ret.scalarOperations(sop2, new MatrixBlock());
+		}
+		catch(Exception ex) {
+			throw new DMLRuntimeException(ex);
+		}
+	}
+	
 	public static MatrixBlock[] getResults(Future<FederatedResponse>[] ffr) {
 		try {
 			MatrixBlock[] ret = new MatrixBlock[ffr.length];
@@ -111,13 +139,20 @@ public class FederationUtils {
 		}
 	}
 
-	public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) {
-		if( !(aop.aggOp.increOp.fn instanceof KahanFunction) ) {
-			throw new DMLRuntimeException("Unsupported aggregation operator: "
-				+ aop.aggOp.increOp.getClass().getSimpleName());
+	public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
+		// handle row aggregate
+		if( aop.isRowAggregate() ) {
+			//independent of aggregation function for row-partitioned federated matrices
+			return rbind(ffr);
 		}
 		
-		//assumes full row partitions for row and col aggregates
-		return aop.isRowAggregate() ?  rbind(ffr) : aggAdd(ffr);
+		// handle col aggregate
+		if( aop.aggOp.increOp.fn instanceof KahanFunction )
+			return aggAdd(ffr);
+		else if( aop.aggOp.increOp.fn instanceof Mean )
+			return aggMean(ffr, map);
+		else
+			throw new DMLRuntimeException("Unsupported aggregation operator: "
+				+ aop.aggOp.increOp.fn.getClass().getSimpleName());
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f330793..4ced46f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -450,6 +450,10 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
 		}
 	}
 	
+	public MatrixObject getTarget(ExecutionContext ec) {
+		return ec.getMatrixObject(params.get("target"));
+	}
+	
 	private CPOperand getTargetOperand() {
 		return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e5dd81e..a9b655b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -26,6 +26,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -61,11 +62,12 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 		FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
 		
 		//execute federated commands and cleanups
-		Future<FederatedResponse>[] tmp = in.getFedMapping().execute(fr1, fr2);
-		in.getFedMapping().cleanup(fr1.getID());
+		FederationMap map = in.getFedMapping();
+		Future<FederatedResponse>[] tmp = map.execute(fr1, fr2);
+		map.cleanup(fr1.getID());
 		if( output.isScalar() )
 			ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp));
 		else
-			ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp));
+			ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 9782558..f1f8f38 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -48,7 +48,7 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
 		if( in1.getDataType() == DataType.SCALAR && in2.getDataType() == DataType.SCALAR )
 			throw new DMLRuntimeException("Federated binary scalar scalar operations not yet supported");
 		else if( in1.getDataType() == DataType.MATRIX && in2.getDataType() == DataType.MATRIX )
-			throw new DMLRuntimeException("Federated binary matrix matrix operations not yet supported");
+			return new BinaryMatrixMatrixFEDInstruction(operator, in1, in2, out, opcode, str);
 		else if( in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.TENSOR )
 			throw new DMLRuntimeException("Federated binary tensor tensor operations not yet supported");
 		else if( in1.isMatrix() && in2.isScalar() )
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
new file mode 100644
index 0000000..d124c76
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
+{
+	protected BinaryMatrixMatrixFEDInstruction(Operator op,
+		CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
+		super(FEDType.Binary, op, in1, in2, out, opcode, istr);
+	}
+
+	@Override
+	public void processInstruction(ExecutionContext ec) {
+		MatrixObject mo1 = ec.getMatrixObject(input1);
+		MatrixObject mo2 = ec.getMatrixObject(input2);
+		
+		if( mo2.isFederated() ) {
+			throw new DMLRuntimeException("Matrix-matrix binary operations "
+				+ " with a federated right input are not supported yet.");
+		}
+		
+		//matrix-matrix binary operations -> lhs fed input -> fed output
+		FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+		
+		//execute federated instruction and cleanup intermediates
+		mo1.getFedMapping().execute(fr1, fr2);
+		mo1.getFedMapping().cleanup(fr1.getID());
+		
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index d6bd388..9e58e52 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -33,6 +33,7 @@ public abstract class FEDInstruction extends Instruction {
 		Binary,
 		Init,
 		MultiReturnParameterizedBuiltin,
+		ParameterizedBuiltin,
 		Tsmm,
 	}
 	
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 5f97350..00f3b04 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -54,25 +54,24 @@ public class FEDInstructionUtils {
 		}
 		else if (inst instanceof BinaryCPInstruction) {
 			BinaryCPInstruction instruction = (BinaryCPInstruction) inst;
-			if( instruction.input1.isMatrix() && instruction.input2.isScalar() ){
-				MatrixObject mo = ec.getMatrixObject(instruction.input1);
-				if(mo.isFederated())
-					return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+			if( instruction.input1.isMatrix() && ec.getMatrixObject(instruction.input1).isFederated()
+				|| instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated() ) {
+				return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
 			}
-			if( instruction.input2.isMatrix() && instruction.input1.isScalar() ){
-				MatrixObject mo = ec.getMatrixObject(instruction.input2);
-				if(mo.isFederated())
-					return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+		}
+		else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
+			ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction)inst;
+			if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) {
+				return ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
 			}
 		}
 		else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
-			MultiReturnParameterizedBuiltinCPInstruction instruction = (MultiReturnParameterizedBuiltinCPInstruction) inst;
-			String opcode = instruction.getOpcode();
-			if(opcode.equals("transformencode") && instruction.input1.isFrame()) {
-				CacheableData<?> fo = ec.getCacheableData(instruction.input1);
+			MultiReturnParameterizedBuiltinCPInstruction minst = (MultiReturnParameterizedBuiltinCPInstruction) inst;
+			if(minst.getOpcode().equals("transformencode") && minst.input1.isFrame()) {
+				CacheableData<?> fo = ec.getCacheableData(minst.input1);
 				if(fo.isFederated()) {
 					return MultiReturnParameterizedBuiltinFEDInstruction
-						.parseInstruction(instruction.getInstructionString());
+						.parseInstruction(minst.getInstructionString());
 				}
 			}
 		}
@@ -80,7 +79,7 @@ public class FEDInstructionUtils {
 			MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
 			MatrixObject mo = ec.getMatrixObject(linst.input1);
 			if( mo.isFederated() )
-				return TsmmFEDInstruction.parseInstruction(linst.toString());
+				return TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
 		}
 		return inst;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
new file mode 100644
index 0000000..3a5ff8a
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+
+public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction {
+
+	protected final LinkedHashMap<String, String> params;
+	
+	protected ParameterizedBuiltinFEDInstruction(Operator op,
+		LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr)
+	{
+		super(FEDType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
+		params = paramsMap;
+	}
+	
+	public HashMap<String,String> getParameterMap() { 
+		return params; 
+	}
+	
+	public String getParam(String key) {
+		return getParameterMap().get(key);
+	}
+	
+	public static LinkedHashMap<String, String> constructParameterMap(String[] params) {
+		// process all elements in "params" except first(opcode) and last(output)
+		LinkedHashMap<String,String> paramMap = new LinkedHashMap<>();
+		
+		// all parameters are of form <name=value>
+		String[] parts;
+		for ( int i=1; i <= params.length-2; i++ ) {
+			parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
+			paramMap.put(parts[0], parts[1]);
+		}
+		
+		return paramMap;
+	}
+	
+	public static ParameterizedBuiltinFEDInstruction parseInstruction ( String str ) {
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		// first part is always the opcode
+		String opcode = parts[0];
+		// last part is always the output
+		CPOperand out = new CPOperand( parts[parts.length-1] ); 
+	
+		// process remaining parts and build a hash map
+		LinkedHashMap<String,String> paramsMap = constructParameterMap(parts);
+	
+		// determine the appropriate value function
+		ValueFunction func = null;
+		if( opcode.equalsIgnoreCase("replace") ) {
+			func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+			return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
+		}
+		else {
+			throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
+		}
+	}
+	
+	@Override 
+	public void processInstruction(ExecutionContext ec) {
+		String opcode = getOpcode();
+		if ( opcode.equalsIgnoreCase("replace") ) {
+			//similar to unary federated instructions, get federated input
+			//execute instruction, and derive federated output matrix
+			MatrixObject mo = getTarget(ec);
+			FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+				new CPOperand[]{getTargetOperand()}, new long[]{mo.getFedMapping().getID()});
+			mo.getFedMapping().execute(fr1);
+			
+			//derive new fed mapping for output
+			MatrixObject out = ec.getMatrixObject(output);
+			out.getDataCharacteristics().set(mo.getDataCharacteristics());
+			out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+		}
+		else {
+			throw new DMLRuntimeException("Unknown opcode : " + opcode);
+		}
+	}
+	
+	public MatrixObject getTarget(ExecutionContext ec) {
+		return ec.getMatrixObject(params.get("target"));
+	}
+	
+	private CPOperand getTargetOperand() {
+		return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
index a79677b..063ff77 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
@@ -61,7 +61,7 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable
 	}
 
 	@Override
-	public void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
+	public MatrixValue binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
 		throw new DMLRuntimeException("operation not supported for CM_N_COVCell");
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 45a0965..ed52481 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2837,7 +2837,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 	}
 
 	@Override
-	public void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
+	public MatrixBlock binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) {
 		MatrixBlock that=checkType(thatValue);
 		if( !LibMatrixBincell.isValidDimensionsBinary(this, that) ) {
 			throw new RuntimeException("block sizes are not matched for binary " +
@@ -2853,6 +2853,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		
 		//core binary cell operation
 		LibMatrixBincell.bincellOpInPlace(this, that, op);
+		return this;
 	}
 	
 	public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
index 42338d9..10d7e61 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
@@ -198,10 +198,11 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria
 	}
 
 	@Override
-	public void binaryOperationsInPlace(BinaryOperator op,
+	public MatrixValue binaryOperationsInPlace(BinaryOperator op,
 			MatrixValue thatValue) {
 		MatrixCell c2=checkType(thatValue);
 		setValue(op.fn.execute(this.getValue(), c2.getValue()));
+		return this;
 	}
 
 	public void denseScalarOperationsInPlace(ScalarOperator op) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
index 102e433..9b213ec 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
@@ -103,7 +103,7 @@ public abstract class MatrixValue implements WritableComparable
 	
 	public abstract MatrixValue binaryOperations(BinaryOperator op, MatrixValue thatValue, MatrixValue result);
 	
-	public abstract void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue);
+	public abstract MatrixValue binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue);
 	
 	public abstract MatrixValue reorgOperations(ReorgOperator op, MatrixValue result,
 			int startRow, int startColumn, int length);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index 0a6ec38..b498b0e 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -1020,13 +1020,13 @@ public class Statistics
 				sb.append("ParFor total update in-place:\t" + lTotalUIPVar + "/" + lTotalLixUIP + "/" + lTotalLix + "\n");
 			}
 			if( federatedReadCount.longValue() > 0){
-				sb.append("Federated (Reads,Puts,Gets) :\t(" + 
-					federatedReadCount.longValue() + "," +
-					federatedPutCount.longValue() + "," +
-					federatedGetCount.longValue() + ")\n");
-				sb.append("Federated Execute (In,UDF)  :\t(" +
-					federatedExecuteInstructionCount.longValue() + "," +
-					federatedExecuteUDFCount.longValue() + ")\n");
+				sb.append("Federated I/O (Read, Put, Get):\t" + 
+					federatedReadCount.longValue() + "/" +
+					federatedPutCount.longValue() + "/" +
+					federatedGetCount.longValue() + ".\n");
+				sb.append("Federated Execute (Inst, UDF):\t" +
+					federatedExecuteInstructionCount.longValue() + "/" +
+					federatedExecuteUDFCount.longValue() + ".\n");
 			}
 
 			sb.append("Total JIT compile time:\t\t" + ((double)getJITCompileTime())/1000 + " sec.\n");
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index 29826f8..bf674a8 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -61,8 +61,7 @@ public class FederatedPCATest extends AutomatedTestBase {
 		// rows have to be even and > 1
 		return Arrays.asList(new Object[][] {
 			{10000, 10, false}, {2000, 50, false}, {1000, 100, false},
-			//TODO support for federated uacmean, uacvar
-			//{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+			{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
 		});
 	}
 
@@ -99,7 +98,6 @@ public class FederatedPCATest extends AutomatedTestBase {
 
 		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
 		loadTestConfiguration(config);
-		setOutputBuffering(false);
 		
 		// Run reference dml script with normal matrix
 		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
@@ -124,8 +122,11 @@ public class FederatedPCATest extends AutomatedTestBase {
 		Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
 		Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
 		if( scaleAndShift ) {
+			Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));
 			Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
-			Assert.assertTrue(heavyHittersContainsString("fed_uacvar"));
+			Assert.assertTrue(heavyHittersContainsString("fed_-"));
+			Assert.assertTrue(heavyHittersContainsString("fed_/"));
+			Assert.assertTrue(heavyHittersContainsString("fed_replace"));
 		}
 		
 		resetExecMode(platformOld);