You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/06/03 08:49:46 UTC

[systemds] branch main updated: [SYSTEMDS-3280] Add homomorphic encryption functionality to Parameter Server.

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new ccbdd3560d [SYSTEMDS-3280] Add homomorphic encryption functionality to Parameter Server.
ccbdd3560d is described below

commit ccbdd3560d6690b9124625acb1886f98d0f8de4c
Author: Florian Lackner <la...@gmail.com>
AuthorDate: Wed Jan 26 14:20:18 2022 +0100

    [SYSTEMDS-3280] Add homomorphic encryption functionality to Parameter Server.
    
    This commit adds homomorphic encryption functionality to ParameterServer.
    It allows a federated parameter server to encrypt the data of each client and do the accumulation step using this encrypted data and homomorphic operations.
    The data never leaves the clients in plaintext.
    
    Closes #1525.
---
 src/main/cpp/build.bat                             |   4 +
 src/main/cpp/build.sh                              |   5 +
 src/main/cpp/he/CMakeLists.txt                     |  64 ++++
 src/main/cpp/he/he.cpp                             | 298 ++++++++++++++++++
 src/main/cpp/he/he.h                               | 111 +++++++
 src/main/cpp/he/libhe.cpp                          | 294 ++++++++++++++++++
 src/main/cpp/he/libhe.h                            | 144 +++++++++
 src/main/cpp/lib/libhe-Linux-x86_64.so             | Bin 0 -> 168696 bytes
 src/main/cpp/systemds.cpp                          |   8 +-
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 .../ParameterizedBuiltinFunctionExpression.java    |   3 +-
 .../java/org/apache/sysds/parser/Statement.java    |   2 +-
 .../controlprogram/context/ExecutionContext.java   |  17 +-
 .../controlprogram/federated/FederatedData.java    |   2 +
 .../federated/FederatedLocalData.java              |   1 +
 .../federated/FederatedStatistics.java             |  34 +++
 .../controlprogram/federated/FederatedWorker.java  |  16 +-
 .../federated/FederatedWorkerHandler.java          |  40 ++-
 .../paramserv/FederatedPSControlThread.java        | 337 +++++++++++++++++----
 .../controlprogram/paramserv/HEParamServer.java    | 194 ++++++++++++
 .../controlprogram/paramserv/LocalParamServer.java |   2 +-
 .../controlprogram/paramserv/NativeHEHelper.java   | 119 ++++++++
 .../paramserv/NetworkTrafficCounter.java           |  42 +++
 .../controlprogram/paramserv/ParamServer.java      |  61 ++--
 .../paramserv/dp/DataPartitionFederatedScheme.java |   1 +
 .../paramserv/homomorphicEncryption/PublicKey.java |  36 +++
 .../homomorphicEncryption/SEALClient.java          |  88 ++++++
 .../homomorphicEncryption/SEALServer.java          | 112 +++++++
 .../runtime/instructions/cp/CiphertextMatrix.java  |  39 +++
 .../sysds/runtime/instructions/cp/Encrypted.java   |  53 ++++
 .../sysds/runtime/instructions/cp/ListObject.java  |  27 ++
 .../cp/ParamservBuiltinCPInstruction.java          | 103 ++++---
 .../runtime/instructions/cp/PlaintextMatrix.java   |  39 +++
 .../java/org/apache/sysds/utils/NativeHelper.java  |  92 +++---
 .../java/org/apache/sysds/utils/Statistics.java    |   2 +
 .../sysds/utils/stats/ParamServStatistics.java     |  51 ++++
 .../org/apache/sysds/test/AutomatedTestBase.java   |  10 +-
 .../paramserv/EncryptedFederatedParamservTest.java | 256 ++++++++++++++++
 .../functions/homomorphicEncryption/InOutTest.java | 118 ++++++++
 .../paramserv/EncryptedFederatedParamservTest.dml  |  61 ++++
 40 files changed, 2700 insertions(+), 188 deletions(-)

diff --git a/src/main/cpp/build.bat b/src/main/cpp/build.bat
index 93c3819c3b..316c4f393d 100644
--- a/src/main/cpp/build.bat
+++ b/src/main/cpp/build.bat
@@ -34,5 +34,9 @@ cmake . -B OPENBLAS -DUSE_OPEN_BLAS=ON -DCMAKE_BUILD_TYPE=Release
 cmake --build OPENBLAS --target install --config Release
 rmdir /Q /S OPENBLAS
 
+cmake he\ -B HE -DCMAKE_BUILD_TYPE=Release
+cmake --build HE --target install --config Release
+rmdir /Q /S HE
+
 echo.
 echo "Make sure to re-run mvn package to make use of the newly compiled libraries"
\ No newline at end of file
diff --git a/src/main/cpp/build.sh b/src/main/cpp/build.sh
index df67aba539..e40ec895a3 100755
--- a/src/main/cpp/build.sh
+++ b/src/main/cpp/build.sh
@@ -66,3 +66,8 @@ ldd lib/libsystemds_mkl-Linux-x86_64.so | grep -v $gcc_toolkit"\|$linux_loader\|
 echo "Non-standard dependencies for libsystemds_openblas-linux-x86_64.so"
 ldd lib/libsystemds_openblas-Linux-x86_64.so | grep -v $gcc_toolkit"\|$linux_loader\|"$openblas
 echo "-----------------------------------------------------------------------"
