You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by bk...@apache.org on 2020/08/05 18:15:17 UTC
[arrow] branch master updated: ARROW-8002: [C++][Dataset][R]
Support partitioned dataset writing
This is an automated email from the ASF dual-hosted git repository.
bkietz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new ef2ccfc ARROW-8002: [C++][Dataset][R] Support partitioned dataset writing
ef2ccfc is described below
commit ef2ccfc7620add5f365b4391dcf61d5c01da65a9
Author: Benjamin Kietzman <be...@gmail.com>
AuthorDate: Wed Aug 5 14:14:09 2020 -0400
ARROW-8002: [C++][Dataset][R] Support partitioned dataset writing
Closes #7869 from bkietz/8002-partitioned-write
Lead-authored-by: Benjamin Kietzman <be...@gmail.com>
Co-authored-by: Neal Richardson <ne...@gmail.com>
Signed-off-by: Benjamin Kietzman <be...@gmail.com>
---
c_glib/arrow-glib/compute.cpp | 3 +-
cpp/src/arrow/array/array_nested.cc | 17 +-
cpp/src/arrow/array/array_nested.h | 10 +-
cpp/src/arrow/array/array_view_test.cc | 11 +
cpp/src/arrow/compute/api_vector.cc | 9 +-
cpp/src/arrow/compute/api_vector.h | 4 +-
cpp/src/arrow/compute/kernels/vector_hash_test.cc | 15 +-
cpp/src/arrow/dataset/dataset_internal.h | 4 +-
cpp/src/arrow/dataset/dataset_test.cc | 2 +-
cpp/src/arrow/dataset/file_base.cc | 140 +++++----
cpp/src/arrow/dataset/file_base.h | 132 +-------
cpp/src/arrow/dataset/file_ipc.cc | 55 +---
cpp/src/arrow/dataset/file_ipc.h | 6 +-
cpp/src/arrow/dataset/file_ipc_test.cc | 353 +++++++++++++++++++---
cpp/src/arrow/dataset/filter.cc | 188 ++++++++++++
cpp/src/arrow/dataset/filter.h | 17 ++
cpp/src/arrow/dataset/filter_test.cc | 52 ++++
cpp/src/arrow/dataset/partition.cc | 321 ++++----------------
cpp/src/arrow/dataset/partition.h | 27 +-
cpp/src/arrow/dataset/partition_test.cc | 181 +----------
cpp/src/arrow/dataset/scanner_internal.h | 37 +++
cpp/src/arrow/dataset/test_util.h | 8 +-
cpp/src/arrow/dataset/type_fwd.h | 3 -
cpp/src/arrow/record_batch.cc | 2 +-
cpp/src/arrow/record_batch.h | 7 +-
r/DESCRIPTION | 1 +
r/NAMESPACE | 2 +
r/R/arrowExports.R | 4 +
r/R/dataset-write.R | 82 +++++
r/R/dataset.R | 18 +-
r/R/schema.R | 24 ++
r/man/Dataset.Rd | 3 +
r/man/write_dataset.Rd | 42 +++
r/src/arrowExports.cpp | 22 ++
r/src/dataset.cpp | 15 +
r/tests/testthat/test-dataset.R | 101 +++++++
r/tests/testthat/test-schema.R | 18 ++
37 files changed, 1193 insertions(+), 743 deletions(-)
diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp
index 3e31899..20d910e 100644
--- a/c_glib/arrow-glib/compute.cpp
+++ b/c_glib/arrow-glib/compute.cpp
@@ -1246,7 +1246,8 @@ garrow_array_count_values(GArrowArray *array,
auto arrow_array = garrow_array_get_raw(array);
auto arrow_counted_values = arrow::compute::ValueCounts(arrow_array);
if (garrow::check(error, arrow_counted_values, "[array][count-values]")) {
- return GARROW_STRUCT_ARRAY(garrow_array_new_raw(&(*arrow_counted_values)));
+ std::shared_ptr<arrow::Array> arrow_counted_values_array = *arrow_counted_values;
+ return GARROW_STRUCT_ARRAY(garrow_array_new_raw(&arrow_counted_values_array));
} else {
return NULL;
}
diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc
index e8d4ed9..f1e8d32 100644
--- a/cpp/src/arrow/array/array_nested.cc
+++ b/cpp/src/arrow/array/array_nested.cc
@@ -101,9 +101,8 @@ Status CleanListOffsets(const Array& offsets, MemoryPool* pool,
}
template <typename TYPE>
-Result<std::shared_ptr<Array>> ListArrayFromArrays(const Array& offsets,
- const Array& values,
- MemoryPool* pool) {
+Result<std::shared_ptr<typename TypeTraits<TYPE>::ArrayType>> ListArrayFromArrays(
+ const Array& offsets, const Array& values, MemoryPool* pool) {
using offset_type = typename TYPE::offset_type;
using ArrayType = typename TypeTraits<TYPE>::ArrayType;
using OffsetArrowType = typename CTypeTraits<offset_type>::ArrowType;
@@ -238,15 +237,15 @@ void LargeListArray::SetData(const std::shared_ptr<ArrayData>& data) {
values_ = MakeArray(data_->child_data[0]);
}
-Result<std::shared_ptr<Array>> ListArray::FromArrays(const Array& offsets,
- const Array& values,
- MemoryPool* pool) {
+Result<std::shared_ptr<ListArray>> ListArray::FromArrays(const Array& offsets,
+ const Array& values,
+ MemoryPool* pool) {
return ListArrayFromArrays<ListType>(offsets, values, pool);
}
-Result<std::shared_ptr<Array>> LargeListArray::FromArrays(const Array& offsets,
- const Array& values,
- MemoryPool* pool) {
+Result<std::shared_ptr<LargeListArray>> LargeListArray::FromArrays(const Array& offsets,
+ const Array& values,
+ MemoryPool* pool) {
return ListArrayFromArrays<LargeListType>(offsets, values, pool);
}
diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h
index e37c34b..e5a219c 100644
--- a/cpp/src/arrow/array/array_nested.h
+++ b/cpp/src/arrow/array/array_nested.h
@@ -103,7 +103,7 @@ class ARROW_EXPORT ListArray : public BaseListArray<ListType> {
/// \param[in] values Array containing list values
/// \param[in] pool MemoryPool in case new offsets array needs to be
/// allocated because of null values
- static Result<std::shared_ptr<Array>> FromArrays(
+ static Result<std::shared_ptr<ListArray>> FromArrays(
const Array& offsets, const Array& values,
MemoryPool* pool = default_memory_pool());
@@ -148,7 +148,7 @@ class ARROW_EXPORT LargeListArray : public BaseListArray<LargeListType> {
/// \param[in] values Array containing list values
/// \param[in] pool MemoryPool in case new offsets array needs to be
/// allocated because of null values
- static Result<std::shared_ptr<Array>> FromArrays(
+ static Result<std::shared_ptr<LargeListArray>> FromArrays(
const Array& offsets, const Array& values,
MemoryPool* pool = default_memory_pool());
@@ -310,8 +310,7 @@ class ARROW_EXPORT StructArray : public Array {
/// The length and data type are automatically inferred from the arguments.
/// There should be at least one child array.
static Result<std::shared_ptr<StructArray>> Make(
- const std::vector<std::shared_ptr<Array>>& children,
- const std::vector<std::string>& field_names,
+ const ArrayVector& children, const std::vector<std::string>& field_names,
std::shared_ptr<Buffer> null_bitmap = NULLPTR,
int64_t null_count = kUnknownNullCount, int64_t offset = 0);
@@ -321,8 +320,7 @@ class ARROW_EXPORT StructArray : public Array {
/// There should be at least one child array. This method does not
/// check that field types and child array types are consistent.
static Result<std::shared_ptr<StructArray>> Make(
- const std::vector<std::shared_ptr<Array>>& children,
- const std::vector<std::shared_ptr<Field>>& fields,
+ const ArrayVector& children, const FieldVector& fields,
std::shared_ptr<Buffer> null_bitmap = NULLPTR,
int64_t null_count = kUnknownNullCount, int64_t offset = 0);
diff --git a/cpp/src/arrow/array/array_view_test.cc b/cpp/src/arrow/array/array_view_test.cc
index 3aac62d..e73bbda 100644
--- a/cpp/src/arrow/array/array_view_test.cc
+++ b/cpp/src/arrow/array/array_view_test.cc
@@ -329,6 +329,17 @@ TEST(TestArrayView, FixedSizeListAsFlat) {
// XXX With nulls (currently fails)
}
+TEST(TestArrayView, FixedSizeListAsFixedSizeBinary) {
+ auto ty1 = fixed_size_list(int32(), 1);
+#if ARROW_LITTLE_ENDIAN
+ auto arr = ArrayFromJSON(ty1, "[[2020568934], [2054316386]]");
+#else
+ auto arr = ArrayFromJSON(ty1, "[[1718579064], [1650553466]]");
+#endif
+ auto expected = ArrayFromJSON(fixed_size_binary(4), R"(["foox", "barz"])");
+ CheckView(arr, expected);
+}
+
TEST(TestArrayView, SparseUnionAsStruct) {
auto child1 = ArrayFromJSON(int16(), "[0, -1, 42]");
auto child2 = ArrayFromJSON(int32(), "[0, 1069547520, -1071644672]");
diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc
index 9a36714..1f69728 100644
--- a/cpp/src/arrow/compute/api_vector.cc
+++ b/cpp/src/arrow/compute/api_vector.cc
@@ -21,13 +21,18 @@
#include <utility>
#include <vector>
+#include "arrow/array/array_nested.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/exec.h"
#include "arrow/datum.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
+#include "arrow/util/checked_cast.h"
namespace arrow {
+
+using internal::checked_pointer_cast;
+
namespace compute {
// ----------------------------------------------------------------------
@@ -60,9 +65,9 @@ const char kCountsFieldName[] = "counts";
const int32_t kValuesFieldIndex = 0;
const int32_t kCountsFieldIndex = 1;
-Result<std::shared_ptr<Array>> ValueCounts(const Datum& value, ExecContext* ctx) {
+Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("value_counts", {value}, ctx));
- return result.make_array();
+ return checked_pointer_cast<StructArray>(result.make_array());
}
// ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h
index 3aa3434..de36202 100644
--- a/cpp/src/arrow/compute/api_vector.h
+++ b/cpp/src/arrow/compute/api_vector.h
@@ -202,8 +202,8 @@ ARROW_EXPORT extern const int32_t kCountsFieldIndex;
/// \since 1.0.0
/// \note API not yet finalized
ARROW_EXPORT
-Result<std::shared_ptr<Array>> ValueCounts(const Datum& value,
- ExecContext* ctx = NULLPTR);
+Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value,
+ ExecContext* ctx = NULLPTR);
/// \brief Dictionary-encode values in an array-like object
/// \param[in] data array-like input
diff --git a/cpp/src/arrow/compute/kernels/vector_hash_test.cc b/cpp/src/arrow/compute/kernels/vector_hash_test.cc
index 70ed84b..10562e9 100644
--- a/cpp/src/arrow/compute/kernels/vector_hash_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_hash_test.cc
@@ -82,9 +82,8 @@ void CheckValueCountsNull(const std::shared_ptr<DataType>& type) {
std::shared_ptr<Array> ex_values = ArrayFromJSON(type, "[]");
std::shared_ptr<Array> ex_counts = ArrayFromJSON(int64(), "[]");
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> result, ValueCounts(input));
- ASSERT_OK(result->ValidateFull());
- auto result_struct = std::dynamic_pointer_cast<StructArray>(result);
+ ASSERT_OK_AND_ASSIGN(auto result_struct, ValueCounts(input));
+ ASSERT_OK(result_struct->ValidateFull());
ASSERT_NE(result_struct->GetFieldByName(kValuesFieldName), nullptr);
// TODO: We probably shouldn't rely on value ordering.
ASSERT_ARRAYS_EQUAL(*ex_values, *result_struct->GetFieldByName(kValuesFieldName));
@@ -615,8 +614,7 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) {
std::vector<std::string> dict_values = {"foo", "bar", "baz", "quuux"};
auto ex_dict = _MakeArray<StringType, std::string>(type, dict_values, {});
- std::vector<int64_t> counts = {3, 2, 1, 1};
- auto ex_counts = _MakeArray<Int64Type, int64_t>(int64(), counts, {});
+ auto ex_counts = _MakeArray<Int64Type, int64_t>(int64(), {3, 2, 1, 1}, {});
ArrayVector arrays = {a1, a2};
auto carr = std::make_shared<ChunkedArray>(arrays);
@@ -636,10 +634,9 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) {
auto dict_carr = std::make_shared<ChunkedArray>(dict_arrays);
// Unique counts
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> counts_array, ValueCounts(carr));
- auto counts_struct = std::dynamic_pointer_cast<StructArray>(counts_array);
- ASSERT_ARRAYS_EQUAL(*ex_dict, *counts_struct->field(0));
- ASSERT_ARRAYS_EQUAL(*ex_counts, *counts_struct->field(1));
+ ASSERT_OK_AND_ASSIGN(auto counts, ValueCounts(carr));
+ ASSERT_ARRAYS_EQUAL(*ex_dict, *counts->field(0));
+ ASSERT_ARRAYS_EQUAL(*ex_counts, *counts->field(1));
// Dictionary encode
ASSERT_OK_AND_ASSIGN(Datum encoded_out, DictionaryEncode(carr));
diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h
index 40ffab5..489339e 100644
--- a/cpp/src/arrow/dataset/dataset_internal.h
+++ b/cpp/src/arrow/dataset/dataset_internal.h
@@ -23,6 +23,7 @@
#include <vector>
#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/file_base.h"
#include "arrow/dataset/type_fwd.h"
#include "arrow/record_batch.h"
#include "arrow/scalar.h"
@@ -51,7 +52,8 @@ inline FragmentIterator GetFragmentsFromDatasets(const DatasetVector& datasets,
return MakeFlattenIterator(std::move(fragments_it));
}
-inline RecordBatchIterator IteratorFromReader(std::shared_ptr<RecordBatchReader> reader) {
+inline RecordBatchIterator IteratorFromReader(
+ const std::shared_ptr<RecordBatchReader>& reader) {
return MakeFunctionIterator([reader] { return reader->Next(); });
}
diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc
index e9430aa..7a378cd 100644
--- a/cpp/src/arrow/dataset/dataset_test.cc
+++ b/cpp/src/arrow/dataset/dataset_test.cc
@@ -385,7 +385,7 @@ class TestEndToEnd : public TestUnionDataset {
auto mock_fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
for (const auto& f : files) {
- ARROW_EXPECT_OK(mock_fs->CreateFile(f.first, f.second, /* recursive */ true));
+ ARROW_EXPECT_OK(mock_fs->CreateFile(f.first, f.second, /*recursive=*/true));
}
fs_ = mock_fs;
diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc
index 806c625..caa44c6 100644
--- a/cpp/src/arrow/dataset/file_base.cc
+++ b/cpp/src/arrow/dataset/file_base.cc
@@ -23,37 +23,31 @@
#include "arrow/dataset/dataset_internal.h"
#include "arrow/dataset/filter.h"
#include "arrow/dataset/scanner.h"
+#include "arrow/dataset/scanner_internal.h"
#include "arrow/filesystem/filesystem.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/io/interfaces.h"
#include "arrow/io/memory.h"
#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
#include "arrow/util/task_group.h"
namespace arrow {
namespace dataset {
-Result<std::shared_ptr<arrow::io::RandomAccessFile>> FileSource::Open() const {
+Result<std::shared_ptr<io::RandomAccessFile>> FileSource::Open() const {
if (filesystem_) {
return filesystem_->OpenInputFile(file_info_);
}
if (buffer_) {
- return std::make_shared<::arrow::io::BufferReader>(buffer_);
+ return std::make_shared<io::BufferReader>(buffer_);
}
return custom_open_();
}
-Result<std::shared_ptr<arrow::io::OutputStream>> WritableFileSource::Open() const {
- if (filesystem_) {
- return filesystem_->OpenOutputStream(path_);
- }
-
- return std::make_shared<::arrow::io::BufferOutputStream>(buffer_);
-}
-
Result<std::shared_ptr<FileFragment>> FileFormat::MakeFragment(
FileSource source, std::shared_ptr<Schema> physical_schema) {
return MakeFragment(std::move(source), scalar(true), std::move(physical_schema));
@@ -71,11 +65,8 @@ Result<std::shared_ptr<FileFragment>> FileFormat::MakeFragment(
new FileFragment(std::move(source), shared_from_this(),
std::move(partition_expression), std::move(physical_schema)));
}
-
-Result<std::shared_ptr<WriteTask>> FileFormat::WriteFragment(
- WritableFileSource destination, std::shared_ptr<Fragment> fragment,
- std::shared_ptr<ScanOptions> scan_options,
- std::shared_ptr<ScanContext> scan_context) {
+Status FileFormat::WriteFragment(RecordBatchReader* batches,
+ io::OutputStream* destination) {
return Status::NotImplemented("writing fragment of format ", type_name());
}
@@ -154,52 +145,99 @@ FragmentIterator FileSystemDataset::GetFragmentsImpl(
return MakeVectorIterator(std::move(fragments));
}
-Result<std::shared_ptr<FileSystemDataset>> FileSystemDataset::Write(
- const WritePlan& plan, std::shared_ptr<ScanOptions> scan_options,
- std::shared_ptr<ScanContext> scan_context) {
- auto filesystem = plan.filesystem;
- if (filesystem == nullptr) {
- filesystem = std::make_shared<fs::LocalFileSystem>();
- }
+struct WriteTask {
+ Status Execute();
- auto task_group = scan_context->TaskGroup();
- auto partition_base_dir = fs::internal::EnsureTrailingSlash(plan.partition_base_dir);
- auto extension = "." + plan.format->type_name();
-
- std::vector<std::shared_ptr<FileFragment>> fragments;
- for (size_t i = 0; i < plan.paths.size(); ++i) {
- const auto& op = plan.fragment_or_partition_expressions[i];
- if (op.kind() == WritePlan::FragmentOrPartitionExpression::FRAGMENT) {
- auto path = partition_base_dir + plan.paths[i] + extension;
-
- const auto& input_fragment = op.fragment();
- FileSource dest(path, filesystem);
-
- ARROW_ASSIGN_OR_RAISE(auto write_task,
- plan.format->WriteFragment({path, filesystem}, input_fragment,
- scan_options, scan_context));
- task_group->Append([write_task] { return write_task->Execute(); });
-
- ARROW_ASSIGN_OR_RAISE(
- auto fragment, plan.format->MakeFragment(
- {path, filesystem}, input_fragment->partition_expression()));
- fragments.push_back(std::move(fragment));
+ /// The basename of files written by this WriteTask. Extensions
+ /// are derived from format
+ std::string basename;
+
+ /// The partitioning with which paths will be generated
+ std::shared_ptr<Partitioning> partitioning;
+
+ /// The format in which fragments will be written
+ std::shared_ptr<FileFormat> format;
+
+ /// The FileSystem and base directory into which fragments will be written
+ std::shared_ptr<fs::FileSystem> filesystem;
+ std::string base_dir;
+
+ /// Batches to be written
+ std::shared_ptr<RecordBatchReader> batches;
+
+ /// An Expression already satisfied by every batch to be written
+ std::shared_ptr<Expression> partition_expression;
+};
+
+Status WriteTask::Execute() {
+ std::unordered_map<std::string, RecordBatchVector> path_to_batches;
+
+ // TODO(bkietz) these calls to Partition() should be scattered across a TaskGroup
+ for (auto maybe_batch : IteratorFromReader(batches)) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, std::move(maybe_batch));
+ ARROW_ASSIGN_OR_RAISE(auto partitioned_batches, partitioning->Partition(batch));
+ for (auto&& partitioned_batch : partitioned_batches) {
+ AndExpression expr(std::move(partitioned_batch.partition_expression),
+ partition_expression);
+ ARROW_ASSIGN_OR_RAISE(std::string path, partitioning->Format(expr));
+ path = fs::internal::EnsureLeadingSlash(path);
+ path_to_batches[path].push_back(std::move(partitioned_batch.batch));
}
}
- RETURN_NOT_OK(task_group->Finish());
+ for (auto&& path_batches : path_to_batches) {
+ auto dir = base_dir + path_batches.first;
+ RETURN_NOT_OK(filesystem->CreateDir(dir, /*recursive=*/true));
- return Make(plan.schema, scalar(true), plan.format, fragments);
-}
+ auto path = fs::internal::ConcatAbstractPath(dir, basename);
+ ARROW_ASSIGN_OR_RAISE(auto destination, filesystem->OpenOutputStream(path));
-Status WriteTask::CreateDestinationParentDir() const {
- if (auto filesystem = destination_.filesystem()) {
- auto parent = fs::internal::GetAbstractPathParent(destination_.path()).first;
- return filesystem->CreateDir(parent, /* recursive = */ true);
+ DCHECK(!path_batches.second.empty());
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+ RecordBatchReader::Make(std::move(path_batches.second)));
+ RETURN_NOT_OK(format->WriteFragment(reader.get(), destination.get()));
}
return Status::OK();
}
+Status FileSystemDataset::Write(std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileFormat> format,
+ std::shared_ptr<fs::FileSystem> filesystem,
+ std::string base_dir,
+ std::shared_ptr<Partitioning> partitioning,
+ std::shared_ptr<ScanContext> scan_context,
+ FragmentIterator fragment_it) {
+ auto task_group = scan_context->TaskGroup();
+
+ base_dir = fs::internal::RemoveTrailingSlash(base_dir).to_string();
+
+ for (const auto& f : partitioning->schema()->fields()) {
+ if (f->type()->id() == Type::DICTIONARY) {
+ return Status::NotImplemented("writing with dictionary partitions");
+ }
+ }
+
+ int i = 0;
+ for (auto maybe_fragment : fragment_it) {
+ ARROW_ASSIGN_OR_RAISE(auto fragment, std::move(maybe_fragment));
+ auto task = std::make_shared<WriteTask>();
+
+ task->basename = "dat_" + std::to_string(i++) + "." + format->type_name();
+ task->partition_expression = fragment->partition_expression();
+ task->format = format;
+ task->filesystem = filesystem;
+ task->base_dir = base_dir;
+ task->partitioning = partitioning;
+
+ // make a record batch reader which yields from a fragment
+ ARROW_ASSIGN_OR_RAISE(task->batches, FragmentRecordBatchReader::Make(
+ std::move(fragment), schema, scan_context));
+ task_group->Append([task] { return task->Execute(); });
+ }
+
+ return task_group->Finish();
+}
+
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h
index 246e71d..c64714a 100644
--- a/cpp/src/arrow/dataset/file_base.h
+++ b/cpp/src/arrow/dataset/file_base.h
@@ -117,45 +117,6 @@ class ARROW_DS_EXPORT FileSource {
Compression::type compression_ = Compression::UNCOMPRESSED;
};
-/// \brief The path and filesystem where an actual file is located or a buffer which can
-/// be written to like a file
-class ARROW_DS_EXPORT WritableFileSource {
- public:
- WritableFileSource(std::string path, std::shared_ptr<fs::FileSystem> filesystem,
- Compression::type compression = Compression::UNCOMPRESSED)
- : path_(std::move(path)),
- filesystem_(std::move(filesystem)),
- compression_(compression) {}
-
- explicit WritableFileSource(std::shared_ptr<ResizableBuffer> buffer,
- Compression::type compression = Compression::UNCOMPRESSED)
- : buffer_(std::move(buffer)), compression_(compression) {}
-
- /// \brief Return the type of raw compression on the file, if any
- Compression::type compression() const { return compression_; }
-
- /// \brief Return the file path, if any. Only valid when file source wraps a path.
- const std::string& path() const {
- static std::string buffer_path = "<Buffer>";
- return filesystem_ ? path_ : buffer_path;
- }
-
- /// \brief Return the filesystem, if any. Otherwise returns nullptr
- const std::shared_ptr<fs::FileSystem>& filesystem() const { return filesystem_; }
-
- /// \brief Return the buffer containing the file, if any. Otherwise returns nullptr
- const std::shared_ptr<ResizableBuffer>& buffer() const { return buffer_; }
-
- /// \brief Get an OutputStream which wraps this file source
- Result<std::shared_ptr<arrow::io::OutputStream>> Open() const;
-
- private:
- std::string path_;
- std::shared_ptr<fs::FileSystem> filesystem_;
- std::shared_ptr<ResizableBuffer> buffer_;
- Compression::type compression_ = Compression::UNCOMPRESSED;
-};
-
/// \brief Base class for file format implementation
class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this<FileFormat> {
public:
@@ -190,12 +151,9 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this<FileForma
Result<std::shared_ptr<FileFragment>> MakeFragment(
FileSource source, std::shared_ptr<Schema> physical_schema = NULLPTR);
- /// \brief Write a fragment. If the parent directory of destination does not exist, it
- /// will be created.
- virtual Result<std::shared_ptr<WriteTask>> WriteFragment(
- WritableFileSource destination, std::shared_ptr<Fragment> fragment,
- std::shared_ptr<ScanOptions> options,
- std::shared_ptr<ScanContext> scan_context); // FIXME(bkietz) make this pure virtual
+ /// \brief Write a fragment.
+ /// FIXME(bkietz) make this pure virtual
+ virtual Status WriteFragment(RecordBatchReader* batches, io::OutputStream* destination);
};
/// \brief A Fragment that is stored in a file with a known format
@@ -248,14 +206,20 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset {
std::shared_ptr<FileFormat> format,
std::vector<std::shared_ptr<FileFragment>> fragments);
- /// \brief Write to a new format and filesystem location, preserving partitioning.
+ /// \brief Write a dataset.
///
- /// \param[in] plan the WritePlan to execute.
- /// \param[in] scan_options options in which to scan fragments
- /// \param[in] scan_context context in which to scan fragments before writing.
- static Result<std::shared_ptr<FileSystemDataset>> Write(
- const WritePlan& plan, std::shared_ptr<ScanOptions> scan_options,
- std::shared_ptr<ScanContext> scan_context);
+ /// \param[in] schema Schema of written dataset.
+ /// \param[in] format FileFormat with which fragments will be written.
+ /// \param[in] filesystem FileSystem into which the dataset will be written.
+ /// \param[in] base_dir Root directory into which the dataset will be written.
+ /// \param[in] partitioning Partitioning used to generate fragment paths.
+ /// \param[in] scan_context Resource pool used to scan and write fragments.
+ /// \param[in] fragments Fragments to be written to disk.
+ static Status Write(std::shared_ptr<Schema> schema, std::shared_ptr<FileFormat> format,
+ std::shared_ptr<fs::FileSystem> filesystem, std::string base_dir,
+ std::shared_ptr<Partitioning> partitioning,
+ std::shared_ptr<ScanContext> scan_context,
+ FragmentIterator fragments);
/// \brief Return the type name of the dataset.
std::string type_name() const override { return "filesystem"; }
@@ -284,69 +248,5 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset {
std::vector<std::shared_ptr<FileFragment>> fragments_;
};
-/// \brief Write a fragment to a single OutputStream.
-class ARROW_DS_EXPORT WriteTask {
- public:
- virtual Status Execute() = 0;
-
- virtual ~WriteTask() = default;
-
- const WritableFileSource& destination() const;
- const std::shared_ptr<FileFormat>& format() const { return format_; }
-
- protected:
- WriteTask(WritableFileSource destination, std::shared_ptr<FileFormat> format)
- : destination_(std::move(destination)), format_(std::move(format)) {}
-
- Status CreateDestinationParentDir() const;
-
- WritableFileSource destination_;
- std::shared_ptr<FileFormat> format_;
-};
-
-/// \brief A declarative plan for writing fragments to a partitioned directory structure.
-class ARROW_DS_EXPORT WritePlan {
- public:
- /// The partitioning with which paths were generated
- std::shared_ptr<Partitioning> partitioning;
-
- /// The schema of the Dataset which will be written
- std::shared_ptr<Schema> schema;
-
- /// The format into which fragments will be written
- std::shared_ptr<FileFormat> format;
-
- /// The FileSystem and base directory for partitioned writing
- std::shared_ptr<fs::FileSystem> filesystem;
- std::string partition_base_dir;
-
- class FragmentOrPartitionExpression {
- public:
- enum Kind { EXPRESSION, FRAGMENT };
-
- explicit FragmentOrPartitionExpression(std::shared_ptr<Expression> partition_expr)
- : kind_(EXPRESSION), partition_expr_(std::move(partition_expr)) {}
-
- explicit FragmentOrPartitionExpression(std::shared_ptr<Fragment> fragment)
- : kind_(FRAGMENT), fragment_(std::move(fragment)) {}
-
- Kind kind() const { return kind_; }
-
- const std::shared_ptr<Expression>& partition_expr() const { return partition_expr_; }
- const std::shared_ptr<Fragment>& fragment() const { return fragment_; }
-
- private:
- Kind kind_;
- std::shared_ptr<Expression> partition_expr_;
- std::shared_ptr<Fragment> fragment_;
- };
-
- /// If fragment_or_partition_expressions[i] is a Fragment, that Fragment will be
- /// written to paths[i]. If it is an Expression, a directory representing that partition
- /// expression will be created at paths[i] instead.
- std::vector<FragmentOrPartitionExpression> fragment_or_partition_expressions;
- std::vector<std::string> paths;
-};
-
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc
index 63f3020..de25875b5 100644
--- a/cpp/src/arrow/dataset/file_ipc.cc
+++ b/cpp/src/arrow/dataset/file_ipc.cc
@@ -159,54 +159,17 @@ Result<ScanTaskIterator> IpcFileFormat::ScanFile(std::shared_ptr<ScanOptions> op
fragment->source());
}
-class IpcWriteTask : public WriteTask {
- public:
- IpcWriteTask(WritableFileSource destination, std::shared_ptr<FileFormat> format,
- std::shared_ptr<Fragment> fragment,
- std::shared_ptr<ScanOptions> scan_options,
- std::shared_ptr<ScanContext> scan_context)
- : WriteTask(std::move(destination), std::move(format)),
- fragment_(std::move(fragment)),
- scan_options_(std::move(scan_options)),
- scan_context_(std::move(scan_context)) {}
-
- Status Execute() override {
- RETURN_NOT_OK(CreateDestinationParentDir());
-
- auto schema = scan_options_->schema();
-
- ARROW_ASSIGN_OR_RAISE(auto out_stream, destination_.Open());
- ARROW_ASSIGN_OR_RAISE(auto writer, ipc::NewFileWriter(out_stream.get(), schema));
- ARROW_ASSIGN_OR_RAISE(auto scan_task_it,
- fragment_->Scan(scan_options_, scan_context_));
-
- for (auto maybe_scan_task : scan_task_it) {
- ARROW_ASSIGN_OR_RAISE(auto scan_task, maybe_scan_task);
-
- ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
-
- for (auto maybe_batch : batch_it) {
- ARROW_ASSIGN_OR_RAISE(auto batch, std::move(maybe_batch));
- RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
- }
- }
-
- return writer->Close();
+Status IpcFileFormat::WriteFragment(RecordBatchReader* batches,
+ io::OutputStream* destination) {
+ ARROW_ASSIGN_OR_RAISE(auto writer, ipc::NewFileWriter(destination, batches->schema()));
+
+ for (;;) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batches->Next());
+ if (batch == nullptr) break;
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
}
- private:
- std::shared_ptr<Fragment> fragment_;
- std::shared_ptr<ScanOptions> scan_options_;
- std::shared_ptr<ScanContext> scan_context_;
-};
-
-Result<std::shared_ptr<WriteTask>> IpcFileFormat::WriteFragment(
- WritableFileSource destination, std::shared_ptr<Fragment> fragment,
- std::shared_ptr<ScanOptions> scan_options,
- std::shared_ptr<ScanContext> scan_context) {
- return std::make_shared<IpcWriteTask>(std::move(destination), shared_from_this(),
- std::move(fragment), std::move(scan_options),
- std::move(scan_context));
+ return writer->Close();
}
} // namespace dataset
diff --git a/cpp/src/arrow/dataset/file_ipc.h b/cpp/src/arrow/dataset/file_ipc.h
index ad18354..2b8b94d 100644
--- a/cpp/src/arrow/dataset/file_ipc.h
+++ b/cpp/src/arrow/dataset/file_ipc.h
@@ -47,10 +47,8 @@ class ARROW_DS_EXPORT IpcFileFormat : public FileFormat {
std::shared_ptr<ScanContext> context,
FileFragment* fragment) const override;
- Result<std::shared_ptr<WriteTask>> WriteFragment(
- WritableFileSource destination, std::shared_ptr<Fragment> fragment,
- std::shared_ptr<ScanOptions> options,
- std::shared_ptr<ScanContext> context) override;
+ Status WriteFragment(RecordBatchReader* batches,
+ io::OutputStream* destination) override;
};
} // namespace dataset
diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc
index c557621..e1b5740 100644
--- a/cpp/src/arrow/dataset/file_ipc_test.cc
+++ b/cpp/src/arrow/dataset/file_ipc_test.cc
@@ -22,6 +22,7 @@
#include <vector>
#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/discovery.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/filter.h"
#include "arrow/dataset/partition.h"
@@ -87,10 +88,10 @@ class TestIpcFileFormat : public ArrowIpcWriterMixin {
kBatchRepetitions);
}
- Result<WritableFileSource> GetFileSink() {
+ Result<std::shared_ptr<io::BufferOutputStream>> GetFileSink() {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> buffer,
AllocateResizableBuffer(0));
- return WritableFileSource(std::move(buffer));
+ return std::make_shared<io::BufferOutputStream>(buffer);
}
RecordBatchIterator Batches(ScanTaskIterator scan_task_it) {
@@ -149,63 +150,327 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderWithVirtualColumn) {
TEST_F(TestIpcFileFormat, WriteRecordBatchReader) {
std::shared_ptr<RecordBatchReader> reader = GetRecordBatchReader();
auto source = GetFileSource(reader.get());
+ reader = GetRecordBatchReader();
opts_ = ScanOptions::Make(reader->schema());
- ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source));
EXPECT_OK_AND_ASSIGN(auto sink, GetFileSink());
- EXPECT_OK_AND_ASSIGN(auto write_task,
- format_->WriteFragment(sink, fragment, opts_, ctx_));
+ ASSERT_OK(format_->WriteFragment(reader.get(), sink.get()));
- ASSERT_OK(write_task->Execute());
+ EXPECT_OK_AND_ASSIGN(auto written, sink->Finish());
- AssertBufferEqual(*sink.buffer(), *source->buffer());
+ AssertBufferEqual(*written, *source->buffer());
}
class TestIpcFileSystemDataset : public TestIpcFileFormat,
- public MakeFileSystemDatasetMixin {};
-
-TEST_F(TestIpcFileSystemDataset, Write) {
- std::string paths = R"(
- old_root/i32=0/str=aaa/dat
- old_root/i32=0/str=bbb/dat
- old_root/i32=0/str=ccc/dat
- old_root/i32=1/str=aaa/dat
- old_root/i32=1/str=bbb/dat
- old_root/i32=1/str=ccc/dat
- )";
-
- ExpressionVector partitions{
- ("i32"_ == 0 and "str"_ == "aaa").Copy(), ("i32"_ == 0 and "str"_ == "bbb").Copy(),
- ("i32"_ == 0 and "str"_ == "ccc").Copy(), ("i32"_ == 1 and "str"_ == "aaa").Copy(),
- ("i32"_ == 1 and "str"_ == "bbb").Copy(), ("i32"_ == 1 and "str"_ == "ccc").Copy(),
- };
-
- MakeDatasetFromPathlist(paths, scalar(true), partitions);
-
- auto schema = arrow::schema({field("i32", int32()), field("str", utf8())});
- opts_ = ScanOptions::Make(schema);
-
- auto partitioning_factory = DirectoryPartitioning::MakeFactory({"str", "i32"});
- ASSERT_OK_AND_ASSIGN(
- auto plan, partitioning_factory->MakeWritePlan(schema, dataset_->GetFragments()));
+ public MakeFileSystemDatasetMixin {
+ public:
+ using PathAndContent = std::unordered_map<std::string, std::string>;
+
+ void SetUp() override {
+ PathAndContent source_files;
+
+ source_files["/dataset/year=2018/month=01/dat0.json"] = R"([
+ {"region": "NY", "model": "3", "sales": 742.0, "country": "US"},
+ {"region": "NY", "model": "S", "sales": 304.125, "country": "US"},
+ {"region": "NY", "model": "Y", "sales": 27.5, "country": "US"}
+ ])";
+ source_files["/dataset/year=2018/month=01/dat1.json"] = R"([
+ {"region": "QC", "model": "3", "sales": 512, "country": "CA"},
+ {"region": "QC", "model": "S", "sales": 978, "country": "CA"},
+ {"region": "NY", "model": "X", "sales": 136.25, "country": "US"},
+ {"region": "QC", "model": "X", "sales": 1.0, "country": "CA"},
+ {"region": "QC", "model": "Y", "sales": 69, "country": "CA"}
+ ])";
+ source_files["/dataset/year=2019/month=01/dat0.json"] = R"([
+ {"region": "CA", "model": "3", "sales": 273.5, "country": "US"},
+ {"region": "CA", "model": "S", "sales": 13, "country": "US"},
+ {"region": "CA", "model": "X", "sales": 54, "country": "US"},
+ {"region": "QC", "model": "S", "sales": 10, "country": "CA"},
+ {"region": "CA", "model": "Y", "sales": 21, "country": "US"}
+ ])";
+ source_files["/dataset/year=2019/month=01/dat1.json"] = R"([
+ {"region": "QC", "model": "3", "sales": 152.25, "country": "CA"},
+ {"region": "QC", "model": "X", "sales": 42, "country": "CA"},
+ {"region": "QC", "model": "Y", "sales": 37, "country": "CA"}
+ ])";
+ source_files["/dataset/.pesky"] = "garbage content";
+
+ auto mock_fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+ for (const auto& f : source_files) {
+ ARROW_EXPECT_OK(mock_fs->CreateFile(f.first, f.second, /* recursive */ true));
+ }
+ fs_ = mock_fs;
+
+ /// schema for the whole dataset (both source and destination)
+ schema_ = schema({
+ field("region", utf8()),
+ field("model", utf8()),
+ field("sales", float64()),
+ field("year", int32()),
+ field("month", int32()),
+ field("country", utf8()),
+ });
+
+ /// Dummy file format for source dataset. Note that it isn't partitioned on country
+ auto source_format = std::make_shared<JSONRecordBatchFileFormat>(
+ SchemaFromColumnNames(schema_, {"region", "model", "sales", "country"}));
+
+ fs::FileSelector s;
+ s.base_dir = "/dataset";
+ s.recursive = true;
+
+ FileSystemFactoryOptions options;
+ options.selector_ignore_prefixes = {"."};
+ options.partitioning = HivePartitioning::MakeFactory();
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs_, s, source_format, options));
+ ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish());
+ }
+
+ void AssertWrittenAsExpected() {
+ std::vector<std::string> files;
+ for (const auto& file_contents : expected_files_) {
+ files.push_back(file_contents.first);
+ }
+ EXPECT_THAT(checked_pointer_cast<FileSystemDataset>(written_)->files(),
+ testing::UnorderedElementsAreArray(files));
+
+ for (auto maybe_fragment : written_->GetFragments()) {
+ ASSERT_OK_AND_ASSIGN(auto fragment, std::move(maybe_fragment));
+
+ ASSERT_OK_AND_ASSIGN(auto actual_physical_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(*expected_physical_schema_, *actual_physical_schema,
+ /*verbose=*/true);
+
+ const auto& path = checked_pointer_cast<FileFragment>(fragment)->source().path();
- plan.format = format_;
- plan.filesystem = fs_;
- plan.partition_base_dir = "new_root/";
+ auto expected_struct = ArrayFromJSON(struct_(expected_physical_schema_->fields()),
+ {expected_files_[path]});
- ASSERT_OK_AND_ASSIGN(auto written, FileSystemDataset::Write(plan, opts_, ctx_));
+ ASSERT_OK_AND_ASSIGN(auto scanner, ScannerBuilder(actual_physical_schema, fragment,
+ std::make_shared<ScanContext>())
+ .Finish());
+ ASSERT_OK_AND_ASSIGN(auto actual_table, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(actual_table, actual_table->CombineChunks());
+ std::shared_ptr<Array> actual_struct;
- auto parent_directories = written->files();
- for (auto& path : parent_directories) {
- EXPECT_EQ(fs::internal::GetAbstractPathExtension(path), "ipc");
- path = fs::internal::GetAbstractPathParent(path).first;
+ for (auto maybe_batch :
+ IteratorFromReader(std::make_shared<TableBatchReader>(*actual_table))) {
+ ASSERT_OK_AND_ASSIGN(auto batch, std::move(maybe_batch));
+ ASSERT_OK_AND_ASSIGN(actual_struct, batch->ToStructArray());
+ }
+
+ AssertArraysEqual(*expected_struct, *actual_struct, /*verbose=*/true);
+ }
}
- EXPECT_THAT(parent_directories,
- testing::ElementsAre("new_root/aaa/0", "new_root/aaa/1", "new_root/bbb/0",
- "new_root/bbb/1", "new_root/ccc/0", "new_root/ccc/1"));
+ PathAndContent expected_files_;
+ std::shared_ptr<Schema> expected_physical_schema_;
+ std::shared_ptr<Dataset> written_;
+};
+
+TEST_F(TestIpcFileSystemDataset, WriteWithIdenticalPartitioningSchema) {
+ auto desired_partitioning = std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(schema_, {"year", "month"}));
+
+ ASSERT_OK(FileSystemDataset::Write(
+ schema_, format_, fs_, "new_root/", desired_partitioning,
+ std::make_shared<ScanContext>(), dataset_->GetFragments()));
+
+ fs::FileSelector s;
+ s.recursive = true;
+ s.base_dir = "/new_root";
+
+ FileSystemFactoryOptions options;
+ options.partitioning = desired_partitioning;
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs_, s, format_, options));
+ ASSERT_OK_AND_ASSIGN(written_, factory->Finish());
+
+ expected_files_["/new_root/2018/1/dat_0.ipc"] = R"([
+ {"region": "NY", "model": "3", "sales": 742.0, "country": "US"},
+ {"region": "NY", "model": "S", "sales": 304.125, "country": "US"},
+ {"region": "NY", "model": "Y", "sales": 27.5, "country": "US"}
+ ])";
+ expected_files_["/new_root/2018/1/dat_1.ipc"] = R"([
+ {"region": "QC", "model": "3", "sales": 512, "country": "CA"},
+ {"region": "QC", "model": "S", "sales": 978, "country": "CA"},
+ {"region": "NY", "model": "X", "sales": 136.25, "country": "US"},
+ {"region": "QC", "model": "X", "sales": 1.0, "country": "CA"},
+ {"region": "QC", "model": "Y", "sales": 69, "country": "CA"}
+ ])";
+ expected_files_["/new_root/2019/1/dat_2.ipc"] = R"([
+ {"region": "CA", "model": "3", "sales": 273.5, "country": "US"},
+ {"region": "CA", "model": "S", "sales": 13, "country": "US"},
+ {"region": "CA", "model": "X", "sales": 54, "country": "US"},
+ {"region": "QC", "model": "S", "sales": 10, "country": "CA"},
+ {"region": "CA", "model": "Y", "sales": 21, "country": "US"}
+ ])";
+ expected_files_["/new_root/2019/1/dat_3.ipc"] = R"([
+ {"region": "QC", "model": "3", "sales": 152.25, "country": "CA"},
+ {"region": "QC", "model": "X", "sales": 42, "country": "CA"},
+ {"region": "QC", "model": "Y", "sales": 37, "country": "CA"}
+ ])";
+ expected_physical_schema_ =
+ SchemaFromColumnNames(schema_, {"region", "model", "sales", "country"});
+
+ AssertWrittenAsExpected();
+}
+
+TEST_F(TestIpcFileSystemDataset, WriteWithUnrelatedPartitioningSchema) {
+ auto desired_partitioning = std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(schema_, {"country", "region"}));
+
+ ASSERT_OK(FileSystemDataset::Write(
+ schema_, format_, fs_, "new_root/", desired_partitioning,
+ std::make_shared<ScanContext>(), dataset_->GetFragments()));
+
+ fs::FileSelector s;
+ s.recursive = true;
+ s.base_dir = "/new_root";
+
+ FileSystemFactoryOptions options;
+ options.partitioning = desired_partitioning;
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs_, s, format_, options));
+ ASSERT_OK_AND_ASSIGN(written_, factory->Finish());
+
+ // XXX first thing a user will be annoyed by: we don't support left
+ // padding the month field with 0.
+ expected_files_["/new_root/US/NY/dat_0.ipc"] = R"([
+ {"year": 2018, "month": 1, "model": "3", "sales": 742.0},
+ {"year": 2018, "month": 1, "model": "S", "sales": 304.125},
+ {"year": 2018, "month": 1, "model": "Y", "sales": 27.5}
+ ])";
+ expected_files_["/new_root/US/NY/dat_1.ipc"] = R"([
+ {"year": 2018, "month": 1, "model": "X", "sales": 136.25}
+ ])";
+ expected_files_["/new_root/CA/QC/dat_1.ipc"] = R"([
+ {"year": 2018, "month": 1, "model": "3", "sales": 512},
+ {"year": 2018, "month": 1, "model": "S", "sales": 978},
+ {"year": 2018, "month": 1, "model": "X", "sales": 1.0},
+ {"year": 2018, "month": 1, "model": "Y", "sales": 69}
+ ])";
+ expected_files_["/new_root/US/CA/dat_2.ipc"] = R"([
+ {"year": 2019, "month": 1, "model": "3", "sales": 273.5},
+ {"year": 2019, "month": 1, "model": "S", "sales": 13},
+ {"year": 2019, "month": 1, "model": "X", "sales": 54},
+ {"year": 2019, "month": 1, "model": "Y", "sales": 21}
+ ])";
+ expected_files_["/new_root/CA/QC/dat_2.ipc"] = R"([
+ {"year": 2019, "month": 1, "model": "S", "sales": 10}
+ ])";
+ expected_files_["/new_root/CA/QC/dat_3.ipc"] = R"([
+ {"year": 2019, "month": 1, "model": "3", "sales": 152.25},
+ {"year": 2019, "month": 1, "model": "X", "sales": 42},
+ {"year": 2019, "month": 1, "model": "Y", "sales": 37}
+ ])";
+ expected_physical_schema_ =
+ SchemaFromColumnNames(schema_, {"model", "sales", "year", "month"});
+
+ AssertWrittenAsExpected();
+}
+
+TEST_F(TestIpcFileSystemDataset, WriteWithSupersetPartitioningSchema) {
+ auto desired_partitioning = std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(schema_, {"year", "month", "country", "region"}));
+
+ ASSERT_OK(FileSystemDataset::Write(
+ schema_, format_, fs_, "new_root/", desired_partitioning,
+ std::make_shared<ScanContext>(), dataset_->GetFragments()));
+
+ fs::FileSelector s;
+ s.recursive = true;
+ s.base_dir = "/new_root";
+
+ FileSystemFactoryOptions options;
+ options.partitioning = desired_partitioning;
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs_, s, format_, options));
+ ASSERT_OK_AND_ASSIGN(written_, factory->Finish());
+
+ // XXX first thing a user will be annoyed by: we don't support left
+ // padding the month field with 0.
+ expected_files_["/new_root/2018/1/US/NY/dat_0.ipc"] = R"([
+ {"model": "3", "sales": 742.0},
+ {"model": "S", "sales": 304.125},
+ {"model": "Y", "sales": 27.5}
+ ])";
+ expected_files_["/new_root/2018/1/US/NY/dat_1.ipc"] = R"([
+ {"model": "X", "sales": 136.25}
+ ])";
+ expected_files_["/new_root/2018/1/CA/QC/dat_1.ipc"] = R"([
+ {"model": "3", "sales": 512},
+ {"model": "S", "sales": 978},
+ {"model": "X", "sales": 1.0},
+ {"model": "Y", "sales": 69}
+ ])";
+ expected_files_["/new_root/2019/1/US/CA/dat_2.ipc"] = R"([
+ {"model": "3", "sales": 273.5},
+ {"model": "S", "sales": 13},
+ {"model": "X", "sales": 54},
+ {"model": "Y", "sales": 21}
+ ])";
+ expected_files_["/new_root/2019/1/CA/QC/dat_2.ipc"] = R"([
+ {"model": "S", "sales": 10}
+ ])";
+ expected_files_["/new_root/2019/1/CA/QC/dat_3.ipc"] = R"([
+ {"model": "3", "sales": 152.25},
+ {"model": "X", "sales": 42},
+ {"model": "Y", "sales": 37}
+ ])";
+ expected_physical_schema_ = SchemaFromColumnNames(schema_, {"model", "sales"});
+
+ AssertWrittenAsExpected();
+}
+
+TEST_F(TestIpcFileSystemDataset, WriteWithEmptyPartitioningSchema) {
+ auto desired_partitioning =
+ std::make_shared<DirectoryPartitioning>(SchemaFromColumnNames(schema_, {}));
+
+ ASSERT_OK(FileSystemDataset::Write(
+ schema_, format_, fs_, "new_root/", desired_partitioning,
+ std::make_shared<ScanContext>(), dataset_->GetFragments()));
+
+ fs::FileSelector s;
+ s.recursive = true;
+ s.base_dir = "/new_root";
+
+ FileSystemFactoryOptions options;
+ options.partitioning = desired_partitioning;
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs_, s, format_, options));
+ ASSERT_OK_AND_ASSIGN(written_, factory->Finish());
+
+ expected_files_["/new_root/dat_0.ipc"] = R"([
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "3", "sales": 742.0},
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "S", "sales": 304.125},
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "Y", "sales": 27.5}
+ ])";
+ expected_files_["/new_root/dat_1.ipc"] = R"([
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "3", "sales": 512},
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "S", "sales": 978},
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "X", "sales": 136.25},
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "X", "sales": 1.0},
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "Y", "sales": 69}
+ ])";
+ expected_files_["/new_root/dat_2.ipc"] = R"([
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "3", "sales": 273.5},
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "S", "sales": 13},
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "X", "sales": 54},
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "S", "sales": 10},
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "Y", "sales": 21}
+ ])";
+ expected_files_["/new_root/dat_3.ipc"] = R"([
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "3", "sales": 152.25},
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "X", "sales": 42},
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "Y", "sales": 37}
+ ])";
+ expected_physical_schema_ = schema_;
+
+ AssertWrittenAsExpected();
}
TEST_F(TestIpcFileFormat, OpenFailureWithRelevantError) {
diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc
index b35b8da..d99d624 100644
--- a/cpp/src/arrow/dataset/filter.cc
+++ b/cpp/src/arrow/dataset/filter.cc
@@ -28,6 +28,7 @@
#include "arrow/buffer.h"
#include "arrow/buffer_builder.h"
+#include "arrow/builder.h"
#include "arrow/compute/api.h"
#include "arrow/dataset/dataset.h"
#include "arrow/io/memory.h"
@@ -38,6 +39,7 @@
#include "arrow/scalar.h"
#include "arrow/type_fwd.h"
#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
#include "arrow/util/string.h"
@@ -1493,5 +1495,191 @@ Result<std::shared_ptr<Expression>> Expression::Deserialize(const Buffer& serial
return DeserializeImpl{}.FromBuffer(serialized);
}
+// Transform an array of counts to offsets which will divide a ListArray
+// into an equal number of slices with corresponding lengths.
+inline Result<std::shared_ptr<Array>> CountsToOffsets(
+ std::shared_ptr<Int64Array> counts) {
+ Int32Builder offset_builder;
+ RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1));
+ offset_builder.UnsafeAppend(0);
+
+ for (int64_t i = 0; i < counts->length(); ++i) {
+ DCHECK_NE(counts->Value(i), 0);
+ auto next_offset = static_cast<int32_t>(offset_builder[i] + counts->Value(i));
+ offset_builder.UnsafeAppend(next_offset);
+ }
+
+ std::shared_ptr<Array> offsets;
+ RETURN_NOT_OK(offset_builder.Finish(&offsets));
+ return offsets;
+}
+
+// Helper for simultaneous dictionary encoding of multiple arrays.
+//
+// The fused dictionary is the Cartesian product of the individual dictionaries.
+// For example given two arrays A, B where A has unique values ["ex", "why"]
+// and B has unique values [0, 1] the fused dictionary is the set of tuples
+// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]].
+//
+// TODO(bkietz) this capability belongs in an Action of the hash kernels, where
+// it can be used to group aggregates without materializing a grouped batch.
+// For the purposes of writing we need the materialized grouped batch anyway
+// since no Writers accept a selection vector.
+class StructDictionary {
+ public:
+ struct Encoded {
+ std::shared_ptr<Int32Array> indices;
+ std::shared_ptr<StructDictionary> dictionary;
+ };
+
+ static Result<Encoded> Encode(const ArrayVector& columns) {
+ Encoded out{nullptr, std::make_shared<StructDictionary>()};
+
+ for (const auto& column : columns) {
+ if (column->null_count() != 0) {
+ return Status::NotImplemented("Grouping on a field with nulls");
+ }
+
+ RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices));
+ }
+
+ return out;
+ }
+
+ Result<std::shared_ptr<StructArray>> Decode(std::shared_ptr<Int32Array> fused_indices,
+ FieldVector fields) {
+ std::vector<Int32Builder> builders(dictionaries_.size());
+ for (Int32Builder& b : builders) {
+ RETURN_NOT_OK(b.Resize(fused_indices->length()));
+ }
+
+ std::vector<int32_t> codes(dictionaries_.size());
+ for (int64_t i = 0; i < fused_indices->length(); ++i) {
+ Expand(fused_indices->Value(i), codes.data());
+
+ auto builder_it = builders.begin();
+ for (int32_t index : codes) {
+ builder_it++->UnsafeAppend(index);
+ }
+ }
+
+ ArrayVector columns(dictionaries_.size());
+ for (size_t i = 0; i < dictionaries_.size(); ++i) {
+ std::shared_ptr<ArrayData> indices;
+ RETURN_NOT_OK(builders[i].FinishInternal(&indices));
+
+ ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices));
+ columns[i] = column.make_array();
+ }
+
+ return StructArray::Make(std::move(columns), std::move(fields));
+ }
+
+ private:
+ Status AddOne(const std::shared_ptr<Array>& column,
+ std::shared_ptr<Int32Array>* fused_indices) {
+ ARROW_ASSIGN_OR_RAISE(Datum encoded, compute::DictionaryEncode(column));
+ ArrayData* encoded_array = encoded.mutable_array();
+
+ auto indices = std::make_shared<Int32Array>(encoded_array->length,
+ std::move(encoded_array->buffers[1]));
+
+ dictionaries_.push_back(MakeArray(std::move(encoded_array->dictionary)));
+ auto dictionary_size = static_cast<int32_t>(dictionaries_.back()->length());
+
+ if (*fused_indices == nullptr) {
+ *fused_indices = std::move(indices);
+ size_ = dictionary_size;
+ return Status::OK();
+ }
+
+ // It's useful to think about the case where each of dictionaries_ has size 10.
+ // In this case the decimal digit in the ones place is the code in dictionaries_[0],
+ // the tens place corresponds to dictionaries_[1], etc.
+ // The incumbent indices must be shifted to the hundreds place so as not to collide.
+ ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices,
+ compute::Multiply(indices, MakeScalar(size_)));
+
+ ARROW_ASSIGN_OR_RAISE(new_fused_indices,
+ compute::Add(new_fused_indices, *fused_indices));
+
+ *fused_indices = checked_pointer_cast<Int32Array>(new_fused_indices.make_array());
+
+ // XXX should probably cap this at 2**15 or so
+ DCHECK(!internal::HasPositiveMultiplyOverflow(size_, dictionary_size));
+ size_ *= dictionary_size;
+ return Status::OK();
+ }
+
+ // expand a fused code into component dict codes, order is in order of addition
+ void Expand(int32_t fused_code, int32_t* codes) {
+ for (size_t i = 0; i < dictionaries_.size(); ++i) {
+ auto dictionary_size = static_cast<int32_t>(dictionaries_[i]->length());
+ codes[i] = fused_code % dictionary_size;
+ fused_code /= dictionary_size;
+ }
+ }
+
+ int32_t size_;
+ ArrayVector dictionaries_;
+};
+
+Result<std::shared_ptr<StructArray>> MakeGroupings(const StructArray& by) {
+ if (by.num_fields() == 0) {
+ return Status::NotImplemented("Grouping with no criteria");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields()));
+
+ ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortToIndices(*fused.indices));
+ ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices));
+ fused.indices = checked_pointer_cast<Int32Array>(sorted.make_array());
+
+ ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values,
+ compute::ValueCounts(fused.indices));
+ fused.indices.reset();
+
+ auto unique_fused_indices =
+ checked_pointer_cast<Int32Array>(fused_counts_and_values->GetFieldByName("values"));
+ ARROW_ASSIGN_OR_RAISE(
+ auto unique_rows,
+ fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields()));
+
+ auto counts =
+ checked_pointer_cast<Int64Array>(fused_counts_and_values->GetFieldByName("counts"));
+ ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts)));
+
+ ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices,
+ ListArray::FromArrays(*offsets, *sort_indices));
+
+ return StructArray::Make(
+ ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)},
+ std::vector<std::string>{"values", "groupings"});
+}
+
+Result<std::shared_ptr<ListArray>> ApplyGroupings(const ListArray& groupings,
+ const Array& array) {
+ ARROW_ASSIGN_OR_RAISE(Datum sorted,
+ compute::Take(array, groupings.data()->child_data[0]));
+
+ return std::make_shared<ListArray>(list(array.type()), groupings.length(),
+ groupings.value_offsets(), sorted.make_array());
+}
+
+Result<RecordBatchVector> ApplyGroupings(const ListArray& groupings,
+ const std::shared_ptr<RecordBatch>& batch) {
+ ARROW_ASSIGN_OR_RAISE(Datum sorted,
+ compute::Take(batch, groupings.data()->child_data[0]));
+
+ const auto& sorted_batch = *sorted.record_batch();
+
+ RecordBatchVector out(static_cast<size_t>(groupings.length()));
+ for (size_t i = 0; i < out.size(); ++i) {
+ out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i));
+ }
+
+ return out;
+}
+
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h
index b7d4655..ebf58cc 100644
--- a/cpp/src/arrow/dataset/filter.h
+++ b/cpp/src/arrow/dataset/filter.h
@@ -641,5 +641,22 @@ class ARROW_DS_EXPORT TreeEvaluator : public ExpressionEvaluator {
struct Impl;
};
+/// \brief Assemble lists of indices of identical rows.
+///
+/// \param[in] by A StructArray whose columns will be used as grouping criteria.
+/// \return A StructArray mapping unique rows (in field "values", represented as a
+/// StructArray with the same fields as `by`) to lists of indices where
+/// that row appears (in field "groupings").
+ARROW_DS_EXPORT
+Result<std::shared_ptr<StructArray>> MakeGroupings(const StructArray& by);
+
+/// \brief Produce slices of an Array which correspond to the provided groupings.
+ARROW_DS_EXPORT
+Result<std::shared_ptr<ListArray>> ApplyGroupings(const ListArray& groupings,
+ const Array& array);
+ARROW_DS_EXPORT
+Result<RecordBatchVector> ApplyGroupings(const ListArray& groupings,
+ const std::shared_ptr<RecordBatch>& batch);
+
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc
index 7c3c1a2..8e16208 100644
--- a/cpp/src/arrow/dataset/filter_test.cc
+++ b/cpp/src/arrow/dataset/filter_test.cc
@@ -592,5 +592,57 @@ TEST(ExpressionSerializationTest, RoundTrips) {
}
}
+void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json,
+ const std::string& expected_json) {
+ FieldVector fields_with_ids = by_fields;
+ fields_with_ids.push_back(field("ids", list(int32())));
+ auto expected = ArrayFromJSON(struct_(fields_with_ids), expected_json);
+
+ FieldVector fields_with_id = by_fields;
+ fields_with_id.push_back(field("id", int32()));
+ auto batch = RecordBatchFromJSON(schema(fields_with_id), batch_json);
+
+ ASSERT_OK_AND_ASSIGN(auto by, batch->RemoveColumn(batch->num_columns() - 1)
+ .Map([](std::shared_ptr<RecordBatch> by) {
+ return by->ToStructArray();
+ }));
+
+ ASSERT_OK_AND_ASSIGN(auto groupings_and_values, MakeGroupings(*by));
+
+ auto groupings =
+ checked_pointer_cast<ListArray>(groupings_and_values->GetFieldByName("groupings"));
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> grouped_ids,
+ ApplyGroupings(*groupings, *batch->GetColumnByName("id")));
+
+ ArrayVector columns =
+ checked_cast<const StructArray&>(*groupings_and_values->GetFieldByName("values"))
+ .fields();
+ columns.push_back(grouped_ids);
+
+ ASSERT_OK_AND_ASSIGN(auto actual, StructArray::Make(columns, fields_with_ids));
+
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+TEST(GroupTest, Basics) {
+ AssertGrouping({field("a", utf8()), field("b", int32())}, R"([
+ {"a": "ex", "b": 0, "id": 0},
+ {"a": "ex", "b": 0, "id": 1},
+ {"a": "why", "b": 0, "id": 2},
+ {"a": "ex", "b": 1, "id": 3},
+ {"a": "why", "b": 0, "id": 4},
+ {"a": "ex", "b": 1, "id": 5},
+ {"a": "ex", "b": 0, "id": 6},
+ {"a": "why", "b": 1, "id": 7}
+ ])",
+ R"([
+ {"a": "ex", "b": 0, "ids": [0, 1, 6]},
+ {"a": "why", "b": 0, "ids": [2, 4]},
+ {"a": "ex", "b": 1, "ids": [3, 5]},
+ {"a": "why", "b": 1, "ids": [7]}
+ ])");
+}
+
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc
index 5844f89..a0ea91d 100644
--- a/cpp/src/arrow/dataset/partition.cc
+++ b/cpp/src/arrow/dataset/partition.cc
@@ -26,26 +26,30 @@
#include <vector>
#include "arrow/array/array_base.h"
+#include "arrow/array/array_nested.h"
#include "arrow/array/builder_binary.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/dataset/dataset_internal.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/filter.h"
#include "arrow/dataset/scanner.h"
+#include "arrow/dataset/scanner_internal.h"
#include "arrow/filesystem/filesystem.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/scalar.h"
#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
#include "arrow/util/range.h"
#include "arrow/util/sort.h"
#include "arrow/util/string_view.h"
namespace arrow {
-namespace dataset {
+using internal::checked_cast;
+using internal::checked_pointer_cast;
using util::string_view;
-using arrow::internal::checked_cast;
+namespace dataset {
std::shared_ptr<Partitioning> Partitioning::Default() {
class DefaultPartitioning : public Partitioning {
@@ -62,24 +66,16 @@ std::shared_ptr<Partitioning> Partitioning::Default() {
return Status::NotImplemented("formatting paths from ", type_name(),
" Partitioning");
}
+
+ Result<std::vector<PartitionedBatch>> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const override {
+ return std::vector<PartitionedBatch>{{batch, scalar(true)}};
+ }
};
return std::make_shared<DefaultPartitioning>();
}
-Result<WritePlan> PartitioningFactory::MakeWritePlan(std::shared_ptr<Schema> schema,
- FragmentIterator fragment_it) {
- return Status::NotImplemented("MakeWritePlan from PartitioningFactory of type ",
- type_name());
-}
-
-Result<WritePlan> PartitioningFactory::MakeWritePlan(
- std::shared_ptr<Schema> schema, FragmentIterator fragment_it,
- std::shared_ptr<Schema> partition_schema) {
- return Status::NotImplemented("MakeWritePlan from PartitioningFactory of type ",
- type_name());
-}
-
Status KeyValuePartitioning::VisitKeys(
const Expression& expr,
const std::function<Status(const std::string& name,
@@ -136,6 +132,56 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr,
});
}
+inline std::shared_ptr<Expression> ConjunctionFromGroupingRow(Scalar* row) {
+ ScalarVector* values = &checked_cast<StructScalar*>(row)->value;
+ ExpressionVector equality_expressions(values->size());
+ for (size_t i = 0; i < values->size(); ++i) {
+ const std::string& name = row->type->field(static_cast<int>(i))->name();
+ equality_expressions[i] = equal(field_ref(name), scalar(std::move(values->at(i))));
+ }
+ return and_(std::move(equality_expressions));
+}
+
+Result<std::vector<Partitioning::PartitionedBatch>> KeyValuePartitioning::Partition(
+ const std::shared_ptr<RecordBatch>& batch) const {
+ FieldVector by_fields;
+ ArrayVector by_columns;
+
+ std::shared_ptr<RecordBatch> rest = batch;
+ for (const auto& partition_field : schema_->fields()) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto match, FieldRef(partition_field->name()).FindOneOrNone(*rest->schema()))
+
+ if (match) {
+ by_fields.push_back(partition_field);
+ by_columns.push_back(rest->column(match[0]));
+ ARROW_ASSIGN_OR_RAISE(rest, rest->RemoveColumn(match[0]));
+ }
+ }
+
+ if (by_fields.empty()) {
+ // no fields to group by; return the whole batch
+ return std::vector<PartitionedBatch>{{batch, scalar(true)}};
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto by,
+ StructArray::Make(std::move(by_columns), std::move(by_fields)));
+ ARROW_ASSIGN_OR_RAISE(auto groupings_and_values, MakeGroupings(*by));
+ auto groupings =
+ checked_pointer_cast<ListArray>(groupings_and_values->GetFieldByName("groupings"));
+ auto unique_rows = groupings_and_values->GetFieldByName("values");
+
+ ARROW_ASSIGN_OR_RAISE(auto grouped_batches, ApplyGroupings(*groupings, rest));
+
+ std::vector<PartitionedBatch> out(grouped_batches.size());
+ for (size_t i = 0; i < out.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto row, unique_rows->GetScalar(i));
+ out[i].partition_expression = ConjunctionFromGroupingRow(row.get());
+ out[i].batch = std::move(grouped_batches[i]);
+ }
+ return out;
+}
+
Result<std::shared_ptr<Expression>> KeyValuePartitioning::ConvertKey(
const Key& key) const {
ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_));
@@ -254,7 +300,7 @@ Result<std::string> DirectoryPartitioning::FormatValues(
if (auto illegal_index = NextValid(values, i)) {
// XXX maybe we should just ignore keys provided after the first absent one?
return Status::Invalid("No partition key for ", schema_->field(i)->name(),
- " but subsequent a key was provided subsequently for ",
+ " but a key was provided subsequently for ",
schema_->field(*illegal_index)->name(), ".");
}
@@ -311,6 +357,8 @@ class KeyValuePartitioningInspectImpl {
void InsertRepr(int index, std::string repr) { values_[index].insert(std::move(repr)); }
Result<std::shared_ptr<Schema>> Finish(ArrayVector* dictionaries) {
+ dictionaries->clear();
+
if (options_.max_partition_dictionary_size != 0) {
dictionaries->resize(name_to_index_.size());
}
@@ -391,253 +439,12 @@ class DirectoryPartitioningFactory : public PartitioningFactory {
return std::make_shared<DirectoryPartitioning>(std::move(out_schema), dictionaries_);
}
- struct MakeWritePlanImpl;
-
- Result<WritePlan> MakeWritePlan(std::shared_ptr<Schema> schema,
- FragmentIterator fragments) override;
-
- Result<WritePlan> MakeWritePlan(std::shared_ptr<Schema> schema,
- FragmentIterator fragments,
- std::shared_ptr<Schema> partition_schema) override;
-
private:
std::vector<std::string> field_names_;
ArrayVector dictionaries_;
PartitioningFactoryOptions options_;
};
-struct DirectoryPartitioningFactory::MakeWritePlanImpl {
- using Indices = std::basic_string<int>;
-
- MakeWritePlanImpl(DirectoryPartitioningFactory* factory, std::shared_ptr<Schema> schema,
- FragmentVector source_fragments)
- : this_(factory),
- schema_(std::move(schema)),
- source_fragments_(std::move(source_fragments)),
- right_hand_sides_(source_fragments_.size(), Indices(num_fields(), -1)) {}
-
- int num_fields() const { return static_cast<int>(this_->field_names_.size()); }
-
- // For a KeyValuePartitioning, every partition expression will be an equality
- // ComparisonExpression where the left operand is a FieldExpression and the right is a
- // ScalarExpression. Comparing Scalars directly is expensive, so first assemble a
- // dictionary containing the scalars from the right operands of every partition
- // expression. This allows later stages of MakeWritePlan to handle a scalar by its
- // dictionary code, which is both more compact to store and cheap to compare.
- //
- // Scalars are stored such that the dictionary code of a fragment's RHS in the
- // partition expression for a given field is given by
- // int code = right_hand_sides_[fragment_index][field_index];
- // and the corresponding scalar can be retrieved with
- // std::shared_ptr<Scalar> scalar = scalar_dict_.code_to_scalar[code];
- Status DictEncodeRightHandSides() {
- if (source_fragments_.empty()) {
- return Status::OK();
- }
-
- for (size_t fragment_i = 0; fragment_i < source_fragments_.size(); ++fragment_i) {
- const auto& fragment = source_fragments_[fragment_i];
-
- auto insert_representable_into_dict = [this, fragment_i](
- const std::string& name,
- const std::shared_ptr<Scalar>& value) {
- auto it = std::find(this_->field_names_.begin(), this_->field_names_.end(), name);
- if (it == this_->field_names_.end()) {
- return Status::OK();
- }
-
- auto field_i = it - this_->field_names_.begin();
-
- int code = scalar_dict_.GetOrInsert(value);
- right_hand_sides_[fragment_i][field_i] = code;
-
- return Status::OK();
- };
-
- RETURN_NOT_OK(KeyValuePartitioning::VisitKeys(*fragment->partition_expression(),
- insert_representable_into_dict));
-
- auto it = std::find(right_hand_sides_[fragment_i].begin(),
- right_hand_sides_[fragment_i].end(), -1);
- if (it != right_hand_sides_[fragment_i].end()) {
- // NB: this is an error when writing DirectoryPartitioning but not
- // HivePartitioning (as it will be valid to simply omit segments)
- return Status::Invalid(
- "fragment ", fragment_i, " had no partition expression for field '",
- this_->field_names_.at(it - right_hand_sides_[fragment_i].begin()), "'");
- }
- }
-
- return Status::OK();
- }
-
- // Infer the Partitioning schema from partition expressions.
- // For example if one partition expression is "omega"_ == 13
- // we can infer that the field "omega" has type int32
- Result<std::shared_ptr<Schema>> InferPartitioningSchema() const {
- if (source_fragments_.empty()) {
- return Status::Invalid(
- "No fragments were provided so the Partitioning schema could not be "
- "inferred.");
- }
-
- // NB: under DirectoryPartitioning every fragment has a partition expression for every
- // field, so we can infer the schema by looking only at the first fragment. This will
- // be more complicated for HivePartitioning.
- int fragment_i = 0;
-
- FieldVector fields(num_fields());
- for (int field_i = 0; field_i < num_fields(); ++field_i) {
- const auto& name = this_->field_names_[field_i];
- const auto& type =
- scalar_dict_.code_to_scalar[right_hand_sides_[fragment_i][field_i]]->type;
- fields[field_i] = field(name, type);
- }
-
- return schema(std::move(fields));
- }
-
- // reconstitute fragment_i's partition expression for field_i by reading the right
- // hand side from the scalar dictionary and constructing an equality
- // ComparisonExpression
- std::shared_ptr<Expression> PartitionExpression(size_t fragment_i, int field_i) {
- auto left_hand_side = field_ref(this_->field_names_[field_i]);
- auto right_hand_side =
- scalar(scalar_dict_.code_to_scalar[right_hand_sides_[fragment_i][field_i]]);
- return equal(std::move(left_hand_side), std::move(right_hand_side));
- }
-
- // create a guid by stringifying the number of milliseconds since the epoch
- std::string Guid() {
- using std::chrono::duration_cast;
- using std::chrono::milliseconds;
- using std::chrono::steady_clock;
- auto milliseconds_since_epoch =
- duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count();
- return std::to_string(milliseconds_since_epoch);
- }
-
- Result<WritePlan> Finish(std::shared_ptr<Schema> partitioning_schema = nullptr) && {
- WritePlan out;
-
- RETURN_NOT_OK(DictEncodeRightHandSides());
-
- if (partitioning_schema == nullptr) {
- ARROW_ASSIGN_OR_RAISE(partitioning_schema, InferPartitioningSchema());
- }
- ARROW_ASSIGN_OR_RAISE(out.partitioning,
- this_->Finish(std::move(partitioning_schema)));
-
- // There's no guarantee that all Fragments have the same schema.
- ARROW_ASSIGN_OR_RAISE(out.schema,
- UnifySchemas({out.partitioning->schema(), schema_}));
-
- // Lexicographic ordering WRT right_hand_sides_ ensures that source_fragments_ are in
- // a depth first visitation order WRT their partition expressions. This makes
- // generation of the full directory tree far simpler since a directory's files are
- // grouped.
- auto permutation = arrow::internal::ArgSort(right_hand_sides_);
- arrow::internal::Permute(permutation, &source_fragments_);
- arrow::internal::Permute(permutation, &right_hand_sides_);
-
- // out.paths[parents[i]] is the parent directory of out.paths[i]
- std::vector<int> parents;
-
- // current_right_hand_sides[field_i] is the RHS dictionary code for the current
- // partition expression corresponding to field_i
- Indices current_right_hand_sides(num_fields(), -1);
-
- // current_partition_expressions[field_i] is the current partition expression
- // corresponding to field_i
- ExpressionVector current_partition_expressions(num_fields());
-
- // out.paths[current_parents[field_i]] is the current ancestor directory corresponding
- // to field_i
- Indices current_parents(num_fields() + 1, -1);
-
- for (size_t fragment_i = 0; fragment_i < source_fragments_.size(); ++fragment_i) {
- int field_i = 0;
- for (; field_i < num_fields(); ++field_i) {
- // these directories have already been created and we're still writing their
- // children
- if (right_hand_sides_[fragment_i][field_i] != current_right_hand_sides[field_i]) {
- break;
- }
- }
-
- for (; field_i < num_fields(); ++field_i) {
- // push a new directory
- current_parents[field_i + 1] = static_cast<int>(parents.size());
- parents.push_back(current_parents[field_i]);
-
- current_partition_expressions.resize(field_i + 1);
- current_partition_expressions[field_i] = PartitionExpression(fragment_i, field_i);
- auto partition_expression = and_(current_partition_expressions);
-
- // format segment for partition_expression
- ARROW_ASSIGN_OR_RAISE(auto path, out.partitioning->Format(*partition_expression));
- out.paths.push_back(std::move(path));
-
- // store partition_expression for use in the written Dataset
- out.fragment_or_partition_expressions.emplace_back(
- current_partition_expressions[field_i]);
-
- current_right_hand_sides[field_i] = right_hand_sides_[fragment_i][field_i];
- }
-
- // push a fragment (not attempting to give files meaningful names)
- std::string basename = Guid() + "_" + std::to_string(fragment_i);
- int parent_i = current_parents[field_i];
- parents.push_back(parent_i);
- out.paths.push_back(fs::internal::JoinAbstractPath(
- std::vector<std::string>{out.paths[parent_i], std::move(basename)}));
-
- // store a fragment for writing to disk
- out.fragment_or_partition_expressions.emplace_back(
- std::move(source_fragments_[fragment_i]));
- }
-
- return out;
- }
-
- DirectoryPartitioningFactory* this_;
- std::shared_ptr<Schema> schema_;
- FragmentVector source_fragments_;
-
- struct {
- std::unordered_map<std::shared_ptr<Scalar>, int, Scalar::Hash, Scalar::PtrsEqual>
- scalar_to_code;
-
- ScalarVector code_to_scalar;
-
- int GetOrInsert(const std::shared_ptr<Scalar>& scalar) {
- int new_code = static_cast<int>(code_to_scalar.size());
-
- auto it_inserted = scalar_to_code.emplace(scalar, new_code);
- if (!it_inserted.second) {
- return it_inserted.first->second;
- }
-
- code_to_scalar.push_back(scalar);
- return new_code;
- }
- } scalar_dict_;
- std::vector<Indices> right_hand_sides_;
-};
-
-Result<WritePlan> DirectoryPartitioningFactory::MakeWritePlan(
- std::shared_ptr<Schema> schema, FragmentIterator fragment_it,
- std::shared_ptr<Schema> partition_schema) {
- ARROW_ASSIGN_OR_RAISE(auto fragments, fragment_it.ToVector());
- return MakeWritePlanImpl(this, schema, std::move(fragments)).Finish(partition_schema);
-}
-
-Result<WritePlan> DirectoryPartitioningFactory::MakeWritePlan(
- std::shared_ptr<Schema> schema, FragmentIterator fragment_it) {
- ARROW_ASSIGN_OR_RAISE(auto fragments, fragment_it.ToVector());
- return MakeWritePlanImpl(this, schema, std::move(fragments)).Finish();
-}
-
std::shared_ptr<PartitioningFactory> DirectoryPartitioning::MakeFactory(
std::vector<std::string> field_names, PartitioningFactoryOptions options) {
return std::shared_ptr<PartitioningFactory>(
diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h
index ea5828ec..021f822 100644
--- a/cpp/src/arrow/dataset/partition.h
+++ b/cpp/src/arrow/dataset/partition.h
@@ -59,6 +59,15 @@ class ARROW_DS_EXPORT Partitioning {
/// \brief The name identifying the kind of partitioning
virtual std::string type_name() const = 0;
+ /// \brief If the input batch shares any fields with this partitioning,
+ /// produce slices of the batch which satisfy mutually exclusive Expressions.
+ struct PartitionedBatch {
+ std::shared_ptr<RecordBatch> batch;
+ std::shared_ptr<Expression> partition_expression;
+ };
+ virtual Result<std::vector<PartitionedBatch>> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const = 0;
+
/// \brief Parse a path into a partition expression
virtual Result<std::shared_ptr<Expression>> Parse(const std::string& path) const = 0;
@@ -104,15 +113,6 @@ class ARROW_DS_EXPORT PartitioningFactory {
/// (fields may be dropped).
virtual Result<std::shared_ptr<Partitioning>> Finish(
const std::shared_ptr<Schema>& schema) const = 0;
-
- // FIXME(bkietz) Make these pure virtual
- /// Construct a WritePlan for the provided fragments
- virtual Result<WritePlan> MakeWritePlan(std::shared_ptr<Schema> schema,
- FragmentIterator fragments,
- std::shared_ptr<Schema> partition_schema);
- /// Construct a WritePlan for the provided fragments, inferring schema
- virtual Result<WritePlan> MakeWritePlan(std::shared_ptr<Schema> schema,
- FragmentIterator fragments);
};
/// \brief Subclass for the common case of a partitioning which yields an equality
@@ -136,6 +136,9 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning {
static Status SetDefaultValuesFromKeys(const Expression& expr,
RecordBatchProjector* projector);
+ Result<std::vector<PartitionedBatch>> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const override;
+
Result<std::shared_ptr<Expression>> Parse(const std::string& path) const override;
Result<std::string> Format(const Expression& expr) const override;
@@ -240,6 +243,12 @@ class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning {
return Status::NotImplemented("formatting paths from ", type_name(), " Partitioning");
}
+ Result<std::vector<PartitionedBatch>> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const override {
+ return Status::NotImplemented("partitioning batches from ", type_name(),
+ " Partitioning");
+ }
+
private:
ParseImpl parse_impl_;
FormatImpl format_impl_;
diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc
index 785be6a..27ab00a 100644
--- a/cpp/src/arrow/dataset/partition_test.cc
+++ b/cpp/src/arrow/dataset/partition_test.cc
@@ -28,6 +28,7 @@
#include <vector>
#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/scanner_internal.h"
#include "arrow/dataset/test_util.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/filesystem/path_util.h"
@@ -36,6 +37,8 @@
#include "arrow/util/io_util.h"
namespace arrow {
+using internal::checked_pointer_cast;
+
namespace dataset {
using E = TestExpression;
@@ -436,6 +439,10 @@ class RangePartitioning : public Partitioning {
}
Result<std::string> Format(const Expression&) const override { return ""; }
+ Result<std::vector<PartitionedBatch>> Partition(
+ const std::shared_ptr<RecordBatch>&) const override {
+ return Status::OK();
+ }
};
TEST_F(TestPartitioning, Range) {
@@ -447,178 +454,6 @@ TEST_F(TestPartitioning, Range) {
("z"_ > 1.5 and "z"_ <= 3.0));
}
-class TestPartitioningWritePlan : public ::testing::Test {
- protected:
- FragmentIterator MakeFragments(const ExpressionVector& partition_expressions) {
- fragments_.clear();
- for (const auto& expr : partition_expressions) {
- fragments_.emplace_back(new InMemoryFragment(RecordBatchVector{}, expr));
- }
- return MakeVectorIterator(fragments_);
- }
-
- std::shared_ptr<Expression> ExpressionPtr(const Expression& e) { return e.Copy(); }
- std::shared_ptr<Expression> ExpressionPtr(std::shared_ptr<Expression> e) { return e; }
-
- template <typename... E>
- FragmentIterator MakeFragments(const E&... partition_expressions) {
- return MakeFragments(ExpressionVector{ExpressionPtr(partition_expressions)...});
- }
-
- template <typename... E>
- void MakeWritePlan(const E&... partition_expressions) {
- auto fragments = MakeFragments(partition_expressions...);
- EXPECT_OK_AND_ASSIGN(plan_,
- factory_->MakeWritePlan(schema({}), std::move(fragments)));
- }
-
- template <typename... E>
- Status MakeWritePlanError(const E&... partition_expressions) {
- auto fragments = MakeFragments(partition_expressions...);
- return factory_->MakeWritePlan(schema({}), std::move(fragments)).status();
- }
-
- template <typename... E>
- void MakeWritePlanWithSchema(const std::shared_ptr<Schema>& partition_schema,
- const E&... partition_expressions) {
- auto fragments = MakeFragments(partition_expressions...);
- EXPECT_OK_AND_ASSIGN(plan_, factory_->MakeWritePlan(schema({}), std::move(fragments),
- partition_schema));
- }
-
- template <typename... E>
- Status MakeWritePlanWithSchemaError(const std::shared_ptr<Schema>& partition_schema,
- const E&... partition_expressions) {
- auto fragments = MakeFragments(partition_expressions...);
- return factory_->MakeWritePlan(schema({}), std::move(fragments), partition_schema)
- .status();
- }
-
- struct ExpectedWritePlan {
- ExpectedWritePlan() = default;
-
- ExpectedWritePlan(const WritePlan& actual_plan, const FragmentVector& fragments) {
- int i = 0;
- for (const auto& op : actual_plan.fragment_or_partition_expressions) {
- if (op.kind() == WritePlan::FragmentOrPartitionExpression::FRAGMENT) {
- auto fragment = op.fragment();
- auto fragment_index =
- static_cast<int>(std::find(fragments.begin(), fragments.end(), fragment) -
- fragments.begin());
- auto path = fs::internal::GetAbstractPathParent(actual_plan.paths[i]).first;
- dirs_[path].fragments.push_back(fragment_index);
- } else {
- auto partition_expression = op.partition_expr();
- dirs_[actual_plan.paths[i]].partition_expression = partition_expression;
- }
- ++i;
- }
- }
-
- ExpectedWritePlan Dir(const std::string& path, const Expression& expr,
- const std::vector<int>& fragments) && {
- dirs_.emplace(path, DirectoryWriteOp{expr.Copy(), fragments});
- return std::move(*this);
- }
-
- struct DirectoryWriteOp {
- std::shared_ptr<Expression> partition_expression;
- std::vector<int> fragments;
-
- bool operator==(const DirectoryWriteOp& other) const {
- return partition_expression->Equals(other.partition_expression) &&
- fragments == other.fragments;
- }
-
- friend void PrintTo(const DirectoryWriteOp& op, std::ostream* os) {
- *os << op.partition_expression->ToString();
-
- *os << " { ";
- for (const auto& fragment : op.fragments) {
- *os << fragment << " ";
- }
- *os << "}\n";
- }
- };
- std::map<std::string, DirectoryWriteOp> dirs_;
- };
-
- struct AssertPlanIs : ExpectedWritePlan {};
-
- void AssertPlanIs(ExpectedWritePlan expected_plan) {
- ExpectedWritePlan actual_plan(plan_, fragments_);
- EXPECT_THAT(actual_plan.dirs_, testing::ContainerEq(expected_plan.dirs_));
- }
-
- FragmentVector fragments_;
- std::shared_ptr<ScanOptions> scan_options_ = ScanOptions::Make(schema({}));
- std::shared_ptr<PartitioningFactory> factory_;
- WritePlan plan_;
-};
-
-TEST_F(TestPartitioningWritePlan, Empty) {
- factory_ = DirectoryPartitioning::MakeFactory({"a", "b"});
-
- // no expressions from which to infer the types of fields a, b
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("No fragments"),
- MakeWritePlanError());
-
- MakeWritePlanWithSchema(schema({field("a", int32()), field("b", utf8())}));
- AssertPlanIs({});
-
- factory_ = HivePartitioning::MakeFactory();
- EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, testing::HasSubstr("hive"),
- MakeWritePlanError());
-}
-
-TEST_F(TestPartitioningWritePlan, SingleDirectory) {
- factory_ = DirectoryPartitioning::MakeFactory({"a"});
-
- MakeWritePlan("a"_ == 42, "a"_ == 99, "a"_ == 101);
- AssertPlanIs(ExpectedWritePlan()
- .Dir("42", "a"_ == 42, {0})
- .Dir("99", "a"_ == 99, {1})
- .Dir("101", "a"_ == 101, {2}));
-
- MakeWritePlan("a"_ == 42, "a"_ == 99, "a"_ == 99, "a"_ == 101, "a"_ == 99);
- AssertPlanIs(ExpectedWritePlan()
- .Dir("42", "a"_ == 42, {0})
- .Dir("99", "a"_ == 99, {1, 2, 4})
- .Dir("101", "a"_ == 101, {3}));
-}
-
-TEST_F(TestPartitioningWritePlan, NestedDirectories) {
- factory_ = DirectoryPartitioning::MakeFactory({"a", "b"});
-
- MakeWritePlan("a"_ == 42 and "b"_ == "hello", "a"_ == 42 and "b"_ == "world",
- "a"_ == 99 and "b"_ == "hello", "a"_ == 99 and "b"_ == "world");
-
- AssertPlanIs(ExpectedWritePlan()
- .Dir("42", "a"_ == 42, {})
- .Dir("42/hello", "b"_ == "hello", {0})
- .Dir("42/world", "b"_ == "world", {1})
- .Dir("99", "a"_ == 99, {})
- .Dir("99/hello", "b"_ == "hello", {2})
- .Dir("99/world", "b"_ == "world", {3}));
-}
-
-TEST_F(TestPartitioningWritePlan, Errors) {
- factory_ = DirectoryPartitioning::MakeFactory({"a"});
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, testing::HasSubstr("no partition expression for field 'a'"),
- MakeWritePlanError("a"_ == 42, scalar(true), "a"_ == 101));
-
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- TypeError, testing::HasSubstr("scalar hello (of type string) is invalid"),
- MakeWritePlanError("a"_ == 42, "a"_ == "hello"));
-
- factory_ = DirectoryPartitioning::MakeFactory({"a", "b"});
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, testing::HasSubstr("no partition expression for field 'a'"),
- MakeWritePlanError("a"_ == 42 and "b"_ == "hello", "a"_ == 99 and "b"_ == "world",
- "b"_ == "forever alone"));
-}
-
TEST(TestStripPrefixAndFilename, Basic) {
ASSERT_EQ(StripPrefixAndFilename("", ""), "");
ASSERT_EQ(StripPrefixAndFilename("a.csv", ""), "");
@@ -633,7 +468,7 @@ TEST(TestStripPrefixAndFilename, Basic) {
EXPECT_THAT(StripPrefixAndFilename(input, "/data"),
testing::ElementsAre("year=2019", "year=2019/month=12",
"year=2019/month=12/day=01"));
-} // namespace dataset
+}
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h
index 5737b1e..94df944 100644
--- a/cpp/src/arrow/dataset/scanner_internal.h
+++ b/cpp/src/arrow/dataset/scanner_internal.h
@@ -112,5 +112,42 @@ inline ScanTaskIterator GetScanTaskIterator(FragmentIterator fragments,
return MakeFlattenIterator(std::move(maybe_scantask_it));
}
+struct FragmentRecordBatchReader : RecordBatchReader {
+ public:
+ std::shared_ptr<Schema> schema() const override { return options_->schema(); }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ return iterator_.Next().Value(batch);
+ }
+
+ static Result<std::shared_ptr<FragmentRecordBatchReader>> Make(
+ std::shared_ptr<Fragment> fragment, std::shared_ptr<Schema> schema,
+ std::shared_ptr<ScanContext> context) {
+ // ensure schema is cached in fragment
+ auto options = ScanOptions::Make(std::move(schema));
+ RETURN_NOT_OK(KeyValuePartitioning::SetDefaultValuesFromKeys(
+ *fragment->partition_expression(), &options->projector));
+
+ auto pool = context->pool;
+ ARROW_ASSIGN_OR_RAISE(auto scan_tasks, fragment->Scan(options, std::move(context)));
+
+ auto reader = std::make_shared<FragmentRecordBatchReader>();
+ reader->options_ = std::move(options);
+ reader->fragment_ = std::move(fragment);
+ reader->iterator_ = ProjectRecordBatch(
+ MakeFlattenIterator(MakeMaybeMapIterator(
+ [](std::shared_ptr<ScanTask> task) { return task->Execute(); },
+ std::move(scan_tasks))),
+ &reader->options_->projector, pool);
+
+ return reader;
+ }
+
+ private:
+ std::shared_ptr<ScanOptions> options_;
+ std::shared_ptr<Fragment> fragment_;
+ RecordBatchIterator iterator_;
+};
+
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h
index 8f22b39..0a686a1 100644
--- a/cpp/src/arrow/dataset/test_util.h
+++ b/cpp/src/arrow/dataset/test_util.h
@@ -300,8 +300,10 @@ struct MakeFileSystemDatasetMixin {
partitions.resize(n_fragments, scalar(true));
}
+ auto s = schema({});
+
MakeFileSystem(infos);
- auto format = std::make_shared<DummyFileFormat>();
+ auto format = std::make_shared<DummyFileFormat>(s);
std::vector<std::shared_ptr<FileFragment>> fragments;
for (size_t i = 0; i < n_fragments; i++) {
@@ -315,8 +317,8 @@ struct MakeFileSystemDatasetMixin {
fragments.push_back(std::move(fragment));
}
- ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(schema({}), root_partition,
- format, std::move(fragments)));
+ ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format,
+ std::move(fragments)));
}
void MakeDatasetFromPathlist(const std::string& pathlist,
diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h
index 44c4bbc..089e1df 100644
--- a/cpp/src/arrow/dataset/type_fwd.h
+++ b/cpp/src/arrow/dataset/type_fwd.h
@@ -83,8 +83,5 @@ using ScanTaskIterator = Iterator<std::shared_ptr<ScanTask>>;
class RecordBatchProjector;
-class WriteTask;
-class WritePlan;
-
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc
index de56b68..56b2783 100644
--- a/cpp/src/arrow/record_batch.cc
+++ b/cpp/src/arrow/record_batch.cc
@@ -173,7 +173,7 @@ Result<std::shared_ptr<RecordBatch>> RecordBatch::FromStructArray(
array->data()->child_data);
}
-Result<std::shared_ptr<Array>> RecordBatch::ToStructArray() const {
+Result<std::shared_ptr<StructArray>> RecordBatch::ToStructArray() const {
if (num_columns() != 0) {
return StructArray::Make(columns(), schema()->fields());
}
diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h
index 0d1b1b1..63d0bd8 100644
--- a/cpp/src/arrow/record_batch.h
+++ b/cpp/src/arrow/record_batch.h
@@ -65,7 +65,7 @@ class ARROW_EXPORT RecordBatch {
/// Create a struct array whose child arrays are the record batch's columns.
/// Note that the record batch's top-level field metadata cannot be reflected
/// in the resulting struct array.
- Result<std::shared_ptr<Array>> ToStructArray() const;
+ Result<std::shared_ptr<StructArray>> ToStructArray() const;
/// \brief Construct record batch from struct array
///
@@ -210,7 +210,7 @@ class ARROW_EXPORT RecordBatchReader {
}
/// \brief Consume entire stream as a vector of record batches
- Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches);
+ Status ReadAll(RecordBatchVector* batches);
/// \brief Read all batches and concatenate as arrow::Table
Status ReadAll(std::shared_ptr<Table>* table);
@@ -221,8 +221,7 @@ class ARROW_EXPORT RecordBatchReader {
/// \param[in] schema schema to conform to. Will be inferred from the first
/// element if not provided.
static Result<std::shared_ptr<RecordBatchReader>> Make(
- std::vector<std::shared_ptr<RecordBatch>> batches,
- std::shared_ptr<Schema> schema = NULLPTR);
+ RecordBatchVector batches, std::shared_ptr<Schema> schema = NULLPTR);
};
} // namespace arrow
diff --git a/r/DESCRIPTION b/r/DESCRIPTION
index ff455b0..cd11e46 100644
--- a/r/DESCRIPTION
+++ b/r/DESCRIPTION
@@ -65,6 +65,7 @@ Collate:
'compute.R'
'config.R'
'csv.R'
+ 'dataset-write.R'
'dataset.R'
'deprecated.R'
'dictionary.R'
diff --git a/r/NAMESPACE b/r/NAMESPACE
index 4e5f130..0748591 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -8,6 +8,7 @@ S3method("==",ArrowObject)
S3method("[",Array)
S3method("[",ChunkedArray)
S3method("[",RecordBatch)
+S3method("[",Schema)
S3method("[",Table)
S3method("[[",RecordBatch)
S3method("[[",Schema)
@@ -239,6 +240,7 @@ export(uint8)
export(unify_schemas)
export(utf8)
export(write_arrow)
+export(write_dataset)
export(write_feather)
export(write_ipc_stream)
export(write_parquet)
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index a98a6cb..c013a71 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -452,6 +452,10 @@ dataset___ScanTask__get_batches <- function(scan_task){
.Call(`_arrow_dataset___ScanTask__get_batches` , scan_task)
}
+dataset___Dataset__Write <- function(ds, schema, format, filesystem, path, partitioning){
+ invisible(.Call(`_arrow_dataset___Dataset__Write` , ds, schema, format, filesystem, path, partitioning))
+}
+
shared_ptr_is_null <- function(xp){
.Call(`_arrow_shared_ptr_is_null` , xp)
}
diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R
new file mode 100644
index 0000000..8baca1e
--- /dev/null
+++ b/r/R/dataset-write.R
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#' Write a dataset
+#'
+#' This function allows you to write a dataset. By writing to more efficient
+#' binary storage formats, and by specifying relevant partitioning, you can
+#' make it much faster to read and query.
+#'
+#' @param dataset [Dataset] or `arrow_dplyr_query`. If a `arrow_dplyr_query`,
+#' note that `select()` or `filter()` queries are not currently supported.
+#' @param path string path to a directory to write to (directory will be
+#' created if it does not exist)
+#' @param format file format to write the dataset to. Currently only "feather"
+#' (aka "ipc") is supported.
+#' @param partitioning `Partitioning` or a character vector of columns to
+#' use as partition keys (to be written as path segments). Default is to
+#' use the current `group_by()` columns.
+#' @param hive_style logical: write partition segments as Hive-style
+#' (`key1=value1/key2=value2/file.ext`) or as just bare values. Default is `TRUE`.
+#' @param ... additional arguments, passed to `dataset$write()`
+#' @return The input `dataset`, invisibly
+#' @export
+write_dataset <- function(dataset,
+ path,
+ format = dataset$format$type,
+ partitioning = dplyr::group_vars(dataset),
+ hive_style = TRUE,
+ ...) {
+ if (inherits(dataset, "arrow_dplyr_query")) {
+ force(partitioning) # get the group_vars before we drop the object
+ # TODO: Write a filtered/projected dataset
+ if (!isTRUE(dataset$filtered_rows)) {
+ stop("Writing a filtered dataset is not yet supported", call. = FALSE)
+ }
+ if (!identical(dataset$selected_columns, set_names(names(dataset$.data)))) {
+ # TODO: actually, we can do this?
+ stop("TODO", call. = FALSE)
+ }
+ dataset <- dataset$.data
+ }
+ if (!inherits(dataset, "Dataset")) {
+ stop("'dataset' must be a Dataset", call. = FALSE)
+ # TODO: This does not exist yet (in the R bindings at least)
+ # dataset <- InMemoryDataset$create(dataset)
+ }
+
+ if (!inherits(format, "FileFormat")) {
+ format <- FileFormat$create(format, ...)
+ }
+ if (!inherits(format, "IpcFileFormat")) {
+ stop(
+ "Unsupported format; datasets currently can only be written to IPC/Feather format",
+ call. = FALSE
+ )
+ }
+
+ if (!inherits(partitioning, "Partitioning")) {
+ # TODO: tidyselect?
+ partition_schema <- dataset$schema[partitioning]
+ if (isTRUE(hive_style)) {
+ partitioning <- HivePartitioning$create(partition_schema)
+ } else {
+ partitioning <- DirectoryPartitioning$create(partition_schema)
+ }
+ }
+ dataset$write(path, format = format, partitioning = partitioning, ...)
+}
diff --git a/r/R/dataset.R b/r/R/dataset.R
index eb229d7..a50e297 100644
--- a/r/R/dataset.R
+++ b/r/R/dataset.R
@@ -133,6 +133,9 @@ open_dataset <- function(sources,
#' may also replace the dataset's schema by using `ds$schema <- new_schema`.
#' This method currently supports only adding, removing, or reordering
#' fields in the schema: you cannot alter or cast the field types.
+#' - `$write(path, filesystem, schema, format, partitioning)`: writes the
+#' dataset to `path` in the `format` file format, partitioned by `partitioning`,
+#' and invisibly returns `self`. See [write_dataset()].
#'
#' `FileSystemDataset` has the following methods:
#' - `$files`: Active binding, returns the files of the `FileSystemDataset`
@@ -159,7 +162,20 @@ Dataset <- R6Class("Dataset", inherit = ArrowObject,
# Start a new scan of the data
# @return A [ScannerBuilder]
NewScan = function() unique_ptr(ScannerBuilder, dataset___Dataset__NewScan(self)),
- ToString = function() self$schema$ToString()
+ ToString = function() self$schema$ToString(),
+ write = function(path, filesystem = NULL, schema = self$schema, format, partitioning) {
+ if (!inherits(filesystem, "FileSystem")) {
+ if (grepl("://", path)) {
+ fs_from_uri <- FileSystem$from_uri(path)
+ filesystem <- fs_from_uri$fs
+ path <- fs_from_uri$path
+ } else {
+ filesystem <- LocalFileSystem$create()
+ }
+ }
+ dataset___Dataset__Write(self, schema, format, filesystem, path, partitioning)
+ invisible(self)
+ }
),
active = list(
schema = function(schema) {
diff --git a/r/R/schema.R b/r/R/schema.R
index 963e5f4..ddbf30f 100644
--- a/r/R/schema.R
+++ b/r/R/schema.R
@@ -152,6 +152,30 @@ length.Schema <- function(x) x$num_fields
}
#' @export
+`[.Schema` <- function(x, i, ...) {
+ if (is.logical(i)) {
+ i <- rep_len(i, length(x)) # For R recycling behavior
+ i <- which(i)
+ }
+ if (is.numeric(i)) {
+ if (all(i < 0)) {
+ # in R, negative i means "everything but i"
+ i <- setdiff(seq_len(length(x)), -1 * i)
+ }
+ }
+ fields <- map(i, ~x[[.]])
+ invalid <- map_lgl(fields, is.null)
+ if (any(invalid)) {
+ stop(
+ "Invalid field name", ifelse(sum(invalid) > 1, "s: ", ": "),
+ oxford_paste(i[invalid]),
+ call. = FALSE
+ )
+ }
+ shared_ptr(Schema, schema_(fields))
+}
+
+#' @export
`$.Schema` <- function(x, name, ...) {
assert_that(is.string(name))
if (name %in% ls(x)) {
diff --git a/r/man/Dataset.Rd b/r/man/Dataset.Rd
index 686611c..1f8ce96 100644
--- a/r/man/Dataset.Rd
+++ b/r/man/Dataset.Rd
@@ -60,6 +60,9 @@ A \code{Dataset} has the following methods:
may also replace the dataset's schema by using \code{ds$schema <- new_schema}.
This method currently supports only adding, removing, or reordering
fields in the schema: you cannot alter or cast the field types.
+\item \verb{$write(path, filesystem, schema, format, partitioning)}: writes the
+dataset to \code{path} in the \code{format} file format, partitioned by \code{partitioning},
+and invisibly returns \code{self}. See \code{\link[=write_dataset]{write_dataset()}}.
}
\code{FileSystemDataset} has the following methods:
diff --git a/r/man/write_dataset.Rd b/r/man/write_dataset.Rd
new file mode 100644
index 0000000..54bd353
--- /dev/null
+++ b/r/man/write_dataset.Rd
@@ -0,0 +1,42 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/dataset-write.R
+\name{write_dataset}
+\alias{write_dataset}
+\title{Write a dataset}
+\usage{
+write_dataset(
+ dataset,
+ path,
+ format = dataset$format$type,
+ partitioning = dplyr::group_vars(dataset),
+ hive_style = TRUE,
+ ...
+)
+}
+\arguments{
+\item{dataset}{\link{Dataset} or \code{arrow_dplyr_query}. If a \code{arrow_dplyr_query},
+note that \code{select()} or \code{filter()} queries are not currently supported.}
+
+\item{path}{string path to a directory to write to (directory will be
+created if it does not exist)}
+
+\item{format}{file format to write the dataset to. Currently only "feather"
+(aka "ipc") is supported.}
+
+\item{partitioning}{\code{Partitioning} or a character vector of columns to
+use as partition keys (to be written as path segments). Default is to
+use the current \code{group_by()} columns.}
+
+\item{hive_style}{logical: write partition segments as Hive-style
+(\code{key1=value1/key2=value2/file.ext}) or as just bare values. Default is \code{TRUE}.}
+
+\item{...}{additional arguments, passed to \code{dataset$write()}}
+}
+\value{
+The input \code{dataset}, invisibly
+}
+\description{
+This function allows you to write a dataset. By writing to more efficient
+binary storage formats, and by specifying relevant partitioning, you can
+make it much faster to read and query.
+}
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index 9d0058b..4a80ed0 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -1774,6 +1774,27 @@ RcppExport SEXP _arrow_dataset___ScanTask__get_batches(SEXP scan_task_sexp){
}
#endif
+// dataset.cpp
+#if defined(ARROW_R_WITH_ARROW)
+void dataset___Dataset__Write(const std::shared_ptr<ds::Dataset>& ds, const std::shared_ptr<arrow::Schema>& schema, const std::shared_ptr<ds::FileFormat>& format, const std::shared_ptr<fs::FileSystem>& filesystem, std::string path, const std::shared_ptr<ds::Partitioning>& partitioning);
+RcppExport SEXP _arrow_dataset___Dataset__Write(SEXP ds_sexp, SEXP schema_sexp, SEXP format_sexp, SEXP filesystem_sexp, SEXP path_sexp, SEXP partitioning_sexp){
+BEGIN_RCPP
+ Rcpp::traits::input_parameter<const std::shared_ptr<ds::Dataset>&>::type ds(ds_sexp);
+ Rcpp::traits::input_parameter<const std::shared_ptr<arrow::Schema>&>::type schema(schema_sexp);
+ Rcpp::traits::input_parameter<const std::shared_ptr<ds::FileFormat>&>::type format(format_sexp);
+ Rcpp::traits::input_parameter<const std::shared_ptr<fs::FileSystem>&>::type filesystem(filesystem_sexp);
+ Rcpp::traits::input_parameter<std::string>::type path(path_sexp);
+ Rcpp::traits::input_parameter<const std::shared_ptr<ds::Partitioning>&>::type partitioning(partitioning_sexp);
+ dataset___Dataset__Write(ds, schema, format, filesystem, path, partitioning);
+ return R_NilValue;
+END_RCPP
+}
+#else
+RcppExport SEXP _arrow_dataset___Dataset__Write(SEXP ds_sexp, SEXP schema_sexp, SEXP format_sexp, SEXP filesystem_sexp, SEXP path_sexp, SEXP partitioning_sexp){
+ Rf_error("Cannot call dataset___Dataset__Write(). Please use arrow::install_arrow() to install required runtime libraries. ");
+}
+#endif
+
// datatype.cpp
#if defined(ARROW_R_WITH_ARROW)
bool shared_ptr_is_null(SEXP xp);
@@ -6023,6 +6044,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_dataset___Scanner__ToTable", (DL_FUNC) &_arrow_dataset___Scanner__ToTable, 1},
{ "_arrow_dataset___Scanner__Scan", (DL_FUNC) &_arrow_dataset___Scanner__Scan, 1},
{ "_arrow_dataset___ScanTask__get_batches", (DL_FUNC) &_arrow_dataset___ScanTask__get_batches, 1},
+ { "_arrow_dataset___Dataset__Write", (DL_FUNC) &_arrow_dataset___Dataset__Write, 6},
{ "_arrow_shared_ptr_is_null", (DL_FUNC) &_arrow_shared_ptr_is_null, 1},
{ "_arrow_unique_ptr_is_null", (DL_FUNC) &_arrow_unique_ptr_is_null, 1},
{ "_arrow_Int8__initialize", (DL_FUNC) &_arrow_Int8__initialize, 0},
diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp
index ade66ef..aac809a 100644
--- a/r/src/dataset.cpp
+++ b/r/src/dataset.cpp
@@ -289,4 +289,19 @@ std::vector<std::shared_ptr<arrow::RecordBatch>> dataset___ScanTask__get_batches
return out;
}
+// [[arrow::export]]
+void dataset___Dataset__Write(const std::shared_ptr<ds::Dataset>& ds,
+ const std::shared_ptr<arrow::Schema>& schema,
+ const std::shared_ptr<ds::FileFormat>& format,
+ const std::shared_ptr<fs::FileSystem>& filesystem,
+ std::string path,
+ const std::shared_ptr<ds::Partitioning>& partitioning) {
+ auto frags = ds->GetFragments();
+ auto ctx = std::make_shared<ds::ScanContext>();
+ ctx->use_threads = true;
+ StopIfNotOk(ds::FileSystemDataset::Write(schema, format, filesystem, path, partitioning,
+ ctx, std::move(frags)));
+ return;
+}
+
#endif
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index e78dfd3..1e93d41 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -626,3 +626,104 @@ test_that("Assembling multiple DatasetFactories with DatasetFactory", {
expect_scan_result(ds, schm)
})
+
+test_that("Writing a dataset: CSV->IPC", {
+ skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
+ ds <- open_dataset(csv_dir, partitioning = "part", format = "csv")
+ dst_dir <- make_temp_dir()
+ write_dataset(ds, dst_dir, format = "feather", partitioning = "int")
+ expect_true(dir.exists(dst_dir))
+ expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "=")))
+
+ new_ds <- open_dataset(dst_dir, format = "feather")
+
+ expect_equivalent(
+ new_ds %>%
+ select(string = chr, integer = int) %>%
+ filter(integer > 6 & integer < 11) %>%
+ collect() %>%
+ summarize(mean = mean(integer)),
+ df1 %>%
+ select(string = chr, integer = int) %>%
+ filter(integer > 6) %>%
+ summarize(mean = mean(integer))
+ )
+})
+
+test_that("Writing a dataset: Parquet->IPC", {
+ skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
+ ds <- open_dataset(hive_dir)
+ dst_dir <- make_temp_dir()
+ write_dataset(ds, dst_dir, format = "feather", partitioning = "int")
+ expect_true(dir.exists(dst_dir))
+ expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "=")))
+
+ new_ds <- open_dataset(dst_dir, format = "feather")
+
+ expect_equivalent(
+ new_ds %>%
+ select(string = chr, integer = int, group) %>%
+ filter(integer > 6 & group == 1) %>%
+ collect() %>%
+ summarize(mean = mean(integer)),
+ df1 %>%
+ select(string = chr, integer = int) %>%
+ filter(integer > 6) %>%
+ summarize(mean = mean(integer))
+ )
+})
+
+test_that("Dataset writing: dplyr methods", {
+ skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
+ ds <- open_dataset(hive_dir)
+ dst_dir <- tempfile()
+ # Specify partition vars by group_by
+ ds %>% group_by(int) %>% write_dataset(dst_dir, format = "feather")
+ expect_true(dir.exists(dst_dir))
+ expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "=")))
+
+ # select to specify schema
+ skip("TODO: select to specify schema")
+ ds %>% group_by(int) %>% select(lgl, chr) %>% write_dataset(dst_dir, format = "feather")
+ new_ds <- open_dataset(dst_dir, format = "feather")
+
+ expect_equivalent(
+ collect(new_ds),
+ rbind(df1[c("lgl", "chr", "int")], df2[c("lgl", "chr", "int")])
+ )
+})
+
+test_that("Dataset writing: non-hive", {
+ skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
+ ds <- open_dataset(hive_dir)
+ dst_dir <- tempfile()
+ write_dataset(ds, dst_dir, format = "feather", partitioning = "int", hive_style = FALSE)
+ expect_true(dir.exists(dst_dir))
+ expect_identical(dir(dst_dir), sort(as.character(c(1:10, 101:110))))
+})
+
+test_that("Dataset writing: no partitioning", {
+ skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
+ ds <- open_dataset(hive_dir)
+ dst_dir <- tempfile()
+ write_dataset(ds, dst_dir, format = "feather", partitioning = NULL)
+ expect_true(dir.exists(dst_dir))
+ expect_true(length(dir(dst_dir)) > 1)
+})
+
+test_that("Dataset writing: unsupported features/input validation", {
+ expect_error(write_dataset(4), "'dataset' must be a Dataset")
+
+ ds <- open_dataset(hive_dir)
+
+ expect_error(write_dataset(ds, format = "csv"), "Unsupported format")
+ expect_error(
+ filter(ds, int == 4) %>% write_dataset(ds),
+ "Writing a filtered dataset is not yet supported"
+ )
+
+ expect_error(
+ write_dataset(ds, partitioning = c("int", "NOTACOLUMN"), format = "ipc"),
+ 'Invalid field name: "NOTACOLUMN"'
+ )
+})
diff --git a/r/tests/testthat/test-schema.R b/r/tests/testthat/test-schema.R
index 6671828..23b08da 100644
--- a/r/tests/testthat/test-schema.R
+++ b/r/tests/testthat/test-schema.R
@@ -45,9 +45,27 @@ test_that("Schema $GetFieldByName", {
expect_null(schm$GetFieldByName("f"))
# TODO: schema(b = double(), b = string())$GetFieldByName("b")
# also returns NULL and probably should error bc duplicated names
+})
+test_that("Schema extract (returns Field)", {
+ schm <- schema(b = double(), c = string())
expect_equal(schm$b, field("b", double()))
expect_equal(schm[["b"]], field("b", double()))
+ expect_equal(schm[[1]], field("b", double()))
+
+ expect_null(schm[["ZZZ"]])
+ expect_error(schm[[42]]) # Should have better error message
+})
+
+test_that("Schema slicing", {
+ schm <- schema(b = double(), c = string(), d = int8())
+ expect_equal(schm[2:3], schema(c = string(), d = int8()))
+ expect_equal(schm[-1], schema(c = string(), d = int8()))
+ expect_equal(schm[c("d", "c")], schema(d = int8(), c = string()))
+ expect_equal(schm[c(FALSE, TRUE, TRUE)], schema(c = string(), d = int8()))
+ expect_error(schm[c("c", "ZZZ")], 'Invalid field name: "ZZZ"')
+ expect_error(schm[c("XXX", "c", "ZZZ")], 'Invalid field names: "XXX" and "ZZZ"')
+
})
test_that("reading schema from Buffer", {