You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2023/06/14 08:50:34 UTC
[doris-thirdparty] branch orc updated: [Feature] Add `StringDictFilter` callbacks. (#90)
This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch orc
in repository https://gitbox.apache.org/repos/asf/doris-thirdparty.git
The following commit(s) were added to refs/heads/orc by this push:
new a4e67d73 [Feature] Add `StringDictFilter` callbacks. (#90)
a4e67d73 is described below
commit a4e67d732e9acf3acb45e85c4cfe84d630e71ec1
Author: Qi Chen <ka...@gmail.com>
AuthorDate: Wed Jun 14 16:50:28 2023 +0800
[Feature] Add `StringDictFilter` callbacks. (#90)
---
c++/include/orc/Reader.hh | 21 +++++++--
c++/src/ColumnReader.cc | 108 ++++++++++++++++++++++++++++++++++------------
c++/src/ColumnReader.hh | 7 ++-
c++/src/Reader.cc | 62 +++++++++++++++++++-------
c++/src/Reader.hh | 24 +++++++++--
5 files changed, 171 insertions(+), 51 deletions(-)
diff --git a/c++/include/orc/Reader.hh b/c++/include/orc/Reader.hh
index 429c1ed0..81e3bf1d 100644
--- a/c++/include/orc/Reader.hh
+++ b/c++/include/orc/Reader.hh
@@ -32,6 +32,7 @@
#include <memory>
#include <set>
#include <string>
+#include <unordered_map>
#include <vector>
namespace orc {
@@ -370,6 +371,17 @@ namespace orc {
void* arg = nullptr) const = 0;
};
+ class StringDictFilter {
+ public:
+ virtual ~StringDictFilter() = default;
+ virtual void fillDictFilterColumnNames(
+ std::unique_ptr<orc::StripeInformation> current_strip_information,
+ std::list<std::string>& columnNames) const = 0;
+ virtual void onStringDictsLoaded(
+ std::unordered_map<std::string, StringDictionary*>& columnNameToDictMap,
+ bool* isStripeFiltered) const = 0;
+ };
+
class RowReader;
/**
@@ -555,15 +567,18 @@ namespace orc {
* Create a RowReader based on this reader with the default options.
* @return a RowReader to read the rows
*/
- virtual std::unique_ptr<RowReader> createRowReader(const ORCFilter* filter = nullptr) const = 0;
+ virtual std::unique_ptr<RowReader> createRowReader(
+ const ORCFilter* filter = nullptr,
+ const StringDictFilter* stringDictFilter = nullptr) const = 0;
/**
* Create a RowReader based on this reader.
* @param options RowReader Options
* @return a RowReader to read the rows
*/
- virtual std::unique_ptr<RowReader> createRowReader(const RowReaderOptions& options,
- const ORCFilter* filter = nullptr) const = 0;
+ virtual std::unique_ptr<RowReader> createRowReader(
+ const RowReaderOptions& options, const ORCFilter* filter = nullptr,
+ const StringDictFilter* stringDictFilter = nullptr) const = 0;
/**
* Get the name of the input stream.
diff --git a/c++/src/ColumnReader.cc b/c++/src/ColumnReader.cc
index 8c6e1663..78005abc 100644
--- a/c++/src/ColumnReader.cc
+++ b/c++/src/ColumnReader.cc
@@ -715,7 +715,14 @@ namespace orc {
class StringDictionaryColumnReader : public ColumnReader {
private:
std::shared_ptr<StringDictionary> dictionary;
+ std::unordered_map<std::string, int64_t> dictValueToCode;
std::unique_ptr<RleDecoder> rle;
+ StripeStreams& stripe;
+ // std::string columnName;
+ bool dictionaryLoaded;
+ uint32_t dictSize;
+ std::unique_ptr<RleDecoder> lengthDecoder;
+ std::unique_ptr<SeekableInputStream> blobStream;
void nextInternal(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull,
const ReadPhase& readPhase);
@@ -737,13 +744,19 @@ namespace orc {
void seekToRowGroup(std::unordered_map<uint64_t, PositionProvider>& positions,
const ReadPhase& readPhase) override;
+
+ StringDictionary* loadDictionary();
};
StringDictionaryColumnReader::StringDictionaryColumnReader(const Type& type,
- StripeStreams& stripe)
- : ColumnReader(type, stripe), dictionary(new StringDictionary(stripe.getMemoryPool())) {
+ StripeStreams& _stripe)
+ : ColumnReader(type, _stripe),
+ dictionary(new StringDictionary(_stripe.getMemoryPool())),
+ stripe(_stripe),
+ dictionaryLoaded(false),
+ dictSize(0) {
RleVersion rleVersion = convertRleVersion(stripe.getEncoding(columnId).kind());
- uint32_t dictSize = stripe.getEncoding(columnId).dictionarysize();
+ dictSize = stripe.getEncoding(columnId).dictionarysize();
std::unique_ptr<SeekableInputStream> stream =
stripe.getStream(columnId, proto::Stream_Kind_DATA, true);
if (stream == nullptr) {
@@ -754,26 +767,8 @@ namespace orc {
if (dictSize > 0 && stream == nullptr) {
throw ParseError("LENGTH stream not found in StringDictionaryColumn");
}
- std::unique_ptr<RleDecoder> lengthDecoder =
- createRleDecoder(std::move(stream), false, rleVersion, memoryPool, metrics);
- dictionary->dictionaryOffset.resize(dictSize + 1);
- int64_t* lengthArray = dictionary->dictionaryOffset.data();
- lengthDecoder->next(lengthArray + 1, dictSize, nullptr);
- lengthArray[0] = 0;
- for (uint32_t i = 1; i < dictSize + 1; ++i) {
- if (lengthArray[i] < 0) {
- throw ParseError("Negative dictionary entry length");
- }
- lengthArray[i] += lengthArray[i - 1];
- }
- int64_t blobSize = lengthArray[dictSize];
- dictionary->dictionaryBlob.resize(static_cast<uint64_t>(blobSize));
- std::unique_ptr<SeekableInputStream> blobStream =
- stripe.getStream(columnId, proto::Stream_Kind_DICTIONARY_DATA, false);
- if (blobSize > 0 && blobStream == nullptr) {
- throw ParseError("DICTIONARY_DATA stream not found in StringDictionaryColumn");
- }
- readFully(dictionary->dictionaryBlob.data(), blobSize, blobStream.get());
+ lengthDecoder = createRleDecoder(std::move(stream), false, rleVersion, memoryPool, metrics);
+ blobStream = stripe.getStream(columnId, proto::Stream_Kind_DICTIONARY_DATA, false);
}
StringDictionaryColumnReader::~StringDictionaryColumnReader() {
@@ -802,11 +797,12 @@ namespace orc {
// update the notNull from the parent class
notNull = rowBatch.hasNulls ? rowBatch.notNull.data() : nullptr;
StringVectorBatch& byteBatch = dynamic_cast<StringVectorBatch&>(rowBatch);
- char* blob = dictionary->dictionaryBlob.data();
- int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
char** outputStarts = byteBatch.data.data();
int64_t* outputLengths = byteBatch.length.data();
rle->next(outputLengths, numValues, notNull);
+ loadDictionary();
+ char* blob = dictionary->dictionaryBlob.data();
+ int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
uint64_t dictionaryCount = dictionary->dictionaryOffset.size() - 1;
if (notNull) {
for (uint64_t i = 0; i < numValues; ++i) {
@@ -840,12 +836,14 @@ namespace orc {
// update the notNull from the parent class
notNull = rowBatch.hasNulls ? rowBatch.notNull.data() : nullptr;
StringVectorBatch& byteBatch = dynamic_cast<StringVectorBatch&>(rowBatch);
- char* blob = dictionary->dictionaryBlob.data();
- int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
char** outputStarts = byteBatch.data.data();
int64_t* outputLengths = byteBatch.length.data();
std::unique_ptr<int64_t[]> tmpOutputLengths(new int64_t[byteBatch.length.size()]);
rle->next(tmpOutputLengths.get(), numValues, notNull);
+ loadDictionary();
+ char* blob = dictionary->dictionaryBlob.data();
+ int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
+
uint64_t dictionaryCount = dictionary->dictionaryOffset.size() - 1;
if (notNull) {
for (size_t i = 0; i < numValues; i++) {
@@ -892,6 +890,8 @@ namespace orc {
// Length buffer is reused to save dictionary entry ids
rle->next(batch.index.data(), numValues, notNull);
+ loadDictionary();
+ batch.dictionary = this->dictionary;
}
void StringDictionaryColumnReader::seekToRowGroup(
@@ -900,6 +900,32 @@ namespace orc {
rle->seek(positions.at(columnId));
}
+ StringDictionary* StringDictionaryColumnReader::loadDictionary() {
+ if (dictionaryLoaded) {
+ return dictionary.get();
+ }
+ dictionary->dictionaryOffset.resize(dictSize + 1);
+ int64_t* lengthArray = dictionary->dictionaryOffset.data();
+ lengthDecoder->next(lengthArray + 1, dictSize, nullptr);
+ lengthArray[0] = 0;
+ for (uint32_t i = 1; i < dictSize + 1; ++i) {
+ if (lengthArray[i] < 0) {
+ throw ParseError("Negative dictionary entry length");
+ }
+ lengthArray[i] += lengthArray[i - 1];
+ }
+ int64_t blobSize = lengthArray[dictSize];
+ // For insert_many_strings_overflow
+ static constexpr int MAX_STRINGS_OVERFLOW_SIZE = 128;
+ dictionary->dictionaryBlob.resize(static_cast<uint64_t>(blobSize) + MAX_STRINGS_OVERFLOW_SIZE);
+ if (blobSize > 0 && blobStream == nullptr) {
+ throw ParseError("DICTIONARY_DATA stream not found in StringDictionaryColumn");
+ }
+ readFully(dictionary->dictionaryBlob.data(), blobSize, blobStream.get());
+ dictionaryLoaded = true;
+ return dictionary.get();
+ }
+
class StringDirectColumnReader : public ColumnReader {
private:
std::unique_ptr<RleDecoder> lengthRle;
@@ -1082,6 +1108,10 @@ namespace orc {
void seekToRowGroup(std::unordered_map<uint64_t, PositionProvider>& positions,
const ReadPhase& readPhase) override;
+ void loadStringDicts(const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+ std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+ const StringDictFilter* stringDictFilter);
+
private:
template <bool encoded>
void nextInternal(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull,
@@ -1163,6 +1193,22 @@ namespace orc {
}
}
+ void StructColumnReader::loadStringDicts(
+ const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+ std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+ const StringDictFilter* stringDictFilter) {
+ for (auto& ptr : children) {
+ auto iter = columnIdToNameMap.find(ptr->getType().getColumnId());
+ if (iter == columnIdToNameMap.end()) {
+ continue;
+ }
+ auto* stringDictionaryColumnReader = dynamic_cast<StringDictionaryColumnReader*>(ptr.get());
+ if (stringDictionaryColumnReader != nullptr) {
+ (*columnNameToDictMap)[iter->second] = stringDictionaryColumnReader->loadDictionary();
+ }
+ }
+ }
+
class ListColumnReader : public ColumnReader {
private:
std::unique_ptr<ColumnReader> child;
@@ -2296,4 +2342,12 @@ namespace orc {
}
}
+ void loadStringDicts(ColumnReader* columnReader,
+ const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+ std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+ const StringDictFilter* stringDictFilter) {
+ auto* structColumnReader = static_cast<StructColumnReader*>(columnReader);
+ structColumnReader->loadStringDicts(columnIdToNameMap, columnNameToDictMap, stringDictFilter);
+ }
+
} // namespace orc
diff --git a/c++/src/ColumnReader.hh b/c++/src/ColumnReader.hh
index fd894891..c437d7cc 100644
--- a/c++/src/ColumnReader.hh
+++ b/c++/src/ColumnReader.hh
@@ -168,7 +168,7 @@ namespace orc {
const ReadPhase& readPhase = ReadPhase::ALL,
uint16_t* sel_rowid_idx = nullptr, size_t sel_size = 0) {
rowBatch.isEncoded = false;
- next(rowBatch, numValues, notNull, readPhase, sel_rowid_idx);
+ next(rowBatch, numValues, notNull, readPhase, sel_rowid_idx, sel_size);
}
/**
@@ -184,6 +184,11 @@ namespace orc {
*/
std::unique_ptr<ColumnReader> buildReader(const Type& type, StripeStreams& stripe,
bool useTightNumericVector = false);
+
+ void loadStringDicts(ColumnReader* columnReader,
+ const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+ std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+ const StringDictFilter* stringDictFilter);
} // namespace orc
#endif
diff --git a/c++/src/Reader.cc b/c++/src/Reader.cc
index 485b874a..f69f27db 100644
--- a/c++/src/Reader.cc
+++ b/c++/src/Reader.cc
@@ -247,7 +247,8 @@ namespace orc {
}
RowReaderImpl::RowReaderImpl(std::shared_ptr<FileContents> _contents,
- const RowReaderOptions& opts, const ORCFilter* _filter)
+ const RowReaderOptions& opts, const ORCFilter* _filter,
+ const StringDictFilter* _stringDictFilter)
: localTimezone(getLocalTimezone()),
contents(_contents),
throwOnHive11DecimalOverflow(opts.getThrowOnHive11DecimalOverflow()),
@@ -256,7 +257,8 @@ namespace orc {
firstRowOfStripe(*contents->pool, 0),
enableEncodedBlock(opts.getEnableLazyDecoding()),
readerTimezone(getTimezoneByName(opts.getTimezoneName())),
- filter(_filter) {
+ filter(_filter),
+ stringDictFilter(_stringDictFilter) {
uint64_t numberOfStripes;
numberOfStripes = static_cast<uint64_t>(footer->stripes_size());
currentStripe = numberOfStripes;
@@ -893,18 +895,20 @@ namespace orc {
}
}
- std::unique_ptr<RowReader> ReaderImpl::createRowReader(const ORCFilter* filter) const {
+ std::unique_ptr<RowReader> ReaderImpl::createRowReader(
+ const ORCFilter* filter, const StringDictFilter* stringDictFilter) const {
RowReaderOptions defaultOpts;
- return createRowReader(defaultOpts, filter);
+ return createRowReader(defaultOpts, filter, stringDictFilter);
}
- std::unique_ptr<RowReader> ReaderImpl::createRowReader(const RowReaderOptions& opts,
- const ORCFilter* filter) const {
+ std::unique_ptr<RowReader> ReaderImpl::createRowReader(
+ const RowReaderOptions& opts, const ORCFilter* filter,
+ const StringDictFilter* stringDictFilter) const {
if (opts.getSearchArgument() && !isMetadataLoaded) {
// load stripe statistics for PPD
readMetadata();
}
- return std::make_unique<RowReaderImpl>(contents, opts, filter);
+ return std::make_unique<RowReaderImpl>(contents, opts, filter, stringDictFilter);
}
uint64_t maxStreamsForType(const proto::Type& type) {
@@ -1101,7 +1105,7 @@ namespace orc {
return;
}
- do {
+ while (currentStripe < lastStripe) {
currentStripeInfo = footer->stripes(static_cast<int>(currentStripe));
uint64_t fileLength = contents->stream->getLength();
if (currentStripeInfo.offset() + currentStripeInfo.indexlength() +
@@ -1144,21 +1148,16 @@ namespace orc {
loadStripeIndex();
// select row groups to read in the current stripe
sargsApplier->pickRowGroups(rowsInCurrentStripe, rowIndexes, bloomFilterIndex);
- if (sargsApplier->hasSelectedFrom(currentRowInStripe)) {
- // current stripe has at least one row group matching the predicate
- break;
- }
- isStripeNeeded = false;
+ isStripeNeeded = sargsApplier->hasSelectedFrom(currentRowInStripe);
}
if (!isStripeNeeded) {
// advance to next stripe when current stripe has no matching rows
currentStripe += 1;
currentRowInStripe = 0;
+ continue;
}
}
- } while (sargsApplier && currentStripe < lastStripe);
- if (currentStripe < lastStripe) {
// get writer timezone info from stripe footer to help understand timestamp values.
const Timezone& writerTimezone = currentStripeFooter.has_writertimezone()
? getTimezoneByName(currentStripeFooter.writertimezone())
@@ -1168,6 +1167,35 @@ namespace orc {
readerTimezone);
reader = buildReader(*contents->schema, stripeStreams, useTightNumericVector);
+ if (stringDictFilter != nullptr) {
+ std::list<std::string> dictFilterColumnNames;
+ std::unique_ptr<StripeInformation> currentStripeInformation(new StripeInformationImpl(
+ currentStripeInfo.offset(), currentStripeInfo.indexlength(),
+ currentStripeInfo.datalength(), currentStripeInfo.footerlength(),
+ currentStripeInfo.numberofrows(), contents->stream.get(), *contents->pool,
+ contents->compression, contents->blockSize, contents->readerMetrics));
+ stringDictFilter->fillDictFilterColumnNames(std::move(currentStripeInformation),
+ dictFilterColumnNames);
+ std::unordered_map<uint64_t, std::string> columnIdToNameMap;
+ for (auto& dictFilterColumnName : dictFilterColumnNames) {
+ columnIdToNameMap[nameTypeMap[dictFilterColumnName]->getColumnId()] =
+ dictFilterColumnName;
+ }
+ std::unordered_map<std::string, StringDictionary*> columnIdToDictMap;
+ loadStringDicts(reader.get(), columnIdToNameMap, &columnIdToDictMap, stringDictFilter);
+ if (!columnIdToNameMap.empty()) {
+ bool isStripeFiltered;
+ stringDictFilter->onStringDictsLoaded(columnIdToDictMap, &isStripeFiltered);
+ if (isStripeFiltered) {
+ reader.reset();
+ // advance to next stripe when current stripe has no matching rows
+ currentStripe += 1;
+ currentRowInStripe = 0;
+ continue;
+ }
+ }
+ }
+
if (sargsApplier) {
// move to the 1st selected row group when PPD is enabled.
currentRowInStripe =
@@ -1179,7 +1207,9 @@ namespace orc {
readPhase);
}
}
- } else {
+ break;
+ }
+ if (currentStripe >= lastStripe) {
// All remaining stripes are skipped.
markEndOfFile();
}
diff --git a/c++/src/Reader.hh b/c++/src/Reader.hh
index d84532c8..2ff3bbe8 100644
--- a/c++/src/Reader.hh
+++ b/c++/src/Reader.hh
@@ -49,9 +49,19 @@ namespace orc {
return *this;
}
+ const ORCFilter* getStringDictFilter() const {
+ return filter;
+ }
+
+ ReaderContext& setStringDictFilter(const StringDictFilter* _stringDictFilter) {
+ this->stringDictFilter = _stringDictFilter;
+ return *this;
+ }
+
private:
std::unordered_set<int> filterColumnIds;
const ORCFilter* filter;
+ const StringDictFilter* stringDictFilter;
};
/**
@@ -202,6 +212,8 @@ namespace orc {
std::map<std::string, Type*> nameTypeMap;
std::vector<std::string> columns;
+ const StringDictFilter* stringDictFilter;
+
// load stripe index if not done so
void loadStripeIndex();
@@ -257,7 +269,8 @@ namespace orc {
* @param options options for reading
*/
RowReaderImpl(std::shared_ptr<FileContents> contents, const RowReaderOptions& options,
- const ORCFilter* filter = nullptr);
+ const ORCFilter* filter = nullptr,
+ const StringDictFilter* stringDictFilter = nullptr);
// Select the columns from the options object
const std::vector<bool> getSelectedColumns() const override;
@@ -357,10 +370,13 @@ namespace orc {
std::unique_ptr<StripeStatistics> getStripeStatistics(uint64_t stripeIndex) const override;
- std::unique_ptr<RowReader> createRowReader(const ORCFilter* filter = nullptr) const override;
+ std::unique_ptr<RowReader> createRowReader(
+ const ORCFilter* filter = nullptr,
+ const StringDictFilter* stringDictFilter = nullptr) const override;
- std::unique_ptr<RowReader> createRowReader(const RowReaderOptions& options,
- const ORCFilter* filter = nullptr) const override;
+ std::unique_ptr<RowReader> createRowReader(
+ const RowReaderOptions& options, const ORCFilter* filter = nullptr,
+ const StringDictFilter* stringDictFilter = nullptr) const override;
uint64_t getContentLength() const override;
uint64_t getStripeStatisticsLength() const override;
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org