You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/06/26 19:19:47 UTC
[arrow] branch master updated: ARROW-4626: [Flight] Add
application-defined metadata to DoGet/DoPut
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 63971ad ARROW-4626: [Flight] Add application-defined metadata to DoGet/DoPut
63971ad is described below
commit 63971ad645919a4b97818b58686f25e04e39699f
Author: David Li <li...@gmail.com>
AuthorDate: Wed Jun 26 14:19:36 2019 -0500
ARROW-4626: [Flight] Add application-defined metadata to DoGet/DoPut
Also covers [ARROW-4627](https://issues.apache.org/jira/browse/ARROW-4627).
This is quite an enormous change, if preferred, I can do my best to try and separate changes.
Author: David Li <li...@gmail.com>
Closes #4282 from lihalite/arrow-4626-application-metadata and squashes the following commits:
6f1cd8db7 <David Li> Rework interface for accessing server-sent metadata during DoPut
8fd99cd9b <David Li> Inline CompletableFuture in Flight acceptPut
4cebc543a <David Li> Mark flaky Flight test
c551d8527 <David Li> Fix new CheckStyle violations
85e21699c <David Li> Fix Flight integration tests using metadata
eff22393c <David Li> Use FlightStreamChunk in Flight/C++
72c2a3fa0 <David Li> Try to always close FlightStream after acceptPut
1718d9b42 <David Li> Make FlightStream cancellable from acceptPut
7ac44df79 <David Li> Make Netty version consistent with gRPC
1225b67af <David Li> Use ArrowBuf instead of byte for Flight metadata
ccfef2d1e <David Li> Disable Flight cancellation tests in CI
0484c333c <David Li> Pass Flight context to ListActions in Python
b0f71d967 <David Li> Replace ARROW_EXPORT with ARROW_FLIGHT_EXPORT
fdaa76e99 <David Li> Add client-side cancelation of DoGet operations
b4dbc445e <David Li> Enable non-nested dictionary batches in Flight integration tests
f7631a2fd <David Li> Add basic Arrow Flight docs
a8ac27fb3 <David Li> Implement application metadata in Flight
86f4789ab <David Li> Add application metadata field to FlightData message
---
cpp/src/arrow/flight/client.cc | 207 +++++++++++--
cpp/src/arrow/flight/client.h | 52 +++-
cpp/src/arrow/flight/flight-benchmark.cc | 18 +-
cpp/src/arrow/flight/flight-test.cc | 169 +++++++++--
cpp/src/arrow/flight/internal.h | 3 +-
cpp/src/arrow/flight/serialization-internal.cc | 28 +-
cpp/src/arrow/flight/serialization-internal.h | 8 +-
cpp/src/arrow/flight/server.cc | 68 ++++-
cpp/src/arrow/flight/server.h | 27 +-
cpp/src/arrow/flight/test-integration-client.cc | 55 +++-
cpp/src/arrow/flight/test-integration-server.cc | 19 +-
cpp/src/arrow/flight/test-util.cc | 18 ++
cpp/src/arrow/flight/test-util.h | 18 ++
cpp/src/arrow/flight/types.cc | 19 ++
cpp/src/arrow/flight/types.h | 27 ++
cpp/src/arrow/ipc/reader.cc | 2 +-
cpp/src/arrow/python/flight.cc | 8 +-
cpp/src/arrow/python/flight.h | 6 +-
docs/source/conf.py | 9 +
docs/source/cpp/api.rst | 1 +
docs/source/cpp/api/flight.rst | 126 ++++++++
docs/source/format/Flight.rst | 106 +++++++
docs/source/index.rst | 1 +
docs/source/python/api.rst | 1 +
docs/source/python/api/flight.rst | 82 ++++++
format/Flight.proto | 14 +-
integration/integration_test.py | 2 +-
.../java/org/apache/arrow/flight/ArrowMessage.java | 83 ++++--
.../org/apache/arrow/flight/AsyncPutListener.java | 63 ++++
.../org/apache/arrow/flight/DictionaryUtils.java | 77 +++++
.../apache/arrow/flight/FlightBindingService.java | 11 +-
.../java/org/apache/arrow/flight/FlightClient.java | 154 +++++++---
.../org/apache/arrow/flight/FlightProducer.java | 104 ++++++-
.../java/org/apache/arrow/flight/FlightServer.java | 25 +-
.../org/apache/arrow/flight/FlightService.java | 44 ++-
.../java/org/apache/arrow/flight/FlightStream.java | 94 +++++-
.../apache/arrow/flight/NoOpFlightProducer.java | 8 +-
...nericOperation.java => NoOpStreamListener.java} | 33 ++-
.../java/org/apache/arrow/flight/PutResult.java | 97 ++++++
.../org/apache/arrow/flight/SyncPutListener.java | 114 ++++++++
.../apache/arrow/flight/example/FlightHolder.java | 14 +-
.../apache/arrow/flight/example/InMemoryStore.java | 23 +-
.../org/apache/arrow/flight/example/Stream.java | 25 +-
.../example/integration/IntegrationTestClient.java | 32 +-
.../arrow/flight/grpc/GetReadableBuffer.java | 23 ++
.../org/apache/arrow/flight/FlightTestUtil.java | 5 +-
.../arrow/flight/TestApplicationMetadata.java | 245 ++++++++++++++++
.../org/apache/arrow/flight/TestBackPressure.java | 80 +++--
.../apache/arrow/flight/TestBasicOperation.java | 14 +-
.../org/apache/arrow/flight/TestCallOptions.java | 7 +-
.../org/apache/arrow/flight/TestLargeMessage.java | 28 +-
.../test/java/org/apache/arrow/flight/TestTls.java | 4 +-
.../org/apache/arrow/flight/auth/TestAuth.java | 8 +-
.../arrow/flight/example/TestExampleServer.java | 18 +-
.../arrow/flight/perf/PerformanceTestServer.java | 34 +--
.../org/apache/arrow/flight/perf/TestPerf.java | 55 ++--
java/pom.xml | 2 +-
python/examples/flight/server.py | 2 +-
python/pyarrow/_flight.pyx | 325 +++++++++++++++++++--
python/pyarrow/includes/libarrow_flight.pxd | 44 ++-
python/pyarrow/tests/test_flight.py | 197 ++++++++++++-
61 files changed, 2713 insertions(+), 473 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 1926928..f81b627 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -79,7 +79,7 @@ struct ClientRpc {
if (auth_handler) {
std::string token;
RETURN_NOT_OK(auth_handler->GetToken(&token));
- context.AddMetadata(internal::AUTH_HEADER, token);
+ context.AddMetadata(internal::kGrpcAuthHeader, token);
}
return Status::OK();
}
@@ -129,15 +129,47 @@ class GrpcClientAuthReader : public ClientAuthReader {
stream_;
};
-class FlightIpcMessageReader : public ipc::MessageReader {
+// The next two classes are intertwined. To get the application
+// metadata while avoiding reimplementing RecordBatchStreamReader, we
+// create an ipc::MessageReader that is tied to the
+// MetadataRecordBatchReader. Every time an IPC message is read, it updates
+// the application metadata field of the MetadataRecordBatchReader. The
+// MetadataRecordBatchReader wraps RecordBatchStreamReader, offering an
+// additional method to get both the record batch and application
+// metadata.
+
+class GrpcIpcMessageReader;
+class GrpcStreamReader : public FlightStreamReader {
public:
- FlightIpcMessageReader(std::unique_ptr<ClientRpc> rpc,
- std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream)
- : rpc_(std::move(rpc)), stream_(std::move(stream)), stream_finished_(false) {}
+ GrpcStreamReader();
+
+ static Status Open(std::unique_ptr<ClientRpc> rpc,
+ std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream,
+ std::unique_ptr<GrpcStreamReader>* out);
+ std::shared_ptr<Schema> schema() const override;
+ Status Next(FlightStreamChunk* out) override;
+ void Cancel() override;
+
+ private:
+ friend class GrpcIpcMessageReader;
+ std::unique_ptr<ipc::RecordBatchReader> batch_reader_;
+ std::shared_ptr<Buffer> last_app_metadata_;
+ std::shared_ptr<ClientRpc> rpc_;
+};
+
+class GrpcIpcMessageReader : public ipc::MessageReader {
+ public:
+ GrpcIpcMessageReader(GrpcStreamReader* reader, std::shared_ptr<ClientRpc> rpc,
+ std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream)
+ : flight_reader_(reader),
+ rpc_(rpc),
+ stream_(std::move(stream)),
+ stream_finished_(false) {}
Status ReadNextMessage(std::unique_ptr<ipc::Message>* out) override {
if (stream_finished_) {
*out = nullptr;
+ flight_reader_->last_app_metadata_ = nullptr;
return Status::OK();
}
internal::FlightData data;
@@ -145,13 +177,16 @@ class FlightIpcMessageReader : public ipc::MessageReader {
// Stream is completed
stream_finished_ = true;
*out = nullptr;
+ flight_reader_->last_app_metadata_ = nullptr;
return OverrideWithServerError(Status::OK());
}
// Validate IPC message
auto st = data.OpenMessage(out);
if (!st.ok()) {
+ flight_reader_->last_app_metadata_ = nullptr;
return OverrideWithServerError(std::move(st));
}
+ flight_reader_->last_app_metadata_ = data.app_metadata;
return Status::OK();
}
@@ -162,23 +197,93 @@ class FlightIpcMessageReader : public ipc::MessageReader {
return std::move(st);
}
+ private:
+ GrpcStreamReader* flight_reader_;
// The RPC context lifetime must be coupled to the ClientReader
- std::unique_ptr<ClientRpc> rpc_;
+ std::shared_ptr<ClientRpc> rpc_;
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream_;
bool stream_finished_;
};
+GrpcStreamReader::GrpcStreamReader() {}
+
+Status GrpcStreamReader::Open(std::unique_ptr<ClientRpc> rpc,
+ std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream,
+ std::unique_ptr<GrpcStreamReader>* out) {
+ *out = std::unique_ptr<GrpcStreamReader>(new GrpcStreamReader);
+ out->get()->rpc_ = std::move(rpc);
+ std::unique_ptr<GrpcIpcMessageReader> message_reader(
+ new GrpcIpcMessageReader(out->get(), out->get()->rpc_, std::move(stream)));
+ return ipc::RecordBatchStreamReader::Open(std::move(message_reader),
+ &(*out)->batch_reader_);
+}
+
+std::shared_ptr<Schema> GrpcStreamReader::schema() const {
+ return batch_reader_->schema();
+}
+
+Status GrpcStreamReader::Next(FlightStreamChunk* out) {
+ out->app_metadata = nullptr;
+ RETURN_NOT_OK(batch_reader_->ReadNext(&out->data));
+ out->app_metadata = std::move(last_app_metadata_);
+ return Status::OK();
+}
+
+void GrpcStreamReader::Cancel() { rpc_->context.TryCancel(); }
+
+// Similarly, the next two classes are intertwined. In order to get
+// application-specific metadata to the IpcPayloadWriter,
+// DoPutPayloadWriter takes a pointer to
+// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on
+// write; DoPutPayloadWriter reads that metadata field to determine
+// what to write.
+
+class DoPutPayloadWriter;
+class GrpcStreamWriter : public FlightStreamWriter {
+ public:
+ ~GrpcStreamWriter() = default;
+
+ GrpcStreamWriter() : app_metadata_(nullptr), batch_writer_(nullptr) {}
+
+ static Status Open(
+ const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<ClientRpc> rpc, std::unique_ptr<pb::PutResult> response,
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer,
+ std::unique_ptr<FlightStreamWriter>* out);
+
+ Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override {
+ return WriteWithMetadata(batch, nullptr, allow_64bit);
+ }
+ Status WriteWithMetadata(const RecordBatch& batch, std::shared_ptr<Buffer> app_metadata,
+ bool allow_64bit = false) override {
+ app_metadata_ = app_metadata;
+ return batch_writer_->WriteRecordBatch(batch, allow_64bit);
+ }
+ void set_memory_pool(MemoryPool* pool) override {
+ batch_writer_->set_memory_pool(pool);
+ }
+ Status Close() override { return batch_writer_->Close(); }
+
+ private:
+ friend class DoPutPayloadWriter;
+ std::shared_ptr<Buffer> app_metadata_;
+ std::unique_ptr<ipc::RecordBatchWriter> batch_writer_;
+};
+
/// A IpcPayloadWriter implementation that writes to a DoPut stream
class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
public:
- DoPutPayloadWriter(const FlightDescriptor& descriptor, std::unique_ptr<ClientRpc> rpc,
- std::unique_ptr<protocol::PutResult> response,
- std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer)
+ DoPutPayloadWriter(
+ const FlightDescriptor& descriptor, std::unique_ptr<ClientRpc> rpc,
+ std::unique_ptr<pb::PutResult> response,
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer,
+ GrpcStreamWriter* stream_writer)
: descriptor_(descriptor),
rpc_(std::move(rpc)),
response_(std::move(response)),
writer_(std::move(writer)),
- first_payload_(true) {}
+ first_payload_(true),
+ stream_writer_(stream_writer) {}
~DoPutPayloadWriter() override = default;
@@ -201,6 +306,9 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
}
RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
first_payload_ = false;
+ } else if (ipc_payload.type == ipc::Message::RECORD_BATCH &&
+ stream_writer_->app_metadata_) {
+ payload.app_metadata = std::move(stream_writer_->app_metadata_);
}
if (!internal::WritePayload(payload, writer_.get())) {
@@ -211,6 +319,10 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
Status Close() override {
bool finished_writes = writer_->WritesDone();
+ // Drain the read side to avoid hanging
+ pb::PutResult message;
+ while (writer_->Read(&message)) {
+ }
RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish()));
if (!finished_writes) {
return Status::UnknownError(
@@ -223,9 +335,47 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
// TODO: there isn't a way to access this as a user.
const FlightDescriptor descriptor_;
std::unique_ptr<ClientRpc> rpc_;
- std::unique_ptr<protocol::PutResult> response_;
- std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer_;
+ std::unique_ptr<pb::PutResult> response_;
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer_;
bool first_payload_;
+ GrpcStreamWriter* stream_writer_;
+};
+
+Status GrpcStreamWriter::Open(
+ const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<ClientRpc> rpc, std::unique_ptr<pb::PutResult> response,
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer,
+ std::unique_ptr<FlightStreamWriter>* out) {
+ std::unique_ptr<GrpcStreamWriter> result(new GrpcStreamWriter);
+ std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(new DoPutPayloadWriter(
+ descriptor, std::move(rpc), std::move(response), writer, result.get()));
+ RETURN_NOT_OK(ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema,
+ &result->batch_writer_));
+ *out = std::move(result);
+ return Status::OK();
+}
+
+FlightMetadataReader::~FlightMetadataReader() = default;
+
+class GrpcMetadataReader : public FlightMetadataReader {
+ public:
+ explicit GrpcMetadataReader(
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader)
+ : reader_(reader) {}
+
+ Status ReadMetadata(std::shared_ptr<Buffer>* out) override {
+ pb::PutResult message;
+ if (reader_->Read(&message)) {
+ *out = Buffer::FromString(std::move(*message.release_app_metadata()));
+ } else {
+ // Stream finished
+ *out = nullptr;
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader_;
};
class FlightClient::FlightClientImpl {
@@ -367,7 +517,7 @@ class FlightClient::FlightClientImpl {
}
Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr<RecordBatchReader>* out) {
+ std::unique_ptr<FlightStreamReader>* out) {
pb::Ticket pb_ticket;
internal::ToProto(ticket, &pb_ticket);
@@ -376,25 +526,25 @@ class FlightClient::FlightClientImpl {
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream(
stub_->DoGet(&rpc->context, pb_ticket));
- std::unique_ptr<ipc::MessageReader> message_reader(
- new FlightIpcMessageReader(std::move(rpc), std::move(stream)));
- return ipc::RecordBatchStreamReader::Open(std::move(message_reader), out);
+ std::unique_ptr<GrpcStreamReader> reader;
+ RETURN_NOT_OK(GrpcStreamReader::Open(std::move(rpc), std::move(stream), &reader));
+ *out = std::move(reader);
+ return Status::OK();
}
Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
- std::unique_ptr<ipc::RecordBatchWriter>* out) {
+ std::unique_ptr<FlightStreamWriter>* out,
+ std::unique_ptr<FlightMetadataReader>* reader) {
std::unique_ptr<ClientRpc> rpc(new ClientRpc(options));
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
- std::unique_ptr<protocol::PutResult> response(new protocol::PutResult);
- std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer(
- stub_->DoPut(&rpc->context, response.get()));
-
- std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(
- new DoPutPayloadWriter(descriptor, std::move(rpc), std::move(response),
- std::move(writer)));
+ std::unique_ptr<pb::PutResult> response(new pb::PutResult);
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> writer(
+ stub_->DoPut(&rpc->context));
- return ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema, out);
+ *reader = std::unique_ptr<FlightMetadataReader>(new GrpcMetadataReader(writer));
+ return GrpcStreamWriter::Open(descriptor, schema, std::move(rpc), std::move(response),
+ writer, out);
}
private:
@@ -449,15 +599,16 @@ Status FlightClient::ListFlights(const FlightCallOptions& options,
}
Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr<RecordBatchReader>* stream) {
+ std::unique_ptr<FlightStreamReader>* stream) {
return impl_->DoGet(options, ticket, stream);
}
Status FlightClient::DoPut(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
- std::unique_ptr<ipc::RecordBatchWriter>* stream) {
- return impl_->DoPut(options, descriptor, schema, stream);
+ std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightMetadataReader>* reader) {
+ return impl_->DoPut(options, descriptor, schema, stream, reader);
}
} // namespace flight
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index b8a5d4f..0fa571d 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -25,6 +25,7 @@
#include <string>
#include <vector>
+#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/status.h"
@@ -35,7 +36,6 @@ namespace arrow {
class MemoryPool;
class RecordBatch;
-class RecordBatchReader;
class Schema;
namespace flight {
@@ -66,6 +66,43 @@ class ARROW_FLIGHT_EXPORT FlightClientOptions {
std::string override_hostname;
};
+/// \brief A RecordBatchReader exposing Flight metadata and cancel
+/// operations.
+class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader {
+ public:
+ /// \brief Try to cancel the call.
+ virtual void Cancel() = 0;
+};
+
+// Silence warning
+// "non dll-interface class RecordBatchReader used as base for dll-interface class"
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4275)
+#endif
+
+/// \brief A RecordBatchWriter that also allows sending
+/// application-defined metadata via the Flight protocol.
+class ARROW_FLIGHT_EXPORT FlightStreamWriter : public ipc::RecordBatchWriter {
+ public:
+ virtual Status WriteWithMetadata(const RecordBatch& batch,
+ std::shared_ptr<Buffer> app_metadata,
+ bool allow_64bit = false) = 0;
+};
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+/// \brief A reader for application-specific metadata sent back to the
+/// client during an upload.
+class ARROW_FLIGHT_EXPORT FlightMetadataReader {
+ public:
+ virtual ~FlightMetadataReader();
+ /// \brief Read a message from the server.
+ virtual Status ReadMetadata(std::shared_ptr<Buffer>* out) = 0;
+};
+
/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_FLIGHT_EXPORT FlightClient {
@@ -151,8 +188,8 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[out] stream the returned RecordBatchReader
/// \return Status
Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr<RecordBatchReader>* stream);
- Status DoGet(const Ticket& ticket, std::unique_ptr<RecordBatchReader>* stream) {
+ std::unique_ptr<FlightStreamReader>* stream);
+ Status DoGet(const Ticket& ticket, std::unique_ptr<FlightStreamReader>* stream) {
return DoGet({}, ticket, stream);
}
@@ -163,13 +200,16 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] descriptor the descriptor of the stream
/// \param[in] schema the schema for the data to upload
/// \param[out] stream a writer to write record batches to
+ /// \param[out] reader a reader for application metadata from the server
/// \return Status
Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
- std::unique_ptr<ipc::RecordBatchWriter>* stream);
+ std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightMetadataReader>* reader);
Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
- std::unique_ptr<ipc::RecordBatchWriter>* stream) {
- return DoPut({}, descriptor, schema, stream);
+ std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightMetadataReader>* reader) {
+ return DoPut({}, descriptor, schema, stream, reader);
}
private:
diff --git a/cpp/src/arrow/flight/flight-benchmark.cc b/cpp/src/arrow/flight/flight-benchmark.cc
index f2bd356..f5dd462 100644
--- a/cpp/src/arrow/flight/flight-benchmark.cc
+++ b/cpp/src/arrow/flight/flight-benchmark.cc
@@ -106,10 +106,10 @@ Status RunPerformanceTest(const std::string& hostname, const int port) {
perf::Token token;
token.ParseFromString(endpoint.ticket.ticket);
- std::unique_ptr<RecordBatchReader> reader;
+ std::unique_ptr<FlightStreamReader> reader;
RETURN_NOT_OK(client->DoGet(endpoint.ticket, &reader));
- std::shared_ptr<RecordBatch> batch;
+ FlightStreamChunk batch;
// This is hard-coded for right now, 4 columns each with int64
const int bytes_per_record = 32;
@@ -120,26 +120,26 @@ Status RunPerformanceTest(const std::string& hostname, const int port) {
int64_t num_bytes = 0;
int64_t num_records = 0;
while (true) {
- RETURN_NOT_OK(reader->ReadNext(&batch));
- if (!batch) {
+ RETURN_NOT_OK(reader->Next(&batch));
+ if (!batch.data) {
break;
}
if (verify) {
- auto values =
- reinterpret_cast<const int64_t*>(batch->column_data(0)->buffers[1]->data());
+ auto values = reinterpret_cast<const int64_t*>(
+ batch.data->column_data(0)->buffers[1]->data());
const int64_t start = token.start() + num_records;
- for (int64_t i = 0; i < batch->num_rows(); ++i) {
+ for (int64_t i = 0; i < batch.data->num_rows(); ++i) {
if (values[i] != start + i) {
return Status::Invalid("verification failure");
}
}
}
- num_records += batch->num_rows();
+ num_records += batch.data->num_rows();
// Hard-coded
- num_bytes += batch->num_rows() * bytes_per_record;
+ num_bytes += batch.data->num_rows() * bytes_per_record;
}
stats.Update(num_records, num_bytes);
return Status::OK();
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index 3c0b67c..c0f0c7f 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -211,19 +211,19 @@ class TestFlightClient : public ::testing::Test {
// By convention, fetch the first endpoint
Ticket ticket = info->endpoints()[0].ticket;
- std::unique_ptr<RecordBatchReader> stream;
+ std::unique_ptr<FlightStreamReader> stream;
ASSERT_OK(client_->DoGet(ticket, &stream));
- std::shared_ptr<RecordBatch> chunk;
+ FlightStreamChunk chunk;
for (int i = 0; i < num_batches; ++i) {
- ASSERT_OK(stream->ReadNext(&chunk));
- ASSERT_NE(nullptr, chunk);
- ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk);
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
}
// Stream exhausted
- ASSERT_OK(stream->ReadNext(&chunk));
- ASSERT_EQ(nullptr, chunk);
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
}
protected:
@@ -255,7 +255,8 @@ class TlsTestServer : public FlightServerBase {
class DoPutTestServer : public FlightServerBase {
public:
Status DoPut(const ServerCallContext& context,
- std::unique_ptr<FlightMessageReader> reader) override {
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
descriptor_ = reader->descriptor();
return reader->ReadAll(&batches_);
}
@@ -267,6 +268,70 @@ class DoPutTestServer : public FlightServerBase {
friend class TestDoPut;
};
+class MetadataTestServer : public FlightServerBase {
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ std::shared_ptr<RecordBatchReader> batch_reader =
+ std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+
+ *data_stream = std::unique_ptr<FlightDataStream>(new NumberingStream(
+ std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader))));
+ return Status::OK();
+ }
+
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
+ FlightStreamChunk chunk;
+ int counter = 0;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (chunk.data == nullptr) break;
+ if (chunk.app_metadata == nullptr) {
+ return Status::Invalid("Expected application metadata to be provided");
+ }
+ if (std::to_string(counter) != chunk.app_metadata->ToString()) {
+ return Status::Invalid("Expected metadata value: " + std::to_string(counter) +
+ " but got: " + chunk.app_metadata->ToString());
+ }
+ auto metadata = Buffer::FromString(std::to_string(counter));
+ RETURN_NOT_OK(writer->WriteMetadata(*metadata));
+ counter++;
+ }
+ return Status::OK();
+ }
+};
+
+template <typename T>
+class InsecureTestServer : public ::testing::Test {
+ public:
+ void SetUp() {
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location));
+
+ std::unique_ptr<FlightServerBase> server(new T);
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+
+ server_.reset(new InProcessTestServer(std::move(server), location));
+ ASSERT_OK(server_->Start());
+ ASSERT_OK(ConnectClient());
+ }
+
+ void TearDown() { server_->Stop(); }
+
+ Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }
+
+ protected:
+ int port_;
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<InProcessTestServer> server_;
+};
+
+using TestMetadata = InsecureTestServer<MetadataTestServer>;
+
class TestAuthHandler : public ::testing::Test {
public:
void SetUp() {
@@ -323,8 +388,9 @@ class TestDoPut : public ::testing::Test {
void CheckDoPut(FlightDescriptor descr, const std::shared_ptr<Schema>& schema,
const BatchVector& batches) {
- std::unique_ptr<ipc::RecordBatchWriter> stream;
- ASSERT_OK(client_->DoPut(descr, schema, &stream));
+ std::unique_ptr<FlightStreamWriter> stream;
+ std::unique_ptr<FlightMetadataReader> reader;
+ ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
for (const auto& batch : batches) {
ASSERT_OK(stream->WriteRecordBatch(*batch));
}
@@ -485,7 +551,7 @@ TEST_F(TestFlightClient, Issue5095) {
// Make sure the server-side error message is reflected to the
// client
Ticket ticket1{"ARROW-5095-fail"};
- std::unique_ptr<RecordBatchReader> stream;
+ std::unique_ptr<FlightStreamReader> stream;
Status status = client_->DoGet(ticket1, &stream);
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
@@ -588,13 +654,14 @@ TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
status = client_->GetFlightInfo(FlightDescriptor{}, &info);
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<RecordBatchReader> stream;
+ std::unique_ptr<FlightStreamReader> stream;
status = client_->DoGet(Ticket{}, &stream);
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<ipc::RecordBatchWriter> writer;
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema = arrow::schema({});
- status = client_->DoPut(FlightDescriptor{}, schema, &writer);
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
ASSERT_OK(status);
status = writer->Close();
ASSERT_RAISES(NotImplemented, status);
@@ -625,15 +692,16 @@ TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<RecordBatchReader> stream;
+ std::unique_ptr<FlightStreamReader> stream;
status = client_->DoGet(Ticket{}, &stream);
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<ipc::RecordBatchWriter> writer;
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema(
(new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
- status = client_->DoPut(FlightDescriptor{}, schema, &writer);
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
ASSERT_OK(status);
status = writer->Close();
ASSERT_RAISES(IOError, status);
@@ -693,5 +761,72 @@ TEST_F(TestTls, OverrideHostname) {
ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
}
+TEST_F(TestMetadata, DoGet) {
+ Ticket ticket{""};
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ FlightStreamChunk chunk;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
+ }
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+}
+
+TEST_F(TestMetadata, DoPut) {
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema = ExampleIntSchema();
+ ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ std::shared_ptr<RecordBatch> chunk;
+ std::shared_ptr<Buffer> metadata;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
+ Buffer::FromString(std::to_string(i))));
+ }
+ // This eventually calls grpc::ClientReaderWriter::Finish which can
+ // hang if there are unread messages. So make sure our wrapper
+ // around this doesn't hang (because it drains any unread messages)
+ ASSERT_OK(writer->Close());
+}
+
+TEST_F(TestMetadata, DoPutReadMetadata) {
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema = ExampleIntSchema();
+ ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ std::shared_ptr<RecordBatch> chunk;
+ std::shared_ptr<Buffer> metadata;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
+ Buffer::FromString(std::to_string(i))));
+ ASSERT_OK(reader->ReadMetadata(&metadata));
+ ASSERT_NE(nullptr, metadata);
+ ASSERT_EQ(std::to_string(i), metadata->ToString());
+ }
+ // As opposed to DoPutDrainMetadata, now we've read the messages, so
+ // make sure this still closes as expected.
+ ASSERT_OK(writer->Close());
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h
index 784e8eb..5283bed 100644
--- a/cpp/src/arrow/flight/internal.h
+++ b/cpp/src/arrow/flight/internal.h
@@ -63,7 +63,8 @@ namespace flight {
namespace internal {
-static const char* AUTH_HEADER = "auth-token-bin";
+/// The name of the header used to pass authentication tokens.
+static const char* kGrpcAuthHeader = "auth-token-bin";
ARROW_FLIGHT_EXPORT
Status SchemaToString(const Schema& schema, std::string* out);
diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc
index d78bac8..0ff5aba 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/serialization-internal.cc
@@ -163,6 +163,14 @@ grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out,
// 1 byte for metadata tag
header_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
+ // App metadata tag if appropriate
+ int32_t app_metadata_size = 0;
+ if (msg.app_metadata && msg.app_metadata->size() > 0) {
+ DCHECK_LT(msg.app_metadata->size(), kInt32Max);
+ app_metadata_size = static_cast<int32_t>(msg.app_metadata->size());
+ header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size);
+ }
+
for (const auto& buffer : ipc_msg.body_buffers) {
// Buffer may be null when the row length is zero, or when all
// entries are invalid.
@@ -214,6 +222,15 @@ grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out,
header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
static_cast<int>(ipc_msg.metadata->size()));
+ // Write app metadata
+ if (app_metadata_size > 0) {
+ WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
+ header_stream.WriteVarint32(app_metadata_size);
+ header_stream.WriteRawMaybeAliased(msg.app_metadata->data(),
+ static_cast<int>(msg.app_metadata->size()));
+ }
+
if (has_body) {
// Write body tag
WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
@@ -292,6 +309,12 @@ grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) {
"Unable to read FlightData metadata");
}
} break;
+ case pb::FlightData::kAppMetadataFieldNumber: {
+ if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) {
+ return grpc::Status(grpc::StatusCode::INTERNAL,
+ "Unable to read FlightData application metadata");
+ }
+ } break;
case pb::FlightData::kDataBodyFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
return grpc::Status(grpc::StatusCode::INTERNAL,
@@ -330,7 +353,7 @@ Status FlightData::OpenMessage(std::unique_ptr<ipc::Message>* message) {
// (see customize_protobuf.h).
bool WritePayload(const FlightPayload& payload,
- grpc::ClientWriter<pb::FlightData>* writer) {
+ grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>* writer) {
// Pretend to be pb::FlightData and intercept in SerializationTraits
return writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
grpc::WriteOptions());
@@ -348,7 +371,8 @@ bool ReadPayload(grpc::ClientReader<pb::FlightData>* reader, FlightData* data) {
return reader->Read(reinterpret_cast<pb::FlightData*>(data));
}
-bool ReadPayload(grpc::ServerReader<pb::FlightData>* reader, FlightData* data) {
+bool ReadPayload(grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader,
+ FlightData* data) {
// Pretend to be pb::FlightData and intercept in SerializationTraits
return reader->Read(reinterpret_cast<pb::FlightData*>(data));
}
diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h
index aa47af6..cfb8b8a 100644
--- a/cpp/src/arrow/flight/serialization-internal.h
+++ b/cpp/src/arrow/flight/serialization-internal.h
@@ -43,6 +43,9 @@ struct FlightData {
/// Non-length-prefixed Message header as described in format/Message.fbs
std::shared_ptr<Buffer> metadata;
+ /// Application-defined metadata
+ std::shared_ptr<Buffer> app_metadata;
+
/// Message body
std::shared_ptr<Buffer> body;
@@ -53,14 +56,15 @@ struct FlightData {
/// Write Flight message on gRPC stream with zero-copy optimizations.
/// True is returned on success, false if some error occurred (connection closed?).
bool WritePayload(const FlightPayload& payload,
- grpc::ClientWriter<pb::FlightData>* writer);
+ grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>* writer);
bool WritePayload(const FlightPayload& payload,
grpc::ServerWriter<pb::FlightData>* writer);
/// Read Flight message from gRPC stream with zero-copy optimizations.
/// True is returned on success, false if stream ended.
bool ReadPayload(grpc::ClientReader<pb::FlightData>* reader, FlightData* data);
-bool ReadPayload(grpc::ServerReader<pb::FlightData>* reader, FlightData* data);
+bool ReadPayload(grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader,
+ FlightData* data);
} // namespace internal
} // namespace flight
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 6f3c466..d059a8b 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -72,12 +72,15 @@ namespace {
// A MessageReader implementation that reads from a gRPC ServerReader
class FlightIpcMessageReader : public ipc::MessageReader {
public:
- explicit FlightIpcMessageReader(grpc::ServerReader<pb::FlightData>* reader)
- : reader_(reader) {}
+ explicit FlightIpcMessageReader(
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader,
+ std::shared_ptr<Buffer>* last_metadata)
+ : reader_(reader), app_metadata_(last_metadata) {}
Status ReadNextMessage(std::unique_ptr<ipc::Message>* out) override {
if (stream_finished_) {
*out = nullptr;
+ *app_metadata_ = nullptr;
return Status::OK();
}
internal::FlightData data;
@@ -89,6 +92,7 @@ class FlightIpcMessageReader : public ipc::MessageReader {
"Client provided malformed message or did not provide message");
}
*out = nullptr;
+ *app_metadata_ = nullptr;
return Status::OK();
}
@@ -100,25 +104,29 @@ class FlightIpcMessageReader : public ipc::MessageReader {
first_message_ = false;
}
- return data.OpenMessage(out);
+ RETURN_NOT_OK(data.OpenMessage(out));
+ *app_metadata_ = std::move(data.app_metadata);
+ return Status::OK();
}
const FlightDescriptor& descriptor() const { return descriptor_; }
protected:
- grpc::ServerReader<pb::FlightData>* reader_;
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader_;
bool stream_finished_ = false;
bool first_message_ = true;
FlightDescriptor descriptor_;
+ std::shared_ptr<Buffer>* app_metadata_;
};
class FlightMessageReaderImpl : public FlightMessageReader {
public:
- explicit FlightMessageReaderImpl(grpc::ServerReader<pb::FlightData>* reader)
+ explicit FlightMessageReaderImpl(
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader)
: reader_(reader) {}
Status Init() {
- message_reader_ = new FlightIpcMessageReader(reader_);
+ message_reader_ = new FlightIpcMessageReader(reader_, &last_metadata_);
return ipc::RecordBatchStreamReader::Open(
std::unique_ptr<ipc::MessageReader>(message_reader_), &batch_reader_);
}
@@ -129,18 +137,41 @@ class FlightMessageReaderImpl : public FlightMessageReader {
std::shared_ptr<Schema> schema() const override { return batch_reader_->schema(); }
- Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
- return batch_reader_->ReadNext(out);
+ Status Next(FlightStreamChunk* out) override {
+ out->app_metadata = nullptr;
+ RETURN_NOT_OK(batch_reader_->ReadNext(&out->data));
+ out->app_metadata = std::move(last_metadata_);
+ return Status::OK();
}
private:
std::shared_ptr<Schema> schema_;
std::unique_ptr<ipc::DictionaryMemo> dictionary_memo_;
- grpc::ServerReader<pb::FlightData>* reader_;
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader_;
FlightIpcMessageReader* message_reader_;
+ std::shared_ptr<Buffer> last_metadata_;
std::shared_ptr<RecordBatchReader> batch_reader_;
};
+class GrpcMetadataWriter : public FlightMetadataWriter {
+ public:
+ explicit GrpcMetadataWriter(
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer)
+ : writer_(writer) {}
+
+ Status WriteMetadata(const Buffer& buffer) override {
+ pb::PutResult message{};
+ message.set_app_metadata(buffer.data(), buffer.size());
+ if (writer_->Write(message)) {
+ return Status::OK();
+ }
+ return Status::IOError("Unknown error writing metadata.");
+ }
+
+ private:
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer_;
+};
+
class GrpcServerAuthReader : public ServerAuthReader {
public:
explicit GrpcServerAuthReader(
@@ -153,7 +184,7 @@ class GrpcServerAuthReader : public ServerAuthReader {
*token = std::move(*request.release_payload());
return Status::OK();
}
- return Status::UnknownError("Could not read client handshake request.");
+ return Status::IOError("Stream is closed.");
}
private:
@@ -246,7 +277,7 @@ class FlightServiceImpl : public FlightService::Service {
}
const auto client_metadata = context->client_metadata();
- const auto auth_header = client_metadata.find(internal::AUTH_HEADER);
+ const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader);
std::string token;
if (auth_header == client_metadata.end()) {
token = "";
@@ -349,16 +380,18 @@ class FlightServiceImpl : public FlightService::Service {
return grpc::Status::OK;
}
- grpc::Status DoPut(ServerContext* context, grpc::ServerReader<pb::FlightData>* reader,
- pb::PutResult* response) {
+ grpc::Status DoPut(ServerContext* context,
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader) {
GrpcServerCallContext flight_context;
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(context, flight_context));
auto message_reader =
std::unique_ptr<FlightMessageReaderImpl>(new FlightMessageReaderImpl(reader));
GRPC_RETURN_NOT_OK(message_reader->Init());
- return internal::ToGrpcStatus(
- server_->DoPut(flight_context, std::move(message_reader)));
+ auto metadata_writer =
+ std::unique_ptr<FlightMetadataWriter>(new GrpcMetadataWriter(reader));
+ return internal::ToGrpcStatus(server_->DoPut(
+ flight_context, std::move(message_reader), std::move(metadata_writer)));
}
grpc::Status ListActions(ServerContext* context, const pb::Empty* request,
@@ -410,6 +443,8 @@ class FlightServiceImpl : public FlightService::Service {
} // namespace
+FlightMetadataWriter::~FlightMetadataWriter() = default;
+
//
// gRPC server lifecycle
//
@@ -572,7 +607,8 @@ Status FlightServerBase::DoGet(const ServerCallContext& context, const Ticket& r
}
Status FlightServerBase::DoPut(const ServerCallContext& context,
- std::unique_ptr<FlightMessageReader> reader) {
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) {
return Status::NotImplemented("NYI");
}
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index c1bcb5c..25656e6 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -74,23 +74,22 @@ class ARROW_FLIGHT_EXPORT RecordBatchStream : public FlightDataStream {
std::unique_ptr<RecordBatchStreamImpl> impl_;
};
-// Silence warning
-// "non dll-interface class RecordBatchReader used as base for dll-interface class"
-#ifdef _MSC_VER
-#pragma warning(push)
-#pragma warning(disable : 4275)
-#endif
-
-/// \brief A reader for IPC payloads uploaded by a client
-class ARROW_FLIGHT_EXPORT FlightMessageReader : public RecordBatchReader {
+/// \brief A reader for IPC payloads uploaded by a client. Also allows
+/// reading application-defined metadata via the Flight protocol.
+class ARROW_FLIGHT_EXPORT FlightMessageReader : public MetadataRecordBatchReader {
public:
/// \brief Get the descriptor for this upload.
virtual const FlightDescriptor& descriptor() const = 0;
};
-#ifdef _MSC_VER
-#pragma warning(pop)
-#endif
+/// \brief A writer for application-specific metadata sent back to the
+/// client during an upload.
+class ARROW_FLIGHT_EXPORT FlightMetadataWriter {
+ public:
+ virtual ~FlightMetadataWriter();
+ /// \brief Send a message to the client.
+ virtual Status WriteMetadata(const Buffer& app_metadata) = 0;
+};
/// \brief Call state/contextual data.
class ARROW_FLIGHT_EXPORT ServerCallContext {
@@ -178,9 +177,11 @@ class ARROW_FLIGHT_EXPORT FlightServerBase {
/// \brief Process a stream of IPC payloads sent from a client
/// \param[in] context The call context.
/// \param[in] reader a sequence of uploaded record batches
+ /// \param[in] writer send metadata back to the client
/// \return Status
virtual Status DoPut(const ServerCallContext& context,
- std::unique_ptr<FlightMessageReader> reader);
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer);
/// \brief Execute an action, return stream of zero or more results
/// \param[in] context The call context.
diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc
index abaa3bc..b02595e 100644
--- a/cpp/src/arrow/flight/test-integration-client.cc
+++ b/cpp/src/arrow/flight/test-integration-client.cc
@@ -44,17 +44,26 @@ DEFINE_string(host, "localhost", "Server port to connect to");
DEFINE_int32(port, 31337, "Server port to connect to");
DEFINE_string(path, "", "Resource path to request");
-/// \brief Helper to read a RecordBatchReader into a Table.
-arrow::Status ReadToTable(std::unique_ptr<arrow::RecordBatchReader>& reader,
+/// \brief Helper to read a MetadataRecordBatchReader into a Table.
+arrow::Status ReadToTable(arrow::flight::MetadataRecordBatchReader& reader,
std::shared_ptr<arrow::Table>* retrieved_data) {
+ // For integration testing, we expect the server numbers the
+ // batches, to test the application metadata part of the spec.
std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
- std::shared_ptr<arrow::RecordBatch> chunk;
+ arrow::flight::FlightStreamChunk chunk;
+ int counter = 0;
while (true) {
- RETURN_NOT_OK(reader->ReadNext(&chunk));
- if (chunk == nullptr) break;
- retrieved_chunks.push_back(chunk);
+ RETURN_NOT_OK(reader.Next(&chunk));
+ if (!chunk.data) break;
+ retrieved_chunks.push_back(chunk.data);
+ if (std::to_string(counter) != chunk.app_metadata->ToString()) {
+ return arrow::Status::Invalid(
+ "Expected metadata value: " + std::to_string(counter) +
+ " but got: " + chunk.app_metadata->ToString());
+ }
+ counter++;
}
- return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
+ return arrow::Table::FromRecordBatches(reader.schema(), retrieved_chunks,
retrieved_data);
}
@@ -71,14 +80,27 @@ arrow::Status ReadToTable(std::unique_ptr<arrow::ipc::internal::json::JsonReader
retrieved_data);
}
-/// \brief Helper to copy a RecordBatchReader to a RecordBatchWriter.
-arrow::Status CopyReaderToWriter(std::unique_ptr<arrow::RecordBatchReader>& reader,
- arrow::ipc::RecordBatchWriter& writer) {
+/// \brief Upload the contents of a RecordBatchReader to a Flight
+/// server, validating the application metadata on the side.
+arrow::Status UploadReaderToFlight(arrow::RecordBatchReader* reader,
+ arrow::flight::FlightStreamWriter& writer,
+ arrow::flight::FlightMetadataReader& metadata_reader) {
+ int counter = 0;
while (true) {
std::shared_ptr<arrow::RecordBatch> chunk;
RETURN_NOT_OK(reader->ReadNext(&chunk));
if (chunk == nullptr) break;
- RETURN_NOT_OK(writer.WriteRecordBatch(*chunk));
+ std::shared_ptr<arrow::Buffer> metadata =
+ arrow::Buffer::FromString(std::to_string(counter));
+ RETURN_NOT_OK(writer.WriteWithMetadata(*chunk, metadata));
+ // Wait for the server to ack the result
+ std::shared_ptr<arrow::Buffer> ack_metadata;
+ RETURN_NOT_OK(metadata_reader.ReadMetadata(&ack_metadata));
+ if (!ack_metadata->Equals(*metadata)) {
+ return arrow::Status::Invalid("Expected metadata value: " + metadata->ToString() +
+ " but got: " + ack_metadata->ToString());
+ }
+ counter++;
}
return writer.Close();
}
@@ -91,10 +113,10 @@ arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location,
std::unique_ptr<arrow::flight::FlightClient> read_client;
RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, &read_client));
- std::unique_ptr<arrow::RecordBatchReader> stream;
+ std::unique_ptr<arrow::flight::FlightStreamReader> stream;
RETURN_NOT_OK(read_client->DoGet(ticket, &stream));
- return ReadToTable(stream, retrieved_data);
+ return ReadToTable(*stream, retrieved_data);
}
int main(int argc, char** argv) {
@@ -120,11 +142,12 @@ int main(int argc, char** argv) {
std::shared_ptr<arrow::Table> original_data;
ABORT_NOT_OK(ReadToTable(reader, &original_data));
- std::unique_ptr<arrow::ipc::RecordBatchWriter> write_stream;
- ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream));
+ std::unique_ptr<arrow::flight::FlightStreamWriter> write_stream;
+ std::unique_ptr<arrow::flight::FlightMetadataReader> metadata_reader;
+ ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream, &metadata_reader));
std::unique_ptr<arrow::RecordBatchReader> table_reader(
new arrow::TableBatchReader(*original_data));
- ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream));
+ ABORT_NOT_OK(UploadReaderToFlight(table_reader.get(), *write_stream, *metadata_reader));
// 2. Get the ticket for the data.
std::unique_ptr<arrow::flight::FlightInfo> info;
diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc
index c5bb180..fe6b53d 100644
--- a/cpp/src/arrow/flight/test-integration-server.cc
+++ b/cpp/src/arrow/flight/test-integration-server.cc
@@ -79,14 +79,16 @@ class FlightIntegrationTestServer : public FlightServerBase {
}
auto flight = data->second;
- *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(
- std::shared_ptr<RecordBatchReader>(new TableBatchReader(*flight))));
+ *data_stream = std::unique_ptr<FlightDataStream>(
+ new NumberingStream(std::unique_ptr<FlightDataStream>(new RecordBatchStream(
+ std::shared_ptr<RecordBatchReader>(new TableBatchReader(*flight))))));
return Status::OK();
}
Status DoPut(const ServerCallContext& context,
- std::unique_ptr<FlightMessageReader> reader) override {
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
const FlightDescriptor& descriptor = reader->descriptor();
if (descriptor.type != FlightDescriptor::DescriptorType::PATH) {
@@ -98,11 +100,14 @@ class FlightIntegrationTestServer : public FlightServerBase {
std::string key = descriptor.path[0];
std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
- std::shared_ptr<arrow::RecordBatch> chunk;
+ arrow::flight::FlightStreamChunk chunk;
while (true) {
- RETURN_NOT_OK(reader->ReadNext(&chunk));
- if (chunk == nullptr) break;
- retrieved_chunks.push_back(chunk);
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (chunk.data == nullptr) break;
+ retrieved_chunks.push_back(chunk.data);
+ if (chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata));
+ }
}
std::shared_ptr<arrow::Table> retrieved_data;
RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index 7dd78fd..4408801 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -260,6 +260,24 @@ Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
return internal::SchemaToString(schema, &out->schema);
}
+NumberingStream::NumberingStream(std::unique_ptr<FlightDataStream> stream)
+ : counter_(0), stream_(std::move(stream)) {}
+
+std::shared_ptr<Schema> NumberingStream::schema() { return stream_->schema(); }
+
+Status NumberingStream::GetSchemaPayload(FlightPayload* payload) {
+ return stream_->GetSchemaPayload(payload);
+}
+
+Status NumberingStream::Next(FlightPayload* payload) {
+ RETURN_NOT_OK(stream_->Next(payload));
+ if (payload && payload->ipc_message.type == ipc::Message::RECORD_BATCH) {
+ payload->app_metadata = Buffer::FromString(std::to_string(counter_));
+ counter_++;
+ }
+ return Status::OK();
+}
+
std::shared_ptr<Schema> ExampleIntSchema() {
auto f0 = field("f0", int32());
auto f1 = field("f1", int32());
diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h
index 5b02630..7fb0b60 100644
--- a/cpp/src/arrow/flight/test-util.h
+++ b/cpp/src/arrow/flight/test-util.h
@@ -25,6 +25,7 @@
#include "arrow/status.h"
#include "arrow/flight/client_auth.h"
+#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
@@ -128,6 +129,23 @@ class ARROW_FLIGHT_EXPORT BatchIterator : public RecordBatchReader {
#endif
// ----------------------------------------------------------------------
+// A FlightDataStream that numbers the record batches
+/// \brief A basic implementation of FlightDataStream that will provide
+/// a sequence of FlightData messages to be written to a gRPC stream
+class ARROW_FLIGHT_EXPORT NumberingStream : public FlightDataStream {
+ public:
+ explicit NumberingStream(std::unique_ptr<FlightDataStream> stream);
+
+ std::shared_ptr<Schema> schema() override;
+ Status GetSchemaPayload(FlightPayload* payload) override;
+ Status Next(FlightPayload* payload) override;
+
+ private:
+ int counter_;
+ std::shared_ptr<FlightDataStream> stream_;
+};
+
+// ----------------------------------------------------------------------
// Example data for test-server and unit tests
using BatchVector = std::vector<std::shared_ptr<RecordBatch>>;
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index d982efc..c82e681 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -25,6 +25,7 @@
#include "arrow/ipc/dictionary.h"
#include "arrow/ipc/reader.h"
#include "arrow/status.h"
+#include "arrow/table.h"
#include "arrow/util/uri.h"
namespace arrow {
@@ -122,6 +123,24 @@ bool Location::Equals(const Location& other) const {
return ToString() == other.ToString();
}
+Status MetadataRecordBatchReader::ReadAll(
+ std::vector<std::shared_ptr<RecordBatch>>* batches) {
+ FlightStreamChunk chunk;
+
+ while (true) {
+ RETURN_NOT_OK(Next(&chunk));
+ if (!chunk.data) break;
+ batches->emplace_back(std::move(chunk.data));
+ }
+ return Status::OK();
+}
+
+Status MetadataRecordBatchReader::ReadAll(std::shared_ptr<Table>* table) {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ RETURN_NOT_OK(ReadAll(&batches));
+ return Table::FromRecordBatches(schema(), batches, table);
+}
+
SimpleFlightListing::SimpleFlightListing(const std::vector<FlightInfo>& flights)
: position_(0), flights_(flights) {}
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index e5f7bcd..abf894c 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -32,8 +32,10 @@
namespace arrow {
class Buffer;
+class RecordBatch;
class Schema;
class Status;
+class Table;
namespace ipc {
@@ -205,6 +207,7 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint {
/// This structure corresponds to FlightData in the protocol.
struct ARROW_FLIGHT_EXPORT FlightPayload {
std::shared_ptr<Buffer> descriptor;
+ std::shared_ptr<Buffer> app_metadata;
ipc::internal::IpcPayload ipc_message;
};
@@ -278,6 +281,30 @@ class ARROW_FLIGHT_EXPORT ResultStream {
virtual Status Next(std::unique_ptr<Result>* info) = 0;
};
+/// \brief A holder for a RecordBatch with associated Flight metadata.
+struct ARROW_FLIGHT_EXPORT FlightStreamChunk {
+ public:
+ std::shared_ptr<RecordBatch> data;
+ std::shared_ptr<Buffer> app_metadata;
+};
+
+/// \brief An interface to read Flight data with metadata.
+class ARROW_FLIGHT_EXPORT MetadataRecordBatchReader {
+ public:
+ virtual ~MetadataRecordBatchReader() = default;
+
+ /// \brief Get the schema for this stream.
+ virtual std::shared_ptr<Schema> schema() const = 0;
+ /// \brief Get the next message from Flight. If the stream is
+ /// finished, then the members of \a FlightStreamChunk will be
+ /// nullptr.
+ virtual Status Next(FlightStreamChunk* next) = 0;
+ /// \brief Consume entire stream as a vector of record batches
+ virtual Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches);
+ /// \brief Consume entire stream as a Table
+ virtual Status ReadAll(std::shared_ptr<Table>* table);
+};
+
// \brief Create a FlightListing from a vector of FlightInfo objects. This can
// be iterated once, then it is consumed
class ARROW_FLIGHT_EXPORT SimpleFlightListing : public FlightListing {
diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc
index 717f29a..992c4fc 100644
--- a/cpp/src/arrow/ipc/reader.cc
+++ b/cpp/src/arrow/ipc/reader.cc
@@ -455,7 +455,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
RETURN_NOT_OK(
ReadMessageAndValidate(message_reader_.get(), /*allow_null=*/false, &message));
- CHECK_MESSAGE_TYPE(message->type(), Message::SCHEMA);
+ CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type());
CHECK_HAS_NO_BODY(*message);
if (message->header() == nullptr) {
return Status::IOError("Header-pointer of flatbuffer-encoded Message is null.");
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index ee19fb9..da5026a 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -108,10 +108,12 @@ Status PyFlightServer::DoGet(const arrow::flight::ServerCallContext& context,
});
}
-Status PyFlightServer::DoPut(const arrow::flight::ServerCallContext& context,
- std::unique_ptr<arrow::flight::FlightMessageReader> reader) {
+Status PyFlightServer::DoPut(
+ const arrow::flight::ServerCallContext& context,
+ std::unique_ptr<arrow::flight::FlightMessageReader> reader,
+ std::unique_ptr<arrow::flight::FlightMetadataWriter> writer) {
return SafeCallIntoPython([&] {
- vtable_.do_put(server_.obj(), context, std::move(reader));
+ vtable_.do_put(server_.obj(), context, std::move(reader), std::move(writer));
return CheckPyError();
});
}
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index aecb97a..5aea7e8 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -50,7 +50,8 @@ class ARROW_PYTHON_EXPORT PyFlightServerVtable {
std::unique_ptr<arrow::flight::FlightDataStream>*)>
do_get;
std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
- std::unique_ptr<arrow::flight::FlightMessageReader>)>
+ std::unique_ptr<arrow::flight::FlightMessageReader>,
+ std::unique_ptr<arrow::flight::FlightMetadataWriter>)>
do_put;
std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
const arrow::flight::Action&,
@@ -123,7 +124,8 @@ class ARROW_PYTHON_EXPORT PyFlightServer : public arrow::flight::FlightServerBas
const arrow::flight::Ticket& request,
std::unique_ptr<arrow::flight::FlightDataStream>* stream) override;
Status DoPut(const arrow::flight::ServerCallContext& context,
- std::unique_ptr<arrow::flight::FlightMessageReader> reader) override;
+ std::unique_ptr<arrow::flight::FlightMessageReader> reader,
+ std::unique_ptr<arrow::flight::FlightMetadataWriter> writer) override;
Status DoAction(const arrow::flight::ServerCallContext& context,
const arrow::flight::Action& action,
std::unique_ptr<arrow::flight::ResultStream>* result) override;
diff --git a/docs/source/conf.py b/docs/source/conf.py
index d525fa9..e605125 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -416,7 +416,16 @@ except ImportError:
from unittest import mock
pyarrow.cuda = sys.modules['pyarrow.cuda'] = mock.Mock()
+try:
+ import pyarrow.flight
+ flight_enabled = True
+except ImportError:
+ flight_enabled = False
+ pyarrow.flight = sys.modules['pyarrow.flight'] = mock.Mock()
+
+
def setup(app):
# Use a config value to indicate whether CUDA API docs can be generated.
# This will also rebuild appropriately when the value changes.
app.add_config_value('cuda_enabled', cuda_enabled, 'env')
+ app.add_config_value('flight_enabled', flight_enabled, 'env')
diff --git a/docs/source/cpp/api.rst b/docs/source/cpp/api.rst
index 522609e..1c113b7 100644
--- a/docs/source/cpp/api.rst
+++ b/docs/source/cpp/api.rst
@@ -30,3 +30,4 @@ API Reference
api/table
api/utilities
api/cuda
+ api/flight
diff --git a/docs/source/cpp/api/flight.rst b/docs/source/cpp/api/flight.rst
new file mode 100644
index 0000000..4e56a76
--- /dev/null
+++ b/docs/source/cpp/api/flight.rst
@@ -0,0 +1,126 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+.. or more contributor license agreements. See the NOTICE file
+.. distributed with this work for additional information
+.. regarding copyright ownership. The ASF licenses this file
+.. to you under the Apache License, Version 2.0 (the
+.. "License"); you may not use this file except in compliance
+.. with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+.. software distributed under the License is distributed on an
+.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+.. KIND, either express or implied. See the License for the
+.. specific language governing permissions and limitations
+.. under the License.
+
+================
+Arrow Flight RPC
+================
+
+.. warning:: Flight is currently unstable. APIs are subject to change,
+ though we don't expect drastic changes.
+
+.. warning:: Flight is currently only available when built from source
+ appropriately.
+
+Common Types
+============
+
+.. doxygenstruct:: arrow::flight::Action
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::ActionType
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::Criteria
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::FlightDescriptor
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::FlightEndpoint
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::FlightInfo
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::FlightPayload
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::FlightListing
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::Location
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::PutResult
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::Result
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::ResultStream
+ :project: arrow_cpp
+ :members:
+
+.. doxygenstruct:: arrow::flight::Ticket
+ :project: arrow_cpp
+ :members:
+
+Clients
+=======
+
+.. doxygenclass:: arrow::flight::FlightClient
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::FlightCallOptions
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::ClientAuthHandler
+ :project: arrow_cpp
+ :members:
+
+.. doxygentypedef:: arrow::flight::TimeoutDuration
+ :project: arrow_cpp
+
+Servers
+=======
+
+.. doxygenclass:: arrow::flight::FlightServerBase
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::FlightDataStream
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::FlightMessageReader
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::RecordBatchStream
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::ServerAuthHandler
+ :project: arrow_cpp
+ :members:
+
+.. doxygenclass:: arrow::flight::ServerCallContext
+ :project: arrow_cpp
+ :members:
diff --git a/docs/source/format/Flight.rst b/docs/source/format/Flight.rst
new file mode 100644
index 0000000..b3476ea
--- /dev/null
+++ b/docs/source/format/Flight.rst
@@ -0,0 +1,106 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+.. or more contributor license agreements. See the NOTICE file
+.. distributed with this work for additional information
+.. regarding copyright ownership. The ASF licenses this file
+.. to you under the Apache License, Version 2.0 (the
+.. "License"); you may not use this file except in compliance
+.. with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+.. software distributed under the License is distributed on an
+.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+.. KIND, either express or implied. See the License for the
+.. specific language governing permissions and limitations
+.. under the License.
+
+Arrow Flight RPC
+================
+
+Arrow Flight is a RPC framework for high-performance data services
+based on Arrow data, and is built on top of gRPC_ and the :doc:`IPC
+format <IPC>`.
+
+Flight is organized around streams of Arrow record batches, being
+either downloaded from or uploaded to another service. A set of
+metadata methods offers discovery and introspection of streams, as
+well as the ability to implement application-specific methods.
+
+Methods and message wire formats are defined by Protobuf, enabling
+interoperability with clients that may support gRPC and Arrow
+separately, but not Flight. However, Flight implementations include
+further optimizations to avoid overhead in usage of Protobuf (mostly
+around avoiding excessive memory copies).
+
+.. _gRPC: https://grpc.io/
+
+RPC Methods
+-----------
+
+Flight defines a set of RPC methods for uploading/downloading data,
+retrieving metadata about a data stream, listing available data
+streams, and for implementing application-specific RPC methods. A
+Flight service implements some subset of these methods, while a Flight
+client can call any of these methods. Thus, one Flight client can
+connect to any Flight service and perform basic operations.
+
+Data streams are identified by descriptors, which are either a path or
+an arbitrary binary command. A client that wishes to download the data
+would:
+
+#. Construct or acquire a ``FlightDescriptor`` for the data set they
+ are interested in. A client may know what descriptor they want
+ already, or they may use methods like ``ListFlights`` to discover
+ them.
+#. Call ``GetFlightInfo(FlightDescriptor)`` to get a ``FlightInfo``
+ message containing details on where the data is located (as well as
+ other metadata, like the schema and possibly an estimate of the
+ dataset size).
+
+ Flight does not require that data live on the same server as
+ metadata: this call may list other servers to connect to. The
+ ``FlightInfo`` message includes a ``Ticket``, an opaque binary
+ token that the server uses to identify the exact data set being
+ requested.
+#. Connect to other servers (if needed).
+#. Call ``DoGet(Ticket)`` to get back a stream of Arrow record
+ batches.
+
+To upload data, a client would:
+
+#. Construct or acquire a ``FlightDescriptor``, as before.
+#. Call ``DoPut(FlightData)`` and upload a stream of Arrow record
+ batches. They would also include the ``FlightDescriptor`` with the
+ first message.
+
+See `Protocol Buffer Definitions`_ for full details on the methods and
+messages involved.
+
+Authentication
+~~~~~~~~~~~~~~
+
+Flight supports application-implemented authentication
+methods. Authentication, if enabled, has two phases: at connection
+time, the client and server can exchange any number of messages. Then,
+the client can provide a token alongside each call, and the server can
+validate that token.
+
+Applications may use any part of this; for instance, they may ignore
+the initial handshake and send an externally acquired token on each
+call, or they may establish trust during the handshake and not
+validate a token for each call. (Note that the latter is not secure if
+you choose to deploy a layer 7 load balancer, as is common with gRPC.)
+
+External Resources
+------------------
+
+- https://arrow.apache.org/blog/2018/10/09/0.11.0-release/
+- https://www.slideshare.net/JacquesNadeau5/apache-arrow-flight-overview
+
+Protocol Buffer Definitions
+---------------------------
+
+.. literalinclude:: ../../../format/Flight.proto
+ :language: protobuf
+ :linenos:
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 3b639b4..6fb16f2 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -43,6 +43,7 @@ such topics as:
format/Layout
format/Metadata
format/IPC
+ format/Flight
.. _toc.usage:
diff --git a/docs/source/python/api.rst b/docs/source/python/api.rst
index b06509f..b1dccd4 100644
--- a/docs/source/python/api.rst
+++ b/docs/source/python/api.rst
@@ -30,6 +30,7 @@ API Reference
api/files
api/tables
api/ipc
+ api/flight
api/formats
api/plasma
api/cuda
diff --git a/docs/source/python/api/flight.rst b/docs/source/python/api/flight.rst
new file mode 100644
index 0000000..4fa1374
--- /dev/null
+++ b/docs/source/python/api/flight.rst
@@ -0,0 +1,82 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+.. or more contributor license agreements. See the NOTICE file
+.. distributed with this work for additional information
+.. regarding copyright ownership. The ASF licenses this file
+.. to you under the Apache License, Version 2.0 (the
+.. "License"); you may not use this file except in compliance
+.. with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+.. software distributed under the License is distributed on an
+.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+.. KIND, either express or implied. See the License for the
+.. specific language governing permissions and limitations
+.. under the License.
+
+.. currentmodule:: pyarrow.flight
+
+Arrow Flight
+============
+
+.. ifconfig:: not flight_enabled
+
+ .. error::
+ This documentation was built without Flight enabled. The Flight
+ API docs are not available.
+
+.. NOTE We still generate those API docs (with empty docstrings)
+.. when Flight is disabled and `pyarrow.flight` mocked (see conf.py).
+.. Otherwise we'd get autodoc warnings, see https://github.com/sphinx-doc/sphinx/issues/4770
+
+.. warning:: Flight is currently unstable. APIs are subject to change,
+ though we don't expect drastic changes.
+
+.. warning:: Flight is currently not distributed as part of wheels or
+ in Conda - it is only available when built from source
+ appropriately.
+
+Common Types
+------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ Action
+ ActionType
+ DescriptorType
+ FlightDescriptor
+ FlightEndpoint
+ FlightInfo
+ Location
+ Ticket
+ Result
+
+Flight Client
+-------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ FlightCallOptions
+ FlightClient
+
+Flight Server
+-------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ FlightServerBase
+ GeneratorStream
+ RecordBatchStream
+
+Authentication
+--------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ ClientAuthHandler
+ ServerAuthHandler
diff --git a/format/Flight.proto b/format/Flight.proto
index 7f0488b..0c8f28e 100644
--- a/format/Flight.proto
+++ b/format/Flight.proto
@@ -77,7 +77,7 @@ service FlightService {
* number. In the latter, the service might implement a 'seal' action that
* can be applied to a descriptor once all streams are uploaded.
*/
- rpc DoPut(stream FlightData) returns (PutResult) {}
+ rpc DoPut(stream FlightData) returns (stream PutResult) {}
/*
* Flight services can support an arbitrary number of simple actions in
@@ -286,6 +286,11 @@ message FlightData {
bytes data_header = 2;
/*
+ * Application-defined metadata.
+ */
+ bytes app_metadata = 3;
+
+ /*
* The actual batch of Arrow data. Preferably handled with minimal-copies
* coming last in the definition to help with sidecar patterns (it is
* expected that some implementations will fetch this field off the wire
@@ -295,7 +300,8 @@ message FlightData {
}
/**
- * The response message (currently empty) associated with the submission of a
- * DoPut.
+ * The response message associated with the submission of a DoPut.
*/
-message PutResult {}
+message PutResult {
+ bytes app_metadata = 1;
+}
diff --git a/integration/integration_test.py b/integration/integration_test.py
index a4763c9..aca0574 100644
--- a/integration/integration_test.py
+++ b/integration/integration_test.py
@@ -1152,7 +1152,7 @@ def get_generated_json_files(tempdir=None, flight=False):
generate_interval_case(),
generate_map_case(),
generate_nested_case(),
- generate_dictionary_case().skip_category(SKIP_FLIGHT),
+ generate_dictionary_case(),
generate_nested_dictionary_case().skip_category(SKIP_ARROW)
.skip_category(SKIP_FLIGHT),
]
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java
index 550f5c1..7879069 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java
@@ -35,6 +35,7 @@ import org.apache.arrow.flight.impl.Flight.FlightData;
import org.apache.arrow.flight.impl.Flight.FlightDescriptor;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -52,7 +53,6 @@ import com.google.protobuf.WireFormat;
import io.grpc.Drainable;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
-import io.grpc.internal.ReadableBuffer;
import io.grpc.protobuf.ProtoUtils;
import io.netty.buffer.ArrowBuf;
@@ -74,9 +74,16 @@ class ArrowMessage implements AutoCloseable {
(FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final int HEADER_TAG =
(FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+ private static final int APP_METADATA_TAG =
+ (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static Marshaller<FlightData> NO_BODY_MARSHALLER = ProtoUtils.marshaller(FlightData.getDefaultInstance());
+ /** Get the application-specific metadata in this message. The ArrowMessage retains ownership of the buffer. */
+ public ArrowBuf getApplicationMetadata() {
+ return appMetadata;
+ }
+
/** Types of messages that can be sent. */
public enum HeaderType {
NONE,
@@ -114,6 +121,7 @@ class ArrowMessage implements AutoCloseable {
private final FlightDescriptor descriptor;
private final Message message;
+ private final ArrowBuf appMetadata;
private final List<ArrowBuf> bufs;
public ArrowMessage(FlightDescriptor descriptor, Schema schema) {
@@ -124,9 +132,15 @@ class ArrowMessage implements AutoCloseable {
message = Message.getRootAsMessage(serializedMessage);
bufs = ImmutableList.of();
this.descriptor = descriptor;
+ this.appMetadata = null;
}
- public ArrowMessage(ArrowRecordBatch batch) {
+ /**
+ * Create an ArrowMessage from a record batch and app metadata.
+ * @param batch The record batch.
+ * @param appMetadata The app metadata. May be null. Takes ownership of the buffer otherwise.
+ */
+ public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata) {
FlatBufferBuilder builder = new FlatBufferBuilder();
int batchOffset = batch.writeTo(builder);
ByteBuffer serializedMessage = MessageSerializer.serializeMessage(builder, MessageHeader.RecordBatch, batchOffset,
@@ -135,11 +149,28 @@ class ArrowMessage implements AutoCloseable {
this.message = Message.getRootAsMessage(serializedMessage);
this.bufs = ImmutableList.copyOf(batch.getBuffers());
this.descriptor = null;
+ this.appMetadata = appMetadata;
}
- private ArrowMessage(FlightDescriptor descriptor, Message message, ArrowBuf buf) {
+ public ArrowMessage(ArrowDictionaryBatch batch) {
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+ int batchOffset = batch.writeTo(builder);
+ ByteBuffer serializedMessage = MessageSerializer
+ .serializeMessage(builder, MessageHeader.DictionaryBatch, batchOffset,
+ batch.computeBodyLength());
+ serializedMessage = serializedMessage.slice();
+ this.message = Message.getRootAsMessage(serializedMessage);
+ // asInputStream will free the buffers implicitly, so increment the reference count
+ batch.getDictionary().getBuffers().forEach(buf -> buf.getReferenceManager().retain());
+ this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers());
+ this.descriptor = null;
+ this.appMetadata = null;
+ }
+
+ private ArrowMessage(FlightDescriptor descriptor, Message message, ArrowBuf appMetadata, ArrowBuf buf) {
this.message = message;
this.descriptor = descriptor;
+ this.appMetadata = appMetadata;
this.bufs = buf == null ? ImmutableList.of() : ImmutableList.of(buf);
}
@@ -169,10 +200,18 @@ class ArrowMessage implements AutoCloseable {
RecordBatch recordBatch = new RecordBatch();
message.header(recordBatch);
ArrowBuf underlying = bufs.get(0);
+ underlying.getReferenceManager().retain();
ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(recordBatch, underlying);
return batch;
}
+ public ArrowDictionaryBatch asDictionaryBatch() throws IOException {
+ Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
+ Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH);
+ ArrowBuf underlying = bufs.get(0);
+ return MessageSerializer.deserializeDictionaryBatch(message, underlying);
+ }
+
public Iterable<ArrowBuf> getBufs() {
return Iterables.unmodifiableIterable(bufs);
}
@@ -183,6 +222,7 @@ class ArrowMessage implements AutoCloseable {
FlightDescriptor descriptor = null;
Message header = null;
ArrowBuf body = null;
+ ArrowBuf appMetadata = null;
while (stream.available() > 0) {
int tag = readRawVarint32(stream);
switch (tag) {
@@ -201,6 +241,12 @@ class ArrowMessage implements AutoCloseable {
header = Message.getRootAsMessage(ByteBuffer.wrap(bytes));
break;
}
+ case APP_METADATA_TAG: {
+ int size = readRawVarint32(stream);
+ appMetadata = allocator.buffer(size);
+ GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, FAST_PATH);
+ break;
+ }
case BODY_TAG:
if (body != null) {
// only read last body.
@@ -209,15 +255,7 @@ class ArrowMessage implements AutoCloseable {
}
int size = readRawVarint32(stream);
body = allocator.buffer(size);
- ReadableBuffer readableBuffer = FAST_PATH ? GetReadableBuffer.getReadableBuffer(stream) : null;
- if (readableBuffer != null) {
- readableBuffer.readBytes(body.nioBuffer(0, size));
- } else {
- byte[] heapBytes = new byte[size];
- ByteStreams.readFully(stream, heapBytes);
- body.writeBytes(heapBytes);
- }
- body.writerIndex(size);
+ GetReadableBuffer.readIntoBuffer(stream, body, size, FAST_PATH);
break;
default:
@@ -225,7 +263,7 @@ class ArrowMessage implements AutoCloseable {
}
}
- return new ArrowMessage(descriptor, header, body);
+ return new ArrowMessage(descriptor, header, appMetadata, body);
} catch (Exception ioe) {
throw new RuntimeException(ioe);
}
@@ -246,7 +284,6 @@ class ArrowMessage implements AutoCloseable {
final ByteString bytes = ByteString.copyFrom(message.getByteBuffer(), message.getByteBuffer().remaining());
-
if (getMessageType() == HeaderType.SCHEMA) {
final FlightData.Builder builder = FlightData.newBuilder()
@@ -260,15 +297,23 @@ class ArrowMessage implements AutoCloseable {
return NO_BODY_MARSHALLER.stream(builder.build());
}
- Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH);
+ Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH ||
+ getMessageType() == HeaderType.DICTIONARY_BATCH);
Preconditions.checkArgument(!bufs.isEmpty());
Preconditions.checkArgument(descriptor == null, "Descriptor should only be included in the schema message.");
ByteArrayOutputStream baos = new ByteArrayOutputStream();
CodedOutputStream cos = CodedOutputStream.newInstance(baos);
cos.writeBytes(FlightData.DATA_HEADER_FIELD_NUMBER, bytes);
- cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED);
+ if (appMetadata != null && appMetadata.capacity() > 0) {
+ // Must call slice() as CodedOutputStream#writeByteBuffer writes -capacity- bytes, not -limit- bytes
+ cos.writeByteBuffer(FlightData.APP_METADATA_FIELD_NUMBER, appMetadata.asNettyBuffer().nioBuffer().slice());
+ // This is weird, but implicitly, writing an ArrowMessage frees any references it has
+ appMetadata.getReferenceManager().release();
+ }
+
+ cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED);
int size = 0;
List<ByteBuf> allBufs = new ArrayList<>();
for (ArrowBuf b : bufs) {
@@ -290,6 +335,7 @@ class ArrowMessage implements AutoCloseable {
initialBuf.writeBytes(baos.toByteArray());
final CompositeByteBuf bb = new CompositeByteBuf(allocator.getAsByteBufAllocator(), true, bufs.size() + 1,
ImmutableList.<ByteBuf>builder().add(initialBuf.asNettyBuffer()).addAll(allBufs).build());
+ // Implicitly, transfer ownership of our buffers to the input stream (which will decrement the refcount when done)
final ByteBufInputStream is = new DrainableByteBufInputStream(bb);
return is;
} catch (Exception ex) {
@@ -319,7 +365,7 @@ class ArrowMessage implements AutoCloseable {
}
@Override
- public void close() throws IOException {
+ public void close() {
buf.release();
}
@@ -354,5 +400,8 @@ class ArrowMessage implements AutoCloseable {
@Override
public void close() throws Exception {
AutoCloseables.close(bufs);
+ if (appMetadata != null) {
+ appMetadata.close();
+ }
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
new file mode 100644
index 0000000..c8214e3
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+/**
+ * A handler for server-sent application metadata messages during a Flight DoPut operation.
+ *
+ * <p>To handle messages, create an instance of this class overriding {@link #onNext(PutResult)}. The other methods
+ * should not be overridden.
+ */
+public class AsyncPutListener implements FlightClient.PutListener {
+
+ private CompletableFuture<Void> completed;
+
+ public AsyncPutListener() {
+ completed = new CompletableFuture<>();
+ }
+
+ /**
+ * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have
+ * happened during the upload.
+ */
+ @Override
+ public final void getResult() {
+ try {
+ completed.get();
+ } catch (InterruptedException | ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void onNext(PutResult val) {
+ }
+
+ @Override
+ public final void onError(Throwable t) {
+ completed.completeExceptionally(t);
+ }
+
+ @Override
+ public final void onCompleted() {
+ completed.complete(null);
+ }
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
new file mode 100644
index 0000000..6409b6a
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Consumer;
+
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.DictionaryUtility;
+
+/**
+ * Utilities to work with dictionaries in Flight.
+ */
+final class DictionaryUtils {
+
+ private DictionaryUtils() {
+ throw new UnsupportedOperationException("Do not instantiate this class.");
+ }
+
+ /**
+ * Generate all the necessary Flight messages to send a schema and associated dictionaries.
+ */
+ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor,
+ final DictionaryProvider provider, final Consumer<ArrowMessage> messageCallback) {
+ final List<Field> fields = new ArrayList<>(originalSchema.getFields().size());
+ final Set<Long> dictionaryIds = new HashSet<>();
+ for (final Field field : originalSchema.getFields()) {
+ fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIds));
+ }
+ final Schema schema = new Schema(fields, originalSchema.getCustomMetadata());
+ // Send the schema message
+ messageCallback.accept(new ArrowMessage(descriptor == null ? null : descriptor.toProtocol(), schema));
+ // Create and write dictionary batches
+ for (Long id : dictionaryIds) {
+ final Dictionary dictionary = provider.lookup(id);
+ final FieldVector vector = dictionary.getVector();
+ final int count = vector.getValueCount();
+ // Do NOT close this root, as it does not actually own the vector.
+ final VectorSchemaRoot dictRoot = new VectorSchemaRoot(
+ Collections.singletonList(vector.getField()),
+ Collections.singletonList(vector),
+ count);
+ final VectorUnloader unloader = new VectorUnloader(dictRoot);
+ try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(
+ id, unloader.getRecordBatch())) {
+ messageCallback.accept(new ArrowMessage(dictionaryBatch));
+ }
+ }
+ return schema;
+ }
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java
index d352b2b..13a28f9 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java
@@ -28,12 +28,11 @@ import com.google.common.collect.ImmutableSet;
import io.grpc.BindableService;
import io.grpc.MethodDescriptor;
+import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ServerCalls;
-import io.grpc.stub.ServerCalls.ClientStreamingMethod;
-import io.grpc.stub.ServerCalls.ServerStreamingMethod;
import io.grpc.stub.StreamObserver;
/**
@@ -66,7 +65,7 @@ class FlightBindingService implements BindableService {
public static MethodDescriptor<ArrowMessage, Flight.PutResult> getDoPutDescriptor(BufferAllocator allocator) {
return MethodDescriptor.<ArrowMessage, Flight.PutResult>newBuilder()
- .setType(io.grpc.MethodDescriptor.MethodType.CLIENT_STREAMING)
+ .setType(MethodType.BIDI_STREAMING)
.setFullMethodName(DO_PUT)
.setSampledToLocalTracing(false)
.setRequestMarshaller(ArrowMessage.createMarshaller(allocator))
@@ -84,7 +83,7 @@ class FlightBindingService implements BindableService {
ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(FlightConstants.SERVICE);
serviceBuilder.addMethod(doGetDescriptor, ServerCalls.asyncServerStreamingCall(new DoGetMethod(delegate)));
- serviceBuilder.addMethod(doPutDescriptor, ServerCalls.asyncClientStreamingCall(new DoPutMethod(delegate)));
+ serviceBuilder.addMethod(doPutDescriptor, ServerCalls.asyncBidiStreamingCall(new DoPutMethod(delegate)));
// copy over not-overridden methods.
for (ServerMethodDefinition<?, ?> definition : baseDefinition.getMethods()) {
@@ -98,7 +97,7 @@ class FlightBindingService implements BindableService {
return serviceBuilder.build();
}
- private class DoGetMethod implements ServerStreamingMethod<Flight.Ticket, ArrowMessage> {
+ private class DoGetMethod implements ServerCalls.ServerStreamingMethod<Flight.Ticket, ArrowMessage> {
private final FlightService delegate;
@@ -112,7 +111,7 @@ class FlightBindingService implements BindableService {
}
}
- private class DoPutMethod implements ClientStreamingMethod<ArrowMessage, Flight.PutResult> {
+ private class DoPutMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, PutResult> {
private final FlightService delegate;
public DoPutMethod(FlightService delegate) {
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index 37e4514..9ac3686 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -17,9 +17,6 @@
package org.apache.arrow.flight;
-import static io.grpc.stub.ClientCalls.asyncClientStreamingCall;
-import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
-
import java.io.InputStream;
import java.net.URISyntaxException;
import java.util.Iterator;
@@ -28,27 +25,26 @@ import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
+import org.apache.arrow.flight.FlightProducer.StreamListener;
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthInterceptor;
import org.apache.arrow.flight.auth.ClientAuthWrapper;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.Empty;
-import org.apache.arrow.flight.impl.Flight.PutResult;
import org.apache.arrow.flight.impl.FlightServiceGrpc;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import com.google.common.base.Preconditions;
-import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterators;
-import com.google.common.util.concurrent.ListenableFuture;
-import com.google.common.util.concurrent.SettableFuture;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
@@ -56,9 +52,11 @@ import io.grpc.MethodDescriptor;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.ClientCallStreamObserver;
+import io.grpc.stub.ClientCalls;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
+import io.netty.buffer.ArrowBuf;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContextBuilder;
@@ -78,6 +76,9 @@ public class FlightClient implements AutoCloseable {
private final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor;
private final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor;
+ /**
+ * Create a Flight client from an allocator and a gRPC channel.
+ */
private FlightClient(BufferAllocator incomingAllocator, ManagedChannel channel) {
this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
this.channel = channel;
@@ -156,27 +157,42 @@ public class FlightClient implements AutoCloseable {
/**
* Create or append a descriptor with another stream.
- * @param descriptor FlightDescriptor
- * @param root VectorSchemaRoot
+ *
+ * @param descriptor FlightDescriptor the descriptor for the data
+ * @param root VectorSchemaRoot the root containing data
+ * @param metadataListener A handler for metadata messages from the server. This will be passed buffers that will be
+ * freed after {@link StreamListener#onNext(Object)} is called!
* @param options RPC-layer hints for this call.
- * @return ClientStreamListener
+ * @return ClientStreamListener an interface to control uploading data
*/
- public ClientStreamListener startPut(
- FlightDescriptor descriptor, VectorSchemaRoot root, CallOption... options) {
+ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root,
+ PutListener metadataListener, CallOption... options) {
+ return startPut(descriptor, root, new MapDictionaryProvider(), metadataListener, options);
+ }
+
+ /**
+ * Create or append a descriptor with another stream.
+ * @param descriptor FlightDescriptor the descriptor for the data
+ * @param root VectorSchemaRoot the root containing data
+ * @param metadataListener A handler for metadata messages from the server.
+ * @param options RPC-layer hints for this call.
+ * @return ClientStreamListener an interface to control uploading data
+ */
+ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider,
+ PutListener metadataListener, CallOption... options) {
Preconditions.checkNotNull(descriptor);
Preconditions.checkNotNull(root);
- SetStreamObserver<PutResult> resultObserver = new SetStreamObserver<>();
+ SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener);
final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>)
- asyncClientStreamingCall(
- authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
+ ClientCalls.asyncBidiStreamingCall(
+ authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
// send the schema to start.
- ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema());
- observer.onNext(message);
+ DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext);
return new PutObserver(new VectorUnloader(
root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */),
- observer, resultObserver.getFuture());
+ observer, metadataListener);
}
/**
@@ -202,7 +218,7 @@ public class FlightClient implements AutoCloseable {
public FlightStream getStream(Ticket ticket, CallOption... options) {
final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
ClientCall<Flight.Ticket, ArrowMessage> call =
- authInterceptor.interceptCall(doGetDescriptor, callOptions, channel);
+ authInterceptor.interceptCall(doGetDescriptor, callOptions, channel);
FlightStream stream = new FlightStream(
allocator,
PENDING_REQUESTS,
@@ -235,54 +251,64 @@ public class FlightClient implements AutoCloseable {
};
- asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver);
+ ClientCalls.asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver);
return stream;
}
- private static class SetStreamObserver<T> implements StreamObserver<T> {
- private final SettableFuture<T> result = SettableFuture.create();
- private volatile T resultLocal;
+ private static class SetStreamObserver implements StreamObserver<Flight.PutResult> {
+ private final BufferAllocator allocator;
+ private final StreamListener<PutResult> listener;
+
+ SetStreamObserver(BufferAllocator allocator, StreamListener<PutResult> listener) {
+ super();
+ this.allocator = allocator;
+ this.listener = listener == null ? NoOpStreamListener.getInstance() : listener;
+ }
@Override
- public void onNext(T value) {
- resultLocal = value;
+ public void onNext(Flight.PutResult value) {
+ try (final PutResult message = PutResult.fromProtocol(allocator, value)) {
+ listener.onNext(message);
+ }
}
@Override
public void onError(Throwable t) {
- result.setException(t);
+ listener.onError(t);
}
@Override
public void onCompleted() {
- result.set(Preconditions.checkNotNull(resultLocal));
- }
-
- public ListenableFuture<T> getFuture() {
- return result;
+ listener.onCompleted();
}
}
private static class PutObserver implements ClientStreamListener {
+
private final ClientCallStreamObserver<ArrowMessage> observer;
private final VectorUnloader unloader;
- private final ListenableFuture<PutResult> futureResult;
+ private final PutListener listener;
public PutObserver(VectorUnloader unloader, ClientCallStreamObserver<ArrowMessage> observer,
- ListenableFuture<PutResult> futureResult) {
+ PutListener listener) {
this.observer = observer;
this.unloader = unloader;
- this.futureResult = futureResult;
+ this.listener = listener;
}
@Override
public void putNext() {
+ putNext(null);
+ }
+
+ @Override
+ public void putNext(ArrowBuf appMetadata) {
ArrowRecordBatch batch = unloader.getRecordBatch();
- // Check the futureResult in case server sent an exception
- while (!observer.isReady() && !futureResult.isDone()) {
+ while (!observer.isReady()) {
/* busy wait */
}
- observer.onNext(new ArrowMessage(batch));
+ // Takes ownership of appMetadata
+ observer.onNext(new ArrowMessage(batch, appMetadata));
}
@Override
@@ -296,12 +322,8 @@ public class FlightClient implements AutoCloseable {
}
@Override
- public PutResult getResult() {
- try {
- return futureResult.get();
- } catch (Exception ex) {
- throw Throwables.propagate(ex);
- }
+ public void getResult() {
+ listener.getResult();
}
}
@@ -310,17 +332,59 @@ public class FlightClient implements AutoCloseable {
*/
public interface ClientStreamListener {
+ /**
+ * Send the current data in the corresponding {@link VectorSchemaRoot} to the server.
+ */
void putNext();
+ /**
+ * Send the current data in the corresponding {@link VectorSchemaRoot} to the server, along with
+ * application-specific metadata. This takes ownership of the buffer.
+ */
+ void putNext(ArrowBuf appMetadata);
+
+ /**
+ * Indicate an error to the server. Terminates the stream; do not call {@link #completed()}.
+ */
void error(Throwable ex);
+ /** Indicate the stream is finished on the client side. */
void completed();
- PutResult getResult();
-
+ /**
+ * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have
+ * happened during the upload.
+ */
+ void getResult();
}
+ /**
+ * A handler for server-sent application metadata messages during a Flight DoPut operation.
+ *
+ * <p>Generally, instead of implementing this yourself, you should use {@link AsyncPutListener} or {@link
+ * SyncPutListener}.
+ */
+ public interface PutListener extends StreamListener<PutResult> {
+ /**
+ * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have
+ * happened during the upload.
+ */
+ void getResult();
+
+ /**
+ * Called when a message from the server is received.
+ *
+ * @param val The application metadata. This buffer will be reclaimed once onNext returns; you must retain a
+ * reference to use it outside this method.
+ */
+ @Override
+ void onNext(PutResult val);
+ }
+
+ /**
+ * Shut down this client.
+ */
public void close() throws InterruptedException {
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
allocator.close();
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java
index 9c3cc2b..fdb5e9f 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java
@@ -17,49 +17,118 @@
package org.apache.arrow.flight;
-import java.util.concurrent.Callable;
-
-import org.apache.arrow.flight.impl.Flight.PutResult;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+
+import io.netty.buffer.ArrowBuf;
/**
* API to Implement an Arrow Flight producer.
*/
public interface FlightProducer {
- void getStream(CallContext context, Ticket ticket,
- ServerStreamListener listener);
+ /**
+ * Return data for a stream.
+ *
+ * @param context Per-call context.
+ * @param ticket The application-defined ticket identifying this stream.
+ * @param listener An interface for sending data back to the client.
+ */
+ void getStream(CallContext context, Ticket ticket, ServerStreamListener listener);
+ /**
+ * List available data streams on this service.
+ *
+ * @param context Per-call context.
+ * @param criteria Application-defined criteria for filtering streams.
+ * @param listener An interface for sending data back to the client.
+ */
void listFlights(CallContext context, Criteria criteria,
StreamListener<FlightInfo> listener);
- FlightInfo getFlightInfo(CallContext context,
- FlightDescriptor descriptor);
+ /**
+ * Get information about a particular data stream.
+ *
+ * @param context Per-call context.
+ * @param descriptor The descriptor identifying the data stream.
+ * @return Metadata about the stream.
+ */
+ FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor);
- Callable<PutResult> acceptPut(CallContext context,
- FlightStream flightStream);
+ /**
+ * Accept uploaded data for a particular stream.
+ *
+ * @param context Per-call context.
+ * @param flightStream The data stream being uploaded.
+ */
+ Runnable acceptPut(CallContext context,
+ FlightStream flightStream, StreamListener<PutResult> ackStream);
+ /**
+ * Generic handler for application-defined RPCs.
+ *
+ * @param context Per-call context.
+ * @param action Client-supplied parameters.
+ * @param listener A stream of responses.
+ */
void doAction(CallContext context, Action action,
StreamListener<Result> listener);
- void listActions(CallContext context,
- StreamListener<ActionType> listener);
+ /**
+ * List available application-defined RPCs.
+ * @param context Per-call context.
+ * @param listener An interface for sending data back to the client.
+ */
+ void listActions(CallContext context, StreamListener<ActionType> listener);
/**
- * Listener for creating a stream on the server side.
+ * An interface for sending Arrow data back to a client.
*/
interface ServerStreamListener {
+ /**
+ * Check whether the call has been cancelled. If so, stop sending data.
+ */
boolean isCancelled();
+ /**
+ * A hint indicating whether the client is ready to receive data without excessive buffering.
+ */
boolean isReady();
+ /**
+ * Start sending data, using the schema of the given {@link VectorSchemaRoot}.
+ *
+ * <p>This method must be called before all others.
+ */
void start(VectorSchemaRoot root);
+ /**
+ * Start sending data, using the schema of the given {@link VectorSchemaRoot}.
+ *
+ * <p>This method must be called before all others.
+ */
+ void start(VectorSchemaRoot root, DictionaryProvider dictionaries);
+
+ /**
+ * Send the current contents of the associated {@link VectorSchemaRoot}.
+ */
void putNext();
+ /**
+ * Send the current contents of the associated {@link VectorSchemaRoot} alongside application-defined metadata.
+ * @param metadata The metadata to send. Ownership of the buffer is transferred to the Flight implementation.
+ */
+ void putNext(ArrowBuf metadata);
+
+ /**
+ * Indicate an error to the client. Terminates the stream; do not call {@link #completed()} afterwards.
+ */
void error(Throwable ex);
+ /**
+ * Indicate that transmission is finished.
+ */
void completed();
}
@@ -71,10 +140,21 @@ public interface FlightProducer {
*/
interface StreamListener<T> {
+ /**
+ * Send the next value to the client.
+ */
void onNext(T val);
+ /**
+ * Indicate an error to the client.
+ *
+ * <p>Terminates the stream; do not call {@link #onCompleted()}.
+ */
void onError(Throwable t);
+ /**
+ * Indicate that the transmission is finished.
+ */
void onCompleted();
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index cd59a75..3f02dd5 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -21,6 +21,8 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.net.URI;
+import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
@@ -47,13 +49,15 @@ public class FlightServer implements AutoCloseable {
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FlightServer.class);
+ private final Location location;
private final Server server;
/** The maximum size of an individual gRPC message. This effectively disables the limit. */
static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE;
/** Create a new instance from a gRPC server. For internal use only. */
- private FlightServer(Server server) {
+ private FlightServer(Location location, Server server) {
+ this.location = location;
this.server = server;
}
@@ -63,10 +67,27 @@ public class FlightServer implements AutoCloseable {
return this;
}
+ /** Get the port the server is running on (if applicable). */
public int getPort() {
return server.getPort();
}
+ /** Get the location for this server. */
+ public Location getLocation() {
+ if (location.getUri().getPort() == 0) {
+ // If the server was bound to port 0, replace the port in the location with the real port.
+ final URI uri = location.getUri();
+ try {
+ return new Location(new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), getPort(),
+ uri.getPath(), uri.getQuery(), uri.getFragment()));
+ } catch (URISyntaxException e) {
+ // We don't expect this to happen
+ throw new RuntimeException(e);
+ }
+ }
+ return location;
+ }
+
/** Block until the server shuts down. */
public void awaitTermination() throws InterruptedException {
server.awaitTermination();
@@ -211,7 +232,7 @@ public class FlightServer implements AutoCloseable {
return null;
});
- return new FlightServer(builder.build());
+ return new FlightServer(location, builder.build());
}
/**
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
index b5c22ef..ee45cef 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
@@ -30,18 +30,21 @@ import org.apache.arrow.flight.impl.Flight.ActionType;
import org.apache.arrow.flight.impl.Flight.Empty;
import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
-import org.apache.arrow.flight.impl.Flight.PutResult;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Preconditions;
+import io.grpc.Status;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
+import io.netty.buffer.ArrowBuf;
/**
* GRPC service implementation for a flight server.
@@ -138,15 +141,25 @@ class FlightService extends FlightServiceImplBase {
@Override
public void start(VectorSchemaRoot root) {
- responseObserver.onNext(new ArrowMessage(null, root.getSchema()));
- // [ARROW-4213] We must align buffers to be compatible with other languages.
+ start(root, new MapDictionaryProvider());
+ }
+
+ @Override
+ public void start(VectorSchemaRoot root, DictionaryProvider provider) {
unloader = new VectorUnloader(root, true, true);
+
+ DictionaryUtils.generateSchemaMessages(root.getSchema(), null, provider, responseObserver::onNext);
}
@Override
public void putNext() {
+ putNext(null);
+ }
+
+ @Override
+ public void putNext(ArrowBuf metadata) {
Preconditions.checkNotNull(unloader);
- responseObserver.onNext(new ArrowMessage(unloader.getRecordBatch()));
+ responseObserver.onNext(new ArrowMessage(unloader.getRecordBatch(), metadata));
}
@Override
@@ -161,18 +174,33 @@ class FlightService extends FlightServiceImplBase {
}
- public StreamObserver<ArrowMessage> doPutCustom(final StreamObserver<PutResult> responseObserverSimple) {
- ServerCallStreamObserver<PutResult> responseObserver = (ServerCallStreamObserver<PutResult>) responseObserverSimple;
+ public StreamObserver<ArrowMessage> doPutCustom(final StreamObserver<Flight.PutResult> responseObserverSimple) {
+ ServerCallStreamObserver<Flight.PutResult> responseObserver =
+ (ServerCallStreamObserver<Flight.PutResult>) responseObserverSimple;
responseObserver.disableAutoInboundFlowControl();
responseObserver.request(1);
- FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, null, (count) -> responseObserver.request(count));
+ // Set a default metadata listener that does nothing. Service implementations should call
+ // FlightStream#setMetadataListener before returning a Runnable if they want to receive metadata.
+ FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, (String message, Throwable cause) -> {
+ responseObserver.onError(Status.CANCELLED.withCause(cause).withDescription(message).asException());
+ }, responseObserver::request);
executors.submit(() -> {
try {
- responseObserver.onNext(producer.acceptPut(makeContext(responseObserver), fs).call());
+ producer.acceptPut(makeContext(responseObserver), fs,
+ StreamPipe.wrap(responseObserver, PutResult::toProtocol)).run();
responseObserver.onCompleted();
} catch (Exception ex) {
responseObserver.onError(ex);
+ // The client may have terminated, so the exception here is effectively swallowed.
+ // Log the error as well so -something- makes it to the developer.
+ logger.error("Exception handling DoPut", ex);
+ }
+ try {
+ fs.close();
+ } catch (Exception e) {
+ logger.error("Exception closing Flight stream", e);
+ throw new RuntimeException(e);
}
});
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java
index 79685c4..010ff33 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java
@@ -17,17 +17,28 @@
package org.apache.arrow.flight;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.stream.Collectors;
+import org.apache.arrow.flight.ArrowMessage.HeaderType;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.DictionaryUtility;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
@@ -35,11 +46,12 @@ import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.stub.StreamObserver;
+import io.netty.buffer.ArrowBuf;
/**
* An adaptor between protobuf streams and flight data streams.
*/
-public class FlightStream {
+public class FlightStream implements AutoCloseable {
private final Object DONE = new Object();
@@ -56,10 +68,12 @@ public class FlightStream {
private volatile int pending = 1;
private boolean completed = false;
private volatile VectorSchemaRoot fulfilledRoot;
+ private DictionaryProvider.MapDictionaryProvider dictionaries;
private volatile VectorLoader loader;
private volatile Throwable ex;
private volatile FlightDescriptor descriptor;
private volatile Schema schema;
+ private volatile ArrowBuf applicationMetadata = null;
/**
* Constructs a new instance.
@@ -74,12 +88,17 @@ public class FlightStream {
this.pendingTarget = pendingTarget;
this.cancellable = cancellable;
this.requestor = requestor;
+ this.dictionaries = new DictionaryProvider.MapDictionaryProvider();
}
public Schema getSchema() {
return schema;
}
+ public DictionaryProvider getDictionaryProvider() {
+ return dictionaries;
+ }
+
public FlightDescriptor getDescriptor() {
return descriptor;
}
@@ -98,7 +117,10 @@ public class FlightStream {
.map(t -> ((AutoCloseable) t))
.collect(Collectors.toList());
- AutoCloseables.close(Iterables.concat(closeables, ImmutableList.of(root.get())));
+ // Must check for null since ImmutableList doesn't accept nulls
+ AutoCloseables.close(Iterables.concat(closeables,
+ applicationMetadata != null ? ImmutableList.of(root.get(), applicationMetadata)
+ : ImmutableList.of(root.get())));
}
/**
@@ -131,15 +153,41 @@ public class FlightStream {
throw new Exception(ex);
}
} else {
- ArrowMessage msg = ((ArrowMessage) data);
- try (ArrowRecordBatch arb = msg.asRecordBatch()) {
- loader.load(arb);
+ try (ArrowMessage msg = ((ArrowMessage) data)) {
+ if (msg.getMessageType() == HeaderType.RECORD_BATCH) {
+ try (ArrowRecordBatch arb = msg.asRecordBatch()) {
+ loader.load(arb);
+ }
+ if (this.applicationMetadata != null) {
+ this.applicationMetadata.close();
+ }
+ this.applicationMetadata = msg.getApplicationMetadata();
+ if (this.applicationMetadata != null) {
+ this.applicationMetadata.getReferenceManager().retain();
+ }
+ } else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) {
+ try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) {
+ final long id = arb.getDictionaryId();
+ final Dictionary dictionary = dictionaries.lookup(id);
+ if (dictionary == null) {
+ throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id);
+ }
+
+ final FieldVector vector = dictionary.getVector();
+ final VectorSchemaRoot dictionaryRoot = new VectorSchemaRoot(Collections.singletonList(vector.getField()),
+ Collections.singletonList(vector), 0);
+ final VectorLoader dictionaryLoader = new VectorLoader(dictionaryRoot);
+ dictionaryLoader.load(arb.getDictionary());
+ }
+ return next();
+ } else {
+ throw new UnsupportedOperationException("Message type is unsupported: " + msg.getMessageType());
+ }
+ return true;
}
- return true;
}
-
} catch (Exception e) {
- throw Throwables.propagate(e);
+ throw new RuntimeException(e);
}
}
@@ -152,6 +200,17 @@ public class FlightStream {
}
}
+ /**
+ * Get the most recent metadata sent from the server. This may be cleared by calls to {@link #next()} if the server
+ * sends a message without metadata. This does NOT take ownership of the buffer - call retain() to create a reference
+ * if you need the buffer after a call to {@link #next()}.
+ *
+ * @return the application metadata. May be null.
+ */
+ public ArrowBuf getLatestMetadata() {
+ return applicationMetadata;
+ }
+
private synchronized void requestOutstanding() {
if (pending < pendingTarget) {
requestor.request(pendingTarget - pending);
@@ -169,23 +228,36 @@ public class FlightStream {
public void onNext(ArrowMessage msg) {
requestOutstanding();
switch (msg.getMessageType()) {
- case SCHEMA:
+ case SCHEMA: {
schema = msg.asSchema();
+ final List<Field> fields = new ArrayList<>();
+ final Map<Long, Dictionary> dictionaryMap = new HashMap<>();
+ for (final Field originalField : schema.getFields()) {
+ final Field updatedField = DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap);
+ fields.add(updatedField);
+ }
+ for (final Map.Entry<Long, Dictionary> entry : dictionaryMap.entrySet()) {
+ dictionaries.put(entry.getValue());
+ }
+ schema = new Schema(fields, schema.getCustomMetadata());
fulfilledRoot = VectorSchemaRoot.create(schema, allocator);
loader = new VectorLoader(fulfilledRoot);
descriptor = msg.getDescriptor() != null ? new FlightDescriptor(msg.getDescriptor()) : null;
root.set(fulfilledRoot);
break;
+ }
case RECORD_BATCH:
queue.add(msg);
break;
- case NONE:
case DICTIONARY_BATCH:
+ queue.add(msg);
+ break;
+ case NONE:
case TENSOR:
default:
queue.add(DONE_EX);
- ex = new UnsupportedOperationException("Unable to handle message of type: " + msg);
+ ex = new UnsupportedOperationException("Unable to handle message of type: " + msg.getMessageType());
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
index 0e6e373..eca32e1 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
@@ -17,10 +17,6 @@
package org.apache.arrow.flight;
-import java.util.concurrent.Callable;
-
-import org.apache.arrow.flight.impl.Flight.PutResult;
-
/**
* A {@link FlightProducer} that throws on all operations.
*/
@@ -45,8 +41,8 @@ public class NoOpFlightProducer implements FlightProducer {
}
@Override
- public Callable<PutResult> acceptPut(CallContext context,
- FlightStream flightStream) {
+ public Runnable acceptPut(CallContext context,
+ FlightStream flightStream, StreamListener<PutResult> ackStream) {
throw new UnsupportedOperationException("NYI");
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java b/java/flight/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java
similarity index 54%
rename from java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java
rename to java/flight/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java
index 03a1e92..e06af1a 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java
@@ -17,26 +17,33 @@
package org.apache.arrow.flight;
+import org.apache.arrow.flight.FlightProducer.StreamListener;
+
/**
- * Unused?.
+ * A {@link StreamListener} that does nothing for all callbacks.
+ * @param <T> The type of the callback object.
*/
-class GenericOperation {
-
- private final String type;
- private final byte[] body;
+public class NoOpStreamListener<T> implements StreamListener<T> {
+ private static NoOpStreamListener INSTANCE = new NoOpStreamListener();
- public GenericOperation(String type, byte[] body) {
- super();
- this.type = type;
- this.body = body == null ? new byte[0] : body;
+ /** Ignores the value received. */
+ @Override
+ public void onNext(T val) {
}
- public String getType() {
- return type;
+ /** Ignores the error received. */
+ @Override
+ public void onError(Throwable t) {
}
- public byte[] getBody() {
- return body;
+ /** Ignores the stream completion event. */
+ @Override
+ public void onCompleted() {
}
+ @SuppressWarnings("unchecked")
+ public static <T> StreamListener<T> getInstance() {
+ // Safe because we never use T
+ return (StreamListener<T>) INSTANCE;
+ }
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java b/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java
new file mode 100644
index 0000000..1184869
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.ReferenceManager;
+
+import com.google.protobuf.ByteString;
+
+import io.netty.buffer.ArrowBuf;
+
+/**
+ * A message from the server during a DoPut operation.
+ *
+ * <p>This object owns an {@link ArrowBuf} and should be closed when you are done with it.
+ */
+public class PutResult implements AutoCloseable {
+
+ private ArrowBuf applicationMetadata;
+
+ private PutResult(ArrowBuf metadata) {
+ applicationMetadata = metadata;
+ }
+
+ /**
+ * Create a PutResult with application-specific metadata.
+ *
+ * <p>This method assumes ownership of the {@link ArrowBuf}.
+ */
+ public static PutResult metadata(ArrowBuf metadata) {
+ if (metadata == null) {
+ return empty();
+ }
+ return new PutResult(metadata);
+ }
+
+ /** Create an empty PutResult. */
+ public static PutResult empty() {
+ return new PutResult(null);
+ }
+
+ /**
+ * Get the metadata in this message. May be null.
+ *
+ * <p>Ownership of the {@link ArrowBuf} is retained by this object. Call {@link ReferenceManager#retain()} to preserve
+ * a reference.
+ */
+ public ArrowBuf getApplicationMetadata() {
+ return applicationMetadata;
+ }
+
+ Flight.PutResult toProtocol() {
+ if (applicationMetadata == null) {
+ return Flight.PutResult.getDefaultInstance();
+ }
+ return Flight.PutResult.newBuilder().setAppMetadata(ByteString.copyFrom(applicationMetadata.nioBuffer())).build();
+ }
+
+ /**
+ * Construct a PutResult from a Protobuf message.
+ *
+ * @param allocator The allocator to use for allocating application metadata memory. The result object owns the
+ * allocated buffer, if any.
+ * @param message The gRPC/Protobuf message.
+ */
+ static PutResult fromProtocol(BufferAllocator allocator, Flight.PutResult message) {
+ final ArrowBuf buf = allocator.buffer(message.getAppMetadata().size());
+ message.getAppMetadata().asReadOnlyByteBufferList().forEach(bb -> {
+ buf.setBytes(buf.writerIndex(), bb);
+ buf.writerIndex(buf.writerIndex() + bb.limit());
+ });
+ return new PutResult(buf);
+ }
+
+ @Override
+ public void close() {
+ if (applicationMetadata != null) {
+ applicationMetadata.close();
+ }
+ }
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
new file mode 100644
index 0000000..f1246a1
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import io.netty.buffer.ArrowBuf;
+
+/**
+ * A listener for server-sent application metadata messages during a Flight DoPut. This class wraps the messages in a
+ * synchronous interface.
+ */
+public final class SyncPutListener implements FlightClient.PutListener, AutoCloseable {
+
+ private final LinkedBlockingQueue<Object> queue;
+ private final CompletableFuture<Void> completed;
+ private static final Object DONE = new Object();
+ private static final Object DONE_WITH_EXCEPTION = new Object();
+
+ public SyncPutListener() {
+ queue = new LinkedBlockingQueue<>();
+ completed = new CompletableFuture<>();
+ }
+
+ private PutResult unwrap(Object queueItem) throws InterruptedException, ExecutionException {
+ if (queueItem == DONE) {
+ queue.put(queueItem);
+ return null;
+ } else if (queueItem == DONE_WITH_EXCEPTION) {
+ queue.put(queueItem);
+ completed.get();
+ }
+ return (PutResult) queueItem;
+ }
+
+ /**
+ * Get the next message from the server, blocking until it is available.
+ *
+ * @return The next message, or null if the server is done sending messages. The caller assumes ownership of the
+ * metadata and must remember to close it.
+ * @throws InterruptedException if interrupted while waiting.
+ * @throws ExecutionException if the server sent an error, or if there was an internal error.
+ */
+ public PutResult read() throws InterruptedException, ExecutionException {
+ return unwrap(queue.take());
+ }
+
+ /**
+ * Get the next message from the server, blocking for the specified amount of time until it is available.
+ *
+ * @return The next message, or null if the server is done sending messages or no message arrived before the timeout.
+ * The caller assumes ownership of the metadata and must remember to close it.
+ * @throws InterruptedException if interrupted while waiting.
+ * @throws ExecutionException if the server sent an error, or if there was an internal error.
+ */
+ public PutResult poll(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException {
+ return unwrap(queue.poll(timeout, unit));
+ }
+
+ @Override
+ public void getResult() {
+ try {
+ completed.get();
+ } catch (InterruptedException | ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void onNext(PutResult val) {
+ final ArrowBuf metadata = val.getApplicationMetadata();
+ metadata.getReferenceManager().retain();
+ queue.add(PutResult.metadata(metadata));
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ completed.completeExceptionally(t);
+ queue.add(DONE_WITH_EXCEPTION);
+ }
+
+ @Override
+ public void onCompleted() {
+ completed.complete(null);
+ queue.add(DONE);
+ }
+
+ @Override
+ public void close() {
+ queue.forEach(o -> {
+ if (o instanceof PutResult) {
+ ((PutResult) o).close();
+ }
+ });
+ }
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java b/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java
index 91ed04e..cf3eb15 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java
@@ -28,6 +28,7 @@ import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.pojo.Schema;
import com.google.common.base.Preconditions;
@@ -43,19 +44,22 @@ public class FlightHolder implements AutoCloseable {
private final FlightDescriptor descriptor;
private final Schema schema;
private final List<Stream> streams = new CopyOnWriteArrayList<>();
+ private final DictionaryProvider dictionaryProvider;
/**
* Creates a new instance.
- *
- * @param allocator The allocator to use for allocating buffers to store data.
+ * @param allocator The allocator to use for allocating buffers to store data.
* @param descriptor The descriptor for the streams.
* @param schema The schema for the stream.
+ * @param dictionaryProvider The dictionary provider for the stream.
*/
- public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema) {
+ public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema,
+ DictionaryProvider dictionaryProvider) {
Preconditions.checkArgument(!descriptor.isCommand());
this.allocator = allocator.newChildAllocator(descriptor.toString(), 0, Long.MAX_VALUE);
this.descriptor = descriptor;
this.schema = schema;
+ this.dictionaryProvider = dictionaryProvider;
}
/**
@@ -72,8 +76,8 @@ public class FlightHolder implements AutoCloseable {
* Adds a new streams which clients can populate via the returned object.
*/
public Stream.StreamCreator addStream(Schema schema) {
- Preconditions.checkArgument(schema.equals(schema), "Stream schema inconsistent with existing schema.");
- return new Stream.StreamCreator(schema, allocator, t -> {
+ Preconditions.checkArgument(this.schema.equals(schema), "Stream schema inconsistent with existing schema.");
+ return new Stream.StreamCreator(schema, dictionaryProvider, allocator, t -> {
synchronized (streams) {
streams.add(t);
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
index 452faa1..59324b3 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
@@ -17,7 +17,6 @@
package org.apache.arrow.flight.example;
-import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@@ -29,15 +28,14 @@ import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.example.Stream.StreamCreator;
-import org.apache.arrow.flight.impl.Flight.PutResult;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
-import org.apache.arrow.vector.types.pojo.Schema;
/**
* A FlightProducer that hosts an in memory store of Arrow buffers.
@@ -80,17 +78,6 @@ public class InMemoryStore implements FlightProducer, AutoCloseable {
return h.getStream(example);
}
- /**
- * Create a new {@link Stream} with the given schema and descriptor.
- */
- public StreamCreator putStream(final FlightDescriptor descriptor, final Schema schema) {
- final FlightHolder h = holders.computeIfAbsent(
- descriptor,
- t -> new FlightHolder(allocator, t, schema));
-
- return h.addStream(schema);
- }
-
@Override
public void listFlights(CallContext context, Criteria criteria,
StreamListener<FlightInfo> listener) {
@@ -116,25 +103,25 @@ public class InMemoryStore implements FlightProducer, AutoCloseable {
}
@Override
- public Callable<PutResult> acceptPut(CallContext context,
- final FlightStream flightStream) {
+ public Runnable acceptPut(CallContext context,
+ final FlightStream flightStream, final StreamListener<PutResult> ackStream) {
return () -> {
StreamCreator creator = null;
boolean success = false;
try (VectorSchemaRoot root = flightStream.getRoot()) {
final FlightHolder h = holders.computeIfAbsent(
flightStream.getDescriptor(),
- t -> new FlightHolder(allocator, t, flightStream.getSchema()));
+ t -> new FlightHolder(allocator, t, flightStream.getSchema(), flightStream.getDictionaryProvider()));
creator = h.addStream(flightStream.getSchema());
VectorUnloader unloader = new VectorUnloader(root);
while (flightStream.next()) {
+ ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata()));
creator.add(unloader.getRecordBatch());
}
creator.complete();
success = true;
- return PutResult.getDefaultInstance();
} finally {
if (!success) {
creator.drop();
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java b/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java
index f36b38c..2d42ed2 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java
@@ -17,6 +17,7 @@
package org.apache.arrow.flight.example;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@@ -28,18 +29,22 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
+import io.netty.buffer.ArrowBuf;
+
/**
* A collection of Arrow record batches.
*/
public class Stream implements AutoCloseable, Iterable<ArrowRecordBatch> {
private final String uuid = UUID.randomUUID().toString();
+ private final DictionaryProvider dictionaryProvider;
private final List<ArrowRecordBatch> batches;
private final Schema schema;
private final long recordCount;
@@ -53,9 +58,11 @@ public class Stream implements AutoCloseable, Iterable<ArrowRecordBatch> {
*/
public Stream(
final Schema schema,
+ final DictionaryProvider dictionaryProvider,
List<ArrowRecordBatch> batches,
long recordCount) {
this.schema = schema;
+ this.dictionaryProvider = dictionaryProvider;
this.batches = ImmutableList.copyOf(batches);
this.recordCount = recordCount;
}
@@ -82,11 +89,17 @@ public class Stream implements AutoCloseable, Iterable<ArrowRecordBatch> {
*/
public void sendTo(BufferAllocator allocator, ServerStreamListener listener) {
try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
- listener.start(root);
+ listener.start(root, dictionaryProvider);
final VectorLoader loader = new VectorLoader(root);
+ int counter = 0;
for (ArrowRecordBatch batch : batches) {
+ final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8);
+ final ArrowBuf metadata = allocator.buffer(rawMetadata.length);
+ metadata.writeBytes(rawMetadata);
loader.load(batch);
- listener.putNext();
+ // Transfers ownership of the buffer - do not free buffer ourselves
+ listener.putNext(metadata);
+ counter++;
}
listener.completed();
} catch (Exception ex) {
@@ -118,18 +131,22 @@ public class Stream implements AutoCloseable, Iterable<ArrowRecordBatch> {
private final List<ArrowRecordBatch> batches = new ArrayList<>();
private final Consumer<Stream> committer;
private long recordCount = 0;
+ private DictionaryProvider dictionaryProvider;
/**
* Creates a new instance.
*
* @param schema The schema for batches in the stream.
+ * @param dictionaryProvider The dictionary provider for the stream.
* @param allocator The allocator used to copy data permanently into the stream.
* @param committer A callback for when the the stream is ready to be finalized (no more batches).
*/
- public StreamCreator(Schema schema, BufferAllocator allocator, Consumer<Stream> committer) {
+ public StreamCreator(Schema schema, DictionaryProvider dictionaryProvider,
+ BufferAllocator allocator, Consumer<Stream> committer) {
this.allocator = allocator;
this.committer = committer;
this.schema = schema;
+ this.dictionaryProvider = dictionaryProvider;
}
/**
@@ -152,7 +169,7 @@ public class Stream implements AutoCloseable, Iterable<ArrowRecordBatch> {
* Complete building the stream (no more batches can be added).
*/
public void complete() {
- Stream stream = new Stream(schema, batches, recordCount);
+ Stream stream = new Stream(schema, dictionaryProvider, batches, recordCount);
committer.accept(stream);
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
index ccafde0..477dfdb 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
@@ -19,15 +19,18 @@ package org.apache.arrow.flight.example.integration;
import java.io.File;
import java.io.IOException;
+import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
+import org.apache.arrow.flight.AsyncPutListener;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.PutResult;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorLoader;
@@ -41,6 +44,8 @@ import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
+import io.netty.buffer.ArrowBuf;
+
/**
* An Example Flight Server that provides access to the InMemoryStore.
*/
@@ -89,15 +94,36 @@ class IntegrationTestClient {
FlightDescriptor descriptor = FlightDescriptor.path(inputPath);
VectorSchemaRoot jsonRoot;
try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator);
- VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) {
+ VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) {
jsonRoot = VectorSchemaRoot.create(root.getSchema(), allocator);
VectorUnloader unloader = new VectorUnloader(root);
VectorLoader jsonLoader = new VectorLoader(jsonRoot);
- FlightClient.ClientStreamListener stream = client.startPut(descriptor, root);
+ FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader,
+ new AsyncPutListener() {
+ int counter = 0;
+
+ @Override
+ public void onNext(PutResult val) {
+ final byte[] metadataRaw = new byte[val.getApplicationMetadata().readableBytes()];
+ val.getApplicationMetadata().readBytes(metadataRaw);
+ final String metadata = new String(metadataRaw, StandardCharsets.UTF_8);
+ if (!Integer.toString(counter).equals(metadata)) {
+ throw new RuntimeException(
+ String.format("Invalid ACK from server. Expected '%d' but got '%s'.", counter, metadata));
+ }
+ counter++;
+ }
+ });
+ int counter = 0;
while (reader.read(root)) {
- stream.putNext();
+ final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8);
+ final ArrowBuf metadata = allocator.buffer(rawMetadata.length);
+ metadata.writeBytes(rawMetadata);
+ // Transfers ownership of the buffer, so do not release it ourselves
+ stream.putNext(metadata);
jsonLoader.load(unloader.getRecordBatch());
root.clear();
+ counter++;
}
stream.completed();
// Need to call this, or exceptions from the server get swallowed
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java
index 9591cf5..b584d96 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java
@@ -17,12 +17,15 @@
package org.apache.arrow.flight.grpc;
+import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import com.google.common.base.Throwables;
+import com.google.common.io.ByteStreams;
import io.grpc.internal.ReadableBuffer;
+import io.netty.buffer.ArrowBuf;
/**
* Enable access to ReadableBuffer directly to copy data from an BufferInputStream into a target
@@ -72,4 +75,24 @@ public class GetReadableBuffer {
}
}
+ /**
+ * Helper method to read a gRPC-provided InputStream into an ArrowBuf.
+ * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}.
+ * @param buf The buffer to read into.
+ * @param size The number of bytes to read.
+ * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link #BUFFER_INPUT_STREAM}).
+ * @throws IOException if there is an error reading form the stream
+ */
+ public static void readIntoBuffer(final InputStream stream, final ArrowBuf buf, final int size,
+ final boolean fastPath) throws IOException {
+ ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null;
+ if (readableBuffer != null) {
+ readableBuffer.readBytes(buf.nioBuffer(0, size));
+ } else {
+ byte[] heapBytes = new byte[size];
+ ByteStreams.readFully(stream, heapBytes);
+ buf.writeBytes(heapBytes);
+ }
+ buf.writerIndex(size);
+ }
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
index 3cb09ef..a10d490 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
@@ -43,14 +43,15 @@ public class FlightTestUtil {
* Returns a a FlightServer (actually anything that is startable)
* that has been started bound to a random port.
*/
- public static <T> T getStartedServer(Function<Integer, T> newServerFromPort) throws IOException {
+ public static <T> T getStartedServer(Function<Location, T> newServerFromLocation) throws IOException {
IOException lastThrown = null;
T server = null;
for (int x = 0; x < 3; x++) {
final int port = 49152 + RANDOM.nextInt(5000);
+ final Location location = Location.forGrpcInsecure(LOCALHOST, port);
lastThrown = null;
try {
- server = newServerFromPort.apply(port);
+ server = newServerFromLocation.apply(location);
try {
server.getClass().getMethod("start").invoke(server);
} catch (NoSuchMethodException | IllegalAccessException e) {
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
new file mode 100644
index 0000000..ad2c58f
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Collections;
+
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+
+import org.apache.arrow.flight.FlightClient.PutListener;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import io.grpc.Status;
+import io.netty.buffer.ArrowBuf;
+
+/**
+ * Tests for application-specific metadata support in Flight.
+ */
+public class TestApplicationMetadata {
+
+ /**
+ * Ensure that a client can read the metadata sent from the server.
+ */
+ @Test
+ // This test is consistently flaky on CI, unfortunately.
+ @Ignore
+ public void retrieveMetadata() {
+ test((allocator, client) -> {
+ try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ byte i = 0;
+ while (stream.next()) {
+ final IntVector vector = (IntVector) stream.getRoot().getVector("a");
+ Assert.assertEquals(1, vector.getValueCount());
+ Assert.assertEquals(10, vector.get(0));
+ Assert.assertEquals(i, stream.getLatestMetadata().getByte(0));
+ i++;
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Ensure that a client can send metadata to the server.
+ */
+ @Test
+ @Ignore
+ public void uploadMetadataAsync() {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ test((allocator, client) -> {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+
+ final PutListener listener = new AsyncPutListener() {
+ int counter = 0;
+
+ @Override
+ public void onNext(PutResult val) {
+ Assert.assertNotNull(val);
+ Assert.assertEquals(counter, val.getApplicationMetadata().getByte(0));
+ counter++;
+ }
+ };
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener);
+
+ root.allocateNew();
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ writer.putNext(metadata);
+ }
+ writer.completed();
+ // Must attempt to retrieve the result to get any server-side errors.
+ writer.getResult();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Ensure that a client can send metadata to the server. Uses the synchronous API.
+ */
+ @Test
+ @Ignore
+ public void uploadMetadataSync() {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ test((allocator, client) -> {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final SyncPutListener listener = new SyncPutListener()) {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener);
+
+ root.allocateNew();
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ writer.putNext(metadata);
+ try (final PutResult message = listener.poll(5000, TimeUnit.SECONDS)) {
+ Assert.assertNotNull(message);
+ Assert.assertEquals(i, message.getApplicationMetadata().getByte(0));
+ } catch (InterruptedException | ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ writer.completed();
+ // Must attempt to retrieve the result to get any server-side errors.
+ writer.getResult();
+ }
+ });
+ }
+
+ /**
+ * Make sure that a {@link SyncPutListener} properly reclaims memory if ignored.
+ */
+ @Test
+ @Ignore
+ public void syncMemoryReclaimed() {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ test((allocator, client) -> {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final SyncPutListener listener = new SyncPutListener()) {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener);
+
+ root.allocateNew();
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ writer.putNext(metadata);
+ }
+ writer.completed();
+ // Must attempt to retrieve the result to get any server-side errors.
+ writer.getResult();
+ }
+ });
+ }
+
+ private void test(BiConsumer<BufferAllocator, FlightClient> fun) {
+ try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(allocator, location, new MetadataFlightProducer(allocator)).build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ fun.accept(allocator, client);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * A FlightProducer that always produces a fixed data stream with metadata on the side.
+ */
+ private static class MetadataFlightProducer extends NoOpFlightProducer {
+
+ private final BufferAllocator allocator;
+
+ public MetadataFlightProducer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ root.allocateNew();
+ listener.start(root);
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ listener.putNext(metadata);
+ }
+ listener.completed();
+ }
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream stream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ try {
+ byte current = 0;
+ while (stream.next()) {
+ final ArrowBuf metadata = stream.getLatestMetadata();
+ if (current != metadata.getByte(0)) {
+ ackStream.onError(Status.INVALID_ARGUMENT.withDescription(String
+ .format("Metadata does not match expected value; got %d but expected %d.", metadata.getByte(0),
+ current)).asRuntimeException());
+ return;
+ }
+ ackStream.onNext(PutResult.metadata(metadata));
+ current++;
+ }
+ if (current != 10) {
+ throw new IllegalArgumentException("Wrong number of messages sent.");
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ };
+ }
+ }
+}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
index 1b40e7e..d0e26e1 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java
@@ -46,29 +46,27 @@ public class TestBackPressure {
try (
final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final PerformanceTestServer server = FlightTestUtil.getStartedServer(
- (port) -> (new PerformanceTestServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port))));
+ (location) -> (new PerformanceTestServer(a, location)));
final FlightClient client = FlightClient.builder(a, server.getLocation()).build()
) {
- FlightStream fs1 = client.getStream(client.getInfo(
+ try (FlightStream fs1 = client.getStream(client.getInfo(
TestPerf.getPerfFlightDescriptor(110L * BATCH_SIZE, BATCH_SIZE, 1))
- .getEndpoints().get(0).getTicket());
- consume(fs1, 10);
+ .getEndpoints().get(0).getTicket())) {
+ consume(fs1, 10);
- // stop consuming fs1 but make sure we can consume a large amount of fs2.
- FlightStream fs2 = client.getStream(client.getInfo(
- TestPerf.getPerfFlightDescriptor(200L * BATCH_SIZE, BATCH_SIZE, 1))
- .getEndpoints().get(0).getTicket());
- consume(fs2, 100);
+ // stop consuming fs1 but make sure we can consume a large amount of fs2.
+ try (FlightStream fs2 = client.getStream(client.getInfo(
+ TestPerf.getPerfFlightDescriptor(200L * BATCH_SIZE, BATCH_SIZE, 1))
+ .getEndpoints().get(0).getTicket())) {
+ consume(fs2, 100);
- consume(fs1, 100);
- consume(fs2, 100);
-
- consume(fs1);
- consume(fs2);
-
- fs1.close();
- fs2.close();
+ consume(fs1, 100);
+ consume(fs2, 100);
+ consume(fs1);
+ consume(fs2);
+ }
+ }
}
}
@@ -92,27 +90,28 @@ public class TestBackPressure {
ServerStreamListener listener) {
int batches = 0;
final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a", MinorType.BIGINT.getType())));
- VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator);
- listener.start(root);
- while (true) {
- while (!listener.isReady()) {
- try {
- Thread.sleep(1);
- sleepTime.addAndGet(1L);
- } catch (InterruptedException ignore) {
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) {
+ listener.start(root);
+ while (true) {
+ while (!listener.isReady()) {
+ try {
+ Thread.sleep(1);
+ sleepTime.addAndGet(1L);
+ } catch (InterruptedException ignore) {
+ }
}
- }
- if (batches > 100) {
- root.clear();
- listener.completed();
- return;
- }
+ if (batches > 100) {
+ root.clear();
+ listener.completed();
+ return;
+ }
- root.allocateNew();
- root.setRowCount(4095);
- listener.putNext();
- batches++;
+ root.allocateNew();
+ root.setRowCount(4095);
+ listener.putNext();
+ batches++;
+ }
}
}
};
@@ -121,16 +120,15 @@ public class TestBackPressure {
try (
BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE);
FlightServer server =
- FlightTestUtil.getStartedServer(
- (port) -> FlightServer.builder(serverAllocator, Location.forGrpcInsecure("localhost", port), producer)
- .build());
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(serverAllocator, location, producer)
+ .build());
BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE);
FlightClient client =
FlightClient
- .builder(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()))
- .build()
+ .builder(clientAllocator, server.getLocation())
+ .build();
+ FlightStream stream = client.getStream(new Ticket(new byte[1]))
) {
- FlightStream stream = client.getStream(new Ticket(new byte[1]));
VectorSchemaRoot root = stream.getRoot();
root.clear();
Thread.sleep(wait);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index f8413b0..abc5a2c 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -19,14 +19,12 @@ package org.apache.arrow.flight;
import java.net.URISyntaxException;
import java.util.Iterator;
-import java.util.concurrent.Callable;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.apache.arrow.flight.FlightClient.ClientStreamListener;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType;
-import org.apache.arrow.flight.impl.Flight.PutResult;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
@@ -98,7 +96,8 @@ public class TestBasicOperation {
IntVector iv = new IntVector("c1", a);
VectorSchemaRoot root = VectorSchemaRoot.of(iv);
- ClientStreamListener listener = c.startPut(FlightDescriptor.path("hello"), root);
+ ClientStreamListener listener = c
+ .startPut(FlightDescriptor.path("hello"), root, new AsyncPutListener());
//batch 1
root.allocateNew();
@@ -155,12 +154,11 @@ public class TestBasicOperation {
Producer producer = new Producer(a);
FlightServer s =
FlightTestUtil.getStartedServer(
- (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build()
+ (location) -> FlightServer.builder(a, location, producer).build()
)) {
try (
- FlightClient c = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))
- .build()
+ FlightClient c = FlightClient.builder(a, s.getLocation()).build()
) {
try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) {
consumer.accept(c, testAllocator);
@@ -199,14 +197,12 @@ public class TestBasicOperation {
}
@Override
- public Callable<PutResult> acceptPut(CallContext context,
- FlightStream flightStream) {
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
return () -> {
try (VectorSchemaRoot root = flightStream.getRoot()) {
while (flightStream.next()) {
}
- return PutResult.getDefaultInstance();
}
};
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
index 71d9986..3acb947 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
@@ -69,11 +69,8 @@ public class TestCallOptions {
BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
Producer producer = new Producer(a);
FlightServer s =
- FlightTestUtil.getStartedServer(
- (port) -> FlightServer.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer)
- .build());
- FlightClient client = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort()))
- .build()) {
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build());
+ FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
testFn.accept(client);
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
index 9913548..629b6f5 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
@@ -19,10 +19,8 @@ package org.apache.arrow.flight;
import java.util.Arrays;
import java.util.List;
-import java.util.concurrent.Callable;
import java.util.stream.Stream;
-import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
@@ -43,13 +41,11 @@ public class TestLargeMessage {
try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final Producer producer = new Producer(a);
final FlightServer s =
- FlightTestUtil.getStartedServer(
- (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build())) {
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build())) {
- try (FlightClient client = FlightClient
- .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build()) {
- FlightStream stream = client.getStream(new Ticket(new byte[]{}));
- try (VectorSchemaRoot root = stream.getRoot()) {
+ try (FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
+ try (FlightStream stream = client.getStream(new Ticket(new byte[]{}));
+ VectorSchemaRoot root = stream.getRoot()) {
while (stream.next()) {
for (final Field field : root.getSchema().getFields()) {
int value = 0;
@@ -61,7 +57,6 @@ public class TestLargeMessage {
}
}
}
- stream.close();
}
}
}
@@ -74,18 +69,17 @@ public class TestLargeMessage {
try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final Producer producer = new Producer(a);
final FlightServer s =
- FlightTestUtil.getStartedServer(
- (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build()
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()
)) {
- try (FlightClient client = FlightClient
- .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build();
+ try (FlightClient client = FlightClient.builder(a, s.getLocation()).build();
BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE);
VectorSchemaRoot root = generateData(testAllocator)) {
- final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root);
+ final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root,
+ new AsyncPutListener());
listener.putNext();
listener.completed();
- Assert.assertEquals(listener.getResult(), Flight.PutResult.getDefaultInstance());
+ listener.getResult();
}
}
}
@@ -141,14 +135,12 @@ public class TestLargeMessage {
}
@Override
- public Callable<Flight.PutResult> acceptPut(CallContext context,
- FlightStream flightStream) {
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
return () -> {
try (VectorSchemaRoot root = flightStream.getRoot()) {
while (flightStream.next()) {
;
}
- return Flight.PutResult.getDefaultInstance();
}
};
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
index c22304d..b9d4dea 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
@@ -96,9 +96,9 @@ public class TestTls {
Producer producer = new Producer();
FlightServer s =
FlightTestUtil.getStartedServer(
- (port) -> {
+ (location) -> {
try {
- return FlightServer.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, port), producer)
+ return FlightServer.builder(a, location, producer)
.useTls(certKey.cert, certKey.key)
.build();
} catch (IOException e) {
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
index 39b2924..54bbadb 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
@@ -29,7 +29,6 @@ import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
-import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
@@ -119,9 +118,9 @@ public class TestAuth {
}
};
- server = FlightTestUtil.getStartedServer((port) -> FlightServer.builder(
+ server = FlightTestUtil.getStartedServer((location) -> FlightServer.builder(
allocator,
- Location.forGrpcInsecure("localhost", port),
+ location,
new NoOpFlightProducer() {
@Override
public void listFlights(CallContext context, Criteria criteria,
@@ -150,8 +149,7 @@ public class TestAuth {
listener.completed();
}
}).authHandler(new BasicServerAuthHandler(validator)).build());
- client = FlightClient.builder(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()))
- .build();
+ client = FlightClient.builder(allocator, server.getLocation()).build();
}
@After
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
index 097c92c..fb157f4 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
@@ -19,6 +19,7 @@ package org.apache.arrow.flight.example;
import java.io.IOException;
+import org.apache.arrow.flight.AsyncPutListener;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClient.ClientStreamListener;
import org.apache.arrow.flight.FlightDescriptor;
@@ -33,12 +34,12 @@ import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.junit.After;
import org.junit.Before;
+import org.junit.Ignore;
import org.junit.Test;
/**
* Ensure that example server supports get and put.
*/
-@org.junit.Ignore
public class TestExampleServer {
private BufferAllocator allocator;
@@ -68,6 +69,7 @@ public class TestExampleServer {
}
@Test
+ @Ignore
public void putStream() {
BufferAllocator a = caseAllocator;
final int size = 10;
@@ -75,7 +77,8 @@ public class TestExampleServer {
IntVector iv = new IntVector("c1", a);
VectorSchemaRoot root = VectorSchemaRoot.of(iv);
- ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root);
+ ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root,
+ new AsyncPutListener());
//batch 1
root.allocateNew();
@@ -102,10 +105,13 @@ public class TestExampleServer {
listener.getResult();
FlightInfo info = client.getInfo(FlightDescriptor.path("hello"));
- FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket());
- VectorSchemaRoot newRoot = stream.getRoot();
- while (stream.next()) {
- newRoot.clear();
+ try (final FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket())) {
+ VectorSchemaRoot newRoot = stream.getRoot();
+ while (stream.next()) {
+ newRoot.clear();
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
}
}
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
index d8d6e67..72099b9 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
@@ -21,21 +21,14 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
-import java.util.concurrent.Callable;
-import org.apache.arrow.flight.Action;
-import org.apache.arrow.flight.ActionType;
-import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
-import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
-import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
-import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Ticket;
-import org.apache.arrow.flight.impl.Flight.PutResult;
import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf;
import org.apache.arrow.flight.perf.impl.PerfOuterClass.Token;
import org.apache.arrow.memory.BufferAllocator;
@@ -79,7 +72,7 @@ public class PerformanceTestServer implements AutoCloseable {
AutoCloseables.close(flightServer, allocator);
}
- private final class PerfProducer implements FlightProducer {
+ private final class PerfProducer extends NoOpFlightProducer {
@Override
public void getStream(CallContext context, Ticket ticket,
@@ -146,11 +139,6 @@ public class PerformanceTestServer implements AutoCloseable {
}
@Override
- public void listFlights(CallContext context, Criteria criteria,
- StreamListener<FlightInfo> listener) {
- }
-
- @Override
public FlightInfo getFlightInfo(CallContext context,
FlightDescriptor descriptor) {
try {
@@ -181,24 +169,6 @@ public class PerformanceTestServer implements AutoCloseable {
throw new RuntimeException(e);
}
}
-
- @Override
- public Callable<PutResult> acceptPut(CallContext context,
- FlightStream flightStream) {
- return null;
- }
-
- @Override
- public void doAction(CallContext context, Action action,
- StreamListener<Result> listener) {
- listener.onCompleted();
- }
-
- @Override
- public void listActions(CallContext context,
- StreamListener<ActionType> listener) {
- }
-
}
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
index a9b9d60..c23c793 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
@@ -28,7 +28,6 @@ import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
-import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf;
import org.apache.arrow.memory.BufferAllocator;
@@ -81,8 +80,7 @@ public class TestPerf {
try (
final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
final PerformanceTestServer server =
- FlightTestUtil.getStartedServer((port) -> new PerformanceTestServer(a,
- Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port)));
+ FlightTestUtil.getStartedServer((location) -> new PerformanceTestServer(a, location));
final FlightClient client = FlightClient.builder(a, server.getLocation()).build();
) {
final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2));
@@ -93,11 +91,13 @@ public class TestPerf {
.map(t -> pool.submit(t))
.collect(Collectors.toList());
- Futures.whenAllSucceed(results);
- Result r = new Result();
- for (ListenableFuture<Result> f : results) {
- r.add(f.get());
- }
+ final Result r = Futures.whenAllSucceed(results).call(() -> {
+ Result res = new Result();
+ for (ListenableFuture<Result> f : results) {
+ res.add(f.get());
+ }
+ return res;
+ }).get();
double seconds = r.nanos * 1.0d / 1000 / 1000 / 1000;
System.out.println(String.format(
@@ -127,28 +127,29 @@ public class TestPerf {
public Result call() throws Exception {
final Result r = new Result();
Stopwatch watch = Stopwatch.createStarted();
- FlightStream stream = client.getStream(ticket);
- final VectorSchemaRoot root = stream.getRoot();
- try {
- BigIntVector a = (BigIntVector) root.getVector("a");
- while (stream.next()) {
- int rows = root.getRowCount();
- long aSum = r.aSum;
- for (int i = 0; i < rows; i++) {
- if (VALIDATE) {
- aSum += a.get(i);
+ try (final FlightStream stream = client.getStream(ticket)) {
+ final VectorSchemaRoot root = stream.getRoot();
+ try {
+ BigIntVector a = (BigIntVector) root.getVector("a");
+ while (stream.next()) {
+ int rows = root.getRowCount();
+ long aSum = r.aSum;
+ for (int i = 0; i < rows; i++) {
+ if (VALIDATE) {
+ aSum += a.get(i);
+ }
}
+ r.bytes += rows * 32;
+ r.rows += rows;
+ r.aSum = aSum;
+ r.batches++;
}
- r.bytes += rows * 32;
- r.rows += rows;
- r.aSum = aSum;
- r.batches++;
- }
- r.nanos = watch.elapsed(TimeUnit.NANOSECONDS);
- return r;
- } finally {
- root.clear();
+ r.nanos = watch.elapsed(TimeUnit.NANOSECONDS);
+ return r;
+ } finally {
+ root.clear();
+ }
}
}
diff --git a/java/pom.xml b/java/pom.xml
index 540b41b..916b7f1 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -33,7 +33,7 @@
<dep.junit.jupiter.version>5.4.0</dep.junit.jupiter.version>
<dep.slf4j.version>1.7.25</dep.slf4j.version>
<dep.guava.version>20.0</dep.guava.version>
- <dep.netty.version>4.1.22.Final</dep.netty.version>
+ <dep.netty.version>4.1.27.Final</dep.netty.version>
<dep.jackson.version>2.9.8</dep.jackson.version>
<dep.hadoop.version>2.7.1</dep.hadoop.version>
<dep.fbs.version>1.9.0</dep.fbs.version>
diff --git a/python/examples/flight/server.py b/python/examples/flight/server.py
index 72ed590..3b69972 100644
--- a/python/examples/flight/server.py
+++ b/python/examples/flight/server.py
@@ -77,7 +77,7 @@ class FlightServer(pyarrow.flight.FlightServerBase):
return None
return pyarrow.flight.RecordBatchStream(self.flights[key])
- def list_actions(self):
+ def list_actions(self, context):
return [
("clear", "Clear the stored flights."),
("shutdown", "Shut down this server."),
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 7ca83a9..7fc4ed4 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -44,6 +44,15 @@ cdef class FlightCallOptions:
CFlightCallOptions options
def __init__(self, timeout=None):
+ """Create call options.
+
+ Parameters
+ ----------
+ timeout : float or None
+ A timeout for the call, in seconds. None means that the
+ timeout defaults to an implementation-specific value.
+
+ """
if timeout is not None:
self.options.timeout = CTimeoutDuration(timeout)
@@ -70,14 +79,24 @@ cdef class Action:
CAction action
def __init__(self, action_type, buf):
+ """Create an action from a type and a buffer.
+
+ Parameters
+ ----------
+ action_type : bytes or str
+ buf : Buffer or bytes-like object
+ """
self.action.type = tobytes(action_type)
self.action.body = pyarrow_unwrap_buffer(as_buffer(buf))
@property
def type(self):
+ """The action type."""
return frombytes(self.action.type)
+ @property
def body(self):
+ """The action body (arguments for the action)."""
return pyarrow_wrap_buffer(self.action.body)
@staticmethod
@@ -92,10 +111,16 @@ _ActionType = collections.namedtuple('_ActionType', ['type', 'description'])
class ActionType(_ActionType):
- """A type of action executable on a Flight service."""
+ """A type of action that is executable on a Flight service."""
def make_action(self, buf):
- """Create an Action with this type."""
+ """Create an Action with this type.
+
+ Parameters
+ ----------
+ buf : obj
+ An Arrow buffer or Python bytes or bytes-like object.
+ """
return Action(self.type, buf)
@@ -105,6 +130,12 @@ cdef class Result:
unique_ptr[CResult] result
def __init__(self, buf):
+ """Create a new result.
+
+ Parameters
+ ----------
+ buf : Buffer or bytes-like object
+ """
self.result.reset(new CResult())
self.result.get().body = pyarrow_unwrap_buffer(as_buffer(buf))
@@ -115,6 +146,23 @@ cdef class Result:
class DescriptorType(enum.Enum):
+ """
+ The type of a FlightDescriptor.
+
+ Attributes
+ ----------
+
+ UNKNOWN
+ An unknown descriptor type.
+
+ PATH
+ A Flight stream represented by a path.
+
+ CMD
+ A Flight stream represented by an application-defined command.
+
+ """
+
UNKNOWN = 0
PATH = 1
CMD = 2
@@ -151,6 +199,7 @@ cdef class FlightDescriptor:
@property
def descriptor_type(self):
+ """Get the type of this descriptor."""
if self.descriptor.type == CDescriptorTypeUnknown:
return DescriptorType.UNKNOWN
elif self.descriptor.type == CDescriptorTypePath:
@@ -309,6 +358,7 @@ cdef class FlightEndpoint:
@property
def ticket(self):
+ """Get the ticket in this endpoint."""
return Ticket(self.endpoint.ticket.ticket)
@property
@@ -400,12 +450,149 @@ cdef class FlightInfo:
return result
-cdef class FlightRecordBatchReader(_CRecordBatchReader, _ReadPandasOption):
+cdef class FlightStreamChunk:
+ """A RecordBatch with application metadata on the side."""
+ cdef:
+ CFlightStreamChunk chunk
+
+ @property
+ def data(self):
+ if self.chunk.data == NULL:
+ return None
+ return pyarrow_wrap_batch(self.chunk.data)
+
+ @property
+ def app_metadata(self):
+ if self.chunk.app_metadata == NULL:
+ return None
+ return pyarrow_wrap_buffer(self.chunk.app_metadata)
+
+ def __iter__(self):
+ return iter((self.data, self.app_metadata))
+
+
+cdef class _MetadataRecordBatchReader:
+ """A reader for Flight streams."""
+
+ # Needs to be separate class so the "real" class can subclass the
+ # pure-Python mixin class
+
cdef dict __dict__
+ cdef shared_ptr[CMetadataRecordBatchReader] reader
+
+ cdef readonly:
+ Schema schema
+
+
+cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader,
+ _ReadPandasOption):
+ """A reader for Flight streams."""
+
+ def __iter__(self):
+ while True:
+ yield self.read_chunk()
+
+ def read_all(self):
+ """Read the entire contents of the stream as a Table."""
+ cdef:
+ shared_ptr[CTable] c_table
+ with nogil:
+ check_status(self.reader.get().ReadAll(&c_table))
+ return pyarrow_wrap_table(c_table)
+
+ def read_chunk(self):
+ """Read the next RecordBatch along with any metadata.
+
+ Returns
+ -------
+ data : RecordBatch
+ The next RecordBatch in the stream.
+ app_metadata : Buffer or None
+ Application-specific metadata for the batch as defined by
+ Flight.
+
+ Raises
+ ------
+ StopIteration
+ when the stream is finished
+ """
+ cdef:
+ FlightStreamChunk chunk = FlightStreamChunk()
+
+ with nogil:
+ check_status(self.reader.get().Next(&chunk.chunk))
+
+ if chunk.chunk.data == NULL:
+ raise StopIteration
+
+ return chunk
+
+
+cdef class FlightStreamReader(MetadataRecordBatchReader):
+ """A reader that can also be canceled."""
+
+ def cancel(self):
+ """Cancel the read operation."""
+ with nogil:
+ (<CFlightStreamReader*> self.reader.get()).Cancel()
-cdef class FlightRecordBatchWriter(_CRecordBatchWriter):
- pass
+cdef class FlightStreamWriter(_CRecordBatchWriter):
+ """A RecordBatchWriter that also allows writing application metadata."""
+
+ def write_with_metadata(self, RecordBatch batch, buf):
+ """Write a RecordBatch along with Flight metadata.
+
+ Parameters
+ ----------
+ batch : RecordBatch
+ The next RecordBatch in the stream.
+ buf : Buffer
+ Application-specific metadata for the batch as defined by
+ Flight.
+ """
+ cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf))
+ with nogil:
+ check_status(
+ (<CFlightStreamWriter*> self.writer.get())
+ .WriteWithMetadata(deref(batch.batch),
+ c_buf,
+ 1))
+
+
+cdef class FlightMetadataReader:
+ """A reader for Flight metadata messages sent during a DoPut."""
+
+ cdef:
+ unique_ptr[CFlightMetadataReader] reader
+
+ def read(self):
+ """Read the next metadata message."""
+ cdef shared_ptr[CBuffer] buf
+ with nogil:
+ check_status(self.reader.get().ReadMetadata(&buf))
+ if buf == NULL:
+ return None
+ return pyarrow_wrap_buffer(buf)
+
+
+cdef class FlightMetadataWriter:
+ """A sender for Flight metadata messages during a DoPut."""
+
+ cdef:
+ unique_ptr[CFlightMetadataWriter] writer
+
+ def write(self, message):
+ """Write the next metadata message.
+
+ Parameters
+ ----------
+ message : Buffer
+ """
+ cdef shared_ptr[CBuffer] buf = \
+ pyarrow_unwrap_buffer(as_buffer(message))
+ with nogil:
+ check_status(self.writer.get().WriteMetadata(deref(buf)))
cdef class FlightClient:
@@ -451,7 +638,15 @@ cdef class FlightClient:
return result
def authenticate(self, auth_handler, options: FlightCallOptions = None):
- """Authenticate to the server."""
+ """Authenticate to the server.
+
+ Parameters
+ ----------
+ auth_handler : ClientAuthHandler
+ The authentication mechanism to use.
+ options : FlightCallOptions
+ Options for this call.
+ """
cdef:
unique_ptr[CClientAuthHandler] handler
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
@@ -539,34 +734,53 @@ cdef class FlightClient:
return result
def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
- """Request the data for a flight."""
+ """Request the data for a flight.
+
+ Returns
+ -------
+ reader : FlightStreamReader
+ """
cdef:
- unique_ptr[CRecordBatchReader] reader
+ unique_ptr[CFlightStreamReader] reader
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(self.client.get().DoGet(
- deref(c_options), ticket.ticket, &reader))
- result = FlightRecordBatchReader()
+ check_status(
+ self.client.get().DoGet(
+ deref(c_options), ticket.ticket, &reader))
+ result = FlightStreamReader()
result.reader.reset(reader.release())
+ result.schema = pyarrow_wrap_schema(result.reader.get().schema())
return result
def do_put(self, descriptor: FlightDescriptor, schema: Schema,
options: FlightCallOptions = None):
- """Upload data to a flight."""
+ """Upload data to a flight.
+
+ Returns
+ -------
+ writer : FlightStreamWriter
+ reader : FlightMetadataReader
+ """
cdef:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
- unique_ptr[CRecordBatchWriter] writer
+ unique_ptr[CFlightStreamWriter] writer
+ unique_ptr[CFlightMetadataReader] metadata_reader
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
+ FlightMetadataReader reader = FlightMetadataReader()
with nogil:
check_status(self.client.get().DoPut(
- deref(c_options), c_descriptor, c_schema, &writer))
- result = FlightRecordBatchWriter()
+ deref(c_options),
+ c_descriptor,
+ c_schema,
+ &writer,
+ &reader.reader))
+ result = FlightStreamWriter()
result.writer.reset(writer.release())
- return result
+ return result, reader
cdef class FlightDataStream:
@@ -809,11 +1023,22 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
return
+ if isinstance(result, (list, tuple)):
+ result, metadata = result
+ else:
+ result, metadata = result, None
+
if isinstance(result, (Table, _CRecordBatchReader)):
+ if metadata:
+ raise ValueError("Can only return metadata alongside a "
+ "RecordBatch.")
result = RecordBatchStream(result)
stream_schema = pyarrow_wrap_schema(stream.schema)
if isinstance(result, FlightDataStream):
+ if metadata:
+ raise ValueError("Can only return metadata alongside a "
+ "RecordBatch.")
data_stream = unique_ptr[CFlightDataStream](
(<FlightDataStream> result).to_stream())
substream_schema = pyarrow_wrap_schema(data_stream.get().schema())
@@ -838,6 +1063,8 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
deref(batch.batch),
c_default_memory_pool(),
&payload.ipc_message))
+ if metadata:
+ payload.app_metadata = pyarrow_unwrap_buffer(as_buffer(metadata))
else:
raise TypeError("GeneratorStream must be initialized with "
"an iterator of FlightDataStream, Table, "
@@ -880,17 +1107,22 @@ cdef void _get_flight_info(void* self, const CServerCallContext& context,
cdef void _do_put(void* self, const CServerCallContext& context,
- unique_ptr[CFlightMessageReader] reader) except *:
+ unique_ptr[CFlightMessageReader] reader,
+ unique_ptr[CFlightMetadataWriter] writer) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
- FlightRecordBatchReader py_reader = FlightRecordBatchReader()
+ MetadataRecordBatchReader py_reader = MetadataRecordBatchReader()
+ FlightMetadataWriter py_writer = FlightMetadataWriter()
FlightDescriptor descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
descriptor.descriptor = reader.get().descriptor()
py_reader.reader.reset(reader.release())
+ py_reader.schema = pyarrow_wrap_schema(
+ py_reader.reader.get().schema())
+ py_writer.writer.reset(writer.release())
(<object> self).do_put(ServerCallContext.wrap(context), descriptor,
- py_reader)
+ py_reader, py_writer)
cdef void _do_get(void* self, const CServerCallContext& context,
@@ -943,7 +1175,7 @@ cdef void _list_actions(void* self, const CServerCallContext& context,
cdef:
CActionType action_type
# Method should return a list of ActionTypes or similar tuple
- result = (<object> self).list_actions()
+ result = (<object> self).list_actions(ServerCallContext.wrap(context))
for action in result:
action_type.type = tobytes(action[0])
action_type.description = tobytes(action[1])
@@ -990,10 +1222,25 @@ cdef void _get_token(void* self, c_string* token) except *:
cdef class ServerAuthHandler:
- """Authentication middleware for a server."""
+ """Authentication middleware for a server.
+
+ To implement an authentication mechanism, subclass this class and
+ override its methods.
+
+ """
def authenticate(self, outgoing, incoming):
- """Conduct the handshake with the client."""
+ """Conduct the handshake with the client.
+
+ May raise an error if the client cannot authenticate.
+
+ Parameters
+ ----------
+ outgoing : ServerAuthSender
+ A channel to send messages to the client.
+ incoming : ServerAuthReader
+ A channel to read messages from the client.
+ """
raise NotImplementedError
def is_valid(self, token):
@@ -1003,6 +1250,11 @@ cdef class ServerAuthHandler:
name the peer) or raise an exception (if the token is
invalid).
+ Parameters
+ ----------
+ token : bytes
+ The authentication token from the client.
+
"""
raise NotImplementedError
@@ -1017,7 +1269,15 @@ cdef class ClientAuthHandler:
"""Authentication plugin for a client."""
def authenticate(self, outgoing, incoming):
- """Conduct the handshake with the server."""
+ """Conduct the handshake with the server.
+
+ Parameters
+ ----------
+ outgoing : ClientAuthSender
+ A channel to send messages to the server.
+ incoming : ClientAuthReader
+ A channel to read messages from the server.
+ """
raise NotImplementedError
def get_token(self):
@@ -1032,12 +1292,26 @@ cdef class ClientAuthHandler:
cdef class FlightServerBase:
- """A Flight service definition."""
+ """A Flight service definition.
+
+ Override methods to define your Flight service.
+
+ """
cdef:
unique_ptr[PyFlightServer] server
def run(self, location, auth_handler=None, tls_certificates=None):
+ """Start this server.
+
+ Parameters
+ ----------
+ location : Location
+ auth_handler : ServerAuthHandler
+ An authentication mechanism to use. May be None.
+ tls_certificates : list
+ A list of (certificate, key) pairs.
+ """
cdef:
PyFlightServerVtable vtable = PyFlightServerVtable()
PyFlightServer* c_server
@@ -1078,7 +1352,8 @@ cdef class FlightServerBase:
def get_flight_info(self, context, descriptor):
raise NotImplementedError
- def do_put(self, context, descriptor, reader):
+ def do_put(self, context, descriptor, reader,
+ writer: FlightMetadataWriter):
raise NotImplementedError
def do_get(self, context, ticket):
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 61e9571..49a5153 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -112,22 +112,50 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CSimpleFlightListing" arrow::flight::SimpleFlightListing":
CSimpleFlightListing(vector[CFlightInfo]&& info)
- cdef cppclass CFlightMessageReader \
- " arrow::flight::FlightMessageReader"(CRecordBatchReader):
- CFlightDescriptor& descriptor()
-
cdef cppclass CFlightPayload" arrow::flight::FlightPayload":
shared_ptr[CBuffer] descriptor
+ shared_ptr[CBuffer] app_metadata
CIpcPayload ipc_message
cdef cppclass CFlightDataStream" arrow::flight::FlightDataStream":
shared_ptr[CSchema] schema()
CStatus Next(CFlightPayload*)
+ cdef cppclass CFlightStreamChunk" arrow::flight::FlightStreamChunk":
+ CFlightStreamChunk()
+ shared_ptr[CRecordBatch] data
+ shared_ptr[CBuffer] app_metadata
+
+ cdef cppclass CMetadataRecordBatchReader \
+ " arrow::flight::MetadataRecordBatchReader":
+ shared_ptr[CSchema] schema()
+ CStatus Next(CFlightStreamChunk* out)
+ CStatus ReadAll(shared_ptr[CTable]* table)
+
+ cdef cppclass CFlightStreamReader \
+ " arrow::flight::FlightStreamReader"(CMetadataRecordBatchReader):
+ void Cancel()
+
+ cdef cppclass CFlightMessageReader \
+ " arrow::flight::FlightMessageReader"(CMetadataRecordBatchReader):
+ CFlightDescriptor& descriptor()
+
+ cdef cppclass CFlightStreamWriter \
+ " arrow::flight::FlightStreamWriter"(CRecordBatchWriter):
+ CStatus WriteWithMetadata(const CRecordBatch& batch,
+ shared_ptr[CBuffer] app_metadata,
+ c_bool allow_64bit)
+
cdef cppclass CRecordBatchStream \
" arrow::flight::RecordBatchStream"(CFlightDataStream):
CRecordBatchStream(shared_ptr[CRecordBatchReader]& reader)
+ cdef cppclass CFlightMetadataReader" arrow::flight::FlightMetadataReader":
+ CStatus ReadMetadata(shared_ptr[CBuffer]* out)
+
+ cdef cppclass CFlightMetadataWriter" arrow::flight::FlightMetadataWriter":
+ CStatus WriteMetadata(const CBuffer& message)
+
cdef cppclass CServerAuthReader" arrow::flight::ServerAuthReader":
CStatus Read(c_string* token)
@@ -193,11 +221,12 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
unique_ptr[CFlightInfo]* info)
CStatus DoGet(CFlightCallOptions& options, CTicket& ticket,
- unique_ptr[CRecordBatchReader]* stream)
+ unique_ptr[CFlightStreamReader]* stream)
CStatus DoPut(CFlightCallOptions& options,
CFlightDescriptor& descriptor,
shared_ptr[CSchema]& schema,
- unique_ptr[CRecordBatchWriter]* stream)
+ unique_ptr[CFlightStreamWriter]* stream,
+ unique_ptr[CFlightMetadataReader]* reader)
# Callbacks for implementing Flight servers
@@ -209,7 +238,8 @@ ctypedef void cb_get_flight_info(object, const CServerCallContext&,
const CFlightDescriptor&,
unique_ptr[CFlightInfo]*)
ctypedef void cb_do_put(object, const CServerCallContext&,
- unique_ptr[CFlightMessageReader])
+ unique_ptr[CFlightMessageReader],
+ unique_ptr[CFlightMetadataWriter])
ctypedef void cb_do_get(object, const CServerCallContext&,
const CTicket&,
unique_ptr[CFlightDataStream]*)
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 3088a7a..3f83a1c 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -20,13 +20,13 @@ import base64
import contextlib
import os
import socket
+import struct
import tempfile
import threading
import time
import traceback
import pytest
-
import pyarrow as pa
from pyarrow.compat import tobytes
@@ -114,17 +114,57 @@ class ConstantFlightServer(flight.FlightServerBase):
return flight.RecordBatchStream(table)
+class MetadataFlightServer(flight.FlightServerBase):
+ """A Flight server that numbers incoming/outgoing data."""
+
+ def do_get(self, context, ticket):
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+ return flight.GeneratorStream(
+ table.schema,
+ self.number_batches(table))
+
+ def do_put(self, context, descriptor, reader, writer):
+ counter = 0
+ expected_data = [-10, -5, 0, 5, 10]
+ while True:
+ try:
+ batch, buf = reader.read_chunk()
+ assert batch.equals(pa.RecordBatch.from_arrays(
+ [pa.array([expected_data[counter]])],
+ ['a']
+ ))
+ assert buf is not None
+ client_counter, = struct.unpack('<i', buf.to_pybytes())
+ assert counter == client_counter
+ writer.write(struct.pack('<i', counter))
+ counter += 1
+ except StopIteration:
+ return
+
+ @staticmethod
+ def number_batches(table):
+ for idx, batch in enumerate(table.to_batches()):
+ buf = struct.pack('<i', idx)
+ yield batch, buf
+
+
class EchoFlightServer(flight.FlightServerBase):
"""A Flight server that returns the last data uploaded."""
- def __init__(self):
+ def __init__(self, expected_schema=None):
super(EchoFlightServer, self).__init__()
self.last_message = None
+ self.expected_schema = expected_schema
def do_get(self, context, ticket):
return flight.RecordBatchStream(self.last_message)
- def do_put(self, context, descriptor, reader):
+ def do_put(self, context, descriptor, reader, writer):
+ if self.expected_schema:
+ assert self.expected_schema == reader.schema
self.last_message = reader.read_all()
@@ -200,10 +240,23 @@ class InvalidStreamFlightServer(flight.FlightServerBase):
class SlowFlightServer(flight.FlightServerBase):
"""A Flight server that delays its responses to test timeouts."""
+ def do_get(self, context, ticket):
+ return flight.GeneratorStream(pa.schema([('a', pa.int32())]),
+ self.slow_stream())
+
def do_action(self, context, action):
time.sleep(0.5)
return iter([])
+ @staticmethod
+ def slow_stream():
+ data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
+ yield pa.Table.from_arrays(data1, names=['a'])
+ # The second message should never get sent; the client should
+ # cancel before we send this
+ time.sleep(10)
+ yield pa.Table.from_arrays(data1, names=['a'])
+
class HttpBasicServerAuthHandler(flight.ServerAuthHandler):
"""An example implementation of HTTP basic authentication."""
@@ -338,7 +391,7 @@ def flight_server(server_base, *args, **kwargs):
yield location
finally:
server_instance.shutdown()
- thread.join()
+ thread.join(3.0)
def test_flight_do_get_ints():
@@ -351,6 +404,17 @@ def test_flight_do_get_ints():
assert data.equals(table)
+@pytest.mark.pandas
+def test_do_get_ints_pandas():
+ """Try a simple do_get call."""
+ table = simple_ints_table()
+
+ with flight_server(ConstantFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ data = client.do_get(flight.Ticket(b'ints')).read_pandas()
+ assert list(data['some_ints']) == table.column(0).to_pylist()
+
+
def test_flight_do_get_dicts():
table = simple_dicts_table()
@@ -392,15 +456,23 @@ def test_flight_get_info():
reason="Unix sockets can't be tested on Windows")
def test_flight_domain_socket():
"""Try a simple do_get call over a Unix domain socket."""
- table = simple_ints_table()
-
with tempfile.NamedTemporaryFile() as sock:
sock.close()
location = flight.Location.for_grpc_unix(sock.name)
with flight_server(ConstantFlightServer,
location=location) as server_location:
client = flight.FlightClient.connect(server_location)
- data = client.do_get(flight.Ticket(b'ints')).read_all()
+
+ reader = client.do_get(flight.Ticket(b'ints'))
+ table = simple_ints_table()
+ assert reader.schema.equals(table.schema)
+ data = reader.read_all()
+ assert data.equals(table)
+
+ reader = client.do_get(flight.Ticket(b'dicts'))
+ table = simple_dicts_table()
+ assert reader.schema.equals(table.schema)
+ data = reader.read_all()
assert data.equals(table)
@@ -415,10 +487,11 @@ def test_flight_large_message():
pa.array(range(0, 10 * 1024 * 1024))
], names=['a'])
- with flight_server(EchoFlightServer) as server_location:
+ with flight_server(EchoFlightServer,
+ expected_schema=data.schema) as server_location:
client = flight.FlightClient.connect(server_location)
- writer = client.do_put(flight.FlightDescriptor.for_path('test'),
- data.schema)
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
# Write a single giant chunk
writer.write_table(data, 10 * 1024 * 1024)
writer.close()
@@ -434,8 +507,8 @@ def test_flight_generator_stream():
with flight_server(EchoStreamFlightServer) as server_location:
client = flight.FlightClient.connect(server_location)
- writer = client.do_put(flight.FlightDescriptor.for_path('test'),
- data.schema)
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
writer.write_table(data)
writer.close()
result = client.do_get(flight.Ticket(b'')).read_all()
@@ -458,7 +531,9 @@ def test_timeout_fires():
client = flight.FlightClient.connect(server_location)
action = flight.Action("", b"")
options = flight.FlightCallOptions(timeout=0.2)
- with pytest.raises(pa.ArrowIOError, match="Deadline Exceeded"):
+ # gRPC error messages change based on version, so don't look
+ # for a particular error
+ with pytest.raises(pa.ArrowIOError):
list(client.do_action(action, options=options))
@@ -479,6 +554,7 @@ token_auth_handler = TokenServerAuthHandler(creds={
})
+@pytest.mark.slow
def test_http_basic_unauth():
"""Test that auth fails when not authenticated."""
with flight_server(EchoStreamFlightServer,
@@ -553,7 +629,9 @@ def test_tls_fails():
# Ensure client doesn't connect when certificate verification
# fails (this is a slow test since gRPC does retry a few times)
client = flight.FlightClient.connect(server_location)
- with pytest.raises(pa.ArrowIOError, match="Connect Failed"):
+ # gRPC error messages change based on version, so don't look
+ # for a particular error
+ with pytest.raises(pa.ArrowIOError):
client.do_get(flight.Ticket(b'ints'))
@@ -585,3 +663,94 @@ def test_tls_override_hostname():
override_hostname="fakehostname")
with pytest.raises(pa.ArrowIOError):
client.do_get(flight.Ticket(b'ints'))
+
+
+def test_flight_do_get_metadata():
+ """Try a simple do_get call with metadata."""
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ batches = []
+ with flight_server(MetadataFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ reader = client.do_get(flight.Ticket(b''))
+ idx = 0
+ while True:
+ try:
+ batch, metadata = reader.read_chunk()
+ batches.append(batch)
+ server_idx, = struct.unpack('<i', metadata.to_pybytes())
+ assert idx == server_idx
+ idx += 1
+ except StopIteration:
+ break
+ data = pa.Table.from_batches(batches)
+ assert data.equals(table)
+
+
+def test_flight_do_put_metadata():
+ """Try a simple do_put call with metadata."""
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ with flight_server(MetadataFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ writer, metadata_reader = client.do_put(
+ flight.FlightDescriptor.for_path(''),
+ table.schema)
+ with writer:
+ for idx, batch in enumerate(table.to_batches(chunksize=1)):
+ metadata = struct.pack('<i', idx)
+ writer.write_with_metadata(batch, metadata)
+ buf = metadata_reader.read()
+ assert buf is not None
+ server_idx, = struct.unpack('<i', buf.to_pybytes())
+ assert idx == server_idx
+
+
+@pytest.mark.slow
+def test_cancel_do_get():
+ """Test canceling a DoGet operation on the client side."""
+ with flight_server(ConstantFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ reader = client.do_get(flight.Ticket(b'ints'))
+ reader.cancel()
+ with pytest.raises(pa.ArrowIOError, match=".*Cancel.*"):
+ reader.read_chunk()
+
+
+@pytest.mark.slow
+def test_cancel_do_get_threaded():
+ """Test canceling a DoGet operation from another thread."""
+ with flight_server(SlowFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ reader = client.do_get(flight.Ticket(b'ints'))
+
+ read_first_message = threading.Event()
+ stream_canceled = threading.Event()
+ result_lock = threading.Lock()
+ raised_proper_exception = threading.Event()
+
+ def block_read():
+ reader.read_chunk()
+ read_first_message.set()
+ stream_canceled.wait(timeout=5)
+ try:
+ reader.read_chunk()
+ except pa.ArrowIOError:
+ with result_lock:
+ raised_proper_exception.set()
+
+ thread = threading.Thread(target=block_read, daemon=True)
+ thread.start()
+ read_first_message.wait(timeout=5)
+ reader.cancel()
+ stream_canceled.set()
+ thread.join(timeout=1)
+
+ with result_lock:
+ assert raised_proper_exception.is_set()