You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2021/08/21 03:31:35 UTC

[spark] branch branch-2.4 updated: Update Spark key negotiation protocol

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 4be5660  Update Spark key negotiation protocol
4be5660 is described below

commit 4be566062defa249435c4d72eb106fe7b933e023
Author: Sean Owen <sr...@gmail.com>
AuthorDate: Wed Aug 11 18:04:55 2021 -0500

    Update Spark key negotiation protocol
---
 common/network-common/pom.xml                      |   4 +
 .../spark/network/crypto/AuthClientBootstrap.java  |   6 +-
 .../apache/spark/network/crypto/AuthEngine.java    | 420 +++++++++------------
 .../{ServerResponse.java => AuthMessage.java}      |  56 ++-
 .../spark/network/crypto/AuthRpcHandler.java       |   6 +-
 .../spark/network/crypto/ClientChallenge.java      | 101 -----
 .../java/org/apache/spark/network/crypto/README.md | 217 ++++-------
 .../spark/network/crypto/AuthEngineSuite.java      | 182 ++++++---
 .../spark/network/crypto/AuthMessagesSuite.java    |  46 +--
 dev/deps/spark-deps-hadoop-2.6                     |   1 +
 dev/deps/spark-deps-hadoop-2.7                     |   1 +
 dev/deps/spark-deps-hadoop-3.1                     |   1 +
 pom.xml                                            |   6 +
 13 files changed, 432 insertions(+), 615 deletions(-)

diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index cd57c43..d585185 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -85,6 +85,10 @@
       <groupId>org.apache.commons</groupId>
       <artifactId>commons-crypto</artifactId>
     </dependency>
+    <dependency>
+      <groupId>com.google.crypto.tink</groupId>
+      <artifactId>tink</artifactId>
+    </dependency>
 
     <!-- Test dependencies -->
     <dependency>
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
index 737e187..1586989 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
@@ -98,15 +98,15 @@ public class AuthClientBootstrap implements TransportClientBootstrap {
 
     String secretKey = secretKeyHolder.getSecretKey(appId);
     try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) {
-      ClientChallenge challenge = engine.challenge();
+      AuthMessage challenge = engine.challenge();
       ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength());
       challenge.encode(challengeData);
 
       ByteBuffer responseData =
           client.sendRpcSync(challengeData.nioBuffer(), conf.authRTTimeoutMs());
-      ServerResponse response = ServerResponse.decodeMessage(responseData);
+      AuthMessage response = AuthMessage.decodeMessage(responseData);
 
-      engine.validate(response);
+      engine.deriveSessionCipher(challenge, response);
       engine.sessionCipher().addToChannel(channel);
     }
   }
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
index 64fdb32..078d9ce 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
@@ -17,134 +17,216 @@
 
 package org.apache.spark.network.crypto;
 
+import javax.crypto.spec.SecretKeySpec;
 import java.io.Closeable;
-import java.io.IOException;
-import java.math.BigInteger;
 import java.security.GeneralSecurityException;
 import java.util.Arrays;
 import java.util.Properties;
-import javax.crypto.Cipher;
-import javax.crypto.SecretKey;
-import javax.crypto.SecretKeyFactory;
-import javax.crypto.ShortBufferException;
-import javax.crypto.spec.IvParameterSpec;
-import javax.crypto.spec.PBEKeySpec;
-import javax.crypto.spec.SecretKeySpec;
-import static java.nio.charset.StandardCharsets.UTF_8;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.primitives.Bytes;
-import org.apache.commons.crypto.cipher.CryptoCipher;
-import org.apache.commons.crypto.cipher.CryptoCipherFactory;
-import org.apache.commons.crypto.random.CryptoRandom;
-import org.apache.commons.crypto.random.CryptoRandomFactory;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
+import com.google.crypto.tink.subtle.AesGcmJce;
+import com.google.crypto.tink.subtle.Hkdf;
+import com.google.crypto.tink.subtle.Random;
+import com.google.crypto.tink.subtle.X25519;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import static java.nio.charset.StandardCharsets.UTF_8;
 import org.apache.spark.network.util.TransportConf;
 
 /**
- * A helper class for abstracting authentication and key negotiation details. This is used by
- * both client and server sides, since the operations are basically the same.
+ * A helper class for abstracting authentication and key negotiation details.
+ * This supports a forward-secure authentication protocol based on X25519 Diffie-Hellman Key
+ * Exchange, using a pre-shared key to derive an AES-GCM key encrypting key.
  */
 class AuthEngine implements Closeable {
-
-  private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class);
-  private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 });
-
-  private final byte[] appId;
-  private final char[] secret;
+  public static final byte[] INPUT_IV_INFO = "inputIv".getBytes(UTF_8);
+  public static final byte[] OUTPUT_IV_INFO = "outputIv".getBytes(UTF_8);
+  private static final String MAC_ALGORITHM = "HMACSHA256";
+  private static final int AES_GCM_KEY_SIZE_BYTES = 16;
+  private static final byte[] EMPTY_TRANSCRIPT = new byte[0];
+
+  private final String appId;
+  private final byte[] preSharedSecret;
   private final TransportConf conf;
   private final Properties cryptoConf;
-  private final CryptoRandom random;
-
-  private byte[] authNonce;
-
-  @VisibleForTesting
-  byte[] challenge;
 
+  private byte[] clientPrivateKey;
   private TransportCipher sessionCipher;
-  private CryptoCipher encryptor;
-  private CryptoCipher decryptor;
 
-  AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException {
-    this.appId = appId.getBytes(UTF_8);
+  AuthEngine(String appId, String preSharedSecret, TransportConf conf) {
+    Preconditions.checkNotNull(appId);
+    Preconditions.checkNotNull(preSharedSecret);
+    this.appId = appId;
+    this.preSharedSecret = preSharedSecret.getBytes(UTF_8);
     this.conf = conf;
     this.cryptoConf = conf.cryptoConf();
-    this.secret = secret.toCharArray();
-    this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf);
+  }
+
+  @VisibleForTesting
+  void setClientPrivateKey(byte[] privateKey) {
+    this.clientPrivateKey = privateKey;
   }
 
   /**
-   * Create the client challenge.
+   * This method will derive a key from a pre-shared secret, a random salt, and an arbitrary
+   * transcript. It will then use that derived key to AES-GCM encrypt an ephemeral X25519 public
+   * key.
    *
-   * @return A challenge to be sent the remote side.
+   * @param ephemeralX25519PublicKey Ephemeral X25519 Public Key to encrypt under a derived key.
+   * @param transcript               Optional byte array representing a protocol transcript, which
+   *                                 is mixed into the key derivation and included as AES-GCM
+   *                                 associated authenticated data (AAD).
+   * @return An encrypted ephemeral X25519 public key.
+   * @throws GeneralSecurityException If HKDF key deriviation or AES-GCM encryption fails.
    */
