You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2023/06/14 08:50:34 UTC

[doris-thirdparty] branch orc updated: [Feature] Add `StringDictFilter` callbacks. (#90)

This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch orc
in repository https://gitbox.apache.org/repos/asf/doris-thirdparty.git


The following commit(s) were added to refs/heads/orc by this push:
     new a4e67d73 [Feature] Add `StringDictFilter` callbacks. (#90)
a4e67d73 is described below

commit a4e67d732e9acf3acb45e85c4cfe84d630e71ec1
Author: Qi Chen <ka...@gmail.com>
AuthorDate: Wed Jun 14 16:50:28 2023 +0800

    [Feature] Add `StringDictFilter` callbacks. (#90)
---
 c++/include/orc/Reader.hh |  21 +++++++--
 c++/src/ColumnReader.cc   | 108 ++++++++++++++++++++++++++++++++++------------
 c++/src/ColumnReader.hh   |   7 ++-
 c++/src/Reader.cc         |  62 +++++++++++++++++++-------
 c++/src/Reader.hh         |  24 +++++++++--
 5 files changed, 171 insertions(+), 51 deletions(-)

diff --git a/c++/include/orc/Reader.hh b/c++/include/orc/Reader.hh
index 429c1ed0..81e3bf1d 100644
--- a/c++/include/orc/Reader.hh
+++ b/c++/include/orc/Reader.hh
@@ -32,6 +32,7 @@
 #include <memory>
 #include <set>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 namespace orc {
@@ -370,6 +371,17 @@ namespace orc {
                         void* arg = nullptr) const = 0;
   };
 
+  class StringDictFilter {
+   public:
+    virtual ~StringDictFilter() = default;
+    virtual void fillDictFilterColumnNames(
+        std::unique_ptr<orc::StripeInformation> current_strip_information,
+        std::list<std::string>& columnNames) const = 0;
+    virtual void onStringDictsLoaded(
+        std::unordered_map<std::string, StringDictionary*>& columnNameToDictMap,
+        bool* isStripeFiltered) const = 0;
+  };
+
   class RowReader;
 
   /**
@@ -555,15 +567,18 @@ namespace orc {
      * Create a RowReader based on this reader with the default options.
      * @return a RowReader to read the rows
      */
-    virtual std::unique_ptr<RowReader> createRowReader(const ORCFilter* filter = nullptr) const = 0;
+    virtual std::unique_ptr<RowReader> createRowReader(
+        const ORCFilter* filter = nullptr,
+        const StringDictFilter* stringDictFilter = nullptr) const = 0;
 
     /**
      * Create a RowReader based on this reader.
      * @param options RowReader Options
      * @return a RowReader to read the rows
      */
-    virtual std::unique_ptr<RowReader> createRowReader(const RowReaderOptions& options,
-                                                       const ORCFilter* filter = nullptr) const = 0;
+    virtual std::unique_ptr<RowReader> createRowReader(
+        const RowReaderOptions& options, const ORCFilter* filter = nullptr,
+        const StringDictFilter* stringDictFilter = nullptr) const = 0;
 
     /**
      * Get the name of the input stream.
diff --git a/c++/src/ColumnReader.cc b/c++/src/ColumnReader.cc
index 8c6e1663..78005abc 100644
--- a/c++/src/ColumnReader.cc
+++ b/c++/src/ColumnReader.cc
@@ -715,7 +715,14 @@ namespace orc {
   class StringDictionaryColumnReader : public ColumnReader {
    private:
     std::shared_ptr<StringDictionary> dictionary;
+    std::unordered_map<std::string, int64_t> dictValueToCode;
     std::unique_ptr<RleDecoder> rle;
+    StripeStreams& stripe;
+    //    std::string columnName;
+    bool dictionaryLoaded;
+    uint32_t dictSize;
+    std::unique_ptr<RleDecoder> lengthDecoder;
+    std::unique_ptr<SeekableInputStream> blobStream;
 
     void nextInternal(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull,
                       const ReadPhase& readPhase);
@@ -737,13 +744,19 @@ namespace orc {
 
     void seekToRowGroup(std::unordered_map<uint64_t, PositionProvider>& positions,
                         const ReadPhase& readPhase) override;
+
+    StringDictionary* loadDictionary();
   };
 
   StringDictionaryColumnReader::StringDictionaryColumnReader(const Type& type,
-                                                             StripeStreams& stripe)
-      : ColumnReader(type, stripe), dictionary(new StringDictionary(stripe.getMemoryPool())) {
+                                                             StripeStreams& _stripe)
+      : ColumnReader(type, _stripe),
+        dictionary(new StringDictionary(_stripe.getMemoryPool())),
+        stripe(_stripe),
+        dictionaryLoaded(false),
+        dictSize(0) {
     RleVersion rleVersion = convertRleVersion(stripe.getEncoding(columnId).kind());
-    uint32_t dictSize = stripe.getEncoding(columnId).dictionarysize();
+    dictSize = stripe.getEncoding(columnId).dictionarysize();
     std::unique_ptr<SeekableInputStream> stream =
         stripe.getStream(columnId, proto::Stream_Kind_DATA, true);
     if (stream == nullptr) {
@@ -754,26 +767,8 @@ namespace orc {
     if (dictSize > 0 && stream == nullptr) {
       throw ParseError("LENGTH stream not found in StringDictionaryColumn");
     }
-    std::unique_ptr<RleDecoder> lengthDecoder =
-        createRleDecoder(std::move(stream), false, rleVersion, memoryPool, metrics);
-    dictionary->dictionaryOffset.resize(dictSize + 1);
-    int64_t* lengthArray = dictionary->dictionaryOffset.data();
-    lengthDecoder->next(lengthArray + 1, dictSize, nullptr);
-    lengthArray[0] = 0;
-    for (uint32_t i = 1; i < dictSize + 1; ++i) {
-      if (lengthArray[i] < 0) {
-        throw ParseError("Negative dictionary entry length");
-      }
-      lengthArray[i] += lengthArray[i - 1];
-    }
-    int64_t blobSize = lengthArray[dictSize];
-    dictionary->dictionaryBlob.resize(static_cast<uint64_t>(blobSize));
-    std::unique_ptr<SeekableInputStream> blobStream =
-        stripe.getStream(columnId, proto::Stream_Kind_DICTIONARY_DATA, false);
-    if (blobSize > 0 && blobStream == nullptr) {
-      throw ParseError("DICTIONARY_DATA stream not found in StringDictionaryColumn");
-    }
-    readFully(dictionary->dictionaryBlob.data(), blobSize, blobStream.get());
+    lengthDecoder = createRleDecoder(std::move(stream), false, rleVersion, memoryPool, metrics);
+    blobStream = stripe.getStream(columnId, proto::Stream_Kind_DICTIONARY_DATA, false);
   }
 
   StringDictionaryColumnReader::~StringDictionaryColumnReader() {
@@ -802,11 +797,12 @@ namespace orc {
     // update the notNull from the parent class
     notNull = rowBatch.hasNulls ? rowBatch.notNull.data() : nullptr;
     StringVectorBatch& byteBatch = dynamic_cast<StringVectorBatch&>(rowBatch);
-    char* blob = dictionary->dictionaryBlob.data();
-    int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
     char** outputStarts = byteBatch.data.data();
     int64_t* outputLengths = byteBatch.length.data();
     rle->next(outputLengths, numValues, notNull);
+    loadDictionary();
+    char* blob = dictionary->dictionaryBlob.data();
+    int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
     uint64_t dictionaryCount = dictionary->dictionaryOffset.size() - 1;
     if (notNull) {
       for (uint64_t i = 0; i < numValues; ++i) {
@@ -840,12 +836,14 @@ namespace orc {
     // update the notNull from the parent class
     notNull = rowBatch.hasNulls ? rowBatch.notNull.data() : nullptr;
     StringVectorBatch& byteBatch = dynamic_cast<StringVectorBatch&>(rowBatch);
-    char* blob = dictionary->dictionaryBlob.data();
-    int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
     char** outputStarts = byteBatch.data.data();
     int64_t* outputLengths = byteBatch.length.data();
     std::unique_ptr<int64_t[]> tmpOutputLengths(new int64_t[byteBatch.length.size()]);
     rle->next(tmpOutputLengths.get(), numValues, notNull);
+    loadDictionary();
+    char* blob = dictionary->dictionaryBlob.data();
+    int64_t* dictionaryOffsets = dictionary->dictionaryOffset.data();
+
     uint64_t dictionaryCount = dictionary->dictionaryOffset.size() - 1;
     if (notNull) {
       for (size_t i = 0; i < numValues; i++) {
@@ -892,6 +890,8 @@ namespace orc {
 
     // Length buffer is reused to save dictionary entry ids
     rle->next(batch.index.data(), numValues, notNull);
+    loadDictionary();
+    batch.dictionary = this->dictionary;
   }
 
   void StringDictionaryColumnReader::seekToRowGroup(
@@ -900,6 +900,32 @@ namespace orc {
     rle->seek(positions.at(columnId));
   }
 
+  StringDictionary* StringDictionaryColumnReader::loadDictionary() {
+    if (dictionaryLoaded) {
+      return dictionary.get();
+    }
+    dictionary->dictionaryOffset.resize(dictSize + 1);
+    int64_t* lengthArray = dictionary->dictionaryOffset.data();
+    lengthDecoder->next(lengthArray + 1, dictSize, nullptr);
+    lengthArray[0] = 0;
+    for (uint32_t i = 1; i < dictSize + 1; ++i) {
+      if (lengthArray[i] < 0) {
+        throw ParseError("Negative dictionary entry length");
+      }
+      lengthArray[i] += lengthArray[i - 1];
+    }
+    int64_t blobSize = lengthArray[dictSize];
+    // For insert_many_strings_overflow
+    static constexpr int MAX_STRINGS_OVERFLOW_SIZE = 128;
+    dictionary->dictionaryBlob.resize(static_cast<uint64_t>(blobSize) + MAX_STRINGS_OVERFLOW_SIZE);
+    if (blobSize > 0 && blobStream == nullptr) {
+      throw ParseError("DICTIONARY_DATA stream not found in StringDictionaryColumn");
+    }
+    readFully(dictionary->dictionaryBlob.data(), blobSize, blobStream.get());
+    dictionaryLoaded = true;
+    return dictionary.get();
+  }
+
   class StringDirectColumnReader : public ColumnReader {
    private:
     std::unique_ptr<RleDecoder> lengthRle;
@@ -1082,6 +1108,10 @@ namespace orc {
     void seekToRowGroup(std::unordered_map<uint64_t, PositionProvider>& positions,
                         const ReadPhase& readPhase) override;
 
+    void loadStringDicts(const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+                         std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+                         const StringDictFilter* stringDictFilter);
+
    private:
     template <bool encoded>
     void nextInternal(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull,
@@ -1163,6 +1193,22 @@ namespace orc {
     }
   }
 
+  void StructColumnReader::loadStringDicts(
+      const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+      std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+      const StringDictFilter* stringDictFilter) {
+    for (auto& ptr : children) {
+      auto iter = columnIdToNameMap.find(ptr->getType().getColumnId());
+      if (iter == columnIdToNameMap.end()) {
+        continue;
+      }
+      auto* stringDictionaryColumnReader = dynamic_cast<StringDictionaryColumnReader*>(ptr.get());
+      if (stringDictionaryColumnReader != nullptr) {
+        (*columnNameToDictMap)[iter->second] = stringDictionaryColumnReader->loadDictionary();
+      }
+    }
+  }
+
   class ListColumnReader : public ColumnReader {
    private:
     std::unique_ptr<ColumnReader> child;
@@ -2296,4 +2342,12 @@ namespace orc {
     }
   }
 
+  void loadStringDicts(ColumnReader* columnReader,
+                       const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+                       std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+                       const StringDictFilter* stringDictFilter) {
+    auto* structColumnReader = static_cast<StructColumnReader*>(columnReader);
+    structColumnReader->loadStringDicts(columnIdToNameMap, columnNameToDictMap, stringDictFilter);
+  }
+
 }  // namespace orc
diff --git a/c++/src/ColumnReader.hh b/c++/src/ColumnReader.hh
index fd894891..c437d7cc 100644
--- a/c++/src/ColumnReader.hh
+++ b/c++/src/ColumnReader.hh
@@ -168,7 +168,7 @@ namespace orc {
                              const ReadPhase& readPhase = ReadPhase::ALL,
                              uint16_t* sel_rowid_idx = nullptr, size_t sel_size = 0) {
       rowBatch.isEncoded = false;
-      next(rowBatch, numValues, notNull, readPhase, sel_rowid_idx);
+      next(rowBatch, numValues, notNull, readPhase, sel_rowid_idx, sel_size);
     }
 
     /**
@@ -184,6 +184,11 @@ namespace orc {
    */
   std::unique_ptr<ColumnReader> buildReader(const Type& type, StripeStreams& stripe,
                                             bool useTightNumericVector = false);
+
+  void loadStringDicts(ColumnReader* columnReader,
+                       const std::unordered_map<uint64_t, std::string>& columnIdToNameMap,
+                       std::unordered_map<std::string, StringDictionary*>* columnNameToDictMap,
+                       const StringDictFilter* stringDictFilter);
 }  // namespace orc
 
 #endif
diff --git a/c++/src/Reader.cc b/c++/src/Reader.cc
index 485b874a..f69f27db 100644
--- a/c++/src/Reader.cc
+++ b/c++/src/Reader.cc
@@ -247,7 +247,8 @@ namespace orc {
   }
 
   RowReaderImpl::RowReaderImpl(std::shared_ptr<FileContents> _contents,
-                               const RowReaderOptions& opts, const ORCFilter* _filter)
+                               const RowReaderOptions& opts, const ORCFilter* _filter,
+                               const StringDictFilter* _stringDictFilter)
       : localTimezone(getLocalTimezone()),
         contents(_contents),
         throwOnHive11DecimalOverflow(opts.getThrowOnHive11DecimalOverflow()),