+
+# compile HE
+cmake he/ -B HE -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=g++
+cmake --build HE --target install --config Release
+rm -R HE
\ No newline at end of file
diff --git a/src/main/cpp/he/CMakeLists.txt b/src/main/cpp/he/CMakeLists.txt
new file mode 100644
index 0000000000..373ba3a5d9
--- /dev/null
+++ b/src/main/cpp/he/CMakeLists.txt
@@ -0,0 +1,64 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+cmake_minimum_required(VERSION 3.8)
+cmake_policy(SET CMP0074 NEW) # make use of <package>_ROOT variable
+project (he LANGUAGES CXX)
+
+# All custom find modules
+set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/")
+
+# Build a shared libraray
+set(HEADER_FILES libhe.h he.h)
+set(SOURCE_FILES libhe.cpp he.cpp)
+
+# Build a shared libraray
+add_library(he SHARED ${SOURCE_FILES} ${HEADER_FILES})
+
+set_target_properties(he PROPERTIES MACOSX_RPATH 1)
+
+# sets the installation path to src/main/cpp/lib
+if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
+  set(CMAKE_INSTALL_PREFIX "${CMAKE_SOURCE_DIR}/.." CACHE PATH "sets the installation path to src/main/cpp/lib" FORCE)
+endif()
+
+# sets the installation path to src/main/cpp/lib
+# install(TARGETS he LIBRARY DESTINATION lib)
+install(TARGETS he RUNTIME DESTINATION lib)
+
+# unify library filenames to libhe_<...>
+if (WIN32)
+    set(CMAKE_IMPORT_LIBRARY_PREFIX lib CACHE INTERNAL "")
+    set(CMAKE_SHARED_LIBRARY_PREFIX lib CACHE INTERNAL "")
+endif()
+
+set(CMAKE_BUILD_TYPE Release)
+set_target_properties(he PROPERTIES OUTPUT_NAME "he-${CMAKE_SYSTEM_NAME}-${CMAKE_SYSTEM_PROCESSOR}")
+
+find_package(SEAL 3.7 REQUIRED)
+target_link_libraries(he SEAL::seal_shared)
+
+# Include directories. (added for Linux & Darwin, fix later for windows)
+# include paths can be spurious
+include_directories($ENV{JAVA_HOME}/include/)
+include_directories($ENV{JAVA_HOME}/include/darwin)
+include_directories($ENV{JAVA_HOME}/include/linux)
+include_directories($ENV{JAVA_HOME}/include/win32)
diff --git a/src/main/cpp/he/he.cpp b/src/main/cpp/he/he.cpp
new file mode 100644
index 0000000000..71ff571af1
--- /dev/null
+++ b/src/main/cpp/he/he.cpp
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "he.h"
+#include "libhe.h"
+
+#ifdef _WIN32
+#include <winsock.h>
+#else
+#include <arpa/inet.h>
+#endif
+
+unique_ptr<istream> get_stream(JNIEnv* env, jbyteArray ary) {
+    size_t size = env->GetArrayLength(ary);
+    jbyte* data = env->GetByteArrayElements(ary, NULL);
+
+    // FIXME: this copies string data once. maybe implement a custom stream
+    // idea: implement a custom stream that wraps a jbyteArray, which calls ReleaseByteArrayElements in its d'tor
+    string data_s = string(reinterpret_cast<char*>(data), size);
+    unique_ptr<istream> ret = std::make_unique<istringstream>(std::move(data_s));
+    env->ReleaseByteArrayElements(ary, data, JNI_ABORT);
+    return ret;
+}
+
+jbyteArray allocate_byte_array(JNIEnv* env, ostringstream& stream) {
+    string data = stream.str(); // FIXME: this copies string content. maybe implement custom ostream
+    jbyteArray ret = env->NewByteArray(data.size());
+    env->SetByteArrayRegion(ret, 0, data.size(), reinterpret_cast<jbyte*>(data.data()));
+    return ret;
+}
+
+void my_assert(bool assertion, const char* message = "Assertion failed") {
+    if (!assertion) {
+        throw logic_error(message);
+    }
+}
+
+template<typename T> jbyteArray serialize(JNIEnv* env, T& object) {
+    ostringstream ss;
+    object.save(ss);
+    return allocate_byte_array(env, ss);
+}
+
+void serialize_uint32_t(ostream& ss, uint32_t n) {
+    n = htonl(n);
+    ss.write(reinterpret_cast<char*>(&n), sizeof(n));
+}
+
+uint32_t deserialize_uint32_t(istream& ss) {
+    uint32_t ret;
+    ss.read(reinterpret_cast<char*>(&ret), sizeof(ret));
+    ret = ntohl(ret);
+    return ret;
+}
+
+Ciphertext deserialize_ciphertext(istream& ss, const SEALContext& context) {
+    Ciphertext ret;
+    ret.load(context, ss);
+    return ret;
+}
+
+void serialize_plaintext(ostream& ss, Plaintext plaintext) {
+    plaintext.save(ss);
+}
+
+template<typename T> T deserialize_unsafe(JNIEnv* env, const SEALContext& context, jbyteArray serialized_object) {
+    auto ss = get_stream(env, serialized_object);
+    T deserialized;
+    deserialized.unsafe_load(context, *ss); // necessary bc partial public keys are not valid public keys
+    return deserialized;
+}
+
+template<typename T> T deserialize(JNIEnv* env, const SEALContext& context, jbyteArray serialized_object) {
+    auto ss = get_stream(env, serialized_object);
+    T deserialized;
+    deserialized.load(context, *ss); // necessary bc partial public keys are not valid public keys
+    return deserialized;
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initClient
+  (JNIEnv* env, jclass, jbyteArray a_ary) {
+    double scale = pow(2.0, 40);
+    GlobalState gs(scale);
+
+    // copy a to global state
+    size_t byte_size = env->GetArrayLength(a_ary);
+    my_assert(byte_size % sizeof(uint64_t) == 0);
+    size_t size = byte_size / sizeof(uint64_t);
+    uint64_t* a = reinterpret_cast<uint64_t*>(env->GetByteArrayElements(a_ary, NULL));
+    gsl::span<uint64_t > new_a(a, size);
+
+    vector<uint64_t> new_a_buf;
+    new_a_buf.assign(new_a.begin(), new_a.end());
+    gs.a.set_data(new_a_buf);
+
+    // release a without back-copy
+    env->ReleaseByteArrayElements(a_ary, reinterpret_cast<jbyte*>(a), JNI_ABORT);
+
+    Client* client = new Client(gs);
+    return reinterpret_cast<jlong>(client);
+}
+
+
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generatePartialPublicKey
+  (JNIEnv* env, jclass, jlong client_ptr) {
+    Client* client = reinterpret_cast<Client*>(client_ptr);
+    return serialize(env, client->partial_public_key().data());
+}
+
+
+JNIEXPORT void JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_setPublicKey
+  (JNIEnv* env, jclass, jlong client_ptr, jbyteArray serialized_public_key) {
+    Client* client = reinterpret_cast<Client*>(client_ptr);
+    client->set_public_key(deserialize<PublicKey>(env, client->context(), serialized_public_key));
+}
+
+
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_encrypt
+  (JNIEnv* env, jclass, jlong client_ptr, jdoubleArray jdata) {
+    Client* client = reinterpret_cast<Client*>(client_ptr);
+    size_t slot_count = get_slot_count(client->context());
+    size_t num_data = env->GetArrayLength(jdata);
+    const double* data = static_cast<const double*>(env->GetDoubleArrayElements(jdata, NULL));
+
+    std::ostringstream ss;
+    // write chunk size
+    uint32_t num_chunks = (num_data - 1) / slot_count + 1;
+    serialize_uint32_t(ss, num_chunks);
+    for (size_t i = 0; i < num_chunks; i++) {
+        size_t offset = slot_count * i;
+        size_t length = min(slot_count, num_data-offset);
+        gsl::span<const double> data_span(&data[offset], length);
+        Ciphertext encrypted_chunk = client->encrypted_data(data_span);
+        encrypted_chunk.save(ss);
+    }
+    env->ReleaseDoubleArrayElements(jdata, const_cast<jdouble*>(data), JNI_ABORT);
+    return allocate_byte_array(env, ss);
+}
+
+
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_partiallyDecrypt
+  (JNIEnv* env, jclass, jlong client_ptr, jbyteArray serialized_ciphertexts) {
+    Client* client = reinterpret_cast<Client*>(client_ptr);
+    auto input = get_stream(env, serialized_ciphertexts);
+    std::ostringstream ss;
+
+    // read num of chunks
+    uint32_t num_chunks = deserialize_uint32_t(*input);
+
+    // write chunk size
+    serialize_uint32_t(ss, num_chunks);
+    for (int i = 0; i < num_chunks; i++) {
+        Ciphertext ciphertext = deserialize_ciphertext(*input, client->context());
+        Plaintext plaintext = client->partial_decryption(ciphertext);
+        plaintext.save(ss);
+    }
+
+    return allocate_byte_array(env, ss);
+}
+
+
+JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initServer
+  (JNIEnv *, jclass) {
+    double scale = pow(2.0, 40);
+    GlobalState gs(scale);
+    Server* server = new Server(gs);
+    return reinterpret_cast<jlong>(server);
+}
+
+
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generateA
+  (JNIEnv* env, jclass, jlong server_ptr) {
+    Server* server = reinterpret_cast<Server*>(server_ptr);
+    uint64_t* data = server->a().data();
+    size_t size = server->a().size() * sizeof(data[0]) / sizeof(jbyte);
+    jbyteArray ret = env->NewByteArray(size);
+    env->SetByteArrayRegion(ret, 0, size, reinterpret_cast<jbyte*>(data));
+    return ret;
+}
+
+
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_aggregatePartialPublicKeys
+  (JNIEnv* env, jclass, jlong server_ptr, jobjectArray partial_public_keys_serialized) {
+    Server* server = reinterpret_cast<Server*>(server_ptr);
+    size_t num_partial_public_keys = env->GetArrayLength(partial_public_keys_serialized);
+    std::vector<Ciphertext> partial_public_keys;
+    partial_public_keys.reserve(num_partial_public_keys);
+
+    for (int i = 0; i < num_partial_public_keys; i++) {
+        jbyteArray j_data = static_cast<jbyteArray>(env->GetObjectArrayElement(partial_public_keys_serialized, i));
+        partial_public_keys.push_back(deserialize_unsafe<Ciphertext>(env, server->context(), j_data));
+        env->DeleteLocalRef(j_data);
+    }
+
+    server->accumulate_partial_public_keys(gsl::span(partial_public_keys));
+    return serialize(env, server->public_key());
+}
+
+
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_accumulateCiphertexts
+  (JNIEnv* env, jclass, jlong server_ptr, jobjectArray ciphertexts_serialized) {
+    Server* server = reinterpret_cast<Server*>(server_ptr);
+    size_t num_ciphertext_arys = env->GetArrayLength(ciphertexts_serialized);
+
+    // init streams
+    vector<unique_ptr<istream>> buf;
+    buf.reserve(num_ciphertext_arys);
+    for (int i = 0; i < num_ciphertext_arys; i++) {
+        jbyteArray j_data = static_cast<jbyteArray>(env->GetObjectArrayElement(ciphertexts_serialized, i));
+        auto stream = get_stream(env, j_data);
+        buf.emplace_back(std::move(stream));
+        env->DeleteLocalRef(j_data);
+    }
+
+    // read lengths of ciphertext arys and check that they are all the same
+    uint32_t num_slots = deserialize_uint32_t(*buf[0]);
+    for (int i = 1; i < num_ciphertext_arys; i++) {
+        my_assert(deserialize_uint32_t(*buf[i]) == num_slots);
+    }
+
+    // read ciphertexts in chunks and accumulate them
+    ostringstream result;
+    serialize_uint32_t(result, num_slots);
+    for (int chunk_idx = 0; chunk_idx < num_slots; chunk_idx++) {
+        vector<Ciphertext> ciphertexts;
+        ciphertexts.reserve(num_ciphertext_arys);
+        for (int i = 0; i < num_ciphertext_arys; i++) {
+            Ciphertext deserialized;
+            deserialized.load(server->context(), *buf[i]);
+            ciphertexts.emplace_back(deserialized);
+        }
+        Ciphertext sum = server->sum_data(std::move(ciphertexts));
+        sum.save(result);
+    }
+
+    return allocate_byte_array(env, result);
+}
+
+
+JNIEXPORT jdoubleArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_average
+  (JNIEnv* env, jclass, jlong server_ptr, jbyteArray ciphertext_sum_serialized, jobjectArray partial_decryptions_serialized) {
+    Server* server = reinterpret_cast<Server*>(server_ptr);
+    size_t slot_size = get_slot_count(server->context());
+    size_t num_plaintext_arys = env->GetArrayLength(partial_decryptions_serialized);
+
+    // init streams
+    vector<unique_ptr<istream>> buf;
+    buf.reserve(num_plaintext_arys);
+    for (int i = 0; i < num_plaintext_arys; i++) {
+        jbyteArray j_data = static_cast<jbyteArray>(env->GetObjectArrayElement(partial_decryptions_serialized, i));
+        auto stream = get_stream(env, j_data);
+        buf.emplace_back(std::move(stream));
+        env->DeleteLocalRef(j_data);
+    }
+
+    // read lengths of ciphertext arys and check that they are all the same
+    uint32_t num_slots = deserialize_uint32_t(*buf[0]);
+    for (int i = 1; i < num_plaintext_arys; i++) {
+        my_assert(deserialize_uint32_t(*buf[i]) == num_slots, "number of plaintext slots is different");
+    }
+
+    auto encrypted_sum_stream = get_stream(env, ciphertext_sum_serialized);
+    my_assert(deserialize_uint32_t(*encrypted_sum_stream) == num_slots, "number of ciphertext slots is different");
+
+    // read ciphertexts in chunks and accumulate them
+    jdoubleArray result = env->NewDoubleArray(num_slots * slot_size);
+    for (int chunk_idx = 0; chunk_idx < num_slots; chunk_idx++) {
+        Ciphertext encrypted_sum = deserialize_ciphertext(*encrypted_sum_stream, server->context());
+
+        vector<Plaintext> partial_decryptions;
+        partial_decryptions.reserve(num_plaintext_arys);
+        for (int i = 0; i < num_plaintext_arys; i++) {
+            Plaintext deserialized;
+            deserialized.load(server->context(), *buf[i]);
+            partial_decryptions.emplace_back(deserialized);
+        }
+        vector<double> averages = server->average(encrypted_sum, move(partial_decryptions));
+        env->SetDoubleArrayRegion(result, chunk_idx*slot_size, averages.size(), averages.data());
+    }
+
+    return result;
+}
diff --git a/src/main/cpp/he/he.h b/src/main/cpp/he/he.h
new file mode 100644
index 0000000000..c7b0ad05d5
--- /dev/null
+++ b/src/main/cpp/he/he.h
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <jni.h>
+/* Header for class org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper */
+
+#ifndef _Included_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper
+#define _Included_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    initClient
+ * Signature: ([B)J
+ */
+JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initClient
+  (JNIEnv *, jclass, jbyteArray);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    generatePartialPublicKey
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generatePartialPublicKey
+  (JNIEnv *, jclass, jlong);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    setPublicKey
+ * Signature: (J[B)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_setPublicKey
+  (JNIEnv *, jclass, jlong, jbyteArray);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    encrypt
+ * Signature: (J[D)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_encrypt
+  (JNIEnv *, jclass, jlong, jdoubleArray);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    partiallyDecrypt
+ * Signature: (J[B)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_partiallyDecrypt
+  (JNIEnv *, jclass, jlong, jbyteArray);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    initServer
+ * Signature: ()J
+ */
+JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initServer
+  (JNIEnv *, jclass);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    generateA
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generateA
+  (JNIEnv *, jclass, jlong);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    aggregatePartialPublicKeys
+ * Signature: (J[[B)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_aggregatePartialPublicKeys
+  (JNIEnv *, jclass, jlong, jobjectArray);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    accumulateCiphertexts
+ * Signature: (J[[B)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_accumulateCiphertexts
+  (JNIEnv *, jclass, jlong, jobjectArray);
+
+/*
+ * Class:     org_apache_sysds_utils_NativeHelper
+ * Method:    average
+ * Signature: (J[B[[B)[D
+ */
+JNIEXPORT jdoubleArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_average
+  (JNIEnv *, jclass, jlong, jbyteArray, jobjectArray);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
\ No newline at end of file
diff --git a/src/main/cpp/he/libhe.cpp b/src/main/cpp/he/libhe.cpp
new file mode 100644
index 0000000000..5f8a929972
--- /dev/null
+++ b/src/main/cpp/he/libhe.cpp
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cassert>
+#include <algorithm>
+#include <optional>
+#include <gsl/span>
+
+#include "libhe.h"
+
+#include "seal/seal.h"
+#include "seal/util/common.h"
+#include "seal/util/rlwe.h"
+#include "seal/util/polyarithsmallmod.h"
+
+using namespace std;
+using namespace seal;
+
+RawPolynomData::RawPolynomData(const SEALContext& context) {
+    // Extract encryption parameters
+    auto &context_data = *context.key_context_data();
+    auto &parms = context_data.parms();
+    auto coeff_modulus = parms.coeff_modulus();
+    size_t coeff_modulus_size = coeff_modulus.size();
+    size_t coeff_count = parms.poly_modulus_degree();
+    _size = util::mul_safe(coeff_count, coeff_modulus_size);
+};
+
+void RawPolynomData::set_data(vector<uint64_t >& data) {
+    assert(data.size() == _size);
+    _data = move(data);
+};
+
+
+gsl::span<Ciphertext::ct_coeff_type > data_span(Ciphertext& c, size_t n) {
+    size_t poly_size = util::mul_safe(c.poly_modulus_degree(), c.coeff_modulus_size());
+    return { c.data(n), poly_size };
+}
+
+RawPolynomData generate_a(const SEALContext& context) {
+    auto ciphertext_prng = UniformRandomGeneratorFactory::DefaultFactory()->create();
+
+    auto &context_data = *context.key_context_data();
+    auto &parms = context_data.parms();
+
+    RawPolynomData rpd(parms);
+    vector<uint64_t > a_poly_data(rpd.size());
+    util::sample_poly_uniform(ciphertext_prng, parms, a_poly_data.data());
+    rpd.set_data(a_poly_data);
+    return rpd;
+}
+
+EncryptionParameters generateParameters() {
+    EncryptionParameters parms(scheme_type::ckks);
+
+    size_t poly_modulus_degree = 4096;
+    parms.set_poly_modulus_degree(poly_modulus_degree);
+    parms.set_coeff_modulus(CoeffModulus::Create(poly_modulus_degree, { 54, 54 }));
+    return parms;
+}
+
+size_t get_slot_count(const SEALContext& ctx) {
+    // slot count is only half of it. but every slot can take one complex number or 2 doubles. so in the end we get twice
+    // the space
+    return ctx.first_context_data()->parms().poly_modulus_degree();
+}
+
+// returns a vector filled with random double values between 0 and 1
+vector<double> random_plaintext_data(size_t count) {
+    // this example is just copied from the CKKS example of SEAL
+    vector<double> data;
+    data.reserve(count);
+    for (size_t i = 0; i < count; i++)
+    {
+        data.push_back(sqrt(static_cast<double>(rand()) / RAND_MAX));
+    }
+    return data;
+}
+
+GlobalState::GlobalState(double _scale) : context(generateParameters()), a(generate_a(context)), scale(_scale) {};
+
+
+PublicKey Client::generate_partial_public_key(const SecretKey &secret_key, const SEALContext &context, RawPolynomData& a)
+{
+    PublicKey public_key;
+    Ciphertext& destination = public_key.data();
+
+    // We use a fresh memory pool with `clear_on_destruction' enabled.
+    MemoryPoolHandle pool = MemoryManager::GetPool(mm_prof_opt::mm_force_new, true);
+
+    auto &context_data = *context.key_context_data();
+    auto &parms = context_data.parms();
+    auto &coeff_modulus = parms.coeff_modulus();
+    size_t coeff_modulus_size = coeff_modulus.size();
+    size_t coeff_count = parms.poly_modulus_degree();
+    auto ntt_tables = context_data.small_ntt_tables();
+    size_t encrypted_size = 2;
+
+    // If a polynomial is too small to store UniformRandomGeneratorInfo,
+    // it is best to just disable save_seed. Note that the size needed is
+    // the size of UniformRandomGeneratorInfo plus one (uint64_t) because
+    // of an indicator word that indicates a seeded ciphertext.
+    size_t poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size);
+
+    destination.resize(context, context.key_parms_id(), encrypted_size);
+    destination.is_ntt_form() = true;
+    destination.scale() = 1.0;
+
+    // Create an instance of a random number generator. We use this for sampling
+    // a seed for a second PRNG used for sampling u (the seed can be public
+    // information. This PRNG is also used for sampling the noise/error below.
+    auto bootstrap_prng = parms.random_generator()->create();
+
+    // Sample a public seed for generating uniform randomness
+    prng_seed_type public_prng_seed;
+    bootstrap_prng->generate(prng_seed_byte_count, reinterpret_cast<seal_byte *>(public_prng_seed.data()));
+
+    // Set up a new default PRNG for expanding u from the seed sampled above
+    auto ciphertext_prng = UniformRandomGeneratorFactory::DefaultFactory()->create(public_prng_seed);
+
+    // Generate ciphertext: (c[0], c[1]) = ([-(as+e)]_q, a)
+    uint64_t *c0 = destination.data();
+    uint64_t *c1 = destination.data(1);
+
+    // copy a into c1
+    assert(a.size() == poly_uint64_count);
+    copy(a.data(), a.data()+poly_uint64_count, c1);
+
+    // Sample e <-- chi
+    auto noise(util::allocate_poly(coeff_count, coeff_modulus_size, pool));
+    util::SEAL_NOISE_SAMPLER(bootstrap_prng, parms, noise.get());
+
+    // Calculate -(a*s + e) (mod q) and store in c[0]
+    for (size_t i = 0; i < coeff_modulus_size; i++)
+    {
+        util::dyadic_product_coeffmod(
+                secret_key.data().data() + i * coeff_count, c1 + i * coeff_count, coeff_count, coeff_modulus[i],
+                c0 + i * coeff_count);
+
+        // Transform the noise e into NTT representation
+        ntt_negacyclic_harvey(noise.get() + i * coeff_count, ntt_tables[i]);
+
+        util::add_poly_coeffmod(
+                noise.get() + i * coeff_count, c0 + i * coeff_count, coeff_count, coeff_modulus[i],
+                c0 + i * coeff_count);
+        util::negate_poly_coeffmod(c0 + i * coeff_count, coeff_count, coeff_modulus[i], c0 + i * coeff_count);
+    }
+
+    public_key.parms_id() = context.key_parms_id();
+    return public_key;
+}
+
+Client::Client(GlobalState global_state) : _gs(move(global_state)), _encoder(_gs.context) {
+    KeyGenerator keygen(_gs.context);
+    _partial_secret_key = keygen.secret_key();
+    _partial_public_key = generate_partial_public_key(_partial_secret_key, _gs.context, _gs.a);
+};
+
+Ciphertext Client::encrypted_data(gsl::span<const double> plain_data) {
+    if (!_encryptor) {
+        _encryptor = make_unique<Encryptor>(_gs.context, *_public_key);
+    }
+
+    // reinterpret plain data as complex<double>
+    assert(plain_data.size() % 2 == 0);
+    gsl::span complex_plain_data(reinterpret_cast<const complex<double>*>(plain_data.data()), plain_data.size() / 2);
+
+    Plaintext plaintext;
+    encoder().encode(complex_plain_data, _gs.scale, plaintext);
+    Ciphertext ciphertext;
+    encryptor().encrypt(plaintext, ciphertext);
+    return ciphertext;
+}
+
+Plaintext Client::partial_decryption(const Ciphertext& encrypted) {
+    using namespace seal::util;
+
+    // c = (c0, c1)
+    // dec(c) = c0+c1*s
+    // we need: c0 + c1*sum(s[i])
+    // so we return c1*s[i]*e[i] and add c0 at the server. e[i] is a noise term necessary for security
+
+    // adapted from Decryptor::decrypt
+
+    auto &context_data = *_gs.context.get_context_data(encrypted.parms_id());
+    auto &parms = context_data.parms();
+    auto &coeff_modulus = parms.coeff_modulus();
+    size_t coeff_count = parms.poly_modulus_degree();
+    size_t coeff_modulus_size = coeff_modulus.size();
+    size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_modulus_size);
+
+    Plaintext plaintext;
+    // Since we overwrite destination, we zeroize destination parameters
+    // This is necessary, otherwise resize will throw an exception.
+    plaintext.parms_id() = parms_id_zero;
+
+    // Resize destination to appropriate size
+    plaintext.resize(rns_poly_uint64_count);
+
+    // Do the dot product of encrypted and the secret key array using NTT.
+    RNSIter destination(plaintext.data(), coeff_count);
+    ConstRNSIter secret_key_array(_partial_secret_key.data().data(), coeff_count);
+    ConstRNSIter c1(encrypted.data(1), coeff_count);
+
+    SEAL_ITERATE(
+            iter(c1, secret_key_array, coeff_modulus, destination), coeff_modulus_size, [&](auto I) {
+                // put < c_1 * s > mod q in destination
+                dyadic_product_coeffmod(get<0>(I), get<1>(I), coeff_count, get<2>(I), get<3>(I));
+            });
+
+    // for security we need to introduce noise here
+    // this part is based on rlwe.cpp:encrypt_zero_symmetric()
+    auto prng = parms.random_generator()->create();
+    MemoryPoolHandle pool = MemoryManager::GetPool(mm_prof_opt::mm_force_new, true);
+    auto noise(allocate_poly(coeff_count, coeff_modulus_size, pool));
+    SEAL_NOISE_SAMPLER(prng, parms, noise.get());
+    auto ntt_tables = context_data.small_ntt_tables();
+
+    for (size_t i = 0; i < coeff_modulus_size; i++)
+    {
+        // Transform the noise e into NTT representation
+        ntt_negacyclic_harvey(noise.get() + i * coeff_count, ntt_tables[i]);
+
+        add_poly_coeffmod(
+                noise.get() + i * coeff_count, plaintext.data() + i * coeff_count, coeff_count, coeff_modulus[i],
+                plaintext.data() + i * coeff_count);
+    }
+
+    // Set destination parameters as in encrypted
+    plaintext.parms_id() = encrypted.parms_id();
+    plaintext.scale() = encrypted.scale();
+    return plaintext;
+}
+
+Server::Server(GlobalState global_state) : _gs(move(global_state)) {};
+
+void Server::accumulate_partial_public_keys(gsl::span<const Ciphertext> partial_pub_keys) {
+    // sum only the first poly of the ciphertexts
+    // the second poly is always the same, see GlobalState.a
+    Ciphertext sum = sum_first_polys(context(), partial_pub_keys);
+    _public_key.data() = sum;
+    assert(is_valid_for(_public_key, context()));
+}
+
+Ciphertext Server::sum_data(vector<Ciphertext>&& data) const {
+    Evaluator e(_gs.context);
+    Ciphertext result;
+    e.add_many(data, result);
+    return result;
+}
+
+vector<double> Server::average(const Ciphertext& encrypted_sum, gsl::span<const Plaintext> partial_decryptions) const {
+    // the partial decryptions were of the form c1*s[i]. we need c0 + sum(c1+s[i])
+    // so we need to add c0 once here.
+
+    // FIXME: this copies encrypted_sum, which is unnecessary
+    uint64_t num_coeffs = util::mul_safe(encrypted_sum.poly_modulus_degree(), encrypted_sum.coeff_modulus_size());
+    gsl::span<const Plaintext::pt_coeff_type> es_data(encrypted_sum.data(0), num_coeffs);
+    Plaintext c0(es_data);
+    c0.parms_id() = context().first_parms_id();
+    c0.scale() = encrypted_sum.scale();
+
+    sum_first_polys_inplace(_gs.context, c0, partial_decryptions); // c0 + sum(c1+s[i])
+
+    // decode sum
+    size_t slot_count = context().first_context_data()->parms().poly_modulus_degree() >> 1;
+    CKKSEncoder encoder(context());
+    vector<double> result(slot_count * 2, 0.0);
+    gsl::span<complex<double>> result_destination(reinterpret_cast<complex<double>*>(result.data()), slot_count);
+    encoder.decode(c0, result_destination);
+
+    // divide by N for average
+    for (double& x : result) {
+        x /= static_cast<double>(partial_decryptions.size());
+    }
+    return result;
+}
+
diff --git a/src/main/cpp/he/libhe.h b/src/main/cpp/he/libhe.h
new file mode 100644
index 0000000000..25774a80a4
--- /dev/null
+++ b/src/main/cpp/he/libhe.h
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef LIBHE_H
+#define LIBHE_H
+
+#include <cassert>
+#include <algorithm>
+#include <optional>
+#include <gsl/span>
+
+#include "seal/seal.h"
+#include "seal/util/common.h"
+#include "seal/util/rlwe.h"
+#include "seal/util/polyarithsmallmod.h"
+
+using namespace std;
+using namespace seal;
+
+class RawPolynomData {
+    vector<uint64_t > _data;
+    size_t _size;
+
+public:
+    explicit RawPolynomData(const SEALContext& context);
+
+    SEAL_NODISCARD inline const size_t& size() const { return _size; };
+    SEAL_NODISCARD inline uint64_t* data() { return _data.data(); };
+    SEAL_NODISCARD inline gsl::span<uint64_t > data_span() { return { data(), size() }; };
+
+    void set_data(vector<uint64_t >& data);
+};
+
+gsl::span<Ciphertext::ct_coeff_type > data_span(Ciphertext& c, size_t n);
+
+RawPolynomData generate_a(const SEALContext& context);
+
+EncryptionParameters generateParameters();
+
+size_t get_slot_count(const SEALContext& ctx);
+
+// returns a vector filled with random double values between 0 and 1
+vector<double> random_plaintext_data(size_t count);
+
+struct GlobalState {
+    SEALContext context;
+    RawPolynomData a;
+    double scale;
+
+    explicit GlobalState(double _scale);
+};
+
+class Client {
+    GlobalState _gs;
+    CKKSEncoder _encoder;
+    SecretKey _partial_secret_key;
+    PublicKey _partial_public_key;
+    std::optional<PublicKey> _public_key = std::nullopt;
+    std::unique_ptr<Encryptor> _encryptor = nullptr;
+
+    SEAL_NODISCARD static PublicKey generate_partial_public_key(const SecretKey &secret_key, const SEALContext &context, RawPolynomData& a);
+
+public:
+    explicit Client(GlobalState global_state);
+
+    SEAL_NODISCARD inline const SEALContext& context() const { return _gs.context; };
+    SEAL_NODISCARD inline const PublicKey& partial_public_key() const { return _partial_public_key; };
+    SEAL_NODISCARD inline const CKKSEncoder& encoder() const { return _encoder; };
+    SEAL_NODISCARD inline CKKSEncoder& encoder() { return _encoder; };
+    SEAL_NODISCARD inline const Encryptor& encryptor() const { assert(_encryptor != nullptr); return *_encryptor; };
+    SEAL_NODISCARD inline const PublicKey& public_key() { return *_public_key; };
+    inline void set_public_key(const PublicKey& pk) { _public_key = make_optional(pk); };
+
+    Ciphertext encrypted_data(gsl::span<const double> plain_data);
+
+    Plaintext partial_decryption(const Ciphertext& encrypted);
+};
+
+// adds b to a in place
+template<typename T> void sum_first_poly_inplace(const SEALContext& context, T& a, const T& b) {
+    auto &context_data = *context.get_context_data(a.parms_id());
+    auto &parms = context_data.parms();
+    auto &coeff_modulus = parms.coeff_modulus();
+    size_t coeff_count = parms.poly_modulus_degree();
+    size_t coeff_modulus_size = coeff_modulus.size();
+
+    // by dereferencing we get only the first poly
+    auto summand_iter = *util::ConstPolyIter(b.data(), coeff_count, coeff_modulus_size);
+    auto sum_iter = *util::ConstPolyIter(a.data(), coeff_count, coeff_modulus_size);
+    auto result_iter = *util::PolyIter(a.data(), coeff_count, coeff_modulus_size);
+    // see Evaluator::add_inplace
+    util::add_poly_coeffmod(sum_iter, summand_iter, coeff_modulus_size, coeff_modulus, result_iter);
+}
+
+// This function adds the first polys in summands to sum (either Ciphertext or Plaintext).
+template<typename T> T sum_first_polys_inplace(const SEALContext& context, T& sum, gsl::span<const T> summands) {
+    for (size_t i = 0; i < summands.size(); i++) {
+        sum_first_poly_inplace(context, sum, summands[i]);
+    }
+    return sum;
+}
+
+// This function sums the first polys in summands (either Ciphertext or Plaintext).
+template<typename T> T sum_first_polys(const SEALContext& context, gsl::span<const T> summands) {
+    T sum = summands[0];
+    sum_first_polys_inplace(context, sum, gsl::span(&summands.data()[1], summands.size() - 1));
+    return sum;
+}
+
+class Server {
+    GlobalState _gs;
+    PublicKey _public_key;
+
+public:
+    explicit Server(GlobalState global_state);
+
+    SEAL_NODISCARD inline RawPolynomData& a() { return _gs.a; };
+    SEAL_NODISCARD inline const SEALContext& context() const { return _gs.context; };
+    SEAL_NODISCARD inline const PublicKey& public_key() const { return _public_key; };
+
+    void accumulate_partial_public_keys(gsl::span<const Ciphertext> partial_pub_keys);
+
+    Ciphertext sum_data(vector<Ciphertext>&& data) const;
+
+    vector<double> average(const Ciphertext& encrypted_sum, gsl::span<const Plaintext> partial_decryptions) const;
+};
+
+#endif //LIBHE_H
diff --git a/src/main/cpp/lib/libhe-Linux-x86_64.so b/src/main/cpp/lib/libhe-Linux-x86_64.so
new file mode 100644
index 0000000000..5d55922788
Binary files /dev/null and b/src/main/cpp/lib/libhe-Linux-x86_64.so differ
diff --git a/src/main/cpp/systemds.cpp b/src/main/cpp/systemds.cpp
index bed1d42f57..86ac053ad1 100644
--- a/src/main/cpp/systemds.cpp
+++ b/src/main/cpp/systemds.cpp
@@ -17,6 +17,12 @@
  * under the License.
  */
 
+#ifdef _WIN32
+#include <winsock.h>
+#else
+#include <arpa/inet.h>
+#endif
+
 #include "common.h"
 #include "libmatrixdnn.h"
 #include "libmatrixmult.h"
@@ -248,4 +254,4 @@ JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_conv2dBackwardF
   RELEASE_INPUT_ARRAY(env, dout, doutPtr, numThreads);
   RELEASE_ARRAY(env, ret, retPtr, numThreads);
   return static_cast<jlong>(nnz);
-}
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index b6ef98abc7..3dfad3413e 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -44,7 +44,7 @@ public class Types
 	 * Data types (tensor, matrix, scalar, frame, object, unknown).
 	 */
 	public enum DataType {
-		TENSOR, MATRIX, SCALAR, FRAME, LIST, UNKNOWN;
+		TENSOR, MATRIX, SCALAR, FRAME, LIST, ENCRYPTED_CIPHER, ENCRYPTED_PLAIN, UNKNOWN;
 		
 		public boolean isMatrix() {
 			return this == MATRIX;
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index a6bcc2dac4..56275fd3b9 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -320,7 +320,8 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 			Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
 			Statement.PS_VAL_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
 			Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
-			Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_NBATCHES, Statement.PS_MODELAVG);
+			Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_NBATCHES,
+			Statement.PS_MODELAVG, Statement.PS_HE);
 		checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
 		// check existence and correctness of parameters
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java
index 995a1e2330..d22a540180 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -76,6 +76,7 @@ public abstract class Statement implements ParseInfo
 	public static final String PS_SEED = "seed";
 	public static final String PS_MODELAVG = "modelAvg";
 	public static final String PS_NBATCHES = "nbatches";
+	public static final String PS_HE = "he";
 	public enum PSModeType {
 		FEDERATED, LOCAL, REMOTE_SPARK
 	}
@@ -124,7 +125,6 @@ public abstract class Statement implements ParseInfo
 	public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname";
 	public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid";
 
-
 	public abstract boolean controlStatement();
 	
 	public abstract VariableSet variablesRead();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index e398abd32b..e211bfeb47 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -38,6 +38,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
 import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -82,9 +83,11 @@ public class ExecutionContext {
 	//lineage map, cache, prepared dedup blocks
 	protected Lineage _lineage;
 
+	protected SEALClient _seal_client;
+
 	//parfor temporary functions (created by eval)
 	protected Set<String> _fnNames;
-	
+
 	/**
 	 * List of {@link GPUContext}s owned by this {@link ExecutionContext}
 	 */
@@ -152,6 +155,14 @@ public class ExecutionContext {
 		return _tid;
 	}
 
+	public void setSealClient(SEALClient seal_client) {
+		_seal_client = seal_client;
+	}
+
+	public SEALClient getSealClient() {
+		return _seal_client;
+	}
+
 	/**
 	 * Get the i-th GPUContext
 	 * @param index index of the GPUContext
@@ -891,11 +902,11 @@ public class ExecutionContext {
 	private static String getNonExistingVarError(String varname) {
 		return "Variable '" + varname + "' does not exist in the symbol table.";
 	}
-	
+
 	public void addTmpParforFunction(String fname) {
 		_fnNames.add(fname);
 	}
-	
+
 	public Set<String> getTmpParforFunctions() {
 		return _fnNames;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 74e113ba02..1fb1e8b1ec 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -35,6 +35,7 @@ import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.meta.MetaData;
 
@@ -193,6 +194,7 @@ public class FederatedData {
 			@Override
 			protected void initChannel(SocketChannel ch) throws Exception {
 				final ChannelPipeline cp = ch.pipeline();
+				cp.addLast("NetworkTrafficCounter", new NetworkTrafficCounter(FederatedStatistics::logServerTraffic));
 				if(ssl)
 					cp.addLast(createSSLHandler(ch, address));
 				if(timeout > -1)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
index de56a1a52e..77ffb7f847 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
@@ -25,6 +25,7 @@ import java.util.concurrent.Future;
 import org.apache.log4j.Logger;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
 
 public class FederatedLocalData extends FederatedData {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index d95b02afd2..5907776898 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -73,6 +73,8 @@ public class FederatedStatistics {
 	private static final LongAdder transferredMatrixBytes = new LongAdder();
 	private static final LongAdder transferredFrameBytes = new LongAdder();
 	private static final LongAdder asyncPrefetchCount = new LongAdder();
+	private static final LongAdder bytesSent = new LongAdder();
+	private static final LongAdder bytesReceived = new LongAdder();
 
 	// stats on the federated worker itself
 	private static final LongAdder fedLookupTableGetCount = new LongAdder();
@@ -80,11 +82,25 @@ public class FederatedStatistics {
 	private static final LongAdder fedLookupTableEntryCount = new LongAdder();
 	private static final LongAdder fedReuseReadHitCount = new LongAdder();
 	private static final LongAdder fedReuseReadBytesCount = new LongAdder();
+	private static final LongAdder fedBytesSent = new LongAdder();
+	private static final LongAdder fedBytesReceived = new LongAdder();
+
 	private static final LongAdder fedPutLineageCount = new LongAdder();
 	private static final LongAdder fedPutLineageItems = new LongAdder();
 	private static final LongAdder fedSerializationReuseCount = new LongAdder();
 	private static final LongAdder fedSerializationReuseBytes = new LongAdder();
 
+	public static void logServerTraffic(long read, long written) {
+		bytesReceived.add(read);
+		bytesSent.add(written);
+	}
+
+	public static void logWorkerTraffic(long read, long written) {
+		fedBytesReceived.add(read);
+		fedBytesSent.add(written);
+	}
+
+
 	public static synchronized void incFederated(RequestType rqt, List<Object> data){
 		switch (rqt) {
 			case READ_VAR:
@@ -164,6 +180,10 @@ public class FederatedStatistics {
 		fedPutLineageItems.reset();
 		fedSerializationReuseCount.reset();
 		fedSerializationReuseBytes.reset();
+		bytesSent.reset();
+		bytesReceived.reset();
+		fedBytesSent.reset();
+		fedBytesReceived.reset();
 	}
 
 	public static String displayFedIOExecStatistics() {
@@ -194,6 +214,19 @@ public class FederatedStatistics {
 		return "";
 	}
 
+	public static String displayNetworkTrafficStatistics() {
+		return "Server I/O bytes (read/written):\t" +
+				bytesReceived.longValue() +
+				"/" +
+				bytesSent.longValue() +
+				"\n" +
+				"Worker I/O bytes (read/written):\t" +
+				fedBytesReceived.longValue() +
+				"/" +
+				fedBytesSent.longValue() +
+				"\n";
+	}
+
 
 	public static void registerFedWorker(String host, int port) {
 		_fedWorkerAddresses.add(new ImmutablePair<>(host, Integer.valueOf(port)));
@@ -232,6 +265,7 @@ public class FederatedStatistics {
 		sb.append(displayLinCacheStats(fedStats.linCacheStats));
 		sb.append(displayMultiTenantStats(fedStats.mtStats));
 		sb.append(displayHeavyHitters(fedStats.heavyHitters, numHeavyHitters));
+		sb.append(displayNetworkTrafficStatistics());
 		return sb.toString();
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index a41f656524..a7c188ef5a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -33,6 +33,7 @@ import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
@@ -53,6 +54,9 @@ import io.netty.handler.codec.serialization.ObjectEncoder;
 import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslContextBuilder;
 import io.netty.handler.ssl.util.SelfSignedCertificate;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import io.netty.handler.codec.serialization.ObjectDecoder;
+import io.netty.handler.codec.serialization.ClassResolvers;
 
 public class FederatedWorker {
 	protected static Logger log = Logger.getLogger(FederatedWorker.class);
@@ -62,6 +66,7 @@ public class FederatedWorker {
 	private final FederatedReadCache _frc;
 	private final FederatedWorkloadAnalyzer _fan;
 	private final boolean _debug;
+	private Timing networkTimer = new Timing();
 
 	public FederatedWorker(int port, boolean debug) {
 		_flt = new FederatedLookupTable();
@@ -183,10 +188,19 @@ public class FederatedWorker {
 				@Override
 				public void initChannel(SocketChannel ch) {
 					final ChannelPipeline cp = ch.pipeline();
+					if(ConfigurationManager.getDMLConfig()
+						.getBooleanValue(DMLConfig.USE_SSL_FEDERATED_COMMUNICATION)) {
+						cp.addLast(cont2.newHandler(ch.alloc()));
+					}
 					if(ssl)
 						cp.addLast(cont2.newHandler(ch.alloc()));
+					cp.addLast("NetworkTrafficCounter", new NetworkTrafficCounter(FederatedStatistics::logWorkerTraffic));
+					cp.addLast("ObjectDecoder",
+						new ObjectDecoder(Integer.MAX_VALUE,
+							ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
+					cp.addLast("ObjectEncoder", new ObjectEncoder());
 					cp.addLast(FederationUtils.decoder(), new FederatedResponseEncoder());
-					cp.addLast(new FederatedWorkerHandler(_flt, _frc, _fan));
+					cp.addLast(new FederatedWorkerHandler(_flt, _frc, _fan, networkTimer));
 				}
 			};
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 4c90c74b1b..d0865df120 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -49,6 +49,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.Instruction.IType;
@@ -73,6 +74,7 @@ import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.privacy.DMLPrivacyException;
 import org.apache.sysds.runtime.privacy.PrivacyMonitor;
 import org.apache.sysds.utils.Statistics;
+import org.apache.sysds.utils.stats.ParamServStatistics;
 
 import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelFutureListener;
@@ -87,13 +89,15 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 	private static final Log LOG = LogFactory.getLog(FederatedWorkerHandler.class.getName());
 
 	/** The Federated Lookup Table of the current Federated Worker. */
-	private final FederatedLookupTable _flt;
+	private FederatedLookupTable _flt;
 
 	/** Read cache shared by all worker handlers */
-	private final FederatedReadCache _frc;
+	private FederatedReadCache _frc;
+	private Timing _timing = null;
+
 
 	/** Federated workload analyzer */
-	private final FederatedWorkloadAnalyzer _fan;
+	private FederatedWorkloadAnalyzer _fan;
 
 	/**
 	 * Create a Federated Worker Handler.
@@ -111,6 +115,11 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		_fan = fan;
 	}
 
+	public FederatedWorkerHandler(FederatedLookupTable flt, FederatedReadCache frc, FederatedWorkloadAnalyzer fan, Timing timing) {
+		this(flt, frc, fan);
+		_timing = timing;
+	}
+
 	@Override
 	public void channelRead(ChannelHandlerContext ctx, Object msg) {
 		ctx.writeAndFlush(createResponse(msg, ctx.channel().remoteAddress()))
@@ -122,6 +131,14 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 	}
 
 	private FederatedResponse createResponse(Object msg, SocketAddress remoteAddress) {
+		try {
+			if (_timing != null) {
+				ParamServStatistics.accFedNetworkTime((long) _timing.stop());
+			}
+		} catch (RuntimeException ignored) {
+			// ignore timing if it wasn't started yet
+		}
+
 		String host;
 		if(remoteAddress instanceof InetSocketAddress) {
 			host = ((InetSocketAddress) remoteAddress).getHostString();
@@ -135,7 +152,11 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 			host = FederatedLookupTable.NOHOST;
 		}
 
-		return createResponse(msg, host);
+		FederatedResponse res = createResponse(msg, host);
+		if (_timing != null) {
+			_timing.start();
+		}
+		return res;
 	}
 
 	private FederatedResponse createResponse(Object msg, String remoteHost) {
@@ -162,7 +183,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 
 	private FederatedResponse createResponse(FederatedRequest[] requests, String remoteHost)
 		throws DMLPrivacyException, FederatedWorkerHandlerException, Exception {
-			
+
 		FederatedResponse response = null; // last response
 		boolean containsCLEAR = false;
 		for(int i = 0; i < requests.length; i++) {
@@ -294,7 +315,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 				throw ex;
 			}
 		}
-		
+
 		if(shouldTryAsyncCompress()) // TODO: replace the reused object
 			CompressedMatrixBlockFactory.compressAsync(ec, sId);
 
@@ -405,7 +426,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 			throw new FederatedWorkerHandlerException(
 				"Unsupported object type, has to be of type CacheBlock or ScalarObject");
 
-				
+
 		// set variable and construct empty response
 		ec.setVariable(varName, data);
 
@@ -429,13 +450,12 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 
 	private FederatedResponse getVariable(FederatedRequest request, ExecutionContextMap ecm) {
 		try{
-
 			checkNumParams(request.getNumParams(), 0);
 			ExecutionContext ec = ecm.get(request.getTID());
 			if(!ec.containsVariable(String.valueOf(request.getID())))
 				throw new FederatedWorkerHandlerException(
 					"Variable " + request.getID() + " does not exist at federated worker.");
-	
+
 			// get variable and construct response
 			Data dataObject = ec.getVariable(String.valueOf(request.getID()));
 			dataObject = PrivacyMonitor.handlePrivacy(dataObject);
@@ -467,7 +487,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		adaptToWorkload(ec, _fan, tid, ins);
 		return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
 	}
-	
+
 	private static ExecutionContext getContextForInstruction(long id, Instruction ins, ExecutionContextMap ecm){
 		final ExecutionContext ec = ecm.get(id);
 		//handle missing spark execution context
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 004e35b571..54d778486a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -34,24 +34,15 @@ import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysds.runtime.controlprogram.ProgramBlock;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.*;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
-import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.cp.BooleanObject;
-import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.IntObject;
-import org.apache.sysds.runtime.instructions.cp.ListObject;
-import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.instructions.cp.*;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -59,11 +50,13 @@ import org.apache.sysds.runtime.util.ProgramConverter;
 import org.apache.sysds.utils.stats.ParamServStatistics;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.concurrent.Callable;
 import java.util.concurrent.Future;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import static org.apache.sysds.runtime.util.ProgramConverter.*;
 
@@ -83,20 +76,23 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	private final boolean _weighting;
 	private double _weightingFactor = 1;
 	private boolean _cycleStartAt0 = false;
+	private boolean _use_homomorphic_encryption = false;
+	private PublicKey _partial_public_key;
 
 	public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
 		PSRuntimeBalancing runtimeBalancing, boolean weighting, int epochs, long batchSize,
-		int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg)
+		int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg, boolean use_homomorphic_encryption)
 	{
 		super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches, modelAvg);
 
 		_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
 		_runtimeBalancing = runtimeBalancing;
-		_weighting = weighting;
+		_weighting = weighting && (!use_homomorphic_encryption); // FIXME: this disables weighting in favor of homomorphic encryption
 		_numBatchesPerNbatch = nbatches;
 		// generate the ID for the model
 		_modelVarID = FederationUtils.getNextFedDataID();
-		_modelAvg = modelAvg;
+		_modelAvg = _use_homomorphic_encryption || modelAvg; // we always have to use modelAvg when using homomorphic encryption
+		_use_homomorphic_encryption = use_homomorphic_encryption;
 	}
 
 	/**
@@ -106,6 +102,9 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	 */
 	public void setup(double weightingFactor) {
 		incWorkerNumber();
+		if (_use_homomorphic_encryption) {
+			((HEParamServer)_ps).registerThread(_workerID, this);
+		}
 
 		// prepare features and labels
 		_featuresData = _features.getFedMapping().getFederatedData()[0];
@@ -160,21 +159,43 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			PROG_END);
 
 		// write program and meta data to worker
-		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
-			new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
-				new SetupFederatedWorker(_batchSize, dataSize, _possibleBatchesPerLocalEpoch,
-					programSerialized, _inst.getNamespace(), _inst.getFunctionName(),
-					_ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"),
-					_modelVarID, _nbatches, _modelAvg)));
+		Future<FederatedResponse> udfResponse;
+
+		final SetupFederatedWorker udf;
+		if (_use_homomorphic_encryption) {
+			byte[] a = ((HEParamServer)_ps).generateA();
+			// generate pk[i] on each client and return it
+			udf = new SetupHEFederatedWorker(a);
+		} else {
+			udf = new SetupFederatedWorker();
+		}
 
+		udf.setParams(_batchSize, dataSize, _possibleBatchesPerLocalEpoch,
+				programSerialized, _inst.getNamespace(), _inst.getFunctionName(),
+				_ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"),
+				_modelVarID, _nbatches, _use_homomorphic_encryption || _modelAvg);
+
+		udfResponse = _featuresData.executeFederatedOperation(
+				new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf));
+
+		FederatedResponse response;
 		try {
-			FederatedResponse response = udfResponse.get();
+			response = udfResponse.get();
 			if(!response.isSuccessful())
 				throw new DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed");
+
 		}
 		catch(Exception e) {
 			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Setup UDF" + e.getMessage());
 		}
+		if (_use_homomorphic_encryption) {
+			try {
+				_partial_public_key = (PublicKey) response.getData()[0];
+			}
+			catch (Exception e) {
+				throw new DMLRuntimeException("FederatedLocalPSThread: HE Setup UDF didn't return an object");
+			}
+		}
 	}
 
 	/**
@@ -196,29 +217,33 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage());
 		}
 	}
-	
+
 	/**
 	 * Setup UDF executed on the federated worker
 	 */
 	private static class SetupFederatedWorker extends FederatedUDF {
 		private static final long serialVersionUID = -3148991224792675607L;
-		private final long _batchSize;
-		private final long _dataSize;
-		private final int _possibleBatchesPerLocalEpoch;
-		private final String _programString;
-		private final String _namespace;
-		private final String _gradientsFunctionName;
-		private final String _aggregationFunctionName;
-		private final ListObject _hyperParams;
-		private final long _modelVarID;
-		private final boolean _modelAvg;
-		private final int _nbatches;
-
-		protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch,
-			String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName,
-			ListObject hyperParams, long modelVarID, int nbatches, boolean modelAvg)
+		private long _batchSize;
+		private long _dataSize;
+		private int _possibleBatchesPerLocalEpoch;
+		private String _programString;
+		private String _namespace;
+		private String _gradientsFunctionName;
+		private String _aggregationFunctionName;
+		private ListObject _hyperParams;
+		private long _modelVarID;
+		private boolean _modelAvg;
+		private int _nbatches;
+		private boolean _params_set = false;
+
+		protected SetupFederatedWorker()
 		{
 			super(new long[]{});
+		}
+
+		public void setParams(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch,
+						 String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName,
+						 ListObject hyperParams, long modelVarID, int nbatches, boolean modelAvg) {
 			_batchSize = batchSize;
 			_dataSize = dataSize;
 			_possibleBatchesPerLocalEpoch = possibleBatchesPerLocalEpoch;
@@ -230,10 +255,15 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			_modelVarID = modelVarID;
 			_modelAvg = modelAvg;
 			_nbatches = nbatches;
+			_params_set = true;
 		}
 
 		@Override
 		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			if (!_params_set) {
+				return new FederatedResponse(FederatedResponse.ResponseType.ERROR, "params were not set");
+			}
+
 			// parse and set program
 			ec.setProgram(ProgramConverter.parseProgram(_programString, 0));
 
@@ -258,9 +288,59 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 		}
 	}
 
-	/**
+	private static class SetupHEFederatedWorker extends SetupFederatedWorker {
+		private static final long serialVersionUID = 9128347291804980123L;
+
+		byte[] _partial_pubkey_a;
+
+        protected SetupHEFederatedWorker(byte[] partial_pubkey_a) {
+            // delegate everything to parent class. set modelAvg to true, as it is the only supported case
+            super();
+            _partial_pubkey_a = partial_pubkey_a;
+        }
+
+        @Override
+        public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			// TODO: set other CKKS parameters
+			// TODO generate partial public key
+			NativeHEHelper.initialize();
+
+			SEALClient sc = new SEALClient(_partial_pubkey_a);
+			ec.setSealClient(sc);
+			PublicKey partial_pubkey = sc.generatePartialPublicKey();
+
+			FederatedResponse res = super.execute(ec, data);
+			if (!res.isSuccessful()) {
+				return res;
+			}
+
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, partial_pubkey);
+        }
+    }
+		/**
 	 * Teardown UDF executed on the federated worker
 	 */
+	private static class SetPublicKeyFederatedWorker extends FederatedUDF {
+		private static final long serialVersionUID = -1536502123123318969L;
+		private final PublicKey _public_key;
+
+		protected SetPublicKeyFederatedWorker(PublicKey public_key) {
+			super(new long[]{});
+			_public_key = public_key;
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			ec.getSealClient().setPublicKey(_public_key);
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+		}
+
+		@Override
+		public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+			return null;
+		}
+	}
+
 	private static class TeardownFederatedWorker extends FederatedUDF {
 		private static final long serialVersionUID = -153650281873318969L;
 
@@ -298,6 +378,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	@Override
 	public Void call() throws Exception {
 		try {
+			Timing tTotal = new Timing(true);
 			switch (_freq) {
 				case BATCH:
 					computeWithBatchUpdates();
@@ -324,6 +405,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	}
 
 	protected void weightAndPushGradients(ListObject gradients) {
+		assert (!(_weighting && _use_homomorphic_encryption)) : "weights and homomorphic encryption are not supported together";
 		// scale gradients - must only include MatrixObjects
 		if(_weighting && _weightingFactor != 1) {
 			Timing tWeighting = DMLScript.STATISTICS ? new Timing(true) : null;
@@ -354,11 +436,17 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
 				ListObject model = pullModel();
 				ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum);
-				if (_modelAvg)
+
+				Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+				if (_modelAvg && !_use_homomorphic_encryption)
+					// we can't call the agg fn if we use HE, because it is implemented homomorphically in SEALServer::aggregateCiphertexts
 					model = _ps.updateLocalModel(_ec, gradients, model);
 				else
 					ParamservUtils.cleanupListObject(model);
-				weightAndPushGradients(_modelAvg ? model : gradients);
+				weightAndPushGradients((_modelAvg && !_use_homomorphic_encryption) ? model : gradients);
+				if (tAgg != null) {
+					ParamServStatistics.accFedAggregation((long)tAgg.stop());
+				}
 				ParamservUtils.cleanupListObject(gradients);
 			}
 		}
@@ -377,7 +465,13 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				currentLocalBatchNumber = currentLocalBatchNumber + _numBatchesPerNbatch;
 				ListObject model = pullModel();
 				ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerNbatch, localStartBatchNum, true);
+
+				Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
 				weightAndPushGradients(gradients);
+				if (tAgg != null) {
+					ParamServStatistics.accFedAggregation((long)tAgg.stop());
+				}
+
 				ParamservUtils.cleanupListObject(model);
 				ParamservUtils.cleanupListObject(gradients);
 			}