-  ClientChallenge challenge() throws GeneralSecurityException {
-    this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
-    SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
-      authNonce, conf.encryptionKeyLength());
-    initializeForAuth(conf.cipherTransformation(), authNonce, authKey);
-
-    this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
-    return new ClientChallenge(new String(appId, UTF_8),
-      conf.keyFactoryAlgorithm(),
-      conf.keyFactoryIterations(),
-      conf.cipherTransformation(),
-      conf.encryptionKeyLength(),
-      authNonce,
-      challenge(appId, authNonce, challenge));
+  private AuthMessage encryptEphemeralPublicKey(
+      byte[] ephemeralX25519PublicKey,
+      byte[] transcript) throws GeneralSecurityException {
+    // This non-secret salt is used in the HKDF key derivations and will be sent in plaintext as
+    // part of the AES-GCM encrypted X25519 public key. It will be included as additional
+    // associated data (AAD).
+    byte[] nonSecretSalt = Random.randBytes(AES_GCM_KEY_SIZE_BYTES);
+    // Mix in the app ID, salt, and transcript into HKDF and use it as AES-GCM AAD
+    byte[] aadState = Bytes.concat(appId.getBytes(UTF_8), nonSecretSalt, transcript);
+    // Use HKDF to derive an AES_GCM key from the pre-shared key, non-secret salt, and AAD state
+    byte[] derivedKeyEncryptingKey = Hkdf.computeHkdf(
+        MAC_ALGORITHM,
+        preSharedSecret,
+        nonSecretSalt,
+        aadState,
+        AES_GCM_KEY_SIZE_BYTES);
+    // AES-GCM encrypt the X25519 public key and include the app ID, salt, and transcript as AAD
+    byte[] aesGcmCiphertext = new AesGcmJce(derivedKeyEncryptingKey)
+        .encrypt(ephemeralX25519PublicKey, aadState);
+    return new AuthMessage(appId, nonSecretSalt, aesGcmCiphertext);
   }
 
   /**
-   * Validates the client challenge, and create the encryption backend for the channel from the
-   * parameters sent by the client.
+   * This method will derive a key from a pre-shared secret, a random salt, and an arbitrary
+   * transcript. It will then use that derived key to AES-GCM encrypt an ephemeral X25519
+   * public key.
    *
-   * @param clientChallenge The challenge from the client.
-   * @return A response to be sent to the client.
+   * @param encryptedPublicKey An X25519 public key to decrypt with a derived key
+   * @param transcript         Optional byte array representing a protocol transcript, which is
+   *                           mixed into the key derivation and included as AES-GCM associated
+   *                           authenticated data (AAD).
+   * @return A decrypted ephemeral public key
+   * @throws GeneralSecurityException If decryption fails, notably if authenticated checks fails.
    */
-  ServerResponse respond(ClientChallenge clientChallenge)
-    throws GeneralSecurityException {
-
-    SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
-      clientChallenge.nonce, clientChallenge.keyLength);
-    initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey);
-
-    byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge);
-    byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge));
-    byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
-    byte[] inputIv = randomBytes(conf.ivLength());
-    byte[] outputIv = randomBytes(conf.ivLength());
+  private byte[] decryptEphemeralPublicKey(
+      AuthMessage encryptedPublicKey,
+      byte[] transcript) throws GeneralSecurityException {
+    Preconditions.checkArgument(appId.equals(encryptedPublicKey.appId));
+    // Mix in the app ID, salt, and transcript into HKDF and use it as AES-GCM AAD
+    byte[] aadState = Bytes.concat(appId.getBytes(UTF_8), encryptedPublicKey.salt, transcript);
+    // Use HKDF to derive an AES_GCM key from the pre-shared key, non-secret salt, and AAD state
+    byte[] derivedKeyEncryptingKey = Hkdf.computeHkdf(
+        MAC_ALGORITHM,
+        preSharedSecret,
+        encryptedPublicKey.salt,
+        aadState,
+        AES_GCM_KEY_SIZE_BYTES);
+    // If the AES-GCM payload is modified at all or if the AAD state does not match, decryption
+    // will throw a GeneralSecurityException.
+    return new AesGcmJce(derivedKeyEncryptingKey)
+        .decrypt(encryptedPublicKey.ciphertext, aadState);
+  }
 
-    SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
-      sessionNonce, clientChallenge.keyLength);
-    this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey,
-      inputIv, outputIv);
+  /**
+   * Encrypt an ephemeral X25519 public key to be sent to the server as a challenge.
+   *
+   * @return An encrypted client ephemeral public key to be sent to the server.
+   */
+  AuthMessage challenge() throws GeneralSecurityException {
+    setClientPrivateKey(X25519.generatePrivateKey());
+    return encryptEphemeralPublicKey(
+        X25519.publicFromPrivate(clientPrivateKey),
+        EMPTY_TRANSCRIPT);
+  }
 
-    // Note the IVs are swapped in the response.
-    return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv));
+  /**
+   * Validates the client challenge by decrypting the ephemeral X25519 public key, computing a
+   * shared secret from it, then encrypting a server ephemeral X25519 public key for the client.
+   *
+   * @param encryptedClientPublicKey The encrypted public key from the client to be decrypted.
+   * @return An encrypted server ephemeral public key to be sent to the client.
+   */
+  AuthMessage response(AuthMessage encryptedClientPublicKey) throws GeneralSecurityException {
+    Preconditions.checkArgument(appId.equals(encryptedClientPublicKey.appId));
+    // Compute a shared secret given the client public key and the server private key
+    byte[] clientPublicKey =
+        decryptEphemeralPublicKey(encryptedClientPublicKey, EMPTY_TRANSCRIPT);
+    // Generate an ephemeral X25519 private key.
+    byte[] serverEphemeralPrivateKey = X25519.generatePrivateKey();
+    // Encrypt the X25519 public key with a key derived from the preSharedSecret and transcript
+    AuthMessage ephemeralServerPublicKey = encryptEphemeralPublicKey(
+        X25519.publicFromPrivate(serverEphemeralPrivateKey),
+        getTranscript(encryptedClientPublicKey));
+    // Compute a shared secret given the client public key and the server private key
+    byte[] sharedSecret =
+        X25519.computeSharedSecret(serverEphemeralPrivateKey, clientPublicKey);
+    byte[] challengeResponseTranscript =
+        getTranscript(encryptedClientPublicKey, ephemeralServerPublicKey);
+    this.sessionCipher =
+        generateTransportCipher(sharedSecret, false, challengeResponseTranscript);
+    return ephemeralServerPublicKey;
   }
 
   /**
    * Validates the server response and initializes the cipher to use for the session.
    *
-   * @param serverResponse The response from the server.
+   * @param encryptedClientPublicKey The encrypted ephemeral public key from the client.
+   * @param encryptedServerPublicKey The encrypted ephemeral public key from the server.
    */
