You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by et...@apache.org on 2023/12/19 03:44:30 UTC
(incubator-celeborn) branch branch-0.4 updated: [CELEBORN-1150] support io encryption for spark
This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch branch-0.4
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/branch-0.4 by this push:
new a53463825 [CELEBORN-1150] support io encryption for spark
a53463825 is described below
commit a53463825cfc3f7a6fb3dfb8ec2581cea174923e
Author: mingji <fe...@alibaba-inc.com>
AuthorDate: Tue Dec 19 11:44:05 2023 +0800
[CELEBORN-1150] support io encryption for spark
### What changes were proposed in this pull request?
1. To support io encryption for spark.
### Why are the changes needed?
Ditto.
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
GA and manually test on a cluster.
Closes #2135 from FMX/B1150.
Authored-by: mingji <fe...@alibaba-inc.com>
Signed-off-by: mingji <fe...@alibaba-inc.com>
(cherry picked from commit 4dacf72a6d662dd6b2bf6c60669a8baad0f1566d)
Signed-off-by: mingji <fe...@alibaba-inc.com>
---
client-spark/spark-2-shaded/pom.xml | 1 +
.../shuffle/celeborn/SparkShuffleManager.java | 36 +++++-
.../shuffle/celeborn/CelebornShuffleHandle.scala | 24 +++-
.../shuffle/celeborn/CelebornShuffleReader.scala | 1 +
.../celeborn/CelebornColumnarShuffleReader.scala | 10 +-
.../CelebornColumnarShuffleReaderSuite.scala | 11 +-
client-spark/spark-3-shaded/pom.xml | 1 +
.../shuffle/celeborn/SparkShuffleManager.java | 51 +++++++-
.../apache/spark/shuffle/celeborn/SparkUtils.java | 14 ++-
.../shuffle/celeborn/CelebornShuffleHandle.scala | 24 +++-
.../shuffle/celeborn/CelebornShuffleReader.scala | 31 ++++-
.../org/apache/celeborn/client/ShuffleClient.java | 27 +++++
.../apache/celeborn/client/ShuffleClientImpl.java | 60 ++++++++++
.../celeborn/client/read/CelebornInputStream.java | 66 +++++++++++
.../celeborn/client/security/CryptoUtils.java | 128 +++++++++++++++++++++
.../org/apache/celeborn/common/CelebornConf.scala | 30 +++++
docs/configuration/client.md | 3 +
.../celeborn/tests/spark/SparkTestBase.scala | 2 +
18 files changed, 498 insertions(+), 22 deletions(-)
diff --git a/client-spark/spark-2-shaded/pom.xml b/client-spark/spark-2-shaded/pom.xml
index 655e4b433..c21ac8871 100644
--- a/client-spark/spark-2-shaded/pom.xml
+++ b/client-spark/spark-2-shaded/pom.xml
@@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
+ <include>org.apache.commons:commons-crypto</include>
</includes>
</artifactSet>
<filters>
diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 470d2e989..35939dd5c 100644
--- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -18,15 +18,20 @@
package org.apache.spark.shuffle.celeborn;
import java.io.IOException;
+import java.util.Optional;
+import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import scala.Int;
+import scala.Option;
import org.apache.spark.*;
+import org.apache.spark.internal.config.package$;
import org.apache.spark.launcher.SparkLauncher;
import org.apache.spark.rdd.DeterministicLevel;
+import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.util.Utils;
@@ -35,6 +40,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.common.util.ThreadUtils;
@@ -99,7 +105,29 @@ public class SparkShuffleManager implements ShuffleManager {
return _sortShuffleManager;
}
- private void initializeLifecycleManager(String appId) {
+ private Properties getIoCryptoConf() {
+ if (!celebornConf.sparkIoEncryptionEnabled()) return new Properties();
+ Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf);
+ cryptoConf.put(
+ CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION,
+ conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION()));
+ return cryptoConf;
+ }
+
+ private Optional<byte[]> getIoCryptoKey() {
+ if (!celebornConf.sparkIoEncryptionEnabled()) return Optional.empty();
+ Option<byte[]> key = SparkEnv.get().securityManager().getIOEncryptionKey();
+ return key.isEmpty() ? Optional.empty() : Optional.ofNullable(key.get());
+ }
+
+ private byte[] getIoCryptoInitializationVector() {
+ if (!celebornConf.sparkIoEncryptionEnabled()) return null;
+ return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false)
+ ? CryptoUtils.createIoCryptoInitializationVector()
+ : null;
+ }
+
+ private void initializeLifecycleManager(String appId, byte[] ioCryptoInitializationVector) {
// Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we
// need to ensure that LifecycleManager will only be created once. Parallelism needs to be
// considered in this place, because if there is one RDD that depends on multiple RDDs
@@ -126,7 +154,8 @@ public class SparkShuffleManager implements ShuffleManager {
// is the same SparkContext among different shuffleIds.
// This method may be called many times.
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
- initializeLifecycleManager(appUniqueId);
+ byte[] iv = getIoCryptoInitializationVector();
+ initializeLifecycleManager(appUniqueId, iv);
lifecycleManager.registerAppShuffleDeterminate(
shuffleId,
@@ -146,7 +175,8 @@ public class SparkShuffleManager implements ShuffleManager {
shuffleId,
celebornConf.clientFetchThrowsFetchFailure(),
numMaps,
- dependency);
+ dependency,
+ iv);
}
}
diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index 4f67edaf3..dc9783a7c 100644
--- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
shuffleId: Int,
val throwsFetchFailure: Boolean,
numMappers: Int,
- dependency: ShuffleDependency[K, V, C])
- extends BaseShuffleHandle(shuffleId, numMappers, dependency)
+ dependency: ShuffleDependency[K, V, C],
+ val ioCryptoInitializationVector: Array[Byte])
+ extends BaseShuffleHandle(shuffleId, numMappers, dependency) {
+ def this(
+ appUniqueId: String,
+ lifecycleManagerHost: String,
+ lifecycleManagerPort: Int,
+ userIdentifier: UserIdentifier,
+ shuffleId: Int,
+ throwsFetchFailure: Boolean,
+ numMappers: Int,
+ dependency: ShuffleDependency[K, V, C]) = this(
+ appUniqueId,
+ lifecycleManagerHost,
+ lifecycleManagerPort,
+ userIdentifier,
+ shuffleId,
+ throwsFetchFailure,
+ numMappers,
+ dependency,
+ null)
+}
diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index dec305225..56a36f0e0 100644
--- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle.celeborn
import java.io.IOException
+import java.util.{Optional, Properties}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
index fd888fb9d..363f5d7ed 100644
--- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
+++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.celeborn
+import java.util.{Optional, Properties}
+
import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
@@ -34,7 +36,9 @@ class CelebornColumnarShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
- shuffleIdTracker: ExecutorShuffleIdTracker)
+ shuffleIdTracker: ExecutorShuffleIdTracker,
+ ioCryptoKey: Optional[Array[Byte]],
+ ioCryptoConf: Properties)
extends CelebornShuffleReader[K, C](
handle,
startPartition,
@@ -44,7 +48,9 @@ class CelebornColumnarShuffleReader[K, C](
context,
conf,
metrics,
- shuffleIdTracker) {
+ shuffleIdTracker,
+ ioCryptoKey,
+ ioCryptoConf) {
override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
diff --git a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index 5df434f54..231fdfbb5 100644
--- a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++ b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.celeborn
+import java.util.Optional
+
import org.apache.spark.{ShuffleDependency, SparkConf}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
@@ -55,7 +57,9 @@ class CelebornColumnarShuffleReaderSuite {
null,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty(),
+ null)
assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]])
} finally {
if (shuffleClient != null) {
@@ -78,6 +82,7 @@ class CelebornColumnarShuffleReaderSuite {
0,
false,
10,
+ null,
null),
0,
10,
@@ -86,7 +91,9 @@ class CelebornColumnarShuffleReaderSuite {
null,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty(),
+ null)
val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]])
Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml
index c8701776d..f4aab4be8 100644
--- a/client-spark/spark-3-shaded/pom.xml
+++ b/client-spark/spark-3-shaded/pom.xml
@@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
+ <include>org.apache.commons:commons-crypto</include>
</includes>
</artifactSet>
<filters>
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index a1cb458cf..03a38b933 100644
--- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -18,13 +18,17 @@
package org.apache.spark.shuffle.celeborn;
import java.io.IOException;
+import java.util.Optional;
+import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.*;
+import org.apache.spark.internal.config.package$;
import org.apache.spark.launcher.SparkLauncher;
import org.apache.spark.rdd.DeterministicLevel;
+import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.sql.internal.SQLConf;
@@ -34,6 +38,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.common.util.ThreadUtils;
@@ -130,7 +135,32 @@ public class SparkShuffleManager implements ShuffleManager {
return _sortShuffleManager;
}
- private void initializeLifecycleManager() {
+ private Properties getIoCryptoConf() {
+ if (!celebornConf.sparkIoEncryptionEnabled()) return new Properties();
+ Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf);
+ cryptoConf.put(
+ CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION,
+ conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION()));
+ return cryptoConf;
+ }
+
+ private Optional<byte[]> getIoCryptoKey() {
+ if (!celebornConf.sparkIoEncryptionEnabled()) return Optional.empty();
+ return SparkEnv.get()
+ .securityManager()
+ .getIOEncryptionKey()
+ .map(key -> Optional.ofNullable(key))
+ .getOrElse(() -> Optional.empty());
+ }
+
+ private byte[] getIoCryptoInitializationVector() {
+ if (!celebornConf.sparkIoEncryptionEnabled()) return null;
+ return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false)
+ ? CryptoUtils.createIoCryptoInitializationVector()
+ : null;
+ }
+
+ private void initializeLifecycleManager(byte[] ioCryptoInitializationVector) {
// Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we
// need to ensure that LifecycleManager will only be created once. Parallelism needs to be
// considered in this place, because if there is one RDD that depends on multiple RDDs
@@ -158,7 +188,8 @@ public class SparkShuffleManager implements ShuffleManager {
// is the same SparkContext among different shuffleIds.
// This method may be called many times.
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
- initializeLifecycleManager();
+ byte[] iv = getIoCryptoInitializationVector();
+ initializeLifecycleManager(iv);
lifecycleManager.registerAppShuffleDeterminate(
shuffleId,
@@ -187,7 +218,8 @@ public class SparkShuffleManager implements ShuffleManager {
shuffleId,
celebornConf.clientFetchThrowsFetchFailure(),
dependency.rdd().getNumPartitions(),
- dependency);
+ dependency,
+ iv);
}
}
@@ -242,7 +274,10 @@ public class SparkShuffleManager implements ShuffleManager {
h.lifecycleManagerHost(),
h.lifecycleManagerPort(),
celebornConf,
- h.userIdentifier());
+ h.userIdentifier(),
+ getIoCryptoKey(),
+ getIoCryptoConf(),
+ h.ioCryptoInitializationVector());
int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, true);
shuffleIdTracker.track(h.shuffleId(), shuffleId);
@@ -371,7 +406,9 @@ public class SparkShuffleManager implements ShuffleManager {
context,
celebornConf,
metrics,
- shuffleIdTracker);
+ shuffleIdTracker,
+ getIoCryptoKey(),
+ getIoCryptoConf());
} else {
return new CelebornShuffleReader<>(
h,
@@ -382,7 +419,9 @@ public class SparkShuffleManager implements ShuffleManager {
context,
celebornConf,
metrics,
- shuffleIdTracker);
+ shuffleIdTracker,
+ getIoCryptoKey(),
+ getIoCryptoConf());
}
}
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index e7a6a5b8b..b11dcde29 100644
--- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.celeborn;
+import java.util.Optional;
+import java.util.Properties;
import java.util.concurrent.atomic.LongAdder;
import scala.Tuple2;
@@ -219,7 +221,9 @@ public class SparkUtils {
TaskContext.class,
CelebornConf.class,
ShuffleReadMetricsReporter.class,
- ExecutorShuffleIdTracker.class);
+ ExecutorShuffleIdTracker.class,
+ Optional.class,
+ Properties.class);
public static <K, C> CelebornShuffleReader<K, C> createColumnarShuffleReader(
CelebornShuffleHandle<K, ?, C> handle,
@@ -230,7 +234,9 @@ public class SparkUtils {
TaskContext context,
CelebornConf conf,
ShuffleReadMetricsReporter metrics,
- ExecutorShuffleIdTracker shuffleIdTracker) {
+ ExecutorShuffleIdTracker shuffleIdTracker,
+ Optional<byte[]> ioCryptoKey,
+ Properties ioCryptoConf) {
return COLUMNAR_SHUFFLE_READER_CONSTRUCTOR_BUILDER
.build()
.invoke(
@@ -243,7 +249,9 @@ public class SparkUtils {
context,
conf,
metrics,
- shuffleIdTracker);
+ shuffleIdTracker,
+ ioCryptoKey,
+ ioCryptoConf);
}
// Added in SPARK-32920, for Spark 3.2 and above
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index 18a3053e0..2c12282e4 100644
--- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
shuffleId: Int,
val throwsFetchFailure: Boolean,
val numMappers: Int,
- dependency: ShuffleDependency[K, V, C])
- extends BaseShuffleHandle(shuffleId, dependency)
+ dependency: ShuffleDependency[K, V, C],
+ val ioCryptoInitializationVector: Array[Byte])
+ extends BaseShuffleHandle(shuffleId, dependency) {
+ def this(
+ appUniqueId: String,
+ lifecycleManagerHost: String,
+ lifecycleManagerPort: Int,
+ userIdentifier: UserIdentifier,
+ shuffleId: Int,
+ throwsFetchFailure: Boolean,
+ numMappers: Int,
+ dependency: ShuffleDependency[K, V, C]) = this(
+ appUniqueId,
+ lifecycleManagerHost,
+ lifecycleManagerPort,
+ userIdentifier,
+ shuffleId,
+ throwsFetchFailure,
+ numMappers,
+ dependency,
+ null)
+}
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index fe7af8309..ceb3639b4 100644
--- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle.celeborn
import java.io.IOException
+import java.util.{Optional, Properties}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
@@ -44,16 +45,42 @@ class CelebornShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
- shuffleIdTracker: ExecutorShuffleIdTracker)
+ shuffleIdTracker: ExecutorShuffleIdTracker,
+ ioCryptoKey: Optional[Array[Byte]],
+ ioCryptoConf: Properties)
extends ShuffleReader[K, C] with Logging {
+ def this(
+ handle: CelebornShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ context: TaskContext,
+ conf: CelebornConf,
+ metrics: ShuffleReadMetricsReporter) = this(
+ handle,
+ startPartition,
+ endPartition,
+ startMapIndex,
+ endMapIndex,
+ context,
+ conf,
+ metrics,
+ null,
+ Optional.empty(),
+ null)
+
private val dep = handle.dependency
private val shuffleClient = ShuffleClient.get(
handle.appUniqueId,
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
- handle.userIdentifier)
+ handle.userIdentifier,
+ ioCryptoKey,
+ ioCryptoConf,
+ handle.ioCryptoInitializationVector)
private val exceptionRef = new AtomicReference[IOException]
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 72230a536..b5fc2ec6b 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -18,6 +18,8 @@
package org.apache.celeborn.client;
import java.io.IOException;
+import java.util.Optional;
+import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
@@ -61,6 +63,26 @@ public abstract class ShuffleClient {
int port,
CelebornConf conf,
UserIdentifier userIdentifier) {
+ return ShuffleClient.get(
+ appUniqueId,
+ driverHost,
+ port,
+ conf,
+ userIdentifier,
+ Optional.empty(),
+ new Properties(),
+ null);
+ }
+
+ public static ShuffleClient get(
+ String appUniqueId,
+ String driverHost,
+ int port,
+ CelebornConf conf,
+ UserIdentifier userIdentifier,
+ Optional<byte[]> ioCryptoKey,
+ Properties ioCryptoConf,
+ byte[] ioCryptoInitializationVector) {
if (null == _instance || !initialized) {
synchronized (ShuffleClient.class) {
if (null == _instance) {
@@ -72,11 +94,13 @@ public abstract class ShuffleClient {
// when communicating with LifecycleManager, it will cause a NullPointerException.
_instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
_instance.setupLifecycleManagerRef(driverHost, port);
+ _instance.setupIoCrypto(ioCryptoKey, ioCryptoConf, ioCryptoInitializationVector);
initialized = true;
} else if (!initialized) {
_instance.shutdown();
_instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
_instance.setupLifecycleManagerRef(driverHost, port);
+ _instance.setupIoCrypto(ioCryptoKey, ioCryptoConf, ioCryptoInitializationVector);
initialized = true;
}
}
@@ -118,6 +142,9 @@ public abstract class ShuffleClient {
String.format("%.2f", (localReadCount * 1.0d / totalReadCount) * 100));
}
+ public void setupIoCrypto(
+ Optional<byte[]> ioCryptoKey, Properties ioCryptoConf, byte[] ioCryptoInitializationVector) {}
+
public abstract void setupLifecycleManagerRef(String host, int port);
public abstract void setupLifecycleManagerRef(RpcEndpointRef endpointRef);
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index c5f463b19..d83e4283a 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -30,6 +30,7 @@ import scala.reflect.ClassTag$;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
+import org.apache.commons.crypto.cipher.CryptoCipher;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,6 +38,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.compress.Compressor;
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
+import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.identity.UserIdentifier;
@@ -157,6 +159,29 @@ public class ShuffleClientImpl extends ShuffleClient {
protected final Map<Integer, ReduceFileGroups> reduceFileGroupsMap =
JavaUtils.newConcurrentHashMap();
+ protected Optional<byte[]> ioCryptoKey = Optional.empty();
+
+ protected Properties ioCryptoConf;
+
+ protected byte[] ioCyrptoInitializationVector;
+
+ private ThreadLocal<CryptoCipher> encipherThreadLocal =
+ new ThreadLocal<CryptoCipher>() {
+ @Override
+ protected CryptoCipher initialValue() {
+ CryptoCipher cryptoCipher = null;
+ if (ioCryptoKey.isPresent()) {
+ try {
+ cryptoCipher =
+ CryptoUtils.getEncipher(ioCryptoKey, ioCryptoConf, ioCyrptoInitializationVector);
+ } catch (IOException e) {
+ logger.error("Failed to init crypto", e);
+ }
+ }
+ return cryptoCipher;
+ }
+ };
+
public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier userIdentifier) {
super();
this.appUniqueId = appUniqueId;
@@ -878,6 +903,23 @@ public class ShuffleClientImpl extends ShuffleClient {
// increment batchId
final int nextBatchId = pushState.nextBatchId();
+ if (ioCryptoKey.isPresent()) {
+ CryptoCipher encipher = encipherThreadLocal.get();
+ byte[] encryptData = new byte[length + encipher.getBlockSize()];
+ int encryptLength = CryptoUtils.encrypt(encipher, data, offset, length, encryptData);
+ logger.debug(
+ "Push data encryption encryptLength/beforeLength {}/{} for shuffle {} map {} attempt {} partition {}.",
+ encryptLength,
+ length,
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId);
+ length = encryptLength;
+ data = encryptData;
+ offset = 0;
+ }
+
if (shuffleCompressionEnabled) {
// compress data
final Compressor compressor = compressorThreadLocal.get();
@@ -1651,6 +1693,9 @@ public class ShuffleClientImpl extends ShuffleClient {
startMapIndex,
endMapIndex,
fetchExcludedWorkers,
+ ioCryptoKey,
+ ioCryptoConf,
+ ioCyrptoInitializationVector,
metricsCallback);
}
}
@@ -1754,4 +1799,19 @@ public class ShuffleClientImpl extends ShuffleClient {
public TransportClientFactory getDataClientFactory() {
return dataClientFactory;
}
+
+ @Override
+ public void setupIoCrypto(
+ Optional<byte[]> ioCryptoKey, Properties ioCryptoConf, byte[] ioCryptoInitializationVector) {
+ this.ioCryptoKey = ioCryptoKey;
+ this.ioCryptoConf = ioCryptoConf;
+ this.ioCyrptoInitializationVector = ioCryptoInitializationVector;
+ if (this.ioCryptoKey.isPresent()) {
+ try {
+ CryptoUtils.getEncipher(this.ioCryptoKey, this.ioCryptoConf, ioCryptoInitializationVector);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to init encipher", e);
+ }
+ }
+ }
}
diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index bb1e95ce9..3284c828a 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -27,12 +27,14 @@ import java.util.concurrent.atomic.LongAdder;
import com.google.common.util.concurrent.Uninterruptibles;
import io.netty.buffer.ByteBuf;
+import org.apache.commons.crypto.cipher.CryptoCipher;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.compress.Decompressor;
+import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClientFactory;
@@ -59,6 +61,37 @@ public abstract class CelebornInputStream extends InputStream {
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
MetricsCallback metricsCallback)
throws IOException {
+ return create(
+ conf,
+ clientFactory,
+ shuffleKey,
+ locations,
+ attempts,
+ attemptNumber,
+ startMapIndex,
+ endMapIndex,
+ fetchExcludedWorkers,
+ Optional.empty(),
+ null,
+ null,
+ metricsCallback);
+ }
+
+ public static CelebornInputStream create(
+ CelebornConf conf,
+ TransportClientFactory clientFactory,
+ String shuffleKey,
+ PartitionLocation[] locations,
+ int[] attempts,
+ int attemptNumber,
+ int startMapIndex,
+ int endMapIndex,
+ ConcurrentHashMap<String, Long> fetchExcludedWorkers,
+ Optional<byte[]> ioCryptoKey,
+ Properties ioCryptoProp,
+ byte[] ioCryptoInitializationVector,
+ MetricsCallback metricsCallback)
+ throws IOException {
if (locations == null || locations.length == 0) {
return emptyInputStream;
} else {
@@ -72,6 +105,9 @@ public abstract class CelebornInputStream extends InputStream {
startMapIndex,
endMapIndex,
fetchExcludedWorkers,
+ ioCryptoKey,
+ ioCryptoProp,
+ ioCryptoInitializationVector,
metricsCallback);
}
}
@@ -149,6 +185,9 @@ public abstract class CelebornInputStream extends InputStream {
private boolean shuffleCompressionEnabled;
private long fetchExcludedWorkerExpireTimeout;
private final ConcurrentHashMap<String, Long> fetchExcludedWorkers;
+ private Optional<byte[]> encryptKey;
+ private Properties encryptProp;
+ private CryptoCipher decipher;
private boolean containLocalRead = false;
@@ -162,6 +201,9 @@ public abstract class CelebornInputStream extends InputStream {
int startMapIndex,
int endMapIndex,
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
+ Optional<byte[]> ioCryptoKey,
+ Properties ioCryptoProp,
+ byte[] ioCryptoInitializationVector,
MetricsCallback metricsCallback)
throws IOException {
this.conf = conf;
@@ -202,6 +244,12 @@ public abstract class CelebornInputStream extends InputStream {
retryWaitMs = transportConf.ioRetryWaitTimeMs();
this.callback = metricsCallback;
moveToNextReader();
+
+ this.encryptKey = ioCryptoKey;
+ this.encryptProp = ioCryptoProp;
+ if (ioCryptoKey.isPresent()) {
+ decipher = CryptoUtils.getDecipher(ioCryptoKey, ioCryptoProp, ioCryptoInitializationVector);
+ }
}
private boolean skipLocation(int startMapIndex, int endMapIndex, PartitionLocation location) {
@@ -570,6 +618,7 @@ public abstract class CelebornInputStream extends InputStream {
callback.incBytesRead(BATCH_HEADER_SIZE + size);
if (shuffleCompressionEnabled) {
// decompress data
+
int originalLength = decompressor.getOriginalLen(compressedBuf);
if (rawDataBuf.length < originalLength) {
rawDataBuf = new byte[originalLength];
@@ -578,6 +627,23 @@ public abstract class CelebornInputStream extends InputStream {
} else {
limit = size;
}
+
+ if (decipher != null) {
+ byte[] decryptBuf = new byte[limit];
+ int decryptLength = CryptoUtils.decrypt(decipher, rawDataBuf, 0, limit, decryptBuf);
+ logger.debug(
+ "fetch data decryption shuffleKey: {}, mapId: {}, attempId: {}, batchId: {}, decryptLength/originLength: {}/{}",
+ shuffleKey,
+ mapId,
+ attemptId,
+ batchId,
+ decryptLength,
+ limit);
+ limit = decryptLength;
+ System.arraycopy(decryptBuf, 0, rawDataBuf, 0, limit);
+ decryptBuf = null;
+ }
+
position = 0;
hasData = true;
break;
diff --git a/client/src/main/java/org/apache/celeborn/client/security/CryptoUtils.java b/client/src/main/java/org/apache/celeborn/client/security/CryptoUtils.java
new file mode 100644
index 000000000..5ce4672d6
--- /dev/null
+++ b/client/src/main/java/org/apache/celeborn/client/security/CryptoUtils.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client.security;
+
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.util.Optional;
+import java.util.Properties;
+import java.util.concurrent.TimeUnit;
+
+import javax.crypto.BadPaddingException;
+import javax.crypto.Cipher;
+import javax.crypto.IllegalBlockSizeException;
+import javax.crypto.ShortBufferException;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+
+import org.apache.commons.crypto.cipher.CryptoCipher;
+import org.apache.commons.crypto.random.CryptoRandomFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class CryptoUtils {
+ private static Logger logger = LoggerFactory.getLogger(CryptoUtils.class);
+ public static final int IV_LENGTH_IN_BYTES = 16;
+ public static final String COMMONS_CRYPTO_CONFIG_PREFIX = "commons.crypto.";
+ public static final String COMMONS_CRYPTO_CONFIG_TRANSFORMATION =
+ COMMONS_CRYPTO_CONFIG_PREFIX + "cipher.transformation";
+ public static final String CRYPTO_ALGORITHM = "AES";
+
+ public static byte[] createIoCryptoInitializationVector() {
+ byte[] iv = new byte[IV_LENGTH_IN_BYTES];
+ long initialIVStart = System.nanoTime();
+ try {
+ CryptoRandomFactory.getCryptoRandom(new Properties()).nextBytes(iv);
+ } catch (GeneralSecurityException e) {
+ logger.warn("Failed to create crypto Initialization Vector", e);
+ iv = "1234567890123456".getBytes();
+ }
+ long initialIVFinish = System.nanoTime();
+ long initialIVTime = TimeUnit.NANOSECONDS.toMillis(initialIVFinish - initialIVStart);
+ if (initialIVTime > 2000) {
+ logger.warn(
+ "It costs {} milliseconds to create the Initialization Vector used by crypto",
+ initialIVTime);
+ }
+ return iv;
+ }
+
+ public static CryptoCipher getEncipher(
+ Optional<byte[]> ioCryptoKey, Properties ioCryptoConf, byte[] ioCryptoInitializationVector)
+ throws IOException {
+ CryptoCipher encipher = null;
+ if (ioCryptoKey.isPresent()) {
+ SecretKeySpec keySpec = new SecretKeySpec(ioCryptoKey.get(), CRYPTO_ALGORITHM);
+ String transformation = (String) ioCryptoConf.get(COMMONS_CRYPTO_CONFIG_TRANSFORMATION);
+ try (final CryptoCipher _encipher =
+ org.apache.commons.crypto.utils.Utils.getCipherInstance(transformation, ioCryptoConf)) {
+ encipher = _encipher;
+ try {
+ encipher.init(
+ Cipher.ENCRYPT_MODE, keySpec, new IvParameterSpec(ioCryptoInitializationVector));
+ } catch (GeneralSecurityException e) {
+ throw new IOException("Failed to init encipher", e);
+ }
+ }
+ }
+ return encipher;
+ }
+
+ public static int encrypt(
+ CryptoCipher encipher, byte[] input, int offset, int length, byte[] output)
+ throws IOException {
+ try {
+ int updateBytes = encipher.update(input, offset, length, output, 0);
+ int finalBytes = encipher.doFinal(input, 0, 0, output, updateBytes);
+ return updateBytes + finalBytes;
+ } catch (ShortBufferException | BadPaddingException | IllegalBlockSizeException e) {
+ throw new IOException("Failed to encrypt", e);
+ }
+ }
+
+ public static CryptoCipher getDecipher(
+ Optional<byte[]> key, Properties cryptoProp, byte[] cryptoInitilizationVector)
+ throws IOException {
+ CryptoCipher decipher = null;
+ if (key.isPresent()) {
+ SecretKeySpec keySpec = new SecretKeySpec(key.get(), CRYPTO_ALGORITHM);
+ String transformation = (String) cryptoProp.get(COMMONS_CRYPTO_CONFIG_TRANSFORMATION);
+ try (final CryptoCipher _decipher =
+ org.apache.commons.crypto.utils.Utils.getCipherInstance(transformation, cryptoProp)) {
+ decipher = _decipher;
+ try {
+ decipher.init(
+ Cipher.DECRYPT_MODE, keySpec, new IvParameterSpec(cryptoInitilizationVector));
+ } catch (GeneralSecurityException e) {
+ throw new IOException("Failed to init encipher", e);
+ }
+ }
+ }
+ return decipher;
+ }
+
+ public static int decrypt(
+ CryptoCipher decipher, byte[] input, int offset, int length, byte[] decoded)
+ throws IOException {
+ try {
+ return decipher.doFinal(input, offset, length, decoded, 0);
+ } catch (ShortBufferException | IllegalBlockSizeException | BadPaddingException e) {
+ throw new IOException("Failed to decrypt", e);
+ }
+ }
+}
diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 0c723832c..e89f30078 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -758,6 +758,12 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def clientExcludedWorkerExpireTimeout: Long = get(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT)
def clientExcludeReplicaOnFailureEnabled: Boolean =
get(CLIENT_EXCLUDE_PEER_WORKER_ON_FAILURE_ENABLED)
+
+ def sparkIoEncryptionEnabled: Boolean = get(SPARK_CLIENT_IO_ENCRYPTION_ENABLED)
+ def sparkIoEncryptionKey: String = get(SPARK_CLIENT_IO_ENCRYPTION_KEY)
+ def sparkIoEncryptionInitializationVector: String =
+ get(SPARK_CLIENT_IO_ENCRYPTION_INITIALIZATION_VECTOR)
+
def clientMrMaxPushData: Long = get(CLIENT_MR_PUSH_DATA_MAX)
// //////////////////////////////////////////////////////
@@ -4239,4 +4245,28 @@ object CelebornConf extends Logging {
.version("0.5.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("30s")
+
+ val SPARK_CLIENT_IO_ENCRYPTION_ENABLED: ConfigEntry[Boolean] =
+ buildConf("celeborn.client.spark.io.encryption.enabled")
+ .categories("client")
+ .version("0.4.0")
+ .doc("whether to enable io encryption")
+ .booleanConf
+ .createWithDefault(true)
+
+ val SPARK_CLIENT_IO_ENCRYPTION_KEY: ConfigEntry[String] =
+ buildConf("celeborn.client.spark.io.encryption.key")
+ .categories("client")
+ .version("0.4.0")
+ .doc("io encryption key")
+ .stringConf
+ .createWithDefault("")
+
+ val SPARK_CLIENT_IO_ENCRYPTION_INITIALIZATION_VECTOR: ConfigEntry[String] =
+ buildConf("celeborn.client.spark.io.encryption.initialization.vector")
+ .categories("client")
+ .version("0.4.0")
+ .doc("io encryption initialization vector")
+ .stringConf
+ .createWithDefault("")
}
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 7c171ed3f..d77cb9521 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -102,6 +102,9 @@ license: |
| celeborn.client.shuffle.register.filterExcludedWorker.enabled | false | Whether to filter excluded worker when register shuffle. | 0.4.0 |
| celeborn.client.slot.assign.maxWorkers | 10000 | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 |
| celeborn.client.spark.fetch.throwsFetchFailure | false | client throws FetchFailedException instead of CelebornIOException | 0.4.0 |
+| celeborn.client.spark.io.encryption.enabled | true | whether to enable io encryption | 0.4.0 |
+| celeborn.client.spark.io.encryption.initialization.vector | | io encryption initialization vector | 0.4.0 |
+| celeborn.client.spark.io.encryption.key | | io encryption key | 0.4.0 |
| celeborn.client.spark.push.sort.memory.threshold | 64m | When SortBasedPusher use memory over the threshold, will trigger push data. If the pipeline push feature is enabled (`celeborn.client.spark.push.sort.pipeline.enabled=true`), the SortBasedPusher will trigger a data push when the memory usage exceeds half of the threshold(by default, 32m). | 0.3.0 |
| celeborn.client.spark.push.sort.pipeline.enabled | false | Whether to enable pipelining for sort based shuffle writer. If true, double buffering will be used to pipeline push | 0.3.0 |
| celeborn.client.spark.push.unsafeRow.fastWrite.enabled | true | This is Celeborn's optimization on UnsafeRow for Spark and it's true by default. If you have changed UnsafeRow's memory layout set this to false. | 0.2.2 |
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index e2cb3c98a..54de9b34d 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -59,6 +59,8 @@ trait SparkTestBase extends AnyFunSuite
sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false")
sparkConf.set(s"spark.${MASTER_ENDPOINTS.key}", masterInfo._1.rpcEnv.address.toString)
sparkConf.set(s"spark.${SPARK_SHUFFLE_WRITER_MODE.key}", mode.toString)
+ sparkConf.set("spark.io.encryption.enabled", "true")
+ sparkConf.set("spark.io.crypto.cipher.transformation", "AES/CBC/PKCS5Padding")
sparkConf
}