You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@orc.apache.org by do...@apache.org on 2021/03/09 17:54:39 UTC
[orc] branch master updated: ORC-751: [C++] Implement Predicate
Pushdown for C++ Reader (#648)
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/orc.git
The following commit(s) were added to refs/heads/master by this push:
new 40ee321 ORC-751: [C++] Implement Predicate Pushdown for C++ Reader (#648)
40ee321 is described below
commit 40ee321b209cad1de73fc97b653636aa0fff28f6
Author: Gang Wu <ga...@alibaba-inc.com>
AuthorDate: Wed Mar 10 01:54:31 2021 +0800
ORC-751: [C++] Implement Predicate Pushdown for C++ Reader (#648)
1. Use RowReaderOptions to pass SearchArgument to enable PPD.
2. Modify RowReaderImpl::startNextStripe to seek to next matched row group based on sarg evaluation.
3. RowReaderImpl::next seeks to next matched row group based on sarg evaluation.
4. RowReaderImpl::seekToRow also jumps to the 1st matched row group after specified row number.
---
c++/include/orc/Reader.hh | 11 ++
c++/src/Options.hh | 10 ++
c++/src/Reader.cc | 232 +++++++++++++++++++++++++++++++-------
c++/src/Reader.hh | 28 ++++-
c++/src/sargs/SargsApplier.cc | 10 +-
c++/src/sargs/SargsApplier.hh | 6 +
c++/test/CMakeLists.txt | 1 +
c++/test/TestPredicatePushdown.cc | 187 ++++++++++++++++++++++++++++++
c++/test/TestReader.cc | 55 +++++++++
9 files changed, 495 insertions(+), 45 deletions(-)
diff --git a/c++/include/orc/Reader.hh b/c++/include/orc/Reader.hh
index 267ae93..eee9551 100644
--- a/c++/include/orc/Reader.hh
+++ b/c++/include/orc/Reader.hh
@@ -23,6 +23,7 @@
#include "orc/Common.hh"
#include "orc/orc-config.hh"
#include "orc/Statistics.hh"
+#include "orc/sargs/SearchArgument.hh"
#include "orc/Type.hh"
#include "orc/Vector.hh"
@@ -192,6 +193,11 @@ namespace orc {
RowReaderOptions& setEnableLazyDecoding(bool enable);
/**
+ * Set search argument for predicate push down
+ */
+ RowReaderOptions& searchArgument(std::unique_ptr<SearchArgument> sargs);
+
+ /**
* Should enable encoding block mode
*/
bool getEnableLazyDecoding() const;
@@ -245,6 +251,11 @@ namespace orc {
* What scale should all Hive 0.11 decimals be normalized to?
*/
int32_t getForcedScaleOnHive11Decimal() const;
+
+ /**
+ * Get search argument for predicate push down
+ */
+ std::shared_ptr<SearchArgument> getSearchArgument() const;
};
diff --git a/c++/src/Options.hh b/c++/src/Options.hh
index 9581331..2808ffa 100644
--- a/c++/src/Options.hh
+++ b/c++/src/Options.hh
@@ -128,6 +128,7 @@ namespace orc {
bool throwOnHive11DecimalOverflow;
int32_t forcedScaleOnHive11Decimal;
bool enableLazyDecoding;
+ std::shared_ptr<SearchArgument> sargs;
RowReaderOptionsPrivate() {
selection = ColumnSelection_NONE;
@@ -249,6 +250,15 @@ namespace orc {
privateBits->enableLazyDecoding = enable;
return *this;
}
+
+ RowReaderOptions& RowReaderOptions::searchArgument(std::unique_ptr<SearchArgument> sargs) {
+ privateBits->sargs = std::move(sargs);
+ return *this;
+ }
+
+ std::shared_ptr<SearchArgument> RowReaderOptions::getSearchArgument() const {
+ return privateBits->sargs;
+ }
}
#endif
diff --git a/c++/src/Reader.cc b/c++/src/Reader.cc
index 37225b9..42ab41c 100644
--- a/c++/src/Reader.cc
+++ b/c++/src/Reader.cc
@@ -68,6 +68,12 @@ namespace orc {
return columnPath.substr(0, columnPath.length() - 1);
}
+ WriterVersion getWriterVersionImpl(const FileContents * contents) {
+ if (!contents->postscript->has_writerversion()) {
+ return WriterVersion_ORIGINAL;
+ }
+ return static_cast<WriterVersion>(contents->postscript->writerversion());
+ }
void ColumnSelector::selectChildren(std::vector<bool>& selectedColumns, const Type& type) {
size_t id = static_cast<size_t>(type.getColumnId());
@@ -227,6 +233,15 @@ namespace orc {
ColumnSelector column_selector(contents.get());
column_selector.updateSelected(selectedColumns, opts);
+
+ // prepare SargsApplier if SearchArgument is available
+ if (opts.getSearchArgument() && footer->rowindexstride() > 0) {
+ sargs = opts.getSearchArgument();
+ sargsApplier.reset(new SargsApplier(*contents->schema,
+ sargs.get(),
+ footer->rowindexstride(),
+ getWriterVersionImpl(_contents.get())));
+ }
}
CompressionKind RowReaderImpl::getCompression() const {
@@ -293,25 +308,34 @@ namespace orc {
previousRow = rowNumber;
startNextStripe();
- uint64_t rowsToSkip = currentRowInStripe;
+ // when predicate push down is enabled, above call to startNextStripe()
+ // will move current row to 1st matching row group; here we only need
+ // to deal with the case when PPD is not enabled.
+ if (!sargsApplier) {
+ uint64_t rowsToSkip = currentRowInStripe;
- if (footer->rowindexstride() > 0 &&
- currentStripeInfo.indexlength() > 0) {
- uint32_t rowGroupId =
- static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride());
- rowsToSkip -= static_cast<uint64_t>(rowGroupId) * footer->rowindexstride();
+ if (footer->rowindexstride() > 0 &&
+ currentStripeInfo.indexlength() > 0) {
+ if (rowIndexes.empty()) {
+ loadStripeIndex();
+ }
+ uint32_t rowGroupId =
+ static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride());
+ rowsToSkip -= static_cast<uint64_t>(rowGroupId) * footer->rowindexstride();
- if (rowGroupId != 0) {
- seekToRowGroup(rowGroupId);
+ if (rowGroupId != 0) {
+ seekToRowGroup(rowGroupId);
+ }
}
- }
- reader->skip(rowsToSkip);
+ reader->skip(rowsToSkip);
+ }
}
- void RowReaderImpl::seekToRowGroup(uint32_t rowGroupEntryId) {
+ void RowReaderImpl::loadStripeIndex() {
// reset all previous row indexes
rowIndexes.clear();
+ bloomFilterIndex.clear();
// obtain row indexes for selected columns
uint64_t offset = currentStripeInfo.offset();
@@ -319,7 +343,8 @@ namespace orc {
const proto::Stream& pbStream = currentStripeFooter.streams(i);
uint64_t colId = pbStream.column();
if (selectedColumns[colId] && pbStream.has_kind()
- && pbStream.kind() == proto::Stream_Kind_ROW_INDEX) {
+ && (pbStream.kind() == proto::Stream_Kind_ROW_INDEX ||
+ pbStream.kind() == proto::Stream_Kind_BLOOM_FILTER_UTF8)) {
std::unique_ptr<SeekableInputStream> inStream =
createDecompressor(getCompression(),
std::unique_ptr<SeekableInputStream>
@@ -331,16 +356,33 @@ namespace orc {
getCompressionSize(),
*contents->pool);
- proto::RowIndex rowIndex;
- if (!rowIndex.ParseFromZeroCopyStream(inStream.get())) {
- throw ParseError("Failed to parse the row index");
+ if (pbStream.kind() == proto::Stream_Kind_ROW_INDEX) {
+ proto::RowIndex rowIndex;
+ if (!rowIndex.ParseFromZeroCopyStream(inStream.get())) {
+ throw ParseError("Failed to parse the row index");
+ }
+ rowIndexes[colId] = rowIndex;
+ } else { // Stream_Kind_BLOOM_FILTER_UTF8
+ proto::BloomFilterIndex pbBFIndex;
+ if (!pbBFIndex.ParseFromZeroCopyStream(inStream.get())) {
+ throw ParseError("Failed to parse bloom filter index");
+ }
+ BloomFilterIndex bfIndex;
+ for (int j = 0; j < pbBFIndex.bloomfilter_size(); j++) {
+ bfIndex.entries.push_back(BloomFilterUTF8Utils::deserialize(
+ pbStream.kind(),
+ currentStripeFooter.columns(static_cast<int>(pbStream.column())),
+ pbBFIndex.bloomfilter(j)));
+ }
+ // add bloom filters to result for one column
+ bloomFilterIndex[pbStream.column()] = bfIndex;
}
-
- rowIndexes[colId] = rowIndex;
}
offset += pbStream.length();
}
+ }
+ void RowReaderImpl::seekToRowGroup(uint32_t rowGroupEntryId) {
// store positions for selected columns
std::vector<std::list<uint64_t>> positions;
// store position providers for selected colimns
@@ -516,10 +558,7 @@ namespace orc {
}
WriterVersion ReaderImpl::getWriterVersion() const {
- if (!contents->postscript->has_writerversion()) {
- return WriterVersion_ORIGINAL;
- }
- return static_cast<WriterVersion>(contents->postscript->writerversion());
+ return getWriterVersionImpl(contents.get());
}
uint64_t ReaderImpl::getContentLength() const {
@@ -892,29 +931,68 @@ namespace orc {
void RowReaderImpl::startNextStripe() {
reader.reset(); // ColumnReaders use lots of memory; free old memory first
- currentStripeInfo = footer->stripes(static_cast<int>(currentStripe));
- uint64_t fileLength = contents->stream->getLength();
- if (currentStripeInfo.offset() + currentStripeInfo.indexlength() +
+ rowIndexes.clear();
+ bloomFilterIndex.clear();
+
+ do {
+ currentStripeInfo = footer->stripes(static_cast<int>(currentStripe));
+ uint64_t fileLength = contents->stream->getLength();
+ if (currentStripeInfo.offset() + currentStripeInfo.indexlength() +
currentStripeInfo.datalength() + currentStripeInfo.footerlength() >= fileLength) {
- std::stringstream msg;
- msg << "Malformed StripeInformation at stripe index " << currentStripe << ": fileLength="
- << fileLength << ", StripeInfo=(offset=" << currentStripeInfo.offset() << ", indexLength="
- << currentStripeInfo.indexlength() << ", dataLength=" << currentStripeInfo.datalength()
- << ", footerLength=" << currentStripeInfo.footerlength() << ")";
- throw ParseError(msg.str());
+ std::stringstream msg;
+ msg << "Malformed StripeInformation at stripe index " << currentStripe << ": fileLength="
+ << fileLength << ", StripeInfo=(offset=" << currentStripeInfo.offset() << ", indexLength="
+ << currentStripeInfo.indexlength() << ", dataLength=" << currentStripeInfo.datalength()
+ << ", footerLength=" << currentStripeInfo.footerlength() << ")";
+ throw ParseError(msg.str());
+ }
+ currentStripeFooter = getStripeFooter(currentStripeInfo, *contents.get());
+ rowsInCurrentStripe = currentStripeInfo.numberofrows();
+
+ if (sargsApplier) {
+ // read row group statistics and bloom filters of current stripe
+ loadStripeIndex();
+
+ // select row groups to read in the current stripe
+ sargsApplier->pickRowGroups(rowsInCurrentStripe,
+ rowIndexes,
+ bloomFilterIndex);
+ if (sargsApplier->hasSelectedFrom(currentRowInStripe)) {
+ // current stripe has at least one row group matching the predicate
+ break;
+ } else {
+ // advance to next stripe when current stripe has no matching rows
+ currentStripe += 1;
+ currentRowInStripe = 0;
+ }
+ }
+ } while (sargsApplier && currentStripe < lastStripe);
+
+ if (currentStripe < lastStripe) {
+ // get writer timezone info from stripe footer to help understand timestamp values.
+ const Timezone& writerTimezone =
+ currentStripeFooter.has_writertimezone() ?
+ getTimezoneByName(currentStripeFooter.writertimezone()) :
+ localTimezone;
+ StripeStreamsImpl stripeStreams(*this, currentStripe, currentStripeInfo,
+ currentStripeFooter,
+ currentStripeInfo.offset(),
+ *contents->stream,
+ writerTimezone);
+ reader = buildReader(*contents->schema, stripeStreams);
+
+ if (sargsApplier) {
+ // move to the 1st selected row group when PPD is enabled.
+ currentRowInStripe = advanceToNextRowGroup(currentRowInStripe,
+ rowsInCurrentStripe,
+ footer->rowindexstride(),
+ sargsApplier->getRowGroups());
+ previousRow = firstRowOfStripe[currentStripe] + currentRowInStripe - 1;
+ if (currentRowInStripe > 0) {
+ seekToRowGroup(static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride()));
+ }
+ }
}
- currentStripeFooter = getStripeFooter(currentStripeInfo, *contents.get());
- rowsInCurrentStripe = currentStripeInfo.numberofrows();
- const Timezone& writerTimezone =
- currentStripeFooter.has_writertimezone() ?
- getTimezoneByName(currentStripeFooter.writertimezone()) :
- localTimezone;
- StripeStreamsImpl stripeStreams(*this, currentStripe, currentStripeInfo,
- currentStripeFooter,
- currentStripeInfo.offset(),
- *(contents->stream.get()),
- writerTimezone);
- reader = buildReader(*contents->schema.get(), stripeStreams);
}
bool RowReaderImpl::next(ColumnVectorBatch& data) {
@@ -934,7 +1012,20 @@ namespace orc {
uint64_t rowsToRead =
std::min(static_cast<uint64_t>(data.capacity),
rowsInCurrentStripe - currentRowInStripe);
+ if (sargsApplier) {
+ rowsToRead = computeBatchSize(rowsToRead,
+ currentRowInStripe,
+ rowsInCurrentStripe,
+ footer->rowindexstride(),
+ sargsApplier->getRowGroups());
+ }
data.numElements = rowsToRead;
+ if (rowsToRead == 0) {
+ previousRow = lastStripe <= 0 ? footer->numberofrows() :
+ firstRowOfStripe[lastStripe - 1] +
+ footer->stripes(static_cast<int>(lastStripe - 1)).numberofrows();
+ return false;
+ }
if (enableEncodedBlock) {
reader->nextEncoded(data, rowsToRead, nullptr);
}
@@ -944,6 +1035,22 @@ namespace orc {
// update row number
previousRow = firstRowOfStripe[currentStripe] + currentRowInStripe;
currentRowInStripe += rowsToRead;
+
+ // check if we need to advance to next selected row group
+ if (sargsApplier) {
+ uint64_t nextRowToRead = advanceToNextRowGroup(currentRowInStripe,
+ rowsInCurrentStripe,
+ footer->rowindexstride(),
+ sargsApplier->getRowGroups());
+ if (currentRowInStripe != nextRowToRead) {
+ // it is guaranteed to be at start of a row group
+ currentRowInStripe = nextRowToRead;
+ if (currentRowInStripe < rowsInCurrentStripe) {
+ seekToRowGroup(static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride()));
+ }
+ }
+ }
+
if (currentRowInStripe >= rowsInCurrentStripe) {
currentStripe += 1;
currentRowInStripe = 0;
@@ -951,6 +1058,47 @@ namespace orc {
return rowsToRead != 0;
}
+ uint64_t RowReaderImpl::computeBatchSize(uint64_t requestedSize,
+ uint64_t currentRowInStripe,
+ uint64_t rowsInCurrentStripe,
+ uint64_t rowIndexStride,
+ const std::vector<bool>& includedRowGroups) {
+ // In case of PPD, batch size should be aware of row group boundaries. If only a subset of row
+ // groups are selected then marker position is set to the end of range (subset of row groups
+ // within stripe).
+ uint64_t endRowInStripe = rowsInCurrentStripe;
+ if (!includedRowGroups.empty()) {
+ endRowInStripe = currentRowInStripe;
+ uint32_t rg = static_cast<uint32_t>(currentRowInStripe / rowIndexStride);
+ for (; rg < includedRowGroups.size(); ++rg) {
+ if (!includedRowGroups[rg]) {
+ break;
+ } else {
+ endRowInStripe = std::min(rowsInCurrentStripe, (rg + 1) * rowIndexStride);
+ }
+ }
+ }
+ return std::min(requestedSize, endRowInStripe - currentRowInStripe);
+ }
+
+ uint64_t RowReaderImpl::advanceToNextRowGroup(uint64_t currentRowInStripe,
+ uint64_t rowsInCurrentStripe,
+ uint64_t rowIndexStride,
+ const std::vector<bool>& includedRowGroups) {
+ if (!includedRowGroups.empty()) {
+ uint32_t rg = static_cast<uint32_t>(currentRowInStripe / rowIndexStride);
+ for (; rg < includedRowGroups.size(); ++rg) {
+ if (includedRowGroups[rg]) {
+ return currentRowInStripe;
+ } else {
+ // advance to start of next row group
+ currentRowInStripe = (rg + 1) * rowIndexStride;
+ }
+ }
+ }
+ return std::min(currentRowInStripe, rowsInCurrentStripe);
+ }
+
std::unique_ptr<ColumnVectorBatch> RowReaderImpl::createRowBatch
(uint64_t capacity) const {
return getSelectedType().createRowBatch(capacity, *contents->pool, enableEncodedBlock);
diff --git a/c++/src/Reader.hh b/c++/src/Reader.hh
index a381956..0693c62 100644
--- a/c++/src/Reader.hh
+++ b/c++/src/Reader.hh
@@ -19,13 +19,14 @@
#ifndef ORC_READER_IMPL_HH
#define ORC_READER_IMPL_HH
+#include "orc/Exceptions.hh"
#include "orc/Int128.hh"
#include "orc/OrcFile.hh"
#include "orc/Reader.hh"
#include "ColumnReader.hh"
-#include "orc/Exceptions.hh"
#include "RLE.hh"
+#include "sargs/SargsApplier.hh"
#include "TypeImpl.hh"
namespace orc {
@@ -142,6 +143,30 @@ namespace orc {
// row index of current stripe with column id as the key
std::unordered_map<uint64_t, proto::RowIndex> rowIndexes;
+ std::map<uint32_t, BloomFilterIndex> bloomFilterIndex;
+ std::shared_ptr<SearchArgument> sargs;
+ std::unique_ptr<SargsApplier> sargsApplier;
+
+ // load stripe index if not done so
+ void loadStripeIndex();
+
+ // In case of PPD, batch size should be aware of row group boundaries.
+ // If only a subset of row groups are selected then the next read should
+ // stop at the end of selected range.
+ static uint64_t computeBatchSize(uint64_t requestedSize,
+ uint64_t currentRowInStripe,
+ uint64_t rowsInCurrentStripe,
+ uint64_t rowIndexStride,
+ const std::vector<bool>& includedRowGroups);
+
+ // Skip non-selected rows
+ static uint64_t advanceToNextRowGroup(uint64_t currentRowInStripe,
+ uint64_t rowsInCurrentStripe,
+ uint64_t rowIndexStride,
+ const std::vector<bool>& includedRowGroups);
+
+ friend class TestRowReader_advanceToNextRowGroup_Test;
+ friend class TestRowReader_computeBatchSize_Test;
/**
* Seek to the start of a row group in the current stripe
@@ -159,7 +184,6 @@ namespace orc {
const RowReaderOptions& options);
// Select the columns from the options object
- void updateSelected();
const std::vector<bool> getSelectedColumns() const override;
const Type& getSelectedType() const override;
diff --git a/c++/src/sargs/SargsApplier.cc b/c++/src/sargs/SargsApplier.cc
index ea6b2c5..af709ab 100644
--- a/c++/src/sargs/SargsApplier.cc
+++ b/c++/src/sargs/SargsApplier.cc
@@ -17,6 +17,7 @@
*/
#include "SargsApplier.hh"
+#include <numeric>
namespace orc {
@@ -43,7 +44,8 @@ namespace orc {
: mType(type)
, mSearchArgument(searchArgument)
, mRowIndexStride(rowIndexStride)
- , mWriterVersion(writerVersion) {
+ , mWriterVersion(writerVersion)
+ , mStats(0, 0) {
const SearchArgumentImpl * sargs =
dynamic_cast<const SearchArgumentImpl *>(mSearchArgument);
@@ -106,6 +108,12 @@ namespace orc {
mHasSkipped = mHasSkipped || (!mRowGroups[rowGroup]);
}
+ // update stats
+ mStats.first = std::accumulate(
+ mRowGroups.cbegin(), mRowGroups.cend(), mStats.first,
+ [](bool rg, uint64_t s) { return rg ? 1 : 0 + s; });
+ mStats.second += groupsInStripe;
+
return mHasSelected;
}
diff --git a/c++/src/sargs/SargsApplier.hh b/c++/src/sargs/SargsApplier.hh
index a2b5e69..cc6db92 100644
--- a/c++/src/sargs/SargsApplier.hh
+++ b/c++/src/sargs/SargsApplier.hh
@@ -76,6 +76,10 @@ namespace orc {
return false;
}
+ std::pair<uint64_t, uint64_t> getStats() const {
+ return mStats;
+ }
+
private:
friend class TestSargsApplier_findColumnTest_Test;
static uint64_t findColumn(const Type& type, const std::string& colName);
@@ -93,6 +97,8 @@ namespace orc {
uint64_t mTotalRowsInStripe;
bool mHasSelected;
bool mHasSkipped;
+ // keep stats of selected RGs and evaluated RGs
+ std::pair<uint64_t, uint64_t> mStats;
};
}
diff --git a/c++/test/CMakeLists.txt b/c++/test/CMakeLists.txt
index da5ef57..badee89 100644
--- a/c++/test/CMakeLists.txt
+++ b/c++/test/CMakeLists.txt
@@ -35,6 +35,7 @@ add_executable (orc-test
TestDriver.cc
TestInt128.cc
TestPredicateLeaf.cc
+ TestPredicatePushdown.cc
TestReader.cc
TestRleDecoder.cc
TestRleEncoder.cc
diff --git a/c++/test/TestPredicatePushdown.cc b/c++/test/TestPredicatePushdown.cc
new file mode 100644
index 0000000..bae5687
--- /dev/null
+++ b/c++/test/TestPredicatePushdown.cc
@@ -0,0 +1,187 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "orc/OrcFile.hh"
+#include "orc/sargs/SearchArgument.hh"
+#include "MemoryInputStream.hh"
+#include "MemoryOutputStream.hh"
+#include "wrap/gtest-wrapper.h"
+
+namespace orc {
+
+ static const int DEFAULT_MEM_STREAM_SIZE = 10 * 1024 * 1024; // 10M
+
+ TEST(TestPredicatePushdown, testPredicatePushdown) {
+ MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+ MemoryPool * pool = getDefaultPool();
+ auto type = std::unique_ptr<Type>(Type::buildTypeFromString(
+ "struct<int1:bigint,string1:string>"));
+ WriterOptions options;
+ options.setStripeSize(1024 * 1024)
+ .setCompressionBlockSize(1024)
+ .setCompression(CompressionKind_NONE)
+ .setMemoryPool(pool)
+ .setRowIndexStride(1000);
+
+ auto writer = createWriter(*type, &memStream, options);
+ auto batch = writer->createRowBatch(3500);
+ auto& structBatch = dynamic_cast<StructVectorBatch&>(*batch);
+ auto& longBatch = dynamic_cast<LongVectorBatch&>(*structBatch.fields[0]);
+ auto& strBatch = dynamic_cast<StringVectorBatch&>(*structBatch.fields[1]);
+
+ // row group stride is 1000, here 3500 rows of data constitute 4 row groups.
+ // the min/max pair of each row group is as below:
+ // int1: 0/299700, 300000/599700, 600000/899700, 900000/1049700
+ // string1: "0"/"9990", "10000"/"19990", "20000"/"29990", "30000"/"34990"
+ char buffer[3500 * 5];
+ uint64_t offset = 0;
+ for (uint64_t i = 0; i < 3500; ++i) {
+ longBatch.data[i] = static_cast<int64_t>(i * 300);
+
+ std::ostringstream ss;
+ ss << 10 * i;
+ std::string str = ss.str();
+ memcpy(buffer + offset, str.c_str(), str.size());
+ strBatch.data[i] = buffer + offset;
+ strBatch.length[i] = static_cast<int64_t>(str.size());
+ offset += str.size();
+ }
+ structBatch.numElements = 3500;
+ longBatch.numElements = 3500;
+ strBatch.numElements = 3500;
+ writer->add(*batch);
+ writer->close();
+
+ std::unique_ptr<InputStream> inStream(new MemoryInputStream (
+ memStream.getData(), memStream.getLength()));
+ ReaderOptions readerOptions;
+ options.setMemoryPool(pool);
+ auto reader = createReader(std::move(inStream), readerOptions);
+ EXPECT_EQ(3500, reader->getNumberOfRows());
+
+ // build search argument (x >= 300000 AND x < 600000)
+ {
+ std::unique_ptr<SearchArgument> sarg = SearchArgumentFactory::newBuilder()
+ ->startAnd()
+ .startNot()
+ .lessThan("int1", PredicateDataType::LONG,
+ Literal(static_cast<int64_t>(300000L)))
+ .end()
+ .lessThan("int1", PredicateDataType::LONG,
+ Literal(static_cast<int64_t>(600000L)))
+ .end()
+ .build();
+
+ RowReaderOptions rowReaderOpts;
+ rowReaderOpts.searchArgument(std::move(sarg));
+ auto rowReader = reader->createRowReader(rowReaderOpts);
+
+ auto readBatch = rowReader->createRowBatch(2000);
+ auto& batch0 = dynamic_cast<StructVectorBatch&>(*readBatch);
+ auto& batch1 = dynamic_cast<LongVectorBatch&>(*batch0.fields[0]);
+ auto& batch2 = dynamic_cast<StringVectorBatch&>(*batch0.fields[1]);
+
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+ EXPECT_EQ(1000, readBatch->numElements);
+ EXPECT_EQ(1000, rowReader->getRowNumber());
+ for (uint64_t i = 1000; i < 2000; ++i) {
+ EXPECT_EQ(300 * i, batch1.data[i - 1000]);
+ EXPECT_EQ(std::to_string(10 * i),
+ std::string(batch2.data[i - 1000], static_cast<size_t>(batch2.length[i - 1000])));
+ }
+ EXPECT_EQ(false, rowReader->next(*readBatch));
+ EXPECT_EQ(3500, rowReader->getRowNumber());
+ }
+
+ // look through the file with no rows selected: x < 0
+ {
+ std::unique_ptr<SearchArgument> sarg = SearchArgumentFactory::newBuilder()
+ ->startAnd()
+ .lessThan("int1", PredicateDataType::LONG,
+ Literal(static_cast<int64_t>(0)))
+ .end()
+ .build();
+
+ RowReaderOptions rowReaderOpts;
+ rowReaderOpts.searchArgument(std::move(sarg));
+ auto rowReader = reader->createRowReader(rowReaderOpts);
+
+ auto readBatch = rowReader->createRowBatch(2000);
+ EXPECT_EQ(false, rowReader->next(*readBatch));
+ EXPECT_EQ(3500, rowReader->getRowNumber());
+ }
+
+ // select first 1000 and last 500 rows: x < 30000 OR x >= 1020000
+ {
+ std::unique_ptr<SearchArgument> sarg = SearchArgumentFactory::newBuilder()
+ ->startOr()
+ .lessThan("int1", PredicateDataType::LONG,
+ Literal(static_cast<int64_t>(300 * 100)))
+ .startNot()
+ .lessThan("int1", PredicateDataType::LONG,
+ Literal(static_cast<int64_t>(300 * 3400)))
+ .end()
+ .end()
+ .build();
+
+ RowReaderOptions rowReaderOpts;
+ rowReaderOpts.searchArgument(std::move(sarg));
+ auto rowReader = reader->createRowReader(rowReaderOpts);
+
+ auto readBatch = rowReader->createRowBatch(2000);
+ auto& batch0 = dynamic_cast<StructVectorBatch&>(*readBatch);
+ auto& batch1 = dynamic_cast<LongVectorBatch&>(*batch0.fields[0]);
+ auto& batch2 = dynamic_cast<StringVectorBatch&>(*batch0.fields[1]);
+
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+ EXPECT_EQ(1000, readBatch->numElements);
+ EXPECT_EQ(0, rowReader->getRowNumber());
+ for (uint64_t i = 0; i < 1000; ++i) {
+ EXPECT_EQ(300 * i, batch1.data[i]);
+ EXPECT_EQ(std::to_string(10 * i),
+ std::string(batch2.data[i], static_cast<size_t>(batch2.length[i])));
+ }
+
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+ EXPECT_EQ(500, readBatch->numElements);
+ EXPECT_EQ(3000, rowReader->getRowNumber());
+ for (uint64_t i = 3000; i < 3500; ++i) {
+ EXPECT_EQ(300 * i, batch1.data[i - 3000]);
+ EXPECT_EQ(std::to_string(10 * i),
+ std::string(batch2.data[i - 3000], static_cast<size_t>(batch2.length[i - 3000])));
+ }
+
+ EXPECT_EQ(false, rowReader->next(*readBatch));
+ EXPECT_EQ(3500, rowReader->getRowNumber());
+
+ // test seek to 3rd row group but is adjusted to 4th row group
+ rowReader->seekToRow(2500);
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+ EXPECT_EQ(3000, rowReader->getRowNumber());
+ EXPECT_EQ(500, readBatch->numElements);
+ for (uint64_t i = 3000; i < 3500; ++i) {
+ EXPECT_EQ(300 * i, batch1.data[i - 3000]);
+ EXPECT_EQ(std::to_string(10 * i),
+ std::string(batch2.data[i - 3000], static_cast<size_t>(batch2.length[i - 3000])));
+ }
+ EXPECT_EQ(false, rowReader->next(*readBatch));
+ EXPECT_EQ(3500, rowReader->getRowNumber());
+ }
+ }
+
+} // namespace orc
diff --git a/c++/test/TestReader.cc b/c++/test/TestReader.cc
index 68f6de1..e31ed11 100644
--- a/c++/test/TestReader.cc
+++ b/c++/test/TestReader.cc
@@ -17,6 +17,7 @@
*/
#include "orc/Reader.hh"
+#include "Reader.hh"
#include "Adaptor.hh"
@@ -46,4 +47,58 @@ namespace orc {
compressionKindToString(static_cast<CompressionKind>(99)));
}
+ TEST(TestRowReader, computeBatchSize) {
+ uint64_t rowIndexStride = 100;
+ uint64_t rowsInCurrentStripe = 100 * 8 + 50;
+ std::vector<bool> includedRowGroups =
+ { false, false, true, true, false, false, true, true, false };
+
+ EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+ 1024, 0, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+ 1024, 50, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(200, RowReaderImpl::computeBatchSize(
+ 1024, 200, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(150, RowReaderImpl::computeBatchSize(
+ 1024, 250, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+ 1024, 550, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(100, RowReaderImpl::computeBatchSize(
+ 1024, 700, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(50, RowReaderImpl::computeBatchSize(
+ 50, 700, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+ 50, 810, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(0, RowReaderImpl::computeBatchSize(
+ 50, 900, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ }
+
+ TEST(TestRowReader, advanceToNextRowGroup) {
+ uint64_t rowIndexStride = 100;
+ uint64_t rowsInCurrentStripe = 100 * 8 + 50;
+ std::vector<bool> includedRowGroups =
+ { false, false, true, true, false, false, true, true, false };
+
+ EXPECT_EQ(200, RowReaderImpl::advanceToNextRowGroup(
+ 0, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(200, RowReaderImpl::advanceToNextRowGroup(
+ 150, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(250, RowReaderImpl::advanceToNextRowGroup(
+ 250, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(350, RowReaderImpl::advanceToNextRowGroup(
+ 350, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(350, RowReaderImpl::advanceToNextRowGroup(
+ 350, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(600, RowReaderImpl::advanceToNextRowGroup(
+ 500, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(699, RowReaderImpl::advanceToNextRowGroup(
+ 699, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(799, RowReaderImpl::advanceToNextRowGroup(
+ 799, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(850, RowReaderImpl::advanceToNextRowGroup(
+ 800, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ EXPECT_EQ(850, RowReaderImpl::advanceToNextRowGroup(
+ 900, rowsInCurrentStripe, rowIndexStride, includedRowGroups));
+ }
+
} // namespace