-  void validate(ServerResponse serverResponse) throws GeneralSecurityException {
-    byte[] response = validateChallenge(authNonce, serverResponse.response);
-
-    byte[] expected = rawResponse(challenge);
-    Preconditions.checkArgument(Arrays.equals(expected, response));
-
-    byte[] nonce = decrypt(serverResponse.nonce);
-    byte[] inputIv = decrypt(serverResponse.inputIv);
-    byte[] outputIv = decrypt(serverResponse.outputIv);
-
-    SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
-      nonce, conf.encryptionKeyLength());
-    this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey,
-      inputIv, outputIv);
+  void deriveSessionCipher(AuthMessage encryptedClientPublicKey,
+                           AuthMessage encryptedServerPublicKey) throws GeneralSecurityException {
+    Preconditions.checkArgument(appId.equals(encryptedClientPublicKey.appId));
+    Preconditions.checkArgument(appId.equals(encryptedServerPublicKey.appId));
+    // Compute a shared secret given the server public key and the client private key,
+    // mixing in the protocol transcript.
+    byte[] serverPublicKey = decryptEphemeralPublicKey(
+        encryptedServerPublicKey,
+        getTranscript(encryptedClientPublicKey));
+    // Compute a shared secret given the client public key and the server private key
+    byte[] sharedSecret = X25519.computeSharedSecret(clientPrivateKey, serverPublicKey);
+    byte[] challengeResponseTranscript =
+        getTranscript(encryptedClientPublicKey, encryptedServerPublicKey);
+    this.sessionCipher =
+        generateTransportCipher(sharedSecret, true, challengeResponseTranscript);
+  }
+
+  private TransportCipher generateTransportCipher(
+      byte[] sharedSecret,
+      boolean isClient,
+      byte[] transcript) throws GeneralSecurityException {
+    byte[] clientIv = Hkdf.computeHkdf(
+        MAC_ALGORITHM,
+        sharedSecret,
+        transcript,  // Passing this as the HKDF salt
+        INPUT_IV_INFO,  // This is the HKDF info field used to differentiate IV values
+        AES_GCM_KEY_SIZE_BYTES);
+    byte[] serverIv = Hkdf.computeHkdf(
+        MAC_ALGORITHM,
+        sharedSecret,
+        transcript,  // Passing this as the HKDF salt
+        OUTPUT_IV_INFO,  // This is the HKDF info field used to differentiate IV values
+        AES_GCM_KEY_SIZE_BYTES);
+    SecretKeySpec sessionKey = new SecretKeySpec(sharedSecret, "AES");
+    return new TransportCipher(
+        cryptoConf,
+        conf.cipherTransformation(),
+        sessionKey,
+        isClient ? clientIv : serverIv,  // If it's the client, use the client IV first
+        isClient ? serverIv : clientIv);
+  }
+
+  private byte[] getTranscript(AuthMessage... encryptedPublicKeys) {
+    ByteBuf transcript = Unpooled.buffer(
+        Arrays.stream(encryptedPublicKeys).mapToInt(k -> k.encodedLength()).sum());
+    Arrays.stream(encryptedPublicKeys).forEachOrdered(k -> k.encode(transcript));
+    return transcript.array();
   }
 
   TransportCipher sessionCipher() {
@@ -153,163 +235,7 @@ class AuthEngine implements Closeable {
   }
 
   @Override
-  public void close() throws IOException {
-    // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that
-    // internal state is cleaned up. Error handling here is just for paranoia, and not meant to
-    // accurately report the errors when they happen.
-    RuntimeException error = null;
-    byte[] dummy = new byte[8];
-    if (encryptor != null) {
-      try {
-        doCipherOp(Cipher.ENCRYPT_MODE, dummy, true);
-      } catch (Exception e) {
-        error = new RuntimeException(e);
-      }
-      encryptor = null;
-    }
-    if (decryptor != null) {
-      try {
-        doCipherOp(Cipher.DECRYPT_MODE, dummy, true);
-      } catch (Exception e) {
-        error = new RuntimeException(e);
-      }
-      decryptor = null;
-    }
-    random.close();
-
-    if (error != null) {
-      throw error;
-    }
-  }
+  public void close() {
 
-  @VisibleForTesting
-  byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException {
-    return encrypt(Bytes.concat(appId, nonce, challenge));
   }
-
-  @VisibleForTesting
-  byte[] rawResponse(byte[] challenge) {
-    BigInteger orig = new BigInteger(challenge);
-    BigInteger response = orig.add(ONE);
-    return response.toByteArray();
-  }
-
-  private byte[] decrypt(byte[] in) throws GeneralSecurityException {
-    return doCipherOp(Cipher.DECRYPT_MODE, in, false);
-  }
-
-  private byte[] encrypt(byte[] in) throws GeneralSecurityException {
-    return doCipherOp(Cipher.ENCRYPT_MODE, in, false);
-  }
-
-  private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key)
-    throws GeneralSecurityException {
-
-    // commons-crypto currently only supports ciphers that require an initial vector; so
-    // create a dummy vector so that we can initialize the ciphers. In the future, if
-    // different ciphers are supported, this will have to be configurable somehow.
-    byte[] iv = new byte[conf.ivLength()];
-    System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));
-
-    CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
-    _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
-    this.encryptor = _encryptor;
-
-    CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
-    _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
-    this.decryptor = _decryptor;
-  }
-
-  /**
-   * Validates an encrypted challenge as defined in the protocol, and returns the byte array
-   * that corresponds to the actual challenge data.
-   */
-  private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge)
-    throws GeneralSecurityException {
-
-    byte[] challenge = decrypt(encryptedChallenge);
-    checkSubArray(appId, challenge, 0);
-    checkSubArray(nonce, challenge, appId.length);
-    return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length);
-  }
-
-  private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength)
-    throws GeneralSecurityException {
-
-    SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf);
-    PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength);
-
-    long start = System.nanoTime();
-    SecretKey key = factory.generateSecret(spec);
-    long end = System.nanoTime();
-
-    LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(),
-      (end - start) / 1000);
-
-    return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
-  }
-
-  private byte[] doCipherOp(int mode, byte[] in, boolean isFinal)
-    throws GeneralSecurityException {
-
-    CryptoCipher cipher;
-    switch (mode) {
-      case Cipher.ENCRYPT_MODE:
-        cipher = encryptor;
-        break;
-      case Cipher.DECRYPT_MODE:
-        cipher = decryptor;
-        break;
-      default:
-        throw new IllegalArgumentException(String.valueOf(mode));
-    }
-
-    Preconditions.checkState(cipher != null, "Cipher is invalid because of previous error.");
-
-    try {
-      int scale = 1;
-      while (true) {
-        int size = in.length * scale;
-        byte[] buffer = new byte[size];
-        try {
-          int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
-            : cipher.update(in, 0, in.length, buffer, 0);
-          if (outSize != buffer.length) {
-            byte[] output = new byte[outSize];
-            System.arraycopy(buffer, 0, output, 0, output.length);
-            return output;
-          } else {
-            return buffer;
-          }
-        } catch (ShortBufferException e) {
-          // Try again with a bigger buffer.
-          scale *= 2;
-        }
-      }
-    } catch (InternalError ie) {
-      // SPARK-25535. The commons-cryto library will throw InternalError if something goes wrong,
-      // and leave bad state behind in the Java wrappers, so it's not safe to use them afterwards.
-      if (mode == Cipher.ENCRYPT_MODE) {
-        this.encryptor = null;
-      } else {
-        this.decryptor = null;
-      }
-      throw ie;
-    }
-  }
-
-  private byte[] randomBytes(int count) {
-    byte[] bytes = new byte[count];
-    random.nextBytes(bytes);
-    return bytes;
-  }
-
-  /** Checks that the "test" array is in the data array starting at the given offset. */
-  private void checkSubArray(byte[] test, byte[] data, int offset) {
-    Preconditions.checkArgument(data.length >= test.length + offset);
-    for (int i = 0; i < test.length; i++) {
-      Preconditions.checkArgument(test[i] == data[i + offset]);
-    }
-  }
-
 }
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthMessage.java
similarity index 53%
rename from common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
rename to common/network-common/src/main/java/org/apache/spark/network/crypto/AuthMessage.java
index caf3a0f..76690cb 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthMessage.java
@@ -21,65 +21,55 @@ import java.nio.ByteBuffer;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
-
 import org.apache.spark.network.protocol.Encodable;
 import org.apache.spark.network.protocol.Encoders;
 
 /**
- * Server's response to client's challenge.
+ * A message sent in the forward secure authentication protocol, containing an app ID, a salt for
+ * key derivation, and an encrypted payload.
  *
- * Please see crypto/README.md for more details.
+ * Please see crypto/README.md for more details of implementation.
  */