@@ -394,7 +488,13 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			// Pull the global parameters from ps
 			ListObject model = pullModel();
 			ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true);
+
+			Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
 			weightAndPushGradients(gradients);
+			if (tAgg != null) {
+				ParamServStatistics.accFedAggregation((long)tAgg.stop());
+			}
+
 			ParamservUtils.cleanupListObject(model);
 			ParamservUtils.cleanupListObject(gradients);
 		}
@@ -431,11 +531,16 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 		}
 
 		// create and execute the udf on the remote worker
+		Object udf;
+		if (_use_homomorphic_encryption) {
+			udf = new HEComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID()},
+					new long[]{_modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum);
+		} else {
+			udf = new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(),
+					_modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum);
+		}
 		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
-			new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
-				new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(),
-				_modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum)
-		));
+				new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf));
 
 		try {
 			Object[] responseData = udfResponse.get().getData();
@@ -444,6 +549,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				long workerComputing = ((DoubleObject) responseData[1]).getLongValue();
 				ParamServStatistics.accFedWorkerComputing(workerComputing);
 				ParamServStatistics.accFedCommunicationTime(total - workerComputing);
+				ParamServStatistics.accFedNetworkTime(total);
 			}
 			return (ListObject) responseData[0];
 		}
@@ -492,12 +598,12 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			ArrayList<DataIdentifier> inputs = func.getInputParams();
 			ArrayList<DataIdentifier> outputs = func.getOutputParams();
 			CPOperand[] boundInputs = inputs.stream()
