You are viewing a plain text version of this content. The canonical link for it is here.
Posted to common-commits@hadoop.apache.org by wa...@apache.org on 2015/02/26 06:18:15 UTC

hadoop git commit: HADOOP-11620. Add support for load balancing across a group of KMS for HA. Contributed by Arun Suresh.

Repository: hadoop
Updated Branches:
  refs/heads/trunk 725cc499f -> 71385f9b7


HADOOP-11620. Add support for load balancing across a group of KMS for HA. Contributed by Arun Suresh.


Project: http://git-wip-us.apache.org/repos/asf/hadoop/repo
Commit: http://git-wip-us.apache.org/repos/asf/hadoop/commit/71385f9b
Tree: http://git-wip-us.apache.org/repos/asf/hadoop/tree/71385f9b
Diff: http://git-wip-us.apache.org/repos/asf/hadoop/diff/71385f9b

Branch: refs/heads/trunk
Commit: 71385f9b70e22618db3f3d2b2c6dca3b1e82c317
Parents: 725cc49
Author: Andrew Wang <wa...@apache.org>
Authored: Wed Feb 25 21:15:44 2015 -0800
Committer: Andrew Wang <wa...@apache.org>
Committed: Wed Feb 25 21:16:37 2015 -0800

----------------------------------------------------------------------
 hadoop-common-project/hadoop-common/CHANGES.txt |   3 +
 .../crypto/key/kms/KMSClientProvider.java       |  84 ++++-
 .../key/kms/LoadBalancingKMSClientProvider.java | 347 +++++++++++++++++++
 .../kms/TestLoadBalancingKMSClientProvider.java | 166 +++++++++
 .../hadoop/crypto/key/kms/server/TestKMS.java   | 114 +++---
 5 files changed, 654 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hadoop/blob/71385f9b/hadoop-common-project/hadoop-common/CHANGES.txt
----------------------------------------------------------------------
diff --git a/hadoop-common-project/hadoop-common/CHANGES.txt b/hadoop-common-project/hadoop-common/CHANGES.txt
index 0d452f7..39062a8 100644
--- a/hadoop-common-project/hadoop-common/CHANGES.txt
+++ b/hadoop-common-project/hadoop-common/CHANGES.txt
@@ -648,6 +648,9 @@ Release 2.7.0 - UNRELEASED
     HADOOP-11506. Configuration variable expansion regex expensive for long
     values. (Gera Shegalov via gera)
 
+    HADOOP-11620. Add support for load balancing across a group of KMS for HA.
+    (Arun Suresh via wang)
+
   BUG FIXES
 
     HADOOP-11512. Use getTrimmedStrings when reading serialization keys

http://git-wip-us.apache.org/repos/asf/hadoop/blob/71385f9b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
----------------------------------------------------------------------
diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
index 97ab253..223e69a 100644
--- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
+++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java
@@ -52,6 +52,7 @@ import java.io.Writer;
 import java.lang.reflect.UndeclaredThrowableException;
 import java.net.HttpURLConnection;
 import java.net.InetSocketAddress;
+import java.net.MalformedURLException;
 import java.net.SocketTimeoutException;
 import java.net.URI;
 import java.net.URISyntaxException;
