You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2019/04/09 09:52:59 UTC
[arrow] branch master updated: ARROW-3200: [C++] Support
dictionaries in Flight streams
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d8e4763 ARROW-3200: [C++] Support dictionaries in Flight streams
d8e4763 is described below
commit d8e4763c4fafce9bb8fa9621d69a2ac4200186ea
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Tue Apr 9 11:52:50 2019 +0200
ARROW-3200: [C++] Support dictionaries in Flight streams
Author: Antoine Pitrou <an...@python.org>
Closes #4113 from pitrou/ARROW-3200-flight-dicts and squashes the following commits:
efe0f5adf <Antoine Pitrou> Address review comments
f9f458df5 <Antoine Pitrou> ARROW-3200: Support dictionaries in Flight streams
---
ci/travis_script_python.sh | 4 +
cpp/src/arrow/CMakeLists.txt | 1 +
cpp/src/arrow/array-test.cc | 40 --
cpp/src/arrow/array-union-test.cc | 74 +++
cpp/src/arrow/flight/client.cc | 229 +++----
cpp/src/arrow/flight/client.h | 17 -
cpp/src/arrow/flight/customize_protobuf.h | 2 +
cpp/src/arrow/flight/flight-test.cc | 246 ++++---
cpp/src/arrow/flight/protocol-internal.h | 2 +
cpp/src/arrow/flight/serialization-internal.cc | 48 ++
cpp/src/arrow/flight/serialization-internal.h | 22 +-
cpp/src/arrow/flight/server.cc | 63 +-
cpp/src/arrow/flight/test-server.cc | 27 +-
cpp/src/arrow/flight/test-util.cc | 56 +-
cpp/src/arrow/flight/test-util.h | 29 +-
cpp/src/arrow/flight/types.cc | 42 ++
cpp/src/arrow/flight/types.h | 16 +
cpp/src/arrow/gpu/cuda-test.cc | 3 +-
cpp/src/arrow/ipc/dictionary.h | 2 +
cpp/src/arrow/ipc/feather-test.cc | 22 +-
cpp/src/arrow/ipc/json-internal.cc | 12 +-
cpp/src/arrow/ipc/json-test.cc | 3 +
cpp/src/arrow/ipc/metadata-internal.cc | 5 +-
cpp/src/arrow/ipc/metadata-internal.h | 2 +-
cpp/src/arrow/ipc/read-write-test.cc | 53 +-
cpp/src/arrow/ipc/reader.cc | 67 +-
cpp/src/arrow/ipc/reader.h | 5 +-
.../arrow/ipc/{test-common.h => test-common.cc} | 170 ++---
cpp/src/arrow/ipc/test-common.h | 713 ++-------------------
cpp/src/arrow/ipc/writer.cc | 394 +++++++-----
cpp/src/arrow/ipc/writer.h | 40 +-
cpp/src/arrow/python/arrow_to_pandas.cc | 7 +-
cpp/src/arrow/python/numpy_to_arrow.cc | 9 +-
cpp/src/gandiva/expression_registry.cc | 11 +-
cpp/src/gandiva/jni/expression_registry_helper.cc | 9 +-
35 files changed, 1055 insertions(+), 1390 deletions(-)
diff --git a/ci/travis_script_python.sh b/ci/travis_script_python.sh
index 710ebb9..5c95eb5 100755
--- a/ci/travis_script_python.sh
+++ b/ci/travis_script_python.sh
@@ -90,6 +90,10 @@ CMAKE_COMMON_FLAGS="-DARROW_EXTRA_ERROR_CONTEXT=ON"
PYTHON_CPP_BUILD_TARGETS="arrow_python-all plasma parquet"
+if [ "$ARROW_TRAVIS_FLIGHT" == "1" ]; then
+ CMAKE_COMMON_FLAGS="$CMAKE_COMMON_FLAGS -DARROW_FLIGHT=ON"
+fi
+
if [ "$ARROW_TRAVIS_COVERAGE" == "1" ]; then
CMAKE_COMMON_FLAGS="$CMAKE_COMMON_FLAGS -DARROW_GENERATE_COVERAGE=ON"
fi
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 6854f14..d1f4852 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -274,6 +274,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS)
# that depend on gtest
add_arrow_lib(arrow_testing
SOURCES
+ ipc/test-common.cc
testing/gtest_util.cc
testing/random.cc
OUTPUTS
diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc
index ad90e46..6826c72 100644
--- a/cpp/src/arrow/array-test.cc
+++ b/cpp/src/arrow/array-test.cc
@@ -34,7 +34,6 @@
#include "arrow/buffer-builder.h"
#include "arrow/buffer.h"
#include "arrow/builder.h"
-#include "arrow/ipc/test-common.h"
#include "arrow/memory_pool.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
@@ -1690,45 +1689,6 @@ TEST_F(TestAdaptiveUIntBuilder, TestAppendNulls) {
}
}
-// ----------------------------------------------------------------------
-// Union tests
-
-TEST(TestUnionArrayAdHoc, TestSliceEquals) {
- std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(ipc::MakeUnion(&batch));
-
- const int64_t size = batch->num_rows();
-
- auto CheckUnion = [&size](std::shared_ptr<Array> array) {
- std::shared_ptr<Array> slice, slice2;
- slice = array->Slice(2);
- ASSERT_EQ(size - 2, slice->length());
-
- slice2 = array->Slice(2);
- ASSERT_EQ(size - 2, slice->length());
-
- ASSERT_TRUE(slice->Equals(slice2));
- ASSERT_TRUE(array->RangeEquals(2, array->length(), 0, slice));
-
- // Chained slices
- slice2 = array->Slice(1)->Slice(1);
- ASSERT_TRUE(slice->Equals(slice2));
-
- slice = array->Slice(1, 5);
- slice2 = array->Slice(1, 5);
- ASSERT_EQ(5, slice->length());
-
- ASSERT_TRUE(slice->Equals(slice2));
- ASSERT_TRUE(array->RangeEquals(1, 6, 0, slice));
-
- AssertZeroPadded(*array);
- TestInitialized(*array);
- };
-
- CheckUnion(batch->column(1));
- CheckUnion(batch->column(2));
-}
-
using DecimalVector = std::vector<Decimal128>;
class DecimalTest : public ::testing::TestWithParam<int> {
diff --git a/cpp/src/arrow/array-union-test.cc b/cpp/src/arrow/array-union-test.cc
new file mode 100644
index 0000000..067d195
--- /dev/null
+++ b/cpp/src/arrow/array-union-test.cc
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/builder.h"
+#include "arrow/status.h"
+// TODO ipc shouldn't be included here
+#include "arrow/ipc/test-common.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+TEST(TestUnionArrayAdHoc, TestSliceEquals) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeUnion(&batch));
+
+ auto CheckUnion = [](std::shared_ptr<Array> array) {
+ const int64_t size = array->length();
+ std::shared_ptr<Array> slice, slice2;
+ slice = array->Slice(2);
+ ASSERT_EQ(size - 2, slice->length());
+
+ slice2 = array->Slice(2);
+ ASSERT_EQ(size - 2, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(2, array->length(), 0, slice));
+
+ // Chained slices
+ slice2 = array->Slice(1)->Slice(1);
+ ASSERT_TRUE(slice->Equals(slice2));
+
+ slice = array->Slice(1, 5);
+ slice2 = array->Slice(1, 5);
+ ASSERT_EQ(5, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(1, 6, 0, slice));
+
+ AssertZeroPadded(*array);
+ TestInitialized(*array);
+ };
+
+ CheckUnion(batch->column(1));
+ CheckUnion(batch->column(2));
+}
+
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 154d34f..28b237d 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -16,7 +16,6 @@
// under the License.
#include "arrow/flight/client.h"
-#include "arrow/flight/protocol-internal.h" // IWYU pragma: keep
#include <memory>
#include <sstream>
@@ -30,10 +29,10 @@
#include <grpc++/grpc++.h>
#endif
-#include "arrow/ipc/dictionary.h"
-#include "arrow/ipc/metadata-internal.h"
+#include "arrow/buffer.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
+#include "arrow/memory_pool.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
#include "arrow/type.h"
@@ -58,86 +57,90 @@ struct ClientRpc {
/// XXX workaround until we have a handshake in Connect
context.set_wait_for_ready(true);
}
+
+ Status IOError(const std::string& error_message) {
+ std::stringstream ss;
+ ss << error_message << context.debug_error_string();
+ return Status::IOError(ss.str());
+ }
};
-class FlightStreamReader : public RecordBatchReader {
+class FlightIpcMessageReader : public ipc::MessageReader {
public:
- FlightStreamReader(std::unique_ptr<ClientRpc> rpc,
- const std::shared_ptr<Schema>& schema,
- std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream)
- : rpc_(std::move(rpc)),
- stream_finished_(false),
- schema_(schema),
- stream_(std::move(stream)) {}
-
- std::shared_ptr<Schema> schema() const override { return schema_; }
-
- Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
- internal::FlightData data;
+ FlightIpcMessageReader(std::unique_ptr<ClientRpc> rpc,
+ std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream)
+ : rpc_(std::move(rpc)), stream_(std::move(stream)), stream_finished_(false) {}
+ Status ReadNextMessage(std::unique_ptr<ipc::Message>* out) override {
if (stream_finished_) {
*out = nullptr;
return Status::OK();
}
-
- // Pretend to be pb::FlightData and intercept in SerializationTraits
- if (stream_->Read(reinterpret_cast<pb::FlightData*>(&data))) {
- std::unique_ptr<ipc::Message> message;
-
- // Validate IPC message
- RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message));
- if (message->type() == ipc::Message::Type::RECORD_BATCH) {
- return ipc::ReadRecordBatch(*message, schema_, out);
- } else if (message->type() == ipc::Message::Type::SCHEMA) {
- return Status(StatusCode::Invalid, "Flight stream changed schema midway");
- } else {
- return Status(StatusCode::Invalid, "Unrecognized message in Flight stream");
- }
- } else {
+ internal::FlightData data;
+ if (!internal::ReadPayload(stream_.get(), &data)) {
// Stream is completed
stream_finished_ = true;
*out = nullptr;
- return internal::FromGrpcStatus(stream_->Finish());
+ return OverrideWithServerError(Status::OK());
}
+ // Validate IPC message
+ auto st = data.OpenMessage(out);
+ if (!st.ok()) {
+ return OverrideWithServerError(std::move(st));
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status OverrideWithServerError(Status&& st) {
+ // Get the gRPC status if not OK, to propagate any server error message
+ RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish()));
+ return st;
}
- private:
// The RPC context lifetime must be coupled to the ClientReader
std::unique_ptr<ClientRpc> rpc_;
-
- bool stream_finished_;
- std::shared_ptr<Schema> schema_;
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream_;
+ bool stream_finished_;
};
-/// \brief A RecordBatchWriter implementation that writes to a Flight
-/// DoPut stream.
-class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter {
+/// A IpcPayloadWriter implementation that writes to a DoPut stream
+class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
public:
- explicit FlightPutWriterImpl(std::unique_ptr<ClientRpc> rpc,
- const FlightDescriptor& descriptor,
- const std::shared_ptr<Schema>& schema,
- MemoryPool* pool = default_memory_pool())
- : rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), pool_(pool) {}
+ DoPutPayloadWriter(const FlightDescriptor& descriptor, std::unique_ptr<ClientRpc> rpc,
+ std::unique_ptr<protocol::PutResult> response,
+ std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer)
+ : descriptor_(descriptor),
+ rpc_(std::move(rpc)),
+ response_(std::move(response)),
+ writer_(std::move(writer)),
+ first_payload_(true) {}
+
+ ~DoPutPayloadWriter() override = default;
- Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override {
+ Status Start() override { return Status::OK(); }
+
+ Status WritePayload(const ipc::internal::IpcPayload& ipc_payload) override {
FlightPayload payload;
- RETURN_NOT_OK(
- ipc::internal::GetRecordBatchPayload(batch, pool_, &payload.ipc_message));
+ payload.ipc_message = ipc_payload;
+
+ if (first_payload_) {
+ // First Flight message needs to encore the Flight descriptor
+ DCHECK_EQ(ipc_payload.type, ipc::Message::SCHEMA);
+ std::string str_descr;
+ {
+ pb::FlightDescriptor pb_descr;
+ RETURN_NOT_OK(internal::ToProto(descriptor_, &pb_descr));
+ if (!pb_descr.SerializeToString(&str_descr)) {
+ return Status::UnknownError("Failed to serialized Flight descriptor");
+ }
+ }
+ RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
+ first_payload_ = false;
+ }
-#ifndef _WIN32
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wstrict-aliasing"
-#endif
- if (!writer_->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
- grpc::WriteOptions())) {
-#ifndef _WIN32
-#pragma GCC diagnostic pop
-#endif
- std::stringstream ss;
- ss << "Could not write record batch to stream: "
- << rpc_->context.debug_error_string();
- return Status::IOError(ss.str());
+ if (!internal::WritePayload(payload, writer_.get())) {
+ return rpc_->IOError("Could not write record batch to stream: ");
}
return Status::OK();
}
@@ -152,43 +155,15 @@ class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter {
return Status::OK();
}
- void set_memory_pool(MemoryPool* pool) override { pool_ = pool; }
-
- private:
- /// \brief Set the gRPC writer backing this Flight stream.
- /// \param [in] writer the gRPC writer
- void set_stream(std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer) {
- writer_ = std::move(writer);
- }
-
+ protected:
// TODO: there isn't a way to access this as a user.
- protocol::PutResult response;
+ const FlightDescriptor descriptor_;
std::unique_ptr<ClientRpc> rpc_;
- FlightDescriptor descriptor_;
- std::shared_ptr<Schema> schema_;
+ std::unique_ptr<protocol::PutResult> response_;
std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer_;
- MemoryPool* pool_;
-
- // We need to reference some fields
- friend class FlightClient;
+ bool first_payload_;
};
-FlightPutWriter::~FlightPutWriter() {}
-
-FlightPutWriter::FlightPutWriter(std::unique_ptr<FlightPutWriterImpl> impl) {
- impl_ = std::move(impl);
-}
-
-Status FlightPutWriter::WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) {
- return impl_->WriteRecordBatch(batch, allow_64bit);
-}
-
-Status FlightPutWriter::Close() { return impl_->Close(); }
-
-void FlightPutWriter::set_memory_pool(MemoryPool* pool) {
- return impl_->set_memory_pool(pool);
-}
-
class FlightClient::FlightClientImpl {
public:
Status Connect(const std::string& host, int port) {
@@ -218,13 +193,13 @@ class FlightClient::FlightClientImpl {
std::vector<FlightInfo> flights;
pb::FlightGetInfo pb_info;
- FlightInfo::Data info_data;
while (stream->Read(&pb_info)) {
+ FlightInfo::Data info_data;
RETURN_NOT_OK(internal::FromProto(pb_info, &info_data));
- flights.emplace_back(FlightInfo(std::move(info_data)));
+ flights.emplace_back(std::move(info_data));
}
- listing->reset(new SimpleFlightListing(flights));
+ listing->reset(new SimpleFlightListing(std::move(flights)));
return internal::FromGrpcStatus(stream->Finish());
}
@@ -292,65 +267,23 @@ class FlightClient::FlightClientImpl {
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream(
stub_->DoGet(&rpc->context, pb_ticket));
- // First message must be the schema
- std::shared_ptr<Schema> schema;
- internal::FlightData data;
- if (!stream->Read(reinterpret_cast<pb::FlightData*>(&data))) {
- // Get the gRPC status if not OK, to get any server error
- // messages
- RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish()));
- return Status(StatusCode::Invalid, "No data in Flight stream");
- }
- std::unique_ptr<ipc::Message> message;
- RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message));
- if (message->type() != ipc::Message::Type::SCHEMA) {
- return Status(StatusCode::Invalid, "Flight stream did not start with schema");
- }
- RETURN_NOT_OK(ipc::ReadSchema(*message, &schema));
-
- *out = std::unique_ptr<RecordBatchReader>(
- new FlightStreamReader(std::move(rpc), schema, std::move(stream)));
- return Status::OK();
+ std::unique_ptr<ipc::MessageReader> message_reader(
+ new FlightIpcMessageReader(std::move(rpc), std::move(stream)));
+ return ipc::RecordBatchStreamReader::Open(std::move(message_reader), out);
}
Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
- std::unique_ptr<ipc::RecordBatchWriter>* stream) {
+ std::unique_ptr<ipc::RecordBatchWriter>* out) {
std::unique_ptr<ClientRpc> rpc(new ClientRpc);
- std::unique_ptr<FlightPutWriter::FlightPutWriterImpl> out(
- new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema));
- std::unique_ptr<grpc::ClientWriter<pb::FlightData>> write_stream(
- stub_->DoPut(&out->rpc_->context, &out->response));
+ std::unique_ptr<protocol::PutResult> response(new protocol::PutResult);
+ std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer(
+ stub_->DoPut(&rpc->context, response.get()));
- // First write the descriptor and schema to the stream.
- FlightPayload payload;
- ipc::DictionaryMemo dictionary_memo;
- RETURN_NOT_OK(ipc::internal::GetSchemaPayload(*schema, out->pool_, &dictionary_memo,
- &payload.ipc_message));
- pb::FlightDescriptor pb_descr;
- RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descr));
- std::string str_descr;
- pb_descr.SerializeToString(&str_descr);
- RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
-
-#ifndef _WIN32
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wstrict-aliasing"
-#endif
- if (!write_stream->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
- grpc::WriteOptions())) {
-#ifndef _WIN32
-#pragma GCC diagnostic pop
-#endif
- std::stringstream ss;
- ss << "Could not write descriptor and schema to stream: "
- << rpc->context.debug_error_string();
- return Status::IOError(ss.str());
- }
+ std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(
+ new DoPutPayloadWriter(descriptor, std::move(rpc), std::move(response),
+ std::move(writer)));
- out->set_stream(std::move(write_stream));
- *stream =
- std::unique_ptr<ipc::RecordBatchWriter>(new FlightPutWriter(std::move(out)));
- return Status::OK();
+ return ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema, out);
}
private:
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 6277c15..3603908 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -109,22 +109,5 @@ class ARROW_EXPORT FlightClient {
std::unique_ptr<FlightClientImpl> impl_;
};
-/// \brief An interface to upload record batches to a Flight server
-class ARROW_EXPORT FlightPutWriter : public ipc::RecordBatchWriter {
- public:
- ~FlightPutWriter() override;
-
- Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override;
- Status Close() override;
- void set_memory_pool(MemoryPool* pool) override;
-
- private:
- class FlightPutWriterImpl;
- explicit FlightPutWriter(std::unique_ptr<FlightPutWriterImpl> impl);
- std::unique_ptr<FlightPutWriterImpl> impl_;
-
- friend class FlightClient;
-};
-
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/customize_protobuf.h b/cpp/src/arrow/flight/customize_protobuf.h
index 58c168c..1d67480 100644
--- a/cpp/src/arrow/flight/customize_protobuf.h
+++ b/cpp/src/arrow/flight/customize_protobuf.h
@@ -100,6 +100,8 @@ template <class T>
class SerializationTraits<T, typename std::enable_if<std::is_same<
arrow::flight::protocol::FlightData, T>::value>::type> {
public:
+ // In the functions below, we cast back the Message argument to its real
+ // type (see ReadPayload() and WritePayload() for the initial cast).
static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb,
bool* own_buffer) {
return arrow::flight::internal::FlightDataSerialize(
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index b099146..cf67e29 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -47,6 +47,90 @@ namespace pb = arrow::flight::protocol;
namespace arrow {
namespace flight {
+void AssertEqual(const ActionType& expected, const ActionType& actual) {
+ ASSERT_EQ(expected.type, actual.type);
+ ASSERT_EQ(expected.description, actual.description);
+}
+
+void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) {
+ ASSERT_TRUE(expected.Equals(actual));
+}
+
+void AssertEqual(const Ticket& expected, const Ticket& actual) {
+ ASSERT_EQ(expected.ticket, actual.ticket);
+}
+
+void AssertEqual(const Location& expected, const Location& actual) {
+ ASSERT_EQ(expected.host, actual.host);
+ ASSERT_EQ(expected.port, actual.port);
+}
+
+void AssertEqual(const std::vector<FlightEndpoint>& expected,
+ const std::vector<FlightEndpoint>& actual) {
+ ASSERT_EQ(expected.size(), actual.size());
+ for (size_t i = 0; i < expected.size(); ++i) {
+ AssertEqual(expected[i].ticket, actual[i].ticket);
+
+ ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size());
+ for (size_t j = 0; j < expected[i].locations.size(); ++j) {
+ AssertEqual(expected[i].locations[j], actual[i].locations[j]);
+ }
+ }
+}
+
+template <typename T>
+void AssertEqual(const std::vector<T>& expected, const std::vector<T>& actual) {
+ ASSERT_EQ(expected.size(), actual.size());
+ for (size_t i = 0; i < expected.size(); ++i) {
+ AssertEqual(expected[i], actual[i]);
+ }
+}
+
+void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) {
+ std::shared_ptr<Schema> ex_schema, actual_schema;
+ ASSERT_OK(expected.GetSchema(&ex_schema));
+ ASSERT_OK(actual.GetSchema(&actual_schema));
+
+ AssertSchemaEqual(*ex_schema, *actual_schema);
+ ASSERT_EQ(expected.total_records(), actual.total_records());
+ ASSERT_EQ(expected.total_bytes(), actual.total_bytes());
+
+ AssertEqual(expected.descriptor(), actual.descriptor());
+ AssertEqual(expected.endpoints(), actual.endpoints());
+}
+
+TEST(TestFlightDescriptor, Basics) {
+ auto a = FlightDescriptor::Command("select * from table");
+ auto b = FlightDescriptor::Command("select * from table");
+ auto c = FlightDescriptor::Command("select foo from table");
+ auto d = FlightDescriptor::Path({"foo", "bar"});
+ auto e = FlightDescriptor::Path({"foo", "baz"});
+ auto f = FlightDescriptor::Path({"foo", "baz"});
+
+ ASSERT_EQ(a.ToString(), "FlightDescriptor<cmd = 'select * from table'>");
+ ASSERT_EQ(d.ToString(), "FlightDescriptor<path = 'foo/bar'>");
+ ASSERT_TRUE(a.Equals(b));
+ ASSERT_FALSE(a.Equals(c));
+ ASSERT_FALSE(a.Equals(d));
+ ASSERT_FALSE(d.Equals(e));
+ ASSERT_TRUE(e.Equals(f));
+}
+
+TEST(TestFlightDescriptor, ToFromProto) {
+ FlightDescriptor descr_test;
+ pb::FlightDescriptor pb_descr;
+
+ FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
+ ASSERT_OK(internal::ToProto(descr1, &pb_descr));
+ ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
+ AssertEqual(descr1, descr_test);
+
+ FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}};
+ ASSERT_OK(internal::ToProto(descr2, &pb_descr));
+ ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
+ AssertEqual(descr2, descr_test);
+}
+
TEST(TestFlight, StartStopTestServer) {
TestServer server("flight-test-server", 30000);
server.Start();
@@ -85,72 +169,52 @@ class TestFlightClient : public ::testing::Test {
Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); }
+ template <typename EndpointCheckFunc>
+ void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
+ EndpointCheckFunc&& check_endpoints) {
+ auto num_batches = static_cast<int>(expected_batches.size());
+ DCHECK_GE(num_batches, 2);
+ auto expected_schema = expected_batches[0]->schema();
+
+ std::unique_ptr<FlightInfo> info;
+ ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ check_endpoints(info->endpoints());
+
+ std::shared_ptr<Schema> schema;
+ ASSERT_OK(info->GetSchema(&schema));
+ AssertSchemaEqual(*expected_schema, *schema);
+
+ // By convention, fetch the first endpoint
+ Ticket ticket = info->endpoints()[0].ticket;
+ std::unique_ptr<RecordBatchReader> stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ std::shared_ptr<RecordBatch> chunk;
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(stream->ReadNext(&chunk));
+ ASSERT_NE(nullptr, chunk);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk);
+ }
+
+ // Stream exhausted
+ ASSERT_OK(stream->ReadNext(&chunk));
+ ASSERT_EQ(nullptr, chunk);
+ }
+
protected:
int port_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<TestServer> server_;
};
-// The server implementation is in test-server.cc; to make changes to the
-// expected results, make edits there
-void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) {}
-
-void AssertEqual(const Ticket& expected, const Ticket& actual) {
- ASSERT_EQ(expected.ticket, actual.ticket);
-}
-
-void AssertEqual(const Location& expected, const Location& actual) {
- ASSERT_EQ(expected.host, actual.host);
- ASSERT_EQ(expected.port, actual.port);
-}
-
-void AssertEqual(const std::vector<FlightEndpoint>& expected,
- const std::vector<FlightEndpoint>& actual) {
- ASSERT_EQ(expected.size(), actual.size());
- for (size_t i = 0; i < expected.size(); ++i) {
- AssertEqual(expected[i].ticket, actual[i].ticket);
-
- ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size());
- for (size_t j = 0; j < expected[i].locations.size(); ++j) {
- AssertEqual(expected[i].locations[j], actual[i].locations[j]);
- }
- }
-}
-
-void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) {
- std::shared_ptr<Schema> ex_schema, actual_schema;
- ASSERT_OK(expected.GetSchema(&ex_schema));
- ASSERT_OK(actual.GetSchema(&actual_schema));
-
- AssertSchemaEqual(*ex_schema, *actual_schema);
- ASSERT_EQ(expected.total_records(), actual.total_records());
- ASSERT_EQ(expected.total_bytes(), actual.total_bytes());
-
- AssertEqual(expected.descriptor(), actual.descriptor());
- AssertEqual(expected.endpoints(), actual.endpoints());
-}
-
-void AssertEqual(const ActionType& expected, const ActionType& actual) {
- ASSERT_EQ(expected.type, actual.type);
- ASSERT_EQ(expected.description, actual.description);
-}
-
-template <typename T>
-void AssertEqual(const std::vector<T>& expected, const std::vector<T>& actual) {
- ASSERT_EQ(expected.size(), actual.size());
- for (size_t i = 0; i < expected.size(); ++i) {
- AssertEqual(expected[i], actual[i]);
- }
-}
-
TEST_F(TestFlightClient, ListFlights) {
std::unique_ptr<FlightListing> listing;
ASSERT_OK(client_->ListFlights(&listing));
ASSERT_TRUE(listing != nullptr);
std::vector<FlightInfo> flights = ExampleFlightInfo();
- std::unique_ptr<FlightInfo> info;
+ std::unique_ptr<FlightInfo> info;
for (const FlightInfo& flight : flights) {
ASSERT_OK(listing->Next(&info));
AssertEqual(flight, *info);
@@ -159,66 +223,56 @@ TEST_F(TestFlightClient, ListFlights) {
ASSERT_TRUE(info == nullptr);
ASSERT_OK(listing->Next(&info));
+ ASSERT_TRUE(info == nullptr);
}
TEST_F(TestFlightClient, GetFlightInfo) {
- FlightDescriptor descr{FlightDescriptor::PATH, "", {"foo", "bar"}};
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
std::unique_ptr<FlightInfo> info;
- ASSERT_OK(client_->GetFlightInfo(descr, &info));
- ASSERT_TRUE(info != nullptr);
+ ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ ASSERT_NE(info, nullptr);
std::vector<FlightInfo> flights = ExampleFlightInfo();
AssertEqual(flights[0], *info);
}
-TEST(TestFlightProtocol, FlightDescriptor) {
- FlightDescriptor descr_test;
- pb::FlightDescriptor pb_descr;
-
- FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
- ASSERT_OK(internal::ToProto(descr1, &pb_descr));
- ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
- AssertEqual(descr1, descr_test);
-
- FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}};
- ASSERT_OK(internal::ToProto(descr2, &pb_descr));
- ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
- AssertEqual(descr2, descr_test);
-}
-
-TEST_F(TestFlightClient, DoGet) {
- FlightDescriptor descr{FlightDescriptor::PATH, "", {"foo", "bar"}};
+TEST_F(TestFlightClient, GetFlightInfoNotFound) {
+ auto descr = FlightDescriptor::Path({"examples", "things"});
std::unique_ptr<FlightInfo> info;
- ASSERT_OK(client_->GetFlightInfo(descr, &info));
-
- // Two endpoints in the example FlightInfo
- ASSERT_EQ(2, info->endpoints().size());
-
- Ticket ticket = info->endpoints()[0].ticket;
- AssertEqual(Ticket{"ticket-id-1"}, ticket);
+ // XXX Ideally should be Invalid (or KeyError), but gRPC doesn't support
+ // multiple error codes.
+ auto st = client_->GetFlightInfo(descr, &info);
+ ASSERT_RAISES(IOError, st);
+ ASSERT_NE(st.message().find("Flight not found"), std::string::npos);
+}
- std::shared_ptr<Schema> schema;
- ASSERT_OK(info->GetSchema(&schema));
+TEST_F(TestFlightClient, DoGetInts) {
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
- auto expected_schema = ExampleSchema1();
- AssertSchemaEqual(*expected_schema, *schema);
+ auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
+ // Two endpoints in the example FlightInfo
+ ASSERT_EQ(2, endpoints.size());
+ AssertEqual(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
+ };
- std::unique_ptr<RecordBatchReader> stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
+ CheckDoGet(descr, expected_batches, check_endpoints);
+}
+TEST_F(TestFlightClient, DoGetDicts) {
+ auto descr = FlightDescriptor::Path({"examples", "dicts"});
BatchVector expected_batches;
- const int num_batches = 5;
- ASSERT_OK(SimpleIntegerBatches(num_batches, &expected_batches));
- std::shared_ptr<RecordBatch> chunk;
- for (int i = 0; i < num_batches; ++i) {
- ASSERT_OK(stream->ReadNext(&chunk));
- ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk);
- }
+ ASSERT_OK(ExampleDictBatches(&expected_batches));
+
+ auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
+ // One endpoint in the example FlightInfo
+ ASSERT_EQ(1, endpoints.size());
+ AssertEqual(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
+ };
- // Stream exhausted
- ASSERT_OK(stream->ReadNext(&chunk));
- ASSERT_EQ(nullptr, chunk);
+ CheckDoGet(descr, expected_batches, check_endpoints);
}
TEST_F(TestFlightClient, ListActions) {
diff --git a/cpp/src/arrow/flight/protocol-internal.h b/cpp/src/arrow/flight/protocol-internal.h
index 2e8dd32..848c1a8 100644
--- a/cpp/src/arrow/flight/protocol-internal.h
+++ b/cpp/src/arrow/flight/protocol-internal.h
@@ -16,6 +16,8 @@
#pragma once
+// This header holds the Flight protobuf definitions.
+
// Need to include this first to get our gRPC customizations
#include "arrow/flight/customize_protobuf.h" // IWYU pragma: export
diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc
index b7e566c..b7a227a 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/serialization-internal.cc
@@ -24,6 +24,7 @@
#include "arrow/util/config.h"
+#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/wire_format_lite.h>
#include <grpc/byte_buffer_reader.h>
@@ -307,6 +308,53 @@ grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) {
return grpc::Status::OK;
}
+Status FlightData::OpenMessage(std::unique_ptr<ipc::Message>* message) {
+ return ipc::Message::Open(metadata, body, message);
+}
+
+// The pointer bitcast hack below causes legitimate warnings, silence them.
+#ifndef _WIN32
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wstrict-aliasing"
+#endif
+
+// Pointer bitcast explanation: grpc::*Writer<T>::Write() and grpc::*Reader<T>::Read()
+// both take a T* argument (here pb::FlightData*). But they don't do anything
+// with that argument except pass it to SerializationTraits<T>::Serialize() and
+// SerializationTraits<T>::Deserialize().
+//
+// Since we control SerializationTraits<pb::FlightData>, we can interpret the
+// pointer argument whichever way we want, including cast it back to the original type.
+// (see customize_protobuf.h).
+
+bool WritePayload(const FlightPayload& payload,
+ grpc::ClientWriter<pb::FlightData>* writer) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions());
+}
+
+bool WritePayload(const FlightPayload& payload,
+ grpc::ServerWriter<pb::FlightData>* writer) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions());
+}
+
+bool ReadPayload(grpc::ClientReader<pb::FlightData>* reader, FlightData* data) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return reader->Read(reinterpret_cast<pb::FlightData*>(data));
+}
+
+bool ReadPayload(grpc::ServerReader<pb::FlightData>* reader, FlightData* data) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return reader->Read(reinterpret_cast<pb::FlightData*>(data));
+}
+
+#ifndef _WIN32
+#pragma GCC diagnostic pop
+#endif
+
} // namespace internal
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h
index 4576290..aa47af6 100644
--- a/cpp/src/arrow/flight/serialization-internal.h
+++ b/cpp/src/arrow/flight/serialization-internal.h
@@ -20,15 +20,12 @@
#pragma once
-// Enable gRPC customizations
-#include "arrow/flight/protocol-internal.h" // IWYU pragma: keep
-
#include <memory>
-#include <google/protobuf/io/coded_stream.h>
-
#include "arrow/flight/internal.h"
#include "arrow/flight/types.h"
+#include "arrow/ipc/message.h"
+#include "arrow/status.h"
namespace arrow {
@@ -48,8 +45,23 @@ struct FlightData {
/// Message body
std::shared_ptr<Buffer> body;
+
+ /// Open IPC message from the metadata and body
+ Status OpenMessage(std::unique_ptr<ipc::Message>* message);
};
+/// Write Flight message on gRPC stream with zero-copy optimizations.
+/// True is returned on success, false if some error occurred (connection closed?).
+bool WritePayload(const FlightPayload& payload,
+ grpc::ClientWriter<pb::FlightData>* writer);
+bool WritePayload(const FlightPayload& payload,
+ grpc::ServerWriter<pb::FlightData>* writer);
+
+/// Read Flight message from gRPC stream with zero-copy optimizations.
+/// True is returned on success, false if stream ended.
+bool ReadPayload(grpc::ClientReader<pb::FlightData>* reader, FlightData* data);
+bool ReadPayload(grpc::ServerReader<pb::FlightData>* reader, FlightData* data);
+
} // namespace internal
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 5a8dc7e..29de44a 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -16,7 +16,6 @@
// under the License.
#include "arrow/flight/server.h"
-#include "arrow/flight/protocol-internal.h"
#include <signal.h>
#include <atomic>
@@ -32,7 +31,7 @@
#include <grpc++/grpc++.h>
#endif
-#include "arrow/ipc/dictionary.h"
+#include "arrow/buffer.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/memory_pool.h"
@@ -81,19 +80,11 @@ class FlightMessageReaderImpl : public FlightMessageReader {
}
internal::FlightData data;
- // Pretend to be pb::FlightData and intercept in SerializationTraits
-#ifndef _WIN32
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wstrict-aliasing"
-#endif
- if (reader_->Read(reinterpret_cast<pb::FlightData*>(&data))) {
-#ifndef _WIN32
-#pragma GCC diagnostic pop
-#endif
+ if (internal::ReadPayload(reader_, &data)) {
std::unique_ptr<ipc::Message> message;
// Validate IPC message
- RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message));
+ RETURN_NOT_OK(data.OpenMessage(&message));
if (message->type() == ipc::Message::Type::RECORD_BATCH) {
return ipc::ReadRecordBatch(*message, schema_, out);
} else {
@@ -126,9 +117,9 @@ class FlightServiceImpl : public FlightService::Service {
return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate");
}
// Write flight info to stream until listing is exhausted
- ProtoType pb_value;
- std::unique_ptr<UserType> value;
while (true) {
+ ProtoType pb_value;
+ std::unique_ptr<UserType> value;
GRPC_RETURN_NOT_OK(iterator->Next(&value));
if (!value) {
break;
@@ -148,8 +139,8 @@ class FlightServiceImpl : public FlightService::Service {
grpc::Status WriteStream(const std::vector<UserType>& values,
ServerWriter<ProtoType>* writer) {
// Write flight info to stream until listing is exhausted
- ProtoType pb_value;
for (const UserType& value : values) {
+ ProtoType pb_value;
GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value));
// Blocking write
if (!writer->Write(pb_value)) {
@@ -210,36 +201,34 @@ class FlightServiceImpl : public FlightService::Service {
return grpc::Status(grpc::StatusCode::NOT_FOUND, "No data in this flight");
}
- // Write the schema as the first message in the stream
- FlightPayload schema_payload;
+ // Write the schema as the first message(s) in the stream
+ // (several messages may be required if there are dictionaries)
MemoryPool* pool = default_memory_pool();
- ipc::DictionaryMemo dictionary_memo;
- GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload(
- *data_stream->schema(), pool, &dictionary_memo, &schema_payload.ipc_message));
-
- // Pretend to be pb::FlightData, we cast back to FlightPayload in
- // SerializationTraits
-#ifndef _WIN32
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wstrict-aliasing"
-#endif
- writer->Write(*reinterpret_cast<const pb::FlightData*>(&schema_payload),
- grpc::WriteOptions());
+ std::vector<ipc::internal::IpcPayload> ipc_payloads;
+ GRPC_RETURN_NOT_OK(
+ ipc::internal::GetSchemaPayloads(*data_stream->schema(), pool, &ipc_payloads));
+
+ for (auto& ipc_payload : ipc_payloads) {
+ // For DoGet, descriptor doesn't need to be written out
+ FlightPayload schema_payload;
+ schema_payload.ipc_message = std::move(ipc_payload);
+
+ if (!internal::WritePayload(schema_payload, writer)) {
+ // Connection terminated? XXX return error code?
+ return grpc::Status::OK;
+ }
+ }
+ // Write incoming data as individual messages
while (true) {
FlightPayload payload;
GRPC_RETURN_NOT_OK(data_stream->Next(&payload));
if (payload.ipc_message.metadata == nullptr ||
- !writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
- grpc::WriteOptions())) {
+ !internal::WritePayload(payload, writer))
// No more messages to write, or connection terminated for some other
// reason
break;
- }
}
-#ifndef _WIN32
-#pragma GCC diagnostic pop
-#endif
return grpc::Status::OK;
}
@@ -247,10 +236,10 @@ class FlightServiceImpl : public FlightService::Service {
pb::PutResult* response) {
// Get metadata
internal::FlightData data;
- if (reader->Read(reinterpret_cast<pb::FlightData*>(&data))) {
+ if (internal::ReadPayload(reader, &data)) {
// Message only lives as long as data
std::unique_ptr<ipc::Message> message;
- GRPC_RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message));
+ GRPC_RETURN_NOT_OK(data.OpenMessage(&message));
if (!message || message->type() != ipc::Message::Type::SCHEMA) {
return internal::ToGrpcStatus(
diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc
index 316d89f..a7049db 100644
--- a/cpp/src/arrow/flight/test-server.cc
+++ b/cpp/src/arrow/flight/test-server.cc
@@ -38,9 +38,14 @@ namespace arrow {
namespace flight {
Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
- if (ticket.ticket == "ticket-id-1") {
+ if (ticket.ticket == "ticket-ints-1") {
BatchVector batches;
- RETURN_NOT_OK(SimpleIntegerBatches(5, &batches));
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-dicts-1") {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleDictBatches(&batches));
*out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
return Status::OK();
} else {
@@ -57,20 +62,16 @@ class FlightTestServer : public FlightServerBase {
}
Status GetFlightInfo(const FlightDescriptor& request,
- std::unique_ptr<FlightInfo>* info) override {
+ std::unique_ptr<FlightInfo>* out) override {
std::vector<FlightInfo> flights = ExampleFlightInfo();
- const FlightInfo* value;
-
- // We only have one kind of flight for each descriptor type
- if (request.type == FlightDescriptor::PATH) {
- value = &flights[0];
- } else {
- value = &flights[1];
+ for (const auto& info : flights) {
+ if (info.descriptor().Equals(request)) {
+ *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
+ return Status::OK();
+ }
}
-
- *info = std::unique_ptr<FlightInfo>(new FlightInfo(*value));
- return Status::OK();
+ return Status::Invalid("Flight not found: ", request.ToString());
}
Status DoGet(const Ticket& request,
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index 7ce8ef5..71ab30c 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -128,28 +128,62 @@ Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
return internal::SchemaToString(schema, &out->schema);
}
+std::shared_ptr<Schema> ExampleIntSchema() {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", int32());
+ return ::arrow::schema({f0, f1});
+}
+
+std::shared_ptr<Schema> ExampleStringSchema() {
+ auto f0 = field("f0", utf8());
+ auto f1 = field("f1", binary());
+ return ::arrow::schema({f0, f1});
+}
+
+std::shared_ptr<Schema> ExampleDictSchema() {
+ std::shared_ptr<RecordBatch> batch;
+ ABORT_NOT_OK(ipc::test::MakeDictionary(&batch));
+ return batch->schema();
+}
+
std::vector<FlightInfo> ExampleFlightInfo() {
- FlightEndpoint endpoint1({{"ticket-id-1"}, {{"foo1.bar.com", 92385}}});
- FlightEndpoint endpoint2({{"ticket-id-2"}, {{"foo2.bar.com", 92385}}});
- FlightEndpoint endpoint3({{"ticket-id-3"}, {{"foo3.bar.com", 92385}}});
- FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
+ FlightInfo::Data flight1, flight2, flight3;
+
+ FlightEndpoint endpoint1({{"ticket-ints-1"}, {{"foo1.bar.com", 92385}}});
+ FlightEndpoint endpoint2({{"ticket-ints-2"}, {{"foo2.bar.com", 92385}}});
+ FlightEndpoint endpoint3({{"ticket-cmd"}, {{"foo3.bar.com", 92385}}});
+ FlightEndpoint endpoint4({{"ticket-dicts-1"}, {{"foo4.bar.com", 92385}}});
+
+ FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}};
FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}};
+ FlightDescriptor descr3{FlightDescriptor::PATH, "", {"examples", "dicts"}};
- auto schema1 = ExampleSchema1();
- auto schema2 = ExampleSchema2();
+ auto schema1 = ExampleIntSchema();
+ auto schema2 = ExampleStringSchema();
+ auto schema3 = ExampleDictSchema();
- FlightInfo::Data flight1, flight2;
ARROW_EXPECT_OK(
MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, 100000, &flight1));
ARROW_EXPECT_OK(MakeFlightInfo(*schema2, descr2, {endpoint3}, 1000, 100000, &flight2));
- return {FlightInfo(flight1), FlightInfo(flight2)};
+ ARROW_EXPECT_OK(MakeFlightInfo(*schema3, descr3, {endpoint4}, -1, -1, &flight3));
+ return {FlightInfo(flight1), FlightInfo(flight2), FlightInfo(flight3)};
}
-Status SimpleIntegerBatches(const int num_batches, BatchVector* out) {
+Status ExampleIntBatches(BatchVector* out) {
std::shared_ptr<RecordBatch> batch;
- for (int i = 0; i < num_batches; ++i) {
+ for (int i = 0; i < 5; ++i) {
// Make all different sizes, use different random seed
- RETURN_NOT_OK(ipc::MakeIntBatchSized(10 + i, &batch, i));
+ RETURN_NOT_OK(ipc::test::MakeIntBatchSized(10 + i, &batch, i));
+ out->push_back(batch);
+ }
+ return Status::OK();
+}
+
+Status ExampleDictBatches(BatchVector* out) {
+ // Just the same batch, repeated a few times
+ std::shared_ptr<RecordBatch> batch;
+ for (int i = 0; i < 3; ++i) {
+ RETURN_NOT_OK(ipc::test::MakeDictionary(&batch));
out->push_back(batch);
}
return Status::OK();
diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h
index 006c966..0c41ec1 100644
--- a/cpp/src/arrow/flight/test-util.h
+++ b/cpp/src/arrow/flight/test-util.h
@@ -88,31 +88,28 @@ class BatchIterator : public RecordBatchReader {
using BatchVector = std::vector<std::shared_ptr<RecordBatch>>;
-inline std::shared_ptr<Schema> ExampleSchema1() {
- auto f0 = field("f0", int32());
- auto f1 = field("f1", int32());
- return ::arrow::schema({f0, f1});
-}
-
-inline std::shared_ptr<Schema> ExampleSchema2() {
- auto f0 = field("f0", utf8());
- auto f1 = field("f1", binary());
- return ::arrow::schema({f0, f1});
-}
+ARROW_EXPORT std::shared_ptr<Schema> ExampleIntSchema();
+
+ARROW_EXPORT std::shared_ptr<Schema> ExampleStringSchema();
+
+ARROW_EXPORT std::shared_ptr<Schema> ExampleDictSchema();
ARROW_EXPORT
-Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
- const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
- int64_t total_bytes, FlightInfo::Data* out);
+Status ExampleIntBatches(BatchVector* out);
ARROW_EXPORT
-std::vector<FlightInfo> ExampleFlightInfo();
+Status ExampleDictBatches(BatchVector* out);
ARROW_EXPORT
-Status SimpleIntegerBatches(const int num_batches, BatchVector* out);
+std::vector<FlightInfo> ExampleFlightInfo();
ARROW_EXPORT
std::vector<ActionType> ExampleActionTypes();
+ARROW_EXPORT
+Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
+ const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
+ int64_t total_bytes, FlightInfo::Data* out);
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index fb8f8c6..3625bc5 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -18,6 +18,7 @@
#include "arrow/flight/types.h"
#include <memory>
+#include <sstream>
#include <utility>
#include "arrow/io/memory.h"
@@ -27,6 +28,47 @@
namespace arrow {
namespace flight {
+bool FlightDescriptor::Equals(const FlightDescriptor& other) const {
+ if (type != other.type) {
+ return false;
+ }
+ switch (type) {
+ case PATH:
+ return path == other.path;
+ case CMD:
+ return cmd == other.cmd;
+ default:
+ return false;
+ }
+}
+
+std::string FlightDescriptor::ToString() const {
+ std::stringstream ss;
+ ss << "FlightDescriptor<";
+ switch (type) {
+ case PATH: {
+ bool first = true;
+ ss << "path = '";
+ for (const auto& p : path) {
+ if (!first) {
+ ss << "/";
+ }
+ first = false;
+ ss << p;
+ }
+ ss << "'";
+ break;
+ }
+ case CMD:
+ ss << "cmd = '" << cmd << "'";
+ break;
+ default:
+ break;
+ }
+ ss << ">";
+ return ss.str();
+}
+
Status FlightInfo::GetSchema(std::shared_ptr<Schema>* out) const {
if (reconstructed_schema_) {
*out = schema_;
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index ba0ab85..0c09766 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -87,6 +87,20 @@ struct FlightDescriptor {
/// List of strings identifying a particular dataset. Should only be defined
/// when type is PATH
std::vector<std::string> path;
+
+ bool Equals(const FlightDescriptor& other) const;
+
+ std::string ToString() const;
+
+ // Convenience factory functions
+
+ static FlightDescriptor Command(const std::string& c) {
+ return FlightDescriptor{CMD, c, {}};
+ }
+
+ static FlightDescriptor Path(const std::vector<std::string>& p) {
+ return FlightDescriptor{PATH, "", p};
+ }
};
/// \brief Data structure providing an opaque identifier or credential to use
@@ -114,6 +128,8 @@ struct FlightEndpoint {
};
/// \brief Staging data structure for messages about to be put on the wire
+///
+/// This structure corresponds to FlightData in the protocol.
struct FlightPayload {
std::shared_ptr<Buffer> descriptor;
ipc::internal::IpcPayload ipc_message;
diff --git a/cpp/src/arrow/gpu/cuda-test.cc b/cpp/src/arrow/gpu/cuda-test.cc
index 51366e1..9a10b27 100644
--- a/cpp/src/arrow/gpu/cuda-test.cc
+++ b/cpp/src/arrow/gpu/cuda-test.cc
@@ -25,6 +25,7 @@
#include "arrow/ipc/test-common.h"
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
#include "arrow/gpu/cuda_api.h"
@@ -320,7 +321,7 @@ class TestCudaArrowIpc : public TestCudaBufferBase {
TEST_F(TestCudaArrowIpc, BasicWriteRead) {
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(ipc::MakeIntRecordBatch(&batch));
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch));
std::shared_ptr<CudaBuffer> device_serialized;
ASSERT_OK(SerializeRecordBatch(*batch, context_.get(), &device_serialized));
diff --git a/cpp/src/arrow/ipc/dictionary.h b/cpp/src/arrow/ipc/dictionary.h
index 4494b13..69ea485 100644
--- a/cpp/src/arrow/ipc/dictionary.h
+++ b/cpp/src/arrow/ipc/dictionary.h
@@ -42,6 +42,8 @@ using DictionaryTypeMap = std::unordered_map<int64_t, std::shared_ptr<Field>>;
class ARROW_EXPORT DictionaryMemo {
public:
DictionaryMemo();
+ DictionaryMemo(DictionaryMemo&&) = default;
+ DictionaryMemo& operator=(DictionaryMemo&&) = default;
/// \brief Returns KeyError if dictionary not found
Status GetDictionary(int64_t id, std::shared_ptr<Array>* dictionary) const;
diff --git a/cpp/src/arrow/ipc/feather-test.cc b/cpp/src/arrow/ipc/feather-test.cc
index e7b699d..001e36a 100644
--- a/cpp/src/arrow/ipc/feather-test.cc
+++ b/cpp/src/arrow/ipc/feather-test.cc
@@ -304,9 +304,9 @@ class TestTableReader : public ::testing::Test {
TEST_F(TestTableReader, ReadIndices) {
std::shared_ptr<RecordBatch> batch1;
- ASSERT_OK(MakeIntRecordBatch(&batch1));
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1));
std::shared_ptr<RecordBatch> batch2;
- ASSERT_OK(MakeIntRecordBatch(&batch2));
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch2));
ASSERT_OK(writer_->Append("f0", *batch1->column(0)));
ASSERT_OK(writer_->Append("f1", *batch1->column(1)));
@@ -329,9 +329,9 @@ TEST_F(TestTableReader, ReadIndices) {
TEST_F(TestTableReader, ReadNames) {
std::shared_ptr<RecordBatch> batch1;
- ASSERT_OK(MakeIntRecordBatch(&batch1));
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1));
std::shared_ptr<RecordBatch> batch2;
- ASSERT_OK(MakeIntRecordBatch(&batch2));
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch2));
ASSERT_OK(writer_->Append("f0", *batch1->column(0)));
ASSERT_OK(writer_->Append("f1", *batch1->column(1)));
@@ -419,7 +419,7 @@ TEST_F(TestTableWriter, SetDescription) {
TEST_F(TestTableWriter, PrimitiveRoundTrip) {
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(MakeIntRecordBatch(&batch));
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch));
ASSERT_OK(writer_->Append("f0", *batch->column(0)));
ASSERT_OK(writer_->Append("f1", *batch->column(1)));
@@ -437,7 +437,7 @@ TEST_F(TestTableWriter, PrimitiveRoundTrip) {
TEST_F(TestTableWriter, CategoryRoundtrip) {
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(MakeDictionaryFlat(&batch));
+ ASSERT_OK(ipc::test::MakeDictionaryFlat(&batch));
CheckBatch(batch);
}
@@ -489,13 +489,13 @@ TEST_F(TestTableWriter, TimeTypes) {
TEST_F(TestTableWriter, VLenPrimitiveRoundTrip) {
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(MakeStringTypesRecordBatch(&batch));
+ ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch));
CheckBatch(batch);
}
TEST_F(TestTableWriter, PrimitiveNullRoundTrip) {
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(MakeNullRecordBatch(&batch));
+ ASSERT_OK(ipc::test::MakeNullRecordBatch(&batch));
for (int i = 0; i < batch->num_columns(); ++i) {
ASSERT_OK(writer_->Append(batch->column_name(i), *batch->column(i)));
@@ -540,7 +540,7 @@ class TestTableWriterSlice : public TestTableWriter,
TEST_P(TestTableWriterSlice, SliceRoundTrip) {
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(MakeIntBatchSized(600, &batch));
+ ASSERT_OK(ipc::test::MakeIntBatchSized(600, &batch));
CheckSlice(batch);
}
@@ -549,13 +549,13 @@ TEST_P(TestTableWriterSlice, SliceStringsRoundTrip) {
auto start = std::get<0>(p);
auto with_nulls = start % 2 == 0;
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(MakeStringTypesRecordBatch(&batch, with_nulls));
+ ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch, with_nulls));
CheckSlice(batch);
}
TEST_P(TestTableWriterSlice, SliceBooleanRoundTrip) {
std::shared_ptr<RecordBatch> batch;
- ASSERT_OK(MakeBooleanBatchSized(600, &batch));
+ ASSERT_OK(ipc::test::MakeBooleanBatchSized(600, &batch));
CheckSlice(batch);
}
diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc
index 0420c13..0bc0f20 100644
--- a/cpp/src/arrow/ipc/json-internal.cc
+++ b/cpp/src/arrow/ipc/json-internal.cc
@@ -343,9 +343,8 @@ class SchemaWriter {
return VisitType(*type.dictionary()->type());
}
- Status Visit(const ExtensionType& type) {
- return Status::NotImplemented("extension type");
- }
+ // Default case
+ Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); }
private:
DictionaryMemo dictionary_memo_;
@@ -1210,10 +1209,6 @@ class ArrayReader {
return Status::OK();
}
- Status Visit(const ExtensionType& type) {
- return Status::NotImplemented("extension type");
- }
-
Status Visit(const DictionaryType& type) {
// This stores the indices in result_
//
@@ -1226,6 +1221,9 @@ class ArrayReader {
return Status::OK();
}
+ // Default case
+ Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); }
+
Status GetChildren(const RjObject& obj, const DataType& type,
std::vector<std::shared_ptr<Array>>* array) {
const auto& json_children = obj.FindMember("children");
diff --git a/cpp/src/arrow/ipc/json-test.cc b/cpp/src/arrow/ipc/json-test.cc
index 72504d4..f6198e3 100644
--- a/cpp/src/arrow/ipc/json-test.cc
+++ b/cpp/src/arrow/ipc/json-test.cc
@@ -34,6 +34,7 @@
#include "arrow/record_batch.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
@@ -42,6 +43,8 @@ namespace ipc {
namespace internal {
namespace json {
+using namespace ::arrow::ipc::test; // NOLINT
+
void TestSchemaRoundTrip(const Schema& schema) {
rj::StringBuffer sb;
rj::Writer<rj::StringBuffer> writer(sb);
diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc
index 2589a10..dedeee3 100644
--- a/cpp/src/arrow/ipc/metadata-internal.cc
+++ b/cpp/src/arrow/ipc/metadata-internal.cc
@@ -974,11 +974,12 @@ FileBlocksToFlatbuffer(FBB& fbb, const std::vector<FileBlock>& blocks) {
Status WriteFileFooter(const Schema& schema, const std::vector<FileBlock>& dictionaries,
const std::vector<FileBlock>& record_batches,
- DictionaryMemo* dictionary_memo, io::OutputStream* out) {
+ io::OutputStream* out) {
FBB fbb;
flatbuffers::Offset<flatbuf::Schema> fb_schema;
- RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema));
+ DictionaryMemo dictionary_memo; // unused
+ RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, &dictionary_memo, &fb_schema));
#ifndef NDEBUG
for (size_t i = 0; i < dictionaries.size(); ++i) {
diff --git a/cpp/src/arrow/ipc/metadata-internal.h b/cpp/src/arrow/ipc/metadata-internal.h
index 6562382..c91983d 100644
--- a/cpp/src/arrow/ipc/metadata-internal.h
+++ b/cpp/src/arrow/ipc/metadata-internal.h
@@ -151,7 +151,7 @@ Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor, int64_t body_
Status WriteFileFooter(const Schema& schema, const std::vector<FileBlock>& dictionaries,
const std::vector<FileBlock>& record_batches,
- DictionaryMemo* dictionary_memo, io::OutputStream* out);
+ io::OutputStream* out);
Status WriteDictionaryMessage(const int64_t id, const int64_t length,
const int64_t body_length,
diff --git a/cpp/src/arrow/ipc/read-write-test.cc b/cpp/src/arrow/ipc/read-write-test.cc
index 6f4da28..0408a17 100644
--- a/cpp/src/arrow/ipc/read-write-test.cc
+++ b/cpp/src/arrow/ipc/read-write-test.cc
@@ -51,24 +51,10 @@ namespace arrow {
using internal::checked_cast;
namespace ipc {
+namespace test {
using BatchVector = std::vector<std::shared_ptr<RecordBatch>>;
-class TestSchemaMetadata : public ::testing::Test {
- public:
- void SetUp() {}
-
- void CheckRoundtrip(const Schema& schema) {
- std::shared_ptr<Buffer> buffer;
- ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer));
-
- std::shared_ptr<Schema> result;
- io::BufferReader reader(buffer);
- ASSERT_OK(ReadSchema(&reader, &result));
- AssertSchemaEqual(schema, *result);
- }
-};
-
TEST(TestMessage, Equals) {
std::string metadata = "foo";
std::string body = "bar";
@@ -147,6 +133,21 @@ TEST(TestMessage, Verify) {
ASSERT_FALSE(message.Verify());
}
+class TestSchemaMetadata : public ::testing::Test {
+ public:
+ void SetUp() {}
+
+ void CheckRoundtrip(const Schema& schema) {
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer));
+
+ std::shared_ptr<Schema> result;
+ io::BufferReader reader(buffer);
+ ASSERT_OK(ReadSchema(&reader, &result));
+ AssertSchemaEqual(schema, *result);
+ }
+};
+
const std::shared_ptr<DataType> INT32 = std::make_shared<Int32Type>();
TEST_F(TestSchemaMetadata, PrimitiveFields) {
@@ -178,6 +179,25 @@ TEST_F(TestSchemaMetadata, NestedFields) {
CheckRoundtrip(schema);
}
+TEST_F(TestSchemaMetadata, DictionaryFields) {
+ {
+ auto dict_type =
+ dictionary(int8(), ArrayFromJSON(int32(), "[6, 5, 4]"), true /* ordered */);
+ auto f0 = field("f0", dict_type);
+ auto f1 = field("f1", list(dict_type));
+
+ Schema schema({f0, f1});
+ CheckRoundtrip(schema);
+ }
+ {
+ auto dict_type = dictionary(int8(), ArrayFromJSON(list(int32()), "[[4, 5], [6]]"));
+ auto f0 = field("f0", dict_type);
+
+ Schema schema({f0});
+ CheckRoundtrip(schema);
+ }
+}
+
TEST_F(TestSchemaMetadata, KeyValueMetadata) {
auto field_metadata = key_value_metadata({{"key", "value"}});
auto schema_metadata = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}});
@@ -388,7 +408,7 @@ TEST_F(TestWriteRecordBatch, SliceTruncatesBuffers) {
// String / Binary
{
- auto s = MakeRandomBinaryArray<StringBuilder, char>(500, false, pool, &a0);
+ auto s = MakeRandomStringArray(500, false, pool, &a0);
ASSERT_TRUE(s.ok());
}
CheckArray(a0);
@@ -993,5 +1013,6 @@ TEST(TestRecordBatchStreamReader, MalformedInput) {
ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader, &batch_reader));
}
+} // namespace test
} // namespace ipc
} // namespace arrow
diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc
index a33f07c..85c6400 100644
--- a/cpp/src/arrow/ipc/reader.cc
+++ b/cpp/src/arrow/ipc/reader.cc
@@ -56,6 +56,38 @@ namespace ipc {
using internal::FileBlock;
using internal::kArrowMagicBytes;
+namespace {
+
+Status InvalidMessageType(Message::Type expected, Message::Type actual) {
+ return Status::IOError("Expected IPC message of type ", FormatMessageType(expected),
+ " got ", FormatMessageType(actual));
+}
+
+#define CHECK_MESSAGE_TYPE(expected, actual) \
+ do { \
+ if ((actual) != (expected)) { \
+ return InvalidMessageType((expected), (actual)); \
+ } \
+ } while (0)
+
+#define CHECK_HAS_BODY(message) \
+ do { \
+ if ((message).body() == nullptr) { \
+ return Status::IOError("Expected body in IPC message of type ", \
+ FormatMessageType((message).type())); \
+ } \
+ } while (0)
+
+#define CHECK_HAS_NO_BODY(message) \
+ do { \
+ if ((message).body_length() != 0) { \
+ return Status::IOError("Unexpected body in IPC message of type ", \
+ FormatMessageType((message).type())); \
+ } \
+ } while (0)
+
+} // namespace
+
// ----------------------------------------------------------------------
// Record batch read path
@@ -287,8 +319,9 @@ Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr<Schema>& sc
Status ReadRecordBatch(const Message& message, const std::shared_ptr<Schema>& schema,
std::shared_ptr<RecordBatch>* out) {
+ CHECK_MESSAGE_TYPE(message.type(), Message::RECORD_BATCH);
+ CHECK_HAS_BODY(message);
io::BufferReader reader(message.body());
- DCHECK_EQ(message.type(), Message::RECORD_BATCH);
return ReadRecordBatch(*message.metadata(), schema, kMaxNestingDepth, &reader, out);
}
@@ -382,14 +415,11 @@ static Status ReadMessageAndValidate(MessageReader* reader, Message::Type expect
}
if ((*message) == nullptr) {
+ // End of stream?
return Status::OK();
}
- if ((*message)->type() != expected_type) {
- return Status::IOError(
- "Message not expected type: ", FormatMessageType(expected_type),
- ", was: ", (*message)->type());
- }
+ CHECK_MESSAGE_TYPE((*message)->type(), expected_type);
return Status::OK();
}
@@ -414,7 +444,13 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
std::unique_ptr<Message> message;
RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::DICTIONARY_BATCH,
false, &message));
+ if (message == nullptr) {
+ // End of stream
+ return Status::IOError(
+ "End of IPC stream when attempting to read dictionary batch");
+ }
+ CHECK_HAS_BODY(*message);
io::BufferReader reader(message->body());
std::shared_ptr<Array> dictionary;
@@ -428,7 +464,12 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
std::unique_ptr<Message> message;
RETURN_NOT_OK(
ReadMessageAndValidate(message_reader_.get(), Message::SCHEMA, false, &message));
+ if (message == nullptr) {
+ // End of stream
+ return Status::IOError("End of IPC stream when attempting to read schema");
+ }
+ CHECK_HAS_NO_BODY(*message);
if (message->header() == nullptr) {
return Status::IOError("Header-pointer of flatbuffer-encoded Message is null.");
}
@@ -448,13 +489,13 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
std::unique_ptr<Message> message;
RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::RECORD_BATCH,
true, &message));
-
if (message == nullptr) {
// End of stream
*batch = nullptr;
return Status::OK();
}
+ CHECK_HAS_BODY(*message);
io::BufferReader reader(message->body());
return ReadRecordBatch(*message->metadata(), schema_, &reader, batch);
}
@@ -485,6 +526,15 @@ Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_read
return Status::OK();
}
+Status RecordBatchStreamReader::Open(std::unique_ptr<MessageReader> message_reader,
+ std::unique_ptr<RecordBatchReader>* reader) {
+ // Private ctor
+ auto result = std::unique_ptr<RecordBatchStreamReader>(new RecordBatchStreamReader());
+ RETURN_NOT_OK(result->impl_->Open(std::move(message_reader)));
+ *reader = std::move(result);
+ return Status::OK();
+}
+
Status RecordBatchStreamReader::Open(io::InputStream* stream,
std::shared_ptr<RecordBatchReader>* out) {
return Open(MessageReader::Open(stream), out);
@@ -854,7 +904,8 @@ Status ReadSparseTensor(const Message& message, std::shared_ptr<SparseTensor>* o
Status ReadSparseTensor(io::InputStream* file, std::shared_ptr<SparseTensor>* out) {
std::unique_ptr<Message> message;
RETURN_NOT_OK(ReadContiguousPayload(file, &message));
- DCHECK_EQ(message->type(), Message::SPARSE_TENSOR);
+ CHECK_MESSAGE_TYPE(message->type(), Message::SPARSE_TENSOR);
+ CHECK_HAS_BODY(*message);
io::BufferReader buffer_reader(message->body());
return ReadSparseTensor(*message->metadata(), &buffer_reader, out);
}
diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h
index 641de3e..8fe310f 100644
--- a/cpp/src/arrow/ipc/reader.h
+++ b/cpp/src/arrow/ipc/reader.h
@@ -56,13 +56,16 @@ class ARROW_EXPORT RecordBatchStreamReader : public RecordBatchReader {
public:
~RecordBatchStreamReader() override;
- /// Create batch reader from generic MessageReader
+ /// Create batch reader from generic MessageReader.
+ /// This will take ownership of the given MessageReader.
///
/// \param[in] message_reader a MessageReader implementation
/// \param[out] out the created RecordBatchReader object
/// \return Status
static Status Open(std::unique_ptr<MessageReader> message_reader,
std::shared_ptr<RecordBatchReader>* out);
+ static Status Open(std::unique_ptr<MessageReader> message_reader,
+ std::unique_ptr<RecordBatchReader>* out);
/// \brief Record batch stream reader from InputStream
///
diff --git a/cpp/src/arrow/ipc/test-common.h b/cpp/src/arrow/ipc/test-common.cc
similarity index 78%
copy from cpp/src/arrow/ipc/test-common.h
copy to cpp/src/arrow/ipc/test-common.cc
index 8593fbc..44b608d 100644
--- a/cpp/src/arrow/ipc/test-common.h
+++ b/cpp/src/arrow/ipc/test-common.cc
@@ -15,9 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-#ifndef ARROW_IPC_TEST_COMMON_H
-#define ARROW_IPC_TEST_COMMON_H
-
#include <algorithm>
#include <cstdint>
#include <memory>
@@ -28,6 +25,7 @@
#include "arrow/array.h"
#include "arrow/buffer.h"
#include "arrow/builder.h"
+#include "arrow/ipc/test-common.h"
#include "arrow/memory_pool.h"
#include "arrow/pretty_print.h"
#include "arrow/record_batch.h"
@@ -40,9 +38,9 @@
namespace arrow {
namespace ipc {
+namespace test {
-static inline void CompareArraysDetailed(int index, const Array& result,
- const Array& expected) {
+void CompareArraysDetailed(int index, const Array& result, const Array& expected) {
if (!expected.Equals(result)) {
std::stringstream pp_result;
std::stringstream pp_expected;
@@ -55,8 +53,7 @@ static inline void CompareArraysDetailed(int index, const Array& result,
}
}
-static inline void CompareBatchColumnsDetailed(const RecordBatch& result,
- const RecordBatch& expected) {
+void CompareBatchColumnsDetailed(const RecordBatch& result, const RecordBatch& expected) {
for (int i = 0; i < expected.num_columns(); ++i) {
auto left = result.column(i);
auto right = expected.column(i);
@@ -64,12 +61,8 @@ static inline void CompareBatchColumnsDetailed(const RecordBatch& result,
}
}
-const auto kListInt32 = list(int32());
-const auto kListListInt32 = list(kListInt32);
-
-static inline Status MakeRandomInt32Array(int64_t length, bool include_nulls,
- MemoryPool* pool, std::shared_ptr<Array>* out,
- uint32_t seed = 0) {
+Status MakeRandomInt32Array(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out, uint32_t seed) {
random::RandomArrayGenerator rand(seed);
const double null_probability = include_nulls ? 0.5 : 0.0;
@@ -78,9 +71,9 @@ static inline Status MakeRandomInt32Array(int64_t length, bool include_nulls,
return Status::OK();
}
-static inline Status MakeRandomListArray(const std::shared_ptr<Array>& child_array,
- int num_lists, bool include_nulls,
- MemoryPool* pool, std::shared_ptr<Array>* out) {
+Status MakeRandomListArray(const std::shared_ptr<Array>& child_array, int num_lists,
+ bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
// Create the null list values
std::vector<uint8_t> valid_lists(num_lists);
const double null_percent = include_nulls ? 0.1 : 0;
@@ -122,10 +115,8 @@ static inline Status MakeRandomListArray(const std::shared_ptr<Array>& child_arr
return ValidateArray(**out);
}
-typedef Status MakeRecordBatch(std::shared_ptr<RecordBatch>* out);
-
-static inline Status MakeRandomBooleanArray(const int length, bool include_nulls,
- std::shared_ptr<Array>* out) {
+Status MakeRandomBooleanArray(const int length, bool include_nulls,
+ std::shared_ptr<Array>* out) {
std::vector<uint8_t> values(length);
random_null_bytes(length, 0.5, values.data());
std::shared_ptr<Buffer> data;
@@ -143,8 +134,7 @@ static inline Status MakeRandomBooleanArray(const int length, bool include_nulls
return Status::OK();
}
-static inline Status MakeBooleanBatchSized(const int length,
- std::shared_ptr<RecordBatch>* out) {
+Status MakeBooleanBatchSized(const int length, std::shared_ptr<RecordBatch>* out) {
// Make the schema
auto f0 = field("f0", boolean());
auto f1 = field("f1", boolean());
@@ -157,12 +147,11 @@ static inline Status MakeBooleanBatchSized(const int length,
return Status::OK();
}
-static inline Status MakeBooleanBatch(std::shared_ptr<RecordBatch>* out) {
+Status MakeBooleanBatch(std::shared_ptr<RecordBatch>* out) {
return MakeBooleanBatchSized(1000, out);
}
-static inline Status MakeIntBatchSized(int length, std::shared_ptr<RecordBatch>* out,
- uint32_t seed = 0) {
+Status MakeIntBatchSized(int length, std::shared_ptr<RecordBatch>* out, uint32_t seed) {
// Make the schema
auto f0 = field("f0", int32());
auto f1 = field("f1", int32());
@@ -177,33 +166,32 @@ static inline Status MakeIntBatchSized(int length, std::shared_ptr<RecordBatch>*
return Status::OK();
}
-static inline Status MakeIntRecordBatch(std::shared_ptr<RecordBatch>* out) {
+Status MakeIntRecordBatch(std::shared_ptr<RecordBatch>* out) {
return MakeIntBatchSized(10, out);
}
-template <class Builder, class RawType>
-Status MakeRandomBinaryArray(int64_t length, bool include_nulls, MemoryPool* pool,
+Status MakeRandomStringArray(int64_t length, bool include_nulls, MemoryPool* pool,
std::shared_ptr<Array>* out) {
const std::vector<std::string> values = {"", "", "abc", "123",
"efg", "456!@#!@#", "12312"};
- Builder builder(pool);
+ StringBuilder builder(pool);
const size_t values_len = values.size();
for (int64_t i = 0; i < length; ++i) {
int64_t values_index = i % values_len;
if (include_nulls && values_index == 0) {
RETURN_NOT_OK(builder.AppendNull());
} else {
- const std::string& value = values[values_index];
- RETURN_NOT_OK(builder.Append(reinterpret_cast<const RawType*>(value.data()),
- static_cast<int32_t>(value.size())));
+ const auto& value = values[values_index];
+ RETURN_NOT_OK(builder.Append(value));
}
}
return builder.Finish(out);
}
template <class Builder, class RawType>
-Status MakeBinaryArrayWithUniqueValues(int64_t length, bool include_nulls,
- MemoryPool* pool, std::shared_ptr<Array>* out) {
+static Status MakeBinaryArrayWithUniqueValues(int64_t length, bool include_nulls,
+ MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
Builder builder(pool);
for (int64_t i = 0; i < length; ++i) {
if (include_nulls && (i % 7 == 0)) {
@@ -217,8 +205,7 @@ Status MakeBinaryArrayWithUniqueValues(int64_t length, bool include_nulls,
return builder.Finish(out);
}
-static inline Status MakeStringTypesRecordBatch(std::shared_ptr<RecordBatch>* out,
- bool with_nulls = true) {
+Status MakeStringTypesRecordBatch(std::shared_ptr<RecordBatch>* out, bool with_nulls) {
const int64_t length = 500;
auto string_type = utf8();
auto binary_type = binary();
@@ -245,12 +232,11 @@ static inline Status MakeStringTypesRecordBatch(std::shared_ptr<RecordBatch>* ou
return Status::OK();
}
-static inline Status MakeStringTypesRecordBatchWithNulls(
- std::shared_ptr<RecordBatch>* out) {
+Status MakeStringTypesRecordBatchWithNulls(std::shared_ptr<RecordBatch>* out) {
return MakeStringTypesRecordBatch(out, true);
}
-static inline Status MakeNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
+Status MakeNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
const int64_t length = 500;
auto f0 = field("f0", null());
auto schema = ::arrow::schema({f0});
@@ -259,10 +245,10 @@ static inline Status MakeNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeListRecordBatch(std::shared_ptr<RecordBatch>* out) {
+Status MakeListRecordBatch(std::shared_ptr<RecordBatch>* out) {
// Make the schema
- auto f0 = field("f0", kListInt32);
- auto f1 = field("f1", kListListInt32);
+ auto f0 = field("f0", list(int32()));
+ auto f1 = field("f1", list(list(int32())));
auto f2 = field("f2", int32());
auto schema = ::arrow::schema({f0, f1, f2});
@@ -282,10 +268,10 @@ static inline Status MakeListRecordBatch(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeZeroLengthRecordBatch(std::shared_ptr<RecordBatch>* out) {
+Status MakeZeroLengthRecordBatch(std::shared_ptr<RecordBatch>* out) {
// Make the schema
- auto f0 = field("f0", kListInt32);
- auto f1 = field("f1", kListListInt32);
+ auto f0 = field("f0", list(int32()));
+ auto f1 = field("f1", list(list(int32())));
auto f2 = field("f2", int32());
auto schema = ::arrow::schema({f0, f1, f2});
@@ -302,10 +288,10 @@ static inline Status MakeZeroLengthRecordBatch(std::shared_ptr<RecordBatch>* out
return Status::OK();
}
-static inline Status MakeNonNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
+Status MakeNonNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
// Make the schema
- auto f0 = field("f0", kListInt32);
- auto f1 = field("f1", kListListInt32);
+ auto f0 = field("f0", list(int32()));
+ auto f1 = field("f1", list(list(int32())));
auto f2 = field("f2", int32());
auto schema = ::arrow::schema({f0, f1, f2});
@@ -325,7 +311,7 @@ static inline Status MakeNonNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeDeeplyNestedList(std::shared_ptr<RecordBatch>* out) {
+Status MakeDeeplyNestedList(std::shared_ptr<RecordBatch>* out) {
const int batch_length = 5;
auto type = int32();
@@ -345,7 +331,7 @@ static inline Status MakeDeeplyNestedList(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeStruct(std::shared_ptr<RecordBatch>* out) {
+Status MakeStruct(std::shared_ptr<RecordBatch>* out) {
// reuse constructed list columns
std::shared_ptr<RecordBatch> list_batch;
RETURN_NOT_OK(MakeListRecordBatch(&list_batch));
@@ -375,7 +361,7 @@ static inline Status MakeStruct(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeUnion(std::shared_ptr<RecordBatch>* out) {
+Status MakeUnion(std::shared_ptr<RecordBatch>* out) {
// Define schema
std::vector<std::shared_ptr<Field>> union_types(
{field("u0", int32()), field("u1", uint8())});
@@ -441,17 +427,13 @@ static inline Status MakeUnion(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeDictionary(std::shared_ptr<RecordBatch>* out) {
+Status MakeDictionary(std::shared_ptr<RecordBatch>* out) {
const int64_t length = 6;
std::vector<bool> is_valid = {true, true, false, true, true, true};
- std::shared_ptr<Array> dict1, dict2;
-
- std::vector<std::string> dict1_values = {"foo", "bar", "baz"};
- std::vector<std::string> dict2_values = {"foo", "bar", "baz", "qux"};
- ArrayFromVector<StringType, std::string>(dict1_values, &dict1);
- ArrayFromVector<StringType, std::string>(dict2_values, &dict2);
+ auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ auto dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]");
auto f0_type = arrow::dictionary(arrow::int32(), dict1);
auto f1_type = arrow::dictionary(arrow::int8(), dict1, true);
@@ -470,51 +452,30 @@ static inline Status MakeDictionary(std::shared_ptr<RecordBatch>* out) {
auto a1 = std::make_shared<DictionaryArray>(f1_type, indices1);
auto a2 = std::make_shared<DictionaryArray>(f2_type, indices2);
- // List of dictionary-encoded string
+ // Lists of dictionary-encoded strings
auto f3_type = list(f1_type);
- std::vector<int32_t> list_offsets = {0, 0, 2, 2, 5, 6, 9};
- std::shared_ptr<Array> offsets, indices3;
- ArrayFromVector<Int32Type, int32_t>(std::vector<bool>(list_offsets.size(), true),
- list_offsets, &offsets);
-
- std::vector<int8_t> indices3_values = {0, 1, 2, 0, 1, 2, 0, 1, 2};
- std::vector<bool> is_valid3(9, true);
- ArrayFromVector<Int8Type, int8_t>(is_valid3, indices3_values, &indices3);
+ auto indices3 = ArrayFromJSON(int8(), "[0, 1, 2, 0, 1, 1, 2, 1, 0]");
+ auto offsets3 = ArrayFromJSON(int32(), "[0, 0, 2, 2, 5, 6, 9]");
std::shared_ptr<Buffer> null_bitmap;
RETURN_NOT_OK(GetBitmapFromVector(is_valid, &null_bitmap));
std::shared_ptr<Array> a3 = std::make_shared<ListArray>(
- f3_type, length, std::static_pointer_cast<PrimitiveArray>(offsets)->values(),
+ f3_type, length, std::static_pointer_cast<PrimitiveArray>(offsets3)->values(),
std::make_shared<DictionaryArray>(f1_type, indices3), null_bitmap, 1);
- // Dictionary-encoded list of integer
- auto f4_value_type = list(int8());
-
- std::shared_ptr<Array> offsets4, values4, indices4;
-
- std::vector<int32_t> list_offsets4 = {0, 2, 2, 3};
- ArrayFromVector<Int32Type, int32_t>(std::vector<bool>(4, true), list_offsets4,
- &offsets4);
-
- std::vector<int8_t> list_values4 = {0, 1, 2};
- ArrayFromVector<Int8Type, int8_t>(std::vector<bool>(3, true), list_values4, &values4);
+ // Dictionary-encoded lists of integers
+ auto dict4 = ArrayFromJSON(list(int8()), "[[44, 55], [], [66]]");
+ auto f4_type = dictionary(int8(), dict4);
- auto dict3 = std::make_shared<ListArray>(
- f4_value_type, 3, std::static_pointer_cast<PrimitiveArray>(offsets4)->values(),
- values4);
-
- std::vector<int8_t> indices4_values = {0, 1, 2, 0, 1, 2};
- ArrayFromVector<Int8Type, int8_t>(is_valid, indices4_values, &indices4);
-
- auto f4_type = dictionary(int8(), dict3);
+ auto indices4 = ArrayFromJSON(int8(), "[0, 1, 2, 0, 2, 2]");
auto a4 = std::make_shared<DictionaryArray>(f4_type, indices4);
// construct batch
auto schema = ::arrow::schema(
- {field("dict1", f0_type), field("sparse", f1_type), field("dense", f2_type),
- field("list of encoded string", f3_type), field("encoded list<int8>", f4_type)});
+ {field("dict1", f0_type), field("dict2", f1_type), field("dict3", f2_type),
+ field("list<encoded utf8>", f3_type), field("encoded list<int8>", f4_type)});
std::vector<std::shared_ptr<Array>> arrays = {a0, a1, a2, a3, a4};
@@ -522,17 +483,13 @@ static inline Status MakeDictionary(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeDictionaryFlat(std::shared_ptr<RecordBatch>* out) {
+Status MakeDictionaryFlat(std::shared_ptr<RecordBatch>* out) {
const int64_t length = 6;
std::vector<bool> is_valid = {true, true, false, true, true, true};
- std::shared_ptr<Array> dict1, dict2;
- std::vector<std::string> dict1_values = {"foo", "bar", "baz"};
- std::vector<std::string> dict2_values = {"foo", "bar", "baz", "qux"};
-
- ArrayFromVector<StringType, std::string>(dict1_values, &dict1);
- ArrayFromVector<StringType, std::string>(dict2_values, &dict2);
+ auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ auto dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]");
auto f0_type = arrow::dictionary(arrow::int32(), dict1);
auto f1_type = arrow::dictionary(arrow::int8(), dict1);
@@ -553,14 +510,14 @@ static inline Status MakeDictionaryFlat(std::shared_ptr<RecordBatch>* out) {
// construct batch
auto schema = ::arrow::schema(
- {field("dict1", f0_type), field("sparse", f1_type), field("dense", f2_type)});
+ {field("dict1", f0_type), field("dict2", f1_type), field("dict3", f2_type)});
std::vector<std::shared_ptr<Array>> arrays = {a0, a1, a2};
*out = RecordBatch::Make(schema, length, arrays);
return Status::OK();
}
-static inline Status MakeDates(std::shared_ptr<RecordBatch>* out) {
+Status MakeDates(std::shared_ptr<RecordBatch>* out) {
std::vector<bool> is_valid = {true, true, true, false, true, true, true};
auto f0 = field("f0", date32());
auto f1 = field("f1", date64());
@@ -580,7 +537,7 @@ static inline Status MakeDates(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeTimestamps(std::shared_ptr<RecordBatch>* out) {
+Status MakeTimestamps(std::shared_ptr<RecordBatch>* out) {
std::vector<bool> is_valid = {true, true, true, false, true, true, true};
auto f0 = field("f0", timestamp(TimeUnit::MILLI));
auto f1 = field("f1", timestamp(TimeUnit::NANO, "America/New_York"));
@@ -599,7 +556,7 @@ static inline Status MakeTimestamps(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeTimes(std::shared_ptr<RecordBatch>* out) {
+Status MakeTimes(std::shared_ptr<RecordBatch>* out) {
std::vector<bool> is_valid = {true, true, true, false, true, true, true};
auto f0 = field("f0", time32(TimeUnit::MILLI));
auto f1 = field("f1", time64(TimeUnit::NANO));
@@ -623,8 +580,8 @@ static inline Status MakeTimes(std::shared_ptr<RecordBatch>* out) {
}
template <typename BuilderType, typename T>
-void AppendValues(const std::vector<bool>& is_valid, const std::vector<T>& values,
- BuilderType* builder) {
+static void AppendValues(const std::vector<bool>& is_valid, const std::vector<T>& values,
+ BuilderType* builder) {
for (size_t i = 0; i < values.size(); ++i) {
if (is_valid[i]) {
ASSERT_OK(builder->Append(values[i]));
@@ -634,7 +591,7 @@ void AppendValues(const std::vector<bool>& is_valid, const std::vector<T>& value
}
}
-static inline Status MakeFWBinary(std::shared_ptr<RecordBatch>* out) {
+Status MakeFWBinary(std::shared_ptr<RecordBatch>* out) {
std::vector<bool> is_valid = {true, true, true, false};
auto f0 = field("f0", fixed_size_binary(4));
auto f1 = field("f1", fixed_size_binary(0));
@@ -658,7 +615,7 @@ static inline Status MakeFWBinary(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeDecimal(std::shared_ptr<RecordBatch>* out) {
+Status MakeDecimal(std::shared_ptr<RecordBatch>* out) {
constexpr int kDecimalPrecision = 38;
auto type = decimal(kDecimalPrecision, 4);
auto f0 = field("f0", type);
@@ -687,7 +644,7 @@ static inline Status MakeDecimal(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
-static inline Status MakeNull(std::shared_ptr<RecordBatch>* out) {
+Status MakeNull(std::shared_ptr<RecordBatch>* out) {
auto f0 = field("f0", null());
// Also put a non-null field to make sure we handle the null array buffers properly
@@ -707,7 +664,6 @@ static inline Status MakeNull(std::shared_ptr<RecordBatch>* out) {
return Status::OK();
}
+} // namespace test
} // namespace ipc
} // namespace arrow
-
-#endif // ARROW_IPC_TEST_COMMON_H
diff --git a/cpp/src/arrow/ipc/test-common.h b/cpp/src/arrow/ipc/test-common.h
index 8593fbc..735991b 100644
--- a/cpp/src/arrow/ipc/test-common.h
+++ b/cpp/src/arrow/ipc/test-common.h
@@ -18,695 +18,110 @@
#ifndef ARROW_IPC_TEST_COMMON_H
#define ARROW_IPC_TEST_COMMON_H
-#include <algorithm>
#include <cstdint>
#include <memory>
-#include <numeric>
-#include <string>
-#include <vector>
#include "arrow/array.h"
-#include "arrow/buffer.h"
-#include "arrow/builder.h"
-#include "arrow/memory_pool.h"
-#include "arrow/pretty_print.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
-#include "arrow/testing/gtest_util.h"
-#include "arrow/testing/random.h"
-#include "arrow/testing/util.h"
#include "arrow/type.h"
-#include "arrow/util/bit-util.h"
namespace arrow {
namespace ipc {
+namespace test {
-static inline void CompareArraysDetailed(int index, const Array& result,
- const Array& expected) {
- if (!expected.Equals(result)) {
- std::stringstream pp_result;
- std::stringstream pp_expected;
-
- ASSERT_OK(PrettyPrint(expected, 0, &pp_expected));
- ASSERT_OK(PrettyPrint(result, 0, &pp_result));
-
- FAIL() << "Index: " << index << " Expected: " << pp_expected.str()
- << "\nGot: " << pp_result.str();
- }
-}
-
-static inline void CompareBatchColumnsDetailed(const RecordBatch& result,
- const RecordBatch& expected) {
- for (int i = 0; i < expected.num_columns(); ++i) {
- auto left = result.column(i);
- auto right = expected.column(i);
- CompareArraysDetailed(i, *left, *right);
- }
-}
-
-const auto kListInt32 = list(int32());
-const auto kListListInt32 = list(kListInt32);
-
-static inline Status MakeRandomInt32Array(int64_t length, bool include_nulls,
- MemoryPool* pool, std::shared_ptr<Array>* out,
- uint32_t seed = 0) {
- random::RandomArrayGenerator rand(seed);
- const double null_probability = include_nulls ? 0.5 : 0.0;
-
- *out = rand.Int32(length, 0, 1000, null_probability);
-
- return Status::OK();
-}
-
-static inline Status MakeRandomListArray(const std::shared_ptr<Array>& child_array,
- int num_lists, bool include_nulls,
- MemoryPool* pool, std::shared_ptr<Array>* out) {
- // Create the null list values
- std::vector<uint8_t> valid_lists(num_lists);
- const double null_percent = include_nulls ? 0.1 : 0;
- random_null_bytes(num_lists, null_percent, valid_lists.data());
-
- // Create list offsets
- const int max_list_size = 10;
-
- std::vector<int32_t> list_sizes(num_lists, 0);
- std::vector<int32_t> offsets(
- num_lists + 1, 0); // +1 so we can shift for nulls. See partial sum below.
- const uint32_t seed = static_cast<uint32_t>(child_array->length());
-
- if (num_lists > 0) {
- rand_uniform_int(num_lists, seed, 0, max_list_size, list_sizes.data());
- // make sure sizes are consistent with null
- std::transform(list_sizes.begin(), list_sizes.end(), valid_lists.begin(),
- list_sizes.begin(),
- [](int32_t size, int32_t valid) { return valid == 0 ? 0 : size; });
- std::partial_sum(list_sizes.begin(), list_sizes.end(), ++offsets.begin());
-
- // Force invariants
- const int32_t child_length = static_cast<int32_t>(child_array->length());
- offsets[0] = 0;
- std::replace_if(offsets.begin(), offsets.end(),
- [child_length](int32_t offset) { return offset > child_length; },
- child_length);
- }
-
- offsets[num_lists] = static_cast<int32_t>(child_array->length());
-
- /// TODO(wesm): Implement support for nulls in ListArray::FromArrays
- std::shared_ptr<Buffer> null_bitmap, offsets_buffer;
- RETURN_NOT_OK(GetBitmapFromVector(valid_lists, &null_bitmap));
- RETURN_NOT_OK(CopyBufferFromVector(offsets, pool, &offsets_buffer));
-
- *out = std::make_shared<ListArray>(list(child_array->type()), num_lists, offsets_buffer,
- child_array, null_bitmap, kUnknownNullCount);
- return ValidateArray(**out);
-}
-
+// A typedef used for test parameterization
typedef Status MakeRecordBatch(std::shared_ptr<RecordBatch>* out);
-static inline Status MakeRandomBooleanArray(const int length, bool include_nulls,
- std::shared_ptr<Array>* out) {
- std::vector<uint8_t> values(length);
- random_null_bytes(length, 0.5, values.data());
- std::shared_ptr<Buffer> data;
- RETURN_NOT_OK(BitUtil::BytesToBits(values, default_memory_pool(), &data));
-
- if (include_nulls) {
- std::vector<uint8_t> valid_bytes(length);
- std::shared_ptr<Buffer> null_bitmap;
- RETURN_NOT_OK(BitUtil::BytesToBits(valid_bytes, default_memory_pool(), &null_bitmap));
- random_null_bytes(length, 0.1, valid_bytes.data());
- *out = std::make_shared<BooleanArray>(length, data, null_bitmap, -1);
- } else {
- *out = std::make_shared<BooleanArray>(length, data, NULLPTR, 0);
- }
- return Status::OK();
-}
-
-static inline Status MakeBooleanBatchSized(const int length,
- std::shared_ptr<RecordBatch>* out) {
- // Make the schema
- auto f0 = field("f0", boolean());
- auto f1 = field("f1", boolean());
- auto schema = ::arrow::schema({f0, f1});
-
- std::shared_ptr<Array> a0, a1;
- RETURN_NOT_OK(MakeRandomBooleanArray(length, true, &a0));
- RETURN_NOT_OK(MakeRandomBooleanArray(length, false, &a1));
- *out = RecordBatch::Make(schema, length, {a0, a1});
- return Status::OK();
-}
-
-static inline Status MakeBooleanBatch(std::shared_ptr<RecordBatch>* out) {
- return MakeBooleanBatchSized(1000, out);
-}
-
-static inline Status MakeIntBatchSized(int length, std::shared_ptr<RecordBatch>* out,
- uint32_t seed = 0) {
- // Make the schema
- auto f0 = field("f0", int32());
- auto f1 = field("f1", int32());
- auto schema = ::arrow::schema({f0, f1});
-
- // Example data
- std::shared_ptr<Array> a0, a1;
- MemoryPool* pool = default_memory_pool();
- RETURN_NOT_OK(MakeRandomInt32Array(length, false, pool, &a0, seed));
- RETURN_NOT_OK(MakeRandomInt32Array(length, true, pool, &a1, seed + 1));
- *out = RecordBatch::Make(schema, length, {a0, a1});
- return Status::OK();
-}
-
-static inline Status MakeIntRecordBatch(std::shared_ptr<RecordBatch>* out) {
- return MakeIntBatchSized(10, out);
-}
-
-template <class Builder, class RawType>
-Status MakeRandomBinaryArray(int64_t length, bool include_nulls, MemoryPool* pool,
- std::shared_ptr<Array>* out) {
- const std::vector<std::string> values = {"", "", "abc", "123",
- "efg", "456!@#!@#", "12312"};
- Builder builder(pool);
- const size_t values_len = values.size();
- for (int64_t i = 0; i < length; ++i) {
- int64_t values_index = i % values_len;
- if (include_nulls && values_index == 0) {
- RETURN_NOT_OK(builder.AppendNull());
- } else {
- const std::string& value = values[values_index];
- RETURN_NOT_OK(builder.Append(reinterpret_cast<const RawType*>(value.data()),
- static_cast<int32_t>(value.size())));
- }
- }
- return builder.Finish(out);
-}
-
-template <class Builder, class RawType>
-Status MakeBinaryArrayWithUniqueValues(int64_t length, bool include_nulls,
- MemoryPool* pool, std::shared_ptr<Array>* out) {
- Builder builder(pool);
- for (int64_t i = 0; i < length; ++i) {
- if (include_nulls && (i % 7 == 0)) {
- RETURN_NOT_OK(builder.AppendNull());
- } else {
- const std::string value = std::to_string(i);
- RETURN_NOT_OK(builder.Append(reinterpret_cast<const RawType*>(value.data()),
- static_cast<int32_t>(value.size())));
- }
- }
- return builder.Finish(out);
-}
-
-static inline Status MakeStringTypesRecordBatch(std::shared_ptr<RecordBatch>* out,
- bool with_nulls = true) {
- const int64_t length = 500;
- auto string_type = utf8();
- auto binary_type = binary();
- auto f0 = field("f0", string_type);
- auto f1 = field("f1", binary_type);
- auto schema = ::arrow::schema({f0, f1});
-
- std::shared_ptr<Array> a0, a1;
- MemoryPool* pool = default_memory_pool();
-
- // Quirk with RETURN_NOT_OK macro and templated functions
- {
- auto s = MakeBinaryArrayWithUniqueValues<StringBuilder, char>(length, with_nulls,
- pool, &a0);
- RETURN_NOT_OK(s);
- }
-
- {
- auto s = MakeBinaryArrayWithUniqueValues<BinaryBuilder, uint8_t>(length, with_nulls,
- pool, &a1);
- RETURN_NOT_OK(s);
- }
- *out = RecordBatch::Make(schema, length, {a0, a1});
- return Status::OK();
-}
-
-static inline Status MakeStringTypesRecordBatchWithNulls(
- std::shared_ptr<RecordBatch>* out) {
- return MakeStringTypesRecordBatch(out, true);
-}
-
-static inline Status MakeNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
- const int64_t length = 500;
- auto f0 = field("f0", null());
- auto schema = ::arrow::schema({f0});
- std::shared_ptr<Array> a0 = std::make_shared<NullArray>(length);
- *out = RecordBatch::Make(schema, length, {a0});
- return Status::OK();
-}
-
-static inline Status MakeListRecordBatch(std::shared_ptr<RecordBatch>* out) {
- // Make the schema
- auto f0 = field("f0", kListInt32);
- auto f1 = field("f1", kListListInt32);
- auto f2 = field("f2", int32());
- auto schema = ::arrow::schema({f0, f1, f2});
-
- // Example data
-
- MemoryPool* pool = default_memory_pool();
- const int length = 200;
- std::shared_ptr<Array> leaf_values, list_array, list_list_array, flat_array;
- const bool include_nulls = true;
- RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &leaf_values));
- RETURN_NOT_OK(
- MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array));
- RETURN_NOT_OK(
- MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array));
- RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array));
- *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array});
- return Status::OK();
-}
-
-static inline Status MakeZeroLengthRecordBatch(std::shared_ptr<RecordBatch>* out) {
- // Make the schema
- auto f0 = field("f0", kListInt32);
- auto f1 = field("f1", kListListInt32);
- auto f2 = field("f2", int32());
- auto schema = ::arrow::schema({f0, f1, f2});
-
- // Example data
- MemoryPool* pool = default_memory_pool();
- const bool include_nulls = true;
- std::shared_ptr<Array> leaf_values, list_array, list_list_array, flat_array;
- RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &leaf_values));
- RETURN_NOT_OK(MakeRandomListArray(leaf_values, 0, include_nulls, pool, &list_array));
- RETURN_NOT_OK(
- MakeRandomListArray(list_array, 0, include_nulls, pool, &list_list_array));
- RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &flat_array));
- *out = RecordBatch::Make(schema, 0, {list_array, list_list_array, flat_array});
- return Status::OK();
-}
-
-static inline Status MakeNonNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
- // Make the schema
- auto f0 = field("f0", kListInt32);
- auto f1 = field("f1", kListListInt32);
- auto f2 = field("f2", int32());
- auto schema = ::arrow::schema({f0, f1, f2});
-
- // Example data
- MemoryPool* pool = default_memory_pool();
- const int length = 50;
- std::shared_ptr<Array> leaf_values, list_array, list_list_array, flat_array;
-
- RETURN_NOT_OK(MakeRandomInt32Array(1000, true, pool, &leaf_values));
- bool include_nulls = false;
- RETURN_NOT_OK(
- MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array));
- RETURN_NOT_OK(
- MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array));
- RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array));
- *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array});
- return Status::OK();
-}
-
-static inline Status MakeDeeplyNestedList(std::shared_ptr<RecordBatch>* out) {
- const int batch_length = 5;
- auto type = int32();
-
- MemoryPool* pool = default_memory_pool();
- std::shared_ptr<Array> array;
- const bool include_nulls = true;
- RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &array));
- for (int i = 0; i < 63; ++i) {
- type = std::static_pointer_cast<DataType>(list(type));
- RETURN_NOT_OK(MakeRandomListArray(array, batch_length, include_nulls, pool, &array));
- }
-
- auto f0 = field("f0", type);
- auto schema = ::arrow::schema({f0});
- std::vector<std::shared_ptr<Array>> arrays = {array};
- *out = RecordBatch::Make(schema, batch_length, arrays);
- return Status::OK();
-}
-
-static inline Status MakeStruct(std::shared_ptr<RecordBatch>* out) {
- // reuse constructed list columns
- std::shared_ptr<RecordBatch> list_batch;
- RETURN_NOT_OK(MakeListRecordBatch(&list_batch));
- std::vector<std::shared_ptr<Array>> columns = {
- list_batch->column(0), list_batch->column(1), list_batch->column(2)};
- auto list_schema = list_batch->schema();
-
- // Define schema
- std::shared_ptr<DataType> type(new StructType(
- {list_schema->field(0), list_schema->field(1), list_schema->field(2)}));
- auto f0 = field("non_null_struct", type);
- auto f1 = field("null_struct", type);
- auto schema = ::arrow::schema({f0, f1});
-
- // construct individual nullable/non-nullable struct arrays
- std::shared_ptr<Array> no_nulls(new StructArray(type, list_batch->num_rows(), columns));
- std::vector<uint8_t> null_bytes(list_batch->num_rows(), 1);
- null_bytes[0] = 0;
- std::shared_ptr<Buffer> null_bitmask;
- RETURN_NOT_OK(BitUtil::BytesToBits(null_bytes, default_memory_pool(), &null_bitmask));
- std::shared_ptr<Array> with_nulls(
- new StructArray(type, list_batch->num_rows(), columns, null_bitmask, 1));
-
- // construct batch
- std::vector<std::shared_ptr<Array>> arrays = {no_nulls, with_nulls};
- *out = RecordBatch::Make(schema, list_batch->num_rows(), arrays);
- return Status::OK();
-}
-
-static inline Status MakeUnion(std::shared_ptr<RecordBatch>* out) {
- // Define schema
- std::vector<std::shared_ptr<Field>> union_types(
- {field("u0", int32()), field("u1", uint8())});
-
- std::vector<uint8_t> type_codes = {5, 10};
- auto sparse_type =
- std::make_shared<UnionType>(union_types, type_codes, UnionMode::SPARSE);
-
- auto dense_type =
- std::make_shared<UnionType>(union_types, type_codes, UnionMode::DENSE);
-
- auto f0 = field("sparse_nonnull", sparse_type, false);
- auto f1 = field("sparse", sparse_type);
- auto f2 = field("dense", dense_type);
-
- auto schema = ::arrow::schema({f0, f1, f2});
-
- // Create data
- std::vector<std::shared_ptr<Array>> sparse_children(2);
- std::vector<std::shared_ptr<Array>> dense_children(2);
-
- const int64_t length = 7;
-
- std::shared_ptr<Buffer> type_ids_buffer;
- std::vector<uint8_t> type_ids = {5, 10, 5, 5, 10, 10, 5};
- RETURN_NOT_OK(CopyBufferFromVector(type_ids, default_memory_pool(), &type_ids_buffer));
-
- std::vector<int32_t> u0_values = {0, 1, 2, 3, 4, 5, 6};
- ArrayFromVector<Int32Type, int32_t>(u0_values, &sparse_children[0]);
-
- std::vector<uint8_t> u1_values = {10, 11, 12, 13, 14, 15, 16};
- ArrayFromVector<UInt8Type, uint8_t>(u1_values, &sparse_children[1]);
-
- // dense children
- u0_values = {0, 2, 3, 7};
- ArrayFromVector<Int32Type, int32_t>(u0_values, &dense_children[0]);
-
- u1_values = {11, 14, 15};
- ArrayFromVector<UInt8Type, uint8_t>(u1_values, &dense_children[1]);
-
- std::shared_ptr<Buffer> offsets_buffer;
- std::vector<int32_t> offsets = {0, 0, 1, 2, 1, 2, 3};
- RETURN_NOT_OK(CopyBufferFromVector(offsets, default_memory_pool(), &offsets_buffer));
-
- std::vector<uint8_t> null_bytes(length, 1);
- null_bytes[2] = 0;
- std::shared_ptr<Buffer> null_bitmask;
- RETURN_NOT_OK(BitUtil::BytesToBits(null_bytes, default_memory_pool(), &null_bitmask));
-
- // construct individual nullable/non-nullable struct arrays
- auto sparse_no_nulls =
- std::make_shared<UnionArray>(sparse_type, length, sparse_children, type_ids_buffer);
- auto sparse = std::make_shared<UnionArray>(sparse_type, length, sparse_children,
- type_ids_buffer, NULLPTR, null_bitmask, 1);
-
- auto dense =
- std::make_shared<UnionArray>(dense_type, length, dense_children, type_ids_buffer,
- offsets_buffer, null_bitmask, 1);
-
- // construct batch
- std::vector<std::shared_ptr<Array>> arrays = {sparse_no_nulls, sparse, dense};
- *out = RecordBatch::Make(schema, length, arrays);
- return Status::OK();
-}
-
-static inline Status MakeDictionary(std::shared_ptr<RecordBatch>* out) {
- const int64_t length = 6;
-
- std::vector<bool> is_valid = {true, true, false, true, true, true};
- std::shared_ptr<Array> dict1, dict2;
-
- std::vector<std::string> dict1_values = {"foo", "bar", "baz"};
- std::vector<std::string> dict2_values = {"foo", "bar", "baz", "qux"};
-
- ArrayFromVector<StringType, std::string>(dict1_values, &dict1);
- ArrayFromVector<StringType, std::string>(dict2_values, &dict2);
-
- auto f0_type = arrow::dictionary(arrow::int32(), dict1);
- auto f1_type = arrow::dictionary(arrow::int8(), dict1, true);
- auto f2_type = arrow::dictionary(arrow::int32(), dict2);
-
- std::shared_ptr<Array> indices0, indices1, indices2;
- std::vector<int32_t> indices0_values = {1, 2, -1, 0, 2, 0};
- std::vector<int8_t> indices1_values = {0, 0, 2, 2, 1, 1};
- std::vector<int32_t> indices2_values = {3, 0, 2, 1, 0, 2};
-
- ArrayFromVector<Int32Type, int32_t>(is_valid, indices0_values, &indices0);
- ArrayFromVector<Int8Type, int8_t>(is_valid, indices1_values, &indices1);
- ArrayFromVector<Int32Type, int32_t>(is_valid, indices2_values, &indices2);
-
- auto a0 = std::make_shared<DictionaryArray>(f0_type, indices0);
- auto a1 = std::make_shared<DictionaryArray>(f1_type, indices1);
- auto a2 = std::make_shared<DictionaryArray>(f2_type, indices2);
-
- // List of dictionary-encoded string
- auto f3_type = list(f1_type);
-
- std::vector<int32_t> list_offsets = {0, 0, 2, 2, 5, 6, 9};
- std::shared_ptr<Array> offsets, indices3;
- ArrayFromVector<Int32Type, int32_t>(std::vector<bool>(list_offsets.size(), true),
- list_offsets, &offsets);
-
- std::vector<int8_t> indices3_values = {0, 1, 2, 0, 1, 2, 0, 1, 2};
- std::vector<bool> is_valid3(9, true);
- ArrayFromVector<Int8Type, int8_t>(is_valid3, indices3_values, &indices3);
-
- std::shared_ptr<Buffer> null_bitmap;
- RETURN_NOT_OK(GetBitmapFromVector(is_valid, &null_bitmap));
-
- std::shared_ptr<Array> a3 = std::make_shared<ListArray>(
- f3_type, length, std::static_pointer_cast<PrimitiveArray>(offsets)->values(),
- std::make_shared<DictionaryArray>(f1_type, indices3), null_bitmap, 1);
+ARROW_EXPORT
+void CompareArraysDetailed(int index, const Array& result, const Array& expected);
- // Dictionary-encoded list of integer
- auto f4_value_type = list(int8());
+ARROW_EXPORT
+void CompareBatchColumnsDetailed(const RecordBatch& result, const RecordBatch& expected);
- std::shared_ptr<Array> offsets4, values4, indices4;
+ARROW_EXPORT
+Status MakeRandomInt32Array(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out, uint32_t seed = 0);
- std::vector<int32_t> list_offsets4 = {0, 2, 2, 3};
- ArrayFromVector<Int32Type, int32_t>(std::vector<bool>(4, true), list_offsets4,
- &offsets4);
+ARROW_EXPORT
+Status MakeRandomListArray(const std::shared_ptr<Array>& child_array, int num_lists,
+ bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out);
- std::vector<int8_t> list_values4 = {0, 1, 2};
- ArrayFromVector<Int8Type, int8_t>(std::vector<bool>(3, true), list_values4, &values4);
+ARROW_EXPORT
+Status MakeRandomBooleanArray(const int length, bool include_nulls,
+ std::shared_ptr<Array>* out);
- auto dict3 = std::make_shared<ListArray>(
- f4_value_type, 3, std::static_pointer_cast<PrimitiveArray>(offsets4)->values(),
- values4);
+ARROW_EXPORT
+Status MakeBooleanBatchSized(const int length, std::shared_ptr<RecordBatch>* out);
- std::vector<int8_t> indices4_values = {0, 1, 2, 0, 1, 2};
- ArrayFromVector<Int8Type, int8_t>(is_valid, indices4_values, &indices4);
+ARROW_EXPORT
+Status MakeBooleanBatch(std::shared_ptr<RecordBatch>* out);
- auto f4_type = dictionary(int8(), dict3);
- auto a4 = std::make_shared<DictionaryArray>(f4_type, indices4);
-
- // construct batch
- auto schema = ::arrow::schema(
- {field("dict1", f0_type), field("sparse", f1_type), field("dense", f2_type),
- field("list of encoded string", f3_type), field("encoded list<int8>", f4_type)});
-
- std::vector<std::shared_ptr<Array>> arrays = {a0, a1, a2, a3, a4};
-
- *out = RecordBatch::Make(schema, length, arrays);
- return Status::OK();
-}
-
-static inline Status MakeDictionaryFlat(std::shared_ptr<RecordBatch>* out) {
- const int64_t length = 6;
-
- std::vector<bool> is_valid = {true, true, false, true, true, true};
- std::shared_ptr<Array> dict1, dict2;
-
- std::vector<std::string> dict1_values = {"foo", "bar", "baz"};
- std::vector<std::string> dict2_values = {"foo", "bar", "baz", "qux"};
-
- ArrayFromVector<StringType, std::string>(dict1_values, &dict1);
- ArrayFromVector<StringType, std::string>(dict2_values, &dict2);
-
- auto f0_type = arrow::dictionary(arrow::int32(), dict1);
- auto f1_type = arrow::dictionary(arrow::int8(), dict1);
- auto f2_type = arrow::dictionary(arrow::int32(), dict2);
-
- std::shared_ptr<Array> indices0, indices1, indices2;
- std::vector<int32_t> indices0_values = {1, 2, -1, 0, 2, 0};
- std::vector<int8_t> indices1_values = {0, 0, 2, 2, 1, 1};
- std::vector<int32_t> indices2_values = {3, 0, 2, 1, 0, 2};
-
- ArrayFromVector<Int32Type, int32_t>(is_valid, indices0_values, &indices0);
- ArrayFromVector<Int8Type, int8_t>(is_valid, indices1_values, &indices1);
- ArrayFromVector<Int32Type, int32_t>(is_valid, indices2_values, &indices2);
-
- auto a0 = std::make_shared<DictionaryArray>(f0_type, indices0);
- auto a1 = std::make_shared<DictionaryArray>(f1_type, indices1);
- auto a2 = std::make_shared<DictionaryArray>(f2_type, indices2);
-
- // construct batch
- auto schema = ::arrow::schema(
- {field("dict1", f0_type), field("sparse", f1_type), field("dense", f2_type)});
-
- std::vector<std::shared_ptr<Array>> arrays = {a0, a1, a2};
- *out = RecordBatch::Make(schema, length, arrays);
- return Status::OK();
-}
-
-static inline Status MakeDates(std::shared_ptr<RecordBatch>* out) {
- std::vector<bool> is_valid = {true, true, true, false, true, true, true};
- auto f0 = field("f0", date32());
- auto f1 = field("f1", date64());
- auto schema = ::arrow::schema({f0, f1});
-
- std::vector<int32_t> date32_values = {0, 1, 2, 3, 4, 5, 6};
- std::shared_ptr<Array> date32_array;
- ArrayFromVector<Date32Type, int32_t>(is_valid, date32_values, &date32_array);
-
- std::vector<int64_t> date64_values = {1489269000000, 1489270000000, 1489271000000,
- 1489272000000, 1489272000000, 1489273000000,
- 1489274000000};
- std::shared_ptr<Array> date64_array;
- ArrayFromVector<Date64Type, int64_t>(is_valid, date64_values, &date64_array);
-
- *out = RecordBatch::Make(schema, date32_array->length(), {date32_array, date64_array});
- return Status::OK();
-}
-
-static inline Status MakeTimestamps(std::shared_ptr<RecordBatch>* out) {
- std::vector<bool> is_valid = {true, true, true, false, true, true, true};
- auto f0 = field("f0", timestamp(TimeUnit::MILLI));
- auto f1 = field("f1", timestamp(TimeUnit::NANO, "America/New_York"));
- auto f2 = field("f2", timestamp(TimeUnit::SECOND));
- auto schema = ::arrow::schema({f0, f1, f2});
+ARROW_EXPORT
+Status MakeIntBatchSized(int length, std::shared_ptr<RecordBatch>* out,
+ uint32_t seed = 0);
- std::vector<int64_t> ts_values = {1489269000000, 1489270000000, 1489271000000,
- 1489272000000, 1489272000000, 1489273000000};
+ARROW_EXPORT
+Status MakeIntRecordBatch(std::shared_ptr<RecordBatch>* out);
- std::shared_ptr<Array> a0, a1, a2;
- ArrayFromVector<TimestampType, int64_t>(f0->type(), is_valid, ts_values, &a0);
- ArrayFromVector<TimestampType, int64_t>(f1->type(), is_valid, ts_values, &a1);
- ArrayFromVector<TimestampType, int64_t>(f2->type(), is_valid, ts_values, &a2);
+ARROW_EXPORT
+Status MakeRandomStringArray(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out);
- *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2});
- return Status::OK();
-}
+ARROW_EXPORT
+Status MakeStringTypesRecordBatch(std::shared_ptr<RecordBatch>* out,
+ bool with_nulls = true);
-static inline Status MakeTimes(std::shared_ptr<RecordBatch>* out) {
- std::vector<bool> is_valid = {true, true, true, false, true, true, true};
- auto f0 = field("f0", time32(TimeUnit::MILLI));
- auto f1 = field("f1", time64(TimeUnit::NANO));
- auto f2 = field("f2", time32(TimeUnit::SECOND));
- auto f3 = field("f3", time64(TimeUnit::NANO));
- auto schema = ::arrow::schema({f0, f1, f2, f3});
+ARROW_EXPORT
+Status MakeStringTypesRecordBatchWithNulls(std::shared_ptr<RecordBatch>* out);
- std::vector<int32_t> t32_values = {1489269000, 1489270000, 1489271000,
- 1489272000, 1489272000, 1489273000};
- std::vector<int64_t> t64_values = {1489269000000, 1489270000000, 1489271000000,
- 1489272000000, 1489272000000, 1489273000000};
-
- std::shared_ptr<Array> a0, a1, a2, a3;
- ArrayFromVector<Time32Type, int32_t>(f0->type(), is_valid, t32_values, &a0);
- ArrayFromVector<Time64Type, int64_t>(f1->type(), is_valid, t64_values, &a1);
- ArrayFromVector<Time32Type, int32_t>(f2->type(), is_valid, t32_values, &a2);
- ArrayFromVector<Time64Type, int64_t>(f3->type(), is_valid, t64_values, &a3);
+ARROW_EXPORT
+Status MakeNullRecordBatch(std::shared_ptr<RecordBatch>* out);
- *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3});
- return Status::OK();
-}
+ARROW_EXPORT
+Status MakeListRecordBatch(std::shared_ptr<RecordBatch>* out);
-template <typename BuilderType, typename T>
-void AppendValues(const std::vector<bool>& is_valid, const std::vector<T>& values,
- BuilderType* builder) {
- for (size_t i = 0; i < values.size(); ++i) {
- if (is_valid[i]) {
- ASSERT_OK(builder->Append(values[i]));
- } else {
- ASSERT_OK(builder->AppendNull());
- }
- }
-}
-
-static inline Status MakeFWBinary(std::shared_ptr<RecordBatch>* out) {
- std::vector<bool> is_valid = {true, true, true, false};
- auto f0 = field("f0", fixed_size_binary(4));
- auto f1 = field("f1", fixed_size_binary(0));
- auto schema = ::arrow::schema({f0, f1});
-
- std::shared_ptr<Array> a1, a2;
-
- FixedSizeBinaryBuilder b1(f0->type());
- FixedSizeBinaryBuilder b2(f1->type());
-
- std::vector<std::string> values1 = {"foo1", "foo2", "foo3", "foo4"};
- AppendValues(is_valid, values1, &b1);
-
- std::vector<std::string> values2 = {"", "", "", ""};
- AppendValues(is_valid, values2, &b2);
-
- RETURN_NOT_OK(b1.Finish(&a1));
- RETURN_NOT_OK(b2.Finish(&a2));
-
- *out = RecordBatch::Make(schema, a1->length(), {a1, a2});
- return Status::OK();
-}
-
-static inline Status MakeDecimal(std::shared_ptr<RecordBatch>* out) {
- constexpr int kDecimalPrecision = 38;
- auto type = decimal(kDecimalPrecision, 4);
- auto f0 = field("f0", type);
- auto f1 = field("f1", type);
- auto schema = ::arrow::schema({f0, f1});
+ARROW_EXPORT
+Status MakeZeroLengthRecordBatch(std::shared_ptr<RecordBatch>* out);
- constexpr int kDecimalSize = 16;
- constexpr int length = 10;
-
- std::shared_ptr<Buffer> data, is_valid;
- std::vector<uint8_t> is_valid_bytes(length);
-
- RETURN_NOT_OK(AllocateBuffer(kDecimalSize * length, &data));
+ARROW_EXPORT
+Status MakeNonNullRecordBatch(std::shared_ptr<RecordBatch>* out);
- random_decimals(length, 1, kDecimalPrecision, data->mutable_data());
- random_null_bytes(length, 0.1, is_valid_bytes.data());
+ARROW_EXPORT
+Status MakeDeeplyNestedList(std::shared_ptr<RecordBatch>* out);
- RETURN_NOT_OK(BitUtil::BytesToBits(is_valid_bytes, default_memory_pool(), &is_valid));
+ARROW_EXPORT
+Status MakeStruct(std::shared_ptr<RecordBatch>* out);
- auto a1 = std::make_shared<Decimal128Array>(f0->type(), length, data, is_valid,
- kUnknownNullCount);
+ARROW_EXPORT
+Status MakeUnion(std::shared_ptr<RecordBatch>* out);
- auto a2 = std::make_shared<Decimal128Array>(f1->type(), length, data);
+ARROW_EXPORT
+Status MakeDictionary(std::shared_ptr<RecordBatch>* out);
- *out = RecordBatch::Make(schema, length, {a1, a2});
- return Status::OK();
-}
+ARROW_EXPORT
+Status MakeDictionaryFlat(std::shared_ptr<RecordBatch>* out);
-static inline Status MakeNull(std::shared_ptr<RecordBatch>* out) {
- auto f0 = field("f0", null());
+ARROW_EXPORT
+Status MakeDates(std::shared_ptr<RecordBatch>* out);
- // Also put a non-null field to make sure we handle the null array buffers properly
- auto f1 = field("f1", int64());
+ARROW_EXPORT
+Status MakeTimestamps(std::shared_ptr<RecordBatch>* out);
- auto schema = ::arrow::schema({f0, f1});
+ARROW_EXPORT
+Status MakeTimes(std::shared_ptr<RecordBatch>* out);
- auto a1 = std::make_shared<NullArray>(10);
+ARROW_EXPORT
+Status MakeFWBinary(std::shared_ptr<RecordBatch>* out);
- std::vector<int64_t> int_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- std::vector<bool> is_valid = {true, true, true, false, false,
- true, true, true, true, true};
- std::shared_ptr<Array> a2;
- ArrayFromVector<Int64Type, int64_t>(f1->type(), is_valid, int_values, &a2);
+ARROW_EXPORT
+Status MakeDecimal(std::shared_ptr<RecordBatch>* out);
- *out = RecordBatch::Make(schema, a1->length(), {a1, a2});
- return Status::OK();
-}
+ARROW_EXPORT
+Status MakeNull(std::shared_ptr<RecordBatch>* out);
+} // namespace test
} // namespace ipc
} // namespace arrow
diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc
index ba99390..bc89dc4 100644
--- a/cpp/src/arrow/ipc/writer.cc
+++ b/cpp/src/arrow/ipc/writer.cc
@@ -22,6 +22,7 @@
#include <cstring>
#include <limits>
#include <sstream>
+#include <utility>
#include <vector>
#include "arrow/array.h"
@@ -43,12 +44,14 @@
#include "arrow/util/bit-util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
+#include "arrow/util/stl.h"
#include "arrow/visitor.h"
namespace arrow {
using internal::checked_cast;
using internal::CopyBitmap;
+using internal::make_unique;
namespace ipc {
@@ -529,17 +532,46 @@ Status WriteIpcPayload(const IpcPayload& payload, io::OutputStream* dst,
return Status::OK();
}
-Status GetSchemaPayload(const Schema& schema, MemoryPool* pool,
- DictionaryMemo* dictionary_memo, IpcPayload* out) {
- out->type = Message::Type::SCHEMA;
- out->body_buffers.clear();
- out->body_length = 0;
- RETURN_NOT_OK(SerializeSchema(schema, pool, &out->metadata));
- return WriteSchemaMessage(schema, dictionary_memo, &out->metadata);
+Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, DictionaryMemo* out_memo,
+ std::vector<IpcPayload>* out_payloads) {
+ DictionaryMemo dictionary_memo;
+ IpcPayload payload;
+
+ out_payloads->clear();
+ payload.type = Message::SCHEMA;
+ RETURN_NOT_OK(WriteSchemaMessage(schema, &dictionary_memo, &payload.metadata));
+ out_payloads->push_back(std::move(payload));
+ out_payloads->reserve(dictionary_memo.size() + 1);
+
+ // Append dictionaries
+ for (auto& pair : dictionary_memo.id_to_dictionary()) {
+ int64_t dictionary_id = pair.first;
+ const auto& dictionary = pair.second;
+
+ // Frame of reference is 0, see ARROW-384
+ const int64_t buffer_start_offset = 0;
+ payload.type = Message::DICTIONARY_BATCH;
+ DictionaryWriter writer(dictionary_id, pool, buffer_start_offset, kMaxNestingDepth,
+ true /* allow_64bit */, &payload);
+ RETURN_NOT_OK(writer.Assemble(dictionary));
+ out_payloads->push_back(std::move(payload));
+ }
+
+ if (out_memo != nullptr) {
+ *out_memo = std::move(dictionary_memo);
+ }
+
+ return Status::OK();
+}
+
+Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool,
+ std::vector<IpcPayload>* out_payloads) {
+ return GetSchemaPayloads(schema, pool, nullptr, out_payloads);
}
Status GetRecordBatchPayload(const RecordBatch& batch, MemoryPool* pool,
IpcPayload* out) {
+ out->type = Message::RECORD_BATCH;
RecordBatchSerializer writer(pool, 0, kMaxNestingDepth, true, out);
return writer.Assemble(batch);
}
@@ -846,11 +878,93 @@ Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize)
Status RecordBatchWriter::WriteTable(const Table& table) { return WriteTable(table, -1); }
// ----------------------------------------------------------------------
-// Stream writer implementation
+// Payload writer implementation
+
+namespace internal {
+
+IpcPayloadWriter::~IpcPayloadWriter() {}
+
+Status IpcPayloadWriter::Start() { return Status::OK(); }
+
+} // namespace internal
+
+namespace {
+
+/// A RecordBatchWriter implementation that writes to a IpcPayloadWriter.
+class RecordBatchPayloadWriter : public RecordBatchWriter {
+ public:
+ ~RecordBatchPayloadWriter() override = default;
+
+ RecordBatchPayloadWriter(std::unique_ptr<internal::IpcPayloadWriter> payload_writer,
+ const Schema& schema)
+ : payload_writer_(std::move(payload_writer)),
+ schema_(schema),
+ pool_(default_memory_pool()),
+ started_(false) {}
+
+ // A Schema-owning constructor variant
+ RecordBatchPayloadWriter(std::unique_ptr<internal::IpcPayloadWriter> payload_writer,
+ const std::shared_ptr<Schema>& schema)
+ : payload_writer_(std::move(payload_writer)),
+ shared_schema_(schema),
+ schema_(*schema),
+ pool_(default_memory_pool()),
+ started_(false) {}
+
+ Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override {
+ if (!batch.schema()->Equals(schema_, false /* check_metadata */)) {
+ return Status::Invalid("Tried to write record batch with different schema");
+ }
+
+ RETURN_NOT_OK(CheckStarted());
+ internal::IpcPayload payload;
+ RETURN_NOT_OK(GetRecordBatchPayload(batch, pool_, &payload));
+ return payload_writer_->WritePayload(payload);
+ }
+
+ Status Close() override {
+ RETURN_NOT_OK(CheckStarted());
+ return payload_writer_->Close();
+ }
+
+ void set_memory_pool(MemoryPool* pool) override { pool_ = pool; }
+
+ Status Start() {
+ started_ = true;
+ RETURN_NOT_OK(payload_writer_->Start());
+
+ // Write out schema payloads
+ std::vector<internal::IpcPayload> payloads;
+ // XXX should we have a GetSchemaPayloads() variant that generates them
+ // one by one, to minimize memory usage?
+ RETURN_NOT_OK(GetSchemaPayloads(schema_, pool_, &payloads));
+ for (const auto& payload : payloads) {
+ RETURN_NOT_OK(payload_writer_->WritePayload(payload));
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status CheckStarted() {
+ if (!started_) {
+ return Start();
+ }
+ return Status::OK();
+ }
+
+ protected:
+ std::unique_ptr<internal::IpcPayloadWriter> payload_writer_;
+ std::shared_ptr<Schema> shared_schema_;
+ const Schema& schema_;
+ MemoryPool* pool_;
+ bool started_;
+};
+
+// ----------------------------------------------------------------------
+// Stream and file writer implementation
class StreamBookKeeper {
public:
- StreamBookKeeper() : sink_(nullptr), position_(-1) {}
explicit StreamBookKeeper(io::OutputStream* sink) : sink_(sink), position_(-1) {}
Status UpdatePosition() { return sink_->Tell(&position_); }
@@ -883,142 +997,131 @@ class StreamBookKeeper {
int64_t position_;
};
-class SchemaWriter : public StreamBookKeeper {
+/// A IpcPayloadWriter implementation that writes to a IPC stream
+/// (with an end-of-stream marker)
+class PayloadStreamWriter : public internal::IpcPayloadWriter,
+ protected StreamBookKeeper {
public:
- SchemaWriter(const Schema& schema, DictionaryMemo* dictionary_memo, MemoryPool* pool,
- io::OutputStream* sink)
- : StreamBookKeeper(sink),
- pool_(pool),
- schema_(schema),
- dictionary_memo_(dictionary_memo) {}
+ explicit PayloadStreamWriter(io::OutputStream* sink) : StreamBookKeeper(sink) {}
+
+ ~PayloadStreamWriter() override = default;
- Status WriteSchema() {
+ Status WritePayload(const internal::IpcPayload& payload) override {
#ifndef NDEBUG
// Catch bug fixed in ARROW-3236
RETURN_NOT_OK(UpdatePositionCheckAligned());
#endif
- std::shared_ptr<Buffer> schema_fb;
- RETURN_NOT_OK(internal::WriteSchemaMessage(schema_, dictionary_memo_, &schema_fb));
-
- int32_t metadata_length = 0;
- RETURN_NOT_OK(internal::WriteMessage(*schema_fb, 8, sink_, &metadata_length));
+ int32_t metadata_length = 0; // unused
+ RETURN_NOT_OK(WriteIpcPayload(payload, sink_, &metadata_length));
RETURN_NOT_OK(UpdatePositionCheckAligned());
return Status::OK();
}
- Status WriteDictionaries(std::vector<FileBlock>* dictionaries) {
- const DictionaryMap& id_to_dictionary = dictionary_memo_->id_to_dictionary();
-
- dictionaries->resize(id_to_dictionary.size());
-
- // TODO(wesm): does sorting by id yield any benefit?
- int dict_index = 0;
- for (const auto& entry : id_to_dictionary) {
- FileBlock* block = &(*dictionaries)[dict_index++];
-
- block->offset = position_;
-
- // Frame of reference in file format is 0, see ARROW-384
- const int64_t buffer_start_offset = 0;
- RETURN_NOT_OK(WriteDictionary(entry.first, entry.second, buffer_start_offset, sink_,
- &block->metadata_length, &block->body_length, pool_));
- RETURN_NOT_OK(UpdatePositionCheckAligned());
- }
-
- return Status::OK();
- }
-
- Status Write(std::vector<FileBlock>* dictionaries) {
- RETURN_NOT_OK(WriteSchema());
-
- // If there are any dictionaries, write them as the next messages
- return WriteDictionaries(dictionaries);
+ Status Close() override {
+ // Write 0 EOS message
+ const int32_t kEos = 0;
+ return Write(&kEos, sizeof(int32_t));
}
-
- private:
- MemoryPool* pool_;
- const Schema& schema_;
- DictionaryMemo* dictionary_memo_;
};
-class RecordBatchStreamWriter::RecordBatchStreamWriterImpl : public StreamBookKeeper {
+/// A IpcPayloadWriter implementation that writes to a IPC file
+/// (with a footer as defined in File.fbs)
+class PayloadFileWriter : public internal::IpcPayloadWriter, protected StreamBookKeeper {
public:
- RecordBatchStreamWriterImpl(io::OutputStream* sink,
- const std::shared_ptr<Schema>& schema)
- : StreamBookKeeper(sink),
- schema_(schema),
- pool_(default_memory_pool()),
- started_(false) {}
+ PayloadFileWriter(io::OutputStream* sink, const std::shared_ptr<Schema>& schema)
+ : StreamBookKeeper(sink), schema_(schema) {}
- virtual ~RecordBatchStreamWriterImpl() = default;
+ ~PayloadFileWriter() override = default;
- virtual Status Start() {
- SchemaWriter schema_writer(*schema_, &dictionary_memo_, pool_, sink_);
- RETURN_NOT_OK(schema_writer.Write(&dictionaries_));
- started_ = true;
- return Status::OK();
- }
-
- virtual Status Close() {
- // Write the schema if not already written
- // User is responsible for closing the OutputStream
- RETURN_NOT_OK(CheckStarted());
+ Status WritePayload(const internal::IpcPayload& payload) override {
+#ifndef NDEBUG
+ // Catch bug fixed in ARROW-3236
+ RETURN_NOT_OK(UpdatePositionCheckAligned());
+#endif
- // Write 0 EOS message
- const int32_t kEos = 0;
- return Write(&kEos, sizeof(int32_t));
- }
+ // Metadata length must include padding, it's computed by WriteIpcPayload()
+ FileBlock block = {position_, 0, payload.body_length};
+ RETURN_NOT_OK(WriteIpcPayload(payload, sink_, &block.metadata_length));
+ RETURN_NOT_OK(UpdatePositionCheckAligned());
- Status CheckStarted() {
- if (!started_) {
- return Start();
+ // Record position and size of some message types, to list them in the footer
+ switch (payload.type) {
+ case Message::DICTIONARY_BATCH:
+ dictionaries_.push_back(block);
+ break;
+ case Message::RECORD_BATCH:
+ record_batches_.push_back(block);
+ break;
+ default:
+ break;
}
+
return Status::OK();
}
- Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit, FileBlock* block) {
- RETURN_NOT_OK(CheckStarted());
+ Status Start() override {
+ // ARROW-3236: The initial position -1 needs to be updated to the stream's
+ // current position otherwise an incorrect amount of padding will be
+ // written to new files.
RETURN_NOT_OK(UpdatePosition());
- block->offset = position_;
-
- // Frame of reference in file format is 0, see ARROW-384
- const int64_t buffer_start_offset = 0;
- RETURN_NOT_OK(arrow::ipc::WriteRecordBatch(
- batch, buffer_start_offset, sink_, &block->metadata_length, &block->body_length,
- pool_, kMaxNestingDepth, allow_64bit));
- RETURN_NOT_OK(UpdatePositionCheckAligned());
+ // It is only necessary to align to 8-byte boundary at the start of the file
+ RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes)));
+ RETURN_NOT_OK(Align());
return Status::OK();
}
- Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) {
- // Push an empty FileBlock. Can be written in the footer later
- if (!batch.schema()->Equals(*schema_, false /* check_metadata */)) {
- return Status::Invalid("Tried to write record batch with different schema");
+ Status Close() override {
+ // Write file footer
+ RETURN_NOT_OK(UpdatePosition());
+ int64_t initial_position = position_;
+ RETURN_NOT_OK(WriteFileFooter(*schema_, dictionaries_, record_batches_, sink_));
+
+ // Write footer length
+ RETURN_NOT_OK(UpdatePosition());
+ int32_t footer_length = static_cast<int32_t>(position_ - initial_position);
+ if (footer_length <= 0) {
+ return Status::Invalid("Invalid file footer");
}
- record_batches_.push_back({0, 0, 0});
- return WriteRecordBatch(batch, allow_64bit,
- &record_batches_[record_batches_.size() - 1]);
- }
+ RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t)));
- void set_memory_pool(MemoryPool* pool) { pool_ = pool; }
+ // Write magic bytes to end file
+ return Write(kArrowMagicBytes, strlen(kArrowMagicBytes));
+ }
protected:
std::shared_ptr<Schema> schema_;
- MemoryPool* pool_;
- bool started_;
-
- // When writing out the schema, we keep track of all the dictionaries we
- // encounter, as they must be written out first in the stream
- DictionaryMemo dictionary_memo_;
-
std::vector<FileBlock> dictionaries_;
std::vector<FileBlock> record_batches_;
};
+} // namespace
+
+class RecordBatchStreamWriter::RecordBatchStreamWriterImpl
+ : public RecordBatchPayloadWriter {
+ public:
+ RecordBatchStreamWriterImpl(io::OutputStream* sink,
+ const std::shared_ptr<Schema>& schema)
+ : RecordBatchPayloadWriter(
+ std::unique_ptr<internal::IpcPayloadWriter>(new PayloadStreamWriter(sink)),
+ schema) {}
+
+ ~RecordBatchStreamWriterImpl() = default;
+};
+
+class RecordBatchFileWriter::RecordBatchFileWriterImpl : public RecordBatchPayloadWriter {
+ public:
+ RecordBatchFileWriterImpl(io::OutputStream* sink, const std::shared_ptr<Schema>& schema)
+ : RecordBatchPayloadWriter(std::unique_ptr<internal::IpcPayloadWriter>(
+ new PayloadFileWriter(sink, schema)),
+ schema) {}
+
+ ~RecordBatchFileWriterImpl() = default;
+};
+
RecordBatchStreamWriter::RecordBatchStreamWriter() {}
RecordBatchStreamWriter::~RecordBatchStreamWriter() {}
@@ -1044,59 +1147,6 @@ Status RecordBatchStreamWriter::Open(io::OutputStream* sink,
Status RecordBatchStreamWriter::Close() { return impl_->Close(); }
-// ----------------------------------------------------------------------
-// File writer implementation
-
-class RecordBatchFileWriter::RecordBatchFileWriterImpl
- : public RecordBatchStreamWriter::RecordBatchStreamWriterImpl {
- public:
- using BASE = RecordBatchStreamWriter::RecordBatchStreamWriterImpl;
-
- RecordBatchFileWriterImpl(io::OutputStream* sink, const std::shared_ptr<Schema>& schema)
- : BASE(sink, schema) {}
-
- Status Start() override {
- // ARROW-3236: The initial position -1 needs to be updated to the stream's
- // current position otherwise an incorrect amount of padding will be
- // written to new files.
- RETURN_NOT_OK(UpdatePosition());
-
- // It is only necessary to align to 8-byte boundary at the start of the file
- RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes)));
- RETURN_NOT_OK(Align());
-
- // We write the schema at the start of the file (and the end). This also
- // writes all the dictionaries at the beginning of the file
- return BASE::Start();
- }
-
- Status Close() override {
- // Write the schema if not already written
- // User is responsible for closing the OutputStream
- RETURN_NOT_OK(CheckStarted());
-
- // Write metadata
- RETURN_NOT_OK(UpdatePosition());
-
- int64_t initial_position = position_;
- RETURN_NOT_OK(WriteFileFooter(*schema_, dictionaries_, record_batches_,
- &dictionary_memo_, sink_));
- RETURN_NOT_OK(UpdatePosition());
-
- // Write footer length
- int32_t footer_length = static_cast<int32_t>(position_ - initial_position);
-
- if (footer_length <= 0) {
- return Status::Invalid("Invalid file footer");
- }
-
- RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t)));
-
- // Write magic bytes to end file
- return Write(kArrowMagicBytes, strlen(kArrowMagicBytes));
- }
-};
-
RecordBatchFileWriter::RecordBatchFileWriter() {}
RecordBatchFileWriter::~RecordBatchFileWriter() {}
@@ -1118,6 +1168,18 @@ Status RecordBatchFileWriter::WriteRecordBatch(const RecordBatch& batch,
Status RecordBatchFileWriter::Close() { return file_impl_->Close(); }
+namespace internal {
+
+Status OpenRecordBatchWriter(std::unique_ptr<IpcPayloadWriter> sink,
+ const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<RecordBatchWriter>* out) {
+ out->reset(new RecordBatchPayloadWriter(std::move(sink), schema));
+ // XXX should we call Start()?
+ return Status::OK();
+}
+
+} // namespace internal
+
// ----------------------------------------------------------------------
// Serialization public APIs
@@ -1142,18 +1204,20 @@ Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool,
kMaxNestingDepth, true);
}
+// TODO: this function also serializes dictionaries. This is suboptimal for
+// the purpose of transmitting working set metadata without actually sending
+// the data (e.g. ListFlights() in Flight RPC).
+
Status SerializeSchema(const Schema& schema, MemoryPool* pool,
std::shared_ptr<Buffer>* out) {
std::shared_ptr<io::BufferOutputStream> stream;
RETURN_NOT_OK(io::BufferOutputStream::Create(1024, pool, &stream));
- DictionaryMemo memo;
- SchemaWriter schema_writer(schema, &memo, pool, stream.get());
-
- // Unused
- std::vector<FileBlock> dictionary_blocks;
+ auto payload_writer = make_unique<PayloadStreamWriter>(stream.get());
+ RecordBatchPayloadWriter writer(std::move(payload_writer), schema);
+ // Write out schema and dictionaries
+ RETURN_NOT_OK(writer.Start());
- RETURN_NOT_OK(schema_writer.Write(&dictionary_blocks));
return stream->Finish(out);
}
diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h
index 50872e9..75034ea 100644
--- a/cpp/src/arrow/ipc/writer.h
+++ b/cpp/src/arrow/ipc/writer.h
@@ -302,28 +302,48 @@ namespace internal {
// Intermediate data structure with metadata header, and zero or more buffers
// for the message body.
struct IpcPayload {
- Message::Type type;
+ Message::Type type = Message::NONE;
std::shared_ptr<Buffer> metadata;
std::vector<std::shared_ptr<Buffer>> body_buffers;
- int64_t body_length;
+ int64_t body_length = 0;
};
-/// \brief Extract IPC payloads from given schema for purposes of wire
-/// transport, separate from using the *StreamWriter classes
+class ARROW_EXPORT IpcPayloadWriter {
+ public:
+ virtual ~IpcPayloadWriter();
+
+ // Default implementation is a no-op
+ virtual Status Start();
+
+ virtual Status WritePayload(const IpcPayload& payload) = 0;
+
+ virtual Status Close() = 0;
+};
+
+/// Create a new RecordBatchWriter from IpcPayloadWriter and schema.
+///
+/// \param[in] sink the IpcPayloadWriter to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[out] out the created RecordBatchWriter
+/// \return Status
ARROW_EXPORT
-Status GetDictionaryPayloads(const Schema& schema,
- std::vector<std::unique_ptr<IpcPayload>>* out);
+Status OpenRecordBatchWriter(std::unique_ptr<IpcPayloadWriter> sink,
+ const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<RecordBatchWriter>* out);
-/// \brief Compute IpcPayload for the given schema
+/// \brief Compute IpcPayloads for the given schema
/// \param[in] schema the Schema that is being serialized
/// \param[in,out] pool for any required temporary memory allocations
/// \param[in,out] dictionary_memo class for tracking dictionaries and assigning
/// dictionary ids
-/// \param[out] out the returned IpcPayload
+/// \param[out] out the returned vector of IpcPayloads
/// \return Status
ARROW_EXPORT
-Status GetSchemaPayload(const Schema& schema, MemoryPool* pool,
- DictionaryMemo* dictionary_memo, IpcPayload* out);
+Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool,
+ DictionaryMemo* dictionary_memo, std::vector<IpcPayload>* out);
+ARROW_EXPORT
+Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool,
+ std::vector<IpcPayload>* out);
/// \brief Compute IpcPayload for the given record batch
/// \param[in] batch the RecordBatch that is being serialized
diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc
index f5810a7..01fb29d 100644
--- a/cpp/src/arrow/python/arrow_to_pandas.cc
+++ b/cpp/src/arrow/python/arrow_to_pandas.cc
@@ -1889,11 +1889,8 @@ class ArrowDeserializer {
return Status::OK();
}
- Status Visit(const UnionType& type) { return Status::NotImplemented("union type"); }
-
- Status Visit(const ExtensionType& type) {
- return Status::NotImplemented("extension type");
- }
+ // Default case
+ Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); }
Status Convert(PyObject** out) {
RETURN_NOT_OK(VisitTypeInline(*col_->type(), this));
diff --git a/cpp/src/arrow/python/numpy_to_arrow.cc b/cpp/src/arrow/python/numpy_to_arrow.cc
index 36f3ccb..ca3f596 100644
--- a/cpp/src/arrow/python/numpy_to_arrow.cc
+++ b/cpp/src/arrow/python/numpy_to_arrow.cc
@@ -234,13 +234,8 @@ class NumPyConverter {
Status Visit(const FixedSizeBinaryType& type);
- Status Visit(const Decimal128Type& type) { return TypeNotImplemented(type.ToString()); }
-
- Status Visit(const DictionaryType& type) { return TypeNotImplemented(type.ToString()); }
-
- Status Visit(const NestedType& type) { return TypeNotImplemented(type.ToString()); }
-
- Status Visit(const ExtensionType& type) { return TypeNotImplemented(type.ToString()); }
+ // Default case
+ Status Visit(const DataType& type) { return TypeNotImplemented(type.ToString()); }
protected:
Status InitNullBitmap() {
diff --git a/cpp/src/gandiva/expression_registry.cc b/cpp/src/gandiva/expression_registry.cc
index 8e667f8..d062963 100644
--- a/cpp/src/gandiva/expression_registry.cc
+++ b/cpp/src/gandiva/expression_registry.cc
@@ -139,15 +139,8 @@ void ExpressionRegistry::AddArrowTypesToVector(arrow::Type::type& type,
case arrow::Type::type::DECIMAL:
vector.push_back(arrow::decimal(38, 0));
break;
- case arrow::Type::type::FIXED_SIZE_BINARY:
- case arrow::Type::type::MAP:
- case arrow::Type::type::INTERVAL:
- case arrow::Type::type::LIST:
- case arrow::Type::type::STRUCT:
- case arrow::Type::type::UNION:
- case arrow::Type::type::DICTIONARY:
- case arrow::Type::type::EXTENSION:
- // un-supported types. test ensures that
+ default:
+ // Unsupported types. test ensures that
// when one of these are added build breaks.
DCHECK(false);
}
diff --git a/cpp/src/gandiva/jni/expression_registry_helper.cc b/cpp/src/gandiva/jni/expression_registry_helper.cc
index 2275641..7b7834d 100644
--- a/cpp/src/gandiva/jni/expression_registry_helper.cc
+++ b/cpp/src/gandiva/jni/expression_registry_helper.cc
@@ -127,14 +127,7 @@ void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type)
gandiva_data_type->set_scale(0);
break;
}
- case arrow::Type::type::FIXED_SIZE_BINARY:
- case arrow::Type::type::MAP:
- case arrow::Type::type::INTERVAL:
- case arrow::Type::type::LIST:
- case arrow::Type::type::STRUCT:
- case arrow::Type::type::UNION:
- case arrow::Type::type::DICTIONARY:
- case arrow::Type::type::EXTENSION:
+ default:
// un-supported types. test ensures that
// when one of these are added build breaks.
DCHECK(false);