-				.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-				.toArray(CPOperand[]::new);
+					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+					.toArray(CPOperand[]::new);
 			ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
-				.collect(Collectors.toCollection(ArrayList::new));
+					.collect(Collectors.toCollection(ArrayList::new));
 			Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunc,
-				opt, boundInputs, func.getInputParamNames(), outputNames, "gradient function");
+					opt, boundInputs, func.getInputParamNames(), outputNames, "gradient function");
 			DataIdentifier gradientsOutput = outputs.get(0);
 
 			// recreate aggregation instruction and output if needed
@@ -508,12 +614,12 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				inputs = func.getInputParams();
 				outputs = func.getOutputParams();
 				boundInputs = inputs.stream()
-					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-					.toArray(CPOperand[]::new);
+						.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+						.toArray(CPOperand[]::new);
 				outputNames = outputs.stream().map(DataIdentifier::getName)
-					.collect(Collectors.toCollection(ArrayList::new));
+						.collect(Collectors.toCollection(ArrayList::new));
 				aggregationInstruction = new FunctionCallCPInstruction(namespace, aggFunc,
-					opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function");
+						opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function");
 				aggregationOutput = outputs.get(0);
 			}
 			ListObject accGradients = null;
@@ -540,7 +646,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				// accrue the computed gradients - In the single batch case this is just a list copy
 				// is this equivalent for momentum based and AMS prob?
 				accGradients = modelAvg ? null :
