You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@orc.apache.org by om...@apache.org on 2019/07/03 16:33:53 UTC
[orc] branch master updated: ORC-523: Update ReaderImpl to work
with column encryption.
This is an automated email from the ASF dual-hosted git repository.
omalley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/orc.git
The following commit(s) were added to refs/heads/master by this push:
new 33916ab ORC-523: Update ReaderImpl to work with column encryption.
33916ab is described below
commit 33916ab11504f7d3b4b76acb2f97d91e2c874523
Author: Owen O'Malley <om...@apache.org>
AuthorDate: Mon Jun 24 13:53:19 2019 -0700
ORC-523: Update ReaderImpl to work with column encryption.
Fixes #408
Signed-off-by: Owen O'Malley <om...@apache.org>
---
java/core/src/java/org/apache/orc/OrcUtils.java | 14 +-
.../src/java/org/apache/orc/StripeInformation.java | 21 +++
.../core/src/java/org/apache/orc/impl/OrcTail.java | 29 ++--
.../src/java/org/apache/orc/impl/ReaderImpl.java | 152 ++++++++++++++---
.../apache/orc/impl/mask/SHA256MaskFactory.java | 14 +-
.../apache/orc/impl/reader/ReaderEncryption.java | 145 ++++++++++++++++
.../orc/impl/reader/ReaderEncryptionKey.java | 132 ++++++++++++++
.../orc/impl/reader/ReaderEncryptionVariant.java | 190 +++++++++++++++++++++
.../org/apache/orc/impl/TestRecordReaderImpl.java | 4 +-
9 files changed, 646 insertions(+), 55 deletions(-)
diff --git a/java/core/src/java/org/apache/orc/OrcUtils.java b/java/core/src/java/org/apache/orc/OrcUtils.java
index 220fa13..0ba46dc 100644
--- a/java/core/src/java/org/apache/orc/OrcUtils.java
+++ b/java/core/src/java/org/apache/orc/OrcUtils.java
@@ -616,9 +616,17 @@ public class OrcUtils {
public static List<StripeInformation> convertProtoStripesToStripes(
List<OrcProto.StripeInformation> stripes) {
- List<StripeInformation> result = new ArrayList<StripeInformation>(stripes.size());
- for (OrcProto.StripeInformation info : stripes) {
- result.add(new ReaderImpl.StripeInformationImpl(info));
+ List<StripeInformation> result = new ArrayList<>(stripes.size());
+ long previousStripeId = -1;
+ byte[][] previousKeys = null;
+ long stripeId = 0;
+ for (OrcProto.StripeInformation stripeProto: stripes) {
+ ReaderImpl.StripeInformationImpl stripe =
+ new ReaderImpl.StripeInformationImpl(stripeProto, stripeId++,
+ previousStripeId, previousKeys);
+ result.add(stripe);
+ previousStripeId = stripe.getEncryptionStripeId();
+ previousKeys = stripe.getEncryptedLocalKeys();
}
return result;
}
diff --git a/java/core/src/java/org/apache/orc/StripeInformation.java b/java/core/src/java/org/apache/orc/StripeInformation.java
index 38f7eba..6490d6b 100644
--- a/java/core/src/java/org/apache/orc/StripeInformation.java
+++ b/java/core/src/java/org/apache/orc/StripeInformation.java
@@ -56,4 +56,25 @@ public interface StripeInformation {
* @return a count of the number of rows
*/
long getNumberOfRows();
+
+ /**
+ * Get the index of this stripe in the current file.
+ * @return 0 to number_of_stripes - 1
+ */
+ long getStripeId();
+
+ /**
+ * Get the original stripe id that was used when the stripe was originally
+ * written. This is only different that getStripeId in merged files.
+ * @return the original stripe id
+ */
+ long getEncryptionStripeId();
+
+ /**
+ * Get the encrypted keys starting from this stripe until overridden by
+ * a new set in a following stripe. The top level array is one for each
+ * encryption variant. Each element is an encrypted key.
+ * @return the array of encrypted keys
+ */
+ byte[][] getEncryptedLocalKeys();
}
diff --git a/java/core/src/java/org/apache/orc/impl/OrcTail.java b/java/core/src/java/org/apache/orc/impl/OrcTail.java
index 9e8a5f2..2765473 100644
--- a/java/core/src/java/org/apache/orc/impl/OrcTail.java
+++ b/java/core/src/java/org/apache/orc/impl/OrcTail.java
@@ -1,4 +1,4 @@
-/**
+/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
@@ -27,6 +27,7 @@ import org.apache.orc.CompressionCodec;
import org.apache.orc.CompressionKind;
import org.apache.orc.OrcFile;
import org.apache.orc.OrcProto;
+import org.apache.orc.OrcUtils;
import org.apache.orc.StripeInformation;
import org.apache.orc.StripeStatistics;
@@ -77,11 +78,7 @@ public final class OrcTail {
}
public List<StripeInformation> getStripes() {
- List<StripeInformation> result = new ArrayList<>(fileTail.getFooter().getStripesCount());
- for (OrcProto.StripeInformation stripeProto : fileTail.getFooter().getStripesList()) {
- result.add(new ReaderImpl.StripeInformationImpl(stripeProto));
- }
- return result;
+ return OrcUtils.convertProtoStripesToStripes(getFooter().getStripesList());
}
public CompressionKind getCompressionKind() {
@@ -92,9 +89,9 @@ public final class OrcTail {
return (int) fileTail.getPostscript().getCompressionBlockSize();
}
- public List<StripeStatistics> getStripeStatistics() throws IOException {
+ public List<StripeStatistics> getStripeStatistics(InStream.StreamOptions options) throws IOException {
List<StripeStatistics> result = new ArrayList<>();
- List<OrcProto.StripeStatistics> ssProto = getStripeStatisticsProto();
+ List<OrcProto.StripeStatistics> ssProto = getStripeStatisticsProto(options);
if (ssProto != null) {
for (OrcProto.StripeStatistics ss : ssProto) {
result.add(new StripeStatistics(ss.getColStatsList()));
@@ -103,17 +100,12 @@ public final class OrcTail {
return result;
}
- public List<OrcProto.StripeStatistics> getStripeStatisticsProto() throws IOException {
+ public List<OrcProto.StripeStatistics> getStripeStatisticsProto(InStream.StreamOptions options) throws IOException {
if (serializedTail == null) return null;
if (metadata == null) {
- CompressionCodec codec = OrcCodecPool.getCodec(getCompressionKind());
- try {
- metadata = extractMetadata(serializedTail, 0,
- (int) fileTail.getPostscript().getMetadataLength(),
- InStream.options().withCodec(codec).withBufferSize(getCompressionBufferSize()));
- } finally {
- OrcCodecPool.returnCodec(getCompressionKind(), codec);
- }
+ metadata = extractMetadata(serializedTail, 0,
+ (int) fileTail.getPostscript().getMetadataLength(),
+ options);
// clear does not clear the contents but sets position to 0 and limit = capacity
serializedTail.clear();
}
@@ -137,7 +129,6 @@ public final class OrcTail {
OrcProto.Footer.Builder footerBuilder = OrcProto.Footer.newBuilder(fileTail.getFooter());
footerBuilder.clearStatistics();
fileTailBuilder.setFooter(footerBuilder.build());
- OrcProto.FileTail result = fileTailBuilder.build();
- return result;
+ return fileTailBuilder.build();
}
}
diff --git a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java
index ddd53b7..d1311b9 100644
--- a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java
+++ b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java
@@ -20,13 +20,15 @@ package org.apache.orc.impl;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.security.Key;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
-import org.apache.hadoop.fs.FileStatus;
+import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.EncryptionKey;
import org.apache.orc.CompressionKind;
import org.apache.orc.DataMaskDescription;
import org.apache.orc.EncryptionKey;
@@ -43,8 +45,8 @@ import org.apache.orc.FileFormatException;
import org.apache.orc.StripeInformation;
import org.apache.orc.StripeStatistics;
import org.apache.orc.UnknownFormatException;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.apache.orc.impl.reader.ReaderEncryption;
+import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
@@ -54,6 +56,9 @@ import org.apache.hadoop.io.Text;
import org.apache.orc.OrcProto;
import com.google.protobuf.CodedInputStream;
+import org.apache.orc.impl.reader.ReaderEncryptionVariant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class ReaderImpl implements Reader {
@@ -78,6 +83,7 @@ public class ReaderImpl implements Reader {
private final List<StripeInformation> stripes;
protected final int rowIndexStride;
private final long contentLength, numberOfRows;
+ private final ReaderEncryption encryption;
private long deserializedSize = -1;
protected final Configuration conf;
@@ -89,10 +95,30 @@ public class ReaderImpl implements Reader {
public static class StripeInformationImpl
implements StripeInformation {
+ private final long stripeId;
+ private final long originalStripeId;
+ private final byte[][] encryptedKeys;
private final OrcProto.StripeInformation stripe;
- public StripeInformationImpl(OrcProto.StripeInformation stripe) {
+ public StripeInformationImpl(OrcProto.StripeInformation stripe,
+ long stripeId,
+ long previousOriginalStripeId,
+ byte[][] previousKeys) {
this.stripe = stripe;
+ this.stripeId = stripeId;
+ if (stripe.hasEncryptStripeId()) {
+ originalStripeId = stripe.getEncryptStripeId();
+ } else {
+ originalStripeId = previousOriginalStripeId + 1;
+ }
+ if (stripe.getEncryptedLocalKeysCount() != 0) {
+ encryptedKeys = new byte[stripe.getEncryptedLocalKeysCount()][];
+ for(int v=0; v < encryptedKeys.length; ++v) {
+ encryptedKeys[v] = stripe.getEncryptedLocalKeys(v).toByteArray();
+ }
+ } else {
+ encryptedKeys = previousKeys;
+ }
}
@Override
@@ -126,6 +152,21 @@ public class ReaderImpl implements Reader {
}
@Override
+ public long getStripeId() {
+ return stripeId;
+ }
+
+ @Override
+ public long getEncryptionStripeId() {
+ return originalStripeId;
+ }
+
+ @Override
+ public byte[][] getEncryptedLocalKeys() {
+ return encryptedKeys;
+ }
+
+ @Override
public String toString() {
return "offset: " + getOffset() + " data: " + getDataLength() +
" rows: " + getNumberOfRows() + " tail: " + getFooterLength() +
@@ -221,20 +262,25 @@ public class ReaderImpl implements Reader {
@Override
public EncryptionKey[] getColumnEncryptionKeys() {
- // TODO
- return new EncryptionKey[0];
+ return encryption.getKeys();
}
@Override
public DataMaskDescription[] getDataMasks() {
- // TODO
- return new DataMaskDescription[0];
+ return encryption.getMasks();
}
@Override
- public EncryptionVariant[] getEncryptionVariants() {
- // TODO
- return new EncryptionVariant[0];
+ public ReaderEncryptionVariant[] getEncryptionVariants() {
+ return encryption.getVariants();
+ }
+
+ /**
+ * Internal access to our view of the encryption.
+ * @return the encryption information for this reader.
+ */
+ public ReaderEncryption getEncryption() {
+ return encryption;
}
@Override
@@ -244,7 +290,61 @@ public class ReaderImpl implements Reader {
@Override
public ColumnStatistics[] getStatistics() {
- return deserializeStats(schema, fileStats);
+ ColumnStatistics[] result = deserializeStats(schema, fileStats);
+ try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
+ InStream.StreamOptions compression = InStream.options();
+ if (codec != null) {
+ compression.withCodec(codec).withBufferSize(bufferSize);
+ }
+ for (ReaderEncryptionVariant variant : encryption.getVariants()) {
+ ColumnStatistics[] overrides;
+ try {
+ overrides = decryptFileStats(variant, compression,
+ tail.getFooter());
+ } catch (IOException e) {
+ throw new RuntimeException("Can't decrypt file stats for " + path +
+ " with " + variant.getKeyDescription());
+ }
+ if (overrides != null) {
+ for (int i = 0; i < overrides.length; ++i) {
+ result[variant.getRoot().getId() + i] = overrides[i];
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ public static ColumnStatistics[] decryptFileStats(ReaderEncryptionVariant encryption,
+ InStream.StreamOptions compression,
+ OrcProto.Footer footer
+ ) throws IOException {
+ Key key = encryption.getFileFooterKey();
+ if (key == null) {
+ return null;
+ } else {
+ OrcProto.EncryptionVariant protoVariant =
+ footer.getEncryption().getVariants(encryption.getVariantId());
+ byte[] bytes = protoVariant.getFileStatistics().toByteArray();
+ BufferChunk buffer = new BufferChunk(ByteBuffer.wrap(bytes), 0);
+ EncryptionAlgorithm algorithm = encryption.getKeyDescription().getAlgorithm();
+ byte[] iv = new byte[algorithm.getIvLength()];
+ CryptoUtils.modifyIvForStream(encryption.getRoot().getId(),
+ OrcProto.Stream.Kind.FILE_STATISTICS, footer.getStripesCount())
+ .accept(iv);
+ InStream.StreamOptions options = new InStream.StreamOptions(compression)
+ .withEncryption(algorithm, key, iv);
+ InStream in = InStream.create("encrypted file stats", buffer,
+ bytes.length, 0, options);
+ OrcProto.FileStatistics decrypted = OrcProto.FileStatistics.parseFrom(in);
+ ColumnStatistics[] result = new ColumnStatistics[decrypted.getColumnCount()];
+ TypeDescription root = encryption.getRoot();
+ for(int i= 0; i < result.length; ++i){
+ result[i] = ColumnStatisticsImpl.deserialize(root.findSubtype(root.getId() + i),
+ decrypted.getColumn(i));
+ }
+ return result;
+ }
}
public static ColumnStatistics[] deserializeStats(
@@ -351,12 +451,17 @@ public class ReaderImpl implements Reader {
this.writerVersion =
OrcFile.WriterVersion.from(writer, fileMetadata.getWriterVersionNum());
this.types = fileMetadata.getTypes();
+ OrcUtils.isValidTypeTree(this.types, 0);
+ this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
this.rowIndexStride = fileMetadata.getRowIndexStride();
this.contentLength = fileMetadata.getContentLength();
this.numberOfRows = fileMetadata.getNumberOfRows();
this.fileStats = fileMetadata.getFileStats();
this.stripes = fileMetadata.getStripes();
this.userMetadata = null; // not cached and not needed here
+ // FileMetadata is obsolete and doesn't support encryption
+ this.encryption = new ReaderEncryption(null, schema, stripes,
+ options.getKeyProvider(), conf);
} else {
OrcTail orcTail = options.getOrcTail();
if (orcTail == null) {
@@ -371,6 +476,8 @@ public class ReaderImpl implements Reader {
this.metadataSize = tail.getMetadataSize();
this.versionList = tail.getPostScript().getVersionList();
this.types = tail.getFooter().getTypesList();
+ OrcUtils.isValidTypeTree(this.types, 0);
+ this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
this.rowIndexStride = tail.getFooter().getRowIndexStride();
this.contentLength = tail.getFooter().getContentLength();
this.numberOfRows = tail.getFooter().getNumberOfRows();
@@ -378,10 +485,14 @@ public class ReaderImpl implements Reader {
this.fileStats = tail.getFooter().getStatisticsList();
this.writerVersion = tail.getWriterVersion();
this.stripes = tail.getStripes();
- this.stripeStats = tail.getStripeStatisticsProto();
+ try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
+ InStream.StreamOptions compress = InStream.options().withCodec(codec)
+ .withBufferSize(bufferSize);
+ this.stripeStats = tail.getStripeStatisticsProto(compress);
+ }
+ this.encryption = new ReaderEncryption(tail.getFooter(), schema,
+ stripes, options.getKeyProvider(), conf);
}
- OrcUtils.isValidTypeTree(this.types, 0);
- this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
}
protected FileSystem getFileSystem() throws IOException {
@@ -559,12 +670,9 @@ public class ReaderImpl implements Reader {
ByteBuffer footerBuffer = buffer.slice();
buffer.reset();
OrcProto.Footer footer;
- CompressionCodec codec = OrcCodecPool.getCodec(compressionKind);
- try {
+ try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)){
footer = extractFooter(footerBuffer, 0, footerSize,
InStream.options().withCodec(codec).withBufferSize(bufferSize));
- } finally {
- OrcCodecPool.returnCodec(compressionKind, codec);
}
fileTailBuilder.setFooter(footer);
} catch (Throwable thr) {
@@ -604,7 +712,6 @@ public class ReaderImpl implements Reader {
return new RecordReaderImpl(this, options);
}
-
@Override
public long getRawDataSize() {
// if the deserializedSize is not computed, then compute it, else
@@ -757,12 +864,9 @@ public class ReaderImpl implements Reader {
@Override
public List<StripeStatistics> getStripeStatistics() throws IOException {
if (metadata == null) {
- CompressionCodec codec = OrcCodecPool.getCodec(compressionKind);
- try {
+ try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
metadata = extractMetadata(tail.getSerializedTail(), 0, metadataSize,
InStream.options().withCodec(codec).withBufferSize(bufferSize));
- } finally {
- OrcCodecPool.returnCodec(compressionKind, codec);
}
}
if (stripeStats == null) {
diff --git a/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java b/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java
index c300500..b3445f9 100644
--- a/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java
+++ b/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java
@@ -62,9 +62,9 @@ import java.util.Arrays;
*/
public class SHA256MaskFactory extends MaskFactory {
- final MessageDigest md;
+ private final MessageDigest md;
- public SHA256MaskFactory(final String... params) {
+ SHA256MaskFactory() {
super();
try {
md = MessageDigest.getInstance("SHA-256");
@@ -138,9 +138,9 @@ public class SHA256MaskFactory extends MaskFactory {
/**
* Helper function to mask binary data with it's SHA-256 hash.
*
- * @param source
- * @param row
- * @param target
+ * @param source the source data
+ * @param row the row that we are translating
+ * @param target the output data
*/
void maskBinary(final BytesColumnVector source, final int row,
final BytesColumnVector target) {
@@ -207,7 +207,7 @@ public class SHA256MaskFactory extends MaskFactory {
final TypeDescription schema;
/* create an instance */
- public StringMask(TypeDescription schema) {
+ StringMask(TypeDescription schema) {
super();
this.schema = schema;
}
@@ -254,7 +254,7 @@ public class SHA256MaskFactory extends MaskFactory {
class BinaryMask implements DataMask {
/* create an instance */
- public BinaryMask() {
+ BinaryMask() {
super();
}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java
new file mode 100644
index 0000000..fe54c49
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl.reader;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.orc.OrcProto;
+import org.apache.orc.StripeInformation;
+import org.apache.orc.TypeDescription;
+import org.apache.orc.impl.HadoopShims;
+import org.apache.orc.impl.HadoopShimsFactory;
+import org.apache.orc.impl.MaskDescriptionImpl;
+
+import java.io.IOException;
+import java.security.SecureRandom;
+import java.util.Arrays;
+import java.util.List;
+
+public class ReaderEncryption {
+ private final HadoopShims.KeyProvider keyProvider;
+ private final ReaderEncryptionKey[] keys;
+ private final MaskDescriptionImpl[] masks;
+ private final ReaderEncryptionVariant[] variants;
+ // Mapping from each column to the next variant to try for that column.
+ // A value of variants.length means no encryption
+ private final ReaderEncryptionVariant[] columnVariants;
+
+ public ReaderEncryption() throws IOException {
+ this(null, null, null, null, null);
+ }
+
+ public ReaderEncryption(OrcProto.Footer footer,
+ TypeDescription schema,
+ List<StripeInformation> stripes,
+ HadoopShims.KeyProvider provider,
+ Configuration conf) throws IOException {
+ if (footer == null || !footer.hasEncryption()) {
+ keyProvider = null;
+ keys = new ReaderEncryptionKey[0];
+ masks = new MaskDescriptionImpl[0];
+ variants = new ReaderEncryptionVariant[0];
+ columnVariants = null;
+ } else {
+ keyProvider = provider != null ? provider :
+ HadoopShimsFactory.get().getKeyProvider(conf, new SecureRandom());
+ OrcProto.Encryption encrypt = footer.getEncryption();
+ masks = new MaskDescriptionImpl[encrypt.getMaskCount()];
+ for(int m=0; m < masks.length; ++m) {
+ masks[m] = new MaskDescriptionImpl(m, encrypt.getMask(m));
+ }
+ keys = new ReaderEncryptionKey[encrypt.getKeyCount()];
+ for(int k=0; k < keys.length; ++k) {
+ keys[k] = new ReaderEncryptionKey(encrypt.getKey(k));
+ }
+ variants = new ReaderEncryptionVariant[encrypt.getVariantsCount()];
+ for(int v=0; v < variants.length; ++v) {
+ OrcProto.EncryptionVariant variant = encrypt.getVariants(v);
+ variants[v] = new ReaderEncryptionVariant(keys[variant.getKey()], v,
+ variant, schema, stripes, keyProvider);
+ }
+ columnVariants = new ReaderEncryptionVariant[schema.getMaximumId() + 1];
+ for(int v = 0; v < variants.length; ++v) {
+ TypeDescription root = variants[v].getRoot();
+ for(int c = root.getId(); c <= root.getMaximumId(); ++c) {
+ if (columnVariants[c] == null) {
+ columnVariants[c] = variants[v];
+ }
+ }
+ }
+ }
+ }
+
+ public MaskDescriptionImpl[] getMasks() {
+ return masks;
+ }
+
+ public ReaderEncryptionKey[] getKeys() {
+ return keys;
+ }
+
+ public ReaderEncryptionVariant[] getVariants() {
+ return variants;
+ }
+
+ /**
+ * Find the next possible variant in this file for the given column.
+ * @param column the column to find a variant for
+ * @param lastVariant the previous variant that we looked at
+ * @return the next variant or null if there are none
+ */
+ private ReaderEncryptionVariant findNextVariant(int column,
+ int lastVariant) {
+ for(int v = lastVariant + 1; v < variants.length; ++v) {
+ TypeDescription root = variants[v].getRoot();
+ if (root.getId() <= column && column <= root.getMaximumId()) {
+ return variants[v];
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Get the variant for a given column that the user has access to.
+ * If we haven't tried a given key, try to decrypt this variant's footer key
+ * to see if the KeyProvider will give it to us. If not, continue to the
+ * next variant.
+ * @param column the column id
+ * @return null for no encryption or the encryption variant
+ */
+ public ReaderEncryptionVariant getVariant(int column) throws IOException {
+ if (keyProvider != null) {
+ while (columnVariants[column] != null) {
+ ReaderEncryptionVariant result = columnVariants[column];
+ switch (result.getKeyDescription().getState()) {
+ case FAILURE:
+ break;
+ case SUCCESS:
+ return result;
+ case UNTRIED:
+ // try to get the footer key, to see if we have access
+ if (result.getFileFooterKey() != null) {
+ return result;
+ }
+ }
+ columnVariants[column] = findNextVariant(column, result.getVariantId());
+ }
+ }
+ return null;
+ }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java
new file mode 100644
index 0000000..407bebb
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl.reader;
+
+import org.apache.orc.EncryptionKey;
+import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.OrcProto;
+import org.apache.orc.impl.HadoopShims;
+import org.jetbrains.annotations.NotNull;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * This tracks the keys for reading encrypted columns.
+ */
+public class ReaderEncryptionKey implements EncryptionKey {
+ private final String name;
+ private final int version;
+ private final EncryptionAlgorithm algorithm;
+ private final List<ReaderEncryptionVariant> roots = new ArrayList<>();
+
+ /**
+ * Store the state of whether we've tried to decrypt a local key using this
+ * key or not. If it fails the first time, we assume the user doesn't have
+ * permission and move on. However, we don't want to retry the same failed
+ * key over and over again.
+ */
+ public enum State {
+ UNTRIED,
+ FAILURE,
+ SUCCESS
+ }
+
+ private State state = State.UNTRIED;
+
+ public ReaderEncryptionKey(OrcProto.EncryptionKey key) {
+ name = key.getKeyName();
+ version = key.getKeyVersion();
+ algorithm =
+ EncryptionAlgorithm.fromSerialization(key.getAlgorithm().getNumber());
+ }
+
+ @Override
+ public String getKeyName() {
+ return name;
+ }
+
+ @Override
+ public int getKeyVersion() {
+ return version;
+ }
+
+ @Override
+ public EncryptionAlgorithm getAlgorithm() {
+ return algorithm;
+ }
+
+ @Override
+ public ReaderEncryptionVariant[] getEncryptionRoots() {
+ return roots.toArray(new ReaderEncryptionVariant[roots.size()]);
+ }
+
+ public HadoopShims.KeyMetadata getMetadata() {
+ return new HadoopShims.KeyMetadata(name, version, algorithm);
+ }
+
+ public State getState() {
+ return state;
+ }
+
+ public void setFailure() {
+ state = State.FAILURE;
+ }
+
+ public void setSucess() {
+ if (state == State.FAILURE) {
+ throw new IllegalStateException("Key " + name + " had already failed.");
+ }
+ state = State.SUCCESS;
+ }
+
+ void addVariant(ReaderEncryptionVariant newVariant) {
+ roots.add(newVariant);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == null || getClass() != other.getClass()) {
+ return false;
+ } else if (other == this) {
+ return true;
+ } else {
+ return compareTo((EncryptionKey) other) == 0;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return name.hashCode() * 127 + version * 7 + algorithm.hashCode();
+ }
+
+ @Override
+ public int compareTo(@NotNull EncryptionKey other) {
+ int result = name.compareTo(other.getKeyName());
+ if (result == 0) {
+ result = Integer.compare(version, other.getKeyVersion());
+ }
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return name + "@" + version + " w/ " + algorithm;
+ }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java
new file mode 100644
index 0000000..255952d
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.orc.impl.reader;
+
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.orc.EncryptionAlgorithm;
+import org.apache.orc.EncryptionVariant;
+import org.apache.orc.OrcProto;
+import org.apache.orc.StripeInformation;
+import org.apache.orc.TypeDescription;
+import org.apache.orc.impl.HadoopShims;
+import org.apache.orc.impl.LocalKey;
+import org.jetbrains.annotations.NotNull;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.security.Key;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Information about an encrypted column.
+ */
+public class ReaderEncryptionVariant implements EncryptionVariant {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(ReaderEncryptionVariant.class);
+ private final HadoopShims.KeyProvider provider;
+ private final ReaderEncryptionKey key;
+ private final TypeDescription column;
+ private final int variantId;
+ private final LocalKey[] localKeys;
+ private final LocalKey footerKey;
+
+ /**
+ * Create a reader's view of an encryption variant.
+ * @param key the encryption key description
+ * @param variantId the of of the variant (0..N-1)
+ * @param proto the serialized description of the variant
+ * @param schema the file schema
+ * @param stripes the stripe information
+ * @param provider the key provider
+ */
+ public ReaderEncryptionVariant(ReaderEncryptionKey key,
+ int variantId,
+ OrcProto.EncryptionVariant proto,
+ TypeDescription schema,
+ List<StripeInformation> stripes,
+ HadoopShims.KeyProvider provider) {
+ this.key = key;
+ this.variantId = variantId;
+ this.provider = provider;
+ this.column = proto.hasRoot() ? schema.findSubtype(proto.getRoot()) : null;
+ this.localKeys = new LocalKey[stripes.size()];
+ HashMap<BytesWritable, LocalKey> cache = new HashMap<>();
+ for(int s=0; s < localKeys.length; ++s) {
+ StripeInformation stripe = stripes.get(s);
+ localKeys[s] = getCachedKey(cache, key.getAlgorithm(),
+ stripe.getEncryptedLocalKeys()[variantId]);
+ }
+ if (proto.hasEncryptedKey()) {
+ footerKey = getCachedKey(cache, key.getAlgorithm(),
+ proto.getEncryptedKey().toByteArray());
+ } else {
+ footerKey = null;
+ }
+ key.addVariant(this);
+ }
+
+ @Override
+ public ReaderEncryptionKey getKeyDescription() {
+ return key;
+ }
+
+ @Override
+ public TypeDescription getRoot() {
+ return column;
+ }
+
+ @Override
+ public int getVariantId() {
+ return variantId;
+ }
+
+ /**
+ * Deduplicate the local keys so that we only decrypt each local key once.
+ * @param cache the cache to use
+ * @param encrypted the encrypted key
+ * @return the local key
+ */
+ private static LocalKey getCachedKey(Map<BytesWritable, LocalKey> cache,
+ EncryptionAlgorithm algorithm,
+ byte[] encrypted) {
+ // wrap byte array in BytesWritable to get equality and hash
+ BytesWritable wrap = new BytesWritable(encrypted);
+ LocalKey result = cache.get(wrap);
+ if (result == null) {
+ result = new LocalKey(algorithm, null, encrypted);
+ cache.put(wrap, result);
+ }
+ return result;
+ }
+
+ private Key getDecryptedKey(LocalKey localKey) throws IOException {
+ Key result = localKey.getDecryptedKey();
+ if (result == null) {
+ switch (this.key.getState()) {
+ case UNTRIED:
+ try {
+ result = provider.decryptLocalKey(key.getMetadata(),
+ localKey.getEncryptedKey());
+ } catch (IOException ioe) {
+ LOG.info("Can't decrypt using key {}", key);
+ }
+ if (result != null) {
+ localKey.setDecryptedKey(result);
+ key.setSucess();
+ } else {
+ key.setFailure();
+ }
+ break;
+ case SUCCESS:
+ result = provider.decryptLocalKey(key.getMetadata(),
+ localKey.getEncryptedKey());
+ if (result == null) {
+ throw new IOException("Can't decrypt local key " + key);
+ }
+ localKey.setDecryptedKey(result);
+ break;
+ case FAILURE:
+ return null;
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public Key getFileFooterKey() throws IOException {
+ return getDecryptedKey(footerKey);
+ }
+
+ @Override
+ public Key getStripeKey(long stripe) throws IOException {
+ return getDecryptedKey(localKeys[(int) stripe]);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == null || other.getClass() != getClass()) {
+ return false;
+ } else {
+ return compareTo((EncryptionVariant) other) == 0;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return key.hashCode() * 127 + column.getId();
+ }
+
+ @Override
+ public int compareTo(@NotNull EncryptionVariant other) {
+ if (other == this) {
+ return 0;
+ } else if (key == other.getKeyDescription()) {
+ return Integer.compare(column.getId(), other.getRoot().getId());
+ } else if (key == null) {
+ return -1;
+ } else {
+ return key.compareTo(other.getKeyDescription());
+ }
+ }
+}
diff --git a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
index 2e78196..9eaaa61 100644
--- a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
+++ b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
@@ -2104,7 +2104,7 @@ public class TestRecordReaderImpl {
.setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
encodings.add(OrcProto.ColumnEncoding.newBuilder()
.setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
- boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe),
+ boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe, 0, 1, null),
indexes, null, encodings, null, false);
assertEquals(4, rows.length);
assertEquals(false, rows[0]);
@@ -2150,7 +2150,7 @@ public class TestRecordReaderImpl {
.setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
encodings.add(OrcProto.ColumnEncoding.newBuilder()
.setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build());
- boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe),
+ boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe, 0, 1, null),
indexes, null, encodings, null, false);
assertEquals(3, rows.length);
assertEquals(false, rows[0]);