You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@orc.apache.org by do...@apache.org on 2021/03/09 17:54:39 UTC

[orc] branch master updated: ORC-751: [C++] Implement Predicate Pushdown for C++ Reader (#648)

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/orc.git


The following commit(s) were added to refs/heads/master by this push:
     new 40ee321  ORC-751: [C++] Implement Predicate Pushdown for C++ Reader (#648)
40ee321 is described below

commit 40ee321b209cad1de73fc97b653636aa0fff28f6
Author: Gang Wu <ga...@alibaba-inc.com>
AuthorDate: Wed Mar 10 01:54:31 2021 +0800

    ORC-751: [C++] Implement Predicate Pushdown for C++ Reader (#648)
    
    1. Use RowReaderOptions to pass SearchArgument to enable PPD.
    2. Modify RowReaderImpl::startNextStripe to seek to next matched row group based on sarg evaluation.
    3. RowReaderImpl::next seeks to next matched row group based on sarg evaluation.
    4. RowReaderImpl::seekToRow also jumps to the 1st matched row group after specified row number.
---
 c++/include/orc/Reader.hh         |  11 ++
 c++/src/Options.hh                |  10 ++
 c++/src/Reader.cc                 | 232 +++++++++++++++++++++++++++++++-------
 c++/src/Reader.hh                 |  28 ++++-
 c++/src/sargs/SargsApplier.cc     |  10 +-
 c++/src/sargs/SargsApplier.hh     |   6 +
 c++/test/CMakeLists.txt           |   1 +
 c++/test/TestPredicatePushdown.cc | 187 ++++++++++++++++++++++++++++++
 c++/test/TestReader.cc            |  55 +++++++++
 9 files changed, 495 insertions(+), 45 deletions(-)

diff --git a/c++/include/orc/Reader.hh b/c++/include/orc/Reader.hh
index 267ae93..eee9551 100644
--- a/c++/include/orc/Reader.hh
+++ b/c++/include/orc/Reader.hh
@@ -23,6 +23,7 @@
 #include "orc/Common.hh"
 #include "orc/orc-config.hh"
 #include "orc/Statistics.hh"
+#include "orc/sargs/SearchArgument.hh"
 #include "orc/Type.hh"
 #include "orc/Vector.hh"
 
@@ -192,6 +193,11 @@ namespace orc {
     RowReaderOptions& setEnableLazyDecoding(bool enable);
 
     /**
+     * Set search argument for predicate push down
+     */
+    RowReaderOptions& searchArgument(std::unique_ptr<SearchArgument> sargs);
+
+    /**
      * Should enable encoding block mode
      */
     bool getEnableLazyDecoding() const;
@@ -245,6 +251,11 @@ namespace orc {
      * What scale should all Hive 0.11 decimals be normalized to?
      */
     int32_t getForcedScaleOnHive11Decimal() const;
+
+    /**
+     * Get search argument for predicate push down
+     */
+    std::shared_ptr<SearchArgument> getSearchArgument() const;
   };
 
 
diff --git a/c++/src/Options.hh b/c++/src/Options.hh
index 9581331..2808ffa 100644
--- a/c++/src/Options.hh
+++ b/c++/src/Options.hh
@@ -128,6 +128,7 @@ namespace orc {
     bool throwOnHive11DecimalOverflow;
     int32_t forcedScaleOnHive11Decimal;
     bool enableLazyDecoding;
+    std::shared_ptr<SearchArgument> sargs;
 
     RowReaderOptionsPrivate() {
       selection = ColumnSelection_NONE;
@@ -249,6 +250,15 @@ namespace orc {
     privateBits->enableLazyDecoding = enable;
     return *this;
   }
+
+  RowReaderOptions& RowReaderOptions::searchArgument(std::unique_ptr<SearchArgument> sargs) {
+    privateBits->sargs = std::move(sargs);
+    return *this;
+  }
+
+  std::shared_ptr<SearchArgument> RowReaderOptions::getSearchArgument() const {
+    return privateBits->sargs;
+  }
 }
 
 #endif
diff --git a/c++/src/Reader.cc b/c++/src/Reader.cc
index 37225b9..42ab41c 100644
--- a/c++/src/Reader.cc
+++ b/c++/src/Reader.cc
@@ -68,6 +68,12 @@ namespace orc {
       return columnPath.substr(0, columnPath.length() - 1);
   }
 
+  WriterVersion getWriterVersionImpl(const FileContents * contents) {
+    if (!contents->postscript->has_writerversion()) {
+      return WriterVersion_ORIGINAL;
+    }
+    return static_cast<WriterVersion>(contents->postscript->writerversion());
+  }
 
   void ColumnSelector::selectChildren(std::vector<bool>& selectedColumns, const Type& type) {
     size_t id = static_cast<size_t>(type.getColumnId());
@@ -227,6 +233,15 @@ namespace orc {
 
     ColumnSelector column_selector(contents.get());
     column_selector.updateSelected(selectedColumns, opts);
+
+    // prepare SargsApplier if SearchArgument is available
+    if (opts.getSearchArgument() && footer->rowindexstride() > 0) {
+      sargs = opts.getSearchArgument();
+      sargsApplier.reset(new SargsApplier(*contents->schema,
+                                          sargs.get(),
+                                          footer->rowindexstride(),
+                                          getWriterVersionImpl(_contents.get())));
+    }
   }
 
   CompressionKind RowReaderImpl::getCompression() const {
@@ -293,25 +308,34 @@ namespace orc {
     previousRow = rowNumber;
     startNextStripe();
 
-    uint64_t rowsToSkip = currentRowInStripe;
+    // when predicate push down is enabled, above call to startNextStripe()
+    // will move current row to 1st matching row group; here we only need
+    // to deal with the case when PPD is not enabled.
+    if (!sargsApplier) {
+      uint64_t rowsToSkip = currentRowInStripe;
 
-    if (footer->rowindexstride() > 0 &&
-        currentStripeInfo.indexlength() > 0) {
-      uint32_t rowGroupId =
-        static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride());
-      rowsToSkip -= static_cast<uint64_t>(rowGroupId) * footer->rowindexstride();
+      if (footer->rowindexstride() > 0 &&
+          currentStripeInfo.indexlength() > 0) {
+        if (rowIndexes.empty()) {
+          loadStripeIndex();
+        }
+        uint32_t rowGroupId =
+          static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride());
+        rowsToSkip -= static_cast<uint64_t>(rowGroupId) * footer->rowindexstride();
 
-      if (rowGroupId != 0) {
-        seekToRowGroup(rowGroupId);
+        if (rowGroupId != 0) {
+          seekToRowGroup(rowGroupId);
+        }
       }
-    }
 
-    reader->skip(rowsToSkip);
+      reader->skip(rowsToSkip);
+    }
   }
 
-  void RowReaderImpl::seekToRowGroup(uint32_t rowGroupEntryId) {
+  void RowReaderImpl::loadStripeIndex() {
     // reset all previous row indexes
     rowIndexes.clear();
+    bloomFilterIndex.clear();
 
     // obtain row indexes for selected columns
     uint64_t offset = currentStripeInfo.offset();
@@ -319,7 +343,8 @@ namespace orc {
       const proto::Stream& pbStream = currentStripeFooter.streams(i);
       uint64_t colId = pbStream.column();
       if (selectedColumns[colId] && pbStream.has_kind()
-          && pbStream.kind() == proto::Stream_Kind_ROW_INDEX) {
+          && (pbStream.kind() == proto::Stream_Kind_ROW_INDEX ||
+              pbStream.kind() == proto::Stream_Kind_BLOOM_FILTER_UTF8)) {
         std::unique_ptr<SeekableInputStream> inStream =
           createDecompressor(getCompression(),
                              std::unique_ptr<SeekableInputStream>
@@ -331,16 +356,33 @@ namespace orc {
                              getCompressionSize(),
                              *contents->pool);
 
-        proto::RowIndex rowIndex;
-        if (!rowIndex.ParseFromZeroCopyStream(inStream.get())) {
-          throw ParseError("Failed to parse the row index");
+        if (pbStream.kind() == proto::Stream_Kind_ROW_INDEX) {
+          proto::RowIndex rowIndex;
+          if (!rowIndex.ParseFromZeroCopyStream(inStream.get())) {
+            throw ParseError("Failed to parse the row index");
+          }
+          rowIndexes[colId] = rowIndex;
+        } else { // Stream_Kind_BLOOM_FILTER_UTF8
+          proto::BloomFilterIndex pbBFIndex;
+          if (!pbBFIndex.ParseFromZeroCopyStream(inStream.get())) {
+            throw ParseError("Failed to parse bloom filter index");
+          }
+          BloomFilterIndex bfIndex;
+          for (int j = 0; j < pbBFIndex.bloomfilter_size(); j++) {
+            bfIndex.entries.push_back(BloomFilterUTF8Utils::deserialize(
+              pbStream.kind(),
+              currentStripeFooter.columns(static_cast<int>(pbStream.column())),
+              pbBFIndex.bloomfilter(j)));
+          }
+          // add bloom filters to result for one column
+          bloomFilterIndex[pbStream.column()] = bfIndex;
         }
-
-        rowIndexes[colId] = rowIndex;
       }
       offset += pbStream.length();
     }
+  }
 
+  void RowReaderImpl::seekToRowGroup(uint32_t rowGroupEntryId) {
     // store positions for selected columns
     std::vector<std::list<uint64_t>> positions;
     // store position providers for selected colimns
@@ -516,10 +558,7 @@ namespace orc {
   }
 
   WriterVersion ReaderImpl::getWriterVersion() const {
-    if (!contents->postscript->has_writerversion()) {
-      return WriterVersion_ORIGINAL;
-    }
-    return static_cast<WriterVersion>(contents->postscript->writerversion());
+    return getWriterVersionImpl(contents.get());
   }
 
   uint64_t ReaderImpl::getContentLength() const {
@@ -892,29 +931,68 @@ namespace orc {
 
   void RowReaderImpl::startNextStripe() {
     reader.reset(); // ColumnReaders use lots of memory; free old memory first
-    currentStripeInfo = footer->stripes(static_cast<int>(currentStripe));
-    uint64_t fileLength = contents->stream->getLength();
-    if (currentStripeInfo.offset() + currentStripeInfo.indexlength() +
+    rowIndexes.clear();
+    bloomFilterIndex.clear();
+
+    do {
+      currentStripeInfo = footer->stripes(static_cast<int>(currentStripe));
+      uint64_t fileLength = contents->stream->getLength();
+      if (currentStripeInfo.offset() + currentStripeInfo.indexlength() +
         currentStripeInfo.datalength() + currentStripeInfo.footerlength() >= fileLength) {
-      std::stringstream msg;
-      msg << "Malformed StripeInformation at stripe index " << currentStripe << ": fileLength="
-          << fileLength << ", StripeInfo=(offset=" << currentStripeInfo.offset() << ", indexLength="
-          << currentStripeInfo.indexlength() << ", dataLength=" << currentStripeInfo.datalength()
-          << ", footerLength=" << currentStripeInfo.footerlength() << ")";
-      throw ParseError(msg.str());
+        std::stringstream msg;
+        msg << "Malformed StripeInformation at stripe index " << currentStripe << ": fileLength="
+            << fileLength << ", StripeInfo=(offset=" << currentStripeInfo.offset() << ", indexLength="
+            << currentStripeInfo.indexlength() << ", dataLength=" << currentStripeInfo.datalength()
+            << ", footerLength=" << currentStripeInfo.footerlength() << ")";
+        throw ParseError(msg.str());
+      }
+      currentStripeFooter = getStripeFooter(currentStripeInfo, *contents.get());
+      rowsInCurrentStripe = currentStripeInfo.numberofrows();
+
+      if (sargsApplier) {
+        // read row group statistics and bloom filters of current stripe
+        loadStripeIndex();
+
+        // select row groups to read in the current stripe
+        sargsApplier->pickRowGroups(rowsInCurrentStripe,
+                                    rowIndexes,
+                                    bloomFilterIndex);
+        if (sargsApplier->hasSelectedFrom(currentRowInStripe)) {
+          // current stripe has at least one row group matching the predicate
+          break;
+        } else {
+          // advance to next stripe when current stripe has no matching rows
+          currentStripe += 1;
+          currentRowInStripe = 0;
+        }
+      }
+    } while (sargsApplier && currentStripe < lastStripe);
+
+    if (currentStripe < lastStripe) {
+      // get writer timezone info from stripe footer to help understand timestamp values.
+      const Timezone& writerTimezone =
+        currentStripeFooter.has_writertimezone() ?
+          getTimezoneByName(currentStripeFooter.writertimezone()) :
+          localTimezone;
+      StripeStreamsImpl stripeStreams(*this, currentStripe, currentStripeInfo,
+                                      currentStripeFooter,
+                                      currentStripeInfo.offset(),
+                                      *contents->stream,
+                                      writerTimezone);
+      reader = buildReader(*contents->schema, stripeStreams);
+
+      if (sargsApplier) {
+        // move to the 1st selected row group when PPD is enabled.
+        currentRowInStripe = advanceToNextRowGroup(currentRowInStripe,
+                                                   rowsInCurrentStripe,
+                                                   footer->rowindexstride(),
+                                                   sargsApplier->getRowGroups());
+        previousRow = firstRowOfStripe[currentStripe] + currentRowInStripe - 1;
+        if (currentRowInStripe > 0) {
+          seekToRowGroup(static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride()));
+        }
+      }
     }
-    currentStripeFooter = getStripeFooter(currentStripeInfo, *contents.get());
-    rowsInCurrentStripe = currentStripeInfo.numberofrows();
-    const Timezone& writerTimezone =
-      currentStripeFooter.has_writertimezone() ?
-        getTimezoneByName(currentStripeFooter.writertimezone()) :
-        localTimezone;
-    StripeStreamsImpl stripeStreams(*this, currentStripe, currentStripeInfo,
-                                    currentStripeFooter,
-                                    currentStripeInfo.offset(),
-                                    *(contents->stream.get()),
-                                    writerTimezone);
-    reader = buildReader(*contents->schema.get(), stripeStreams);
   }
 
   bool RowReaderImpl::next(ColumnVectorBatch& data) {
@@ -934,7 +1012,20 @@ namespace orc {
     uint64_t rowsToRead =
       std::min(static_cast<uint64_t>(data.capacity),
                rowsInCurrentStripe - currentRowInStripe);
+    if (sargsApplier) {
+      rowsToRead = computeBatchSize(rowsToRead,
+                                    currentRowInStripe,
+                                    rowsInCurrentStripe,
+                                    footer->rowindexstride(),
+                                    sargsApplier->getRowGroups());
+    }
     data.numElements = rowsToRead;
+    if (rowsToRead == 0) {
+      previousRow = lastStripe <= 0 ? footer->numberofrows() :
+                    firstRowOfStripe[lastStripe - 1] +
+                    footer->stripes(static_cast<int>(lastStripe - 1)).numberofrows();
+      return false;
+    }
     if (enableEncodedBlock) {
       reader->nextEncoded(data, rowsToRead, nullptr);
     }
@@ -944,6 +1035,22 @@ namespace orc {
     // update row number
     previousRow = firstRowOfStripe[currentStripe] + currentRowInStripe;
     currentRowInStripe += rowsToRead;
+
+    // check if we need to advance to next selected row group
+    if (sargsApplier) {
+      uint64_t nextRowToRead = advanceToNextRowGroup(currentRowInStripe,
+                                                     rowsInCurrentStripe,
+                                                     footer->rowindexstride(),
+                                                     sargsApplier->getRowGroups());
+      if (currentRowInStripe != nextRowToRead) {
+        // it is guaranteed to be at start of a row group
+        currentRowInStripe = nextRowToRead;
+        if (currentRowInStripe < rowsInCurrentStripe) {
+          seekToRowGroup(static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride()));
+        }
+      }
+    }
+
     if (currentRowInStripe >= rowsInCurrentStripe) {
       currentStripe += 1;
       currentRowInStripe = 0;
@@ -951,6 +1058,47 @@ namespace orc {
     return rowsToRead != 0;
   }
 
+  uint64_t RowReaderImpl::computeBatchSize(uint64_t requestedSize,
+                                           uint64_t currentRowInStripe,
+                                           uint64_t rowsInCurrentStripe,
+                                           uint64_t rowIndexStride,
+                                           const std::vector<bool>& includedRowGroups) {
+    // In case of PPD, batch size should be aware of row group boundaries. If only a subset of row
+    // groups are selected then marker position is set to the end of range (subset of row groups
+    // within stripe).
+    uint64_t endRowInStripe = rowsInCurrentStripe;
+    if (!includedRowGroups.empty()) {
+      endRowInStripe = currentRowInStripe;
+      uint32_t rg = static_cast<uint32_t>(currentRowInStripe / rowIndexStride);
+      for (; rg < includedRowGroups.size(); ++rg) {
+        if (!includedRowGroups[rg]) {
+          break;
+        } else {
+          endRowInStripe = std::min(rowsInCurrentStripe, (rg + 1) * rowIndexStride);
+        }
+      }
+    }
+    return std::min(requestedSize, endRowInStripe - currentRowInStripe);
+  }
+
+  uint64_t RowReaderImpl::advanceToNextRowGroup(uint64_t currentRowInStripe,
+                                                uint64_t rowsInCurrentStripe,
+                                                uint64_t rowIndexStride,
+                                                const std::vector<bool>& includedRowGroups) {
+    if (!includedRowGroups.empty()) {
+      uint32_t rg = static_cast<uint32_t>(currentRowInStripe / rowIndexStride);
+      for (; rg < includedRowGroups.size(); ++rg) {
+        if (includedRowGroups[rg]) {
+          return currentRowInStripe;
+        } else {
+          // advance to start of next row group
+          currentRowInStripe = (rg + 1) * rowIndexStride;
+        }
+      }
+    }
+    return std::min(currentRowInStripe, rowsInCurrentStripe);
+  }
+
   std::unique_ptr<ColumnVectorBatch> RowReaderImpl::createRowBatch
                                               (uint64_t capacity) const {
     return getSelectedType().createRowBatch(capacity, *contents->pool, enableEncodedBlock);
diff --git a/c++/src/Reader.hh b/c++/src/Reader.hh
index a381956..0693c62 100644
--- a/c++/src/Reader.hh
+++ b/c++/src/Reader.hh
@@ -19,13 +19,14 @@
 #ifndef ORC_READER_IMPL_HH
 #define ORC_READER_IMPL_HH
 
+#include "orc/Exceptions.hh"
 #include "orc/Int128.hh"
 #include "orc/OrcFile.hh"
 #include "orc/Reader.hh"
 
 #include "ColumnReader.hh"
-#include "orc/Exceptions.hh"
 #include "RLE.hh"
+#include "sargs/SargsApplier.hh"
 #include "TypeImpl.hh"
 
 namespace orc {
@@ -142,6 +143,30 @@ namespace orc {
 
     // row index of current stripe with column id as the key
     std::unordered_map<uint64_t, proto::RowIndex> rowIndexes;
+    std::map<uint32_t, BloomFilterIndex> bloomFilterIndex;
+    std::shared_ptr<SearchArgument> sargs;
+    std::unique_ptr<SargsApplier> sargsApplier;
+
+    // load stripe index if not done so
+    void loadStripeIndex();
+
+    // In case of PPD, batch size should be aware of row group boundaries.
+    // If only a subset of row groups are selected then the next read should
+    // stop at the end of selected range.
+    static uint64_t computeBatchSize(uint64_t requestedSize,
+                                     uint64_t currentRowInStripe,
+                                     uint64_t rowsInCurrentStripe,
+                                     uint64_t rowIndexStride,
+                                     const std::vector<bool>& includedRowGroups);
+
+    // Skip non-selected rows
+    static uint64_t advanceToNextRowGroup(uint64_t currentRowInStripe,
+                                          uint64_t rowsInCurrentStripe,
+                                          uint64_t rowIndexStride,
+                                          const std::vector<bool>& includedRowGroups);
+
+    friend class TestRowReader_advanceToNextRowGroup_Test;
+    friend class TestRowReader_computeBatchSize_Test;
 
     /**
      * Seek to the start of a row group in the current stripe
@@ -159,7 +184,6 @@ namespace orc {
                   const RowReaderOptions& options);
 
     // Select the columns from the options object
-    void updateSelected();
     const std::vector<bool> getSelectedColumns() const override;
 
     const Type& getSelectedType() const override;
diff --git a/c++/src/sargs/SargsApplier.cc b/c++/src/sargs/SargsApplier.cc
index ea6b2c5..af709ab 100644
--- a/c++/src/sargs/SargsApplier.cc
+++ b/c++/src/sargs/SargsApplier.cc
@@ -17,6 +17,7 @@
  */
 
 #include "SargsApplier.hh"
+#include <numeric>
 
 namespace orc {
 
@@ -43,7 +44,8 @@ namespace orc {
                              : mType(type)
                              , mSearchArgument(searchArgument)
                              , mRowIndexStride(rowIndexStride)
-                             , mWriterVersion(writerVersion) {
+                             , mWriterVersion(writerVersion)
+                             , mStats(0, 0) {
     const SearchArgumentImpl * sargs =
       dynamic_cast<const SearchArgumentImpl *>(mSearchArgument);
 
@@ -106,6 +108,12 @@ namespace orc {
       mHasSkipped = mHasSkipped || (!mRowGroups[rowGroup]);
     }
 
+    // update stats
+    mStats.first = std::accumulate(
+      mRowGroups.cbegin(), mRowGroups.cend(), mStats.first,
+      [](bool rg, uint64_t s) { return rg ? 1 : 0 + s; });
+    mStats.second += groupsInStripe;
+
     return mHasSelected;
   }
 
diff --git a/c++/src/sargs/SargsApplier.hh b/c++/src/sargs/SargsApplier.hh
index a2b5e69..cc6db92 100644
--- a/c++/src/sargs/SargsApplier.hh
+++ b/c++/src/sargs/SargsApplier.hh
@@ -76,6 +76,10 @@ namespace orc {
       return false;
     }
 
+    std::pair<uint64_t, uint64_t> getStats() const {
+      return mStats;
+    }
+
   private:
     friend class TestSargsApplier_findColumnTest_Test;
     static uint64_t findColumn(const Type& type, const std::string& colName);
@@ -93,6 +97,8 @@ namespace orc {
     uint64_t mTotalRowsInStripe;
     bool mHasSelected;
     bool mHasSkipped;
+    // keep stats of selected RGs and evaluated RGs
+    std::pair<uint64_t, uint64_t> mStats;
   };
 
 }
diff --git a/c++/test/CMakeLists.txt b/c++/test/CMakeLists.txt
index da5ef57..badee89 100644
--- a/c++/test/CMakeLists.txt
+++ b/c++/test/CMakeLists.txt
@@ -35,6 +35,7 @@ add_executable (orc-test
   TestDriver.cc
   TestInt128.cc
   TestPredicateLeaf.cc
+  TestPredicatePushdown.cc
   TestReader.cc
   TestRleDecoder.cc
   TestRleEncoder.cc
diff --git a/c++/test/TestPredicatePushdown.cc b/c++/test/TestPredicatePushdown.cc
new file mode 100644
index 0000000..bae5687
--- /dev/null
+++ b/c++/test/TestPredicatePushdown.cc
@@ -0,0 +1,187 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "orc/OrcFile.hh"
+#include "orc/sargs/SearchArgument.hh"
+#include "MemoryInputStream.hh"
+#include "MemoryOutputStream.hh"
+#include "wrap/gtest-wrapper.h"
+
+namespace orc {
+
+  static const int DEFAULT_MEM_STREAM_SIZE = 10 * 1024 * 1024; // 10M
+
+  TEST(TestPredicatePushdown, testPredicatePushdown) {
+    MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+    MemoryPool * pool = getDefaultPool();
+    auto type = std::unique_ptr<Type>(Type::buildTypeFromString(
+      "struct<int1:bigint,string1:string>"));
+    WriterOptions options;
+    options.setStripeSize(1024 * 1024)
+      .setCompressionBlockSize(1024)
+      .setCompression(CompressionKind_NONE)
+      .setMemoryPool(pool)
+      .setRowIndexStride(1000);
+
+    auto writer = createWriter(*type, &memStream, options);
+    auto batch = writer->createRowBatch(3500);
+    auto& structBatch = dynamic_cast<StructVectorBatch&>(*batch);
+    auto& longBatch = dynamic_cast<LongVectorBatch&>(*structBatch.fields[0]);
+    auto& strBatch = dynamic_cast<StringVectorBatch&>(*structBatch.fields[1]);
+
+    // row group stride is 1000, here 3500 rows of data constitute 4 row groups.
+    // the min/max pair of each row group is as below:
+    // int1: 0/299700, 300000/599700, 600000/899700, 900000/1049700
+    // string1: "0"/"9990", "10000"/"19990", "20000"/"29990", "30000"/"34990"
+    char buffer[3500 * 5];
+    uint64_t offset = 0;
+    for (uint64_t i = 0; i < 3500; ++i) {
+      longBatch.data[i] = static_cast<int64_t>(i * 300);
+
+      std::ostringstream ss;
+      ss << 10 * i;
+      std::string str = ss.str();
+      memcpy(buffer + offset, str.c_str(), str.size());
+      strBatch.data[i] = buffer + offset;
+      strBatch.length[i] = static_cast<int64_t>(str.size());
+      offset += str.size();
+    }
+    structBatch.numElements = 3500;
+    longBatch.numElements = 3500;
+    strBatch.numElements = 3500;
+    writer->add(*batch);
+    writer->close();
+
+    std::unique_ptr<InputStream> inStream(new MemoryInputStream (
+      memStream.getData(), memStream.getLength()));
+    ReaderOptions readerOptions;
+    options.setMemoryPool(pool);
+    auto reader = createReader(std::move(inStream), readerOptions);
+    EXPECT_EQ(3500, reader->getNumberOfRows());
+
+    // build search argument (x >= 300000 AND x < 600000)
+    {
+      std::unique_ptr<SearchArgument> sarg = SearchArgumentFactory::newBuilder()
+        ->startAnd()
+        .startNot()
+        .lessThan("int1", PredicateDataType::LONG,
+                  Literal(static_cast<int64_t>(300000L)))
+        .end()
+        .lessThan("int1", PredicateDataType::LONG,
+                  Literal(static_cast<int64_t>(600000L)))
+        .end()
+        .build();
+
+      RowReaderOptions rowReaderOpts;
+      rowReaderOpts.searchArgument(std::move(sarg));
+      auto rowReader = reader->createRowReader(rowReaderOpts);
+
+      auto readBatch = rowReader->createRowBatch(2000);
+      auto& batch0 = dynamic_cast<StructVectorBatch&>(*readBatch);
+      auto& batch1 = dynamic_cast<LongVectorBatch&>(*batch0.fields[0]);
+      auto& batch2 = dynamic_cast<StringVectorBatch&>(*batch0.fields[1]);
+
+      EXPECT_EQ(true, rowReader->next(*readBatch));
+      EXPECT_EQ(1000, readBatch->numElements);
+      EXPECT_EQ(1000, rowReader->getRowNumber());
+      for (uint64_t i = 1000; i < 2000; ++i) {
+        EXPECT_EQ(300 * i, batch1.data[i - 1000]);
+        EXPECT_EQ(std::to_string(10 * i),
+          std::string(batch2.data[i - 1000], static_cast<size_t>(batch2.length[i - 1000])));
+      }
+      EXPECT_EQ(false, rowReader->next(*readBatch));
+      EXPECT_EQ(3500, rowReader->getRowNumber());
+    }
+
+    // look through the file with no rows selected: x < 0
+    {
+      std::unique_ptr<SearchArgument> sarg = SearchArgumentFactory::newBuilder()
+        ->startAnd()
+        .lessThan("int1", PredicateDataType::LONG,
+          Literal(static_cast<int64_t>(0)))
+        .end()
+        .build();
+
+      RowReaderOptions rowReaderOpts;
+      rowReaderOpts.searchArgument(std::move(sarg));
+      auto rowReader = reader->createRowReader(rowReaderOpts);
+
+      auto readBatch = rowReader->createRowBatch(2000);
+      EXPECT_EQ(false, rowReader->next(*readBatch));
+      EXPECT_EQ(3500, rowReader->getRowNumber());
+    }
+
+    // select first 1000 and last 500 rows: x < 30000 OR x >= 1020000
+    {
+      std::unique_ptr<SearchArgument> sarg = SearchArgumentFactory::newBuilder()
+        ->startOr()
+        .lessThan("int1", PredicateDataType::LONG,
+          Literal(static_cast<int64_t>(300 * 100)))
+        .startNot()
+        .lessThan("int1", PredicateDataType::LONG,
+          Literal(static_cast<int64_t>(300 * 3400)))
+        .end()
+        .end()
+        .build();
+
+      RowReaderOptions rowReaderOpts;
+      rowReaderOpts.searchArgument(std::move(sarg));
+      auto rowReader = reader->createRowReader(rowReaderOpts);
+
+      auto readBatch = rowReader->createRowBatch(2000);
+      auto& batch0 = dynamic_cast<StructVectorBatch&>(*readBatch);
+      auto& batch1 = dynamic_cast<LongVectorBatch&>(*batch0.fields[0]);
+      auto& batch2 = dynamic_cast<StringVectorBatch&>(*batch0.fields[1]);
+
+      EXPECT_EQ(true, rowReader->next(*readBatch));
+      EXPECT_EQ(1000, readBatch->numElements);
+      EXPECT_EQ(0, rowReader->getRowNumber());
+      for (uint64_t i = 0; i < 1000; ++i) {
+        EXPECT_EQ(300 * i, batch1.data[i]);
+        EXPECT_EQ(std::to_string(10 * i),
+                  std::string(batch2.data[i], static_cast<size_t>(batch2.length[i])));
+      }
+
+      EXPECT_EQ(true, rowReader->next(*readBatch));
+      EXPECT_EQ(500, readBatch->numElements);
+      EXPECT_EQ(3000, rowReader->getRowNumber());
+      for (uint64_t i = 3000; i < 3500; ++i) {
+        EXPECT_EQ(300 * i, batch1.data[i - 3000]);
+        EXPECT_EQ(std::to_string(10 * i),
+                  std::string(batch2.data[i - 3000], static_cast<size_t>(batch2.length[i - 3000])));
+      }
+
+      EXPECT_EQ(false, rowReader->next(*readBatch));
+      EXPECT_EQ(3500, rowReader->getRowNumber());
+
+      // test seek to 3rd row group but is adjusted to 4th row group
+      rowReader->seekToRow(2500);
+      EXPECT_EQ(true, rowReader->next(*readBatch));
+      EXPECT_EQ(3000, rowReader->getRowNumber());
+      EXPECT_EQ(500, readBatch->numElements);
+      for (uint64_t i = 3000; i < 3500; ++i) {
+        EXPECT_EQ(300 * i, batch1.data[i - 3000]);
+        EXPECT_EQ(std::to_string(10 * i),
+                  std::string(batch2.data[i - 3000], static_cast<size_t>(batch2.length[i - 3000])));
+      }
+      EXPECT_EQ(false, rowReader->next(*readBatch));
+      EXPECT_EQ(3500, rowReader->getRowNumber());
+    }
+  }
+
+}  // namespace orc
diff --git a/c++/test/TestReader.cc b/c++/test/TestReader.cc
index 68f6de1..e31ed11 100644
--- a/c++/test/TestReader.cc
+++ b/c++/test/TestReader.cc
@@ -17,6 +17,7 @@
  */
 
 #include "orc/Reader.hh"
+#include "Reader.hh"
 
 #include "Adaptor.hh"
 
@@ -46,4 +47,58 @@ namespace orc {
               compressionKindToString(static_cast<CompressionKind>(99)));
   }
 
+  TEST(TestRowReader, computeBatchSize) {
+    uint64_t rowIndexStride = 100;
+    uint64_t rowsInCurrentStripe = 100 * 8 + 50;
+    std::vector<bool> includedRowGroups =
+      { false, false, true, true, false, false, true, true, false };
+
+    EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+      1024, 0, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+      1024, 50, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(200, RowReaderImpl::computeBatchSize(
+      1024, 200, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(150, RowReaderImpl::computeBatchSize(
+      1024, 250, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+      1024, 550, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(100, RowReaderImpl::computeBatchSize(
+      1024, 700, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(50, RowReaderImpl::computeBatchSize(
+      50, 700, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+      50, 810, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+      50, 900, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+  }
+
+  TEST(TestRowReader, advanceToNextRowGroup) {
+    uint64_t rowIndexStride = 100;
+    uint64_t rowsInCurrentStripe = 100 * 8 + 50;
+    std::vector<bool> includedRowGroups =
+      { false, false, true, true, false, false, true, true, false };
+
+    EXPECT_EQ(200, RowReaderImpl::advanceToNextRowGroup(
+      0, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(200, RowReaderImpl::advanceToNextRowGroup(
+      150, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(250, RowReaderImpl::advanceToNextRowGroup(
+      250, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(350, RowReaderImpl::advanceToNextRowGroup(
+      350, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(350, RowReaderImpl::advanceToNextRowGroup(
+      350, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(600, RowReaderImpl::advanceToNextRowGroup(
+      500, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(699, RowReaderImpl::advanceToNextRowGroup(
+      699, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(799, RowReaderImpl::advanceToNextRowGroup(
+      799, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(850, RowReaderImpl::advanceToNextRowGroup(
+      800, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+    EXPECT_EQ(850, RowReaderImpl::advanceToNextRowGroup(
+      900, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+  }
+
 }  // namespace