You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/08/25 15:59:50 UTC

[arrow] branch master updated: ARROW-6238: [C++][Dataset] Implement SimpleDataSource, SimpleDataFragment and SimpleScanTask

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e0fa3d1  ARROW-6238: [C++][Dataset] Implement SimpleDataSource, SimpleDataFragment and SimpleScanTask
e0fa3d1 is described below

commit e0fa3d19a1bce02bd39ccb88b0721db162f5d9ba
Author: François Saint-Jacques <fs...@gmail.com>
AuthorDate: Sun Aug 25 10:59:35 2019 -0500

    ARROW-6238: [C++][Dataset] Implement SimpleDataSource, SimpleDataFragment and SimpleScanTask
    
    The Simple* family of classes are iterator backed by explicit vectors. This can be useful to represent a memory datasource that rarely changes, e.g. a constant join table.
    
    - SimpleDataSource is backed by a vector<DataFragment>.
    - SimpleDataFragment is backed by a vector<RecordBatch>.
    - SimpleScanTask is backed by a vector<RecordBatch>.
    
    Closes #5140 from fsaintjacques/ARROW-6238-simple-datasource-datafragment and squashes the following commits:
    
    2d415666c <François Saint-Jacques> Address comments
    3e6c2b735 <François Saint-Jacques> Linter
    262d18eb8 <François Saint-Jacques> Reformat and lint
    964c799f0 <François Saint-Jacques> Improve ArrowBaseFixtureMixin to generate according to schema
    bce55137d <François Saint-Jacques> ARROW-6238:  Implements SimpleDataSource, SimpleDataFragment and SimpleScanTask
    
    Authored-by: François Saint-Jacques <fs...@gmail.com>
    Signed-off-by: Wes McKinney <we...@apache.org>
---
 cpp/src/arrow/dataset/CMakeLists.txt  | 10 +++++-
 cpp/src/arrow/dataset/dataset.cc      | 50 ++++++++++++++++++++++++++
 cpp/src/arrow/dataset/dataset.h       | 23 ++++++++++++
 cpp/src/arrow/dataset/dataset_test.cc | 62 ++++++++++++++++++++++++++++++++
 cpp/src/arrow/dataset/scanner.cc      |  8 ++++-
 cpp/src/arrow/dataset/scanner.h       | 13 +++++++
 cpp/src/arrow/dataset/test_util.h     | 44 +++++++++++++++++++++++
 cpp/src/arrow/testing/gtest_util.h    | 67 +++++++++++++++++++++++++++++++++++
 cpp/src/arrow/testing/util.h          | 22 ++++++++++++
 cpp/src/arrow/util/iterator.h         | 33 +++++++++++++++--
 10 files changed, 328 insertions(+), 4 deletions(-)

diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt
index 923df33..494fa36 100644
--- a/cpp/src/arrow/dataset/CMakeLists.txt
+++ b/cpp/src/arrow/dataset/CMakeLists.txt
@@ -23,7 +23,7 @@ arrow_install_all_headers("arrow/dataset")
 # pkg-config support
 arrow_add_pkg_config("arrow-dataset")
 
-set(ARROW_DATASET_SRCS scanner.cc file_base.cc)
+set(ARROW_DATASET_SRCS dataset.cc file_base.cc scanner.cc)
 set(ARROW_DATASET_LINK_STATIC arrow_static)
 set(ARROW_DATASET_LINK_SHARED arrow_shared)
 
@@ -57,6 +57,14 @@ foreach(LIB_TARGET ${ARROW_DATASET_LIBRARIES})
 endforeach()
 
 if(NOT WIN32)
