You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@orc.apache.org by om...@apache.org on 2019/08/02 19:24:55 UTC

[orc] branch master updated: ORC-529: Allow configuration and table properties to control encryption.

This is an automated email from the ASF dual-hosted git repository.

omalley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/orc.git


The following commit(s) were added to refs/heads/master by this push:
     new 79e77f4  ORC-529: Allow configuration and table properties to control encryption.
79e77f4 is described below

commit 79e77f4f75cb1c6e13d36da38be6a502cc405856
Author: Owen O'Malley <om...@apache.org>
AuthorDate: Tue Jul 23 04:40:47 2019 -0700

    ORC-529: Allow configuration and table properties to control encryption.
    
    Fixes #415
    
    Signed-off-by: Owen O'Malley <om...@apache.org>
---
 .../src/java/org/apache/orc/InMemoryKeystore.java  |   5 +-
 java/core/src/java/org/apache/orc/OrcConf.java     |   5 +-
 java/core/src/java/org/apache/orc/OrcFile.java     |  98 ++---
 .../src/java/org/apache/orc/TypeDescription.java   | 385 +++----------------
 .../src/java/org/apache/orc/impl/CryptoUtils.java  |  52 +++
 .../src/java/org/apache/orc/impl/ParserUtils.java  | 425 +++++++++++++++++++++
 .../src/java/org/apache/orc/impl/WriterImpl.java   |  92 +++--
 .../apache/orc/impl/reader/ReaderEncryption.java   |  11 +-
 .../orc/impl/reader/ReaderEncryptionVariant.java   |   6 +-
 .../org.apache.orc.impl.KeyProvider$Factory}       |   3 +-
 .../test/org/apache/orc/TestTypeDescription.java   |  81 ++++
 .../src/test/org/apache/orc/TestVectorOrcFile.java |  11 +-
 .../test/org/apache/orc/impl/TestCryptoUtils.java  |  45 ++-
 .../org/apache/orc/impl/TestPhysicalFsWriter.java  |   2 +-
 .../org/apache/orc/mapred/OrcOutputFormat.java     |   4 +-
 java/pom.xml                                       |   1 +
 .../src/java/org/apache/orc/impl/HadoopShims.java  |  65 +---
 .../org/apache/orc/impl/HadoopShimsCurrent.java    |   4 +-
 .../org/apache/orc/impl/HadoopShimsPre2_3.java     |   2 +-
 .../org/apache/orc/impl/HadoopShimsPre2_6.java     |   2 +-
 .../org/apache/orc/impl/HadoopShimsPre2_7.java     |   4 +-
 .../src/java/org/apache/orc/impl/KeyProvider.java  |  84 ++++
 .../org/apache/orc/impl/TestHadoopShimsPre2_7.java | 140 -------
 .../src/java/org/apache/orc/tools/KeyTool.java     |   9 +-
 .../test/org/apache/orc/impl/FakeKeyProvider.java  | 142 +++++++
 .../org/apache/orc/impl/TestHadoopKeyProvider.java |  62 +++
 ...org.apache.hadoop.crypto.key.KeyProviderFactory |   2 +-
 27 files changed, 1066 insertions(+), 676 deletions(-)

diff --git a/java/core/src/java/org/apache/orc/InMemoryKeystore.java b/java/core/src/java/org/apache/orc/InMemoryKeystore.java
index 735962f..78e6a85 100644
--- a/java/core/src/java/org/apache/orc/InMemoryKeystore.java
+++ b/java/core/src/java/org/apache/orc/InMemoryKeystore.java
@@ -18,6 +18,7 @@
 package org.apache.orc;
 
 import org.apache.orc.impl.HadoopShims;
+import org.apache.orc.impl.KeyProvider;
 import org.apache.orc.impl.LocalKey;
 
 import javax.crypto.BadPaddingException;
@@ -40,7 +41,7 @@ import java.util.Random;
 import java.util.TreeMap;
 
 /**
- * This is an in-memory implementation of {@link HadoopShims.KeyProvider}.
+ * This is an in-memory implementation of {@link KeyProvider}.
  *
  * The primary use of this class is for when the user doesn't have a
  * Hadoop KMS running and wishes to use encryption. It is also useful for
@@ -52,7 +53,7 @@ import java.util.TreeMap;
  *
  * This class is not thread safe.
  */
