You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2023/06/30 08:13:56 UTC
[arrow-flight-sql-postgresql] branch main updated: Add support for INSERT/UPDATE/DELETE (#42)
This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git
The following commit(s) were added to refs/heads/main by this push:
new fbddcf3 Add support for INSERT/UPDATE/DELETE (#42)
fbddcf3 is described below
commit fbddcf38596e3f73710ce84c26a9e5d606e8c1de
Author: Sutou Kouhei <ko...@clear-code.com>
AuthorDate: Fri Jun 30 17:13:50 2023 +0900
Add support for INSERT/UPDATE/DELETE (#42)
Closes GH-19
---
src/afs.cc | 149 ++++++++++++++++++++++++++++++++++++++----------
test/test-flight-sql.rb | 23 ++++++++
2 files changed, 143 insertions(+), 29 deletions(-)
diff --git a/src/afs.cc b/src/afs.cc
index 3cd0009..8f0e778 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -335,7 +335,9 @@ struct SessionData {
dsa_pointer userName;
dsa_pointer password;
dsa_pointer clientAddress;
- dsa_pointer query;
+ dsa_pointer selectQuery;
+ dsa_pointer updateQuery;
+ int64_t nUpdatedRecords;
SharedRingBufferData bufferData;
};
@@ -529,8 +531,10 @@ class WorkerProcessor : public Processor {
dsa_free(area_, session->userName);
if (DsaPointerIsValid(session->password))
dsa_free(area_, session->password);
- if (DsaPointerIsValid(session->query))
- dsa_free(area_, session->query);
+ if (DsaPointerIsValid(session->selectQuery))
+ dsa_free(area_, session->selectQuery);
+ if (DsaPointerIsValid(session->updateQuery))
+ dsa_free(area_, session->updateQuery);
SharedRingBuffer::free_data(&(session->bufferData), area_);
dshash_delete_entry(sessions_, session);
}
@@ -645,17 +649,20 @@ class Executor : public WorkerProcessor {
void signaled()
{
- P("%s: %s: signaled: before: %d", Tag, tag_, session_->query);
- P("signaled: before: %d", session_->query);
- if (DsaPointerIsValid(session_->query))
+ P("%s: %s: signaled: before: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery);
+ if (DsaPointerIsValid(session_->selectQuery))
{
- execute();
+ select();
+ }
+ else if (DsaPointerIsValid(session_->updateQuery))
+ {
+ update();
}
else
{
Processor::signaled();
}
- P("%s: %s: signaled: after: %d", Tag, tag_, session_->query);
+ P("%s: %s: signaled: after: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery);
}
private:
@@ -844,26 +851,26 @@ class Executor : public WorkerProcessor {
return true;
}
- void execute()
+ void select()
{
- pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": executing").c_str());
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": selecting").c_str());
PushActiveSnapshot(GetTransactionSnapshot());
LWLockAcquire(lock_, LW_EXCLUSIVE);
std::string query(
- static_cast<const char*>(dsa_get_address(area_, session_->query)));
- dsa_free(area_, session_->query);
- session_->query = InvalidDsaPointer;
+ static_cast<const char*>(dsa_get_address(area_, session_->selectQuery)));
+ dsa_free(area_, session_->selectQuery);
+ session_->selectQuery = InvalidDsaPointer;
SetCurrentStatementStartTimestamp();
- P("%s: %s: execute: %s", Tag, tag_, query.c_str());
+ P("%s: %s: select: %s", Tag, tag_, query.c_str());
auto result = SPI_execute(query.c_str(), true, 0);
LWLockRelease(lock_);
if (result == SPI_OK_SELECT)
{
pgstat_report_activity(STATE_RUNNING,
- (std::string(Tag) + ": writing").c_str());
+ (std::string(Tag) + ": select: writing").c_str());
auto status = write();
if (!status.ok())
{
@@ -873,7 +880,7 @@ class Executor : public WorkerProcessor {
else
{
set_shared_string(session_->errorMessage,
- std::string(Tag) + ": " + tag_ +
+ std::string(Tag) + ": " + tag_ + ": select" +
": failed to run a query: <" + query +
">: " + SPI_result_code_string(result));
}
@@ -882,7 +889,53 @@ class Executor : public WorkerProcessor {
if (sharedData_->serverPID != InvalidPid)
{
- P("%s: %s: kill server: %d", Tag, tag_, sharedData_->serverPID);
+ P("%s: %s: select: kill server: %d", Tag, tag_, sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+
+ pgstat_report_activity(STATE_IDLE, NULL);
+ }
+
+ void update()
+ {
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": updating").c_str());
+
+ PushActiveSnapshot(GetTransactionSnapshot());
+
+ LWLockAcquire(lock_, LW_EXCLUSIVE);
+ std::string query(
+ static_cast<const char*>(dsa_get_address(area_, session_->updateQuery)));
+ dsa_free(area_, session_->updateQuery);
+ session_->updateQuery = InvalidDsaPointer;
+ SetCurrentStatementStartTimestamp();
+ P("%s: %s: update: %s", Tag, tag_, query.c_str());
+ auto result = SPI_execute(query.c_str(), false, 0);
+ LWLockRelease(lock_);
+
+ switch (result)
+ {
+ case SPI_OK_INSERT:
+ case SPI_OK_DELETE:
+ case SPI_OK_UPDATE:
+ session_->nUpdatedRecords = SPI_processed;
+ break;
+ default:
+ set_shared_string(session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": update" +
+ ": failed to run a query: <" + query +
+ ">: " + SPI_result_code_string(result));
+ break;
+ }
+
+ PopActiveSnapshot();
+
+ // TODO: Is this usage correct?
+ CommitTransactionCommand();
+ StartTransactionCommand();
+
+ if (sharedData_->serverPID != InvalidPid)
+ {
+ P("%s: %s: update: kill server: %d", Tag, tag_, sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
}
@@ -1128,22 +1181,22 @@ class Proxy : public WorkerProcessor {
}
}
- arrow::Result<std::shared_ptr<arrow::Schema>> execute(uint64_t sessionID,
- const std::string& query)
+ arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID,
+ const std::string& query)
{
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
- set_shared_string(session->query, query);
+ set_shared_string(session->selectQuery, query);
if (session->executorPID != InvalidPid)
{
- P("%s: %s: execute: kill executor: %d", Tag, tag_, session->executorPID);
+ P("%s: %s: select: kill executor: %d", Tag, tag_, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
auto buffer = std::move(create_shared_ring_buffer(session));
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
- P("%s: %s: %s: wait: execute", Tag, tag_, AFS_FUNC);
+ P("%s: %s: %s: wait: select", Tag, tag_, AFS_FUNC);
return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0;
});
}
@@ -1151,7 +1204,7 @@ class Proxy : public WorkerProcessor {
{
return report_session_error(session);
}
- P("%s: %s: execute: open", Tag, tag_);
+ P("%s: %s: select: open", Tag, tag_);
auto input = std::make_shared<SharedRingBufferInputStream>(this, session);
// Read schema only stream format data.
ARROW_ASSIGN_OR_RAISE(auto reader,
@@ -1159,17 +1212,44 @@ class Proxy : public WorkerProcessor {
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
- P("%s: %s: execute: read next", Tag, tag_);
+ P("%s: %s: select: read next", Tag, tag_);
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
}
- P("%s: %s: execute: schema", Tag, tag_);
+ P("%s: %s: select: schema", Tag, tag_);
return reader->schema();
}
+ arrow::Result<int64_t> update(uint64_t sessionID, const std::string& query)
+ {
+ auto session = find_session(sessionID);
+ SessionReleaser sessionReleaser(sessions_, session);
+ set_shared_string(session->updateQuery, query);
+ session->nUpdatedRecords = -1;
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: update: kill executor: %d",
+ Tag, tag_, session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ conditionVariable_.wait(lock, [&] {
+ P("%s: %s: %s: wait: update", Tag, tag_, AFS_FUNC);
+ return DsaPointerIsValid(session->errorMessage) || session->nUpdatedRecords >= 0;
+ });
+ }
+ if (DsaPointerIsValid(session->errorMessage))
+ {
+ return report_session_error(session);
+ }
+ P("%s: %s: update: done: %ld", Tag, tag_, session->nUpdatedRecords);
+ return session->nUpdatedRecords;
+ }
+
arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read(uint64_t sessionID)
{
auto session = find_session(sessionID);
@@ -1210,7 +1290,9 @@ class Proxy : public WorkerProcessor {
set_shared_string(session->userName, userName);
set_shared_string(session->password, password);
set_shared_string(session->clientAddress, clientAddress);
- session->query = InvalidDsaPointer;
+ session->selectQuery = InvalidDsaPointer;
+ session->updateQuery = InvalidDsaPointer;
+ session->nUpdatedRecords = -1;
SharedRingBuffer::initialize_data(&(session->bufferData));
LWLockRelease(lock_);
return session;
@@ -1507,11 +1589,11 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQuery& command,
- const arrow::flight::FlightDescriptor& descriptor)
+ const arrow::flight::FlightDescriptor& descriptor) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = command.query;
- ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->execute(sessionID, query));
+ ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, query));
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(query));
std::vector<arrow::flight::FlightEndpoint> endpoints{
@@ -1524,13 +1606,22 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement(
const arrow::flight::ServerCallContext& context,
- const arrow::flight::sql::StatementQueryTicket& command)
+ const arrow::flight::sql::StatementQueryTicket& command) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read(sessionID));
return std::make_unique<arrow::flight::RecordBatchStream>(reader);
}
+ arrow::Result<int64_t> DoPutCommandStatementUpdate(
+ const arrow::flight::ServerCallContext& context,
+ const arrow::flight::sql::StatementUpdate& command) override
+ {
+ ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+ const auto& query = command.query;
+ return proxy_->update(sessionID, query);
+ }
+
private:
arrow::Result<uint64_t> session_id(const arrow::flight::ServerCallContext& context)
{
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 3e0992d..6997460 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -51,4 +51,27 @@ class FlightSQLTest < Test::Unit::TestCase
assert_equal(Arrow::Table.new(value: Arrow::Int32Array.new([1, -2, 3])),
reader.read_all)
end
+
+ def test_isnert_int32
+ unless filght_sql_client.respond_to?(:execute_update)
+ omit("red-arrow-flight-sql 13.0.0 or later is required")
+ end
+
+ run_sql("CREATE TABLE data (value integer)")
+
+ n_changed_records = flight_sql_client.execute_update(
+ "INSERT INTO data VALUES (1), (-2), (3)",
+ @options)
+ assert_equal(3, n_changed_records)
+ assert_equal([<<-RESULT, ""], run_sql("SELECT * FROM data"))
+SELECT * FROM data
+ value
+-------
+ 1
+ -2
+ 3
+(3 rows)
+
+ RESULT
+ end
end