You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@orc.apache.org by om...@apache.org on 2019/07/03 16:33:53 UTC

[orc] branch master updated: ORC-523: Update ReaderImpl to work with column encryption.

This is an automated email from the ASF dual-hosted git repository.

omalley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/orc.git


The following commit(s) were added to refs/heads/master by this push:
     new 33916ab  ORC-523: Update ReaderImpl to work with column encryption.
33916ab is described below

commit 33916ab11504f7d3b4b76acb2f97d91e2c874523
Author: Owen O'Malley <om...@apache.org>
AuthorDate: Mon Jun 24 13:53:19 2019 -0700

    ORC-523: Update ReaderImpl to work with column encryption.
    
    Fixes #408
    
    Signed-off-by: Owen O'Malley <om...@apache.org>
---
 java/core/src/java/org/apache/orc/OrcUtils.java    |  14 +-
 .../src/java/org/apache/orc/StripeInformation.java |  21 +++
 .../core/src/java/org/apache/orc/impl/OrcTail.java |  29 ++--
 .../src/java/org/apache/orc/impl/ReaderImpl.java   | 152 ++++++++++++++---
 .../apache/orc/impl/mask/SHA256MaskFactory.java    |  14 +-
 .../apache/orc/impl/reader/ReaderEncryption.java   | 145 ++++++++++++++++
 .../orc/impl/reader/ReaderEncryptionKey.java       | 132 ++++++++++++++
 .../orc/impl/reader/ReaderEncryptionVariant.java   | 190 +++++++++++++++++++++
 .../org/apache/orc/impl/TestRecordReaderImpl.java  |   4 +-
 9 files changed, 646 insertions(+), 55 deletions(-)