@@ -74,6 +75,7 @@ import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import com.google.common.base.Strings;
 
 /**
  * KMS client <code>KeyProvider</code> implementation.
@@ -221,14 +223,71 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension,
    */
   public static class Factory extends KeyProviderFactory {
 
+    /**
+     * This provider expects URIs in the following form :
+     * kms://<PROTO>@<AUTHORITY>/<PATH>
+     *
+     * where :
+     * - PROTO = http or https
+     * - AUTHORITY = <HOSTS>[:<PORT>]
+     * - HOSTS = <HOSTNAME>[;<HOSTS>]
+     * - HOSTNAME = string
+     * - PORT = integer
+     *
+     * If multiple hosts are provider, the Factory will create a
+     * {@link LoadBalancingKMSClientProvider} that round-robins requests
+     * across the provided list of hosts.
+     */
     @Override
-    public KeyProvider createProvider(URI providerName, Configuration conf)
+    public KeyProvider createProvider(URI providerUri, Configuration conf)
         throws IOException {
-      if (SCHEME_NAME.equals(providerName.getScheme())) {
-        return new KMSClientProvider(providerName, conf);
+      if (SCHEME_NAME.equals(providerUri.getScheme())) {
+        URL origUrl = new URL(extractKMSPath(providerUri).toString());
+        String authority = origUrl.getAuthority();
+        // check for ';' which delimits the backup hosts
+        if (Strings.isNullOrEmpty(authority)) {
+          throw new IOException(
+              "No valid authority in kms uri [" + origUrl + "]");
+        }
+        // Check if port is present in authority
+        // In the current scheme, all hosts have to run on the same port
+        int port = -1;
+        String hostsPart = authority;
+        if (authority.contains(":")) {
+          String[] t = authority.split(":");
+          try {
+            port = Integer.parseInt(t[1]);
+          } catch (Exception e) {
+            throw new IOException(
+                "Could not parse port in kms uri [" + origUrl + "]");
+          }
+          hostsPart = t[0];
+        }
+        return createProvider(providerUri, conf, origUrl, port, hostsPart);
       }
       return null;
     }
+
+    private KeyProvider createProvider(URI providerUri, Configuration conf,
+        URL origUrl, int port, String hostsPart) throws IOException {
+      String[] hosts = hostsPart.split(";");
+      if (hosts.length == 1) {
+        return new KMSClientProvider(providerUri, conf);
+      } else {
+        KMSClientProvider[] providers = new KMSClientProvider[hosts.length];
+        for (int i = 0; i < hosts.length; i++) {
+          try {
+            providers[i] =
+                new KMSClientProvider(
+                    new URI("kms", origUrl.getProtocol(), hosts[i], port,
+                        origUrl.getPath(), null, null), conf);
+          } catch (URISyntaxException e) {
+            throw new IOException("Could not instantiate KMSProvider..", e);
+          }
+        }
+        return new LoadBalancingKMSClientProvider(providers, conf);
+      }
+    }
   }
 
   public static <T> T checkNotNull(T o, String name)
@@ -302,10 +361,8 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension,
 
   public KMSClientProvider(URI uri, Configuration conf) throws IOException {
     super(conf);
-    Path path = ProviderUtils.unnestUri(uri);
-    URL url = path.toUri().toURL();
-    kmsUrl = createServiceURL(url);
-    if ("https".equalsIgnoreCase(url.getProtocol())) {
+    kmsUrl = createServiceURL(extractKMSPath(uri));
+    if ("https".equalsIgnoreCase(new URL(kmsUrl).getProtocol())) {
       sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, conf);
       try {
         sslFactory.init();
@@ -346,8 +403,12 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension,
             .getCurrentUser();
   }
 
-  private String createServiceURL(URL url) throws IOException {
-    String str = url.toExternalForm();
+  private static Path extractKMSPath(URI uri) throws MalformedURLException, IOException {
+    return ProviderUtils.unnestUri(uri);
+  }
+
+  private static String createServiceURL(Path path) throws IOException {
+    String str = new URL(path.toString()).toExternalForm();
     if (str.endsWith("/")) {
       str = str.substring(0, str.length() - 1);
     }
@@ -853,4 +914,9 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension,
       }
     }
   }
+
+  @VisibleForTesting
+  String getKMSUrl() {
+    return kmsUrl;
+  }
 }

http://git-wip-us.apache.org/repos/asf/hadoop/blob/71385f9b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java
----------------------------------------------------------------------
diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java
new file mode 100644
index 0000000..c1579e7
--- /dev/null
+++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java
@@ -0,0 +1,347 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.crypto.key.kms;
+
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.security.NoSuchAlgorithmException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
+import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.util.Time;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.annotations.VisibleForTesting;
+
+/**
+ * A simple LoadBalancing KMSClientProvider that round-robins requests
+ * across a provided array of KMSClientProviders. It also retries failed
+ * requests on the next available provider in the load balancer group. It
+ * only retries failed requests that result in an IOException, sending back
+ * all other Exceptions to the caller without retry.
+ */
+public class LoadBalancingKMSClientProvider extends KeyProvider implements
+    CryptoExtension,
+    KeyProviderDelegationTokenExtension.DelegationTokenExtension {
+
+  public static Logger LOG =
+      LoggerFactory.getLogger(LoadBalancingKMSClientProvider.class);
+
+  static interface ProviderCallable<T> {
+    public T call(KMSClientProvider provider) throws IOException, Exception;
+  }
+
+  @SuppressWarnings("serial")
+  static class WrapperException extends RuntimeException {
+    public WrapperException(Throwable cause) {
+      super(cause);
+    }
+  }
+
+  private final KMSClientProvider[] providers;
+  private final AtomicInteger currentIdx;
+
+  public LoadBalancingKMSClientProvider(KMSClientProvider[] providers,
+      Configuration conf) {
+    this(shuffle(providers), Time.monotonicNow(), conf);
+  }
+
+  @VisibleForTesting
+  LoadBalancingKMSClientProvider(KMSClientProvider[] providers, long seed,
+      Configuration conf) {
+    super(conf);
+    this.providers = providers;
+    this.currentIdx = new AtomicInteger((int)(seed % providers.length));
+  }
+
+  @VisibleForTesting
+  KMSClientProvider[] getProviders() {
+    return providers;
+  }
+
+  private <T> T doOp(ProviderCallable<T> op, int currPos)
+      throws IOException {
+    IOException ex = null;
+    for (int i = 0; i < providers.length; i++) {
+      KMSClientProvider provider = providers[(currPos + i) % providers.length];
+      try {
+        return op.call(provider);
+      } catch (IOException ioe) {
+        LOG.warn("KMS provider at [{}] threw an IOException [{}]!!",
+            provider.getKMSUrl(), ioe.getMessage());
+        ex = ioe;
+      } catch (Exception e) {
+        if (e instanceof RuntimeException) {
+          throw (RuntimeException)e;
+        } else {
+          throw new WrapperException(e);
+        }
+      }
+    }
+    if (ex != null) {
+      LOG.warn("Aborting since the Request has failed with all KMS"
+          + " providers in the group. !!");
+      throw ex;
+    }
+    throw new IOException("No providers configured !!");
+  }
+
+  private int nextIdx() {
+    while (true) {
+      int current = currentIdx.get();
+      int next = (current + 1) % providers.length;
+      if (currentIdx.compareAndSet(current, next)) {
+        return current;
+      }
+    }
+  }
+
+  @Override
+  public Token<?>[]
+      addDelegationTokens(final String renewer, final Credentials credentials)
+          throws IOException {
+    return doOp(new ProviderCallable<Token<?>[]>() {
+      @Override
+      public Token<?>[] call(KMSClientProvider provider) throws IOException {
+        return provider.addDelegationTokens(renewer, credentials);
+      }
+    }, nextIdx());
+  }
+
+  // This request is sent to all providers in the load-balancing group
+  @Override
+  public void warmUpEncryptedKeys(String... keyNames) throws IOException {
+    for (KMSClientProvider provider : providers) {
+      try {
+        provider.warmUpEncryptedKeys(keyNames);
+      } catch (IOException ioe) {
+        LOG.error(
+            "Error warming up keys for provider with url"
+            + "[" + provider.getKMSUrl() + "]");
+      }
+    }
+  }
+
+  // This request is sent to all providers in the load-balancing group
+  @Override
+  public void drain(String keyName) {
+    for (KMSClientProvider provider : providers) {
+      provider.drain(keyName);
+    }
+  }
+
+  @Override
+  public EncryptedKeyVersion
+      generateEncryptedKey(final String encryptionKeyName)
+          throws IOException, GeneralSecurityException {
+    try {
+      return doOp(new ProviderCallable<EncryptedKeyVersion>() {
+        @Override
+        public EncryptedKeyVersion call(KMSClientProvider provider)
+            throws IOException, GeneralSecurityException {
+          return provider.generateEncryptedKey(encryptionKeyName);
+        }
+      }, nextIdx());
+    } catch (WrapperException we) {
+      throw (GeneralSecurityException) we.getCause();
+    }
+  }
+
+  @Override
+  public KeyVersion
+      decryptEncryptedKey(final EncryptedKeyVersion encryptedKeyVersion)
+          throws IOException, GeneralSecurityException {
+    try {
+      return doOp(new ProviderCallable<KeyVersion>() {
+        @Override
+        public KeyVersion call(KMSClientProvider provider)
+            throws IOException, GeneralSecurityException {
+          return provider.decryptEncryptedKey(encryptedKeyVersion);
+        }
+      }, nextIdx());
+    } catch (WrapperException we) {
+      throw (GeneralSecurityException)we.getCause();
+    }
+  }
+
+  @Override
+  public KeyVersion getKeyVersion(final String versionName) throws IOException {
+    return doOp(new ProviderCallable<KeyVersion>() {
+      @Override
+      public KeyVersion call(KMSClientProvider provider) throws IOException {
+        return provider.getKeyVersion(versionName);
+      }
+    }, nextIdx());
+  }
+
+  @Override
+  public List<String> getKeys() throws IOException {
+    return doOp(new ProviderCallable<List<String>>() {
+      @Override
+      public List<String> call(KMSClientProvider provider) throws IOException {
+        return provider.getKeys();
+      }
+    }, nextIdx());
+  }
+
+  @Override
+  public Metadata[] getKeysMetadata(final String... names) throws IOException {
+    return doOp(new ProviderCallable<Metadata[]>() {
+      @Override
+      public Metadata[] call(KMSClientProvider provider) throws IOException {
+        return provider.getKeysMetadata(names);
+      }
+    }, nextIdx());
+  }
+
+  @Override
+  public List<KeyVersion> getKeyVersions(final String name) throws IOException {
+    return doOp(new ProviderCallable<List<KeyVersion>>() {
+      @Override
+      public List<KeyVersion> call(KMSClientProvider provider)
+          throws IOException {
+        return provider.getKeyVersions(name);
+      }
+    }, nextIdx());
+  }
+
+  @Override
+  public KeyVersion getCurrentKey(final String name) throws IOException {
+    return doOp(new ProviderCallable<KeyVersion>() {
+      @Override
+      public KeyVersion call(KMSClientProvider provider) throws IOException {
+        return provider.getCurrentKey(name);
+      }
+    }, nextIdx());
+  }
+  @Override
+  public Metadata getMetadata(final String name) throws IOException {
+    return doOp(new ProviderCallable<Metadata>() {
+      @Override
+      public Metadata call(KMSClientProvider provider) throws IOException {
+        return provider.getMetadata(name);
+      }
+    }, nextIdx());
+  }
+
+  @Override
+  public KeyVersion createKey(final String name, final byte[] material,
+      final Options options) throws IOException {
+    return doOp(new ProviderCallable<KeyVersion>() {
+      @Override
+      public KeyVersion call(KMSClientProvider provider) throws IOException {
+        return provider.createKey(name, material, options);
+      }
+    }, nextIdx());
+  }
+
+  @Override
+  public KeyVersion createKey(final String name, final Options options)
+      throws NoSuchAlgorithmException, IOException {
+    try {
+      return doOp(new ProviderCallable<KeyVersion>() {
+        @Override
+        public KeyVersion call(KMSClientProvider provider) throws IOException,
+            NoSuchAlgorithmException {
+          return provider.createKey(name, options);
+        }
+      }, nextIdx());
+    } catch (WrapperException e) {
+      throw (NoSuchAlgorithmException)e.getCause();
+    }
+  }
+  @Override
+  public void deleteKey(final String name) throws IOException {
+    doOp(new ProviderCallable<Void>() {
+      @Override
+      public Void call(KMSClientProvider provider) throws IOException {
+        provider.deleteKey(name);
+        return null;
+      }
+    }, nextIdx());
+  }
+  @Override
+  public KeyVersion rollNewVersion(final String name, final byte[] material)
+      throws IOException {
+    return doOp(new ProviderCallable<KeyVersion>() {
+      @Override
+      public KeyVersion call(KMSClientProvider provider) throws IOException {
+        return provider.rollNewVersion(name, material);
+      }
+    }, nextIdx());
+  }
+
+  @Override
+  public KeyVersion rollNewVersion(final String name)
+      throws NoSuchAlgorithmException, IOException {
+    try {
+      return doOp(new ProviderCallable<KeyVersion>() {
+        @Override
+        public KeyVersion call(KMSClientProvider provider) throws IOException,
+        NoSuchAlgorithmException {
+          return provider.rollNewVersion(name);
+        }
+      }, nextIdx());
+    } catch (WrapperException e) {
+      throw (NoSuchAlgorithmException)e.getCause();
+    }
+  }
+
+  // Close all providers in the LB group
+  @Override
+  public void close() throws IOException {
+    for (KMSClientProvider provider : providers) {
+      try {
+        provider.close();
+      } catch (IOException ioe) {
+        LOG.error("Error closing provider with url"
+            + "[" + provider.getKMSUrl() + "]");
+      }
+    }
+  }
+
+
+  @Override
+  public void flush() throws IOException {
+    for (KMSClientProvider provider : providers) {
+      try {
+        provider.flush();
+      } catch (IOException ioe) {
+        LOG.error("Error flushing provider with url"
+            + "[" + provider.getKMSUrl() + "]");
+      }
+    }
+  }
+
+  private static KMSClientProvider[] shuffle(KMSClientProvider[] providers) {
+    List<KMSClientProvider> list = Arrays.asList(providers);
+    Collections.shuffle(list);
+    return list.toArray(providers);
+  }
+}

http://git-wip-us.apache.org/repos/asf/hadoop/blob/71385f9b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java
----------------------------------------------------------------------
diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java
new file mode 100644
index 0000000..08a3d93
--- /dev/null
+++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java
@@ -0,0 +1,166 @@
+/**    when(p1.getKMSUrl()).thenReturn("p1");
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.crypto.key.kms;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.net.URI;
+import java.security.NoSuchAlgorithmException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProvider.Options;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import com.google.common.collect.Sets;
+
+public class TestLoadBalancingKMSClientProvider {
+
+  @Test
+  public void testCreation() throws Exception {
+    Configuration conf = new Configuration();
+    KeyProvider kp = new KMSClientProvider.Factory().createProvider(new URI(
+        "kms://http@host1/kms/foo"), conf);
+    assertTrue(kp instanceof KMSClientProvider);
+    assertEquals("http://host1/kms/foo/v1/",
+        ((KMSClientProvider) kp).getKMSUrl());
+
+    kp = new KMSClientProvider.Factory().createProvider(new URI(
+        "kms://http@host1;host2;host3/kms/foo"), conf);
+    assertTrue(kp instanceof LoadBalancingKMSClientProvider);
+    KMSClientProvider[] providers =
+        ((LoadBalancingKMSClientProvider) kp).getProviders();
+    assertEquals(3, providers.length);
+    assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/",
+        "http://host2/kms/foo/v1/",
+        "http://host3/kms/foo/v1/"),
+        Sets.newHashSet(providers[0].getKMSUrl(),
+            providers[1].getKMSUrl(),
+            providers[2].getKMSUrl()));
+
+    kp = new KMSClientProvider.Factory().createProvider(new URI(
+        "kms://http@host1;host2;host3:16000/kms/foo"), conf);
+    assertTrue(kp instanceof LoadBalancingKMSClientProvider);
+    providers =
+        ((LoadBalancingKMSClientProvider) kp).getProviders();
+    assertEquals(3, providers.length);
+    assertEquals(Sets.newHashSet("http://host1:16000/kms/foo/v1/",
+        "http://host2:16000/kms/foo/v1/",
+        "http://host3:16000/kms/foo/v1/"),
+        Sets.newHashSet(providers[0].getKMSUrl(),
+            providers[1].getKMSUrl(),
+            providers[2].getKMSUrl()));
+  }
+
+  @Test
+  public void testLoadBalancing() throws Exception {
+    Configuration conf = new Configuration();
+    KMSClientProvider p1 = mock(KMSClientProvider.class);
+    when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenReturn(
+            new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0]));
+    KMSClientProvider p2 = mock(KMSClientProvider.class);
+    when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenReturn(
+            new KMSClientProvider.KMSKeyVersion("p2", "v2", new byte[0]));
+    KMSClientProvider p3 = mock(KMSClientProvider.class);
+    when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenReturn(
+            new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0]));
+    KeyProvider kp = new LoadBalancingKMSClientProvider(
+        new KMSClientProvider[] { p1, p2, p3 }, 0, conf);
+    assertEquals("p1", kp.createKey("test1", new Options(conf)).getName());
+    assertEquals("p2", kp.createKey("test2", new Options(conf)).getName());
+    assertEquals("p3", kp.createKey("test3", new Options(conf)).getName());
+    assertEquals("p1", kp.createKey("test4", new Options(conf)).getName());
+  }
+
+  @Test
+  public void testLoadBalancingWithFailure() throws Exception {
+    Configuration conf = new Configuration();
+    KMSClientProvider p1 = mock(KMSClientProvider.class);
+    when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenReturn(
+            new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0]));
+    when(p1.getKMSUrl()).thenReturn("p1");
+    // This should not be retried
+    KMSClientProvider p2 = mock(KMSClientProvider.class);
+    when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenThrow(new NoSuchAlgorithmException("p2"));
+    when(p2.getKMSUrl()).thenReturn("p2");
+    KMSClientProvider p3 = mock(KMSClientProvider.class);
+    when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenReturn(
+            new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0]));
+    when(p3.getKMSUrl()).thenReturn("p3");
+    // This should be retried
+    KMSClientProvider p4 = mock(KMSClientProvider.class);
+    when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenThrow(new IOException("p4"));
+    when(p4.getKMSUrl()).thenReturn("p4");
+    KeyProvider kp = new LoadBalancingKMSClientProvider(
+        new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf);
+
+    assertEquals("p1", kp.createKey("test4", new Options(conf)).getName());
+    // Exceptions other than IOExceptions will not be retried
+    try {
+      kp.createKey("test1", new Options(conf)).getName();
+      fail("Should fail since its not an IOException");
+    } catch (Exception e) {
+      assertTrue(e instanceof NoSuchAlgorithmException);
+    }
+    assertEquals("p3", kp.createKey("test2", new Options(conf)).getName());
+    // IOException will trigger retry in next provider
+    assertEquals("p1", kp.createKey("test3", new Options(conf)).getName());
+  }
+
+  @Test
+  public void testLoadBalancingWithAllBadNodes() throws Exception {
+    Configuration conf = new Configuration();
+    KMSClientProvider p1 = mock(KMSClientProvider.class);
+    when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenThrow(new IOException("p1"));
+    KMSClientProvider p2 = mock(KMSClientProvider.class);
+    when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenThrow(new IOException("p2"));
+    KMSClientProvider p3 = mock(KMSClientProvider.class);
+    when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenThrow(new IOException("p3"));
+    KMSClientProvider p4 = mock(KMSClientProvider.class);
+    when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class)))
+        .thenThrow(new IOException("p4"));
+    when(p1.getKMSUrl()).thenReturn("p1");
+    when(p2.getKMSUrl()).thenReturn("p2");
+    when(p3.getKMSUrl()).thenReturn("p3");
+    when(p4.getKMSUrl()).thenReturn("p4");
+    KeyProvider kp = new LoadBalancingKMSClientProvider(
+        new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf);
+    try {
+      kp.createKey("test3", new Options(conf)).getName();
+      fail("Should fail since all providers threw an IOException");
+    } catch (Exception e) {
+      assertTrue(e instanceof IOException);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/hadoop/blob/71385f9b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
----------------------------------------------------------------------
diff --git a/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
index 70ba95f..c5a990b 100644
--- a/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
+++ b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java
@@ -24,9 +24,11 @@ import org.apache.hadoop.crypto.key.KeyProvider;
 import org.apache.hadoop.crypto.key.KeyProvider.KeyVersion;
 import org.apache.hadoop.crypto.key.KeyProvider.Options;
 import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
 import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
 import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension;
 import org.apache.hadoop.crypto.key.kms.KMSClientProvider;
+import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.minikdc.MiniKdc;
@@ -99,6 +101,12 @@ public class TestKMS {
     }
   }
 
+  protected KeyProvider createProvider(URI uri, Configuration conf)
+      throws IOException {
+    return new LoadBalancingKMSClientProvider(
+        new KMSClientProvider[] { new KMSClientProvider(uri, conf) }, conf);
+  }
+
   protected <T> T runServer(String keystore, String password, File confDir,
       KMSCallable<T> callable) throws Exception {
     return runServer(-1, keystore, password, confDir, callable);
@@ -305,7 +313,7 @@ public class TestKMS {
         final URI uri = createKMSUri(getKMSUrl());
 
         if (ssl) {
-          KeyProvider testKp = new KMSClientProvider(uri, conf);
+          KeyProvider testKp = createProvider(uri, conf);
           ThreadGroup threadGroup = Thread.currentThread().getThreadGroup();
           while (threadGroup.getParent() != null) {
             threadGroup = threadGroup.getParent();
@@ -335,12 +343,14 @@ public class TestKMS {
             doAs(user, new PrivilegedExceptionAction<Void>() {
               @Override
               public Void run() throws Exception {
-                final KeyProvider kp = new KMSClientProvider(uri, conf);
+                final KeyProvider kp = createProvider(uri, conf);
                 // getKeys() empty
                 Assert.assertTrue(kp.getKeys().isEmpty());
 
                 Thread.sleep(4000);
-                Token<?>[] tokens = ((KMSClientProvider)kp).addDelegationTokens("myuser", new Credentials());
+                Token<?>[] tokens =
+                    ((KeyProviderDelegationTokenExtension.DelegationTokenExtension)kp)
+                    .addDelegationTokens("myuser", new Credentials());
                 Assert.assertEquals(1, tokens.length);
                 Assert.assertEquals("kms-dt", tokens[0].getKind().toString());
                 return null;
@@ -348,12 +358,14 @@ public class TestKMS {
             });
           }
         } else {
-          KeyProvider kp = new KMSClientProvider(uri, conf);
+          KeyProvider kp = createProvider(uri, conf);
           // getKeys() empty
           Assert.assertTrue(kp.getKeys().isEmpty());
 
           Thread.sleep(4000);
-          Token<?>[] tokens = ((KMSClientProvider)kp).addDelegationTokens("myuser", new Credentials());
+          Token<?>[] tokens =
+              ((KeyProviderDelegationTokenExtension.DelegationTokenExtension)kp)
+              .addDelegationTokens("myuser", new Credentials());
           Assert.assertEquals(1, tokens.length);
           Assert.assertEquals("kms-dt", tokens[0].getKind().toString());
         }
@@ -404,7 +416,7 @@ public class TestKMS {
         Date started = new Date();
         Configuration conf = new Configuration();
         URI uri = createKMSUri(getKMSUrl());
-        KeyProvider kp = new KMSClientProvider(uri, conf);
+        KeyProvider kp = createProvider(uri, conf);
 
         // getKeys() empty
         Assert.assertTrue(kp.getKeys().isEmpty());
@@ -687,7 +699,7 @@ public class TestKMS {
         doAs("CREATE", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               Options options = new KeyProvider.Options(conf);
               Map<String, String> attributes = options.getAttributes();
@@ -727,7 +739,7 @@ public class TestKMS {
         doAs("DECRYPT_EEK", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               Options options = new KeyProvider.Options(conf);
               Map<String, String> attributes = options.getAttributes();
@@ -760,7 +772,7 @@ public class TestKMS {
         doAs("ROLLOVER", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               Options options = new KeyProvider.Options(conf);
               Map<String, String> attributes = options.getAttributes();
@@ -804,7 +816,7 @@ public class TestKMS {
         doAs("GET", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               Options options = new KeyProvider.Options(conf);
               Map<String, String> attributes = options.getAttributes();
@@ -836,7 +848,7 @@ public class TestKMS {
         final EncryptedKeyVersion ekv = doAs("GENERATE_EEK", new PrivilegedExceptionAction<EncryptedKeyVersion>() {
           @Override
           public EncryptedKeyVersion run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               Options options = new KeyProvider.Options(conf);
               Map<String, String> attributes = options.getAttributes();
@@ -861,7 +873,7 @@ public class TestKMS {
         doAs("ROLLOVER", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProviderCryptoExtension kpce =
                   KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp);
@@ -891,7 +903,7 @@ public class TestKMS {
         doAs("GENERATE_EEK", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProviderCryptoExtension kpce =
                   KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp);
@@ -964,7 +976,7 @@ public class TestKMS {
                 new PrivilegedExceptionAction<KeyProvider>() {
                   @Override
                   public KeyProvider run() throws Exception {
-                    KMSClientProvider kp = new KMSClientProvider(uri, conf);
+                    KeyProvider kp = createProvider(uri, conf);
                         kp.createKey("k1", new byte[16],
                             new KeyProvider.Options(conf));
                     return kp;
@@ -1041,7 +1053,7 @@ public class TestKMS {
                 new PrivilegedExceptionAction<Void>() {
                   @Override
                   public Void run() throws Exception {
-                    KMSClientProvider kp = new KMSClientProvider(uri, conf);
+                    KeyProvider kp = createProvider(uri, conf);
 
                     kp.createKey("k0", new byte[16],
                         new KeyProvider.Options(conf));
@@ -1072,7 +1084,7 @@ public class TestKMS {
                 new PrivilegedExceptionAction<Void>() {
                   @Override
                   public Void run() throws Exception {
-                    KMSClientProvider kp = new KMSClientProvider(uri, conf);
+                    KeyProvider kp = createProvider(uri, conf);
                     kp.createKey("k3", new byte[16],
                         new KeyProvider.Options(conf));
                     // Atleast 2 rollovers.. so should induce signer Exception
@@ -1132,7 +1144,7 @@ public class TestKMS {
         doAs("client", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               kp.createKey("k", new KeyProvider.Options(conf));
               Assert.fail();
@@ -1223,7 +1235,7 @@ public class TestKMS {
         doAs("CREATE", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProvider.KeyVersion kv = kp.createKey("k0",
                   new KeyProvider.Options(conf));
@@ -1238,7 +1250,7 @@ public class TestKMS {
         doAs("DELETE", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               kp.deleteKey("k0");
             } catch (Exception ex) {
@@ -1251,7 +1263,7 @@ public class TestKMS {
         doAs("SET_KEY_MATERIAL", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProvider.KeyVersion kv = kp.createKey("k1", new byte[16],
                   new KeyProvider.Options(conf));
@@ -1266,7 +1278,7 @@ public class TestKMS {
         doAs("ROLLOVER", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProvider.KeyVersion kv = kp.rollNewVersion("k1");
               Assert.assertNull(kv.getMaterial());
@@ -1280,7 +1292,7 @@ public class TestKMS {
         doAs("SET_KEY_MATERIAL", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProvider.KeyVersion kv =
                   kp.rollNewVersion("k1", new byte[16]);
@@ -1296,7 +1308,7 @@ public class TestKMS {
             doAs("GET", new PrivilegedExceptionAction<KeyVersion>() {
           @Override
           public KeyVersion run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               kp.getKeyVersion("k1@0");
               KeyVersion kv = kp.getCurrentKey("k1");
@@ -1313,7 +1325,7 @@ public class TestKMS {
                 new PrivilegedExceptionAction<EncryptedKeyVersion>() {
           @Override
           public EncryptedKeyVersion run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension.
                       createKeyProviderCryptoExtension(kp);
@@ -1330,7 +1342,7 @@ public class TestKMS {
         doAs("DECRYPT_EEK", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension.
                       createKeyProviderCryptoExtension(kp);
@@ -1345,7 +1357,7 @@ public class TestKMS {
         doAs("GET_KEYS", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               kp.getKeys();
             } catch (Exception ex) {
@@ -1358,7 +1370,7 @@ public class TestKMS {
         doAs("GET_METADATA", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             try {
               kp.getMetadata("k1");
               kp.getKeysMetadata("k1");
@@ -1385,7 +1397,7 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KeyProvider kp = new KMSClientProvider(uri, conf);
+              KeyProvider kp = createProvider(uri, conf);
               KeyProvider.KeyVersion kv = kp.createKey("k2",
                   new KeyProvider.Options(conf));
               Assert.fail();
@@ -1440,12 +1452,12 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KMSClientProvider kp = new KMSClientProvider(uri, conf);
+              KeyProvider kp = createProvider(uri, conf);
               KeyProvider.KeyVersion kv = kp.createKey("ck0",
                   new KeyProvider.Options(conf));
               EncryptedKeyVersion eek =
-                  kp.generateEncryptedKey("ck0");
-              kp.decryptEncryptedKey(eek);
+                  ((CryptoExtension)kp).generateEncryptedKey("ck0");
+              ((CryptoExtension)kp).decryptEncryptedKey(eek);
               Assert.assertNull(kv.getMaterial());
             } catch (Exception ex) {
               Assert.fail(ex.getMessage());
@@ -1458,12 +1470,12 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KMSClientProvider kp = new KMSClientProvider(uri, conf);
+              KeyProvider kp = createProvider(uri, conf);
               KeyProvider.KeyVersion kv = kp.createKey("ck1",
                   new KeyProvider.Options(conf));
               EncryptedKeyVersion eek =
-                  kp.generateEncryptedKey("ck1");
-              kp.decryptEncryptedKey(eek);
+                  ((CryptoExtension)kp).generateEncryptedKey("ck1");
+              ((CryptoExtension)kp).decryptEncryptedKey(eek);
               Assert.fail("admin user must not be allowed to decrypt !!");
             } catch (Exception ex) {
             }
@@ -1475,12 +1487,12 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KMSClientProvider kp = new KMSClientProvider(uri, conf);
+              KeyProvider kp = createProvider(uri, conf);
               KeyProvider.KeyVersion kv = kp.createKey("ck2",
                   new KeyProvider.Options(conf));
               EncryptedKeyVersion eek =
-                  kp.generateEncryptedKey("ck2");
-              kp.decryptEncryptedKey(eek);
+                  ((CryptoExtension)kp).generateEncryptedKey("ck2");
+              ((CryptoExtension)kp).decryptEncryptedKey(eek);
               Assert.fail("admin user must not be allowed to decrypt !!");
             } catch (Exception ex) {
             }
@@ -1525,7 +1537,7 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KeyProvider kp = new KMSClientProvider(uri, conf);
+              KeyProvider kp = createProvider(uri, conf);
               KeyProvider.KeyVersion kv = kp.createKey("ck0",
                   new KeyProvider.Options(conf));
               Assert.assertNull(kv.getMaterial());
@@ -1540,7 +1552,7 @@ public class TestKMS {
           @Override
           public Void run() throws Exception {
             try {
-              KeyProvider kp = new KMSClientProvider(uri, conf);
+              KeyProvider kp = createProvider(uri, conf);
               KeyProvider.KeyVersion kv = kp.createKey("ck1",
                   new KeyProvider.Options(conf));
               Assert.assertNull(kv.getMaterial());
@@ -1583,7 +1595,7 @@ public class TestKMS {
 
     boolean caughtTimeout = false;
     try {
-      KeyProvider kp = new KMSClientProvider(uri, conf);
+      KeyProvider kp = createProvider(uri, conf);
       kp.getKeys();
     } catch (SocketTimeoutException e) {
       caughtTimeout = true;
@@ -1593,7 +1605,7 @@ public class TestKMS {
 
     caughtTimeout = false;
     try {
-      KeyProvider kp = new KMSClientProvider(uri, conf);
+      KeyProvider kp = createProvider(uri, conf);
       KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp)
           .generateEncryptedKey("a");
     } catch (SocketTimeoutException e) {
@@ -1604,7 +1616,7 @@ public class TestKMS {
 
     caughtTimeout = false;
     try {
-      KeyProvider kp = new KMSClientProvider(uri, conf);
+      KeyProvider kp = createProvider(uri, conf);
       KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp)
           .decryptEncryptedKey(
               new KMSClientProvider.KMSEncryptedKeyVersion("a",
@@ -1651,7 +1663,7 @@ public class TestKMS {
             UserGroupInformation.getCurrentUser();
 
         try {
-          KeyProvider kp = new KMSClientProvider(uri, conf);
+          KeyProvider kp = createProvider(uri, conf);
           kp.createKey(keyA, new KeyProvider.Options(conf));
         } catch (IOException ex) {
           System.out.println(ex.getMessage());
@@ -1660,7 +1672,7 @@ public class TestKMS {
         doAs("client", new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             KeyProviderDelegationTokenExtension kpdte =
                 KeyProviderDelegationTokenExtension.
                     createKeyProviderDelegationTokenExtension(kp);
@@ -1672,7 +1684,7 @@ public class TestKMS {
         nonKerberosUgi.addCredentials(credentials);
 
         try {
-          KeyProvider kp = new KMSClientProvider(uri, conf);
+          KeyProvider kp = createProvider(uri, conf);
           kp.createKey(keyA, new KeyProvider.Options(conf));
         } catch (IOException ex) {
           System.out.println(ex.getMessage());
@@ -1681,7 +1693,7 @@ public class TestKMS {
         nonKerberosUgi.doAs(new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            KeyProvider kp = new KMSClientProvider(uri, conf);
+            KeyProvider kp = createProvider(uri, conf);
             kp.createKey(keyD, new KeyProvider.Options(conf));
             return null;
           }
@@ -1767,7 +1779,7 @@ public class TestKMS {
                   new PrivilegedExceptionAction<KeyProvider>() {
                     @Override
                     public KeyProvider run() throws Exception {
-                      KMSClientProvider kp = new KMSClientProvider(uri, conf);
+                      KeyProvider kp = createProvider(uri, conf);
                           kp.createKey("k1", new byte[16],
                               new KeyProvider.Options(conf));
                           kp.createKey("k2", new byte[16],
@@ -1844,7 +1856,7 @@ public class TestKMS {
         clientUgi.doAs(new PrivilegedExceptionAction<Void>() {
           @Override
           public Void run() throws Exception {
-            final KeyProvider kp = new KMSClientProvider(uri, conf);
+            final KeyProvider kp = createProvider(uri, conf);
             kp.createKey("kaa", new KeyProvider.Options(conf));
 
             // authorized proxyuser
@@ -1956,7 +1968,7 @@ public class TestKMS {
             fooUgi.doAs(new PrivilegedExceptionAction<Void>() {
               @Override
               public Void run() throws Exception {
-                KeyProvider kp = new KMSClientProvider(uri, conf);
+                KeyProvider kp = createProvider(uri, conf);
                 Assert.assertNotNull(kp.createKey("kaa",
                     new KeyProvider.Options(conf)));
                 return null;
@@ -1970,7 +1982,7 @@ public class TestKMS {
               @Override
               public Void run() throws Exception {
                 try {
-                  KeyProvider kp = new KMSClientProvider(uri, conf);
+                  KeyProvider kp = createProvider(uri, conf);
                   kp.createKey("kbb", new KeyProvider.Options(conf));
                   Assert.fail();
                 } catch (Exception ex) {
@@ -1986,7 +1998,7 @@ public class TestKMS {
             barUgi.doAs(new PrivilegedExceptionAction<Void>() {
               @Override
               public Void run() throws Exception {
-                KeyProvider kp = new KMSClientProvider(uri, conf);
+                KeyProvider kp = createProvider(uri, conf);
                 Assert.assertNotNull(kp.createKey("kcc",
                     new KeyProvider.Options(conf)));
                 return null;