-					ParamservUtils.accrueGradients(accGradients, gradients, false);
+						ParamservUtils.accrueGradients(accGradients, gradients, false);
 
 				// update the local model with gradients if needed
 				// FIXME ensure that with modelAvg we always update the model
@@ -564,11 +670,12 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			// model clean up
 			ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
 			// TODO double check cleanup gradients and models
-			
+
 			// stop timing
 			DoubleObject gradientsTime = new DoubleObject(tGradients.stop());
+			ParamServStatistics.accGradientComputeTime(gradientsTime.getLongValue());
 			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
-				new Object[]{modelAvg ? model : accGradients, gradientsTime});
+					new Object[]{modelAvg ? model : accGradients, gradientsTime});
 		}
 
 		@Override
@@ -577,6 +684,102 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 		}
 	}
 
+
+	/**
+	 * This wraps federatedComputeGradientsForNBatches and adds encryption
+	 */
+	private static class HEComputeGradientsForNBatches extends federatedComputeGradientsForNBatches {
+		private static final long serialVersionUID = -3535901512348794852L;
+		private final long[] _deferredIds;
+
+		protected HEComputeGradientsForNBatches(long[] deferredIds, long[] inIDs, int numBatchesToCompute, boolean localUpdate, int localStartBatchNum) {
+			super(inIDs, numBatchesToCompute, localUpdate, localStartBatchNum);
+			_deferredIds = deferredIds;
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data_without_deferred) {
+			Timing tTotal = new Timing(true);
+			// add features and gradients to data
+			Data[] deferred_inputs = Arrays.stream(_deferredIds).mapToObj(id -> ec.getVariable(String.valueOf(id))).toArray(Data[]::new);
+			Data[] data = Arrays.copyOf(deferred_inputs, deferred_inputs.length + data_without_deferred.length);
+			System.arraycopy(data_without_deferred, 0, data, deferred_inputs.length, data_without_deferred.length);
+			FederatedResponse res = super.execute(ec, data);
+
+			if (!res.isSuccessful()) {
+				return res;
+			}
+
+			// encrypt model with SEAL
+			try {
+				Timing tEncrypt = DMLScript.STATISTICS ? new Timing(true) : null;
+
+				ListObject model = (ListObject) res.getData()[0];
+				ListObject encrypted_model = new ListObject(model);
+				IntStream.range(0, model.getLength()).forEach(matrix_idx -> {
+					CiphertextMatrix encrypted_matrix = ec.getSealClient().encrypt((MatrixObject) model.getData(matrix_idx));
+					encrypted_model.set(matrix_idx, encrypted_matrix);
+				});
+
+				// overwrite model with encryption
+				res.getData()[0] = encrypted_model;
+
+				if (tEncrypt != null) {
+					ParamServStatistics.accHEEncryptionTime((long)tEncrypt.stop());
+				}
+
+				// stop timing
+				DoubleObject gradientsTime = new DoubleObject(tTotal.stop());
+				res.getData()[1] = gradientsTime;
+			} catch (Exception e) {
+				return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new Object[] { e });
+			}
+			return res;
+		}
+	}
+
+	private static class HEComputePartialDecryption extends FederatedUDF {
+		private static final long serialVersionUID = -4535098129348794852L;
+		private final CiphertextMatrix[] _encrypted_sum;
+
+		protected HEComputePartialDecryption(CiphertextMatrix[] encrypted_sum) {
+			super(new long[]{});
+			_encrypted_sum = encrypted_sum;
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			Timing tPartialDecrypt = DMLScript.STATISTICS ? new Timing(true) : null;
+			PlaintextMatrix[] result = new PlaintextMatrix[_encrypted_sum.length];
+			IntStream.range(0, result.length).forEach(i -> {
+				result[i] = ec.getSealClient().partiallyDecrypt(_encrypted_sum[i]);
+			});
+			if (tPartialDecrypt != null) {
+				ParamServStatistics.accHEPartialDecryptionTime((long)tPartialDecrypt.stop());
+			}
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, result);
+		}
+
+		@Override
+		public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+			return null;
+		}
+	}
+
+
+	public PlaintextMatrix[] getPartialDecryption(CiphertextMatrix[] encrypted_sum) {
+		Object udf = new HEComputePartialDecryption(encrypted_sum);
+		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
+				new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf));
+
+		try {
+			Object[] responseData = udfResponse.get().getData();
+			return (PlaintextMatrix[]) responseData;
+		} catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage());
+		}
+	}
+
 	// Statistics methods
 	protected void accFedPSGradientWeightingTime(Timing time) {
 		if (DMLScript.STATISTICS && time != null)
@@ -608,4 +811,24 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	protected void accGradientComputeTime(Timing time) {
 		throw new NotImplementedException();
 	}
+
+	public PublicKey getPartialPublicKey() {
+		return _partial_public_key;
+	}
+
+	public void setPublicKey(PublicKey public_key) {
+		Future<FederatedResponse> res = _featuresData.executeFederatedOperation(
+				new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
+						new SetPublicKeyFederatedWorker(public_key)));
+
+		try {
+			FederatedResponse response = res.get();
+			if(!response.isSuccessful())
+				throw new DMLRuntimeException("FederatedLocalPSThread: SetPublicKey UDF failed");
+
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Public Key Setup UDF" + e.getMessage());
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
new file mode 100644
index 0000000000..577bf6c820
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
+import org.apache.sysds.utils.NativeHelper;
+import org.apache.sysds.utils.stats.ParamServStatistics;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * This class implements Homomorphic Encryption (HE) for LocalParamServer. It only supports modelAvg=true.
+ */
+public class HEParamServer extends LocalParamServer {
+    private int _thread_counter = 0;
+    private final List<FederatedPSControlThread> _threads;
+    private final List<Object> _result_buffer; // one per thread
+    private Object _result;
+    private final SEALServer _seal_server;
+
+    public static HEParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
+                                          Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
+                                          MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
+    {
+        NativeHEHelper.initialize();
+        return new HEParamServer(model, aggFunc, updateType, freq, ec,
+                workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches);
+    }
+
+    private HEParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
+                             Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
+                             MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
+    {
+        super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, true);
+
+        _seal_server = new SEALServer();
+
+        _threads = Collections.synchronizedList(new ArrayList<>(workerNum));
+        for (int i = 0; i < getNumWorkers(); i++) {
+            _threads.add(null);
+        }
+
+        _result_buffer = new ArrayList<>(workerNum);
+        resetResultBuffer();
+    }
+
+    public void registerThread(int thread_id, FederatedPSControlThread thread) {
+        _threads.set(thread_id, thread);
+    }
+
+    private synchronized void resetResultBuffer() {
+        _result_buffer.clear();
+        for (int i = 0; i < getNumWorkers(); i++) {
+            _result_buffer.add(null);
+        }
+    }
+
+    public byte[] generateA() {
+        return _seal_server.generateA();
+    }
+
+    public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) {
+        return _seal_server.aggregatePartialPublicKeys(partial_public_keys);
+    }
+
+    /**
+     * this method collects all T Objects from each worker into a list and then calls f once on this list to produce
+     * another T, which it returns.
+     */
+    private synchronized <T,U> U collectAndDo(int workerId, T obj, Function<List<T>, U> f) {
+        _result_buffer.set(workerId, obj);
+        _thread_counter++;
+
+        if (_thread_counter == getNumWorkers()) {
+            List<T> buf = _result_buffer.stream().map(x -> (T)x).collect(Collectors.toList());
+            _result = f.apply(buf);
+            resetResultBuffer();
+            _thread_counter = 0;
+            notifyAll();
+        } else {
+            try {
+                wait();
+            } catch (InterruptedException i) {
+                throw new RuntimeException("thread interrupted");
+            }
+        }
+
+        return (U) _result;
+    }
+
+    private CiphertextMatrix[] homomorphicAggregation(List<ListObject> encrypted_models) {
+        Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+        CiphertextMatrix[] result = new CiphertextMatrix[encrypted_models.get(0).getLength()];
+        IntStream.range(0, encrypted_models.get(0).getLength()).forEach(matrix_idx -> {
+            CiphertextMatrix[] summands = new CiphertextMatrix[encrypted_models.size()];
+            for (int i = 0; i < encrypted_models.size(); i++) {
+                summands[i] = (CiphertextMatrix) encrypted_models.get(i).getData(matrix_idx);
+            }
+            result[matrix_idx] = _seal_server.accumulateCiphertexts(summands);;
+        });
+        if (tAgg != null) {
+            ParamServStatistics.accHEAccumulation((long)tAgg.stop());
+        }
+        return result;
+    }
+
+    private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, List<PlaintextMatrix[]> partial_decryptions) {
+        Timing tDecrypt = DMLScript.STATISTICS ? new Timing(true) : null;
+
+        MatrixObject[] result = new MatrixObject[partial_decryptions.get(0).length];
+
+        IntStream.range(0, partial_decryptions.get(0).length).forEach(matrix_idx -> {
+            PlaintextMatrix[] partial_plaintexts = new PlaintextMatrix[partial_decryptions.size()];
+            for (int i = 0; i < partial_decryptions.size(); i++) {
+                partial_plaintexts[i] = partial_decryptions.get(i)[matrix_idx];
+            }
+
+            result[matrix_idx] = _seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts);
+        });
+
+        ListObject old_model = getResult();
+        ListObject new_model = new ListObject(old_model);
+        for (int i = 0; i < new_model.getLength(); i++) {
+            new_model.set(i, result[i]);
+        }
+
+        if (tDecrypt != null) {
+            ParamServStatistics.accHEDecryptionTime((long)tDecrypt.stop());
+        }
+
+        updateAndBroadcastModel(new_model, null);
+        return null;
+    }
+
+    // this is only to be used in push()
+    private Timing commTimer;
+    private void startCommTimer() {
+        commTimer = new Timing(true);
+    }
+    private long stopCommTimer() {
+        return (long)commTimer.stop();
+    }
+    // ---------------------------------
+
+    @Override
+    public void push(int workerID, ListObject encrypted_model) {
+        // wait for all updates and sum them homomorphically
+        CiphertextMatrix[] homomorphic_sum = collectAndDo(workerID, encrypted_model, x -> {
+            CiphertextMatrix[] res = this.homomorphicAggregation(x);
+            this.startCommTimer();
+            return res;
+        });
+
+        // get partial decryptions
+        PlaintextMatrix[] partial_decryption = _threads.get(workerID).getPartialDecryption(homomorphic_sum);
+
+        // do average and update global model
+        collectAndDo(workerID, partial_decryption, x -> {
+            ParamServStatistics.accFedNetworkTime(this.stopCommTimer());
+            return this.homomorphicAverage(homomorphic_sum, x);
+        });
+    }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
index 50c76a0f42..9fd49ca0d1 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -39,7 +39,7 @@ public class LocalParamServer extends ParamServer {
 			workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
 	}
 
-	private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
+	protected LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
 		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
 		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
 	{
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java
new file mode 100644
index 0000000000..38e4dec553
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv;
+
+import org.apache.commons.lang.SystemUtils;
+import org.apache.sysds.utils.NativeHelper;
+
+public class NativeHEHelper {
+    public static boolean initialize() {
+        String platform_suffix = (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so");
+        String library_name = "libhe" + platform_suffix;
+        return NativeHelper.loadLibraryHelperFromResource(library_name);
+    }
+
+    // ----------------------------------------------------------------------------------------------------------------
+    // SEAL integration
+    // ----------------------------------------------------------------------------------------------------------------
+
+    // these are called by SEALClient
+
+    /**
+     * generates a Client object
+     * @param a a constant generated by generateA
+     * @return a pointer to the native client object
+     */
+    public static native long initClient(byte[] a);
+
+    /**
+     * generates a partial public key
+     * stores a partial private key corresponding to the partial public key in client
+     * @param client A pointer to a Client, obtained from initClient
+     * @return a serialized partial public key
+     */
+    public static native byte[] generatePartialPublicKey(long client);
+
+    /**
+     * sets the public key and stores it in client
+     * @param client A pointer to a Client, obtained from initClient
+     * @param public_key serialized public key obtained from generatePartialPublicKey
+     */
+    public static native void setPublicKey(long client, byte[] public_key);
+
+    /**
+     * encrypts data with public key stored in client
+     * setPublicKey() must have been called before calling this
+     * @param client A pointer to a Client, obtained from initClient
+     * @param plaintexts array of double values to be encrypted
+     * @return serialized ciphertext
+     */
+    public static native byte[] encrypt(long client, double[] plaintexts);
+
+    /**
+     * partially decrypts ciphertexts with the partial private key. generatePartialPublicKey() must
+     * have been called before calling this function
+     * @param client A pointer to a Client, obtained from initClient
+     * @param ciphertext serialized ciphertext
+     * @return serialized partial decryption
+     */
+    public static native byte[] partiallyDecrypt(long client, byte[] ciphertext);
+
+    // ----------------------------------------------------------------------------------------------------------------
+
+    // these are called by SEALServer
+
+    /**
+     * generates the Server object and returns a pointer to it
+     * @return pointer to a native Server object
+     */
+    public static native long initServer();
+
+    /**
+     * this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse
+     * @param server A pointer to a Server, obtained from initServer
+     * @return serialized a constant
+     */
+    public static native byte[] generateA(long server);
+
+    /**
+     * accumulates the given partial public keys into a public key, stores it in server and returns it
+     * @param server A pointer to a Server, obtained from initServer
+     * @param partial_public_keys array of serialized partial public keys
+     * @return serialized partial public key
+     */
+    public static native byte[] aggregatePartialPublicKeys(long server, byte[][] partial_public_keys);
+
+    /**
+     * accumulates the given ciphertexts into a sum ciphertext and returns it
+     * @param server A pointer to a Server, obtained from initServer
+     * @param ciphertexts array of serialized ciphertexts
+     * @return serialized accumulated ciphertext
+     */
+    public static native byte[] accumulateCiphertexts(long server, byte[][] ciphertexts);
+
+    /**
+     * averages the partial decryptions and returns the result
+     * @param server A pointer to a Server, obtained from initServer
+     * @param encrypted_sum the result of accumulateCiphertexts()
+     * @param partial_plaintexts the result of partiallyDecrypt of each ciphertext fed into accumulateCiphertexts
+     * @return average of original data
+     */
+    public static native double[] average(long server, byte[] encrypted_sum, byte[][] partial_plaintexts);
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
new file mode 100644
index 0000000000..f823b9d3be
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv;
+
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.traffic.ChannelTrafficShapingHandler;
+import java.util.function.BiConsumer;
+
+public class NetworkTrafficCounter extends ChannelTrafficShapingHandler {
+    private final BiConsumer<Long, Long> _fn; // (read, written) -> Void, logs bytes read and written
+    public NetworkTrafficCounter(BiConsumer<Long, Long> fn) {
+        // checkInterval of zero means that doAccounting will not be called
+        super( 0);
+        _fn = fn;
+    }
+
+    // log bytes read/written after channel is closed
+    @Override
+    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+        _fn.accept(trafficCounter.cumulativeReadBytes(), trafficCounter.cumulativeWrittenBytes());
+        trafficCounter.resetCumulativeTime();
+        super.channelInactive(ctx);
+    }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 009dc20a33..0e09fabf30 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -19,10 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
+import java.util.*;
 import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
 import java.util.stream.Collectors;
@@ -326,30 +323,8 @@ public abstract class ParamServer
 					_accModels = ParamservUtils.accrueGradients(_accModels, weightParams, true);
 
 					if(allFinished()) {
-						_model = setParams(_ec, _accModels, _model);
-						if (DMLScript.STATISTICS && tAgg != null)
-							ParamServStatistics.accAggregationTime((long) tAgg.stop());
-						_accModels = null; //reset for next accumulation
-
-						// This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch
-						// In the BSP batch case that occurs after the sync counter reaches the number of batches and in the
-						// BSP epoch case every time
-						if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
-
-							if(LOG.isInfoEnabled())
-								LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter);
-							time_epoch();
-							if(_validationPossible) {
-								validate();
-							}
-							_epochCounter++;
-							_syncCounter = 0;
-						}
-						// Broadcast the updated model
+						updateAndBroadcastModel(_accModels, tAgg);
 						resetFinishedStates();
-						broadcastModel(true);
-						if(LOG.isDebugEnabled())
-							LOG.debug("Global parameter is broadcasted successfully ");
 					}
 					break;
 				}