-public class ServerResponse implements Encodable {
+class AuthMessage implements Encodable {
   /** Serialization tag used to catch incorrect payloads. */
   private static final byte TAG_BYTE = (byte) 0xFB;
 
-  public final byte[] response;
-  public final byte[] nonce;
-  public final byte[] inputIv;
-  public final byte[] outputIv;
+  public final String appId;
+  public final byte[] salt;
+  public final byte[] ciphertext;
 
-  public ServerResponse(
-      byte[] response,
-      byte[] nonce,
-      byte[] inputIv,
-      byte[] outputIv) {
-    this.response = response;
-    this.nonce = nonce;
-    this.inputIv = inputIv;
-    this.outputIv = outputIv;
+  AuthMessage(String appId, byte[] salt, byte[] ciphertext) {
+    this.appId = appId;
+    this.salt = salt;
+    this.ciphertext = ciphertext;
   }
 
   @Override
   public int encodedLength() {
     return 1 +
-      Encoders.ByteArrays.encodedLength(response) +
-      Encoders.ByteArrays.encodedLength(nonce) +
-      Encoders.ByteArrays.encodedLength(inputIv) +
-      Encoders.ByteArrays.encodedLength(outputIv);
+      Encoders.Strings.encodedLength(appId) +
+      Encoders.ByteArrays.encodedLength(salt) +
+      Encoders.ByteArrays.encodedLength(ciphertext);
   }
 
   @Override
   public void encode(ByteBuf buf) {
     buf.writeByte(TAG_BYTE);
-    Encoders.ByteArrays.encode(buf, response);
-    Encoders.ByteArrays.encode(buf, nonce);
-    Encoders.ByteArrays.encode(buf, inputIv);
-    Encoders.ByteArrays.encode(buf, outputIv);
+    Encoders.Strings.encode(buf, appId);
+    Encoders.ByteArrays.encode(buf, salt);
+    Encoders.ByteArrays.encode(buf, ciphertext);
   }
 
-  public static ServerResponse decodeMessage(ByteBuffer buffer) {
+  public static AuthMessage decodeMessage(ByteBuffer buffer) {
     ByteBuf buf = Unpooled.wrappedBuffer(buffer);
 
     if (buf.readByte() != TAG_BYTE) {
-      throw new IllegalArgumentException("Expected ServerResponse, received something else.");
+      throw new IllegalArgumentException("Expected ClientChallenge, received something else.");
     }
 
-    return new ServerResponse(
-      Encoders.ByteArrays.decode(buf),
-      Encoders.ByteArrays.decode(buf),
-      Encoders.ByteArrays.decode(buf),
-      Encoders.ByteArrays.decode(buf));
+    return new AuthMessage(
+      Encoders.Strings.decode(buf),  // AppID
+      Encoders.ByteArrays.decode(buf),  // Salt
+      Encoders.ByteArrays.decode(buf));  // Ciphertext
   }
-
 }
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
index dd31c95..549ee4d 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
@@ -84,9 +84,9 @@ class AuthRpcHandler extends AbstractAuthRpcHandler {
     int position = message.position();
     int limit = message.limit();
 
-    ClientChallenge challenge;
+    AuthMessage challenge;
     try {
-      challenge = ClientChallenge.decodeMessage(message);
+      challenge = AuthMessage.decodeMessage(message);
       LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress());
     } catch (RuntimeException e) {
       if (conf.saslFallback()) {
@@ -113,7 +113,7 @@ class AuthRpcHandler extends AbstractAuthRpcHandler {
         "Trying to authenticate non-registered app %s.", challenge.appId);
       LOG.debug("Authenticating challenge for app {}.", challenge.appId);
       engine = new AuthEngine(challenge.appId, secret, conf);
-      ServerResponse response = engine.respond(challenge);
+      AuthMessage response = engine.response(challenge);
       ByteBuf responseData = Unpooled.buffer(response.encodedLength());
       response.encode(responseData);
       callback.onSuccess(responseData.nioBuffer());
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
deleted file mode 100644
index 819b8a7..0000000
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.crypto;
-
-import java.nio.ByteBuffer;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-
-import org.apache.spark.network.protocol.Encodable;
-import org.apache.spark.network.protocol.Encoders;
-
-/**
- * The client challenge message, used to initiate authentication.
- *
- * Please see crypto/README.md for more details of implementation.
- */
-public class ClientChallenge implements Encodable {
-  /** Serialization tag used to catch incorrect payloads. */
-  private static final byte TAG_BYTE = (byte) 0xFA;
-
-  public final String appId;
-  public final String kdf;
-  public final int iterations;
-  public final String cipher;
-  public final int keyLength;
-  public final byte[] nonce;
-  public final byte[] challenge;
-
-  public ClientChallenge(
-      String appId,
-      String kdf,
-      int iterations,
-      String cipher,
-      int keyLength,
-      byte[] nonce,
-      byte[] challenge) {
-    this.appId = appId;
-    this.kdf = kdf;
-    this.iterations = iterations;
-    this.cipher = cipher;
-    this.keyLength = keyLength;
-    this.nonce = nonce;
-    this.challenge = challenge;
-  }
-
-  @Override
-  public int encodedLength() {
-    return 1 + 4 + 4 +
-      Encoders.Strings.encodedLength(appId) +
-      Encoders.Strings.encodedLength(kdf) +
-      Encoders.Strings.encodedLength(cipher) +
-      Encoders.ByteArrays.encodedLength(nonce) +
-      Encoders.ByteArrays.encodedLength(challenge);
-  }
-
-  @Override
-  public void encode(ByteBuf buf) {
-    buf.writeByte(TAG_BYTE);
-    Encoders.Strings.encode(buf, appId);
-    Encoders.Strings.encode(buf, kdf);
-    buf.writeInt(iterations);
-    Encoders.Strings.encode(buf, cipher);
-    buf.writeInt(keyLength);
-    Encoders.ByteArrays.encode(buf, nonce);
-    Encoders.ByteArrays.encode(buf, challenge);
-  }
-
-  public static ClientChallenge decodeMessage(ByteBuffer buffer) {
-    ByteBuf buf = Unpooled.wrappedBuffer(buffer);
-
-    if (buf.readByte() != TAG_BYTE) {
-      throw new IllegalArgumentException("Expected ClientChallenge, received something else.");
-    }
-
-    return new ClientChallenge(
-      Encoders.Strings.decode(buf),
-      Encoders.Strings.decode(buf),
-      buf.readInt(),
-      Encoders.Strings.decode(buf),
-      buf.readInt(),
-      Encoders.ByteArrays.decode(buf),
-      Encoders.ByteArrays.decode(buf));
-  }
-
-}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
index 14df703..78e7459 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md
@@ -1,158 +1,101 @@
-Spark Auth Protocol and AES Encryption Support
+Forward Secure Auth Protocol
 ==============================================
 
-This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This
-protocol is built on symmetric key encryption, based on the assumption that the two endpoints being
-authenticated share a common secret, which is how Spark authentication currently works. The protocol
-provides mutual authentication, meaning that after the negotiation both parties know that the remote
-side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's
-not an implementation of it.
+This file describes a forward secure authentication protocol which may be used by Spark. This
+protocol is essentially ephemeral Diffie-Hellman key exchange using Curve25519, referred to as
+X25519.
 
-This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently
-released JREs.
+Both client and server share a (possibly low-entropy) pre-shared secret that is used to derive a
+key-encrypting key using HKDF. This will mix in any preceding protocol transcript.
 
-The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5:
-
-- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired.
-- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only
-  viable, supported cipher these days is 3DES, and a more modern alternative is desired.
-- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link
-  in the negotiation would still be MD5 and 3DES.
-
-The protocol assumes that the shared secret is generated and distributed in a secure manner.
-
-The protocol always negotiates encryption keys. If encryption is not desired, the existing
-SASL-based authentication, or no authentication at all, can be chosen instead.
-
-When messages are described below, it's expected that the implementation should support
-arbitrary sizes for fields that don't have a fixed size.
+The key-encrypting key is used to encrypt an X25519 public key with AES-GCM. This is intended to
+authenticate the message exchange between the parties and there is no expectation of secrecy for
+the public key. This protocol utilizes GCM's associated authenticated data (AAD) field to include
+metadata and the prior protocol transcript, to bind each round with all preceding rounds.
 
 Client Challenge
 ----------------
 
-The auth negotiation is started by the client. The client starts by generating an encryption
-key based on the application's shared secret, and a nonce.
-
-    KEY = KDF(SECRET, SALT, KEY_LENGTH)
-
-Where:
-- KDF(): a key derivation function that takes a secret, a salt, a configurable number of
-  iterations, and a configurable key length.
-- SALT: a byte sequence used to salt the key derivation function.
-- KEY_LENGTH: length of the encryption key to generate.
-
+The auth negotiation is started by the client. Given an application ID, the client starts by
+generating a random 16-byte salt value and deriving a key encryption key:
 
-The client generates a message with the following content:
+    preSharedKey = lookupKey(appId)
+    nonSecretSalt = Random(16 bytes)
+    aadState = Concat(appId, nonSecretSalt)
+    keyEncryptingKey = HKDF(preSharedKey, nonSecretSalt, aadState)
 
-    CLIENT_CHALLENGE = (
-        APP_ID,
-        KDF,
-        ITERATIONS,
-        CIPHER,
-        KEY_LENGTH,
-        ANONCE,
-        ENC(APP_ID || ANONCE || CHALLENGE))
+This key encryption key is then used to encrypt an ephemeral X25519 public key.
 
-Where:
+    clientKeyPair = X25519.generate()
+    randomIV = Random(16 bytes)
+    ciphertext = AES-GCM-Encrypt(
+      key = keyEncryptingKey,
+      iv = randomIV,
+      plaintext = clientKeyPair.publicKey(),
+      aad = aadState)
+    clientChallenge = (appId, nonSecretSalt, randomIV, ciphertext)
 
-- APP_ID: the application ID which the server uses to identify the shared secret.
-- KDF: the key derivation function described above.
-- ITERATIONS: number of iterations to run the KDF when generating keys.
-- CIPHER: the cipher used to encrypt data.
-- KEY_LENGTH: length of the encryption keys to generate, in bits.
-- ANONCE: the nonce used as the salt when generating the auth key.
-- ENC(): an encryption function that uses the cipher and the generated key. This function
-  will also be used in the definition of other messages below.
-- CHALLENGE: a byte sequence used as a challenge to the server.
-- ||: concatenation operator.
-
-When strings are used where byte arrays are expected, the UTF-8 representation of the string
-is assumed.
-
-To respond to the challenge, the server should consider the byte array as representing an
-arbitrary-length integer, and respond with the value of the integer plus one.
+Note that the App ID and non-secret salt are bound to the ciphertext both through HKDF key
+derivation and AES-GCM AAD. We are not relying on keeping the client public key secret and could
+alternatively compute a MAC rather than encrypting with AES-GCM.
 
+The client sends this challenge to a server.
 
 Server Response And Challenge
 -----------------------------
 
-Once the client challenge is received, the server will generate the same auth key by
-using the same algorithm the client has used. It will then verify the client challenge:
-if the APP_ID and ANONCE fields match, the server knows that the client has the shared
-secret. The server then creates a response to the client challenge, to prove that it also
-has the secret key, and provides parameters to be used when creating the session key.
-
-The following describes the response from the server:
-
-    SERVER_CHALLENGE = (
-        ENC(APP_ID || ANONCE || RESPONSE),
-        ENC(SNONCE),
-        ENC(INIV),
-        ENC(OUTIV))
-
-Where:
-
-- RESPONSE: the server's response to the client challenge.
-- SNONCE: a nonce to be used as salt when generating the session key.
-- INIV: initialization vector used to initialize the input channel of the client.
-- OUTIV: initialization vector used to initialize the output channel of the client.
-
-At this point the server considers the client to be authenticated, and will try to
-decrypt any data further sent by the client using the session key.
-
-
-Default Algorithms
-------------------
-
-Configuration options are available for the KDF and cipher algorithms to use.
-
-The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm
-from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support
-PBEKeySpec when generating keys. The default number of iterations was chosen to take a
-reasonable amount of time on modern CPUs. See the documentation in TransportConf for more
-details.
-
-The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any
-algorithm supported by the commons-crypto library. It should allow the cipher to operate
-in stream mode.
-
-The default key length is 128 (bits).
-
-
-Implementation Details
-----------------------
-
-The commons-crypto library currently only supports AES ciphers, and requires an initialization
-vector (IV). This first version of the protocol does not explicitly include the IV in the client
-challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and
-padding the IV with zeroes in case the nonce is not long enough.
-
-Future versions of the protocol might add support for new ciphers and explicitly include needed
-configuration parameters in the messages.
-
-
-Threat Assessment
+Once the client challenge is received, the server will derive the same key encryption key and
+recover the client's public key:
+
+    assert(appId = clientChallenge.appId)
+    preSharedKey = lookupKey(appId)
+    aadState = Concat(appId, clientChallenge.nonSecretSalt)
+    keyEncryptingKey = HKDF(preSharedKey, nonSecretSalt, aadState)
+    clientPublicKey = AES-GCM-Decrypt(
+      key = keyEncryptingKey,
+      iv = clientChallenge.randomIV,
+      ciphertext = clientChallenge.ciphertext,
+      aad = aadState)
+
+The server can then send its own ephemeral public key to the client, encrypted under a key derived
+from the pre-shared key and the protocol transcript so far:
+
+    preSharedKey = lookupKey(appId)
+    nonSecretSalt = Random(16 bytes)
+    aadState = Concat(appId, nonSecretSalt, clientChallenge)
+    keyEncryptingKey = HKDF(preSharedKey, nonSecretSalt, aadState)
+    randomIV = Random(16 bytes)
+    serverKeyPair = X25519.generate()
+    ciphertext = AES-GCM-Encrypt(
+      key = keyEncryptingKey,
+      iv = randomIV,
+      plaintext = serverKeyPair.publicKey(),
+      aad = aadState)
+    serverResponse = (appId, nonSecretSalt, randomIV, ciphertext)
+
+Now that the server has the client's ephemeral public key, it can generate its own ephemeral
+keypair and compute a shared secret.
+
+    sharedSecret = X25519.computeSharedSecret(clientPublicKey, serverKeyPair.privateKey())
+
+With the shared secret, the server will also generate two initialization vectors to be used for
+inbound and outbound streams. These IVs are not secret and will be bound to the preceding protocol
+transcript in order to be deterministic by both parties.
+
+    clientIv = HKDF(sharedSecret, salt=transcript, info="clientIv")
+    serverIv = HKDF(sharedSecret, salt=transcript, info="serverIv")
+
+The server can then send its response to the client, who can decrypt the server's ephemeral public
+key, and reconstruct the same shared secret and IVs.
+
+Security Comments
 -----------------
 
-The protocol is secure against different forms of attack:
-
-* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible
-  to calculate the original secret from the encrypted messages. Neither the secret nor any
-  encryption keys are transmitted on the wire, encrypted or not.
-
-* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to
-  know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a
-  malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged
-  between client and server, nor insert arbitrary commands for the server to execute.
+This protocol is essentially a [NNpsk0](http://www.noiseprotocol.org/noise.html#pattern-modifiers)
+pattern in the [Noise framework](http://www.noiseprotocol.org/) built around ECDHE using X25519 as
+the underlying curve. If the pre-shared key is compromised, it does not allow for recovery of past
+sessions. It would, however, allow impersonation of future sessions.
 
-* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to
-  just replay messages sniffed from the communication channel.
+In the event of a pre-shared key compromise, messages would still be confidential from a passive
+observer. Only active adversaries spoofing a session would be able to recover plaintext.
 
-An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the
-shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be
-able to generate a session key, which will make it hard to craft a valid, encrypted message that the
-server will be able to understand. This will cause the server to close the connection as soon as the
-attacker tries to send any command to the server. The attacker can just hold the channel open for
-some time, which will be closed when the server times out the channel. These issues could be
-separately mitigated by adding a shorter timeout for the first message after authentication, and
-potentially by adding host blacklists if a possible attack is detected from a particular host.
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
index 382b733..fbd8a55 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
@@ -19,30 +19,42 @@ package org.apache.spark.network.crypto;
 
 import java.nio.ByteBuffer;
 import java.nio.channels.WritableByteChannel;
+import java.security.GeneralSecurityException;
 import java.util.Arrays;
 import java.util.Map;
-import java.security.InvalidKeyException;
 import java.util.Random;
 
-import static java.nio.charset.StandardCharsets.UTF_8;
-
 import com.google.common.collect.ImmutableMap;
+import com.google.crypto.tink.subtle.Hex;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.FileRegion;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import static org.junit.Assert.*;
 import org.junit.BeforeClass;
 import org.junit.Test;
+import static org.mockito.Mockito.*;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
-
-import org.apache.spark.network.util.ByteArrayWritableChannel;
-import org.apache.spark.network.util.MapConfigProvider;
-import org.apache.spark.network.util.TransportConf;
 
 public class AuthEngineSuite {
 
+  private static final String clientPrivate =
+      "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186";
+  private static final String clientChallengeHex =
+      "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" +
+      "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" +
+      "65f8c426e18ff380f6";
+  private static final String serverResponseHex =
+      "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" +
+      "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" +
+      "08ecad08b46b5ee3ff";
+  private static final String sharedKey =
+      "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31";
+  private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2";
+  private static final String outputIv = "a72709baf00785cad6329ce09f631f71";
   private static TransportConf conf;
 
   @BeforeClass
@@ -56,9 +68,9 @@ public class AuthEngineSuite {
     AuthEngine server = new AuthEngine("appId", "secret", conf);
 
     try {
-      ClientChallenge clientChallenge = client.challenge();
-      ServerResponse serverResponse = server.respond(clientChallenge);
-      client.validate(serverResponse);
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
 
       TransportCipher serverCipher = server.sessionCipher();
       TransportCipher clientCipher = client.sessionCipher();
@@ -72,50 +84,113 @@ public class AuthEngineSuite {
     }
   }
 
-  @Test
-  public void testMismatchedSecret() throws Exception {
-    AuthEngine client = new AuthEngine("appId", "secret", conf);
-    AuthEngine server = new AuthEngine("appId", "different_secret", conf);
+  @Test(expected = IllegalArgumentException.class)
+  public void testCorruptChallengeAppId() throws Exception {
+
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage corruptChallenge =
+              new AuthMessage("junk", clientChallenge.salt, clientChallenge.ciphertext);
+      AuthMessage serverResponse = server.response(corruptChallenge);
+    }
+  }
 
-    ClientChallenge clientChallenge = client.challenge();
-    try {
-      server.respond(clientChallenge);
-      fail("Should have failed to validate response.");
-    } catch (IllegalArgumentException e) {
-      // Expected.
+  @Test(expected = GeneralSecurityException.class)
+  public void testCorruptChallengeSalt() throws Exception {
+
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      clientChallenge.salt[0] ^= 1;
+      AuthMessage serverResponse = server.response(clientChallenge);
     }
   }
 
-  @Test(expected = IllegalArgumentException.class)
-  public void testWrongAppId() throws Exception {
-    AuthEngine engine = new AuthEngine("appId", "secret", conf);
-    ClientChallenge challenge = engine.challenge();
-
-    byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce,
-      engine.rawResponse(engine.challenge));
-    engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
-      challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+  @Test(expected = GeneralSecurityException.class)
+  public void testCorruptChallengeCiphertext() throws Exception {
+
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      clientChallenge.ciphertext[0] ^= 1;
+      AuthMessage serverResponse = server.response(clientChallenge);
+    }
   }
 
   @Test(expected = IllegalArgumentException.class)
-  public void testWrongNonce() throws Exception {
-    AuthEngine engine = new AuthEngine("appId", "secret", conf);
-    ClientChallenge challenge = engine.challenge();
-
-    byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 },
-      engine.rawResponse(engine.challenge));
-    engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
-      challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+  public void testCorruptResponseAppId() throws Exception {
+
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      AuthMessage corruptResponse =
+              new AuthMessage("junk", serverResponse.salt, serverResponse.ciphertext);
+      client.deriveSessionCipher(clientChallenge, corruptResponse);
+    }
   }
 
-  @Test(expected = IllegalArgumentException.class)
-  public void testBadChallenge() throws Exception {
-    AuthEngine engine = new AuthEngine("appId", "secret", conf);
-    ClientChallenge challenge = engine.challenge();
+  @Test(expected = GeneralSecurityException.class)
+  public void testCorruptResponseSalt() throws Exception {
+
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      serverResponse.salt[0] ^= 1;
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+    }
+  }
+
+  @Test(expected = GeneralSecurityException.class)
+  public void testCorruptServerCiphertext() throws Exception {
+
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      serverResponse.ciphertext[0] ^= 1;
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+    }
+  }
+
+  @Test
+  public void testFixedChallenge() throws Exception {
+    try (AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge =
+              AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex)));
+      // This tests that the server will accept an old challenge as expected. However,
+      // it will generate a fresh ephemeral keypair, so we can't replay an old session.
+      AuthMessage freshServerResponse = server.response(clientChallenge);
+    }
+  }
 
-    byte[] badChallenge = new byte[challenge.challenge.length];
-    engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
-      challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
+  @Test
+  public void testFixedChallengeResponse() throws Exception {
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf)) {
+      byte[] clientPrivateKey = Hex.decode(clientPrivate);
+      client.setClientPrivateKey(clientPrivateKey);
+      AuthMessage clientChallenge =
+              AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex)));
+      AuthMessage serverResponse =
+              AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex)));
+      // Verify that the client will accept an old transcript.
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      TransportCipher clientCipher = client.sessionCipher();
+      assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), sharedKey);
+      assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv);
+      assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv);
+    }
+  }
+
+  @Test(expected = GeneralSecurityException.class)
+  public void testMismatchedSecret() throws Exception {
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "different_secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      server.response(clientChallenge);
+    }
   }
 
   @Test