-public class InMemoryKeystore implements HadoopShims.KeyProvider {
+public class InMemoryKeystore implements KeyProvider {
   /**
    * Support AES 256 ?
    */
diff --git a/java/core/src/java/org/apache/orc/OrcConf.java b/java/core/src/java/org/apache/orc/OrcConf.java
index a6fbad1..7cca1db 100644
--- a/java/core/src/java/org/apache/orc/OrcConf.java
+++ b/java/core/src/java/org/apache/orc/OrcConf.java
@@ -162,7 +162,10 @@ public enum OrcConf {
       "Comma-separated list of columns for which dictionary encoding is to be skipped."),
   // some JVM doesn't allow array creation of size Integer.MAX_VALUE, so chunk size is slightly less than max int
   ORC_MAX_DISK_RANGE_CHUNK_LIMIT("orc.max.disk.range.chunk.limit", "hive.exec.orc.max.disk.range.chunk.limit",
-    Integer.MAX_VALUE - 1024, "When reading stripes >2GB, specify max limit for the chunk size.")
+    Integer.MAX_VALUE - 1024, "When reading stripes >2GB, specify max limit for the chunk size."),
+  ENCRYPTION("orc.encrypt", "orc.encrypt", null, "The list of keys and columns to encrypt with"),
+  DATA_MASK("orc.mask", "orc.mask", null, "The masks to apply to the encrypted columns"),
+  KEY_PROVIDER("orc.key.provider", "orc.key.provider", "hadoop", "The kind of KeyProvider to use for encryption.")
   ;
 
   private final String attribute;
diff --git a/java/core/src/java/org/apache/orc/OrcFile.java b/java/core/src/java/org/apache/orc/OrcFile.java
index 4eb2c83..23c4d0f 100644
--- a/java/core/src/java/org/apache/orc/OrcFile.java
+++ b/java/core/src/java/org/apache/orc/OrcFile.java
@@ -32,6 +32,7 @@ import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.orc.impl.HadoopShims;
 import org.apache.orc.impl.HadoopShimsFactory;
+import org.apache.orc.impl.KeyProvider;
 import org.apache.orc.impl.MemoryManagerImpl;
 import org.apache.orc.impl.OrcTail;
 import org.apache.orc.impl.ReaderImpl;
@@ -275,7 +276,7 @@ public class OrcFile {
     private FileSystem filesystem;
     private long maxLength = Long.MAX_VALUE;
     private OrcTail orcTail;
-    private HadoopShims.KeyProvider keyProvider;
+    private KeyProvider keyProvider;
     // TODO: We can generalize FileMetada interface. Make OrcTail implement FileMetadata interface
     // and remove this class altogether. Both footer caching and llap caching just needs OrcTail.
     // For now keeping this around to avoid complex surgery
@@ -306,7 +307,7 @@ public class OrcFile {
      * @param provider
      * @return
      */
-    public ReaderOptions setKeyProvider(HadoopShims.KeyProvider provider) {
+    public ReaderOptions setKeyProvider(KeyProvider provider) {
       this.keyProvider = provider;
       return this;
     }
@@ -327,7 +328,7 @@ public class OrcFile {
       return orcTail;
     }
 
-    public HadoopShims.KeyProvider getKeyProvider() {
+    public KeyProvider getKeyProvider() {
       return keyProvider;
     }
 
@@ -397,40 +398,6 @@ public class OrcFile {
   }
 
   /**
-   * An internal class that describes how to encrypt a column.
-   */
-  public static class EncryptionOption {
-    private final String columnNames;
-    private final String keyName;
-    private final String mask;
-    private final String[] maskParameters;
-
-    EncryptionOption(String columnNames, String keyName, String mask,
-                     String... maskParams) {
-      this.columnNames = columnNames;
-      this.keyName = keyName;
-      this.mask = mask;
-      this.maskParameters = maskParams;
-    }
-
-    public String getColumnNames() {
-      return columnNames;
-    }
-
-    public String getKeyName() {
-      return keyName;
-    }
-
-    public String getMask() {
-      return mask;
-    }
-
-    public String[] getMaskParameters() {
-      return maskParameters;
-    }
-  }
-
-  /**
    * Options for creating ORC file writers.
    */
   public static class WriterOptions implements Cloneable {
@@ -460,8 +427,9 @@ public class OrcFile {
     private boolean writeVariableLengthBlocks;
     private HadoopShims shims;
     private String directEncodingColumns;
-    private List<EncryptionOption> encryption = new ArrayList<>();
-    private HadoopShims.KeyProvider provider;
+    private String encryption;
+    private String masks;
+    private KeyProvider provider;
 
     protected WriterOptions(Properties tableProperties, Configuration conf) {
       configuration = conf;
@@ -757,50 +725,24 @@ public class OrcFile {
       return this;
     }
 
-    /*
-     * Encrypt a set of columns with a key.
-     * For readers without access to the key, they will read nulls.
-     * @param columnNames the columns to encrypt
-     * @param keyName the key name to encrypt the data with
-     * @return this
-     */
-    public WriterOptions encryptColumn(String columnNames,
-                                       String keyName) {
-      return encryptColumn(columnNames, keyName,
-          DataMask.Standard.NULLIFY.getName());
-    }
-
     /**
      * Encrypt a set of columns with a key.
-     * The data is also masked and stored unencrypted in the file. Readers
-     * without access to the key will instead get the masked data.
-     * @param columnNames the column names to encrypt
-     * @param keyName the key name to encrypt the data with
-     * @param mask the kind of masking
-     * @param maskParameters the parameters to the mask
+     * For readers without access to the key, they will read nulls.
+     * @param value a key-list of which columns to encrypt
      * @return this
      */
-    public WriterOptions encryptColumn(String columnNames,
-                                       String keyName,
-                                       String mask,
-                                       String... maskParameters) {
-      encryption.add(new EncryptionOption(columnNames, keyName, mask,
-          maskParameters));
+    public WriterOptions encrypt(String value) {
+      encryption = value;
       return this;
     }
 
     /**
-     * Set a different mask on a subtree that is already being encrypted.
-     * @param columnNames the column names to change the mask on
-     * @param mask the name of the mask
-     * @param maskParameters the parameters for the mask
+     * Set the masks for the unencrypted data.
+     * @param value a list of the masks and column names
      * @return this
      */
-    public WriterOptions maskColumn(String columnNames,
-                                    String mask,
-                                    String... maskParameters) {
-      encryption.add(new EncryptionOption(columnNames, null,
-          mask, maskParameters));
+    public WriterOptions masks(String value) {
+      masks = value;
       return this;
     }
 
@@ -809,12 +751,12 @@ public class OrcFile {
      * @param provider
      * @return
      */
-    public WriterOptions setKeyProvider(HadoopShims.KeyProvider provider) {
+    public WriterOptions setKeyProvider(KeyProvider provider) {
       this.provider = provider;
       return this;
     }
 
-    public HadoopShims.KeyProvider getKeyProvider() {
+    public KeyProvider getKeyProvider() {
       return provider;
     }
 
@@ -922,9 +864,13 @@ public class OrcFile {
       return directEncodingColumns;
     }
 
-    public List<EncryptionOption> getEncryption() {
+    public String getEncryption() {
       return encryption;
     }
+
+    public String getMasks() {
+      return masks;
+    }
   }
 
   /**
diff --git a/java/core/src/java/org/apache/orc/TypeDescription.java b/java/core/src/java/org/apache/orc/TypeDescription.java
index 4f32cc5..f36b134 100644
--- a/java/core/src/java/org/apache/orc/TypeDescription.java
+++ b/java/core/src/java/org/apache/orc/TypeDescription.java
@@ -30,6 +30,7 @@ import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.orc.impl.ParserUtils;
 import org.apache.orc.impl.SchemaEvolution;
 import org.jetbrains.annotations.NotNull;
 
@@ -56,6 +57,10 @@ public class TypeDescription
   private static final int DEFAULT_LENGTH = 256;
   static final Pattern UNQUOTED_NAMES = Pattern.compile("^[a-zA-Z0-9_]+$");
 
+  // type attributes
+  public static final String ENCRYPT_ATTRIBUTE = "encrypt";
+  public static final String MASK_ATTRIBUTE = "mask";
+
   @Override
   public int compareTo(TypeDescription other) {
     if (this == other) {
@@ -193,231 +198,6 @@ public class TypeDescription
     return new TypeDescription(Category.DECIMAL);
   }
 
-  static class StringPosition {
-    final String value;
-    int position;
-    final int length;
-
-    StringPosition(String value) {
-      this.value = value;
-      position = 0;
-      length = value.length();
-    }
-
-    @Override
-    public String toString() {
-      StringBuilder buffer = new StringBuilder();
-      buffer.append('\'');
-      buffer.append(value.substring(0, position));
-      buffer.append('^');
-      buffer.append(value.substring(position));
-      buffer.append('\'');
-      return buffer.toString();
-    }
-  }
-
-  static Category parseCategory(StringPosition source) {
-    StringBuilder word = new StringBuilder();
-    boolean hadSpace = true;
-    while (source.position < source.length) {
-      char ch = source.value.charAt(source.position);
-      if (Character.isLetter(ch)) {
-        word.append(Character.toLowerCase(ch));
-        hadSpace = false;
-      } else if (ch == ' ') {
-        if (!hadSpace) {
-          hadSpace = true;
-          word.append(ch);
-        }
-      } else {
-        break;
-      }
-      source.position += 1;
-    }
-    String catString = word.toString();
-    // if there were trailing spaces, remove them.
-    if (hadSpace) {
-      catString = catString.trim();
-    }
-    if (!catString.isEmpty()) {
-      for (Category cat : Category.values()) {
-        if (cat.getName().equals(catString)) {
-          return cat;
-        }
-      }
-    }
-    throw new IllegalArgumentException("Can't parse category at " + source);
-  }
-
-  static int parseInt(StringPosition source) {
-    int start = source.position;
-    int result = 0;
-    while (source.position < source.length) {
-      char ch = source.value.charAt(source.position);
-      if (!Character.isDigit(ch)) {
-        break;
-      }
-      result = result * 10 + (ch - '0');
-      source.position += 1;
-    }
-    if (source.position == start) {
-      throw new IllegalArgumentException("Missing integer at " + source);
-    }
-    return result;
-  }
-
-  static String parseName(StringPosition source) {
-    if (source.position == source.length) {
-      throw new IllegalArgumentException("Missing name at " + source);
-    }
-    final int start = source.position;
-    if (source.value.charAt(source.position) == '`') {
-      source.position += 1;
-      StringBuilder buffer = new StringBuilder();
-      boolean closed = false;
-      while (source.position < source.length) {
-        char ch = source.value.charAt(source.position);
-        source.position += 1;
-        if (ch == '`') {
-          if (source.position < source.length &&
-              source.value.charAt(source.position) == '`') {
-            source.position += 1;
-            buffer.append('`');
-          } else {
-            closed = true;
-            break;
-          }
-        } else {
-          buffer.append(ch);
-        }
-      }
-      if (!closed) {
-        source.position = start;
-        throw new IllegalArgumentException("Unmatched quote at " + source);
-      } else if (buffer.length() == 0) {
-        throw new IllegalArgumentException("Empty quoted field name at " + source);
-      }
-      return buffer.toString();
-    } else {
-      while (source.position < source.length) {
-        char ch = source.value.charAt(source.position);
-        if (!Character.isLetterOrDigit(ch) && ch != '_') {
-          break;
-        }
-        source.position += 1;
-      }
-      if (source.position == start) {
-        throw new IllegalArgumentException("Missing name at " + source);
-      }
-      return source.value.substring(start, source.position);
-    }
-  }
-
-  static void requireChar(StringPosition source, char required) {
-    if (source.position >= source.length ||
-        source.value.charAt(source.position) != required) {
-      throw new IllegalArgumentException("Missing required char '" +
-          required + "' at " + source);
-    }
-    source.position += 1;
-  }
-
-  static boolean consumeChar(StringPosition source, char ch) {
-    boolean result = source.position < source.length &&
-        source.value.charAt(source.position) == ch;
-    if (result) {
-      source.position += 1;
-    }
-    return result;
-  }
-
-  static void parseUnion(TypeDescription type, StringPosition source) {
-    requireChar(source, '<');
-    do {
-      type.addUnionChild(parseType(source));
-    } while (consumeChar(source, ','));
-    requireChar(source, '>');
-  }
-
-  static void parseStruct(TypeDescription type, StringPosition source) {
-    requireChar(source, '<');
-    boolean needComma = false;
-    while (!consumeChar(source, '>')) {
-      if (needComma) {
-        requireChar(source, ',');
-      } else {
-        needComma = true;
-      }
-      String fieldName = parseName(source);
-      requireChar(source, ':');
-      type.addField(fieldName, parseType(source));
-    }
-  }
-
-  static TypeDescription parseType(StringPosition source) {
-    TypeDescription result = new TypeDescription(parseCategory(source));
-    switch (result.getCategory()) {
-      case BINARY:
-      case BOOLEAN:
-      case BYTE:
-      case DATE:
-      case DOUBLE:
-      case FLOAT:
-      case INT:
-      case LONG:
-      case SHORT:
-      case STRING:
-      case TIMESTAMP:
-      case TIMESTAMP_INSTANT:
-        break;
-      case CHAR:
-      case VARCHAR:
-        requireChar(source, '(');
-        result.withMaxLength(parseInt(source));
-        requireChar(source, ')');
-        break;
-      case DECIMAL: {
-        requireChar(source, '(');
-        int precision = parseInt(source);
-        requireChar(source, ',');
-        result.withScale(parseInt(source));
-        result.withPrecision(precision);
-        requireChar(source, ')');
-        break;
-      }
-      case LIST: {
-        requireChar(source, '<');
-        TypeDescription child = parseType(source);
-        result.children.add(child);
-        child.parent = result;
-        requireChar(source, '>');
-        break;
-      }
-      case MAP: {
-        requireChar(source, '<');
-        TypeDescription keyType = parseType(source);
-        result.children.add(keyType);
-        keyType.parent = result;
-        requireChar(source, ',');
-        TypeDescription valueType = parseType(source);
-        result.children.add(valueType);
-        valueType.parent = result;
-        requireChar(source, '>');
-        break;
-      }
-      case UNION:
-        parseUnion(result, source);
-        break;
-      case STRUCT:
-        parseStruct(result, source);
-        break;
-      default:
-        throw new IllegalArgumentException("Unknown type " +
-            result.getCategory() + " at " + source);
-    }
-    return result;
-  }
-
   /**
    * Parse TypeDescription from the Hive type names. This is the inverse
    * of TypeDescription.toString()
@@ -429,9 +209,9 @@ public class TypeDescription
     if (typeName == null) {
       return null;
     }
-    StringPosition source = new StringPosition(typeName);
-    TypeDescription result = parseType(source);
-    if (source.position != source.length) {
+    ParserUtils.StringPosition source = new ParserUtils.StringPosition(typeName);
+    TypeDescription result = ParserUtils.parseType(source);
+    if (source.hasCharactersLeft()) {
       throw new IllegalArgumentException("Extra characters at " + source);
     }
     return result;
@@ -473,7 +253,7 @@ public class TypeDescription
   /**
    * Set an attribute on this type.
    * @param key the attribute name
-   * @param value the attribute value
+   * @param value the attribute value or null to clear the value
    * @return this for method chaining
    */
   public TypeDescription setAttribute(@NotNull String key,
@@ -549,8 +329,7 @@ public class TypeDescription
       throw new IllegalArgumentException("Can only add types to union type" +
           " and not " + category);
     }
-    children.add(child);
-    child.parent = this;
+    addChild(child);
     return this;
   }
 
@@ -566,8 +345,7 @@ public class TypeDescription
           " and not " + category);
     }
     fieldNames.add(field);
-    children.add(fieldType);
-    fieldType.parent = this;
+    addChild(fieldType);
     return this;
   }
 
@@ -876,6 +654,30 @@ public class TypeDescription
     return startId;
   }
 
+  /**
+   * Add a child to a type.
+   * @param child the child to add
+   */
+  public void addChild(TypeDescription child) {
+    switch (category) {
+      case LIST:
+        if (children.size() >= 1) {
+          throw new IllegalArgumentException("Can't add more children to list");
+        }
+      case MAP:
+        if (children.size() >= 2) {
+          throw new IllegalArgumentException("Can't add more children to map");
+        }
+      case UNION:
+      case STRUCT:
+        children.add(child);
+        child.parent = this;
+        break;
+      default:
+        throw new IllegalArgumentException("Can't add children to " + category);
+    }
+  }
+
   public TypeDescription(Category category) {
     this.category = category;
     if (category.isPrimitive) {
@@ -1050,79 +852,6 @@ public class TypeDescription
   }
 
   /**
-   * Split a compound name into parts separated by '.'.
-   * @param source the string to parse into simple names
-   * @return a list of simple names from the source
-   */
-  private static List<String> splitName(StringPosition source) {
-    List<String> result = new ArrayList<>();
-    do {
-      result.add(parseName(source));
-    } while (consumeChar(source, '.'));
-    return result;
-  }
-
-  private static final Pattern INTEGER_PATTERN = Pattern.compile("^[0-9]+$");
-
-  private TypeDescription findSubtype(StringPosition source) {
-    List<String> names = splitName(source);
-    if (names.size() == 1 && INTEGER_PATTERN.matcher(names.get(0)).matches()) {
-      return findSubtype(Integer.parseInt(names.get(0)));
-    }
-    TypeDescription current = SchemaEvolution.checkAcidSchema(this)
-        ? SchemaEvolution.getBaseRow(this) : this;
-    while (names.size() > 0) {
-      String first = names.remove(0);
-      switch (current.category) {
-        case STRUCT: {
-          int posn = current.fieldNames.indexOf(first);
-          if (posn == -1) {
-            throw new IllegalArgumentException("Field " + first +
-                " not found in " + current.toString());
-          }
-          current = current.children.get(posn);
-          break;
-        }
-        case LIST:
-          if (first.equals("_elem")) {
-            current = current.getChildren().get(0);
-          } else {
-            throw new IllegalArgumentException("Field " + first +
-                "not found in " + current.toString());
-          }
-          break;
-        case MAP:
-          if (first.equals("_key")) {
-            current = current.getChildren().get(0);
-          } else if (first.equals("_value")) {
-            current = current.getChildren().get(1);
-          } else {
-            throw new IllegalArgumentException("Field " + first +
-                "not found in " + current.toString());
-          }
-          break;
-        case UNION: {
-          try {
-            int posn = Integer.parseInt(first);
-            if (posn < 0 || posn >= current.getChildren().size()) {
-              throw new NumberFormatException("off end of union");
-            }
-            current = current.getChildren().get(posn);
-          } catch (NumberFormatException e) {
-            throw new IllegalArgumentException("Field " + first +
-                "not found in " + current.toString(), e);
-          }
-          break;
-        }
-        default:
-          throw new IllegalArgumentException("Field " + first +
-              "not found in " + current.toString());
-      }
-    }
-    return current;
-  }
-
-  /**
    * Find a subtype of this schema by name.
    * If the name is a simple integer, it will be used as a column number.
    * Otherwise, this routine will recursively search for the name.
@@ -1137,9 +866,9 @@ public class TypeDescription
    * @return the subtype
    */
   public TypeDescription findSubtype(String columnName) {
-    StringPosition source = new StringPosition(columnName);
-    TypeDescription result = findSubtype(source);
-    if (source.position != source.length) {
+    ParserUtils.StringPosition source = new ParserUtils.StringPosition(columnName);
+    TypeDescription result = ParserUtils.findSubtype(this, source);
+    if (source.hasCharactersLeft()) {
       throw new IllegalArgumentException("Remaining text in parsing field name "
           + source);
     }
@@ -1154,20 +883,32 @@ public class TypeDescription
    * @return the list of subtypes that correspond to the column names
    */
   public List<TypeDescription> findSubtypes(String columnNameList) {
-    StringPosition source = new StringPosition(columnNameList);
-    List<TypeDescription> result = new ArrayList<>();
-    boolean needComma = false;
-    while (source.position != source.length) {
-      if (needComma) {
-        if (!consumeChar(source, ',')) {
-          throw new IllegalArgumentException("Comma expected in list of column"
-              + " names at " + source);
-        }
-      } else {
-        needComma = true;
-      }
-      result.add(findSubtype(source));
+    ParserUtils.StringPosition source = new ParserUtils.StringPosition(columnNameList);
+    List<TypeDescription> result = ParserUtils.findSubtypeList(this, source);
+    if (source.hasCharactersLeft()) {
+      throw new IllegalArgumentException("Remaining text in parsing field name "
+          + source);
     }
     return result;
   }
+
+  /**
+   * Annotate a schema with the encryption keys and masks.
+   * @param encryption the encryption keys and the fields
+   * @param masks the encryption masks and the fields
+   */
+  public void annotateEncryption(String encryption, String masks) {
+    ParserUtils.StringPosition source = new ParserUtils.StringPosition(encryption);
+    ParserUtils.parseKeys(source, this);
+    if (source.hasCharactersLeft()) {
+      throw new IllegalArgumentException("Remaining text in parsing encryption keys "
+          + source);
+    }
+    source = new ParserUtils.StringPosition(masks);
+    ParserUtils.parseMasks(source, this);
+    if (source.hasCharactersLeft()) {
+      throw new IllegalArgumentException("Remaining text in parsing encryption masks "
+          + source);
+    }
+  }
 }
diff --git a/java/core/src/java/org/apache/orc/impl/CryptoUtils.java b/java/core/src/java/org/apache/orc/impl/CryptoUtils.java
index 7d88796..eb4eace 100644
--- a/java/core/src/java/org/apache/orc/impl/CryptoUtils.java
+++ b/java/core/src/java/org/apache/orc/impl/CryptoUtils.java
@@ -18,8 +18,16 @@
 
 package org.apache.orc.impl;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.orc.InMemoryKeystore;
+import org.apache.orc.OrcConf;
 import org.apache.orc.OrcProto;
 
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import java.util.ServiceLoader;
 import java.util.function.Consumer;
 
 /**
@@ -110,4 +118,48 @@ public class CryptoUtils {
       iv[i] = 0;
     }
   }
+
+  /** A cache for the key providers */
+  private static final Map<String, KeyProvider> keyProviderCache = new HashMap<>();
+
+  /**
+   * Create a KeyProvider.
+   * It will cache the result, so that only one provider of each kind will be
+   * created.
+   *
+   * @param random the random generator to use
+   * @return the new KeyProvider
+   */
+  public static KeyProvider getKeyProvider(Configuration conf,
+                                           Random random) throws IOException {
+    String kind = OrcConf.KEY_PROVIDER.getString(conf);
+    String cacheKey = kind + "." + random.getClass().getName();
+    KeyProvider result = keyProviderCache.get(cacheKey);
+    if (result == null) {
+      ServiceLoader<KeyProvider.Factory> loader = ServiceLoader.load(KeyProvider.Factory.class);
+      for (KeyProvider.Factory factory : loader) {
+        result = factory.create(kind, conf, random);
+        if (result != null) {
+          keyProviderCache.put(cacheKey, result);
+          break;
+        }
+      }
+    }
+    return result;
+  }
+
+  public static class HadoopKeyProviderFactory implements KeyProvider.Factory {
+
+    @Override
+    public KeyProvider create(String kind,
+                              Configuration conf,
+                              Random random) throws IOException {
+      if ("hadoop".equals(kind)) {
+        return HadoopShimsFactory.get().getHadoopKeyProvider(conf, random);
+      } else if ("memory".equals(kind)) {
+        return new InMemoryKeystore(random);
+      }
+      return null;
+    }
+  }
 }
diff --git a/java/core/src/java/org/apache/orc/impl/ParserUtils.java b/java/core/src/java/org/apache/orc/impl/ParserUtils.java
new file mode 100644
index 0000000..a6d227b
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/ParserUtils.java
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl;
+
+import org.apache.orc.TypeDescription;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.regex.Pattern;
+
+public class ParserUtils {
+
+  static TypeDescription.Category parseCategory(ParserUtils.StringPosition source) {
+    StringBuilder word = new StringBuilder();
+    boolean hadSpace = true;
+    while (source.position < source.length) {
+      char ch = source.value.charAt(source.position);
+      if (Character.isLetter(ch)) {
+        word.append(Character.toLowerCase(ch));
+        hadSpace = false;
+      } else if (ch == ' ') {
+        if (!hadSpace) {
+          hadSpace = true;
+          word.append(ch);
+        }
+      } else {
+        break;
+      }
+      source.position += 1;
+    }
+    String catString = word.toString();
+    // if there were trailing spaces, remove them.
+    if (hadSpace) {
+      catString = catString.trim();
+    }
+    if (!catString.isEmpty()) {
+      for (TypeDescription.Category cat : TypeDescription.Category.values()) {
+        if (cat.getName().equals(catString)) {
+          return cat;
+        }
+      }
+    }
+    throw new IllegalArgumentException("Can't parse category at " + source);
+  }
+
+  static int parseInt(ParserUtils.StringPosition source) {
+    int start = source.position;
+    int result = 0;
+    while (source.position < source.length) {
+      char ch = source.value.charAt(source.position);
+      if (!Character.isDigit(ch)) {
+        break;
+      }
+      result = result * 10 + (ch - '0');
+      source.position += 1;
+    }
+    if (source.position == start) {
+      throw new IllegalArgumentException("Missing integer at " + source);
+    }
+    return result;
+  }
+
+  static String parseName(ParserUtils.StringPosition source) {
+    if (source.position == source.length) {
+      throw new IllegalArgumentException("Missing name at " + source);
+    }
+    final int start = source.position;
+    if (source.value.charAt(source.position) == '`') {
+      source.position += 1;
+      StringBuilder buffer = new StringBuilder();
+      boolean closed = false;
+      while (source.position < source.length) {
+        char ch = source.value.charAt(source.position);
+        source.position += 1;
+        if (ch == '`') {
+          if (source.position < source.length &&
+                  source.value.charAt(source.position) == '`') {
+            source.position += 1;
+            buffer.append('`');
+          } else {
+            closed = true;
+            break;
+          }
+        } else {
+          buffer.append(ch);
+        }
+      }
+      if (!closed) {
+        source.position = start;
+        throw new IllegalArgumentException("Unmatched quote at " + source);
+      } else if (buffer.length() == 0) {
+        throw new IllegalArgumentException("Empty quoted field name at " + source);
+      }
+      return buffer.toString();
+    } else {
+      while (source.position < source.length) {
+        char ch = source.value.charAt(source.position);
+        if (!Character.isLetterOrDigit(ch) && ch != '_') {
+          break;
+        }
+        source.position += 1;
+      }
+      if (source.position == start) {
+        throw new IllegalArgumentException("Missing name at " + source);
+      }
+      return source.value.substring(start, source.position);
+    }
+  }
+
+  static void requireChar(ParserUtils.StringPosition source, char required) {
+    if (source.position >= source.length ||
+            source.value.charAt(source.position) != required) {
+      throw new IllegalArgumentException("Missing required char '" +
+              required + "' at " + source);
+    }
+    source.position += 1;
+  }
+
+  private static boolean consumeChar(ParserUtils.StringPosition source,
+                                     char ch) {
+    boolean result = source.position < source.length &&
+            source.value.charAt(source.position) == ch;
+    if (result) {
+      source.position += 1;
+    }
+    return result;
+  }
+
+  private static void parseUnion(TypeDescription type,
+                                 ParserUtils.StringPosition source) {
+    requireChar(source, '<');
+    do {
+      type.addUnionChild(parseType(source));
+    } while (consumeChar(source, ','));
+    requireChar(source, '>');
+  }
+
+  private static void parseStruct(TypeDescription type,
+                                  ParserUtils.StringPosition source) {
+    requireChar(source, '<');
+    boolean needComma = false;
+    while (!consumeChar(source, '>')) {
+      if (needComma) {
+        requireChar(source, ',');
+      } else {
+        needComma = true;
+      }
+      String fieldName = parseName(source);
+      requireChar(source, ':');
+      type.addField(fieldName, parseType(source));
+    }
+  }
+
+  public static TypeDescription parseType(ParserUtils.StringPosition source) {
+    TypeDescription result = new TypeDescription(parseCategory(source));
+    switch (result.getCategory()) {
+      case BINARY:
+      case BOOLEAN:
+      case BYTE:
+      case DATE:
+      case DOUBLE:
+      case FLOAT:
+      case INT:
+      case LONG:
+      case SHORT:
+      case STRING:
+      case TIMESTAMP:
+      case TIMESTAMP_INSTANT:
+        break;
+      case CHAR:
+      case VARCHAR:
+        requireChar(source, '(');
+        result.withMaxLength(parseInt(source));
+        requireChar(source, ')');
+        break;
+      case DECIMAL: {
+        requireChar(source, '(');
+        int precision = parseInt(source);
+        requireChar(source, ',');
+        result.withScale(parseInt(source));
+        result.withPrecision(precision);
+        requireChar(source, ')');
+        break;
+      }
+      case LIST: {
+        requireChar(source, '<');
+        TypeDescription child = parseType(source);
+        result.addChild(child);
+        requireChar(source, '>');
+        break;
+      }
+      case MAP: {
+        requireChar(source, '<');
+        TypeDescription keyType = parseType(source);
+        result.addChild(keyType);
+        requireChar(source, ',');
+        TypeDescription valueType = parseType(source);
+        result.addChild(valueType);
+        requireChar(source, '>');
+        break;
+      }
+      case UNION:
+        parseUnion(result, source);
+        break;
+      case STRUCT:
+        parseStruct(result, source);
+        break;
+      default:
+        throw new IllegalArgumentException("Unknown type " +
+            result.getCategory() + " at " + source);
+    }
+    return result;
+  }
+
+  /**
+   * Split a compound name into parts separated by '.'.
+   * @param source the string to parse into simple names
+   * @return a list of simple names from the source
+   */
+  private static List<String> splitName(ParserUtils.StringPosition source) {
+    List<String> result = new ArrayList<>();
+    do {
+      result.add(parseName(source));
+    } while (consumeChar(source, '.'));
+    return result;
+  }
+
+
+  private static final Pattern INTEGER_PATTERN = Pattern.compile("^[0-9]+$");
+
+  public static TypeDescription findSubtype(TypeDescription schema,
+                                            ParserUtils.StringPosition source) {
+    List<String> names = ParserUtils.splitName(source);
+    if (names.size() == 1 && INTEGER_PATTERN.matcher(names.get(0)).matches()) {
+      return schema.findSubtype(Integer.parseInt(names.get(0)));
+    }
+    TypeDescription current = SchemaEvolution.checkAcidSchema(schema)
+        ? SchemaEvolution.getBaseRow(schema) : schema;
+    while (names.size() > 0) {
+      String first = names.remove(0);
+      switch (current.getCategory()) {
+        case STRUCT: {
+          int posn = current.getFieldNames().indexOf(first);
+          if (posn == -1) {
+            throw new IllegalArgumentException("Field " + first +
+                " not found in " + current.toString());
+          }
+          current = current.getChildren().get(posn);
+          break;
+        }
+        case LIST:
+          if (first.equals("_elem")) {
+            current = current.getChildren().get(0);
+          } else {
+            throw new IllegalArgumentException("Field " + first +
+                "not found in " + current.toString());
+          }
+          break;
+        case MAP:
+          if (first.equals("_key")) {
+            current = current.getChildren().get(0);
+          } else if (first.equals("_value")) {
+            current = current.getChildren().get(1);
+          } else {
+            throw new IllegalArgumentException("Field " + first +
+                "not found in " + current.toString());
+          }
+          break;
+        case UNION: {
+          try {
+            int posn = Integer.parseInt(first);
+            if (posn < 0 || posn >= current.getChildren().size()) {
+              throw new NumberFormatException("off end of union");
+            }
+            current = current.getChildren().get(posn);
+          } catch (NumberFormatException e) {
+            throw new IllegalArgumentException("Field " + first +
+                "not found in " + current.toString(), e);
+          }
+          break;
+        }
+        default:
+          throw new IllegalArgumentException("Field " + first +
+              "not found in " + current.toString());
+      }
+    }
+    return current;
+  }
+
+  public static List<TypeDescription> findSubtypeList(TypeDescription schema,
+                                                      StringPosition source) {
+    List<TypeDescription> result = new ArrayList<>();
+    if (source.hasCharactersLeft()) {
+      do {
+        result.add(findSubtype(schema, source));
+      } while (consumeChar(source, ','));
+    }
+    return result;
+  }
+
+  public static class StringPosition {
+    final String value;
+    int position;
+    final int length;
+
+    public StringPosition(String value) {
+      this.value = value == null ? "" : value;
+      position = 0;
+      length = this.value.length();
+    }
+
+    @Override
+    public String toString() {
+      return '\'' + value.substring(0, position) + '^' +
+          value.substring(position) + '\'';
+    }
+
+    public String fromPosition(int start) {
+      return value.substring(start, this.position);
+    }
+
+    public boolean hasCharactersLeft() {
+      return position != length;
+    }
+  }
+
+  /**
+   * Annotate the given schema with the encryption information.
+   *
+   * Format of the string is a key-list.
+   * <ul>
+   *   <li>key-list = key (';' key-list)?</li>
+   *   <li>key = key-name ':' field-list</li>
+   *   <li>field-list = field-name ( ',' field-list )?</li>
+   *   <li>field-name = number | field-part ('.' field-name)?</li>
+   *   <li>field-part = quoted string | simple name</li>
+   * </ul>
+   *
+   * @param source the string to parse
+   * @param schema the top level schema
+   * @throw IllegalArgumentException if there are conflicting keys for a field
+   */
+  public static void parseKeys(StringPosition source, TypeDescription schema) {
+    if (source.hasCharactersLeft()) {
+      do {
+        String keyName = parseName(source);
+        requireChar(source, ':');
+        for (TypeDescription field : findSubtypeList(schema, source)) {
+          String prev = field.getAttributeValue(TypeDescription.ENCRYPT_ATTRIBUTE);
+          if (prev != null && !prev.equals(keyName)) {
+            throw new IllegalArgumentException("Conflicting encryption keys " +
+                keyName + " and " + prev);
+          }
+          field.setAttribute(TypeDescription.ENCRYPT_ATTRIBUTE, keyName);
+        }
+      } while (consumeChar(source, ';'));
+    }
+  }
+
+  /**
+   * Annotate the given schema with the masking information.
+   *
+   * Format of the string is a mask-list.
+   * <ul>
+   *   <li>mask-list = mask (';' mask-list)?</li>
+   *   <li>mask = mask-name (',' parameter)* ':' field-list</li>
+   *   <li>field-list = field-name ( ',' field-list )?</li>
+   *   <li>field-name = number | field-part ('.' field-name)?</li>
+   *   <li>field-part = quoted string | simple name</li>
+   * </ul>
+   *
+   * @param source the string to parse
+   * @param schema the top level schema
+   * @throw IllegalArgumentException if there are conflicting masks for a field
+   */
+  public static void parseMasks(StringPosition source, TypeDescription schema) {
+    if (source.hasCharactersLeft()) {
+      do {
+        // parse the mask and parameters, but only get the underlying string
+        int start = source.position;
+        parseName(source);
+        while (consumeChar(source, ',')) {
+          parseName(source);
+        }
+        String maskString = source.fromPosition(start);
+        requireChar(source, ':');
+        for (TypeDescription field : findSubtypeList(schema, source)) {
+          String prev = field.getAttributeValue(TypeDescription.MASK_ATTRIBUTE);
+          if (prev != null && !prev.equals(maskString)) {
+            throw new IllegalArgumentException("Conflicting encryption masks " +
+                maskString + " and " + prev);
+          }
+          field.setAttribute(TypeDescription.MASK_ATTRIBUTE, maskString);
+        }
+      } while (consumeChar(source, ';'));
+    }
+  }
+
+  public static MaskDescriptionImpl buildMaskDescription(String value) {
+    StringPosition source = new StringPosition(value);
+    String maskName = parseName(source);
+    List<String> params = new ArrayList<>();
+    while (consumeChar(source, ',')) {
+      params.add(parseName(source));
+    }
+    return new MaskDescriptionImpl(maskName,
+        params.toArray(new String[params.size()]));
+  }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/WriterImpl.java b/java/core/src/java/org/apache/orc/impl/WriterImpl.java
index b914263..e556470 100644
--- a/java/core/src/java/org/apache/orc/impl/WriterImpl.java
+++ b/java/core/src/java/org/apache/orc/impl/WriterImpl.java
@@ -25,10 +25,8 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.SortedMap;
-import java.util.SortedSet;
 import java.util.TimeZone;
 import java.util.TreeMap;
-import java.util.TreeSet;
 
 import io.airlift.compress.lz4.Lz4Compressor;
 import io.airlift.compress.lz4.Lz4Decompressor;
@@ -124,14 +122,14 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback {
       new ArrayList<>();
 
   // the list of maskDescriptions, keys, and variants
-  private SortedSet<MaskDescriptionImpl> maskDescriptions = new TreeSet<>();
+  private SortedMap<String, MaskDescriptionImpl> maskDescriptions = new TreeMap<>();
   private SortedMap<String, WriterEncryptionKey> keys = new TreeMap<>();
   private final WriterEncryptionVariant[] encryption;
   // the mapping of columns to maskDescriptions
   private final MaskDescriptionImpl[] columnMaskDescriptions;
   // the mapping of columns to EncryptionVariants
   private final WriterEncryptionVariant[] columnEncryption;
-  private HadoopShims.KeyProvider keyProvider;
+  private KeyProvider keyProvider;
   // do we need to include the current encryption keys in the next stripe
   // information
   private boolean needKeyFlush;
@@ -139,23 +137,24 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback {
   public WriterImpl(FileSystem fs,
                     Path path,
                     OrcFile.WriterOptions opts) throws IOException {
-    this.schema = opts.getSchema();
+    // clone it so that we can annotate it with encryption
+    this.schema = opts.getSchema().clone();
     int numColumns = schema.getMaximumId() + 1;
     if (!opts.isEnforceBufferSize()) {
       opts.bufferSize(getEstimatedBufferSize(opts.getStripeSize(), numColumns,
           opts.getBufferSize()));
     }
 
-    // Do we have column encryption?
-    List<OrcFile.EncryptionOption> encryptionOptions = opts.getEncryption();
+    // Annotate the schema with the column encryption
+    schema.annotateEncryption(opts.getEncryption(), opts.getMasks());
     columnEncryption = new WriterEncryptionVariant[numColumns];
-    if (encryptionOptions.isEmpty()) {
+    if (opts.getEncryption() == null) {
       columnMaskDescriptions = null;
       encryption = new WriterEncryptionVariant[0];
       needKeyFlush = false;
     } else {
       columnMaskDescriptions = new MaskDescriptionImpl[numColumns];
-      encryption = setupEncryption(opts.getKeyProvider(), encryptionOptions);
+      encryption = setupEncryption(opts.getKeyProvider(), schema);
       needKeyFlush = true;
     }
 
@@ -586,7 +585,7 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback {
 
   private OrcProto.Encryption.Builder writeEncryptionFooter() {
     OrcProto.Encryption.Builder encrypt = OrcProto.Encryption.newBuilder();
-    for(MaskDescriptionImpl mask: maskDescriptions) {
+    for(MaskDescriptionImpl mask: maskDescriptions.values()) {
       OrcProto.DataMask.Builder maskBuilder = OrcProto.DataMask.newBuilder();
       maskBuilder.setName(mask.getName());
       for(String param: mask.getParameters()) {
@@ -803,7 +802,7 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback {
   }
 
   WriterEncryptionKey getKey(String keyName,
-                             HadoopShims.KeyProvider provider) throws IOException {
+                             KeyProvider provider) throws IOException {
     WriterEncryptionKey result = keys.get(keyName);
     if (result == null) {
       result = new WriterEncryptionKey(provider.getCurrentKeyVersion(keyName));
@@ -812,12 +811,43 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback {
     return result;
   }
 
-  MaskDescriptionImpl getMask(OrcFile.EncryptionOption opt) {
-    MaskDescriptionImpl result = new MaskDescriptionImpl(opt.getMask(),
-        opt.getMaskParameters());
+  MaskDescriptionImpl getMask(String maskString) {
+    MaskDescriptionImpl result = maskDescriptions.get(maskString);
     // if it is already there, get the earlier object
-    if (!maskDescriptions.add(result)) {
-      result = maskDescriptions.tailSet(result).first();
+    if (result == null) {
+      result = ParserUtils.buildMaskDescription(maskString);
+      maskDescriptions.put(maskString, result);
+    }
+    return result;
+  }
+
+  int visitTypeTree(TypeDescription schema,
+                    boolean encrypted,
+                    KeyProvider provider) throws IOException {
+    int result = 0;
+    String keyName = schema.getAttributeValue(TypeDescription.ENCRYPT_ATTRIBUTE);
+    String maskName = schema.getAttributeValue(TypeDescription.MASK_ATTRIBUTE);
+    if (keyName != null) {
+      if (encrypted) {
+        throw new IllegalArgumentException("Nested encryption type: " + schema);
+      }
+      encrypted = true;
+      result += 1;
+      WriterEncryptionKey key = getKey(keyName, provider);
+      HadoopShims.KeyMetadata metadata = key.getMetadata();
+      WriterEncryptionVariant variant = new WriterEncryptionVariant(key,
+          schema, keyProvider.createLocalKey(metadata));
+      key.addRoot(variant);
+    }
+    if (encrypted && (keyName != null || maskName != null)) {
+      MaskDescriptionImpl mask = getMask(maskName == null ? "nullify" : maskName);
+      mask.addColumn(schema);
+    }
+    List<TypeDescription> children = schema.getChildren();
+    if (children != null) {
+      for(TypeDescription child: children) {
+        result += visitTypeTree(child, encrypted, provider);
+      }
     }
     return result;
   }
@@ -826,37 +856,21 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback {
    * Iterate through the encryption options given by the user and set up
    * our data structures.
    * @param provider the KeyProvider to use to generate keys
-   * @param options the options from the user
+   * @param schema the a
    */
-  WriterEncryptionVariant[] setupEncryption(HadoopShims.KeyProvider provider,
-                                            List<OrcFile.EncryptionOption> options
+  WriterEncryptionVariant[] setupEncryption(KeyProvider provider,
+                                            TypeDescription schema
                                             ) throws IOException {
     keyProvider = provider != null ? provider :
-                      SHIMS.getKeyProvider(conf, new SecureRandom());
+                      CryptoUtils.getKeyProvider(conf, new SecureRandom());
     if (keyProvider == null) {
       throw new IllegalArgumentException("Encryption requires a KeyProvider.");
     }
-    // fill out the primary encryption keys
-    int variantCount = 0;
-    for(OrcFile.EncryptionOption option: options) {
-      MaskDescriptionImpl mask = getMask(option);
-      for(TypeDescription col: schema.findSubtypes(option.getColumnNames())) {
-        mask.addColumn(col);
-      }
-      if (option.getKeyName() != null) {
-        WriterEncryptionKey key = getKey(option.getKeyName(), keyProvider);
-        HadoopShims.KeyMetadata metadata = key.getMetadata();
-        for(TypeDescription rootType: schema.findSubtypes(option.getColumnNames())) {
-          WriterEncryptionVariant variant = new WriterEncryptionVariant(key,
-              rootType, keyProvider.createLocalKey(metadata));
-          key.addRoot(variant);
-          variantCount += 1;
-        }
-      }
-    }
+    int variantCount = visitTypeTree(schema, false, provider);
+
     // Now that we have de-duped the keys and maskDescriptions, make the arrays
     int nextId = 0;
-    for (MaskDescriptionImpl mask: maskDescriptions) {
+    for (MaskDescriptionImpl mask: maskDescriptions.values()) {
       mask.setId(nextId++);
       for(TypeDescription column: mask.getColumns()) {
         this.columnMaskDescriptions[column.getId()] = mask;
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java
index c647c10..fdf7e1c 100644
--- a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java
@@ -22,17 +22,16 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.orc.OrcProto;
 import org.apache.orc.StripeInformation;
 import org.apache.orc.TypeDescription;
-import org.apache.orc.impl.HadoopShims;
-import org.apache.orc.impl.HadoopShimsFactory;
+import org.apache.orc.impl.CryptoUtils;
+import org.apache.orc.impl.KeyProvider;
 import org.apache.orc.impl.MaskDescriptionImpl;
 
 import java.io.IOException;
 import java.security.SecureRandom;
-import java.util.Arrays;
 import java.util.List;
 
 public class ReaderEncryption {
-  private final HadoopShims.KeyProvider keyProvider;
+  private final KeyProvider keyProvider;
   private final ReaderEncryptionKey[] keys;
   private final MaskDescriptionImpl[] masks;
   private final ReaderEncryptionVariant[] variants;
@@ -51,7 +50,7 @@ public class ReaderEncryption {
   public ReaderEncryption(OrcProto.Footer footer,
                           TypeDescription schema,
                           List<StripeInformation> stripes,
-                          HadoopShims.KeyProvider provider,
+                          KeyProvider provider,
                           Configuration conf) throws IOException {
     if (footer == null || !footer.hasEncryption()) {
       keyProvider = null;
@@ -61,7 +60,7 @@ public class ReaderEncryption {
       columnVariants = null;
     } else {
       keyProvider = provider != null ? provider :
-          HadoopShimsFactory.get().getKeyProvider(conf, new SecureRandom());
+          CryptoUtils.getKeyProvider(conf, new SecureRandom());
       OrcProto.Encryption encrypt = footer.getEncryption();
       masks = new MaskDescriptionImpl[encrypt.getMaskCount()];
       for(int m=0; m < masks.length; ++m) {
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java
index 255952d..c20bd8e 100644
--- a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java
@@ -24,7 +24,7 @@ import org.apache.orc.EncryptionVariant;
 import org.apache.orc.OrcProto;
 import org.apache.orc.StripeInformation;
 import org.apache.orc.TypeDescription;
-import org.apache.orc.impl.HadoopShims;
+import org.apache.orc.impl.KeyProvider;
 import org.apache.orc.impl.LocalKey;
 import org.jetbrains.annotations.NotNull;
 import org.slf4j.Logger;
@@ -42,7 +42,7 @@ import java.util.Map;
 public class ReaderEncryptionVariant implements EncryptionVariant {
   private static final Logger LOG =
       LoggerFactory.getLogger(ReaderEncryptionVariant.class);
-  private final HadoopShims.KeyProvider provider;
+  private final KeyProvider provider;
   private final ReaderEncryptionKey key;
   private final TypeDescription column;
   private final int variantId;
@@ -63,7 +63,7 @@ public class ReaderEncryptionVariant implements EncryptionVariant {
                                  OrcProto.EncryptionVariant proto,
                                  TypeDescription schema,
                                  List<StripeInformation> stripes,
-                                 HadoopShims.KeyProvider provider) {
+                                 KeyProvider provider) {
     this.key = key;
     this.variantId = variantId;
     this.provider = provider;
diff --git a/java/shims/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory b/java/core/src/resources/META-INF/services/org.apache.orc.impl.KeyProvider$Factory
similarity index 92%
copy from java/shims/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory
copy to java/core/src/resources/META-INF/services/org.apache.orc.impl.KeyProvider$Factory
index 14ee9a5..da1659f 100644
--- a/java/shims/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory
+++ b/java/core/src/resources/META-INF/services/org.apache.orc.impl.KeyProvider$Factory
@@ -12,5 +12,4 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-org.apache.orc.impl.TestHadoopShimsPre2_7$TestKeyProviderFactory
+org.apache.orc.impl.CryptoUtils$HadoopKeyProviderFactory
\ No newline at end of file
diff --git a/java/core/src/test/org/apache/orc/TestTypeDescription.java b/java/core/src/test/org/apache/orc/TestTypeDescription.java
index 4a6a199..7519ff2 100644
--- a/java/core/src/test/org/apache/orc/TestTypeDescription.java
+++ b/java/core/src/test/org/apache/orc/TestTypeDescription.java
@@ -19,6 +19,7 @@ package org.apache.orc;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
 
 import org.apache.hadoop.conf.Configuration;
@@ -375,4 +376,84 @@ public class TestTypeDescription {
     assertEquals("nullify", street.getAttributeValue("mask"));
     assertEquals(null, street.getAttributeValue("foobar"));
   }
+
+  static int clearAttributes(TypeDescription schema) {
+    int result = 0;
+    for(String attribute: schema.getAttributeNames()) {
+      schema.removeAttribute(attribute);
+      result += 1;
+    }
+    List<TypeDescription> children = schema.getChildren();
+    if (children != null) {
+      for (TypeDescription child : children) {
+        result += clearAttributes(child);
+      }
+    }
+    return result;
+  }
+
+  @Test
+  public void testEncryption() {
+    String schemaString =  "struct<" +
+        "name:struct<first:string,last:string>," +
+        "address:struct<street:string,city:string,country:string,post_code:string>," +
+        "credit_cards:array<struct<card_number:string,expire:date,ccv:string>>>";
+    TypeDescription schema = TypeDescription.fromString(schemaString);
+    TypeDescription copy = TypeDescription.fromString(schemaString);
+    assertEquals(copy, schema);
+
+    // set some encryption
+    schema.annotateEncryption("pii:name,address.street;credit:credit_cards", null);
+    assertEquals("pii",
+        schema.findSubtype("name").getAttributeValue(TypeDescription.ENCRYPT_ATTRIBUTE));
+    assertEquals("pii",
+        schema.findSubtype("address.street").getAttributeValue(TypeDescription.ENCRYPT_ATTRIBUTE));
+    assertEquals("credit",
+        schema.findSubtype("credit_cards").getAttributeValue(TypeDescription.ENCRYPT_ATTRIBUTE));
+    assertNotEquals(copy, schema);
+    assertEquals(3, clearAttributes(schema));
+    assertEquals(copy, schema);
+
+    schema.annotateEncryption("pii:name.first", "redact,Yy:name.first");
+    // check that we ignore if already set
+    schema.annotateEncryption("pii:name.first", "redact,Yy:name.first,credit_cards");
+    assertEquals("pii",
+        schema.findSubtype("name.first").getAttributeValue(TypeDescription.ENCRYPT_ATTRIBUTE));
+    assertEquals("redact,Yy",
+        schema.findSubtype("name.first").getAttributeValue(TypeDescription.MASK_ATTRIBUTE));
+    assertEquals("redact,Yy",
+        schema.findSubtype("credit_cards").getAttributeValue(TypeDescription.MASK_ATTRIBUTE));
+    assertEquals(3, clearAttributes(schema));
+
+    schema.annotateEncryption("pii:name", "redact:name.first;nullify:name.last");
+    assertEquals("pii",
+        schema.findSubtype("name").getAttributeValue(TypeDescription.ENCRYPT_ATTRIBUTE));
+    assertEquals("redact",
+        schema.findSubtype("name.first").getAttributeValue(TypeDescription.MASK_ATTRIBUTE));
+    assertEquals("nullify",
+        schema.findSubtype("name.last").getAttributeValue(TypeDescription.MASK_ATTRIBUTE));
+    assertEquals(3, clearAttributes(schema));
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testEncryptionConflict() {
+    TypeDescription schema = TypeDescription.fromString(
+        "struct<" +
+            "name:struct<first:string,last:string>," +
+            "address:struct<street:string,city:string,country:string,post_code:string>," +
+            "credit_cards:array<struct<card_number:string,expire:date,ccv:string>>>");
+    // set some encryption
+    schema.annotateEncryption("pii:address,personal:address",null);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testMaskConflict() {
+    TypeDescription schema = TypeDescription.fromString(
+        "struct<" +
+            "name:struct<first:string,last:string>," +
+            "address:struct<street:string,city:string,country:string,post_code:string>," +
+            "credit_cards:array<struct<card_number:string,expire:date,ccv:string>>>");
+    // set some encryption
+    schema.annotateEncryption(null,"nullify:name;sha256:name");
+  }
 }
diff --git a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java
index 69e1a40..93b0aa7 100644
--- a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java
+++ b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java
@@ -3745,14 +3745,12 @@ public class TestVectorOrcFile {
     InMemoryKeystore keys = new InMemoryKeystore()
         .addKey("pii", EncryptionAlgorithm.AES_CTR_128, piiKey)
         .addKey("credit", EncryptionAlgorithm.AES_CTR_256, creditKey);
-
     Writer writer = OrcFile.createWriter(testFilePath,
         OrcFile.writerOptions(conf)
             .setSchema(schema)
             .version(fileFormat)
             .setKeyProvider(keys)
-            .encryptColumn("i", "pii")
-            .encryptColumn("x", "credit"));
+            .encrypt("pii:i;credit:x"));
     VectorizedRowBatch batch = schema.createRowBatch();
     batch.size = ROWS;
     LongColumnVector i = (LongColumnVector) batch.cols[0];
@@ -3889,12 +3887,7 @@ public class TestVectorOrcFile {
             .version(fileFormat)
             .stripeSize(10000)
             .setKeyProvider(allKeys)
-            .encryptColumn("dec", "key_0")
-            .encryptColumn("dt", "key_1")
-            .encryptColumn("time", "key_2")
-            .encryptColumn("dbl", "key_3")
-            .encryptColumn("bool", "key_4")
-            .encryptColumn("bin", "key_5"));
+            .encrypt("key_0:dec;key_1:dt;key_2:time;key_3:dbl;key_4:bool;key_5:bin"));
     // Set size to 1000 precisely so that stripes are exactly 5000 rows long.
     VectorizedRowBatch batch = schema.createRowBatch(1000);
     DecimalColumnVector dec = (DecimalColumnVector) batch.cols[0];
diff --git a/java/core/src/test/org/apache/orc/impl/TestCryptoUtils.java b/java/core/src/test/org/apache/orc/impl/TestCryptoUtils.java
index a2caa2e..2811bbf 100644
--- a/java/core/src/test/org/apache/orc/impl/TestCryptoUtils.java
+++ b/java/core/src/test/org/apache/orc/impl/TestCryptoUtils.java
@@ -18,14 +18,20 @@
 
 package org.apache.orc.impl;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.BytesWritable;
 import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.InMemoryKeystore;
+import org.apache.orc.OrcConf;
 import org.apache.orc.OrcProto;
 import org.junit.Test;
 
-import java.util.Arrays;
+import java.io.IOException;
+import java.security.Key;
+import java.util.List;
+import java.util.Random;
 
 import static junit.framework.Assert.assertEquals;
-import static org.junit.Assert.assertNotEquals;
 
 public class TestCryptoUtils {
 
@@ -45,4 +51,39 @@ public class TestCryptoUtils {
     assertEquals(0x34, iv[6]);
     assertEquals(0x56, iv[7]);
   }
+
+  @Test
+  public void testMemoryKeyProvider() throws IOException {
+    Configuration conf = new Configuration();
+    OrcConf.KEY_PROVIDER.setString(conf, "memory");
+    // Hard code the random so that we know the bytes that will come out.
+    InMemoryKeystore provider =
+        (InMemoryKeystore) CryptoUtils.getKeyProvider(conf, new Random(24));
+    byte[] piiKey = new byte[]{0,1,2,3,4,5,6,7,8,9,0xa,0xb,0xc,0xd,0xe,0xf};
+    provider.addKey("pii", EncryptionAlgorithm.AES_CTR_128, piiKey);
+    byte[] piiKey2 = new byte[]{0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
+        0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f};
+    provider.addKey("pii", 1, EncryptionAlgorithm.AES_CTR_128, piiKey2);
+    byte[] secretKey = new byte[]{0x20,0x21,0x22,0x23,0x24,0x25,0x26,0x27,
+        0x28,0x29,0x2a,0x2b,0x2c,0x2d,0x2e,0x2f};
+    provider.addKey("secret", EncryptionAlgorithm.AES_CTR_128, secretKey);
+
+    List<String> keyNames = provider.getKeyNames();
+    assertEquals(2, keyNames.size());
+    assertEquals(true, keyNames.contains("pii"));
+    assertEquals(true, keyNames.contains("secret"));
+    HadoopShims.KeyMetadata meta = provider.getCurrentKeyVersion("pii");
+    assertEquals(1, meta.getVersion());
+    LocalKey localKey = provider.createLocalKey(meta);
+    byte[] encrypted = localKey.getEncryptedKey();
+    // make sure that we get exactly what we expect to test the encryption
+    assertEquals("c7 ab 4f bb 38 f4 de ad d0 b3 59 e2 21 2a 95 32",
+        new BytesWritable(encrypted).toString());
+    // now check to make sure that we get the expected bytes back
+    assertEquals("c7 a1 d0 41 7b 24 72 44 1a 58 c7 72 4a d4 be b3",
+        new BytesWritable(localKey.getDecryptedKey().getEncoded()).toString());
+    Key key = provider.decryptLocalKey(meta, encrypted);
+    assertEquals(new BytesWritable(localKey.getDecryptedKey().getEncoded()).toString(),
+        new BytesWritable(key.getEncoded()).toString());
+  }
 }
diff --git a/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java b/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java
index 333bc98..6028307 100644
--- a/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java
+++ b/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java
@@ -277,7 +277,7 @@ public class TestPhysicalFsWriter {
     }
 
     @Override
-    public KeyProvider getKeyProvider(Configuration conf, Random random) {
+    public KeyProvider getHadoopKeyProvider(Configuration conf, Random random) {
       return null;
     }
   }
diff --git a/java/mapreduce/src/java/org/apache/orc/mapred/OrcOutputFormat.java b/java/mapreduce/src/java/org/apache/orc/mapred/OrcOutputFormat.java
index 341fbcd..2322d1b 100644
--- a/java/mapreduce/src/java/org/apache/orc/mapred/OrcOutputFormat.java
+++ b/java/mapreduce/src/java/org/apache/orc/mapred/OrcOutputFormat.java
@@ -61,7 +61,9 @@ public class OrcOutputFormat<V extends Writable>
         .stripeSize(OrcConf.STRIPE_SIZE.getLong(conf))
         .rowIndexStride((int) OrcConf.ROW_INDEX_STRIDE.getLong(conf))
         .bufferSize((int) OrcConf.BUFFER_SIZE.getLong(conf))
-        .paddingTolerance(OrcConf.BLOCK_PADDING_TOLERANCE.getDouble(conf));
+        .paddingTolerance(OrcConf.BLOCK_PADDING_TOLERANCE.getDouble(conf))
+        .encrypt(OrcConf.ENCRYPTION.getString(conf))
+        .masks(OrcConf.DATA_MASK.getString(conf));
   }
 
   @Override
diff --git a/java/pom.xml b/java/pom.xml
index 4cf999b..f7efa92 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -218,6 +218,7 @@
               <exclude>**/*.md</exclude>
               <exclude>**/target/**</exclude>
               <exclude>.idea/**</exclude>
+              <exclude>**/*.iml</exclude>
             </excludes>
           </configuration>
         </plugin>
diff --git a/java/shims/src/java/org/apache/orc/impl/HadoopShims.java b/java/shims/src/java/org/apache/orc/impl/HadoopShims.java
index 69ab8f1..ac58b63 100644
--- a/java/shims/src/java/org/apache/orc/impl/HadoopShims.java
+++ b/java/shims/src/java/org/apache/orc/impl/HadoopShims.java
@@ -26,8 +26,6 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.nio.ByteBuffer;
-import java.security.Key;
-import java.util.List;
 import java.util.Random;
 
 public interface HadoopShims {
@@ -145,55 +143,6 @@ public interface HadoopShims {
   }
 
   /**
-   * A source of crypto keys. This is usually backed by a Ranger KMS.
-   */
-  interface KeyProvider {
-
-    /**
-     * Get the list of key names from the key provider.
-     * @return a list of key names
-     */
-    List<String> getKeyNames() throws IOException;
-
-    /**
-     * Get the current metadata for a given key. This is used when encrypting
-     * new data.
-     *
-     * @param keyName the name of a key
-     * @return metadata for the current version of the key
-     * @throws IllegalArgumentException if the key is unknown
-     */
-    KeyMetadata getCurrentKeyVersion(String keyName) throws IOException;
-
-    /**
-     * Create a local key for the given key version. This local key will be
-     * randomly generated and encrypted with the given version of the master
-     * key. The encryption and decryption is done with the local key and the
-     * user process never has access to the master key, because it stays on the
-     * Ranger KMS.
-     *
-     * @param key the master key version
-     * @return the local key's material both encrypted and unencrypted
-     */
-    LocalKey createLocalKey(KeyMetadata key) throws IOException;
-
-    /**
-     * Decrypt a local key for reading a file.
-     *
-     * @param key the master key version
-     * @param encryptedKey the encrypted key
-     * @return the decrypted local key's material or null if the key is not
-     * available
-     */
-    Key decryptLocalKey(KeyMetadata key, byte[] encryptedKey) throws IOException;
-
-    /**
-     * Get the kind of this provider.
-     */
-    KeyProviderKind getKind();
-  }
-
-  /**
    * Information about a crypto key including the key name, version, and the
    * algorithm.
    */
@@ -233,23 +182,17 @@ public interface HadoopShims {
 
     @Override
     public String toString() {
-      StringBuilder buffer = new StringBuilder();
-      buffer.append(keyName);
-      buffer.append('@');
-      buffer.append(version);
-      buffer.append(' ');
-      buffer.append(algorithm);
-      return buffer.toString();
+      return keyName + '@' + version + ' ' + algorithm;
     }
   }
 
   /**
-   * Create a KeyProvider to get encryption keys.
+   * Create a Hadoop KeyProvider to get encryption keys.
    * @param conf the configuration
    * @param random a secure random number generator
    * @return a key provider or null if none was provided
    */
-  KeyProvider getKeyProvider(Configuration conf,
-                             Random random) throws IOException;
+  KeyProvider getHadoopKeyProvider(Configuration conf,
+                                   Random random) throws IOException;
 
 }
diff --git a/java/shims/src/java/org/apache/orc/impl/HadoopShimsCurrent.java b/java/shims/src/java/org/apache/orc/impl/HadoopShimsCurrent.java
index ff11159..c32ed2e 100644
--- a/java/shims/src/java/org/apache/orc/impl/HadoopShimsCurrent.java
+++ b/java/shims/src/java/org/apache/orc/impl/HadoopShimsCurrent.java
@@ -59,8 +59,8 @@ public class HadoopShimsCurrent implements HadoopShims {
   }
 
   @Override
-  public KeyProvider getKeyProvider(Configuration conf,
-                                    Random random) throws IOException {
+  public KeyProvider getHadoopKeyProvider(Configuration conf,
+                                          Random random) throws IOException {
     return HadoopShimsPre2_7.createKeyProvider(conf, random);
   }
 }
diff --git a/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_3.java b/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_3.java
index c186e99..07e45b2 100644
--- a/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_3.java
+++ b/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_3.java
@@ -55,7 +55,7 @@ public class HadoopShimsPre2_3 implements HadoopShims {
   }
 
   @Override
-  public KeyProvider getKeyProvider(Configuration conf, Random random) {
+  public KeyProvider getHadoopKeyProvider(Configuration conf, Random random) {
     return new NullKeyProvider();
   }
 
diff --git a/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_6.java b/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_6.java
index 618e4c8..75a1145 100644
--- a/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_6.java
+++ b/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_6.java
@@ -129,7 +129,7 @@ public class HadoopShimsPre2_6 implements HadoopShims {
   }
 
   @Override
-  public KeyProvider getKeyProvider(Configuration conf, Random random) {
+  public KeyProvider getHadoopKeyProvider(Configuration conf, Random random) {
     return new HadoopShimsPre2_3.NullKeyProvider();
   }
 }
diff --git a/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_7.java b/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_7.java
index b92be65..552c6c5 100644
--- a/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_7.java
+++ b/java/shims/src/java/org/apache/orc/impl/HadoopShimsPre2_7.java
@@ -238,8 +238,8 @@ public class HadoopShimsPre2_7 implements HadoopShims {
   }
 
   @Override
-  public KeyProvider getKeyProvider(Configuration conf,
-                                    Random random) throws IOException {
+  public KeyProvider getHadoopKeyProvider(Configuration conf,
+                                          Random random) throws IOException {
     return createKeyProvider(conf, random);
   }
 }
diff --git a/java/shims/src/java/org/apache/orc/impl/KeyProvider.java b/java/shims/src/java/org/apache/orc/impl/KeyProvider.java
new file mode 100644
index 0000000..2384020
--- /dev/null
+++ b/java/shims/src/java/org/apache/orc/impl/KeyProvider.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl;
+
+import org.apache.hadoop.conf.Configuration;
+
+import java.io.IOException;
+import java.security.Key;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * A source of crypto keys. This is usually backed by a Ranger KMS.
+ */
+public interface KeyProvider {
+
+  /**
+   * Get the list of key names from the key provider.
+   * @return a list of key names
+   */
+  List<String> getKeyNames() throws IOException;
+
+  /**
+   * Get the current metadata for a given key. This is used when encrypting
+   * new data.
+   *
+   * @param keyName the name of a key
+   * @return metadata for the current version of the key
+   * @throws IllegalArgumentException if the key is unknown
+   */
+  HadoopShims.KeyMetadata getCurrentKeyVersion(String keyName) throws IOException;
+
+  /**
+   * Create a local key for the given key version. This local key will be
+   * randomly generated and encrypted with the given version of the master
+   * key. The encryption and decryption is done with the local key and the
+   * user process never has access to the master key, because it stays on the
+   * Ranger KMS.
+   *
+   * @param key the master key version
+   * @return the local key's material both encrypted and unencrypted
+   */
+  LocalKey createLocalKey(HadoopShims.KeyMetadata key) throws IOException;
+
+  /**
+   * Decrypt a local key for reading a file.
+   *
+   * @param key the master key version
+   * @param encryptedKey the encrypted key
+   * @return the decrypted local key's material or null if the key is not
+   * available
+   */
+  Key decryptLocalKey(HadoopShims.KeyMetadata key, byte[] encryptedKey) throws IOException;
+
+  /**
+   * Get the kind of this provider.
+   */
+  HadoopShims.KeyProviderKind getKind();
+
+  /**
+   * A service loader factory interface.
+   */
+  interface Factory {
+    KeyProvider create(String kind,
+                       Configuration conf,
+                       Random random) throws IOException;
+  }
+}
diff --git a/java/shims/src/test/org/apache/orc/impl/TestHadoopShimsPre2_7.java b/java/shims/src/test/org/apache/orc/impl/TestHadoopShimsPre2_7.java
index 2db90a5..a07fdb1 100644
--- a/java/shims/src/test/org/apache/orc/impl/TestHadoopShimsPre2_7.java
+++ b/java/shims/src/test/org/apache/orc/impl/TestHadoopShimsPre2_7.java
@@ -67,144 +67,4 @@ public class TestHadoopShimsPre2_7 {
     assertEquals(EncryptionAlgorithm.AES_CTR_256,
         HadoopShimsPre2_7.findAlgorithm(meta));
   }
-
-  @Test
-  public void testHadoopKeyProvider() throws IOException {
-    HadoopShims shims = new HadoopShimsPre2_7();
-    Configuration conf = new Configuration();
-    conf.set("hadoop.security.key.provider.path", "test:///");
-    // Hard code the random so that we know the bytes that will come out.
-    HadoopShims.KeyProvider provider = shims.getKeyProvider(conf, new Random(24));
-    List<String> keyNames = provider.getKeyNames();
-    assertEquals(2, keyNames.size());
-    assertEquals(true, keyNames.contains("pii"));
-    assertEquals(true, keyNames.contains("secret"));
-    HadoopShims.KeyMetadata piiKey = provider.getCurrentKeyVersion("pii");
-    assertEquals(1, piiKey.getVersion());
-    LocalKey localKey = provider.createLocalKey(piiKey);
-    byte[] encrypted = localKey.getEncryptedKey();
-    // make sure that we get exactly what we expect to test the encryption
-    assertEquals("c7 ab 4f bb 38 f4 de ad d0 b3 59 e2 21 2a 95 32",
-        new BytesWritable(encrypted).toString());
-    // now check to make sure that we get the expected bytes back
-    assertEquals("c7 a1 d0 41 7b 24 72 44 1a 58 c7 72 4a d4 be b3",
-        new BytesWritable(localKey.getDecryptedKey().getEncoded()).toString());
-    Key key = provider.decryptLocalKey(piiKey, encrypted);
-    assertEquals(new BytesWritable(localKey.getDecryptedKey().getEncoded()).toString(),
-        new BytesWritable(key.getEncoded()).toString());
-  }
-
-  /**
-   * Create a Hadoop KeyProvider that lets us test the interaction
-   * with the Hadoop code.
-   * Must only be used in unit tests!
-   */
-  public static class TestKeyProviderFactory extends KeyProviderFactory {
-
-    @Override
-    public KeyProvider createProvider(URI uri,
-                                      Configuration conf) throws IOException {
-      if ("test".equals(uri.getScheme())) {
-        KeyProvider provider = new TestKeyProvider(conf);
-        // populate a couple keys into the provider
-        byte[] piiKey = new byte[]{0,1,2,3,4,5,6,7,8,9,0xa,0xb,0xc,0xd,0xe,0xf};
-        KeyProvider.Options aes128 = new KeyProvider.Options(conf);
-        provider.createKey("pii", piiKey, aes128);
-        byte[] piiKey2 = new byte[]{0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
-            0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f};
-        provider.rollNewVersion("pii", piiKey2);
-        byte[] secretKey = new byte[]{0x20,0x21,0x22,0x23,0x24,0x25,0x26,0x27,
-            0x28,0x29,0x2a,0x2b,0x2c,0x2d,0x2e,0x2f};
-        provider.createKey("secret", secretKey, aes128);
-        return KeyProviderCryptoExtension.createKeyProviderCryptoExtension(provider);
-      }
-      return null;
-    }
-  }
-
-  /**
-   * A Hadoop KeyProvider that lets us test the interaction
-   * with the Hadoop code.
-   * Must only be used in unit tests!
-   */
-  static class TestKeyProvider extends KeyProvider {
-    // map from key name to metadata
-    private final Map<String, TestMetadata> keyMetdata = new HashMap<>();
-    // map from key version name to material
-    private final Map<String, KeyVersion> keyVersions = new HashMap<>();
-
-    public TestKeyProvider(Configuration conf) {
-      super(conf);
-    }
-
-    @Override
-    public KeyVersion getKeyVersion(String name) {
-      return keyVersions.get(name);
-    }
-
-    @Override
-    public List<String> getKeys() {
-      return new ArrayList<>(keyMetdata.keySet());
-    }
-
-    @Override
-    public List<KeyVersion> getKeyVersions(String name) {
-      List<KeyVersion> result = new ArrayList<>();
-      Metadata meta = getMetadata(name);
-      for(int v=0; v < meta.getVersions(); ++v) {
-        String versionName = buildVersionName(name, v);
-        KeyVersion material = keyVersions.get(versionName);
-        if (material != null) {
-          result.add(material);
-        }
-      }
-      return result;
-    }
-
-    @Override
-    public Metadata getMetadata(String name)  {
-      return keyMetdata.get(name);
-    }
-
-    @Override
-    public KeyVersion createKey(String name, byte[] bytes, Options options) {
-      String versionName = buildVersionName(name, 0);
-      keyMetdata.put(name, new TestMetadata(options.getCipher(),
-          options.getBitLength(), 1));
-      KeyVersion result = new KMSClientProvider.KMSKeyVersion(name, versionName, bytes);
-      keyVersions.put(versionName, result);
-      return result;
-    }
-
-    @Override
-    public void deleteKey(String name) {
-      throw new UnsupportedOperationException("Can't delete keys");
-    }
-
-    @Override
-    public KeyVersion rollNewVersion(String name, byte[] bytes) {
-      TestMetadata key = keyMetdata.get(name);
-      String versionName = buildVersionName(name, key.addVersion());
-      KeyVersion result = new KMSClientProvider.KMSKeyVersion(name, versionName,
-          bytes);
-      keyVersions.put(versionName, result);
-      return result;
-    }
-
-    @Override
-    public void flush() {
-      // Nothing
-    }
-
-    static class TestMetadata extends KeyProvider.Metadata {
-
-      protected TestMetadata(String cipher, int bitLength, int versions) {
-        super(cipher, bitLength, null, null, null, versions);
-      }
-
-      public int addVersion() {
-        return super.addVersion();
-      }
-    }
-  }
 }
diff --git a/java/tools/src/java/org/apache/orc/tools/KeyTool.java b/java/tools/src/java/org/apache/orc/tools/KeyTool.java
index 5339334..ea592bc 100644
--- a/java/tools/src/java/org/apache/orc/tools/KeyTool.java
+++ b/java/tools/src/java/org/apache/orc/tools/KeyTool.java
@@ -26,8 +26,9 @@ import org.apache.commons.cli.ParseException;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.BytesWritable;
 import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.impl.CryptoUtils;
 import org.apache.orc.impl.HadoopShims;
-import org.apache.orc.impl.HadoopShimsFactory;
+import org.apache.orc.impl.KeyProvider;
 import org.codehaus.jettison.json.JSONException;
 import org.codehaus.jettison.json.JSONWriter;
 
@@ -42,7 +43,7 @@ import java.security.SecureRandom;
 public class KeyTool {
 
   static void printKey(JSONWriter writer,
-                       HadoopShims.KeyProvider provider,
+                       KeyProvider provider,
                        String keyName) throws JSONException, IOException {
     HadoopShims.KeyMetadata meta = provider.getCurrentKeyVersion(keyName);
     writer.object();
@@ -79,8 +80,8 @@ public class KeyTool {
   }
 
   void run() throws IOException, JSONException {
-    HadoopShims.KeyProvider provider =
-        HadoopShimsFactory.get().getKeyProvider(conf, new SecureRandom());
+    KeyProvider provider =
+        CryptoUtils.getKeyProvider(conf, new SecureRandom());
     if (provider == null) {
       System.err.println("No key provider available.");
       System.exit(1);
diff --git a/java/tools/src/test/org/apache/orc/impl/FakeKeyProvider.java b/java/tools/src/test/org/apache/orc/impl/FakeKeyProvider.java
new file mode 100644
index 0000000..1c3f6c5
--- /dev/null
+++ b/java/tools/src/test/org/apache/orc/impl/FakeKeyProvider.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl;
+
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.crypto.key.KeyProvider;
+import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
+import org.apache.hadoop.crypto.key.KeyProviderFactory;
+import org.apache.hadoop.crypto.key.kms.KMSClientProvider;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A Hadoop KeyProvider that lets us test the interaction
+ * with the Hadoop code.
+ * Must only be used in unit tests!
+ */
+public class FakeKeyProvider extends KeyProvider {
+  // map from key name to metadata
+  private final Map<String, TestMetadata> keyMetdata = new HashMap<>();
+  // map from key version name to material
+  private final Map<String, KeyVersion> keyVersions = new HashMap<>();
+
+  public FakeKeyProvider(Configuration conf) {
+    super(conf);
+  }
+
+  @Override
+  public KeyVersion getKeyVersion(String name) {
+    return keyVersions.get(name);
+  }
+
+  @Override
+  public List<String> getKeys() {
+    return new ArrayList<>(keyMetdata.keySet());
+  }
+
+  @Override
+  public List<KeyVersion> getKeyVersions(String name) {
+    List<KeyVersion> result = new ArrayList<>();
+    Metadata meta = getMetadata(name);
+    for(int v=0; v < meta.getVersions(); ++v) {
+      String versionName = buildVersionName(name, v);
+      KeyVersion material = keyVersions.get(versionName);
+      if (material != null) {
+        result.add(material);
+      }
+    }
+    return result;
+  }
+
+  @Override
+  public Metadata getMetadata(String name)  {
+    return keyMetdata.get(name);
+  }
+
+  @Override
+  public KeyVersion createKey(String name, byte[] bytes, Options options) {
+    String versionName = buildVersionName(name, 0);
+    keyMetdata.put(name, new TestMetadata(options.getCipher(),
+        options.getBitLength(), 1));
+    KeyVersion result = new KMSClientProvider.KMSKeyVersion(name, versionName, bytes);
+    keyVersions.put(versionName, result);
+    return result;
+  }
+
+  @Override
+  public void deleteKey(String name) {
+    throw new UnsupportedOperationException("Can't delete keys");
+  }
+
+  @Override
+  public KeyVersion rollNewVersion(String name, byte[] bytes) {
+    TestMetadata key = keyMetdata.get(name);
+    String versionName = buildVersionName(name, key.addVersion());
+    KeyVersion result = new KMSClientProvider.KMSKeyVersion(name, versionName,
+        bytes);
+    keyVersions.put(versionName, result);
+    return result;
+  }
+
+  @Override
+  public void flush() {
+    // Nothing
+  }
+
+  static class TestMetadata extends KeyProvider.Metadata {
+
+    TestMetadata(String cipher, int bitLength, int versions) {
+      super(cipher, bitLength, null, null, null, versions);
+    }
+
+    public int addVersion() {
+      return super.addVersion();
+    }
+  }
+
+  public static class Factory extends KeyProviderFactory {
+
+    @Override
+    public KeyProvider createProvider(URI uri,
+                                      Configuration conf) throws IOException {
+      if ("test".equals(uri.getScheme())) {
+        KeyProvider provider = new FakeKeyProvider(conf);
+        // populate a couple keys into the provider
+        byte[] piiKey = new byte[]{0,1,2,3,4,5,6,7,8,9,0xa,0xb,0xc,0xd,0xe,0xf};
+        org.apache.hadoop.crypto.key.KeyProvider.Options aes128 = new KeyProvider.Options(conf);
+        provider.createKey("pii", piiKey, aes128);
+        byte[] piiKey2 = new byte[]{0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
+            0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f};
+        provider.rollNewVersion("pii", piiKey2);
+        byte[] secretKey = new byte[]{0x20,0x21,0x22,0x23,0x24,0x25,0x26,0x27,
+            0x28,0x29,0x2a,0x2b,0x2c,0x2d,0x2e,0x2f};
+        provider.createKey("secret", secretKey, aes128);
+        return KeyProviderCryptoExtension.createKeyProviderCryptoExtension(provider);
+      }
+      return null;
+    }
+  }
+}
diff --git a/java/tools/src/test/org/apache/orc/impl/TestHadoopKeyProvider.java b/java/tools/src/test/org/apache/orc/impl/TestHadoopKeyProvider.java
new file mode 100644
index 0000000..d3cbf60
--- /dev/null
+++ b/java/tools/src/test/org/apache/orc/impl/TestHadoopKeyProvider.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.orc.impl;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.BytesWritable;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.security.Key;
+import java.util.List;
+import java.util.Random;
+
+import static junit.framework.Assert.assertEquals;
+
+public class TestHadoopKeyProvider {
+
+  /**
+   * Tests the path through the hadoop key provider code base.
+   * This should be consistent with TestCryptoUtils.testMemoryKeyProvider.
+   * @throws IOException
+   */
+  @Test
+  public void testHadoopKeyProvider() throws IOException {
+    Configuration conf = new Configuration();
+    conf.set("hadoop.security.key.provider.path", "test:///");
+    // Hard code the random so that we know the bytes that will come out.
+    KeyProvider provider = CryptoUtils.getKeyProvider(conf, new Random(24));
+    List<String> keyNames = provider.getKeyNames();
+    assertEquals(2, keyNames.size());
+    assertEquals(true, keyNames.contains("pii"));
+    assertEquals(true, keyNames.contains("secret"));
+    HadoopShims.KeyMetadata piiKey = provider.getCurrentKeyVersion("pii");
+    assertEquals(1, piiKey.getVersion());
+    LocalKey localKey = provider.createLocalKey(piiKey);
+    byte[] encrypted = localKey.getEncryptedKey();
+    // make sure that we get exactly what we expect to test the encryption
+    assertEquals("c7 ab 4f bb 38 f4 de ad d0 b3 59 e2 21 2a 95 32",
+        new BytesWritable(encrypted).toString());
+    // now check to make sure that we get the expected bytes back
+    assertEquals("c7 a1 d0 41 7b 24 72 44 1a 58 c7 72 4a d4 be b3",
+        new BytesWritable(localKey.getDecryptedKey().getEncoded()).toString());
+    Key key = provider.decryptLocalKey(piiKey, encrypted);
+    assertEquals(new BytesWritable(localKey.getDecryptedKey().getEncoded()).toString(),
+        new BytesWritable(key.getEncoded()).toString());
+  }
+}
diff --git a/java/shims/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory b/java/tools/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory
similarity index 92%
rename from java/shims/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory
rename to java/tools/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory
index 14ee9a5..9648d70 100644
--- a/java/shims/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory
+++ b/java/tools/src/test/resources/META-INF/services/org.apache.hadoop.crypto.key.KeyProviderFactory
@@ -13,4 +13,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-org.apache.orc.impl.TestHadoopShimsPre2_7$TestKeyProviderFactory
+org.apache.orc.impl.FakeKeyProvider$Factory