@@ -365,7 +340,33 @@ public abstract class ParamServer
 		}
 	}
 
-	protected  ListObject weightModels(ListObject params, int numWorkers) {
+	protected void updateAndBroadcastModel(ListObject new_model, Timing tAgg) {
+		_model = setParams(_ec, new_model, _model);
+		if (DMLScript.STATISTICS && tAgg != null)
+			ParamServStatistics.accAggregationTime((long) tAgg.stop());
+		_accModels = null; //reset for next accumulation
+
+		// This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch
+		// In the BSP batch case that occurs after the sync counter reaches the number of batches and in the
+		// BSP epoch case every time
+		if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+
+			if(LOG.isInfoEnabled())
+				LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter);
+			time_epoch();
+			if(_validationPossible) {
+				validate();
+			}
+			_epochCounter++;
+			_syncCounter = 0;
+		}
+		// Broadcast the updated model
+		broadcastModel(true);
+		if(LOG.isDebugEnabled())
+			LOG.debug("Global parameter is broadcasted successfully ");
+	}
+
+	protected ListObject weightModels(ListObject params, int numWorkers) {
 		double _averagingFactor = 1d / numWorkers;
 
 		if( _averagingFactor != 1) {
@@ -472,6 +473,10 @@ public abstract class ParamServer
 			ParamServStatistics.accValidationTime((long) tValidate.stop());
 	}
 
+	public int getNumWorkers() {
+		return _numWorkers;
+	}
+
 	public FunctionCallCPInstruction getAggInst() {
 		return _inst;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index 5bb3e12dca..96979e3a5d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -87,6 +87,7 @@ public abstract class DataPartitionFederatedScheme {
 						new MatrixCharacteristics(range.getSize(0), range.getSize(1)),
 						Types.FileFormat.BINARY)
 				);
+				slice.setPrivacyConstraints(fedMatrix.getPrivacyConstraint());
 
 				// Create new federation map
 				List<Pair<FederatedRange, FederatedData>> newFedHashMap = new ArrayList<>();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java
new file mode 100644
index 0000000000..96fd415308
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption;
+
+import java.io.Serializable;
+
+public class PublicKey implements Serializable {
+    private static final long serialVersionUID = 91289081237980123L;
+
+    private final byte[] _data;
+
+    public PublicKey(byte[] data) {
+        _data = data;
+    }
+
+    public byte[] getData() {
+        return _data;
+    }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java
new file mode 100644
index 0000000000..935f2808af
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
+import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.stream.IntStream;
+
+public class SEALClient {
+    public SEALClient(byte[] a) {
+        // TODO take params here, like slot_count etc.
+        ctx = NativeHEHelper.initClient(a);
+    }
+
+    // this is a pointer to the context used by all native methods of this class
+    private final long ctx;
+
+
+    /**
+     * generates a partial public key
+     * stores a partial private key corresponding to the partial public key in ctx
+     *
+     * @return the partial public key
+     */
+    public PublicKey generatePartialPublicKey() {
+        return new PublicKey(NativeHEHelper.generatePartialPublicKey(ctx));
+    }
+
+    /**
+     * sets the public key and stores it in ctx
+     *
+     * @param public_key the public key to set
+     */
+    public void setPublicKey(PublicKey public_key) {
+        NativeHEHelper.setPublicKey(ctx, public_key.getData());
+    }
+
+    /**
+     * encrypts one block of data with public key stored statically and returns it
+     * setPublicKey() must have been called before calling this
+     * @param plaintext the MatrixObject to encrypt
+     * @return the encrypted matrix
+     */
+    public CiphertextMatrix encrypt(MatrixObject plaintext) {
+        MatrixBlock mb = plaintext.acquireReadAndRelease();
+        if (mb.isInSparseFormat()) {
+            mb.allocateSparseRowsBlock();
+            mb.sparseToDense();
+        }
+        DenseBlock db = mb.getDenseBlock();
+        int[] dims = IntStream.range(0, db.numDims()).map(db::getDim).toArray();
+        double[] raw_data = mb.getDenseBlockValues();
+        return new CiphertextMatrix(dims, plaintext.getDataCharacteristics(), NativeHEHelper.encrypt(ctx, raw_data));
+    }
+
+    /**
+     * partially decrypts ciphertext with the partial private key. generatePartialPublicKey() must
+     * have been called before calling this function
+     *
+     * @param ciphertext the ciphertext to partially decrypt
+     * @return the partial decryption of ciphertext
+     */
+    public PlaintextMatrix partiallyDecrypt(CiphertextMatrix ciphertext) {
+        return new PlaintextMatrix(ciphertext.getDims(), ciphertext.getDataCharacteristics(), NativeHEHelper.partiallyDecrypt(ctx, ciphertext.getData()));
+    }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java
new file mode 100644
index 0000000000..d6265c7f6d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.DenseBlockFactory;
+import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
+import org.apache.sysds.runtime.instructions.cp.Encrypted;
+import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+
+import java.util.Arrays;
+
+public class SEALServer {
+    public SEALServer() {
+        // TODO take params here, like slot_count etc.
+        ctx = NativeHEHelper.initServer();
+    }
+
+    // this is a pointer to the context used by all native methods of this class
+    private final long ctx;
+    private byte[] _a;
+
+    /**
+     * this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse
+     * @return serialized a constant
+     */
+    public synchronized byte[] generateA() {
+        if (_a == null) {
+            _a = NativeHEHelper.generateA(ctx);
+        }
+        return _a;
+    }
+
+    /**
+     * accumulates the given partial public keys into a public key, stores it in ctx and returns it
+     * @param partial_public_keys an array of partial public keys generated with SEALServer::generatePartialPublicKey
+     * @return the aggregated public key
+     */
+    public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) {
+        return new PublicKey(NativeHEHelper.aggregatePartialPublicKeys(ctx, extractRawData(partial_public_keys)));
+    }
+
+    /**
+     * accumulates the given ciphertext blocks into a sum ciphertext and returns it
+     * @param ciphertexts ciphertexts encrypted with the partial public keys
+     * @return the accumulated ciphertext (which is the homomorphic sum of ciphertexts)
+     */
+    public CiphertextMatrix accumulateCiphertexts(CiphertextMatrix[] ciphertexts) {
+        return new CiphertextMatrix(ciphertexts[0].getDims(), ciphertexts[0].getDataCharacteristics(), NativeHEHelper.accumulateCiphertexts(ctx, extractRawData(ciphertexts)));
+    }
+
+    /**
+     * averages the partial decryptions
+     * @param encrypted_sum is the result of accumulateCiphertexts()
+     * @param partial_plaintexts is the result of SEALServer::partiallyDecrypt of each ciphertext fed into accumulateCiphertexts
+     * @return the unencrypted, element-wise average of the original matrices
+     */
+    public MatrixObject average(CiphertextMatrix encrypted_sum, PlaintextMatrix[] partial_plaintexts) {
+        double[] raw_result = NativeHEHelper.average(ctx, encrypted_sum.getData(), extractRawData(partial_plaintexts));
+        int[] dims = encrypted_sum.getDims();
+        int result_len = Arrays.stream(dims).reduce(1, (x,y) -> x*y);
+        DataCharacteristics dc = encrypted_sum.getDataCharacteristics();
+
+        DenseBlock new_dense_block = DenseBlockFactory.createDenseBlock(Arrays.copyOf(raw_result, result_len), dims);
+        MatrixBlock new_matrix_block = new MatrixBlock((int)dc.getRows(), (int)dc.getCols(), new_dense_block);
+        MatrixObject new_mo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(dc, Types.FileFormat.BINARY));
+        new_mo.acquireModify(new_matrix_block);
+        new_mo.release();
+        return new_mo;
+    }
+
+    private static byte[][] extractRawData(Encrypted[] data) {
+        byte[][] raw_data = new byte[data.length][];
+        for (int i = 0; i < data.length; i++) {
+            raw_data[i] = data[i].getData();
+        }
+        return raw_data;
+    }
+
+    // TODO: extract an interface for this and use it here
+    private static byte[][] extractRawData(PublicKey[] data) {
+        byte[][] raw_data = new byte[data.length][];
+        for (int i = 0; i < data.length; i++) {
+            raw_data[i] = data[i].getData();
+        }
+        return raw_data;
+    }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java
new file mode 100644
index 0000000000..1cbef9d488
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+/**
+ * This class abstracts over an encrypted matrix of ciphertexts. It stores the data as opaque byte array. The layout is unspecified.
+ */
+public class CiphertextMatrix extends Encrypted {
+    private static final long serialVersionUID = 1762936872261940616L;
+
+    public CiphertextMatrix(int[] dims, DataCharacteristics dc, byte[] data) {
+        super(dims, dc, data, Types.DataType.ENCRYPTED_CIPHER);
+    }
+
+    @Override
+    public String getDebugName() {
+        return "CiphertextMatrix " + getData().hashCode();
+    }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java
new file mode 100644
index 0000000000..eb7d1ea44a
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+/**
+ * This class abstracts over an encrypted data. It stores the data as opaque byte array. The layout is unspecified.
+ */
+public abstract class Encrypted extends Data {
+    private static final long serialVersionUID = 1762936872268046168L;
+
+    private final int[] _dims;
+    private final DataCharacteristics _dc;
+    private final byte[] _data;
+
+    public Encrypted(int[] dims, DataCharacteristics dc, byte[] data, Types.DataType dt) {
+        super(dt, Types.ValueType.UNKNOWN);
+        _dims = dims;
+        _dc = dc;
+        _data = data;
+    }
+
+    public int[] getDims() {
+        return _dims;
+    }
+
+    public DataCharacteristics getDataCharacteristics() {
+        return _dc;
+    }
+
+    public byte[] getData() {
+        return _data;
+    }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 38288178e4..5c302fe80a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -397,6 +397,18 @@ public class ListObject extends Data implements Externalizable {
 					ScalarObject so = (ScalarObject) d;
 					out.writeObject(so.getStringValue());
 					break;
+				case ENCRYPTED_CIPHER:
+				case ENCRYPTED_PLAIN:
+					Encrypted e = (Encrypted) d;
+					int[] dims = e.getDims();
+					dc = e.getDataCharacteristics();
+					out.writeObject(dims);
+					out.writeObject(dc.getRows());
+					out.writeObject(dc.getCols());
+					out.writeObject(dc.getBlocksize());
+					out.writeObject(dc.getNonZeros());
+					out.writeObject(e.getData());
+					break;
 				default:
 					throw new DMLRuntimeException("Unable to serialize datatype " + dataType);
 			}
@@ -463,6 +475,21 @@ public class ListObject extends Data implements Externalizable {
 					}
 					d = so;
 					break;
+				case ENCRYPTED_CIPHER:
+				case ENCRYPTED_PLAIN:
+					int[] dims = (int[]) in.readObject();
+					rows = (long) in.readObject();
+					cols = (long) in.readObject();
+					blockSize = (int) in.readObject();
+					nonZeros = (long) in.readObject();
+					byte[] data = (byte[])in.readObject();
+					DataCharacteristics dc = new MatrixCharacteristics(rows, cols, blockSize, nonZeros);
+					if (dataType == DataType.ENCRYPTED_CIPHER) {
+						d = new CiphertextMatrix(dims, dc, data);
+					} else {
+						d = new PlaintextMatrix(dims, dc, data);
+					}
+					break;
 				default:
 					throw new DMLRuntimeException("Unable to deserialize datatype " + dataType);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 25353f6003..d16aa9ec4e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,28 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN;
-import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE;
-import static org.apache.sysds.parser.Statement.PS_EPOCHS;
-import static org.apache.sysds.parser.Statement.PS_FEATURES;
-import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
-import static org.apache.sysds.parser.Statement.PS_FED_WEIGHTING;
-import static org.apache.sysds.parser.Statement.PS_FREQUENCY;
-import static org.apache.sysds.parser.Statement.PS_HYPER_PARAMS;
-import static org.apache.sysds.parser.Statement.PS_LABELS;
-import static org.apache.sysds.parser.Statement.PS_MODE;
-import static org.apache.sysds.parser.Statement.PS_MODEL;
-import static org.apache.sysds.parser.Statement.PS_MODELAVG;
-import static org.apache.sysds.parser.Statement.PS_NBATCHES;
-import static org.apache.sysds.parser.Statement.PS_PARALLELISM;
-import static org.apache.sysds.parser.Statement.PS_SCHEME;
-import static org.apache.sysds.parser.Statement.PS_SEED;
-import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
-import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
-import static org.apache.sysds.parser.Statement.PS_VAL_FEATURES;
-import static org.apache.sysds.parser.Statement.PS_VAL_FUN;
-import static org.apache.sysds.parser.Statement.PS_VAL_LABELS;
-
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
@@ -71,25 +49,22 @@ import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread;
-import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
-import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer;
-import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
-import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
-import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
-import org.apache.sysds.runtime.controlprogram.paramserv.SparkParamservUtils;
+import org.apache.sysds.runtime.controlprogram.paramserv.*;
 import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
 import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
 import org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner;
 import org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey;
 import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.runtime.util.ProgramConverter;
 import org.apache.sysds.utils.stats.ParamServStatistics;
 
+import static org.apache.sysds.parser.Statement.*;
+
 public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction {
 	private static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());
 
@@ -102,6 +77,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
 	public static final int DEFAULT_NBATCHES = 1;
 	private static final Boolean DEFAULT_MODELAVG = false;
+	private static final Boolean DEFAULT_HE = false;
 
 	public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) {
 		super(op, paramsMap, out, opcode, istr);
@@ -188,23 +164,56 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) ? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
 		MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
 		boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
-		ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
-			getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg);
+
+		// check if we need homomorphic encryption
+		boolean use_homomorphic_encryption_ = getHe();
+		for (int i = 0; i < workerNum; i++) {
+			use_homomorphic_encryption_ = use_homomorphic_encryption_ || checkIsPrivate(result._pFeatures.get(i));
+			use_homomorphic_encryption_ = use_homomorphic_encryption_ || checkIsPrivate(result._pLabels.get(i));
+		}
+		final boolean use_homomorphic_encryption = use_homomorphic_encryption_;
+		if (use_homomorphic_encryption && !modelAvg) {
+			throw new DMLRuntimeException("can't use homomorphic encryption without modelAvg");
+		}
+
+		if (use_homomorphic_encryption && weighting) {
+			throw new DMLRuntimeException("can't use homomorphic encryption with weighting");
+		}
+
+		LocalParamServer ps = (LocalParamServer)createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
+			getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg, use_homomorphic_encryption);
 		// Create the local workers
 		int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
 		List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
 			.mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighting,
-				getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, nbatches, modelAvg))
+				getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, nbatches, modelAvg, use_homomorphic_encryption))
 			.collect(Collectors.toList());
 		if(workerNum != threads.size()) {
 			throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
 		}
+
 		// Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers
 		for (int i = 0; i < threads.size(); i++) {
 			threads.get(i).setFeatures(result._pFeatures.get(i));
 			threads.get(i).setLabels(result._pLabels.get(i));
 			threads.get(i).setup(result._weightingFactors.get(i));
 		}