@@ -123,9 +198,9 @@ public class AuthEngineSuite {
     AuthEngine client = new AuthEngine("appId", "secret", conf);
     AuthEngine server = new AuthEngine("appId", "secret", conf);
     try {
-      ClientChallenge clientChallenge = client.challenge();
-      ServerResponse serverResponse = server.respond(clientChallenge);
-      client.validate(serverResponse);
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
 
       TransportCipher cipher = server.sessionCipher();
       TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
@@ -151,9 +226,9 @@ public class AuthEngineSuite {
     AuthEngine client = new AuthEngine("appId", "secret", conf);
     AuthEngine server = new AuthEngine("appId", "secret", conf);
     try {
-      ClientChallenge clientChallenge = client.challenge();
-      ServerResponse serverResponse = server.respond(clientChallenge);
-      client.validate(serverResponse);
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
 
       TransportCipher cipher = server.sessionCipher();
       TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
@@ -193,7 +268,7 @@ public class AuthEngineSuite {
     }
   }
 
-  @Test(expected = InvalidKeyException.class)
+  @Test(expected = AssertionError.class)
   public void testBadKeySize() throws Exception {
     Map<String, String> mconf = ImmutableMap.of("spark.network.crypto.keyLength", "42");
     TransportConf conf = new TransportConf("rpc", new MapConfigProvider(mconf));
@@ -201,7 +276,6 @@ public class AuthEngineSuite {
     try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) {
       engine.challenge();
       fail("Should have failed to create challenge message.");
-
       // Call close explicitly to make sure it's idempotent.
       engine.close();
     }
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java
index a90ff24..baed940 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java
@@ -17,15 +17,11 @@
 
 package org.apache.spark.network.crypto;
 
-import java.nio.ByteBuffer;
-import java.util.Arrays;
-
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
 import org.junit.Test;
-import static org.junit.Assert.*;
-
-import org.apache.spark.network.protocol.Encodable;
 
 public class AuthMessagesSuite {
 
@@ -42,39 +38,15 @@ public class AuthMessagesSuite {
     } return bytes;
   }
 
-  private static int integer() {
-    return COUNTER++;
-  }
-
-  @Test
-  public void testClientChallenge() {
-    ClientChallenge msg = new ClientChallenge(string(), string(), integer(), string(), integer(),
-      byteArray(), byteArray());
-    ClientChallenge decoded = ClientChallenge.decodeMessage(encode(msg));
-
-    assertEquals(msg.appId, decoded.appId);
-    assertEquals(msg.kdf, decoded.kdf);
-    assertEquals(msg.iterations, decoded.iterations);
-    assertEquals(msg.cipher, decoded.cipher);
-    assertEquals(msg.keyLength, decoded.keyLength);
-    assertTrue(Arrays.equals(msg.nonce, decoded.nonce));
-    assertTrue(Arrays.equals(msg.challenge, decoded.challenge));
-  }
-
   @Test
-  public void testServerResponse() {
-    ServerResponse msg = new ServerResponse(byteArray(), byteArray(), byteArray(), byteArray());
-    ServerResponse decoded = ServerResponse.decodeMessage(encode(msg));
-    assertTrue(Arrays.equals(msg.response, decoded.response));
-    assertTrue(Arrays.equals(msg.nonce, decoded.nonce));
-    assertTrue(Arrays.equals(msg.inputIv, decoded.inputIv));
-    assertTrue(Arrays.equals(msg.outputIv, decoded.outputIv));
-  }
-
-  private ByteBuffer encode(Encodable msg) {
+  public void testPublicKeyEncodeDecode() {
+    AuthMessage msg = new AuthMessage(string(), byteArray(), byteArray());
     ByteBuf buf = Unpooled.buffer();
     msg.encode(buf);
-    return buf.nioBuffer();
-  }
+    AuthMessage decoded = AuthMessage.decodeMessage(buf.nioBuffer());
 
+    assertEquals(msg.appId, decoded.appId);
+    assertArrayEquals(msg.salt, decoded.salt);
+    assertArrayEquals(msg.ciphertext, decoded.ciphertext);
+  }
 }
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index a83dc31..57fb136 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -188,6 +188,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar
 stream/2.7.0//stream-2.7.0.jar
 stringtemplate/3.2.1//stringtemplate-3.2.1.jar
 super-csv/2.2.0//super-csv-2.2.0.jar
+tink/1.6.0//tink-1.6.0.jar
 univocity-parsers/2.7.3//univocity-parsers-2.7.3.jar
 validation-api/1.1.0.Final//validation-api-1.1.0.Final.jar
 xbean-asm6-shaded/4.8//xbean-asm6-shaded-4.8.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index eb6305d..29077c7 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -189,6 +189,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar
 stream/2.7.0//stream-2.7.0.jar
 stringtemplate/3.2.1//stringtemplate-3.2.1.jar
 super-csv/2.2.0//super-csv-2.2.0.jar
+tink/1.6.0//tink-1.6.0.jar
 univocity-parsers/2.7.3//univocity-parsers-2.7.3.jar
 validation-api/1.1.0.Final//validation-api-1.1.0.Final.jar
 xbean-asm6-shaded/4.8//xbean-asm6-shaded-4.8.jar
diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1
index b9db185..0552b6c 100644
--- a/dev/deps/spark-deps-hadoop-3.1
+++ b/dev/deps/spark-deps-hadoop-3.1
@@ -209,6 +209,7 @@ stax2-api/3.1.4//stax2-api-3.1.4.jar
 stream/2.7.0//stream-2.7.0.jar
 stringtemplate/3.2.1//stringtemplate-3.2.1.jar
 super-csv/2.2.0//super-csv-2.2.0.jar
+tink/1.6.0//tink-1.6.0.jar
 token-provider/1.0.1//token-provider-1.0.1.jar
 univocity-parsers/2.7.3//univocity-parsers-2.7.3.jar
 validation-api/1.1.0.Final//validation-api-1.1.0.Final.jar
diff --git a/pom.xml b/pom.xml
index 6fe0a16..889776f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -187,6 +187,7 @@
     <paranamer.version>2.8</paranamer.version>
     <maven-antrun.version>1.8</maven-antrun.version>
     <commons-crypto.version>1.0.0</commons-crypto.version>
+    <tink.version>1.6.0</tink.version>
     <!--
     If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py,
     ./python/run-tests.py and ./python/setup.py too.
@@ -2028,6 +2029,11 @@
         </exclusions>
       </dependency>
       <dependency>
+        <groupId>com.google.crypto.tink</groupId>
+        <artifactId>tink</artifactId>
+        <version>${tink.version}</version>
+      </dependency>
+      <dependency>
         <groupId>com.thoughtworks.paranamer</groupId>
         <artifactId>paranamer</artifactId>
         <version>${paranamer.version}</version>

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org