+  add_arrow_test(dataset_test
+                 EXTRA_LINK_LIBS
+                 ${ARROW_DATASET_TEST_LINK_LIBS}
+                 PREFIX
+                 "arrow-dataset"
+                 LABELS
+                 "arrow_dataset")
+
   add_arrow_test(file_test
                  EXTRA_LINK_LIBS
                  ${ARROW_DATASET_TEST_LINK_LIBS}
diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc
new file mode 100644
index 0000000..403db33
--- /dev/null
+++ b/cpp/src/arrow/dataset/dataset.cc
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/dataset.h"
+
+#include <memory>
+#include <utility>
+
+#include "arrow/dataset/scanner.h"
+#include "arrow/util/stl.h"
+
+namespace arrow {
+namespace dataset {
+
+SimpleDataFragment::SimpleDataFragment(
+    std::vector<std::shared_ptr<RecordBatch>> record_batches)
+    : record_batches_(std::move(record_batches)) {}
+
+Status SimpleDataFragment::Scan(std::shared_ptr<ScanContext> scan_context,
+                                std::unique_ptr<ScanTaskIterator>* out) {
+  // Make an explicit copy of record_batches_ to ensure Scan can be called
+  // multiple times.
+  auto it = MakeIterator(record_batches_);
+
+  // RecordBatch -> ScanTask
+  auto fn = [](std::shared_ptr<RecordBatch> batch) -> std::unique_ptr<ScanTask> {
+    std::vector<std::shared_ptr<RecordBatch>> batches{batch};
+    return internal::make_unique<SimpleScanTask>(std::move(batches));
+  };
+
+  *out = MakeMapIterator(fn, std::move(it));
+  return Status::OK();
+}
+
+}  // namespace dataset
+}  // namespace arrow
diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h
index 4a780c0..e538522 100644
--- a/cpp/src/arrow/dataset/dataset.h
+++ b/cpp/src/arrow/dataset/dataset.h
@@ -19,6 +19,7 @@
 
 #include <memory>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "arrow/dataset/type_fwd.h"
@@ -53,6 +54,23 @@ class ARROW_DS_EXPORT DataFragment {
   virtual ~DataFragment() = default;
 };
 
+/// \brief A trivial DataFragment that yields ScanTask out of a fixed set of
+/// RecordBatch.
+class ARROW_DS_EXPORT SimpleDataFragment : public DataFragment {
+ public:
+  explicit SimpleDataFragment(std::vector<std::shared_ptr<RecordBatch>> record_batches);
+
+  Status Scan(std::shared_ptr<ScanContext> scan_context,
+              std::unique_ptr<ScanTaskIterator>* out) override;
+
+  bool splittable() const override { return false; }
+
+  std::shared_ptr<ScanOptions> scan_options() const override { return NULLPTR; }
+
+ protected:
+  std::vector<std::shared_ptr<RecordBatch>> record_batches_;
+};
+
 /// \brief A basic component of a Dataset which yields zero or more
 /// DataFragments. A DataSource acts as a discovery mechanism of DataFragments
 /// and partitions, e.g. files deeply nested in a directory.
@@ -71,11 +89,16 @@ class ARROW_DS_EXPORT DataSource {
 /// \brief A DataSource consisting of a flat sequence of DataFragments
 class ARROW_DS_EXPORT SimpleDataSource : public DataSource {
  public:
+  explicit SimpleDataSource(DataFragmentVector fragments)
+      : fragments_(std::move(fragments)) {}
+
   std::unique_ptr<DataFragmentIterator> GetFragments(
       std::shared_ptr<ScanOptions> options) override {
     return MakeIterator(fragments_);
   }
 
+  std::string type() const override { return "simple_data_source"; }
+
  private:
   DataFragmentVector fragments_;
 };
diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc
new file mode 100644
index 0000000..e3dfb34
--- /dev/null
+++ b/cpp/src/arrow/dataset/dataset_test.cc
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/dataset.h"
+
+#include "arrow/dataset/test_util.h"
+
+namespace arrow {
+namespace dataset {
+
+class TestSimpleDataFragment : public TestDataFragmentMixin {};
+
+TEST_F(TestSimpleDataFragment, Scan) {
+  constexpr int64_t kBatchSize = 1024;
+  constexpr int64_t kNumberBatches = 16;
+
+  auto s = schema({field("i32", int32()), field("f64", float64())});
+  auto batch = GetRecordBatch(kBatchSize, s);
+  auto reader = GetRecordBatchReader(kNumberBatches, batch);
+
+  // Creates a SimpleDataFragment of the same repeated batch.
+  auto fragment = SimpleDataFragment({kNumberBatches, batch});
+
+  AssertFragmentEquals(reader.get(), &fragment);
+}
+
+class TestSimpleDataSource : public TestDataSourceMixin {};
+
+TEST_F(TestSimpleDataSource, GetFragments) {
+  constexpr int64_t kNumberFragments = 4;
+  constexpr int64_t kBatchSize = 1024;
+  constexpr int64_t kNumberBatches = 16;
+
+  auto s = schema({field("i32", int32()), field("f64", float64())});
+  auto batch = GetRecordBatch(kBatchSize, s);
+  auto reader = GetRecordBatchReader(kNumberBatches * kNumberFragments, batch);
+
+  std::vector<std::shared_ptr<RecordBatch>> batches{kNumberBatches, batch};
+  auto fragment = std::make_shared<SimpleDataFragment>(batches);
+  // It is safe to copy fragment multiple time since Scan() does not consume
+  // the internal array.
+  auto source = SimpleDataSource({kNumberFragments, fragment});
+
+  AssertDataSourceEquals(reader.get(), &source);
+}
+
+}  // namespace dataset
+}  // namespace arrow
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index ad80264..110d726 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -18,5 +18,11 @@
 #include "arrow/dataset/scanner.h"
 
 namespace arrow {
-namespace dataset {}  // namespace dataset
+namespace dataset {
+
+std::unique_ptr<RecordBatchIterator> SimpleScanTask::Scan() {
+  return MakeIterator(record_batches_);
+}
+
+}  // namespace dataset
 }  // namespace arrow
diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h
index 7c83fc4..f942da7 100644
--- a/cpp/src/arrow/dataset/scanner.h
+++ b/cpp/src/arrow/dataset/scanner.h
@@ -19,6 +19,7 @@
 
 #include <memory>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "arrow/dataset/type_fwd.h"
@@ -63,6 +64,18 @@ class ARROW_DS_EXPORT ScanTask {
   virtual ~ScanTask() = default;
 };
 
+/// \brief A trivial ScanTask that yields the RecordBatch of an array.
+class ARROW_DS_EXPORT SimpleScanTask : public ScanTask {
+ public:
+  explicit SimpleScanTask(std::vector<std::shared_ptr<RecordBatch>> record_batches)
+      : record_batches_(std::move(record_batches)) {}
+
+  std::unique_ptr<RecordBatchIterator> Scan() override;
+
+ protected:
+  std::vector<std::shared_ptr<RecordBatch>> record_batches_;
+};
+
 /// \brief Main interface for
 class ARROW_DS_EXPORT Scanner {
  public:
diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h
index 4f4d081..5835bf8 100644
--- a/cpp/src/arrow/dataset/test_util.h
+++ b/cpp/src/arrow/dataset/test_util.h
@@ -133,6 +133,50 @@ class GeneratedRecordBatch : public RecordBatchReader {
   Gen gen_;
 };
 
+class DatasetFixtureMixin : public ::testing::Test {
+ public:
+  DatasetFixtureMixin() : ctx_(std::make_shared<ScanContext>()) {}
+
+ protected:
+  std::shared_ptr<ScanOptions> options_;
+  std::shared_ptr<ScanContext> ctx_;
+};
+
+class TestDataFragmentMixin : public DatasetFixtureMixin {
+ public:
+  /// \brief Ensure that record batches found in reader are equals to the
+  /// record batches yielded by the data fragment.
+  void AssertFragmentEquals(RecordBatchReader* expected, DataFragment* fragment) {
+    std::unique_ptr<ScanTaskIterator> it;
+    ARROW_EXPECT_OK(fragment->Scan(ctx_, &it));
+
+    ARROW_EXPECT_OK(it->Visit([expected](std::unique_ptr<ScanTask> task) -> Status {
+      auto batch_it = task->Scan();
+      return batch_it->Visit([expected](std::shared_ptr<RecordBatch> rhs) -> Status {
+        std::shared_ptr<RecordBatch> lhs;
+        RETURN_NOT_OK(expected->ReadNext(&lhs));
+        EXPECT_NE(lhs, nullptr);
+        AssertBatchesEqual(*lhs, *rhs);
+        return Status::OK();
+      });
+    }));
+  }
+};
+
+class TestDataSourceMixin : public TestDataFragmentMixin {
+ public:
+  /// \brief Ensure that record batches found in reader are equals to the
+  /// record batches yielded by the data fragments of a source.
+  void AssertDataSourceEquals(RecordBatchReader* expected, DataSource* source) {
+    auto it = source->GetFragments(options_);
+
+    ARROW_EXPECT_OK(it->Visit([&](std::shared_ptr<DataFragment> fragment) -> Status {
+      AssertFragmentEquals(expected, fragment.get());
+      return Status::OK();
+    }));
+  }
+};
+
 template <typename Gen>
 std::unique_ptr<GeneratedRecordBatch<Gen>> MakeGeneratedRecordBatch(
     std::shared_ptr<Schema> schema, Gen&& gen) {
diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h
index 13d1c3d..34fcfb8 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -33,6 +33,7 @@
 #include "arrow/builder.h"
 #include "arrow/result.h"
 #include "arrow/status.h"
+#include "arrow/testing/util.h"
 #include "arrow/type_fwd.h"
 #include "arrow/type_traits.h"
 #include "arrow/util/bit_util.h"
@@ -345,4 +346,70 @@ void AssertSortedEquals(std::vector<T> u, std::vector<T> v) {
   ASSERT_EQ(u, v);
 }
 
+// Mixin used to generate trivial Array/RecordBatch/RecordBatchReader
+template <typename T>
+static inline std::shared_ptr<Array> GetArray(int64_t size,
+                                              std::shared_ptr<DataType> type,
+                                              T value = 0) {
+  using BuilderType = typename CTypeTraits<T>::BuilderType;
+  auto builder_fn = [](BuilderType* builder) { builder->UnsafeAppend(T(0)); };
+  ASSERT_OK_AND_ASSIGN(auto array, ArrayFromBuilderVisitor(type, size, builder_fn));
+  return array;
+}
+
+static inline std::shared_ptr<Array> GetArrayOfZeros(int64_t size,
+                                                     std::shared_ptr<DataType> type) {
+  switch (type->id()) {
+    case Type::BOOL:
+      return GetArray<bool>(size, type);
+    case Type::UINT8:
+      return GetArray<uint8_t>(size, type);
+    case Type::UINT16:
+      return GetArray<uint16_t>(size, type);
+    case Type::UINT32:
+      return GetArray<uint32_t>(size, type);
+    case Type::UINT64:
+      return GetArray<uint64_t>(size, type);
+    case Type::INT8:
+      return GetArray<int8_t>(size, type);
+    case Type::INT16:
+      return GetArray<int16_t>(size, type);
+    case Type::INT32:
+      return GetArray<int32_t>(size, type);
+    case Type::INT64:
+      return GetArray<int64_t>(size, type);
+    case Type::FLOAT:
+      return GetArray<float>(size, type);
+    case Type::DOUBLE:
+      return GetArray<double>(size, type);
+    default:
+      std::abort();
+      break;
+  }
+
+  return nullptr;
+}
+
+static inline std::shared_ptr<RecordBatch> GetRecordBatch(
+    int64_t size, std::shared_ptr<Schema> schema) {
+  std::vector<std::shared_ptr<Array>> arrays;
+
+  for (const auto& field : schema->fields()) {
+    arrays.emplace_back(GetArrayOfZeros(size, field->type()));
+  }
+
+  return RecordBatch::Make(schema, size, arrays);
+}
+
+static inline std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+    int64_t n_batch, std::shared_ptr<RecordBatch> batch) {
+  return std::make_shared<RepeatedRecordBatch>(n_batch, batch);
+}
+
+static inline std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+    int64_t n_batch, int64_t batch_size, std::shared_ptr<Schema> schema) {
+  auto batch = GetRecordBatch(batch_size, std::move(schema));
+  return GetRecordBatchReader(n_batch, std::move(batch));
+}
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/testing/util.h b/cpp/src/arrow/testing/util.h
index 801f863..d5f4dca 100644
--- a/cpp/src/arrow/testing/util.h
+++ b/cpp/src/arrow/testing/util.h
@@ -177,4 +177,26 @@ Result<std::shared_ptr<Array>> ArrayFromBuilderVisitor(
   return ArrayFromBuilderVisitor(type, length, length, std::forward<Fn>(fn));
 }
 
+class RepeatedRecordBatch : public RecordBatchReader {
+ public:
+  RepeatedRecordBatch(int64_t repetitions, std::shared_ptr<RecordBatch> batch)
+      : repetitions_(repetitions), batch_(std::move(batch)) {}
+
+  std::shared_ptr<Schema> schema() const override { return batch_->schema(); }
+
+  Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+    if (repetitions_ > 0) {
+      *batch = batch_;
+      --repetitions_;
+    } else {
+      *batch = nullptr;
+    }
+    return Status::OK();
+  }
+
+ private:
+  int64_t repetitions_;
+  std::shared_ptr<RecordBatch> batch_;
+};
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h
index 0975e9d..78ab18b 100644
--- a/cpp/src/arrow/util/iterator.h
+++ b/cpp/src/arrow/util/iterator.h
@@ -61,7 +61,7 @@ class Iterator {
   }
 };
 
