You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/08/18 16:04:50 UTC
[arrow] branch master updated: ARROW-17254: [C++][Go][Java][FlightRPC] Implement and test Flight SQL GetSchema (#13898)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new bc52f9f0e5 ARROW-17254: [C++][Go][Java][FlightRPC] Implement and test Flight SQL GetSchema (#13898)
bc52f9f0e5 is described below
commit bc52f9f0e582474501e92e6a281f0110754a8af1
Author: David Li <li...@gmail.com>
AuthorDate: Thu Aug 18 12:04:42 2022 -0400
ARROW-17254: [C++][Go][Java][FlightRPC] Implement and test Flight SQL GetSchema (#13898)
Consistently implements and tests the GetSchema method in Flight SQL.
Builds on #13897.
Authored-by: David Li <li...@gmail.com>
Signed-off-by: David Li <li...@gmail.com>
---
.../arrow/flight/integration_tests/CMakeLists.txt | 11 +
.../integration_tests/flight_integration_test.cc | 60 ++++
.../flight/integration_tests/test_integration.cc | 351 +++++++++++++--------
cpp/src/arrow/flight/sql/client.cc | 208 ++++++++----
cpp/src/arrow/flight/sql/client.h | 83 ++++-
cpp/src/arrow/flight/sql/server.cc | 79 +++++
cpp/src/arrow/flight/sql/server.h | 23 ++
cpp/src/arrow/flight/types.cc | 5 +-
cpp/src/arrow/flight/types.h | 2 +-
cpp/src/arrow/python/flight.cc | 5 +-
go/arrow/flight/flightsql/client.go | 88 ++++++
go/arrow/flight/flightsql/server.go | 60 ++++
go/arrow/internal/flight_integration/scenario.go | 131 +++++++-
.../integration/tests/FlightSqlScenario.java | 36 ++-
.../tests/FlightSqlScenarioProducer.java | 9 +
.../flight/integration/tests/IntegrationTest.java | 65 ++++
.../apache/arrow/flight/sql/FlightSqlClient.java | 135 ++++++++
.../apache/arrow/flight/sql/FlightSqlProducer.java | 45 ++-
.../apache/arrow/flight/sql/FlightSqlUtils.java | 2 +-
19 files changed, 1173 insertions(+), 225 deletions(-)
diff --git a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt
index 66a021b4b5..1bbd923160 100644
--- a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt
+++ b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt
@@ -40,3 +40,14 @@ target_link_libraries(flight-test-integration-client
add_dependencies(arrow-integration flight-test-integration-client
flight-test-integration-server)
+
+if(ARROW_BUILD_TESTS)
+ add_arrow_test(flight_integration_test
+ SOURCES
+ flight_integration_test.cc
+ test_integration.cc
+ STATIC_LINK_LIBS
+ ${ARROW_FLIGHT_INTEGRATION_TEST_LINK_LIBS}
+ LABELS
+ "arrow_flight")
+endif()
diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
new file mode 100644
index 0000000000..706ac3b7d9
--- /dev/null
+++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Run the integration test scenarios in-process.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/flight/integration_tests/test_integration.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace flight {
+namespace integration_tests {
+
+Status RunScenario(const std::string& scenario_name) {
+ std::shared_ptr<Scenario> scenario;
+ ARROW_RETURN_NOT_OK(GetScenario(scenario_name, &scenario));
+
+ std::unique_ptr<FlightServerBase> server;
+ ARROW_ASSIGN_OR_RAISE(Location bind_location,
+ arrow::flight::Location::ForGrpcTcp("0.0.0.0", 0));
+ FlightServerOptions server_options(bind_location);
+ ARROW_RETURN_NOT_OK(scenario->MakeServer(&server, &server_options));
+ ARROW_RETURN_NOT_OK(server->Init(server_options));
+
+ ARROW_ASSIGN_OR_RAISE(Location location,
+ arrow::flight::Location::ForGrpcTcp("0.0.0.0", server->port()));
+ auto client_options = arrow::flight::FlightClientOptions::Defaults();
+ ARROW_RETURN_NOT_OK(scenario->MakeClient(&client_options));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<FlightClient> client,
+ FlightClient::Connect(location, client_options));
+ ARROW_RETURN_NOT_OK(scenario->RunClient(std::move(client)));
+ return Status::OK();
+}
+
+TEST(FlightIntegration, AuthBasicProto) { ASSERT_OK(RunScenario("auth:basic_proto")); }
+
+TEST(FlightIntegration, Middleware) { ASSERT_OK(RunScenario("middleware")); }
+
+TEST(FlightIntegration, FlightSql) { ASSERT_OK(RunScenario("flight_sql")); }
+
+} // namespace integration_tests
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc
index 7bdd27da79..b228f9cceb 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc
@@ -117,7 +117,7 @@ class AuthBasicProtoScenario : public Scenario {
/// regardless of what gRPC does.
class TestServerMiddleware : public ServerMiddleware {
public:
- explicit TestServerMiddleware(std::string received) : received_(received) {}
+ explicit TestServerMiddleware(std::string received) : received_(std::move(received)) {}
void SendingHeaders(AddCallHeaders* outgoing_headers) override {
outgoing_headers->AddHeader("x-middleware", received_);
@@ -154,11 +154,11 @@ class TestClientMiddleware : public ClientMiddleware {
explicit TestClientMiddleware(std::string* received_header)
: received_header_(received_header) {}
- void SendingHeaders(AddCallHeaders* outgoing_headers) {
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
outgoing_headers->AddHeader("x-middleware", "expected value");
}
- void ReceivedHeaders(const CallHeaders& incoming_headers) {
+ void ReceivedHeaders(const CallHeaders& incoming_headers) override {
// We expect the server to always send this header. gRPC/Java may
// send it in trailers instead of headers, so we expect Flight to
// account for this.
@@ -170,7 +170,7 @@ class TestClientMiddleware : public ClientMiddleware {
}
}
- void CallCompleted(const Status& status) {}
+ void CallCompleted(const Status& status) override {}
private:
std::string* received_header_;
@@ -178,7 +178,8 @@ class TestClientMiddleware : public ClientMiddleware {
class TestClientMiddlewareFactory : public ClientMiddlewareFactory {
public:
- void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) {
+ void StartCall(const CallInfo& info,
+ std::unique_ptr<ClientMiddleware>* middleware) override {
*middleware =
std::unique_ptr<ClientMiddleware>(new TestClientMiddleware(&received_header_));
}
@@ -218,8 +219,8 @@ class MiddlewareServer : public FlightServerBase {
class MiddlewareScenario : public Scenario {
Status MakeServer(std::unique_ptr<FlightServerBase>* server,
FlightServerOptions* options) override {
- options->middleware.push_back(
- {"grpc_trailers", std::make_shared<TestServerMiddlewareFactory>()});
+ options->middleware.emplace_back("grpc_trailers",
+ std::make_shared<TestServerMiddlewareFactory>());
server->reset(new MiddlewareServer());
return Status::OK();
}
@@ -284,11 +285,13 @@ std::shared_ptr<Schema> GetQuerySchema() {
constexpr int64_t kUpdateStatementExpectedRows = 10000L;
constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L;
+constexpr char kSelectStatement[] = "SELECT STATEMENT";
template <typename T>
-arrow::Status AssertEq(const T& expected, const T& actual) {
+arrow::Status AssertEq(const T& expected, const T& actual, const std::string& message) {
if (expected != actual) {
- return Status::Invalid("Expected \"", expected, "\", got \'", actual, "\"");
+ return Status::Invalid(message, ": expected \"", expected, "\", got \"", actual,
+ "\"");
}
return Status::OK();
}
@@ -301,7 +304,9 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoStatement(
const ServerCallContext& context, const sql::StatementQuery& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("SELECT STATEMENT", command.query));
+ ARROW_RETURN_NOT_OK(
+ AssertEq<std::string>(kSelectStatement, command.query,
+ "Unexpected statement in GetFlightInfoStatement"));
ARROW_ASSIGN_OR_RAISE(auto handle,
sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE"));
@@ -313,6 +318,14 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
return std::unique_ptr<FlightInfo>(new FlightInfo(result));
}
+ arrow::Result<std::unique_ptr<SchemaResult>> GetSchemaStatement(
+ const ServerCallContext& context, const sql::StatementQuery& command,
+ const FlightDescriptor& descriptor) override {
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ kSelectStatement, command.query, "Unexpected statement in GetSchemaStatement"));
+ return SchemaResult::Make(*GetQuerySchema());
+ }
+
arrow::Result<std::unique_ptr<FlightDataStream>> DoGetStatement(
const ServerCallContext& context,
const sql::StatementQueryTicket& command) override {
@@ -323,11 +336,21 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
const ServerCallContext& context, const sql::PreparedStatementQuery& command,
const FlightDescriptor& descriptor) override {
ARROW_RETURN_NOT_OK(AssertEq<std::string>("SELECT PREPARED STATEMENT HANDLE",
- command.prepared_statement_handle));
+ command.prepared_statement_handle,
+ "Unexpected prepared statement handle"));
return GetFlightInfoForCommand(descriptor, GetQuerySchema());
}
+ arrow::Result<std::unique_ptr<SchemaResult>> GetSchemaPreparedStatement(
+ const ServerCallContext& context, const sql::PreparedStatementQuery& command,
+ const FlightDescriptor& descriptor) override {
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("SELECT PREPARED STATEMENT HANDLE",
+ command.prepared_statement_handle,
+ "Unexpected prepared statement handle"));
+ return SchemaResult::Make(*GetQuerySchema());
+ }
+
arrow::Result<std::unique_ptr<FlightDataStream>> DoGetPreparedStatement(
const ServerCallContext& context,
const sql::PreparedStatementQuery& command) override {
@@ -358,11 +381,14 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoSqlInfo(
const ServerCallContext& context, const sql::GetSqlInfo& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(AssertEq<int64_t>(2, command.info.size()));
- ARROW_RETURN_NOT_OK(AssertEq<int32_t>(
- sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, command.info[0]));
- ARROW_RETURN_NOT_OK(AssertEq<int32_t>(
- sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, command.info[1]));
+ ARROW_RETURN_NOT_OK(AssertEq<int64_t>(2, command.info.size(),
+ "Wrong number of SqlInfo values passed"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq<int32_t>(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME,
+ command.info[0], "Unexpected SqlInfo passed"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq<int32_t>(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY,
+ command.info[1], "Unexpected SqlInfo passed"));
return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema());
}
@@ -375,9 +401,11 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoSchemas(
const ServerCallContext& context, const sql::GetDbSchemas& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("catalog", command.catalog.value()));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("catalog", command.catalog.value(),
+ "Wrong catalog passed"));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("db_schema_filter_pattern",
- command.db_schema_filter_pattern.value()));
+ command.db_schema_filter_pattern.value(),
+ "Wrong db_schema_filter_pattern passed"));
return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetDbSchemasSchema());
}
@@ -390,15 +418,22 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoTables(
const ServerCallContext& context, const sql::GetTables& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("catalog", command.catalog.value()));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("catalog", command.catalog.value(),
+ "Wrong catalog passed"));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("db_schema_filter_pattern",
- command.db_schema_filter_pattern.value()));
+ command.db_schema_filter_pattern.value(),
+ "Wrong db_schema_filter_pattern passed"));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("table_filter_pattern",
- command.table_name_filter_pattern.value()));
- ARROW_RETURN_NOT_OK(AssertEq<int64_t>(2, command.table_types.size()));
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("table", command.table_types[0]));
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("view", command.table_types[1]));
- ARROW_RETURN_NOT_OK(AssertEq<bool>(true, command.include_schema));
+ command.table_name_filter_pattern.value(),
+ "Wrong table_filter_pattern passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<int64_t>(2, command.table_types.size(),
+ "Wrong number of table types passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("table", command.table_types[0],
+ "Wrong table type passed"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq<std::string>("view", command.table_types[1], "Wrong table type passed"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq<bool>(true, command.include_schema, "include_schema should be true"));
return GetFlightInfoForCommand(descriptor,
sql::SqlSchema::GetTablesSchemaWithIncludedSchema());
@@ -422,11 +457,12 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoPrimaryKeys(
const ServerCallContext& context, const sql::GetPrimaryKeys& command,
const FlightDescriptor& descriptor) override {
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "catalog", command.table_ref.catalog.value(), "Wrong catalog passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "db_schema", command.table_ref.db_schema.value(), "Wrong db_schema passed"));
ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("catalog", command.table_ref.catalog.value()));
- ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("db_schema", command.table_ref.db_schema.value()));
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("table", command.table_ref.table));
+ AssertEq<std::string>("table", command.table_ref.table, "Wrong table passed"));
return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetPrimaryKeysSchema());
}
@@ -439,11 +475,12 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoExportedKeys(
const ServerCallContext& context, const sql::GetExportedKeys& command,
const FlightDescriptor& descriptor) override {
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "catalog", command.table_ref.catalog.value(), "Wrong catalog passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "db_schema", command.table_ref.db_schema.value(), "Wrong db_schema passed"));
ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("catalog", command.table_ref.catalog.value()));
- ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("db_schema", command.table_ref.db_schema.value()));
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("table", command.table_ref.table));
+ AssertEq<std::string>("table", command.table_ref.table, "Wrong table passed"));
return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetExportedKeysSchema());
}
@@ -456,11 +493,12 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoImportedKeys(
const ServerCallContext& context, const sql::GetImportedKeys& command,
const FlightDescriptor& descriptor) override {
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "catalog", command.table_ref.catalog.value(), "Wrong catalog passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "db_schema", command.table_ref.db_schema.value(), "Wrong db_schema passed"));
ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("catalog", command.table_ref.catalog.value()));
- ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("db_schema", command.table_ref.db_schema.value()));
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("table", command.table_ref.table));
+ AssertEq<std::string>("table", command.table_ref.table, "Wrong table passed"));
return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetImportedKeysSchema());
}
@@ -473,16 +511,20 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoCrossReference(
const ServerCallContext& context, const sql::GetCrossReference& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("pk_catalog", command.pk_table_ref.catalog.value()));
- ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("pk_db_schema", command.pk_table_ref.db_schema.value()));
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("pk_table", command.pk_table_ref.table));
- ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("fk_catalog", command.fk_table_ref.catalog.value()));
- ARROW_RETURN_NOT_OK(
- AssertEq<std::string>("fk_db_schema", command.fk_table_ref.db_schema.value()));
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("fk_table", command.fk_table_ref.table));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "pk_catalog", command.pk_table_ref.catalog.value(), "Wrong pk catalog passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("pk_db_schema",
+ command.pk_table_ref.db_schema.value(),
+ "Wrong pk db_schema passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("pk_table", command.pk_table_ref.table,
+ "Wrong pk table passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>(
+ "fk_catalog", command.fk_table_ref.catalog.value(), "Wrong fk catalog passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("fk_db_schema",
+ command.fk_table_ref.db_schema.value(),
+ "Wrong fk db_schema passed"));
+ ARROW_RETURN_NOT_OK(AssertEq<std::string>("fk_table", command.fk_table_ref.table,
+ "Wrong fk table passed"));
return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema());
}
@@ -494,7 +536,9 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<int64_t> DoPutCommandStatementUpdate(
const ServerCallContext& context, const sql::StatementUpdate& command) override {
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("UPDATE STATEMENT", command.query));
+ ARROW_RETURN_NOT_OK(
+ AssertEq<std::string>("UPDATE STATEMENT", command.query,
+ "Wrong query for DoPutCommandStatementUpdate"));
return kUpdateStatementExpectedRows;
}
@@ -502,9 +546,10 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<sql::ActionCreatePreparedStatementResult> CreatePreparedStatement(
const ServerCallContext& context,
const sql::ActionCreatePreparedStatementRequest& request) override {
- ARROW_RETURN_NOT_OK(
- AssertEq<bool>(true, request.query == "SELECT PREPARED STATEMENT" ||
- request.query == "UPDATE PREPARED STATEMENT"));
+ if (request.query != "SELECT PREPARED STATEMENT" &&
+ request.query != "UPDATE PREPARED STATEMENT") {
+ return Status::Invalid("Unexpected query: ", request.query);
+ }
sql::ActionCreatePreparedStatementResult result;
result.prepared_statement_handle = request.query + " HANDLE";
@@ -515,6 +560,11 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
Status ClosePreparedStatement(
const ServerCallContext& context,
const sql::ActionClosePreparedStatementRequest& request) override {
+ if (request.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" &&
+ request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE") {
+ return Status::Invalid("Invalid handle for ClosePreparedStatement: ",
+ request.prepared_statement_handle);
+ }
return Status::OK();
}
@@ -522,11 +572,14 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
const sql::PreparedStatementQuery& command,
FlightMessageReader* reader,
FlightMetadataWriter* writer) override {
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("SELECT PREPARED STATEMENT HANDLE",
- command.prepared_statement_handle));
+ if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE") {
+ return Status::Invalid("Invalid handle for DoPutPreparedStatementQuery: ",
+ command.prepared_statement_handle);
+ }
ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema());
- ARROW_RETURN_NOT_OK(AssertEq<Schema>(*GetQuerySchema(), *actual_schema));
+ ARROW_RETURN_NOT_OK(AssertEq<Schema>(*GetQuerySchema(), *actual_schema,
+ "Wrong schema for DoPutPreparedStatementQuery"));
return Status::OK();
}
@@ -534,10 +587,11 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<int64_t> DoPutPreparedStatementUpdate(
const ServerCallContext& context, const sql::PreparedStatementUpdate& command,
FlightMessageReader* reader) override {
- ARROW_RETURN_NOT_OK(AssertEq<std::string>("UPDATE PREPARED STATEMENT HANDLE",
- command.prepared_statement_handle));
-
- return kUpdatePreparedStatementExpectedRows;
+ if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE") {
+ return kUpdatePreparedStatementExpectedRows;
+ }
+ return Status::Invalid("Invalid handle for DoPutPreparedStatementUpdate: ",
+ command.prepared_statement_handle);
}
private:
@@ -569,19 +623,27 @@ class FlightSqlScenario : public Scenario {
Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }
- Status Validate(std::shared_ptr<Schema> expected_schema,
- arrow::Result<std::unique_ptr<FlightInfo>> flight_info_result,
- sql::FlightSqlClient* sql_client) {
+ Status Validate(const std::shared_ptr<Schema>& expected_schema,
+ const FlightInfo& flight_info, sql::FlightSqlClient* sql_client) {
FlightCallOptions call_options;
-
- ARROW_ASSIGN_OR_RAISE(auto flight_info, flight_info_result);
ARROW_ASSIGN_OR_RAISE(
- auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket));
-
+ std::unique_ptr<FlightStreamReader> reader,
+ sql_client->DoGet(call_options, flight_info.endpoints()[0].ticket));
ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema());
+ if (!expected_schema->Equals(*actual_schema, /*check_metadata=*/true)) {
+ return Status::Invalid("Schemas did not match. Expected:\n", *expected_schema,
+ "\nActual:\n", *actual_schema);
+ }
+ ARROW_RETURN_NOT_OK(reader->ToTable());
+ return Status::OK();
+ }
- if (!actual_schema->Equals(*expected_schema, /*check_metadata=*/true)) {
- return Status::Invalid("Schemas do not match. Expected:\n", *expected_schema,
+ Status ValidateSchema(const std::shared_ptr<Schema>& expected_schema,
+ const SchemaResult& result) {
+ ipc::DictionaryMemo memo;
+ ARROW_ASSIGN_OR_RAISE(auto actual_schema, result.GetSchema(&memo));
+ if (!expected_schema->Equals(*actual_schema, /*check_metadata=*/true)) {
+ return Status::Invalid("Schemas did not match. Expected:\n", *expected_schema,
"\nActual:\n", *actual_schema);
}
return Status::OK();
@@ -589,13 +651,9 @@ class FlightSqlScenario : public Scenario {
Status RunClient(std::unique_ptr<FlightClient> client) override {
sql::FlightSqlClient sql_client(std::move(client));
-
ARROW_RETURN_NOT_OK(ValidateMetadataRetrieval(&sql_client));
-
ARROW_RETURN_NOT_OK(ValidateStatementExecution(&sql_client));
-
ARROW_RETURN_NOT_OK(ValidatePreparedStatementExecution(&sql_client));
-
return Status::OK();
}
@@ -613,82 +671,119 @@ class FlightSqlScenario : public Scenario {
sql::TableRef pk_table_ref = {"pk_catalog", "pk_db_schema", "pk_table"};
sql::TableRef fk_table_ref = {"fk_catalog", "fk_db_schema", "fk_table"};
- ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetCatalogsSchema(),
- sql_client->GetCatalogs(options), sql_client));
+ std::unique_ptr<FlightInfo> info;
+ std::unique_ptr<SchemaResult> schema;
+
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->GetCatalogs(options));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetCatalogsSchema(options));
+ ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetCatalogsSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetCatalogsSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(
+ info, sql_client->GetDbSchemas(options, &catalog, &db_schema_filter_pattern));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetDbSchemasSchema(options));
+ ARROW_RETURN_NOT_OK(
+ Validate(sql::SqlSchema::GetDbSchemasSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetDbSchemasSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(
+ info, sql_client->GetTables(options, &catalog, &db_schema_filter_pattern,
+ &table_filter_pattern, true, &table_types));
+ ARROW_ASSIGN_OR_RAISE(schema,
+ sql_client->GetTablesSchema(options, /*include_schema=*/true));
+ ARROW_RETURN_NOT_OK(
+ Validate(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(
+ ValidateSchema(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(schema,
+ sql_client->GetTablesSchema(options, /*include_schema=*/false));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetTablesSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->GetTableTypes(options));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetTableTypesSchema(options));
+ ARROW_RETURN_NOT_OK(
+ Validate(sql::SqlSchema::GetTableTypesSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetTableTypesSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->GetPrimaryKeys(options, table_ref));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetPrimaryKeysSchema(options));
+ ARROW_RETURN_NOT_OK(
+ Validate(sql::SqlSchema::GetPrimaryKeysSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetPrimaryKeysSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->GetExportedKeys(options, table_ref));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetExportedKeysSchema(options));
+ ARROW_RETURN_NOT_OK(
+ Validate(sql::SqlSchema::GetExportedKeysSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetExportedKeysSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->GetImportedKeys(options, table_ref));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetImportedKeysSchema(options));
+ ARROW_RETURN_NOT_OK(
+ Validate(sql::SqlSchema::GetImportedKeysSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetImportedKeysSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(
+ info, sql_client->GetCrossReference(options, pk_table_ref, fk_table_ref));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetCrossReferenceSchema(options));
ARROW_RETURN_NOT_OK(
- Validate(sql::SqlSchema::GetDbSchemasSchema(),
- sql_client->GetDbSchemas(options, &catalog, &db_schema_filter_pattern),
- sql_client));
+ Validate(sql::SqlSchema::GetCrossReferenceSchema(), *info, sql_client));
ARROW_RETURN_NOT_OK(
- Validate(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(),
- sql_client->GetTables(options, &catalog, &db_schema_filter_pattern,
- &table_filter_pattern, true, &table_types),
- sql_client));
- ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetTableTypesSchema(),
- sql_client->GetTableTypes(options), sql_client));
- ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetPrimaryKeysSchema(),
- sql_client->GetPrimaryKeys(options, table_ref),
- sql_client));
- ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetExportedKeysSchema(),
- sql_client->GetExportedKeys(options, table_ref),
- sql_client));
- ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetImportedKeysSchema(),
- sql_client->GetImportedKeys(options, table_ref),
- sql_client));
- ARROW_RETURN_NOT_OK(Validate(
- sql::SqlSchema::GetCrossReferenceSchema(),
- sql_client->GetCrossReference(options, pk_table_ref, fk_table_ref), sql_client));
- ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetXdbcTypeInfoSchema(),
- sql_client->GetXdbcTypeInfo(options), sql_client));
- ARROW_RETURN_NOT_OK(Validate(
- sql::SqlSchema::GetSqlInfoSchema(),
- sql_client->GetSqlInfo(
- options, {sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME,
- sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY}),
- sql_client));
+ ValidateSchema(sql::SqlSchema::GetCrossReferenceSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->GetXdbcTypeInfo(options));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetXdbcTypeInfoSchema(options));
+ ARROW_RETURN_NOT_OK(
+ Validate(sql::SqlSchema::GetXdbcTypeInfoSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetXdbcTypeInfoSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(
+ info, sql_client->GetSqlInfo(
+ options, {sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME,
+ sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY}));
+ ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetSqlInfoSchema(options));
+ ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetSqlInfoSchema(), *info, sql_client));
+ ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetSqlInfoSchema(), *schema));
return Status::OK();
}
Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) {
- FlightCallOptions options;
+ ARROW_ASSIGN_OR_RAISE(auto info, sql_client->Execute({}, kSelectStatement));
+ ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client));
- ARROW_RETURN_NOT_OK(Validate(
- GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client));
- ARROW_ASSIGN_OR_RAISE(auto update_statement_result,
- sql_client->ExecuteUpdate(options, "UPDATE STATEMENT"));
- if (update_statement_result != kUpdateStatementExpectedRows) {
- return Status::Invalid("Expected 'UPDATE STATEMENT' return ",
- kUpdateStatementExpectedRows, ", got ",
- update_statement_result);
- }
+ ARROW_ASSIGN_OR_RAISE(auto schema,
+ sql_client->GetExecuteSchema({}, kSelectStatement));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(auto updated_rows,
+ sql_client->ExecuteUpdate({}, "UPDATE STATEMENT"));
+ ARROW_RETURN_NOT_OK(AssertEq(kUpdateStatementExpectedRows, updated_rows,
+ "Wrong number of updated rows for ExecuteUpdate"));
return Status::OK();
}
Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) {
- FlightCallOptions options;
-
- ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement,
- sql_client->Prepare(options, "SELECT PREPARED STATEMENT"));
-
auto parameters =
RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")});
- ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters));
- ARROW_RETURN_NOT_OK(
- Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client));
+ ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement,
+ sql_client->Prepare({}, "SELECT PREPARED STATEMENT"));
+ ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters));
+ ARROW_ASSIGN_OR_RAISE(auto info, select_prepared_statement->Execute());
+ ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client));
+ ARROW_ASSIGN_OR_RAISE(auto schema, select_prepared_statement->GetSchema({}));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema));
ARROW_RETURN_NOT_OK(select_prepared_statement->Close());
ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement,
- sql_client->Prepare(options, "UPDATE PREPARED STATEMENT"));
- ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement_result,
- update_prepared_statement->ExecuteUpdate());
- if (update_prepared_statement_result != kUpdatePreparedStatementExpectedRows) {
- return Status::Invalid("Expected 'UPDATE STATEMENT' return ",
- kUpdatePreparedStatementExpectedRows, ", got ",
- update_prepared_statement_result);
- }
+ sql_client->Prepare({}, "UPDATE PREPARED STATEMENT"));
+ ARROW_ASSIGN_OR_RAISE(auto updated_rows, update_prepared_statement->ExecuteUpdate());
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows,
+ "Wrong number of updated rows for prepared statement ExecuteUpdate"));
ARROW_RETURN_NOT_OK(update_prepared_statement->Close());
return Status::OK();
diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc
index 10ff1eea6f..e299b7ceb1 100644
--- a/cpp/src/arrow/flight/sql/client.cc
+++ b/cpp/src/arrow/flight/sql/client.cc
@@ -36,15 +36,45 @@ namespace arrow {
namespace flight {
namespace sql {
+namespace {
+arrow::Result<FlightDescriptor> GetFlightDescriptorForCommand(
+ const google::protobuf::Message& command) {
+ google::protobuf::Any any;
+ if (!any.PackFrom(command)) {
+ return Status::SerializationError("Failed to pack ", command.GetTypeName());
+ }
+
+ std::string buf;
+ if (!any.SerializeToString(&buf)) {
+ return Status::SerializationError("Failed to serialize ", command.GetTypeName());
+ }
+ return FlightDescriptor::Command(buf);
+}
+
+arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoForCommand(
+ FlightSqlClient* client, const FlightCallOptions& options,
+ const google::protobuf::Message& command) {
+ ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
+ GetFlightDescriptorForCommand(command));
+ return client->GetFlightInfo(options, descriptor);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> GetSchemaForCommand(
+ FlightSqlClient* client, const FlightCallOptions& options,
+ const google::protobuf::Message& command) {
+ ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
+ GetFlightDescriptorForCommand(command));
+ return client->GetSchema(options, descriptor);
+}
+} // namespace
+
FlightSqlClient::FlightSqlClient(std::shared_ptr<FlightClient> client)
: impl_(std::move(client)) {}
PreparedStatement::PreparedStatement(FlightSqlClient* client, std::string handle,
std::shared_ptr<Schema> dataset_schema,
- std::shared_ptr<Schema> parameter_schema,
- FlightCallOptions options)
+ std::shared_ptr<Schema> parameter_schema)
: client_(client),
- options_(std::move(options)),
handle_(std::move(handle)),
dataset_schema_(std::move(dataset_schema)),
parameter_schema_(std::move(parameter_schema)),
@@ -59,30 +89,20 @@ PreparedStatement::~PreparedStatement() {
}
}
-inline FlightDescriptor GetFlightDescriptorForCommand(
- const google::protobuf::Message& command) {
- google::protobuf::Any any;
- any.PackFrom(command);
-
- const std::string& string = any.SerializeAsString();
- return FlightDescriptor::Command(string);
-}
-
-arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoForCommand(
- FlightSqlClient& client, const FlightCallOptions& options,
- const google::protobuf::Message& command) {
- const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command);
+arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::Execute(
+ const FlightCallOptions& options, const std::string& query) {
+ flight_sql_pb::CommandStatementQuery command;
+ command.set_query(query);
- ARROW_ASSIGN_OR_RAISE(auto flight_info, client.GetFlightInfo(options, descriptor));
- return std::move(flight_info);
+ return GetFlightInfoForCommand(this, options, command);
}
-arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::Execute(
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetExecuteSchema(
const FlightCallOptions& options, const std::string& query) {
flight_sql_pb::CommandStatementQuery command;
command.set_query(query);
- return GetFlightInfoForCommand(*this, options, command);
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<int64_t> FlightSqlClient::ExecuteUpdate(const FlightCallOptions& options,
@@ -90,7 +110,8 @@ arrow::Result<int64_t> FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o
flight_sql_pb::CommandStatementUpdate command;
command.set_query(query);
- const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command);
+ ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
+ GetFlightDescriptorForCommand(command));
std::unique_ptr<FlightStreamWriter> writer;
std::unique_ptr<FlightMetadataReader> reader;
@@ -114,8 +135,13 @@ arrow::Result<int64_t> FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetCatalogs(
const FlightCallOptions& options) {
flight_sql_pb::CommandGetCatalogs command;
+ return GetFlightInfoForCommand(this, options, command);
+}
- return GetFlightInfoForCommand(*this, options, command);
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetCatalogsSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetCatalogs command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetDbSchemas(
@@ -129,7 +155,13 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetDbSchemas(
command.set_db_schema_filter_pattern(*db_schema_filter_pattern);
}
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetDbSchemasSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetDbSchemas command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetTables(
@@ -158,7 +190,14 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetTables(
}
}
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetTablesSchema(
+ const FlightCallOptions& options, bool include_schema) {
+ flight_sql_pb::CommandGetTables command;
+ command.set_include_schema(include_schema);
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetPrimaryKeys(
@@ -175,7 +214,13 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetPrimaryKeys(
command.set_table(table_ref.table);
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetPrimaryKeysSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetPrimaryKeys command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetExportedKeys(
@@ -192,7 +237,13 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetExportedKeys(
command.set_table(table_ref.table);
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetExportedKeysSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetExportedKeys command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetImportedKeys(
@@ -209,7 +260,13 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetImportedKeys(
command.set_table(table_ref.table);
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetImportedKeysSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetImportedKeys command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetCrossReference(
@@ -233,21 +290,33 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetCrossReference(
}
command.set_fk_table(fk_table_ref.table);
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetCrossReferenceSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetCrossReference command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetTableTypes(
const FlightCallOptions& options) {
flight_sql_pb::CommandGetTableTypes command;
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetTableTypesSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetTableTypes command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetXdbcTypeInfo(
const FlightCallOptions& options) {
flight_sql_pb::CommandGetXdbcTypeInfo command;
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetXdbcTypeInfo(
@@ -256,7 +325,27 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetXdbcTypeInfo(
command.set_data_type(data_type);
- return GetFlightInfoForCommand(*this, options, command);
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetXdbcTypeInfoSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetXdbcTypeInfo command;
+ return GetSchemaForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetSqlInfo(
+ const FlightCallOptions& options, const std::vector<int>& sql_info) {
+ flight_sql_pb::CommandGetSqlInfo command;
+ for (const int& info : sql_info) command.add_info(info);
+
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlClient::GetSqlInfoSchema(
+ const FlightCallOptions& options) {
+ flight_sql_pb::CommandGetSqlInfo command;
+ return GetSchemaForCommand(this, options, command);
}
arrow::Result<std::unique_ptr<FlightStreamReader>> FlightSqlClient::DoGet(
@@ -319,28 +408,24 @@ arrow::Result<std::shared_ptr<PreparedStatement>> FlightSqlClient::Prepare(
auto handle = prepared_statement_result.prepared_statement_handle();
return std::make_shared<PreparedStatement>(this, handle, dataset_schema,
- parameter_schema, options);
+ parameter_schema);
}
-arrow::Result<std::unique_ptr<FlightInfo>> PreparedStatement::Execute() {
+arrow::Result<std::unique_ptr<FlightInfo>> PreparedStatement::Execute(
+ const FlightCallOptions& options) {
if (is_closed_) {
return Status::Invalid("Statement already closed.");
}
- flight_sql_pb::CommandPreparedStatementQuery execute_query_command;
-
- execute_query_command.set_prepared_statement_handle(handle_);
-
- google::protobuf::Any any;
- any.PackFrom(execute_query_command);
-
- const std::string& string = any.SerializeAsString();
- const FlightDescriptor descriptor = FlightDescriptor::Command(string);
+ flight_sql_pb::CommandPreparedStatementQuery command;
+ command.set_prepared_statement_handle(handle_);
+ ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
+ GetFlightDescriptorForCommand(command));
if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
std::unique_ptr<FlightStreamWriter> writer;
std::unique_ptr<FlightMetadataReader> reader;
- ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, parameter_binding_->schema(),
+ ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(),
&writer, &reader));
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
@@ -350,28 +435,30 @@ arrow::Result<std::unique_ptr<FlightInfo>> PreparedStatement::Execute() {
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer));
}
- ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options_, descriptor));
+ ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor));
return std::move(flight_info);
}
-arrow::Result<int64_t> PreparedStatement::ExecuteUpdate() {
+arrow::Result<int64_t> PreparedStatement::ExecuteUpdate(
+ const FlightCallOptions& options) {
if (is_closed_) {
return Status::Invalid("Statement already closed.");
}
flight_sql_pb::CommandPreparedStatementUpdate command;
command.set_prepared_statement_handle(handle_);
- const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command);
+ ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
+ GetFlightDescriptorForCommand(command));
std::unique_ptr<FlightStreamWriter> writer;
std::unique_ptr<FlightMetadataReader> reader;
if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
- ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, parameter_binding_->schema(),
+ ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(),
&writer, &reader));
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
} else {
const std::shared_ptr<Schema> schema = arrow::schema({});
- ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, schema, &writer, &reader));
+ ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, schema, &writer, &reader));
const ArrayVector columns;
const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch));
@@ -406,7 +493,20 @@ std::shared_ptr<Schema> PreparedStatement::parameter_schema() const {
return parameter_schema_;
}
-Status PreparedStatement::Close() {
+arrow::Result<std::unique_ptr<SchemaResult>> PreparedStatement::GetSchema(
+ const FlightCallOptions& options) {
+ if (is_closed_) {
+ return Status::Invalid("Statement already closed");
+ }
+
+ flight_sql_pb::CommandPreparedStatementQuery command;
+ command.set_prepared_statement_handle(handle_);
+ ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
+ GetFlightDescriptorForCommand(command));
+ return client_->GetSchema(options, descriptor);
+}
+
+Status PreparedStatement::Close(const FlightCallOptions& options) {
if (is_closed_) {
return Status::Invalid("Statement already closed.");
}
@@ -422,7 +522,7 @@ Status PreparedStatement::Close() {
std::unique_ptr<ResultStream> results;
- ARROW_RETURN_NOT_OK(client_->DoAction(options_, action, &results));
+ ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results));
is_closed_ = true;
@@ -431,14 +531,6 @@ Status PreparedStatement::Close() {
Status FlightSqlClient::Close() { return impl_->Close(); }
-arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetSqlInfo(
- const FlightCallOptions& options, const std::vector<int>& sql_info) {
- flight_sql_pb::CommandGetSqlInfo command;
- for (const int& info : sql_info) command.add_info(info);
-
- return GetFlightInfoForCommand(*this, options, command);
-}
-
} // namespace sql
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h
index 7c8cb640e8..26315e0d23 100644
--- a/cpp/src/arrow/flight/sql/client.h
+++ b/cpp/src/arrow/flight/sql/client.h
@@ -54,6 +54,10 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::unique_ptr<FlightInfo>> Execute(const FlightCallOptions& options,
const std::string& query);
+ /// \brief Get the result set schema from the server.
+ arrow::Result<std::unique_ptr<SchemaResult>> GetExecuteSchema(
+ const FlightCallOptions& options, const std::string& query);
+
/// \brief Execute an update query on the server.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] query The query to be executed in the UTF-8 format.
@@ -67,6 +71,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::unique_ptr<FlightInfo>> GetCatalogs(
const FlightCallOptions& options);
+ /// \brief Get the catalogs schema from the server (should be
+ /// identical to SqlSchema::GetCatalogsSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetCatalogsSchema(
+ const FlightCallOptions& options);
+
/// \brief Request a list of database schemas.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] catalog The catalog.
@@ -76,6 +85,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
const FlightCallOptions& options, const std::string* catalog,
const std::string* db_schema_filter_pattern);
+ /// \brief Get the database schemas schema from the server (should be
+ /// identical to SqlSchema::GetDbSchemasSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetDbSchemasSchema(
+ const FlightCallOptions& options);
+
/// \brief Given a flight ticket and schema, request to be sent the
/// stream. Returns record batch stream reader
/// \param[in] options Per-RPC options
@@ -99,6 +113,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
const std::string* table_filter_pattern, bool include_schema,
const std::vector<std::string>* table_types);
+ /// \brief Get the tables schema from the server (should be
+ /// identical to SqlSchema::GetTablesSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetTablesSchema(
+ const FlightCallOptions& options, bool include_schema);
+
/// \brief Request the primary keys for a table.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] table_ref The table reference.
@@ -106,6 +125,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::unique_ptr<FlightInfo>> GetPrimaryKeys(
const FlightCallOptions& options, const TableRef& table_ref);
+ /// \brief Get the primary keys schema from the server (should be
+ /// identical to SqlSchema::GetPrimaryKeysSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetPrimaryKeysSchema(
+ const FlightCallOptions& options);
+
/// \brief Retrieves a description about the foreign key columns that reference the
/// primary key columns of the given table.
/// \param[in] options RPC-layer hints for this call.
@@ -114,6 +138,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::unique_ptr<FlightInfo>> GetExportedKeys(
const FlightCallOptions& options, const TableRef& table_ref);
+ /// \brief Get the exported keys schema from the server (should be
+ /// identical to SqlSchema::GetExportedKeysSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetExportedKeysSchema(
+ const FlightCallOptions& options);
+
/// \brief Retrieves the foreign key columns for the given table.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] table_ref The table reference.
@@ -121,6 +150,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::unique_ptr<FlightInfo>> GetImportedKeys(
const FlightCallOptions& options, const TableRef& table_ref);
+ /// \brief Get the imported keys schema from the server (should be
+ /// identical to SqlSchema::GetImportedKeysSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetImportedKeysSchema(
+ const FlightCallOptions& options);
+
/// \brief Retrieves a description of the foreign key columns in the given foreign key
/// table that reference the primary key or the columns representing a unique
/// constraint of the parent table (could be the same or a different table).
@@ -132,12 +166,22 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
const FlightCallOptions& options, const TableRef& pk_table_ref,
const TableRef& fk_table_ref);
+ /// \brief Get the cross reference schema from the server (should be
+ /// identical to SqlSchema::GetCrossReferenceSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetCrossReferenceSchema(
+ const FlightCallOptions& options);
+
/// \brief Request a list of table types.
/// \param[in] options RPC-layer hints for this call.
/// \return The FlightInfo describing where to access the dataset.
arrow::Result<std::unique_ptr<FlightInfo>> GetTableTypes(
const FlightCallOptions& options);
+ /// \brief Get the table types schema from the server (should be
+ /// identical to SqlSchema::GetTableTypesSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetTableTypesSchema(
+ const FlightCallOptions& options);
+
/// \brief Request the information about all the data types supported.
/// \param[in] options RPC-layer hints for this call.
/// \return The FlightInfo describing where to access the dataset.
@@ -151,6 +195,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::unique_ptr<FlightInfo>> GetXdbcTypeInfo(
const FlightCallOptions& options, int data_type);
+ /// \brief Get the type info schema from the server (should be
+ /// identical to SqlSchema::GetXdbcTypeInfoSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetXdbcTypeInfoSchema(
+ const FlightCallOptions& options);
+
/// \brief Request a list of SQL information.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] sql_info the SQL info required.
@@ -158,6 +207,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::unique_ptr<FlightInfo>> GetSqlInfo(const FlightCallOptions& options,
const std::vector<int>& sql_info);
+ /// \brief Get the SQL information schema from the server (should be
+ /// identical to SqlSchema::GetSqlInfoSchema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetSqlInfoSchema(
+ const FlightCallOptions& options);
+
/// \brief Create a prepared statement object.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] query The query that will be executed.
@@ -165,17 +219,18 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
arrow::Result<std::shared_ptr<PreparedStatement>> Prepare(
const FlightCallOptions& options, const std::string& query);
- /// \brief Retrieve the FlightInfo.
- /// \param[in] options RPC-layer hints for this call.
- /// \param[in] descriptor The flight descriptor.
- /// \return The flight info with the metadata.
- // NOTE: This is public because it is been used by the anonymous
- // function GetFlightInfoForCommand.
+ /// \brief Call the underlying Flight client's GetFlightInfo.
virtual arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfo(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
return impl_->GetFlightInfo(options, descriptor);
}
+ /// \brief Call the underlying Flight client's GetSchema.
+ virtual arrow::Result<std::unique_ptr<SchemaResult>> GetSchema(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor) {
+ return impl_->GetSchema(options, descriptor);
+ }
+
/// \brief Explicitly shut down and clean up the client.
Status Close();
@@ -212,10 +267,9 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
/// \param[in] handle Handle for this prepared statement.
/// \param[in] dataset_schema Schema of the resulting dataset.
/// \param[in] parameter_schema Schema of the parameters (if any).
- /// \param[in] options RPC-layer hints for this call.
PreparedStatement(FlightSqlClient* client, std::string handle,
std::shared_ptr<Schema> dataset_schema,
- std::shared_ptr<Schema> parameter_schema, FlightCallOptions options);
+ std::shared_ptr<Schema> parameter_schema);
/// \brief Default destructor for the PreparedStatement class.
/// The destructor will call the Close method from the class in order,
@@ -226,11 +280,12 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
/// \brief Executes the prepared statement query on the server.
/// \return A FlightInfo object representing the stream(s) to fetch.
- arrow::Result<std::unique_ptr<FlightInfo>> Execute();
+ arrow::Result<std::unique_ptr<FlightInfo>> Execute(
+ const FlightCallOptions& options = {});
/// \brief Executes the prepared statement update query on the server.
/// \return The number of rows affected.
- arrow::Result<int64_t> ExecuteUpdate();
+ arrow::Result<int64_t> ExecuteUpdate(const FlightCallOptions& options = {});
/// \brief Retrieve the parameter schema from the query.
/// \return The parameter schema from the query.
@@ -245,10 +300,15 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
/// \return Status.
Status SetParameters(std::shared_ptr<RecordBatch> parameter_binding);
+ /// \brief Re-request the result set schema from the server (should
+ /// be identical to dataset_schema).
+ arrow::Result<std::unique_ptr<SchemaResult>> GetSchema(
+ const FlightCallOptions& options = {});
+
/// \brief Close the prepared statement, so that this PreparedStatement can not used
/// anymore and server can free up any resources.
/// \return Status.
- Status Close();
+ Status Close(const FlightCallOptions& options = {});
/// \brief Check if the prepared statement is closed.
/// \return The state of the prepared statement.
@@ -256,7 +316,6 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
private:
FlightSqlClient* client_;
- FlightCallOptions options_;
std::string handle_;
std::shared_ptr<Schema> dataset_schema_;
std::shared_ptr<Schema> parameter_schema_;
diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc
index 0ebe647ba1..78fbff0c33 100644
--- a/cpp/src/arrow/flight/sql/server.cc
+++ b/cpp/src/arrow/flight/sql/server.cc
@@ -344,6 +344,72 @@ Status FlightSqlServerBase::GetFlightInfo(const ServerCallContext& context,
return Status::Invalid("The defined request is invalid.");
}
+Status FlightSqlServerBase::GetSchema(const ServerCallContext& context,
+ const FlightDescriptor& request,
+ std::unique_ptr<SchemaResult>* schema) {
+ google::protobuf::Any any;
+ if (!any.ParseFromArray(request.cmd.data(), static_cast<int>(request.cmd.size()))) {
+ return Status::Invalid("Unable to parse command");
+ }
+
+ if (any.Is<pb::sql::CommandStatementQuery>()) {
+ ARROW_ASSIGN_OR_RAISE(StatementQuery internal_command,
+ ParseCommandStatementQuery(any));
+ ARROW_ASSIGN_OR_RAISE(*schema,
+ GetSchemaStatement(context, internal_command, request));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandPreparedStatementQuery>()) {
+ ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command,
+ ParseCommandPreparedStatementQuery(any));
+ ARROW_ASSIGN_OR_RAISE(*schema,
+ GetSchemaPreparedStatement(context, internal_command, request));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetCatalogs>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetCatalogsSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetCrossReference>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema,
+ SchemaResult::Make(*SqlSchema::GetCrossReferenceSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetDbSchemas>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetDbSchemasSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetExportedKeys>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema,
+ SchemaResult::Make(*SqlSchema::GetExportedKeysSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetImportedKeys>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema,
+ SchemaResult::Make(*SqlSchema::GetImportedKeysSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetPrimaryKeys>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema,
+ SchemaResult::Make(*SqlSchema::GetPrimaryKeysSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetSqlInfo>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetSqlInfoSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetTables>()) {
+ ARROW_ASSIGN_OR_RAISE(GetTables command, ParseCommandGetTables(any));
+ if (command.include_schema) {
+ ARROW_ASSIGN_OR_RAISE(
+ *schema, SchemaResult::Make(*SqlSchema::GetTablesSchemaWithIncludedSchema()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetTablesSchema()));
+ }
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetTableTypes>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetTableTypesSchema()));
+ return Status::OK();
+ } else if (any.Is<pb::sql::CommandGetXdbcTypeInfo>()) {
+ ARROW_ASSIGN_OR_RAISE(*schema,
+ SchemaResult::Make(*SqlSchema::GetXdbcTypeInfoSchema()));
+ return Status::OK();
+ }
+
+ return Status::NotImplemented("Command not recognized: ", any.type_url());
+}
+
Status FlightSqlServerBase::DoGet(const ServerCallContext& context, const Ticket& request,
std::unique_ptr<FlightDataStream>* stream) {
google::protobuf::Any any;
@@ -531,6 +597,12 @@ arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlServerBase::GetFlightInfoSta
return Status::NotImplemented("GetFlightInfoStatement not implemented");
}
+arrow::Result<std::unique_ptr<SchemaResult>> FlightSqlServerBase::GetSchemaStatement(
+ const ServerCallContext& context, const StatementQuery& command,
+ const FlightDescriptor& descriptor) {
+ return Status::NotImplemented("GetSchemaStatement not implemented");
+}
+
arrow::Result<std::unique_ptr<FlightDataStream>> FlightSqlServerBase::DoGetStatement(
const ServerCallContext& context, const StatementQueryTicket& command) {
return Status::NotImplemented("DoGetStatement not implemented");
@@ -543,6 +615,13 @@ FlightSqlServerBase::GetFlightInfoPreparedStatement(const ServerCallContext& con
return Status::NotImplemented("GetFlightInfoPreparedStatement not implemented");
}
+arrow::Result<std::unique_ptr<SchemaResult>>
+FlightSqlServerBase::GetSchemaPreparedStatement(const ServerCallContext& context,
+ const PreparedStatementQuery& command,
+ const FlightDescriptor& descriptor) {
+ return Status::NotImplemented("GetSchemaPreparedStatement not implemented");
+}
+
arrow::Result<std::unique_ptr<FlightDataStream>>
FlightSqlServerBase::DoGetPreparedStatement(const ServerCallContext& context,
const PreparedStatementQuery& command) {
diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h
index f077c5d5d5..49e239a0cd 100644
--- a/cpp/src/arrow/flight/sql/server.h
+++ b/cpp/src/arrow/flight/sql/server.h
@@ -28,6 +28,7 @@
#include "arrow/flight/sql/server.h"
#include "arrow/flight/sql/types.h"
#include "arrow/flight/sql/visibility.h"
+#include "arrow/flight/types.h"
#include "arrow/util/optional.h"
namespace arrow {
@@ -221,6 +222,25 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase {
virtual arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoCatalogs(
const ServerCallContext& context, const FlightDescriptor& descriptor);
+ /// \brief Get the schema of the result set of a query.
+ /// \param[in] context Per-call context.
+ /// \param[in] command The StatementQuery containing the SQL query.
+ /// \param[in] descriptor The descriptor identifying the data stream.
+ /// \return The schema of the result set.
+ virtual arrow::Result<std::unique_ptr<SchemaResult>> GetSchemaStatement(
+ const ServerCallContext& context, const StatementQuery& command,
+ const FlightDescriptor& descriptor);
+
+ /// \brief Get the schema of the result set of a prepared statement.
+ /// \param[in] context Per-call context.
+ /// \param[in] command The PreparedStatementQuery containing the
+ /// prepared statement handle.
+ /// \param[in] descriptor The descriptor identifying the data stream.
+ /// \return The schema of the result set.
+ virtual arrow::Result<std::unique_ptr<SchemaResult>> GetSchemaPreparedStatement(
+ const ServerCallContext& context, const PreparedStatementQuery& command,
+ const FlightDescriptor& descriptor);
+
/// \brief Get a FlightDataStream containing the list of catalogs.
/// \param[in] context Per-call context.
/// \return An interface for sending data back to the client.
@@ -462,6 +482,9 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase {
Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
std::unique_ptr<FlightInfo>* info) final;
+ Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<SchemaResult>* schema) override;
+
Status DoGet(const ServerCallContext& context, const Ticket& request,
std::unique_ptr<FlightDataStream>* stream) final;
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index ddb8a036fb..6e80f40cfb 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -28,6 +28,7 @@
#include "arrow/ipc/reader.h"
#include "arrow/status.h"
#include "arrow/table.h"
+#include "arrow/util/make_unique.h"
#include "arrow/util/string_view.h"
#include "arrow/util/uri.h"
@@ -150,10 +151,10 @@ arrow::Result<std::shared_ptr<Schema>> SchemaResult::GetSchema(
return ipc::ReadSchema(&schema_reader, dictionary_memo);
}
-arrow::Result<SchemaResult> SchemaResult::Make(const Schema& schema) {
+arrow::Result<std::unique_ptr<SchemaResult>> SchemaResult::Make(const Schema& schema) {
std::string schema_in;
RETURN_NOT_OK(internal::SchemaToString(schema, &schema_in));
- return SchemaResult(std::move(schema_in));
+ return arrow::internal::make_unique<SchemaResult>(std::move(schema_in));
}
Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo,
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index a061f33afe..2ec24ff586 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -397,7 +397,7 @@ struct ARROW_FLIGHT_EXPORT SchemaResult {
explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {}
/// \brief Factory method to construct a SchemaResult.
- static arrow::Result<SchemaResult> Make(const Schema& schema);
+ static arrow::Result<std::unique_ptr<SchemaResult>> Make(const Schema& schema);
/// \brief return schema
/// \param[in,out] dictionary_memo for dictionary bookkeeping, will
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index 9077bbe4ac..bf7af27ac7 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -380,10 +380,7 @@ Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
std::unique_ptr<arrow::flight::SchemaResult>* out) {
- ARROW_ASSIGN_OR_RAISE(auto result, arrow::flight::SchemaResult::Make(*schema));
- *out = std::unique_ptr<arrow::flight::SchemaResult>(
- new arrow::flight::SchemaResult(std::move(result)));
- return Status::OK();
+ return arrow::flight::SchemaResult::Make(*schema).Value(out);
}
} // namespace flight
diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go
index 5f7f693d2b..b8ee01cfde 100644
--- a/go/arrow/flight/flightsql/client.go
+++ b/go/arrow/flight/flightsql/client.go
@@ -77,6 +77,14 @@ func flightInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, op
return cl.getFlightInfo(ctx, desc, opts...)
}
+func schemaForCommand(ctx context.Context, cl *Client, cmd proto.Message, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ desc, err := descForCommand(cmd)
+ if err != nil {
+ return nil, err
+ }
+ return cl.getSchema(ctx, desc, opts...)
+}
+
// Execute executes the desired query on the server and returns a FlightInfo
// object describing where to retrieve the results.
func (c *Client) Execute(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
@@ -84,6 +92,13 @@ func (c *Client) Execute(ctx context.Context, query string, opts ...grpc.CallOpt
return flightInfoForCommand(ctx, c, &cmd, opts...)
}
+// GetExecuteSchema gets the schema of the result set of a query without
+// executing the query itself.
+func (c *Client) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ cmd := pb.CommandStatementQuery{Query: query}
+ return schemaForCommand(ctx, c, &cmd, opts...)
+}
+
// ExecuteUpdate is for executing an update query and only returns the number of affected rows.
func (c *Client) ExecuteUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) {
var (
@@ -128,12 +143,22 @@ func (c *Client) GetCatalogs(ctx context.Context, opts ...grpc.CallOption) (*fli
return flightInfoForCommand(ctx, c, &pb.CommandGetCatalogs{}, opts...)
}
+// GetCatalogsSchema requests the schema of GetCatalogs from the server
+func (c *Client) GetCatalogsSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetCatalogs{}, opts...)
+}
+
// GetDBSchemas requests the list of schemas from the database and
// returns a FlightInfo object where the response can be retrieved
func (c *Client) GetDBSchemas(ctx context.Context, cmdOpts *GetDBSchemasOpts, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
return flightInfoForCommand(ctx, c, (*pb.CommandGetDbSchemas)(cmdOpts), opts...)
}
+// GetDBSchemasSchema requests the schema of GetDBSchemas from the server
+func (c *Client) GetDBSchemasSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetDbSchemas{}, opts...)
+}
+
// DoGet uses the provided flight ticket to request the stream of data.
// It returns a recordbatch reader to stream the results. Release
// should be called on the reader when done.
@@ -154,6 +179,11 @@ func (c *Client) GetTables(ctx context.Context, reqOptions *GetTablesOpts, opts
return flightInfoForCommand(ctx, c, (*pb.CommandGetTables)(reqOptions), opts...)
}
+// GetTablesSchema requests the schema of GetTables from the server.
+func (c *Client) GetTablesSchema(ctx context.Context, reqOptions *GetTablesOpts, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, (*pb.CommandGetTables)(reqOptions), opts...)
+}
+
// GetPrimaryKeys requests the primary keys for a specific table from the
// server, specified using a TableRef. Returns a FlightInfo object where
// the response can be retrieved.
@@ -166,6 +196,11 @@ func (c *Client) GetPrimaryKeys(ctx context.Context, ref TableRef, opts ...grpc.
return flightInfoForCommand(ctx, c, &cmd, opts...)
}
+// GetPrimaryKeysSchema requests the schema of GetPrimaryKeys from the server.
+func (c *Client) GetPrimaryKeysSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetPrimaryKeys{}, opts...)
+}
+
// GetExportedKeys retrieves a description about the foreign key columns
// that reference the primary key columns of the specified table. Returns
// a FlightInfo object where the response can be retrieved.
@@ -178,6 +213,11 @@ func (c *Client) GetExportedKeys(ctx context.Context, ref TableRef, opts ...grpc
return flightInfoForCommand(ctx, c, &cmd, opts...)
}
+// GetExportedKeysSchema requests the schema of GetExportedKeys from the server.
+func (c *Client) GetExportedKeysSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetExportedKeys{}, opts...)
+}
+
// GetImportedKeys returns the foreign key columns for the specified table.
// Returns a FlightInfo object indicating where the response can be retrieved.
func (c *Client) GetImportedKeys(ctx context.Context, ref TableRef, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
@@ -189,6 +229,11 @@ func (c *Client) GetImportedKeys(ctx context.Context, ref TableRef, opts ...grpc
return flightInfoForCommand(ctx, c, &cmd, opts...)
}
+// GetImportedKeysSchema requests the schema of GetImportedKeys from the server.
+func (c *Client) GetImportedKeysSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetImportedKeys{}, opts...)
+}
+
// GetCrossReference retrieves a description of the foreign key columns
// in the specified ForeignKey table that reference the primary key or
// columns representing a restraint of the parent table (could be the same
@@ -206,6 +251,11 @@ func (c *Client) GetCrossReference(ctx context.Context, pkTable, fkTable TableRe
return flightInfoForCommand(ctx, c, &cmd, opts...)
}
+// GetCrossReferenceSchema requests the schema of GetCrossReference from the server.
+func (c *Client) GetCrossReferenceSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetCrossReference{}, opts...)
+}
+
// GetTableTypes requests a list of the types of tables available on this
// server. Returns a FlightInfo object indicating where the response can
// be retrieved.
@@ -213,6 +263,11 @@ func (c *Client) GetTableTypes(ctx context.Context, opts ...grpc.CallOption) (*f
return flightInfoForCommand(ctx, c, &pb.CommandGetTableTypes{}, opts...)
}
+// GetTableTypesSchema requests the schema of GetTableTypes from the server.
+func (c *Client) GetTableTypesSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetTableTypes{}, opts...)
+}
+
// GetXdbcTypeInfo requests the information about all the data types supported
// (dataType == nil) or a specific data type. Returns a FlightInfo object
// indicating where the response can be retrieved.
@@ -220,6 +275,11 @@ func (c *Client) GetXdbcTypeInfo(ctx context.Context, dataType *int32, opts ...g
return flightInfoForCommand(ctx, c, &pb.CommandGetXdbcTypeInfo{DataType: dataType}, opts...)
}
+// GetXdbcTypeInfoSchema requests the schema of GetXdbcTypeInfo from the server.
+func (c *Client) GetXdbcTypeInfoSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetXdbcTypeInfo{}, opts...)
+}
+
// GetSqlInfo returns a list of the requested SQL information corresponding
// to the values in the info slice. Returns a FlightInfo object indicating
// where the response can be retrieved.
@@ -232,6 +292,11 @@ func (c *Client) GetSqlInfo(ctx context.Context, info []SqlInfo, opts ...grpc.Ca
return flightInfoForCommand(ctx, c, cmd, opts...)
}
+// GetSqlInfoSchema requests the schema of GetSqlInfo from the server.
+func (c *Client) GetSqlInfoSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return schemaForCommand(ctx, c, &pb.CommandGetSqlInfo{}, opts...)
+}
+
// Prepare creates a PreparedStatement object for the specified query.
// The resulting PreparedStatement object should be Closed when no longer
// needed. It will maintain a reference to this Client for use to execute
@@ -302,6 +367,10 @@ func (c *Client) getFlightInfo(ctx context.Context, desc *flight.FlightDescripto
return c.Client.GetFlightInfo(ctx, desc, opts...)
}
+func (c *Client) getSchema(ctx context.Context, desc *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
+ return c.Client.GetSchema(ctx, desc, opts...)
+}
+
// Close will close the underlying flight Client in use by this flightsql.Client
func (c *Client) Close() error { return c.Client.Close() }
@@ -430,6 +499,25 @@ func (p *PreparedStatement) DatasetSchema() *arrow.Schema { return p.datasetSche
// the prepared statement.
func (p *PreparedStatement) ParameterSchema() *arrow.Schema { return p.paramSchema }
+// GetSchema re-requests the schema of the result set of the prepared
+// statement from the server. It should otherwise be identical to DatasetSchema.
+//
+// Will error if already closed.
+func (p *PreparedStatement) GetSchema(ctx context.Context) (*flight.SchemaResult, error) {
+ if p.closed {
+ return nil, errors.New("arrow/flightsql: prepared statement already closed")
+ }
+
+ cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle}
+
+ desc, err := descForCommand(cmd)
+ if err != nil {
+ return nil, err
+ }
+
+ return p.client.getSchema(ctx, desc, p.opts...)
+}
+
// SetParameters takes a record batch to send as the parameter bindings when
// executing. It should match the schema from ParameterSchema.
//
diff --git a/go/arrow/flight/flightsql/server.go b/go/arrow/flight/flightsql/server.go
index 17bc9e188a..8080df9e4b 100644
--- a/go/arrow/flight/flightsql/server.go
+++ b/go/arrow/flight/flightsql/server.go
@@ -181,6 +181,10 @@ func (BaseServer) GetFlightInfoStatement(context.Context, StatementQuery, *fligh
return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoStatement not implemented")
}
+func (BaseServer) GetSchemaStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) {
+ return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented")
+}
+
func (BaseServer) DoGetStatement(context.Context, StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return nil, nil, status.Errorf(codes.Unimplemented, "DoGetStatement not implemented")
}
@@ -189,6 +193,10 @@ func (BaseServer) GetFlightInfoPreparedStatement(context.Context, PreparedStatem
return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoPreparedStatement not implemented")
}
+func (BaseServer) GetSchemaPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) {
+ return nil, status.Errorf(codes.Unimplemented, "GetSchemaPreparedStatement not implemented")
+}
+
func (BaseServer) DoGetPreparedStatement(context.Context, PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return nil, nil, status.Errorf(codes.Unimplemented, "DoGetPreparedStatement not implemented")
}
@@ -367,12 +375,17 @@ func (BaseServer) DoPutPreparedStatementUpdate(context.Context, PreparedStatemen
type Server interface {
// GetFlightInfoStatement returns a FlightInfo for executing the requested sql query
GetFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error)
+ // GetFlightInfoStatement returns the schema of the result set of the requested sql query
+ GetSchemaStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error)
// DoGetStatement returns a stream containing the query results for the
// requested statement handle that was populated by GetFlightInfoStatement
DoGetStatement(context.Context, StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error)
// GetFlightInfoPreparedStatement returns a FlightInfo for executing an already
// prepared statement with the provided statement handle.
GetFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error)
+ // GetSchemaPreparedStatement returns the schema of the result set of executing an already
+ // prepared statement with the provided statement handle.
+ GetSchemaPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error)
// DoGetPreparedStatement returns a stream containing the results from executing
// a prepared statement query with the provided statement handle.
DoGetPreparedStatement(context.Context, PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error)
@@ -519,6 +532,53 @@ func (f *flightSqlServer) GetFlightInfo(ctx context.Context, request *flight.Fli
return nil, status.Error(codes.InvalidArgument, "requested command is invalid")
}
+func (f *flightSqlServer) GetSchema(ctx context.Context, request *flight.FlightDescriptor) (*flight.SchemaResult, error) {
+ var (
+ anycmd anypb.Any
+ cmd proto.Message
+ err error
+ )
+ if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil {
+ return nil, status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error())
+ }
+
+ if cmd, err = anycmd.UnmarshalNew(); err != nil {
+ return nil, status.Errorf(codes.InvalidArgument, "could not unmarshal Any to a command type: %s", err.Error())
+ }
+
+ switch cmd := cmd.(type) {
+ case *pb.CommandStatementQuery:
+ return f.srv.GetSchemaStatement(ctx, cmd, request)
+ case *pb.CommandPreparedStatementQuery:
+ return f.srv.GetSchemaPreparedStatement(ctx, cmd, request)
+ case *pb.CommandGetCatalogs:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.Catalogs, f.mem)}, nil
+ case *pb.CommandGetDbSchemas:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.DBSchemas, f.mem)}, nil
+ case *pb.CommandGetTables:
+ if cmd.GetIncludeSchema() {
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.TablesWithIncludedSchema, f.mem)}, nil
+ }
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.Tables, f.mem)}, nil
+ case *pb.CommandGetTableTypes:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.TableTypes, f.mem)}, nil
+ case *pb.CommandGetXdbcTypeInfo:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.XdbcTypeInfo, f.mem)}, nil
+ case *pb.CommandGetSqlInfo:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.SqlInfo, f.mem)}, nil
+ case *pb.CommandGetPrimaryKeys:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.PrimaryKeys, f.mem)}, nil
+ case *pb.CommandGetExportedKeys:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.ExportedKeys, f.mem)}, nil
+ case *pb.CommandGetImportedKeys:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.ImportedKeys, f.mem)}, nil
+ case *pb.CommandGetCrossReference:
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.CrossReference, f.mem)}, nil
+ }
+
+ return nil, status.Errorf(codes.InvalidArgument, "requested command is invalid: %s", anycmd.GetTypeUrl())
+}
+
func (f *flightSqlServer) DoGet(request *flight.Ticket, stream flight.FlightService_DoGetServer) (err error) {
var (
anycmd anypb.Any
diff --git a/go/arrow/internal/flight_integration/scenario.go b/go/arrow/internal/flight_integration/scenario.go
index c89334002d..4e96d7100a 100644
--- a/go/arrow/internal/flight_integration/scenario.go
+++ b/go/arrow/internal/flight_integration/scenario.go
@@ -599,6 +599,22 @@ func (m *flightSqlScenarioTester) validate(expected *arrow.Schema, result *fligh
if !expected.Equal(rdr.Schema()) {
return fmt.Errorf("expected: %s, got: %s", expected, rdr.Schema())
}
+ for {
+ _, err := rdr.Read()
+ if err == io.EOF { break }
+ if err != nil { return err }
+ }
+ return nil
+}
+
+func (m *flightSqlScenarioTester) validateSchema(expected *arrow.Schema, result *flight.SchemaResult) error {
+ schema, err := flight.DeserializeSchema(result.GetSchema(), memory.DefaultAllocator)
+ if err != nil {
+ return err
+ }
+ if !expected.Equal(schema) {
+ return fmt.Errorf("expected: %s, got: %s", expected, schema)
+ }
return nil
}
@@ -626,6 +642,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err := client.GetCatalogsSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err := m.validateSchema(schema_ref.Catalogs, schema); err != nil {
+ return err
+ }
+
info, err = client.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{Catalog: &catalog, DbSchemaFilterPattern: &dbSchemaFilterPattern})
if err != nil {
return err
@@ -634,6 +658,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetDBSchemasSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.DBSchemas, schema); err != nil {
+ return err
+ }
+
info, err = client.GetTables(ctx, &flightsql.GetTablesOpts{Catalog: &catalog, DbSchemaFilterPattern: &dbSchemaFilterPattern, TableNameFilterPattern: &tableFilterPattern, IncludeSchema: true, TableTypes: tableTypes})
if err != nil {
return err
@@ -642,6 +674,22 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetTablesSchema(ctx, &flightsql.GetTablesOpts{IncludeSchema: true})
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.TablesWithIncludedSchema, schema); err != nil {
+ return err
+ }
+
+ schema, err = client.GetTablesSchema(ctx, &flightsql.GetTablesOpts{IncludeSchema: false})
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.Tables, schema); err != nil {
+ return err
+ }
+
info, err = client.GetTableTypes(ctx)
if err != nil {
return err
@@ -650,6 +698,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetTableTypesSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.TableTypes, schema); err != nil {
+ return err
+ }
+
info, err = client.GetPrimaryKeys(ctx, ref)
if err != nil {
return err
@@ -658,6 +714,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetPrimaryKeysSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.PrimaryKeys, schema); err != nil {
+ return err
+ }
+
info, err = client.GetExportedKeys(ctx, ref)
if err != nil {
return err
@@ -666,6 +730,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetExportedKeysSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.ExportedKeys, schema); err != nil {
+ return err
+ }
+
info, err = client.GetImportedKeys(ctx, ref)
if err != nil {
return err
@@ -674,6 +746,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetImportedKeysSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.ImportedKeys, schema); err != nil {
+ return err
+ }
+
info, err = client.GetCrossReference(ctx, pkRef, fkRef)
if err != nil {
return err
@@ -682,6 +762,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetCrossReferenceSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.CrossReference, schema); err != nil {
+ return err
+ }
+
info, err = client.GetXdbcTypeInfo(ctx, nil)
if err != nil {
return err
@@ -690,6 +778,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetXdbcTypeInfoSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.XdbcTypeInfo, schema); err != nil {
+ return err
+ }
+
info, err = client.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerName, flightsql.SqlInfoFlightSqlServerReadOnly})
if err != nil {
return err
@@ -698,6 +794,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl
return err
}
+ schema, err = client.GetSqlInfoSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(schema_ref.SqlInfo, schema); err != nil {
+ return err
+ }
+
return nil
}
@@ -711,6 +815,14 @@ func (m *flightSqlScenarioTester) ValidateStatementExecution(client *flightsql.C
return err
}
+ schema, err := client.GetExecuteSchema(ctx, "SELECT STATEMENT")
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(QuerySchema, schema); err != nil {
+ return err
+ }
+
updateResult, err := client.ExecuteUpdate(ctx, "UPDATE STATEMENT")
if err != nil {
return err
@@ -740,6 +852,13 @@ func (m *flightSqlScenarioTester) ValidatePreparedStatementExecution(client *fli
if err = m.validate(QuerySchema, info, client); err != nil {
return err
}
+ schema, err := prepared.GetSchema(ctx)
+ if err != nil {
+ return err
+ }
+ if err = m.validateSchema(QuerySchema, schema); err != nil {
+ return err
+ }
if err = prepared.Close(ctx); err != nil {
return err
@@ -762,9 +881,7 @@ func (m *flightSqlScenarioTester) ValidatePreparedStatementExecution(client *fli
func (m *flightSqlScenarioTester) doGetForTestCase(schema *arrow.Schema) chan flight.StreamChunk {
ch := make(chan flight.StreamChunk)
- go func() {
- ch <- flight.StreamChunk{Data: array.NewRecord(schema, []arrow.Array{}, 0)}
- }()
+ close(ch)
return ch
}
@@ -789,6 +906,10 @@ func (m *flightSqlScenarioTester) GetFlightInfoStatement(ctx context.Context, cm
}, nil
}
+func (m *flightSqlScenarioTester) GetSchemaStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) {
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(QuerySchema, memory.DefaultAllocator)}, nil
+}
+
func (m *flightSqlScenarioTester) DoGetStatement(ctx context.Context, cmd flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return QuerySchema, m.doGetForTestCase(QuerySchema), nil
}
@@ -801,6 +922,10 @@ func (m *flightSqlScenarioTester) GetFlightInfoPreparedStatement(_ context.Conte
return m.flightInfoForCommand(desc, QuerySchema), nil
}
+func (m *flightSqlScenarioTester) GetSchemaPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) {
+ return &flight.SchemaResult{Schema: flight.SerializeSchema(QuerySchema, memory.DefaultAllocator)}, nil
+}
+
func (m *flightSqlScenarioTester) DoGetPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return QuerySchema, m.doGetForTestCase(QuerySchema), nil
}
diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java
index cf17349064..19c1378cfe 100644
--- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java
+++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java
@@ -26,6 +26,7 @@ import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.SchemaResult;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlProducer;
@@ -72,32 +73,52 @@ public class FlightSqlScenario implements Scenario {
validate(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, sqlClient.getCatalogs(options),
sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, sqlClient.getCatalogsSchema(options));
+
validate(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA,
sqlClient.getSchemas("catalog", "db_schema_filter_pattern", options),
sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA, sqlClient.getSchemasSchema());
+
validate(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA,
sqlClient.getTables("catalog", "db_schema_filter_pattern", "table_filter_pattern",
Arrays.asList("table", "view"), true, options), sqlClient);
- validate(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypes(options),
- sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA,
+ sqlClient.getTablesSchema(/*includeSchema*/true, options));
+ validateSchema(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA,
+ sqlClient.getTablesSchema(/*includeSchema*/false, options));
+
+ validate(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypes(options), sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypesSchema(options));
+
validate(FlightSqlProducer.Schemas.GET_PRIMARY_KEYS_SCHEMA,
sqlClient.getPrimaryKeys(TableRef.of("catalog", "db_schema", "table"), options),
sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_PRIMARY_KEYS_SCHEMA, sqlClient.getPrimaryKeysSchema(options));
+
validate(FlightSqlProducer.Schemas.GET_EXPORTED_KEYS_SCHEMA,
sqlClient.getExportedKeys(TableRef.of("catalog", "db_schema", "table"), options),
sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_EXPORTED_KEYS_SCHEMA, sqlClient.getExportedKeysSchema(options));
+
validate(FlightSqlProducer.Schemas.GET_IMPORTED_KEYS_SCHEMA,
sqlClient.getImportedKeys(TableRef.of("catalog", "db_schema", "table"), options),
sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_IMPORTED_KEYS_SCHEMA, sqlClient.getImportedKeysSchema(options));
+
validate(FlightSqlProducer.Schemas.GET_CROSS_REFERENCE_SCHEMA,
sqlClient.getCrossReference(TableRef.of("pk_catalog", "pk_db_schema", "pk_table"),
TableRef.of("fk_catalog", "fk_db_schema", "fk_table"), options),
sqlClient);
- validate(FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA,
- sqlClient.getXdbcTypeInfo(options), sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_CROSS_REFERENCE_SCHEMA, sqlClient.getCrossReferenceSchema(options));
+
+ validate(FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA, sqlClient.getXdbcTypeInfo(options), sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA, sqlClient.getXdbcTypeInfoSchema(options));
+
validate(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA,
sqlClient.getSqlInfo(new FlightSql.SqlInfo[] {FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME,
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY}, options), sqlClient);
+ validateSchema(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlClient.getSqlInfoSchema(options));
}
private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception {
@@ -105,6 +126,8 @@ public class FlightSqlScenario implements Scenario {
validate(FlightSqlScenarioProducer.getQuerySchema(),
sqlClient.execute("SELECT STATEMENT", options), sqlClient);
+ validateSchema(FlightSqlScenarioProducer.getQuerySchema(),
+ sqlClient.getExecuteSchema("SELECT STATEMENT", options));
IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options),
UPDATE_STATEMENT_EXPECTED_ROWS);
@@ -122,6 +145,7 @@ public class FlightSqlScenario implements Scenario {
validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options),
sqlClient);
+ validateSchema(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.fetchSchema());
}
try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(
@@ -139,4 +163,8 @@ public class FlightSqlScenario implements Scenario {
IntegrationAssertions.assertEquals(expectedSchema, actualSchema);
}
}
+
+ private void validateSchema(Schema expected, SchemaResult actual) {
+ IntegrationAssertions.assertEquals(expected, actual.getSchema());
+ }
}
diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java
index 7db99187c4..33d62b650e 100644
--- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java
+++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java
@@ -125,9 +125,18 @@ public class FlightSqlScenarioProducer implements FlightSqlProducer {
return getFlightInfoForSchema(command, descriptor, getQuerySchema());
}
+ @Override
+ public SchemaResult getSchemaPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context,
+ FlightDescriptor descriptor) {
+ IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(),
+ "SELECT PREPARED STATEMENT HANDLE");
+ return new SchemaResult(getQuerySchema());
+ }
+
@Override
public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command,
CallContext context, FlightDescriptor descriptor) {
+ IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT");
return new SchemaResult(getQuerySchema());
}
diff --git a/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
new file mode 100644
index 0000000000..dfb9a81085
--- /dev/null
+++ b/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.integration.tests;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Run the integration test scenarios in-process.
+ */
+class IntegrationTest {
+ @Test
+ void authBasicProto() throws Exception {
+ testScenario("auth:basic_proto");
+ }
+
+ @Test
+ void middleware() throws Exception {
+ testScenario("middleware");
+ }
+
+ @Test
+ void flightSql() throws Exception {
+ testScenario("flight_sql");
+ }
+
+ void testScenario(String scenarioName) throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator()) {
+ final FlightServer.Builder builder = FlightServer.builder()
+ .allocator(allocator)
+ .location(Location.forGrpcInsecure("0.0.0.0", 0));
+ final Scenario scenario = Scenarios.getScenario(scenarioName);
+ scenario.buildServer(builder);
+ builder.producer(scenario.producer(allocator, Location.forGrpcInsecure("0.0.0.0", 0)));
+
+ try (final FlightServer server = builder.build()) {
+ server.start();
+
+ final Location location = Location.forGrpcInsecure("localhost", server.getPort());
+ try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
+ scenario.client(allocator, location, client);
+ }
+ }
+ }
+ }
+}
diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java
index dd9480f400..f1f07a1588 100644
--- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java
+++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java
@@ -97,6 +97,16 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of the result set of a query.
+ */
+ public SchemaResult getExecuteSchema(final String query, final CallOption... options) {
+ final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder();
+ builder.setQuery(query);
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Execute an update query on the server.
*
@@ -137,6 +147,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getCatalogs(CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_CATALOGS_SCHEMA}.
+ */
+ public SchemaResult getCatalogsSchema(final CallOption... options) {
+ final CommandGetCatalogs command = CommandGetCatalogs.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Request a list of schemas.
*
@@ -160,6 +181,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getSchemas(String, String, CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_SCHEMAS_SCHEMA}.
+ */
+ public SchemaResult getSchemasSchema(final CallOption... options) {
+ final CommandGetDbSchemas command = CommandGetDbSchemas.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Get schema for a stream.
*
@@ -231,6 +263,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getSqlInfo(SqlInfo...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_SQL_INFO_SCHEMA}.
+ */
+ public SchemaResult getSqlInfoSchema(final CallOption... options) {
+ final CommandGetSqlInfo command = CommandGetSqlInfo.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Request the information about the data types supported related to
* a filter data type.
@@ -261,6 +304,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getXdbcTypeInfo(CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_TYPE_INFO_SCHEMA}.
+ */
+ public SchemaResult getXdbcTypeInfoSchema(final CallOption... options) {
+ final CommandGetXdbcTypeInfo command = CommandGetXdbcTypeInfo.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Request a list of tables.
*
@@ -298,6 +352,18 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getTables(String, String, String, List, boolean, CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_TABLES_SCHEMA} or
+ * {@link FlightSqlProducer.Schemas#GET_TABLES_SCHEMA_NO_SCHEMA}.
+ */
+ public SchemaResult getTablesSchema(boolean includeSchema, final CallOption... options) {
+ final CommandGetTables command = CommandGetTables.newBuilder().setIncludeSchema(includeSchema).build();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Request the primary keys for a table.
*
@@ -323,6 +389,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getPrimaryKeys(TableRef, CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_PRIMARY_KEYS_SCHEMA}.
+ */
+ public SchemaResult getPrimaryKeysSchema(final CallOption... options) {
+ final CommandGetPrimaryKeys command = CommandGetPrimaryKeys.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Retrieves a description about the foreign key columns that reference the primary key columns of the given table.
*
@@ -350,6 +427,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getExportedKeys(TableRef, CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_EXPORTED_KEYS_SCHEMA}.
+ */
+ public SchemaResult getExportedKeysSchema(final CallOption... options) {
+ final CommandGetExportedKeys command = CommandGetExportedKeys.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Retrieves the foreign key columns for the given table.
*
@@ -378,6 +466,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getImportedKeys(TableRef, CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_IMPORTED_KEYS_SCHEMA}.
+ */
+ public SchemaResult getImportedKeysSchema(final CallOption... options) {
+ final CommandGetImportedKeys command = CommandGetImportedKeys.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Retrieves a description of the foreign key columns that reference the given table's
* primary key columns (the foreign keys exported by a table).
@@ -417,6 +516,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getCrossReference(TableRef, TableRef, CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_CROSS_REFERENCE_SCHEMA}.
+ */
+ public SchemaResult getCrossReferenceSchema(final CallOption... options) {
+ final CommandGetCrossReference command = CommandGetCrossReference.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Request a list of table types.
*
@@ -429,6 +539,17 @@ public class FlightSqlClient implements AutoCloseable {
return client.getInfo(descriptor, options);
}
+ /**
+ * Get the schema of {@link #getTableTypes(CallOption...)} from the server.
+ *
+ * <p>Should be identical to {@link FlightSqlProducer.Schemas#GET_TABLE_TYPES_SCHEMA}.
+ */
+ public SchemaResult getTableTypesSchema(final CallOption... options) {
+ final CommandGetTableTypes command = CommandGetTableTypes.getDefaultInstance();
+ final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
/**
* Create a prepared statement on the server.
*
@@ -534,6 +655,20 @@ public class FlightSqlClient implements AutoCloseable {
return parameterSchema;
}
+ /**
+ * Get the schema of the result set (should be identical to {@link #getResultSetSchema()}).
+ */
+ public SchemaResult fetchSchema(CallOption... options) {
+ checkOpen();
+
+ final FlightDescriptor descriptor = FlightDescriptor
+ .command(Any.pack(CommandPreparedStatementQuery.newBuilder()
+ .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle())
+ .build())
+ .toByteArray());
+ return client.getSchema(descriptor, options);
+ }
+
private Schema deserializeSchema(final ByteString bytes) {
try {
return bytes.isEmpty() ?
diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java
index c617c6a03e..4226ec9e22 100644
--- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java
+++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java
@@ -147,26 +147,32 @@ public interface FlightSqlProducer extends FlightProducer, AutoCloseable {
if (command.is(CommandStatementQuery.class)) {
return getSchemaStatement(
FlightSqlUtils.unpackOrThrow(command, CommandStatementQuery.class), context, descriptor);
+ } else if (command.is(CommandPreparedStatementQuery.class)) {
+ return getSchemaPreparedStatement(
+ FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, descriptor);
} else if (command.is(CommandGetCatalogs.class)) {
return new SchemaResult(Schemas.GET_CATALOGS_SCHEMA);
+ } else if (command.is(CommandGetCrossReference.class)) {
+ return new SchemaResult(Schemas.GET_CROSS_REFERENCE_SCHEMA);
} else if (command.is(CommandGetDbSchemas.class)) {
return new SchemaResult(Schemas.GET_SCHEMAS_SCHEMA);
+ } else if (command.is(CommandGetExportedKeys.class)) {
+ return new SchemaResult(Schemas.GET_EXPORTED_KEYS_SCHEMA);
+ } else if (command.is(CommandGetImportedKeys.class)) {
+ return new SchemaResult(Schemas.GET_IMPORTED_KEYS_SCHEMA);
+ } else if (command.is(CommandGetPrimaryKeys.class)) {
+ return new SchemaResult(Schemas.GET_PRIMARY_KEYS_SCHEMA);
} else if (command.is(CommandGetTables.class)) {
- return new SchemaResult(Schemas.GET_TABLES_SCHEMA);
+ if (FlightSqlUtils.unpackOrThrow(command, CommandGetTables.class).getIncludeSchema()) {
+ return new SchemaResult(Schemas.GET_TABLES_SCHEMA);
+ }
+ return new SchemaResult(Schemas.GET_TABLES_SCHEMA_NO_SCHEMA);
} else if (command.is(CommandGetTableTypes.class)) {
return new SchemaResult(Schemas.GET_TABLE_TYPES_SCHEMA);
} else if (command.is(CommandGetSqlInfo.class)) {
return new SchemaResult(Schemas.GET_SQL_INFO_SCHEMA);
} else if (command.is(CommandGetXdbcTypeInfo.class)) {
return new SchemaResult(Schemas.GET_TYPE_INFO_SCHEMA);
- } else if (command.is(CommandGetPrimaryKeys.class)) {
- return new SchemaResult(Schemas.GET_PRIMARY_KEYS_SCHEMA);
- } else if (command.is(CommandGetImportedKeys.class)) {
- return new SchemaResult(Schemas.GET_IMPORTED_KEYS_SCHEMA);
- } else if (command.is(CommandGetExportedKeys.class)) {
- return new SchemaResult(Schemas.GET_EXPORTED_KEYS_SCHEMA);
- } else if (command.is(CommandGetCrossReference.class)) {
- return new SchemaResult(Schemas.GET_CROSS_REFERENCE_SCHEMA);
}
throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid command provided.").toRuntimeException();
@@ -336,16 +342,31 @@ public interface FlightSqlProducer extends FlightProducer, AutoCloseable {
CallContext context, FlightDescriptor descriptor);
/**
- * Gets schema about a particular SQL query based data stream.
+ * Get the schema of the result set of a query.
*
- * @param command The sql command to generate the data stream.
+ * @param command The SQL query.
* @param context Per-call context.
* @param descriptor The descriptor identifying the data stream.
- * @return Schema for the stream.
+ * @return the schema of the result set.
*/
SchemaResult getSchemaStatement(CommandStatementQuery command, CallContext context,
FlightDescriptor descriptor);
+ /**
+ * Get the schema of the result set of a prepared statement.
+ *
+ * @param command The prepared statement handle.
+ * @param context Per-call context.
+ * @param descriptor The descriptor identifying the data stream.
+ * @return the schema of the result set.
+ */
+ default SchemaResult getSchemaPreparedStatement(CommandPreparedStatementQuery command, CallContext context,
+ FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED
+ .withDescription("GetSchema with CommandPreparedStatementQuery is not implemented")
+ .toRuntimeException();
+ }
+
/**
* Returns data for a SQL query based data stream.
* @param ticket Ticket message containing the statement handle.
diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java
index 25affa8f08..e461515c40 100644
--- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java
+++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java
@@ -76,7 +76,7 @@ public final class FlightSqlUtils {
return source.unpack(as);
} catch (final InvalidProtocolBufferException e) {
throw CallStatus.INVALID_ARGUMENT
- .withDescription("Provided message cannot be unpacked as desired type.")
+ .withDescription("Provided message cannot be unpacked as " + as.getName() + ": " + e)
.withCause(e)
.toRuntimeException();
}