You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/06/04 17:44:46 UTC

[systemds] branch main updated: [MINOR] Fix warnings, data types, formatting of the federated backend

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 85fa35312c [MINOR] Fix warnings, data types, formatting of the federated backend
85fa35312c is described below

commit 85fa35312c3e536024d9a14738d10cb181e343c0
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Jun 4 19:44:13 2022 +0200

    [MINOR] Fix warnings, data types, formatting of the federated backend
---
 src/main/java/org/apache/sysds/common/Types.java   |   4 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |   2 +-
 .../RewriteElementwiseMultChainOptimization.java   |   2 +-
 .../federated/FederatedLocalData.java              |   1 -
 .../paramserv/FederatedPSControlThread.java        |   1 -
 .../controlprogram/paramserv/HEParamServer.java    | 298 ++++++++++-----------
 .../paramserv/NetworkTrafficCounter.java           |  27 +-
 .../runtime/instructions/cp/PlaintextMatrix.java   |  16 +-
 .../sysds/utils/stats/ParamServStatistics.java     |   1 -
 .../paramserv/EncryptedFederatedParamservTest.java |   7 +-
 .../fedplanning/FederatedMultiplyPlanningTest.java |   1 -
 11 files changed, 179 insertions(+), 181 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 3dfad3413e..a7cfa823aa 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -44,7 +44,9 @@ public class Types
 	 * Data types (tensor, matrix, scalar, frame, object, unknown).
 	 */
 	public enum DataType {
-		TENSOR, MATRIX, SCALAR, FRAME, LIST, ENCRYPTED_CIPHER, ENCRYPTED_PLAIN, UNKNOWN;
+		TENSOR, MATRIX, SCALAR, FRAME, LIST, UNKNOWN,
+		//TODO remove from Data Type -> generic object
+		ENCRYPTED_CIPHER, ENCRYPTED_PLAIN;
 		
 		public boolean isMatrix() {
 			return this == MATRIX;
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 7bdb5a424e..2ee317f35e 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -813,7 +813,7 @@ public abstract class Hop implements ParseInfo {
 				
 				break;
 			}
-			case UNKNOWN: {
+			default: {
 				//memory estimate always unknown
 				_outputMemEstimate = OptimizerUtils.DEFAULT_SIZE;
 				break;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index bbb8f0a161..d2244c8a7d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -262,8 +262,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 					case MATRIX: orderDataType[i] = 1; break;
 					case TENSOR: orderDataType[i] = 2; break;
 					case FRAME:  orderDataType[i] = 3; break;
-					case UNKNOWN:orderDataType[i] = 4; break;
 					case LIST:   orderDataType[i] = 5; break;
+					default:     orderDataType[i] = 4; break;
 				}
 		}
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
index 77ffb7f847..de56a1a52e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
@@ -25,7 +25,6 @@ import java.util.concurrent.Future;
 import org.apache.log4j.Logger;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
 
 public class FederatedLocalData extends FederatedData {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 54d778486a..0c984698c1 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -378,7 +378,6 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	@Override
 	public Void call() throws Exception {
 		try {
-			Timing tTotal = new Timing(true);
 			switch (_freq) {
 				case BATCH:
 					computeWithBatchUpdates();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
index 577bf6c820..4e873abdb6 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
@@ -29,7 +29,6 @@ import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
-import org.apache.sysds.utils.NativeHelper;
 import org.apache.sysds.utils.stats.ParamServStatistics;
 
 import java.util.ArrayList;
@@ -43,152 +42,153 @@ import java.util.stream.IntStream;
  * This class implements Homomorphic Encryption (HE) for LocalParamServer. It only supports modelAvg=true.
  */
 public class HEParamServer extends LocalParamServer {
-    private int _thread_counter = 0;
-    private final List<FederatedPSControlThread> _threads;
-    private final List<Object> _result_buffer; // one per thread
-    private Object _result;
-    private final SEALServer _seal_server;
-
-    public static HEParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
-                                          Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-                                          MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
-    {
-        NativeHEHelper.initialize();
-        return new HEParamServer(model, aggFunc, updateType, freq, ec,
-                workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches);
-    }
-
-    private HEParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
-                             Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-                             MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
-    {
-        super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, true);
-
-        _seal_server = new SEALServer();
-
-        _threads = Collections.synchronizedList(new ArrayList<>(workerNum));
-        for (int i = 0; i < getNumWorkers(); i++) {
-            _threads.add(null);
-        }
-
-        _result_buffer = new ArrayList<>(workerNum);
-        resetResultBuffer();
-    }
-
-    public void registerThread(int thread_id, FederatedPSControlThread thread) {
-        _threads.set(thread_id, thread);
-    }
-
-    private synchronized void resetResultBuffer() {
-        _result_buffer.clear();
-        for (int i = 0; i < getNumWorkers(); i++) {
-            _result_buffer.add(null);
-        }
-    }
-
-    public byte[] generateA() {
-        return _seal_server.generateA();
-    }
-
-    public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) {
-        return _seal_server.aggregatePartialPublicKeys(partial_public_keys);
-    }
-
-    /**
-     * this method collects all T Objects from each worker into a list and then calls f once on this list to produce
-     * another T, which it returns.
-     */
-    private synchronized <T,U> U collectAndDo(int workerId, T obj, Function<List<T>, U> f) {
-        _result_buffer.set(workerId, obj);
-        _thread_counter++;
-
-        if (_thread_counter == getNumWorkers()) {
-            List<T> buf = _result_buffer.stream().map(x -> (T)x).collect(Collectors.toList());
-            _result = f.apply(buf);
-            resetResultBuffer();
-            _thread_counter = 0;
-            notifyAll();
-        } else {
-            try {
-                wait();
-            } catch (InterruptedException i) {
-                throw new RuntimeException("thread interrupted");
-            }
-        }
-
-        return (U) _result;
-    }
-
-    private CiphertextMatrix[] homomorphicAggregation(List<ListObject> encrypted_models) {
-        Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
-        CiphertextMatrix[] result = new CiphertextMatrix[encrypted_models.get(0).getLength()];
-        IntStream.range(0, encrypted_models.get(0).getLength()).forEach(matrix_idx -> {
-            CiphertextMatrix[] summands = new CiphertextMatrix[encrypted_models.size()];
-            for (int i = 0; i < encrypted_models.size(); i++) {
-                summands[i] = (CiphertextMatrix) encrypted_models.get(i).getData(matrix_idx);
-            }
-            result[matrix_idx] = _seal_server.accumulateCiphertexts(summands);;
-        });
-        if (tAgg != null) {
-            ParamServStatistics.accHEAccumulation((long)tAgg.stop());
-        }
-        return result;
-    }
-
-    private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, List<PlaintextMatrix[]> partial_decryptions) {
-        Timing tDecrypt = DMLScript.STATISTICS ? new Timing(true) : null;
-
-        MatrixObject[] result = new MatrixObject[partial_decryptions.get(0).length];
-
-        IntStream.range(0, partial_decryptions.get(0).length).forEach(matrix_idx -> {
-            PlaintextMatrix[] partial_plaintexts = new PlaintextMatrix[partial_decryptions.size()];
-            for (int i = 0; i < partial_decryptions.size(); i++) {
-                partial_plaintexts[i] = partial_decryptions.get(i)[matrix_idx];
-            }
-
-            result[matrix_idx] = _seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts);
-        });
-
-        ListObject old_model = getResult();
-        ListObject new_model = new ListObject(old_model);
-        for (int i = 0; i < new_model.getLength(); i++) {
-            new_model.set(i, result[i]);
-        }
-
-        if (tDecrypt != null) {
-            ParamServStatistics.accHEDecryptionTime((long)tDecrypt.stop());
-        }
-
-        updateAndBroadcastModel(new_model, null);
-        return null;
-    }
-
-    // this is only to be used in push()
-    private Timing commTimer;
-    private void startCommTimer() {
-        commTimer = new Timing(true);
-    }
-    private long stopCommTimer() {
-        return (long)commTimer.stop();
-    }
-    // ---------------------------------
-
-    @Override
-    public void push(int workerID, ListObject encrypted_model) {
-        // wait for all updates and sum them homomorphically
-        CiphertextMatrix[] homomorphic_sum = collectAndDo(workerID, encrypted_model, x -> {
-            CiphertextMatrix[] res = this.homomorphicAggregation(x);
-            this.startCommTimer();
-            return res;
-        });
-
-        // get partial decryptions
-        PlaintextMatrix[] partial_decryption = _threads.get(workerID).getPartialDecryption(homomorphic_sum);
-
-        // do average and update global model
-        collectAndDo(workerID, partial_decryption, x -> {
-            ParamServStatistics.accFedNetworkTime(this.stopCommTimer());
-            return this.homomorphicAverage(homomorphic_sum, x);
-        });
-    }
+	private int _thread_counter = 0;
+	private final List<FederatedPSControlThread> _threads;
+	private final List<Object> _result_buffer; // one per thread
+	private Object _result;
+	private final SEALServer _seal_server;
+
+	public static HEParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
+		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
+		MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
+	{
+		NativeHEHelper.initialize();
+		return new HEParamServer(model, aggFunc, updateType, freq, ec,
+				workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches);
+	}
+
+	private HEParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
+		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc,
+		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
+	{
+		super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, true);
+
+		_seal_server = new SEALServer();
+
+		_threads = Collections.synchronizedList(new ArrayList<>(workerNum));
+		for (int i = 0; i < getNumWorkers(); i++) {
+			_threads.add(null);
+		}
+
+		_result_buffer = new ArrayList<>(workerNum);
+		resetResultBuffer();
+	}
+
+	public void registerThread(int thread_id, FederatedPSControlThread thread) {
+		_threads.set(thread_id, thread);
+	}
+
+	private synchronized void resetResultBuffer() {
+		_result_buffer.clear();
+		for (int i = 0; i < getNumWorkers(); i++) {
+			_result_buffer.add(null);
+		}
+	}
+
+	public byte[] generateA() {
+		return _seal_server.generateA();
+	}
+
+	public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) {
+		return _seal_server.aggregatePartialPublicKeys(partial_public_keys);
+	}
+
+	/**
+	 * this method collects all T Objects from each worker into a list and then calls f once on this list to produce
+	 * another T, which it returns.
+	 */
+	@SuppressWarnings("unchecked")
+	private synchronized <T,U> U collectAndDo(int workerId, T obj, Function<List<T>, U> f) {
+		_result_buffer.set(workerId, obj);
+		_thread_counter++;
+
+		if (_thread_counter == getNumWorkers()) {
+			List<T> buf = _result_buffer.stream().map(x -> (T)x).collect(Collectors.toList());
+			_result = f.apply(buf);
+			resetResultBuffer();
+			_thread_counter = 0;
+			notifyAll();
+		} else {
+			try {
+				wait();
+			} catch (InterruptedException i) {
+				throw new RuntimeException("thread interrupted");
+			}
+		}
+
+		return (U) _result;
+	}
+
+	private CiphertextMatrix[] homomorphicAggregation(List<ListObject> encrypted_models) {
+		Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+		CiphertextMatrix[] result = new CiphertextMatrix[encrypted_models.get(0).getLength()];
+		IntStream.range(0, encrypted_models.get(0).getLength()).forEach(matrix_idx -> {
+			CiphertextMatrix[] summands = new CiphertextMatrix[encrypted_models.size()];
+			for (int i = 0; i < encrypted_models.size(); i++) {
+				summands[i] = (CiphertextMatrix) encrypted_models.get(i).getData(matrix_idx);
+			}
+			result[matrix_idx] = _seal_server.accumulateCiphertexts(summands);;
+		});
+		if (tAgg != null) {
+			ParamServStatistics.accHEAccumulation((long)tAgg.stop());
+		}
+		return result;
+	}
+
+	private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, List<PlaintextMatrix[]> partial_decryptions) {
+		Timing tDecrypt = DMLScript.STATISTICS ? new Timing(true) : null;
+
+		MatrixObject[] result = new MatrixObject[partial_decryptions.get(0).length];
+
+		IntStream.range(0, partial_decryptions.get(0).length).forEach(matrix_idx -> {
+			PlaintextMatrix[] partial_plaintexts = new PlaintextMatrix[partial_decryptions.size()];
+			for (int i = 0; i < partial_decryptions.size(); i++) {
+				partial_plaintexts[i] = partial_decryptions.get(i)[matrix_idx];
+			}
+
+			result[matrix_idx] = _seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts);
+		});
+
+		ListObject old_model = getResult();
+		ListObject new_model = new ListObject(old_model);
+		for (int i = 0; i < new_model.getLength(); i++) {
+			new_model.set(i, result[i]);
+		}
+
+		if (tDecrypt != null) {
+			ParamServStatistics.accHEDecryptionTime((long)tDecrypt.stop());
+		}
+
+		updateAndBroadcastModel(new_model, null);
+		return null;
+	}
+
+	// this is only to be used in push()
+	private Timing commTimer;
+	private void startCommTimer() {
+		commTimer = new Timing(true);
+	}
+	private long stopCommTimer() {
+		return (long)commTimer.stop();
+	}
+	// ---------------------------------
+
+	@Override
+	public void push(int workerID, ListObject encrypted_model) {
+		// wait for all updates and sum them homomorphically
+		CiphertextMatrix[] homomorphic_sum = collectAndDo(workerID, encrypted_model, x -> {
+			CiphertextMatrix[] res = this.homomorphicAggregation(x);
+			this.startCommTimer();
+			return res;
+		});
+
+		// get partial decryptions
+		PlaintextMatrix[] partial_decryption = _threads.get(workerID).getPartialDecryption(homomorphic_sum);
+
+		// do average and update global model
+		collectAndDo(workerID, partial_decryption, x -> {
+			ParamServStatistics.accFedNetworkTime(this.stopCommTimer());
+			return this.homomorphicAverage(homomorphic_sum, x);
+		});
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
index f823b9d3be..9c353c3258 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
@@ -19,24 +19,23 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv;
 
-import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.handler.traffic.ChannelTrafficShapingHandler;
 import java.util.function.BiConsumer;
 
 public class NetworkTrafficCounter extends ChannelTrafficShapingHandler {
-    private final BiConsumer<Long, Long> _fn; // (read, written) -> Void, logs bytes read and written
-    public NetworkTrafficCounter(BiConsumer<Long, Long> fn) {
-        // checkInterval of zero means that doAccounting will not be called
-        super( 0);
-        _fn = fn;
-    }
+	private final BiConsumer<Long, Long> _fn; // (read, written) -> Void, logs bytes read and written
+	public NetworkTrafficCounter(BiConsumer<Long, Long> fn) {
+		// checkInterval of zero means that doAccounting will not be called
+		super( 0);
+		_fn = fn;
+	}
 
-    // log bytes read/written after channel is closed
-    @Override
-    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
-        _fn.accept(trafficCounter.cumulativeReadBytes(), trafficCounter.cumulativeWrittenBytes());
-        trafficCounter.resetCumulativeTime();
-        super.channelInactive(ctx);
-    }
+	// log bytes read/written after channel is closed
+	@Override
+	public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+		_fn.accept(trafficCounter.cumulativeReadBytes(), trafficCounter.cumulativeWrittenBytes());
+		trafficCounter.resetCumulativeTime();
+		super.channelInactive(ctx);
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
index 6fe2b3814f..d36d40bc8e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
@@ -26,14 +26,14 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
  * This class abstracts over an encrypted matrix of ciphertexts. It stores the data as opaque byte array. The layout is unspecified.
  */
 public class PlaintextMatrix extends Encrypted {
-    private static final long serialVersionUID = 5732436872261940616L;
+	private static final long serialVersionUID = 5732436872261940616L;
 
-    public PlaintextMatrix(int[] dims, DataCharacteristics dc, byte[] data) {
-        super(dims, dc, data, Types.DataType.ENCRYPTED_PLAIN);
-    }
+	public PlaintextMatrix(int[] dims, DataCharacteristics dc, byte[] data) {
+		super(dims, dc, data, Types.DataType.ENCRYPTED_PLAIN);
+	}
 
-    @Override
-    public String getDebugName() {
-        return "PlaintextMatrix " + getData().hashCode();
-    }
+	@Override
+	public String getDebugName() {
+		return "PlaintextMatrix " + getData().hashCode();
+	}
 }
diff --git a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
index 8eb26a1963..3edf7bb77e 100644
--- a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
@@ -21,7 +21,6 @@ package org.apache.sysds.utils.stats;
 
 import java.util.concurrent.atomic.LongAdder;
 
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 
 public class ParamServStatistics {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
index 250358d408..ca50338e33 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
@@ -25,7 +25,6 @@ import java.util.Collection;
 import java.util.List;
 
 import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -94,8 +93,10 @@ public class EncryptedFederatedParamservTest extends AutomatedTestBase {
 		});
 	}
 
-	public EncryptedFederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size,
-										  int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighting, String data_distribution, int seed) {
+	public EncryptedFederatedParamservTest(String networkType, int numFederatedWorkers,
+		int dataSetSize, int batch_size, int epochs, double eta, String utype, String freq,
+		String scheme, String runtime_balancing, String weighting, String data_distribution, int seed)
+	{
 		try {
 			NativeHEHelper.initialize();
 		} catch (Exception e) {
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index b9a3a14fd5..14c093ebe8 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -23,7 +23,6 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;