diff --git a/java/core/src/java/org/apache/orc/OrcUtils.java b/java/core/src/java/org/apache/orc/OrcUtils.java
index 220fa13..0ba46dc 100644
--- a/java/core/src/java/org/apache/orc/OrcUtils.java
+++ b/java/core/src/java/org/apache/orc/OrcUtils.java
@@ -616,9 +616,17 @@ public class OrcUtils {
 
   public static List<StripeInformation> convertProtoStripesToStripes(
       List<OrcProto.StripeInformation> stripes) {
-    List<StripeInformation> result = new ArrayList<StripeInformation>(stripes.size());
-    for (OrcProto.StripeInformation info : stripes) {
-      result.add(new ReaderImpl.StripeInformationImpl(info));
+    List<StripeInformation> result = new ArrayList<>(stripes.size());
+    long previousStripeId = -1;
+    byte[][] previousKeys = null;
+    long stripeId = 0;
+    for (OrcProto.StripeInformation stripeProto: stripes) {
+      ReaderImpl.StripeInformationImpl stripe =
+          new ReaderImpl.StripeInformationImpl(stripeProto, stripeId++,
+              previousStripeId, previousKeys);
+      result.add(stripe);
+      previousStripeId = stripe.getEncryptionStripeId();
+      previousKeys = stripe.getEncryptedLocalKeys();
     }
     return result;
   }
diff --git a/java/core/src/java/org/apache/orc/StripeInformation.java b/java/core/src/java/org/apache/orc/StripeInformation.java
index 38f7eba..6490d6b 100644
--- a/java/core/src/java/org/apache/orc/StripeInformation.java
+++ b/java/core/src/java/org/apache/orc/StripeInformation.java
@@ -56,4 +56,25 @@ public interface StripeInformation {
    * @return a count of the number of rows
    */
   long getNumberOfRows();
+
+  /**
+   * Get the index of this stripe in the current file.
+   * @return 0 to number_of_stripes - 1
+   */
+  long getStripeId();
+
+  /**
+   * Get the original stripe id that was used when the stripe was originally
+   * written. This is only different that getStripeId in merged files.
+   * @return the original stripe id
+   */
+  long getEncryptionStripeId();
+
+  /**
+   * Get the encrypted keys starting from this stripe until overridden by
+   * a new set in a following stripe. The top level array is one for each
+   * encryption variant. Each element is an encrypted key.
+   * @return the array of encrypted keys
+   */
+  byte[][] getEncryptedLocalKeys();
 }
diff --git a/java/core/src/java/org/apache/orc/impl/OrcTail.java b/java/core/src/java/org/apache/orc/impl/OrcTail.java
index 9e8a5f2..2765473 100644
--- a/java/core/src/java/org/apache/orc/impl/OrcTail.java
+++ b/java/core/src/java/org/apache/orc/impl/OrcTail.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -27,6 +27,7 @@ import org.apache.orc.CompressionCodec;
 import org.apache.orc.CompressionKind;
 import org.apache.orc.OrcFile;
 import org.apache.orc.OrcProto;
+import org.apache.orc.OrcUtils;
 import org.apache.orc.StripeInformation;
 import org.apache.orc.StripeStatistics;
 
@@ -77,11 +78,7 @@ public final class OrcTail {
   }
 
   public List<StripeInformation> getStripes() {
-    List<StripeInformation> result = new ArrayList<>(fileTail.getFooter().getStripesCount());
-    for (OrcProto.StripeInformation stripeProto : fileTail.getFooter().getStripesList()) {
-      result.add(new ReaderImpl.StripeInformationImpl(stripeProto));
-    }
-    return result;
+    return OrcUtils.convertProtoStripesToStripes(getFooter().getStripesList());
   }
 
   public CompressionKind getCompressionKind() {
@@ -92,9 +89,9 @@ public final class OrcTail {
     return (int) fileTail.getPostscript().getCompressionBlockSize();
   }
 
-  public List<StripeStatistics> getStripeStatistics() throws IOException {
+  public List<StripeStatistics> getStripeStatistics(InStream.StreamOptions options) throws IOException {
     List<StripeStatistics> result = new ArrayList<>();
-    List<OrcProto.StripeStatistics> ssProto = getStripeStatisticsProto();
+    List<OrcProto.StripeStatistics> ssProto = getStripeStatisticsProto(options);
     if (ssProto != null) {
       for (OrcProto.StripeStatistics ss : ssProto) {
         result.add(new StripeStatistics(ss.getColStatsList()));
@@ -103,17 +100,12 @@ public final class OrcTail {
     return result;
   }
 
-  public List<OrcProto.StripeStatistics> getStripeStatisticsProto() throws IOException {
+  public List<OrcProto.StripeStatistics> getStripeStatisticsProto(InStream.StreamOptions options) throws IOException {
     if (serializedTail == null) return null;
     if (metadata == null) {
-      CompressionCodec codec = OrcCodecPool.getCodec(getCompressionKind());
-      try {
-        metadata = extractMetadata(serializedTail, 0,
-            (int) fileTail.getPostscript().getMetadataLength(),
-            InStream.options().withCodec(codec).withBufferSize(getCompressionBufferSize()));
-      } finally {
-        OrcCodecPool.returnCodec(getCompressionKind(), codec);
-      }
+      metadata = extractMetadata(serializedTail, 0,
+          (int) fileTail.getPostscript().getMetadataLength(),
+          options);
       // clear does not clear the contents but sets position to 0 and limit = capacity
       serializedTail.clear();
     }
@@ -137,7 +129,6 @@ public final class OrcTail {
     OrcProto.Footer.Builder footerBuilder = OrcProto.Footer.newBuilder(fileTail.getFooter());
     footerBuilder.clearStatistics();
     fileTailBuilder.setFooter(footerBuilder.build());
-    OrcProto.FileTail result = fileTailBuilder.build();
-    return result;
+    return fileTailBuilder.build();
   }
 }
diff --git a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java
index ddd53b7..d1311b9 100644
--- a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java
+++ b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java
@@ -20,13 +20,15 @@ package org.apache.orc.impl;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.security.Key;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
-import org.apache.hadoop.fs.FileStatus;
+import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.EncryptionKey;
 import org.apache.orc.CompressionKind;
 import org.apache.orc.DataMaskDescription;
 import org.apache.orc.EncryptionKey;
@@ -43,8 +45,8 @@ import org.apache.orc.FileFormatException;
 import org.apache.orc.StripeInformation;
 import org.apache.orc.StripeStatistics;
 import org.apache.orc.UnknownFormatException;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.apache.orc.impl.reader.ReaderEncryption;
+import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FSDataInputStream;
 import org.apache.hadoop.fs.FileSystem;
@@ -54,6 +56,9 @@ import org.apache.hadoop.io.Text;
 import org.apache.orc.OrcProto;
 
 import com.google.protobuf.CodedInputStream;
+import org.apache.orc.impl.reader.ReaderEncryptionVariant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public class ReaderImpl implements Reader {
 
@@ -78,6 +83,7 @@ public class ReaderImpl implements Reader {
   private final List<StripeInformation> stripes;
   protected final int rowIndexStride;
   private final long contentLength, numberOfRows;
+  private final ReaderEncryption encryption;
 
   private long deserializedSize = -1;
   protected final Configuration conf;
@@ -89,10 +95,30 @@ public class ReaderImpl implements Reader {
 
   public static class StripeInformationImpl
       implements StripeInformation {
+    private final long stripeId;
+    private final long originalStripeId;
+    private final byte[][] encryptedKeys;
     private final OrcProto.StripeInformation stripe;
 
-    public StripeInformationImpl(OrcProto.StripeInformation stripe) {
+    public StripeInformationImpl(OrcProto.StripeInformation stripe,
+                                 long stripeId,
+                                 long previousOriginalStripeId,
+                                 byte[][] previousKeys) {
       this.stripe = stripe;
+      this.stripeId = stripeId;
+      if (stripe.hasEncryptStripeId()) {
+        originalStripeId = stripe.getEncryptStripeId();
+      } else {
+        originalStripeId = previousOriginalStripeId + 1;
+      }
+      if (stripe.getEncryptedLocalKeysCount() != 0) {
+        encryptedKeys = new byte[stripe.getEncryptedLocalKeysCount()][];
+        for(int v=0; v < encryptedKeys.length; ++v) {
+          encryptedKeys[v] = stripe.getEncryptedLocalKeys(v).toByteArray();
+        }
+      } else {
+        encryptedKeys = previousKeys;
+      }
     }
 
     @Override
@@ -126,6 +152,21 @@ public class ReaderImpl implements Reader {
     }
 
     @Override
+    public long getStripeId() {
+      return stripeId;
+    }
+
+    @Override
+    public long getEncryptionStripeId() {
+      return originalStripeId;
+    }
+
+    @Override
+    public byte[][] getEncryptedLocalKeys() {
+      return encryptedKeys;
+    }
+
+    @Override
     public String toString() {
       return "offset: " + getOffset() + " data: " + getDataLength() +
         " rows: " + getNumberOfRows() + " tail: " + getFooterLength() +
@@ -221,20 +262,25 @@ public class ReaderImpl implements Reader {
 
   @Override
   public EncryptionKey[] getColumnEncryptionKeys() {
-    // TODO
-    return new EncryptionKey[0];
+    return encryption.getKeys();
   }
 
   @Override
   public DataMaskDescription[] getDataMasks() {
-    // TODO
-    return new DataMaskDescription[0];
+    return encryption.getMasks();
   }
 
   @Override
-  public EncryptionVariant[] getEncryptionVariants() {
-    // TODO
-    return new EncryptionVariant[0];
+  public ReaderEncryptionVariant[] getEncryptionVariants() {
+    return encryption.getVariants();
+  }
+
+  /**
+   * Internal access to our view of the encryption.
+   * @return the encryption information for this reader.
+   */
+  public ReaderEncryption getEncryption() {
+    return encryption;
   }
 
   @Override
@@ -244,7 +290,61 @@ public class ReaderImpl implements Reader {
 
   @Override
   public ColumnStatistics[] getStatistics() {
-    return deserializeStats(schema, fileStats);
+    ColumnStatistics[] result = deserializeStats(schema, fileStats);
+    try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
+      InStream.StreamOptions compression = InStream.options();
+      if (codec != null) {
+        compression.withCodec(codec).withBufferSize(bufferSize);
+      }
+      for (ReaderEncryptionVariant variant : encryption.getVariants()) {
+        ColumnStatistics[] overrides;
+        try {
+          overrides = decryptFileStats(variant, compression,
+              tail.getFooter());
+        } catch (IOException e) {
+          throw new RuntimeException("Can't decrypt file stats for " + path +
+                                        " with " + variant.getKeyDescription());
+        }
+        if (overrides != null) {
+          for (int i = 0; i < overrides.length; ++i) {
+            result[variant.getRoot().getId() + i] = overrides[i];
+          }
+        }
+      }
+    }
+    return result;
+  }
+
+  public static ColumnStatistics[] decryptFileStats(ReaderEncryptionVariant encryption,
+                                                    InStream.StreamOptions compression,
+                                                    OrcProto.Footer footer
+                                                    ) throws IOException {
+    Key key = encryption.getFileFooterKey();
+    if (key == null) {
+      return null;
+    } else {
+      OrcProto.EncryptionVariant protoVariant =
+          footer.getEncryption().getVariants(encryption.getVariantId());
+      byte[] bytes = protoVariant.getFileStatistics().toByteArray();
+      BufferChunk buffer = new BufferChunk(ByteBuffer.wrap(bytes), 0);
+      EncryptionAlgorithm algorithm = encryption.getKeyDescription().getAlgorithm();
+      byte[] iv = new byte[algorithm.getIvLength()];
+      CryptoUtils.modifyIvForStream(encryption.getRoot().getId(),
+          OrcProto.Stream.Kind.FILE_STATISTICS, footer.getStripesCount())
+          .accept(iv);
+      InStream.StreamOptions options = new InStream.StreamOptions(compression)
+                                           .withEncryption(algorithm, key, iv);
+      InStream in = InStream.create("encrypted file stats", buffer,
+          bytes.length, 0, options);
+      OrcProto.FileStatistics decrypted = OrcProto.FileStatistics.parseFrom(in);
+      ColumnStatistics[] result = new ColumnStatistics[decrypted.getColumnCount()];
+      TypeDescription root = encryption.getRoot();
+      for(int i= 0; i < result.length; ++i){
+        result[i] = ColumnStatisticsImpl.deserialize(root.findSubtype(root.getId() + i),
+            decrypted.getColumn(i));
+      }
+      return result;
+    }
   }
 
   public static ColumnStatistics[] deserializeStats(
@@ -351,12 +451,17 @@ public class ReaderImpl implements Reader {
       this.writerVersion =
           OrcFile.WriterVersion.from(writer, fileMetadata.getWriterVersionNum());
       this.types = fileMetadata.getTypes();
+      OrcUtils.isValidTypeTree(this.types, 0);
+      this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
       this.rowIndexStride = fileMetadata.getRowIndexStride();
       this.contentLength = fileMetadata.getContentLength();
       this.numberOfRows = fileMetadata.getNumberOfRows();
       this.fileStats = fileMetadata.getFileStats();
       this.stripes = fileMetadata.getStripes();
       this.userMetadata = null; // not cached and not needed here
+      // FileMetadata is obsolete and doesn't support encryption
+      this.encryption = new ReaderEncryption(null, schema, stripes,
+          options.getKeyProvider(), conf);
     } else {
       OrcTail orcTail = options.getOrcTail();
       if (orcTail == null) {
@@ -371,6 +476,8 @@ public class ReaderImpl implements Reader {
       this.metadataSize = tail.getMetadataSize();
       this.versionList = tail.getPostScript().getVersionList();
       this.types = tail.getFooter().getTypesList();
+      OrcUtils.isValidTypeTree(this.types, 0);
+      this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
       this.rowIndexStride = tail.getFooter().getRowIndexStride();
       this.contentLength = tail.getFooter().getContentLength();
       this.numberOfRows = tail.getFooter().getNumberOfRows();
@@ -378,10 +485,14 @@ public class ReaderImpl implements Reader {
       this.fileStats = tail.getFooter().getStatisticsList();
       this.writerVersion = tail.getWriterVersion();
       this.stripes = tail.getStripes();
-      this.stripeStats = tail.getStripeStatisticsProto();
+      try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
+        InStream.StreamOptions compress = InStream.options().withCodec(codec)
+                                             .withBufferSize(bufferSize);
+        this.stripeStats = tail.getStripeStatisticsProto(compress);
+      }
+      this.encryption = new ReaderEncryption(tail.getFooter(), schema,
+          stripes, options.getKeyProvider(), conf);
     }
-    OrcUtils.isValidTypeTree(this.types, 0);
-    this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
   }
 
   protected FileSystem getFileSystem() throws IOException {
@@ -559,12 +670,9 @@ public class ReaderImpl implements Reader {
       ByteBuffer footerBuffer = buffer.slice();
       buffer.reset();
       OrcProto.Footer footer;
-      CompressionCodec codec = OrcCodecPool.getCodec(compressionKind);
-      try {
+      try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)){
         footer = extractFooter(footerBuffer, 0, footerSize,
             InStream.options().withCodec(codec).withBufferSize(bufferSize));
-      } finally {
-        OrcCodecPool.returnCodec(compressionKind, codec);
       }
       fileTailBuilder.setFooter(footer);
     } catch (Throwable thr) {
@@ -604,7 +712,6 @@ public class ReaderImpl implements Reader {
     return new RecordReaderImpl(this, options);
   }
 
-
   @Override
   public long getRawDataSize() {
     // if the deserializedSize is not computed, then compute it, else
@@ -757,12 +864,9 @@ public class ReaderImpl implements Reader {
   @Override
   public List<StripeStatistics> getStripeStatistics() throws IOException {
     if (metadata == null) {
-      CompressionCodec codec = OrcCodecPool.getCodec(compressionKind);
-      try {
+      try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
         metadata = extractMetadata(tail.getSerializedTail(), 0, metadataSize,
             InStream.options().withCodec(codec).withBufferSize(bufferSize));
-      } finally {
-        OrcCodecPool.returnCodec(compressionKind, codec);
       }
     }
     if (stripeStats == null) {
diff --git a/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java b/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java
index c300500..b3445f9 100644
--- a/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java
+++ b/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java
@@ -62,9 +62,9 @@ import java.util.Arrays;
  */
 public class SHA256MaskFactory extends MaskFactory {
 
-  final MessageDigest md;
+  private final MessageDigest md;
 
-  public SHA256MaskFactory(final String... params) {
+  SHA256MaskFactory() {
     super();
     try {
       md = MessageDigest.getInstance("SHA-256");
@@ -138,9 +138,9 @@ public class SHA256MaskFactory extends MaskFactory {
   /**
    * Helper function to mask binary data with it's SHA-256 hash.
    *
-   * @param source
-   * @param row
-   * @param target
+   * @param source the source data
+   * @param row the row that we are translating
+   * @param target the output data
    */
   void maskBinary(final BytesColumnVector source, final int row,
       final BytesColumnVector target) {
@@ -207,7 +207,7 @@ public class SHA256MaskFactory extends MaskFactory {
     final TypeDescription schema;
 
     /* create an instance */
-    public StringMask(TypeDescription schema) {
+    StringMask(TypeDescription schema) {
       super();
       this.schema = schema;
     }
@@ -254,7 +254,7 @@ public class SHA256MaskFactory extends MaskFactory {
   class BinaryMask implements DataMask {
 
     /* create an instance */
-    public BinaryMask() {
+    BinaryMask() {
       super();
     }
 
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java
new file mode 100644
index 0000000..fe54c49
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl.reader;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.orc.OrcProto;
+import org.apache.orc.StripeInformation;
+import org.apache.orc.TypeDescription;
+import org.apache.orc.impl.HadoopShims;
+import org.apache.orc.impl.HadoopShimsFactory;
+import org.apache.orc.impl.MaskDescriptionImpl;
+
+import java.io.IOException;
+import java.security.SecureRandom;
+import java.util.Arrays;
+import java.util.List;
+
+public class ReaderEncryption {
+  private final HadoopShims.KeyProvider keyProvider;
+  private final ReaderEncryptionKey[] keys;
+  private final MaskDescriptionImpl[] masks;
+  private final ReaderEncryptionVariant[] variants;
+  // Mapping from each column to the next variant to try for that column.
+  // A value of variants.length means no encryption
+  private final ReaderEncryptionVariant[] columnVariants;
+
+  public ReaderEncryption() throws IOException {
+    this(null, null, null, null, null);
+  }
+
+  public ReaderEncryption(OrcProto.Footer footer,
+                          TypeDescription schema,
+                          List<StripeInformation> stripes,
+                          HadoopShims.KeyProvider provider,
+                          Configuration conf) throws IOException {
+    if (footer == null || !footer.hasEncryption()) {
+      keyProvider = null;
+      keys = new ReaderEncryptionKey[0];
+      masks = new MaskDescriptionImpl[0];
+      variants = new ReaderEncryptionVariant[0];
+      columnVariants = null;
+    } else {
+      keyProvider = provider != null ? provider :
+          HadoopShimsFactory.get().getKeyProvider(conf, new SecureRandom());
+      OrcProto.Encryption encrypt = footer.getEncryption();
+      masks = new MaskDescriptionImpl[encrypt.getMaskCount()];
+      for(int m=0; m < masks.length; ++m) {
+        masks[m] = new MaskDescriptionImpl(m, encrypt.getMask(m));
+      }
+      keys = new ReaderEncryptionKey[encrypt.getKeyCount()];
+      for(int k=0; k < keys.length; ++k) {
+        keys[k] = new ReaderEncryptionKey(encrypt.getKey(k));
+      }
+      variants = new ReaderEncryptionVariant[encrypt.getVariantsCount()];
+      for(int v=0; v < variants.length; ++v) {
+        OrcProto.EncryptionVariant variant = encrypt.getVariants(v);
+        variants[v] = new ReaderEncryptionVariant(keys[variant.getKey()], v,
+            variant, schema, stripes, keyProvider);
+      }
+      columnVariants = new ReaderEncryptionVariant[schema.getMaximumId() + 1];
+      for(int v = 0; v < variants.length; ++v) {
+        TypeDescription root = variants[v].getRoot();
+        for(int c = root.getId(); c <= root.getMaximumId(); ++c) {
+          if (columnVariants[c] == null) {
+            columnVariants[c] = variants[v];
+          }
+        }
+      }
+    }
+  }
+
+  public MaskDescriptionImpl[] getMasks() {
+    return masks;
+  }
+
+  public ReaderEncryptionKey[] getKeys() {
+    return keys;
+  }
+
+  public ReaderEncryptionVariant[] getVariants() {
+    return variants;
+  }
+
+  /**
+   * Find the next possible variant in this file for the given column.
+   * @param column the column to find a variant for
+   * @param lastVariant the previous variant that we looked at
+   * @return the next variant or null if there are none
+   */
+  private ReaderEncryptionVariant findNextVariant(int column,
+                                                  int lastVariant) {
+    for(int v = lastVariant + 1; v < variants.length; ++v) {
+      TypeDescription root = variants[v].getRoot();
+      if (root.getId() <= column && column <= root.getMaximumId()) {
+        return variants[v];
+      }
+    }
+    return null;
+  }
+
+  /**
+   * Get the variant for a given column that the user has access to.
+   * If we haven't tried a given key, try to decrypt this variant's footer key
+   * to see if the KeyProvider will give it to us. If not, continue to the
+   * next variant.
+   * @param column the column id
+   * @return null for no encryption or the encryption variant
+   */
+  public ReaderEncryptionVariant getVariant(int column) throws IOException {
+    if (keyProvider != null) {
+      while (columnVariants[column] != null) {
+        ReaderEncryptionVariant result = columnVariants[column];
+        switch (result.getKeyDescription().getState()) {
+        case FAILURE:
+          break;
+        case SUCCESS:
+          return result;
+        case UNTRIED:
+          // try to get the footer key, to see if we have access
+          if (result.getFileFooterKey() != null) {
+            return result;
+          }
+        }
+        columnVariants[column] = findNextVariant(column, result.getVariantId());
+      }
+    }
+    return null;
+  }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java
new file mode 100644
index 0000000..407bebb
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl.reader;
+
+import org.apache.orc.EncryptionKey;
+import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.OrcProto;
+import org.apache.orc.impl.HadoopShims;
+import org.jetbrains.annotations.NotNull;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * This tracks the keys for reading encrypted columns.
+ */
+public class ReaderEncryptionKey implements EncryptionKey {
+  private final String name;
+  private final int version;
+  private final EncryptionAlgorithm algorithm;
+  private final List<ReaderEncryptionVariant> roots = new ArrayList<>();
+
+  /**
+   * Store the state of whether we've tried to decrypt a local key using this
+   * key or not. If it fails the first time, we assume the user doesn't have
+   * permission and move on. However, we don't want to retry the same failed
+   * key over and over again.
+   */
+  public enum State {
+    UNTRIED,
+    FAILURE,
+    SUCCESS
+  }
+
+  private State state = State.UNTRIED;
+
+  public ReaderEncryptionKey(OrcProto.EncryptionKey key) {
+    name = key.getKeyName();
+    version = key.getKeyVersion();
+    algorithm =
+        EncryptionAlgorithm.fromSerialization(key.getAlgorithm().getNumber());
+  }
+
+  @Override
+  public String getKeyName() {
+    return name;
+  }
+
+  @Override
+  public int getKeyVersion() {
+    return version;
+  }
+
+  @Override
+  public EncryptionAlgorithm getAlgorithm() {
+    return algorithm;
+  }
+
+  @Override
+  public ReaderEncryptionVariant[] getEncryptionRoots() {
+    return roots.toArray(new ReaderEncryptionVariant[roots.size()]);
+  }
+
+  public HadoopShims.KeyMetadata getMetadata() {
+    return new HadoopShims.KeyMetadata(name, version, algorithm);
+  }
+
+  public State getState() {
+    return state;
+  }
+
+  public void setFailure() {
+    state = State.FAILURE;
+  }
+
+  public void setSucess() {
+    if (state == State.FAILURE) {
+      throw new IllegalStateException("Key " + name + " had already failed.");
+    }
+    state = State.SUCCESS;
+  }
+
+  void addVariant(ReaderEncryptionVariant newVariant) {
+    roots.add(newVariant);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other == null || getClass() != other.getClass()) {
+      return false;
+    } else if (other == this) {
+      return true;
+    } else {
+      return compareTo((EncryptionKey) other) == 0;
+    }
+  }
+
+  @Override
+  public int hashCode() {
+    return name.hashCode() * 127 + version * 7 + algorithm.hashCode();
+  }
+
+  @Override
+  public int compareTo(@NotNull EncryptionKey other) {
+    int result = name.compareTo(other.getKeyName());
+    if (result == 0) {
+      result = Integer.compare(version, other.getKeyVersion());
+    }
+    return result;
+  }
+
+  @Override
+  public String toString() {
+    return name + "@" + version + " w/ " + algorithm;
+  }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java
new file mode 100644
index 0000000..255952d
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl.reader;
+
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.EncryptionVariant;
+import org.apache.orc.OrcProto;
+import org.apache.orc.StripeInformation;
+import org.apache.orc.TypeDescription;
+import org.apache.orc.impl.HadoopShims;
+import org.apache.orc.impl.LocalKey;
+import org.jetbrains.annotations.NotNull;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.security.Key;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Information about an encrypted column.
+ */
+public class ReaderEncryptionVariant implements EncryptionVariant {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(ReaderEncryptionVariant.class);
+  private final HadoopShims.KeyProvider provider;
+  private final ReaderEncryptionKey key;
+  private final TypeDescription column;
+  private final int variantId;
+  private final LocalKey[] localKeys;
+  private final LocalKey footerKey;
+
+  /**
+   * Create a reader's view of an encryption variant.
+   * @param key the encryption key description
+   * @param variantId the of of the variant (0..N-1)
+   * @param proto the serialized description of the variant
+   * @param schema the file schema
+   * @param stripes the stripe information
+   * @param provider the key provider
+   */
+  public ReaderEncryptionVariant(ReaderEncryptionKey key,
+                                 int variantId,
+                                 OrcProto.EncryptionVariant proto,
+                                 TypeDescription schema,
+                                 List<StripeInformation> stripes,
+                                 HadoopShims.KeyProvider provider) {
+    this.key = key;
+    this.variantId = variantId;
+    this.provider = provider;
+    this.column = proto.hasRoot() ? schema.findSubtype(proto.getRoot()) : null;
+    this.localKeys = new LocalKey[stripes.size()];
+    HashMap<BytesWritable, LocalKey> cache = new HashMap<>();
+    for(int s=0; s < localKeys.length; ++s) {
+      StripeInformation stripe = stripes.get(s);
+      localKeys[s] = getCachedKey(cache, key.getAlgorithm(),
+          stripe.getEncryptedLocalKeys()[variantId]);
+    }
+    if (proto.hasEncryptedKey()) {
+      footerKey = getCachedKey(cache, key.getAlgorithm(),
+          proto.getEncryptedKey().toByteArray());
+    } else {
+      footerKey = null;
+    }
+    key.addVariant(this);
+  }
+
+  @Override
+  public ReaderEncryptionKey getKeyDescription() {
+    return key;
+  }
+
+  @Override
+  public TypeDescription getRoot() {
+    return column;
+  }
+
+  @Override
+  public int getVariantId() {
+    return variantId;
+  }
+
+  /**
+   * Deduplicate the local keys so that we only decrypt each local key once.
+   * @param cache the cache to use
+   * @param encrypted the encrypted key
+   * @return the local key
+   */
+  private static LocalKey getCachedKey(Map<BytesWritable, LocalKey> cache,
+                                       EncryptionAlgorithm algorithm,
+                                       byte[] encrypted) {
+    // wrap byte array in BytesWritable to get equality and hash
+    BytesWritable wrap = new BytesWritable(encrypted);
+    LocalKey result = cache.get(wrap);
+    if (result == null) {
+      result = new LocalKey(algorithm, null, encrypted);
+      cache.put(wrap, result);
+    }
+    return result;
+  }
+
+  private Key getDecryptedKey(LocalKey localKey) throws IOException {
+    Key result = localKey.getDecryptedKey();
+    if (result == null) {
+      switch (this.key.getState()) {
+      case UNTRIED:
+        try {
+          result = provider.decryptLocalKey(key.getMetadata(),
+              localKey.getEncryptedKey());
+        } catch (IOException ioe) {
+          LOG.info("Can't decrypt using key {}", key);
+        }
+        if (result != null) {
+          localKey.setDecryptedKey(result);
+          key.setSucess();
+        } else {
+          key.setFailure();
+        }
+        break;
+      case SUCCESS:
+        result = provider.decryptLocalKey(key.getMetadata(),
+            localKey.getEncryptedKey());
+        if (result == null) {
+          throw new IOException("Can't decrypt local key " + key);
+        }
+        localKey.setDecryptedKey(result);
+        break;
+      case FAILURE:
+        return null;
+      }
+    }
+    return result;
+  }
+
+  @Override
+  public Key getFileFooterKey() throws IOException {
+    return getDecryptedKey(footerKey);
+  }
+
+  @Override
+  public Key getStripeKey(long stripe) throws IOException {
+    return getDecryptedKey(localKeys[(int) stripe]);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other == null || other.getClass() != getClass()) {
+      return false;
+    } else {
+      return compareTo((EncryptionVariant) other) == 0;
+    }
+  }
+
+  @Override
+  public int hashCode() {
+    return key.hashCode() * 127 + column.getId();
+  }
+
+  @Override
+  public int compareTo(@NotNull EncryptionVariant other) {
+    if (other == this) {
+      return 0;
+    } else if (key == other.getKeyDescription()) {
+      return Integer.compare(column.getId(), other.getRoot().getId());
+    } else if (key == null) {
+      return -1;
+    } else {
+      return key.compareTo(other.getKeyDescription());
+    }
+  }
+}
diff --git a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
index 2e78196..9eaaa61 100644
--- a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
+++ b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
@@ -2104,7 +2104,7 @@ public class TestRecordReaderImpl {
         .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
     encodings.add(OrcProto.ColumnEncoding.newBuilder()
         .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
-    boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe),
+    boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe, 0, 1, null),
         indexes, null, encodings, null, false);
     assertEquals(4, rows.length);
     assertEquals(false, rows[0]);
@@ -2150,7 +2150,7 @@ public class TestRecordReaderImpl {
         .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
     encodings.add(OrcProto.ColumnEncoding.newBuilder()
         .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
-    boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe),
+    boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe, 0, 1, null),
         indexes, null, encodings, null, false);
     assertEquals(3, rows.length);
     assertEquals(false, rows[0]);