+
+		if (use_homomorphic_encryption) {
+			// generate public key from partial public keys
+			PublicKey[] partial_public_keys = new PublicKey[threads.size()];
+			for (int i = 0; i < threads.size(); i++) {
+				partial_public_keys[i] = threads.get(i).getPartialPublicKey();
+			}
+
+			// TODO: accumulate public keys with SEAL
+			PublicKey public_key = ((HEParamServer)ps).aggregatePartialPublicKeys(partial_public_keys);
+
+			for (FederatedPSControlThread thread : threads) {
+				thread.setPublicKey(public_key);
+			}
+		}
+
 		if (DMLScript.STATISTICS)
 			ParamServStatistics.accSetupTime((long) tSetup.stop());
 
@@ -479,21 +488,32 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	 * @return parameter server
 	 */
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
-		PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg)
+										PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg)
 	{
 		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, nbatches, modelAvg);
 	}
 
+
+	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
+										PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc,
+										int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)	{
+		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg, false);
+	}
+
 	// When this creation is used the parameter server is able to validate after each epoch
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
 		PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc,
-		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
+		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, boolean use_homomorphic_encryption)
 	{
 		switch (mode) {
 			case FEDERATED:
 			case LOCAL:
 			case REMOTE_SPARK:
-				return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
+				if (use_homomorphic_encryption) {
+					return HEParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches);
+				} else {
+					return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
+				}
 			default:
 				throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
 		}
@@ -614,4 +634,15 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		}
 		return Integer.parseInt(getParam(PS_NBATCHES));
 	}