@@ -256,7 +257,8 @@ namespace orc {
         firstRowOfStripe(*contents->pool, 0),
         enableEncodedBlock(opts.getEnableLazyDecoding()),
         readerTimezone(getTimezoneByName(opts.getTimezoneName())),
-        filter(_filter) {
+        filter(_filter),
+        stringDictFilter(_stringDictFilter) {
     uint64_t numberOfStripes;
     numberOfStripes = static_cast<uint64_t>(footer->stripes_size());
     currentStripe = numberOfStripes;
@@ -893,18 +895,20 @@ namespace orc {
     }
   }
 
-  std::unique_ptr<RowReader> ReaderImpl::createRowReader(const ORCFilter* filter) const {
+  std::unique_ptr<RowReader> ReaderImpl::createRowReader(
+      const ORCFilter* filter, const StringDictFilter* stringDictFilter) const {
     RowReaderOptions defaultOpts;
-    return createRowReader(defaultOpts, filter);
+    return createRowReader(defaultOpts, filter, stringDictFilter);
   }
 
-  std::unique_ptr<RowReader> ReaderImpl::createRowReader(const RowReaderOptions& opts,
-                                                         const ORCFilter* filter) const {
+  std::unique_ptr<RowReader> ReaderImpl::createRowReader(
+      const RowReaderOptions& opts, const ORCFilter* filter,
+      const StringDictFilter* stringDictFilter) const {
     if (opts.getSearchArgument() && !isMetadataLoaded) {
       // load stripe statistics for PPD
       readMetadata();
     }
-    return std::make_unique<RowReaderImpl>(contents, opts, filter);
+    return std::make_unique<RowReaderImpl>(contents, opts, filter, stringDictFilter);
   }
 
   uint64_t maxStreamsForType(const proto::Type& type) {
@@ -1101,7 +1105,7 @@ namespace orc {
       return;
     }
 
-    do {
+    while (currentStripe < lastStripe) {
       currentStripeInfo = footer->stripes(static_cast<int>(currentStripe));
       uint64_t fileLength = contents->stream->getLength();
       if (currentStripeInfo.offset() + currentStripeInfo.indexlength() +
@@ -1144,21 +1148,16 @@ namespace orc {
           loadStripeIndex();
           // select row groups to read in the current stripe
           sargsApplier->pickRowGroups(rowsInCurrentStripe, rowIndexes, bloomFilterIndex);
-          if (sargsApplier->hasSelectedFrom(currentRowInStripe)) {
-            // current stripe has at least one row group matching the predicate
-            break;
-          }
-          isStripeNeeded = false;
+          isStripeNeeded = sargsApplier->hasSelectedFrom(currentRowInStripe);
         }
         if (!isStripeNeeded) {
           // advance to next stripe when current stripe has no matching rows
           currentStripe += 1;
           currentRowInStripe = 0;
+          continue;
         }
       }
-    } while (sargsApplier && currentStripe < lastStripe);
 
-    if (currentStripe < lastStripe) {
       // get writer timezone info from stripe footer to help understand timestamp values.
       const Timezone& writerTimezone = currentStripeFooter.has_writertimezone()
                                            ? getTimezoneByName(currentStripeFooter.writertimezone())
@@ -1168,6 +1167,35 @@ namespace orc {
                                       readerTimezone);
       reader = buildReader(*contents->schema, stripeStreams, useTightNumericVector);
 
+      if (stringDictFilter != nullptr) {
+        std::list<std::string> dictFilterColumnNames;
+        std::unique_ptr<StripeInformation> currentStripeInformation(new StripeInformationImpl(
+            currentStripeInfo.offset(), currentStripeInfo.indexlength(),
+            currentStripeInfo.datalength(), currentStripeInfo.footerlength(),
+            currentStripeInfo.numberofrows(), contents->stream.get(), *contents->pool,
+            contents->compression, contents->blockSize, contents->readerMetrics));
+        stringDictFilter->fillDictFilterColumnNames(std::move(currentStripeInformation),
+                                                    dictFilterColumnNames);
+        std::unordered_map<uint64_t, std::string> columnIdToNameMap;
+        for (auto& dictFilterColumnName : dictFilterColumnNames) {
+          columnIdToNameMap[nameTypeMap[dictFilterColumnName]->getColumnId()] =
+              dictFilterColumnName;
+        }
+        std::unordered_map<std::string, StringDictionary*> columnIdToDictMap;
+        loadStringDicts(reader.get(), columnIdToNameMap, &columnIdToDictMap, stringDictFilter);
+        if (!columnIdToNameMap.empty()) {
+          bool isStripeFiltered;
+          stringDictFilter->onStringDictsLoaded(columnIdToDictMap, &isStripeFiltered);
+          if (isStripeFiltered) {
+            reader.reset();
+            // advance to next stripe when current stripe has no matching rows
+            currentStripe += 1;
+            currentRowInStripe = 0;
+            continue;
+          }
+        }
+      }
+
       if (sargsApplier) {
         // move to the 1st selected row group when PPD is enabled.
         currentRowInStripe =
@@ -1179,7 +1207,9 @@ namespace orc {
                          readPhase);
         }
       }
-    } else {
+      break;
+    }
+    if (currentStripe >= lastStripe) {
       // All remaining stripes are skipped.
       markEndOfFile();
     }
diff --git a/c++/src/Reader.hh b/c++/src/Reader.hh
index d84532c8..2ff3bbe8 100644
--- a/c++/src/Reader.hh
+++ b/c++/src/Reader.hh
@@ -49,9 +49,19 @@ namespace orc {
       return *this;
     }
 
+    const ORCFilter* getStringDictFilter() const {
+      return filter;
+    }
+
+    ReaderContext& setStringDictFilter(const StringDictFilter* _stringDictFilter) {
+      this->stringDictFilter = _stringDictFilter;
+      return *this;
+    }
+
    private:
     std::unordered_set<int> filterColumnIds;
     const ORCFilter* filter;
+    const StringDictFilter* stringDictFilter;
   };
 
   /**
@@ -202,6 +212,8 @@ namespace orc {
     std::map<std::string, Type*> nameTypeMap;
     std::vector<std::string> columns;
 
+    const StringDictFilter* stringDictFilter;
+
     // load stripe index if not done so
     void loadStripeIndex();
 
@@ -257,7 +269,8 @@ namespace orc {
      * @param options options for reading
      */
     RowReaderImpl(std::shared_ptr<FileContents> contents, const RowReaderOptions& options,
-                  const ORCFilter* filter = nullptr);
+                  const ORCFilter* filter = nullptr,
+                  const StringDictFilter* stringDictFilter = nullptr);
 
     // Select the columns from the options object
     const std::vector<bool> getSelectedColumns() const override;
@@ -357,10 +370,13 @@ namespace orc {
 
     std::unique_ptr<StripeStatistics> getStripeStatistics(uint64_t stripeIndex) const override;
 
-    std::unique_ptr<RowReader> createRowReader(const ORCFilter* filter = nullptr) const override;
+    std::unique_ptr<RowReader> createRowReader(
+        const ORCFilter* filter = nullptr,
+        const StringDictFilter* stringDictFilter = nullptr) const override;
 
-    std::unique_ptr<RowReader> createRowReader(const RowReaderOptions& options,
-                                               const ORCFilter* filter = nullptr) const override;
+    std::unique_ptr<RowReader> createRowReader(
+        const RowReaderOptions& options, const ORCFilter* filter = nullptr,
+        const StringDictFilter* stringDictFilter = nullptr) const override;
 
     uint64_t getContentLength() const override;
     uint64_t getStripeStatisticsLength() const override;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org