-/// Simple iterator which yields the elements of a std::vector
+/// \brief Simple iterator which yields the elements of a std::vector
 template <typename T>
 class VectorIterator : public Iterator<T> {
  public:
@@ -78,8 +78,37 @@ class VectorIterator : public Iterator<T> {
 };
 
 template <typename T>
-std::unique_ptr<VectorIterator<T>> MakeIterator(std::vector<T> v) {
+std::unique_ptr<Iterator<T>> MakeIterator(std::vector<T> v) {
   return std::unique_ptr<VectorIterator<T>>(new VectorIterator<T>(std::move(v)));
 }
 
+/// \brief MapIterator takes ownership of an iterator and a function to apply
+/// on every element.
+template <typename Fn, typename I, typename O = typename std::result_of<Fn(I)>::type>
+class MapIterator : public Iterator<O> {
+ public:
+  explicit MapIterator(Fn map, std::unique_ptr<Iterator<I>> it)
+      : map_(std::move(map)), it_(std::move(it)) {}
+
+  Status Next(O* out) override {
+    I i;
+
+    ARROW_RETURN_NOT_OK(it_->Next(&i));
+    // Ensure loops exit.
+    *out = (i == NULLPTR) ? NULLPTR : map_(std::move(i));
+
+    return Status::OK();
+  }
+
+ private:
+  Fn map_;
+  std::unique_ptr<Iterator<I>> it_;
+};
+
+template <typename Fn, typename I, typename O = typename std::result_of<Fn(I)>::type>
+std::unique_ptr<Iterator<O>> MakeMapIterator(Fn map, std::unique_ptr<Iterator<I>> it) {
+  return std::unique_ptr<MapIterator<Fn, I, O>>(
+      new MapIterator<Fn, I, O>(std::move(map), std::move(it)));
+}
+
 }  // namespace arrow