+
+	private boolean checkIsPrivate(MatrixObject obj) {
+		PrivacyConstraint pc = obj.getPrivacyConstraint();
+		return pc != null && pc.hasPrivateElements();
+	}
+
+	private boolean getHe() {
+		if(!getParameterMap().containsKey(PS_HE))
+			return DEFAULT_HE;
+		return Boolean.parseBoolean(getParam(PS_HE));
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
new file mode 100644
index 0000000000..6fe2b3814f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+/**
+ * This class abstracts over an encrypted matrix of ciphertexts. It stores the data as opaque byte array. The layout is unspecified.
+ */
+public class PlaintextMatrix extends Encrypted {
+    private static final long serialVersionUID = 5732436872261940616L;
+
+    public PlaintextMatrix(int[] dims, DataCharacteristics dc, byte[] data) {
+        super(dims, dc, data, Types.DataType.ENCRYPTED_PLAIN);
+    }
+
+    @Override
+    public String getDebugName() {
+        return "PlaintextMatrix " + getData().hashCode();
+    }
+}
diff --git a/src/main/java/org/apache/sysds/utils/NativeHelper.java b/src/main/java/org/apache/sysds/utils/NativeHelper.java
index 83869c23d2..e86bd56b19 100644
--- a/src/main/java/org/apache/sysds/utils/NativeHelper.java
+++ b/src/main/java/org/apache/sysds/utils/NativeHelper.java
@@ -44,14 +44,14 @@ import org.apache.commons.lang.SystemUtils;
  * By default, it first tries to load Intel MKL, else tries to load OpenBLAS.
  */
 public class NativeHelper {
-	
+
 	public enum NativeBlasState {
 		NOT_ATTEMPTED_LOADING_NATIVE_BLAS,
 		SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE,
 		SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE,
 		ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY
 	}
-	
+
 	public static NativeBlasState CURRENT_NATIVE_BLAS_STATE = NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS;
 	private static String blasType;
 
@@ -63,16 +63,16 @@ public class NativeHelper {
 
 	/**
 	 * Called by Statistics to print the loaded BLAS.
-	 * 
+	 *
 	 * @return empty string or the BLAS that is loaded
 	 */
 	public static String getCurrentBLAS() {
 		return blasType != null ? blasType : "";
 	}
-	
+
 	/**
 	 * Called by runtime to check if the BLAS is available for exploitation
-	 * 
+	 *
 	 * @return true if CURRENT_NATIVE_BLAS_STATE is SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE else false
 	 */
 	public static boolean isNativeLibraryLoaded() {
@@ -99,10 +99,10 @@ public class NativeHelper {
 		}
 		return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
 	}
-	
+
 	/**
-	 * Initialize the native library before executing the DML program 
-	 * 
+	 * Initialize the native library before executing the DML program
+	 *
 	 * @param customLibPath specified by sysds.native.blas.directory
 	 * @param userSpecifiedBLAS specified by sysds.native.blas
 	 */
@@ -121,22 +121,22 @@ public class NativeHelper {
 			performLoading(customLibPath, userSpecifiedBLAS);
 		}
 	}
-	
+
 	/**
 	 * Return true if the given BLAS type is supported.
-	 * 
+	 *
 	 * @param userSpecifiedBLAS BLAS type specified via sysds.native.blas property
 	 * @return true if the userSpecifiedBLAS is auto | mkl | openblas, else false
 	 */
 	private static boolean isSupportedBLAS(String userSpecifiedBLAS) {
-		return userSpecifiedBLAS.equalsIgnoreCase("auto") || 
-				userSpecifiedBLAS.equalsIgnoreCase("mkl") || 
+		return userSpecifiedBLAS.equalsIgnoreCase("auto") ||
+				userSpecifiedBLAS.equalsIgnoreCase("mkl") ||
 				userSpecifiedBLAS.equalsIgnoreCase("openblas");
 	}
-	
+
 	/**
 	 * Note: we only support 64 bit Java on x86 and AMD machine
-	 * 
+	 *
 	 * @return true if the hardware architecture is supported
 	 */
 	private static boolean isSupportedArchitecture() {
@@ -166,21 +166,21 @@ public class NativeHelper {
 	 * 		   SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE
 	 */
 	private static boolean isBLASLoaded() {
-		return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE || 
+		return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE ||
 				CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE;
 	}
-	
+
 	/**
 	 * Check if we should attempt to perform loading.
 	 * If custom library path is provided, we should attempt to load again if not already loaded.
-	 * 
-	 * @param customLibPath custom library path 
+	 *
+	 * @param customLibPath custom library path
 	 * @return true if we should attempt to load blas again
 	 */
 	private static boolean shouldReload(String customLibPath) {
 		boolean isValidBLASDirectory = customLibPath != null && !customLibPath.equalsIgnoreCase("none");
 		return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS ||
-			   (isValidBLASDirectory && !isBLASLoaded());
+				(isValidBLASDirectory && !isBLASLoaded());
 	}
 
 	// Performing loading in a method instead of a static block will throw a detailed stack trace in case of fatal errors
@@ -191,13 +191,13 @@ public class NativeHelper {
 		// attemptedLoading variable ensures that we don't try to load SystemDS and other dependencies
 		// again and again especially in the parfor (hence the double-checking with synchronized).
 		if(shouldReload(customLibPath) && isSupportedBLAS(userSpecifiedBLAS) && isSupportedArchitecture()
-			&& isSupportedOS()) {
+				&& isSupportedOS()) {
 			long start = System.nanoTime();
 			synchronized(NativeHelper.class) {
 				if(shouldReload(customLibPath)) {
 					// Set attempted loading unsuccessful in case of exception
 					CURRENT_NATIVE_BLAS_STATE = NativeBlasState.ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY;
-					String [] blas = new String[] { userSpecifiedBLAS }; 
+					String [] blas = new String[] { userSpecifiedBLAS };
 					if(userSpecifiedBLAS.equalsIgnoreCase("auto")) {
 						blas = new String[] { "mkl", "openblas" };
 					}
@@ -206,7 +206,7 @@ public class NativeHelper {
 						String platform_suffix = (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so");
 						String library_name = "libsystemds_" + blasType + platform_suffix;
 						if(loadLibraryHelperFromResource(library_name) ||
-						   loadBLAS(customLibPath, library_name,"Loading native helper with customLibPath."))
+								loadBLAS(customLibPath, library_name,"Loading native helper with customLibPath."))
 						{
 							LOG.info("Using native blas: " + blasType + getNativeBLASPath());
 							CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
@@ -215,15 +215,15 @@ public class NativeHelper {
 				}
 			}
 			double timeToLoadInMilliseconds = (System.nanoTime()-start)*1e-6;
-			if(timeToLoadInMilliseconds > 1000) 
+			if(timeToLoadInMilliseconds > 1000)
 				LOG.warn("Time to load native blas: " + timeToLoadInMilliseconds + " milliseconds.");
 		}
 		else if(LOG.isDebugEnabled() && !isSupportedBLAS(userSpecifiedBLAS)) {
 			LOG.debug("Using internal Java BLAS as native BLAS support instead of the configuration " +
-				"'sysds.native.blas'=" + userSpecifiedBLAS + ".");
+					"'sysds.native.blas'=" + userSpecifiedBLAS + ".");
 		}
 	}
-	
+
 	private static boolean checkAndLoadBLAS(String customLibPath, String [] listBLAS) {
 		if(customLibPath != null && customLibPath.equalsIgnoreCase("none"))
 			customLibPath = null;
@@ -250,10 +250,10 @@ public class NativeHelper {
 		}
 		return isLoaded;
 	}
-	
+
 	/**
 	 * Useful method for debugging.
-	 * 
+	 *
 	 * @return empty string (if !LOG.isDebugEnabled()) or the path from where openblas or mkl is loaded.
 	 */
 	private static String getNativeBLASPath() {
@@ -287,8 +287,8 @@ public class NativeHelper {
 
 	/**
 	 * Attempts to load native BLAS
-	 * 
-	 * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the 
+	 *
+	 * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the
 	 * @param blas can be gomp, openblas or mkl_rt
 	 * @param optionalMsg message for debugging
 	 * @return true if successfully loaded BLAS
@@ -300,8 +300,8 @@ public class NativeHelper {
 			try {
 				// This fixes libPath if it already contained a prefix/suffix and mapLibraryName added another one.
 				libPath = libPath.replace("liblibsystemds", "libsystemds")
-								 .replace(".dll.dll", ".dll")
-								 .replace(".so.so", ".so");
+						.replace(".dll.dll", ".dll")
+						.replace(".so.so", ".so");
 				System.load(libPath);
 				LOG.info("Loaded the library:" + libPath);
 				return true;
@@ -321,7 +321,7 @@ public class NativeHelper {
 		catch (UnsatisfiedLinkError e) {
 			LOG.debug("java.library.path: " + System.getProperty("java.library.path"));
 			LOG.debug("Unable to load " + blas + (optionalMsg == null ? "" : (" (" + optionalMsg + ")")) +
-				" \n Message from exception was: " + e.getMessage());
+					" \n Message from exception was: " + e.getMessage());
 			return false;
 		}
 	}
@@ -355,13 +355,13 @@ public class NativeHelper {
 	}
 
 	// TODO: Add pmm, wsloss, mmchain, etc.
-	
+
 	//double-precision matrix multiply dense-dense
 	public static native long dmmdd(double [] m1, double [] m2, double [] ret, int m1rlen, int m1clen, int m2clen,
-									   int numThreads);
+									int numThreads);
 	//single-precision matrix multiply dense-dense
 	public static native long smmdd(FloatBuffer m1, FloatBuffer m2, FloatBuffer ret, int m1rlen, int m1clen, int m2clen,
-									   int numThreads);
+									int numThreads);
 	//transpose-self matrix multiply
 	public static native long tsmm(double[] m1, double[] ret, int m1rlen, int m1clen, boolean leftTrans, int numThreads);
 
@@ -374,27 +374,27 @@ public class NativeHelper {
 
 	// Returns -1 if failures or returns number of nonzeros
 	// Called by DnnCPInstruction if both input and filter are dense
-	public static native long conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W, 
-			int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
+	public static native long conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W,
+										  int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
 
 	public static native long dconv2dBiasAddDense(double [] input, double [] bias, double [] filter, double [] ret, int N,
-		int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
-												 int numThreads);
+												  int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
+												  int numThreads);
 
 	public static native long sconv2dBiasAddDense(FloatBuffer input, FloatBuffer bias, FloatBuffer filter, FloatBuffer ret,
-		int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
-												 int numThreads);
+												  int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
+												  int numThreads);
 
 	// Called by DnnCPInstruction if both input and filter are dense
 	public static native long conv2dBackwardFilterDense(double [] input, double [] dout, double [] ret, int N, int C,
-													   int H, int W, int K, int R, int S, int stride_h, int stride_w,
-													   int pad_h, int pad_w, int P, int Q, int numThreads);
+														int H, int W, int K, int R, int S, int stride_h, int stride_w,
+														int pad_h, int pad_w, int P, int Q, int numThreads);
 
 	// If both filter and dout are dense, then called by DnnCPInstruction
 	// Else, called by LibMatrixDNN's thread if filter is dense. dout[n] is converted to dense if sparse.
 	public static native long conv2dBackwardDataDense(double [] filter, double [] dout, double [] ret, int N, int C,
-													 int H, int W, int K, int R, int S, int stride_h, int stride_w,
-													 int pad_h, int pad_w, int P, int Q, int numThreads);
+													  int H, int W, int K, int R, int S, int stride_h, int stride_w,
+													  int pad_h, int pad_w, int P, int Q, int numThreads);
 
 	// Currently only supported with numThreads = 1 and sparse input
 	// Called by LibMatrixDNN's thread if input is sparse. dout[n] is converted to dense if sparse.
@@ -415,4 +415,4 @@ public class NativeHelper {
 	// different tradeoffs. In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be
 	// fastest. We can revisit this decision later and hence I would not recommend removing this method.
 	private static native void setMaxNumThreads(int numThreads);
-}
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index 77b63a9921..aece9b655a 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -674,6 +674,8 @@ public class Statistics
 		if(DMLScript.FED_STATISTICS) {
 			sb.append("\n");
 			sb.append(FederatedStatistics.displayStatistics(DMLScript.FED_STATISTICS_COUNT));
+			sb.append("\n");
+			sb.append(ParamServStatistics.displayFloStatistics());
 		}
 
 		return sb.toString();
diff --git a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
index 0d97bfd0c6..8eb26a1963 100644
--- a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
@@ -21,6 +21,7 @@ package org.apache.sysds.utils.stats;
 
 import java.util.concurrent.atomic.LongAdder;
 
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 
 public class ParamServStatistics {
@@ -41,6 +42,14 @@ public class ParamServStatistics {
 	private static final LongAdder fedWorkerComputingTime = new LongAdder();
 	private static final LongAdder fedGradientWeightingTime = new LongAdder();
 	private static final LongAdder fedCommunicationTime = new LongAdder();
+	private static final LongAdder fedNetworkTime = new LongAdder(); // measures exactly how long it takes netty to send & receive data
+	// Homomorphic encryption specifics (time is in milli sec)
+	private static final LongAdder heEncryption = new LongAdder(); // SEALClient::encrypt
+	private static final LongAdder heAccumulation = new LongAdder(); // SEALServer::accumulateCiphertexts
+	private static final LongAdder hePartialDecryption = new LongAdder(); // SEALClient::partiallyDecrypt
+	private static final LongAdder heDecryption = new LongAdder(); // SEALServer::average
+
+	private static final LongAdder fedAggregation = new LongAdder(); // SEALServer::average
 
 	public static void incWorkerNumber() {
 		numWorkers.increment();
@@ -110,6 +119,14 @@ public class ParamServStatistics {
 		fedWorkerComputingTime.add(t);
 	}
 
+	public static void accFedNetworkTime(long t) {
+		fedNetworkTime.add(t);
+	}
+
+	public static void accFedAggregation(long t) {
+		fedAggregation.add(t);
+	}
+
 	public static void accFedGradientWeightingTime(long t) {
 		fedGradientWeightingTime.add(t);
 	}
@@ -118,6 +135,22 @@ public class ParamServStatistics {
 		fedCommunicationTime.add(t);
 	}
 
+	public static void accHEEncryptionTime(long t) {
+		heEncryption.add(t);
+	}
+
+	public static void accHEAccumulation(long t) {
+		heAccumulation.add(t);
+	}
+
+	public static void accHEPartialDecryptionTime(long t) {
+		hePartialDecryption.add(t);
+	}
+
+	public static void accHEDecryptionTime(long t) {
+		heDecryption.add(t);
+	}
+
 	public static void reset() {
 		executionTime.reset();
 		numWorkers.reset();
@@ -133,6 +166,12 @@ public class ParamServStatistics {
 		fedWorkerComputingTime.reset();
 		fedGradientWeightingTime.reset();
 		fedCommunicationTime.reset();
+		fedNetworkTime.reset();
+		heEncryption.reset();
+		heAccumulation.reset();
+		hePartialDecryption.reset();
+		heDecryption.reset();
+		fedAggregation.reset();
 	}
 
 	public static String displayStatistics() {
@@ -168,4 +207,16 @@ public class ParamServStatistics {
 		sb.append(String.format("PS fed grad. weigh. time (cum):\t%.3f secs.\n", fedGradientWeightingTime.doubleValue() / 1000));
 		return sb.toString();
 	}
+
+	public static String displayFloStatistics() {
+		StringBuilder sb = new StringBuilder();
+		sb.append(String.format("PS fed network time (cum):\t%.3f secs.\n", fedNetworkTime.doubleValue() / 1000));
+		sb.append(String.format("PS fed agg time:\t%.3f secs.\n", fedAggregation.doubleValue() / 1000));
+		sb.append(String.format("Paramserv grad compute time:\t%.3f secs.\n", gradientComputeTime.doubleValue() / 1000));
+		sb.append(String.format("HE PS encryption time:\t%.3f secs.\n", heEncryption.doubleValue() / 1000));
+		sb.append(String.format("HE PS accumulation time:\t%.3f secs.\n", heAccumulation.doubleValue() / 1000));
+		sb.append(String.format("HE PS partial decryption time:\t%.3f secs.\n", hePartialDecryption.doubleValue() / 1000));
+		sb.append(String.format("HE PS decryption time:\t%.3f secs.\n", heDecryption.doubleValue() / 1000));
+		return sb.toString();
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 6ebff8eacd..c5f7d1a54b 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -652,6 +652,12 @@ public abstract class AutomatedTestBase {
 	 */
 	protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name,
 		double[][] matrix, int numFederatedWorkers, List<Integer> ports, double[][] ranges)
+	{
+		rowFederateLocallyAndWriteInputMatrixWithMTD(name, matrix, numFederatedWorkers, ports, ranges, null);
+	}
+
+	protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name,
+		double[][] matrix, int numFederatedWorkers, List<Integer> ports, double[][] ranges, PrivacyConstraint privacyConstraint)
 	{
 		// check matrix non empty
 		if(matrix.length == 0 || matrix[0].length == 0)
@@ -677,7 +683,7 @@ public abstract class AutomatedTestBase {
 			// write slice
 			writeInputMatrixWithMTD(path, Arrays.copyOfRange(matrix, (int)lowerBound, (int)upperBound),
 				false, new MatrixCharacteristics((long) examplesForWorkerI, ncol,
-				OptimizerUtils.DEFAULT_BLOCKSIZE, (long) examplesForWorkerI * ncol));
+				OptimizerUtils.DEFAULT_BLOCKSIZE, (long) examplesForWorkerI * ncol), privacyConstraint);
 
 			// generate fedmap entry
 			FederatedRange range = new FederatedRange(new long[]{(long) lowerBound, 0}, new long[]{(long) upperBound, ncol});
@@ -688,7 +694,7 @@ public abstract class AutomatedTestBase {
 		federatedMatrixObject.setFedMapping(new FederationMap(FederationUtils.getNextFedDataID(), fedHashMap));
 		federatedMatrixObject.getFedMapping().setType(FType.ROW);
 
-		writeInputFederatedWithMTD(name, federatedMatrixObject, null);
+		writeInputFederatedWithMTD(name, federatedMatrixObject, privacyConstraint);
 	}
 
 	protected double[][] generateBalancedFederatedRowRanges(int numFederatedWorkers, int dataSetSize) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
new file mode 100644
index 0000000000..250358d408
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.codegen.SpoofCompiler;
+import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class EncryptedFederatedParamservTest extends AutomatedTestBase {
+	// private static final Log LOG = LogFactory.getLog(EncryptedFederatedParamservTest.class.getName());
+	private final static String TEST_DIR = "functions/federated/paramserv/";
+	private final static String TEST_NAME = "EncryptedFederatedParamservTest";
+	private final static String TEST_CLASS_DIR = TEST_DIR + EncryptedFederatedParamservTest.class.getSimpleName() + "/";
+
+	private final String _networkType;
+	private final int _numFederatedWorkers;
+	private final int _dataSetSize;
+	private final int _epochs;
+	private final int _batch_size;
+	private final double _eta;
+	private final String _utype;
+	private final String _freq;
+	private final String _scheme;
+	private final String _runtime_balancing;
+	private final String _weighting;
+	private final String _data_distribution;
+	private final int _seed;
+
+	// parameters
+	@Parameterized.Parameters
+	public static Collection<Object[]> parameters() {
+		return Arrays.asList(new Object[][] {
+				// Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency
+				// basic functionality
+				//{"TwoNN",	4, 60000, 32, 4, 0.01, 	"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"NONE" ,		"false","BALANCED",		200},
+
+				// One important point is that we do the model averaging in the case of BSP
+				{"TwoNN",	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"false",	"IMBALANCED",	200},
+				{"CNN", 	2, 4, 1, 1, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "BASELINE",		"false",	"IMBALANCED", 	200},
+				//{"TwoNN", 	5, 1000, 100, 1, 0.01, 	"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"NONE",			"true",	"BALANCED",		200},
+
+				/*
+                    // runtime balancing
+                    {"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MIN", 	"true",	"IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MIN", 	"true",	"IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_AVG", 	"true",	"IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_AVG", 	"true",	"IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MAX",	"true", "IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"CYCLE_MAX",	"true", "IMBALANCED",	200},
+    
+                    // data partitioning
+                    {"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "SHUFFLE", 				"CYCLE_AVG", 	"true",	"IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "REPLICATE_TO_MAX",	 	"NONE", 		"true",	"IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "SUBSAMPLE_TO_MIN",		"NONE", 		"true",	"IMBALANCED",	200},
+                    {"TwoNN", 	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "BALANCE_TO_AVG",		"NONE", 		"true",	"IMBALANCED",	200},
+    
+                    // balanced tests
+                    {"CNN", 	5, 1000, 100, 2, 0.01, 	"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", 	"NONE", 		"true",	"BALANCED",		200}
+                */
+		});
+	}
+
+	public EncryptedFederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size,
+										  int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighting, String data_distribution, int seed) {
+		try {
+			NativeHEHelper.initialize();
+		} catch (Exception e) {
+			throw e;
+		}
+		_networkType = networkType;
+		_numFederatedWorkers = numFederatedWorkers;
+		_dataSetSize = dataSetSize;
+		_batch_size = batch_size;
+		_epochs = epochs;
+		_eta = eta;
+		_utype = utype;
+		_freq = freq;
+		_scheme = scheme;
+		_runtime_balancing = runtime_balancing;
+		_weighting = weighting;
+		_data_distribution = data_distribution;
+		_seed = seed;
+	}
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+	}
+
+	@Test
+	@Ignore
+	public void EncryptedfederatedParamservSingleNode() {
+		EncryptedfederatedParamserv(ExecMode.SINGLE_NODE, true);
+	}
+
+	@Test
+	@Ignore
+	public void EncryptedfederatedParamservHybrid() {
+		EncryptedfederatedParamserv(ExecMode.HYBRID, true);
+	}
+
+	private void EncryptedfederatedParamserv(ExecMode mode, boolean modelAvg) {
+		// Warning Statistics accumulate in unit test
+		// config
+		getAndLoadTestConfiguration(TEST_NAME);
+		String HOME = SCRIPT_DIR + TEST_DIR;
+		setOutputBuffering(true);
+
+		int C = 1, Hin = 28, Win = 28;
+		int numLabels = 10;
+
+		ExecMode platformOld = setExecMode(mode);
+
+		try {
+			// start threads
+			List<Integer> ports = new ArrayList<>();
+			List<Thread> threads = new ArrayList<>();
+			for(int i = 0; i < _numFederatedWorkers; i++) {
+				ports.add(getRandomAvailablePort());
+				threads.add(startLocalFedWorkerThread(ports.get(i),
+						i==(_numFederatedWorkers-1) ? FED_WORKER_WAIT : FED_WORKER_WAIT_S));
+			}
+
+			// generate test data
+			double[][] features = generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
+			double[][] labels = generateDummyMNISTLabels(_dataSetSize, numLabels);
+			String featuresName = "";
+			String labelsName = "";
+
+			PrivacyConstraint privacyConstraint = new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.Private);
+
+			// federate test data balanced or imbalanced
+			if(_data_distribution.equals("IMBALANCED")) {
+				featuresName = "X_IMBALANCED_" + _numFederatedWorkers;
+				labelsName = "y_IMBALANCED_" + _numFederatedWorkers;
+				double[][] ranges = {{0,1}, {1,4}};
+				rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges, privacyConstraint);
+				rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges, privacyConstraint);
+			}
+			else {
+				featuresName = "X_BALANCED_" + _numFederatedWorkers;
+				labelsName = "y_BALANCED_" + _numFederatedWorkers;
+				double[][] ranges = generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length);
+				rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges, privacyConstraint);
+				rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges, privacyConstraint);
+			}
+
+			try {
+				//wait for all workers to be setup
+				Thread.sleep(FED_WORKER_WAIT);
+			}
+			catch(InterruptedException e) {
+				e.printStackTrace();
+			}
+
+			// dml name
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			// generate program args
+			List<String> programArgsList = new ArrayList<>(Arrays.asList("-stats",
+					"-nvargs",
+					"features=" + input(featuresName),
+					"labels=" + input(labelsName),
+					"epochs=" + _epochs,
+					"batch_size=" + _batch_size,
+					"eta=" + _eta,
+					"utype=" + _utype,
+					"freq=" + _freq,
+					"scheme=" + _scheme,
+					"runtime_balancing=" + _runtime_balancing,
+					"weighting=" + _weighting,
+					"network_type=" + _networkType,
+					"channels=" + C,
+					"hin=" + Hin,
+					"win=" + Win,
+					"seed=" + _seed,
+					"modelAvg=" +  Boolean.toString(modelAvg).toUpperCase()));
+
+			programArgs = programArgsList.toArray(new String[0]);
+			String log = runTest(null).toString();
+			Assert.assertEquals("Test Failed \n" + log, 0, Statistics.getNoOfExecutedSPInst());
+
+			// shut down threads
+			for(int i = 0; i < _numFederatedWorkers; i++) {
+				TestUtils.shutdownThreads(threads.get(i));
+			}
+		}
+		finally {
+			resetExecMode(platformOld);
+		}
+	}
+
+	/**
+	 * Generates an feature matrix that has the same format as the MNIST dataset,
+	 * but is completely random and normalized
+	 *
+	 *  @param numExamples Number of examples to generate
+	 *  @param C Channels in the input data
+	 *  @param Hin Height in Pixels of the input data
+	 *  @param Win Width in Pixels of the input data
+	 *  @return a dummy MNIST feature matrix
+	 */
+	private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) {
+		// Seed -1 takes the time in milliseconds as a seed
+		// Sparsity 1 means no sparsity
+		return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+	}
+
+	/**
+	 * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists
+	 * of one hot encoded vectors as rows
+	 *
+	 *  @param numExamples Number of examples to generate
+	 *  @param numLabels Number of labels to generate
+	 *  @return a dummy MNIST lable matrix
+	 */
+	private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) {
+		// Seed -1 takes the time in milliseconds as a seed
+		// Sparsity 1 means no sparsity
+		return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java b/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java
new file mode 100644
index 0000000000..5bb952d9ed
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.homomorphicEncryption;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient;
+import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer;
+import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
+import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class InOutTest extends AutomatedTestBase {
+    private final static String TEST_NAME = "InOutTest";
+    private final static String TEST_DIR = "functions/data/";
+    private final static String TEST_CLASS_DIR = TEST_DIR + InOutTest.class.getSimpleName() + "/";
+
+    private final int num_clients = 3;
+
+    private final int rows = 100;
+    private final int cols = 200;
+    private final long seed = 42;
+
+    @Override
+    public void setUp() {
+        try {
+            NativeHEHelper.initialize();
+        } catch (Exception e) {
+            throw e;
+        }
+
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "C" }) );
+    }
+
+    @Test
+    @Ignore
+    public void endToEndTest() {
+        SEALServer server = new SEALServer();
+
+        SEALClient[] clients = new SEALClient[num_clients];
+        PublicKey[] partial_pub_keys = new PublicKey[num_clients];
+        for (int i = 0; i < num_clients; i++) {
+            clients[i] = new SEALClient(server.generateA());
+            partial_pub_keys[i] = clients[i].generatePartialPublicKey();
+        }
+
+        PublicKey public_key = server.aggregatePartialPublicKeys(partial_pub_keys);
+
+        MatrixObject[] plaintexts = new MatrixObject[num_clients];
+        CiphertextMatrix[] ciphertexts = new CiphertextMatrix[num_clients];
+        for (int i = 0; i < num_clients; i++) {
+            MatrixBlock mb = TestUtils.generateTestMatrixBlock(rows, cols, -100, 100, 1.0, seed+i);
+            MatrixObject mo = new MatrixObject(Types.ValueType.FP64, null);
+            mo.setMetaData(new MetaDataFormat(new MatrixCharacteristics(rows, cols), Types.FileFormat.BINARY));
+            mo.acquireModify(mb);
+            mo.release();
+            plaintexts[i] = mo;
+
+            clients[i].setPublicKey(public_key);
+            ciphertexts[i] = clients[i].encrypt(plaintexts[i]);
+        }
+
+        CiphertextMatrix encrypted_sum = server.accumulateCiphertexts(ciphertexts);
+
+        PlaintextMatrix[] partial_decryptions = new PlaintextMatrix[num_clients];
+        for (int i = 0; i < num_clients; i++) {
+            partial_decryptions[i] = clients[i].partiallyDecrypt(encrypted_sum);
+        }
+
+        MatrixObject result = server.average(encrypted_sum, partial_decryptions);
+
+        double[] expected_raw_result = new double[rows*cols];
+        double[][] plaintexts_raw = new double[num_clients][];
+        for (int i = 0; i < num_clients; i++) {
+            plaintexts_raw[i] = plaintexts[i].acquireReadAndRelease().getDenseBlockValues();
+        }
+        for (int x = 0; x < rows * cols; x++) {
+            double sum = 0.0;
+            for (int i = 0; i < num_clients; i++) {
+                sum += plaintexts_raw[i][x];
+            }
+            expected_raw_result[x] = sum / num_clients;
+        }
+
+        double[] raw_result = result.acquireReadAndRelease().getDenseBlockValues();
+        assert result.getNumRows() == rows;
+        assert result.getNumColumns() == cols;
+        assert raw_result.length == rows*cols;
+        TestUtils.compareMatrices(raw_result, expected_raw_result, 5e-8);
+    }
+}
diff --git a/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml
new file mode 100644
index 0000000000..b8021867dc
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml
@@ -0,0 +1,61 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
+source("src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml") as TwoNNModelAvg
+source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
+source("src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml") as CNNModelAvg
+
+
+# create federated input matrices
+features = read($features)
+labels = read($labels)
+
+if($network_type == "TwoNN") {
+  if(!as.logical($modelAvg)) {
+    model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed)
+    print("Test results:")
+    [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list())
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+  }
+  else if (as.logical($modelAvg)){
+    model = TwoNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed, $modelAvg)
+    print("Test results:")
+    [loss_test, accuracy_test] = TwoNNModelAvg::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list())
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+  }
+}
+else if($network_type == "CNN") {
+  if(!as.logical($modelAvg)) {
+    model = CNN::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed)
+    print("Test results:")
+    hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+    [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams)
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+  }
+  else if (as.logical($modelAvg)){
+    model = CNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed, $modelAvg)
+    print("Test results:")
+    hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+    [loss_test, accuracy_test] = CNNModelAvg::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams)
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+  }
+}