You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2023/06/30 08:13:56 UTC

[arrow-flight-sql-postgresql] branch main updated: Add support for INSERT/UPDATE/DELETE (#42)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new fbddcf3  Add support for INSERT/UPDATE/DELETE (#42)
fbddcf3 is described below

commit fbddcf38596e3f73710ce84c26a9e5d606e8c1de
Author: Sutou Kouhei <ko...@clear-code.com>
AuthorDate: Fri Jun 30 17:13:50 2023 +0900

    Add support for INSERT/UPDATE/DELETE (#42)
    
    Closes GH-19
---
 src/afs.cc              | 149 ++++++++++++++++++++++++++++++++++++++----------
 test/test-flight-sql.rb |  23 ++++++++
 2 files changed, 143 insertions(+), 29 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index 3cd0009..8f0e778 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -335,7 +335,9 @@ struct SessionData {
 	dsa_pointer userName;
 	dsa_pointer password;
 	dsa_pointer clientAddress;
-	dsa_pointer query;
+	dsa_pointer selectQuery;
+	dsa_pointer updateQuery;
+	int64_t nUpdatedRecords;
 	SharedRingBufferData bufferData;
 };
 
@@ -529,8 +531,10 @@ class WorkerProcessor : public Processor {
 			dsa_free(area_, session->userName);
 		if (DsaPointerIsValid(session->password))
 			dsa_free(area_, session->password);
-		if (DsaPointerIsValid(session->query))
-			dsa_free(area_, session->query);
+		if (DsaPointerIsValid(session->selectQuery))
+			dsa_free(area_, session->selectQuery);
+		if (DsaPointerIsValid(session->updateQuery))
+			dsa_free(area_, session->updateQuery);
 		SharedRingBuffer::free_data(&(session->bufferData), area_);
 		dshash_delete_entry(sessions_, session);
 	}
@@ -645,17 +649,20 @@ class Executor : public WorkerProcessor {
 
 	void signaled()
 	{
-		P("%s: %s: signaled: before: %d", Tag, tag_, session_->query);
-		P("signaled: before: %d", session_->query);
-		if (DsaPointerIsValid(session_->query))
+		P("%s: %s: signaled: before: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery);
+		if (DsaPointerIsValid(session_->selectQuery))
 		{
-			execute();
+			select();
+		}
+		else if (DsaPointerIsValid(session_->updateQuery))
+		{
+			update();
 		}
 		else
 		{
 			Processor::signaled();
 		}
-		P("%s: %s: signaled: after: %d", Tag, tag_, session_->query);
+		P("%s: %s: signaled: after: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery);
 	}
 
    private:
@@ -844,26 +851,26 @@ class Executor : public WorkerProcessor {
 		return true;
 	}
 
-	void execute()
+	void select()
 	{
-		pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": executing").c_str());
+		pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": selecting").c_str());
 
 		PushActiveSnapshot(GetTransactionSnapshot());
 
 		LWLockAcquire(lock_, LW_EXCLUSIVE);
 		std::string query(
-			static_cast<const char*>(dsa_get_address(area_, session_->query)));
-		dsa_free(area_, session_->query);
-		session_->query = InvalidDsaPointer;
+			static_cast<const char*>(dsa_get_address(area_, session_->selectQuery)));
+		dsa_free(area_, session_->selectQuery);
+		session_->selectQuery = InvalidDsaPointer;
 		SetCurrentStatementStartTimestamp();
-		P("%s: %s: execute: %s", Tag, tag_, query.c_str());
+		P("%s: %s: select: %s", Tag, tag_, query.c_str());
 		auto result = SPI_execute(query.c_str(), true, 0);
 		LWLockRelease(lock_);
 
 		if (result == SPI_OK_SELECT)
 		{
 			pgstat_report_activity(STATE_RUNNING,
-			                       (std::string(Tag) + ": writing").c_str());
+			                       (std::string(Tag) + ": select: writing").c_str());
 			auto status = write();
 			if (!status.ok())
 			{
@@ -873,7 +880,7 @@ class Executor : public WorkerProcessor {
 		else
 		{
 			set_shared_string(session_->errorMessage,
-			                  std::string(Tag) + ": " + tag_ +
+			                  std::string(Tag) + ": " + tag_ + ": select" +
 			                      ": failed to run a query: <" + query +
 			                      ">: " + SPI_result_code_string(result));
 		}
@@ -882,7 +889,53 @@ class Executor : public WorkerProcessor {
 
 		if (sharedData_->serverPID != InvalidPid)
 		{
-			P("%s: %s: kill server: %d", Tag, tag_, sharedData_->serverPID);
+			P("%s: %s: select: kill server: %d", Tag, tag_, sharedData_->serverPID);
+			kill(sharedData_->serverPID, SIGUSR1);
+		}
+
+		pgstat_report_activity(STATE_IDLE, NULL);
+	}
+
+	void update()
+	{
+		pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": updating").c_str());
+
+		PushActiveSnapshot(GetTransactionSnapshot());
+
+		LWLockAcquire(lock_, LW_EXCLUSIVE);
+		std::string query(
+			static_cast<const char*>(dsa_get_address(area_, session_->updateQuery)));
+		dsa_free(area_, session_->updateQuery);
+		session_->updateQuery = InvalidDsaPointer;
+		SetCurrentStatementStartTimestamp();
+		P("%s: %s: update: %s", Tag, tag_, query.c_str());
+		auto result = SPI_execute(query.c_str(), false, 0);
+		LWLockRelease(lock_);
+
+		switch (result)
+		{
+		case SPI_OK_INSERT:
+		case SPI_OK_DELETE:
+		case SPI_OK_UPDATE:
+			session_->nUpdatedRecords = SPI_processed;
+			break;
+		default:
+			set_shared_string(session_->errorMessage,
+			                  std::string(Tag) + ": " + tag_ + ": update" +
+			                      ": failed to run a query: <" + query +
+			                      ">: " + SPI_result_code_string(result));
+			break;
+		}
+
+		PopActiveSnapshot();
+
+		// TODO: Is this usage correct?
+		CommitTransactionCommand();
+		StartTransactionCommand();
+
+		if (sharedData_->serverPID != InvalidPid)
+		{
+			P("%s: %s: update: kill server: %d", Tag, tag_, sharedData_->serverPID);
 			kill(sharedData_->serverPID, SIGUSR1);
 		}
 
@@ -1128,22 +1181,22 @@ class Proxy : public WorkerProcessor {
 		}
 	}
 
-	arrow::Result<std::shared_ptr<arrow::Schema>> execute(uint64_t sessionID,
-	                                                      const std::string& query)
+	arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID,
+														 const std::string& query)
 	{
 		auto session = find_session(sessionID);
 		SessionReleaser sessionReleaser(sessions_, session);
-		set_shared_string(session->query, query);
+		set_shared_string(session->selectQuery, query);
 		if (session->executorPID != InvalidPid)
 		{
-			P("%s: %s: execute: kill executor: %d", Tag, tag_, session->executorPID);
+			P("%s: %s: select: kill executor: %d", Tag, tag_, session->executorPID);
 			kill(session->executorPID, SIGUSR1);
 		}
 		{
 			auto buffer = std::move(create_shared_ring_buffer(session));
 			std::unique_lock<std::mutex> lock(mutex_);
 			conditionVariable_.wait(lock, [&] {
-				P("%s: %s: %s: wait: execute", Tag, tag_, AFS_FUNC);
+				P("%s: %s: %s: wait: select", Tag, tag_, AFS_FUNC);
 				return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0;
 			});
 		}
@@ -1151,7 +1204,7 @@ class Proxy : public WorkerProcessor {
 		{
 			return report_session_error(session);
 		}
-		P("%s: %s: execute: open", Tag, tag_);
+		P("%s: %s: select: open", Tag, tag_);
 		auto input = std::make_shared<SharedRingBufferInputStream>(this, session);
 		// Read schema only stream format data.
 		ARROW_ASSIGN_OR_RAISE(auto reader,
@@ -1159,17 +1212,44 @@ class Proxy : public WorkerProcessor {
 		while (true)
 		{
 			std::shared_ptr<arrow::RecordBatch> recordBatch;
-			P("%s: %s: execute: read next", Tag, tag_);
+			P("%s: %s: select: read next", Tag, tag_);
 			ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
 			if (!recordBatch)
 			{
 				break;
 			}
 		}
-		P("%s: %s: execute: schema", Tag, tag_);
+		P("%s: %s: select: schema", Tag, tag_);
 		return reader->schema();
 	}
 
+	arrow::Result<int64_t> update(uint64_t sessionID, const std::string& query)
+	{
+		auto session = find_session(sessionID);
+		SessionReleaser sessionReleaser(sessions_, session);
+		set_shared_string(session->updateQuery, query);
+		session->nUpdatedRecords = -1;
+		if (session->executorPID != InvalidPid)
+		{
+			P("%s: %s: update: kill executor: %d",
+			  Tag, tag_, session->executorPID);
+			kill(session->executorPID, SIGUSR1);
+		}
+		{
+			std::unique_lock<std::mutex> lock(mutex_);
+			conditionVariable_.wait(lock, [&] {
+				P("%s: %s: %s: wait: update", Tag, tag_, AFS_FUNC);
+				return DsaPointerIsValid(session->errorMessage) || session->nUpdatedRecords >= 0;
+			});
+		}
+		if (DsaPointerIsValid(session->errorMessage))
+		{
+			return report_session_error(session);
+		}
+		P("%s: %s: update: done: %ld", Tag, tag_, session->nUpdatedRecords);
+		return session->nUpdatedRecords;
+	}
+
 	arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read(uint64_t sessionID)
 	{
 		auto session = find_session(sessionID);
@@ -1210,7 +1290,9 @@ class Proxy : public WorkerProcessor {
 		set_shared_string(session->userName, userName);
 		set_shared_string(session->password, password);
 		set_shared_string(session->clientAddress, clientAddress);
-		session->query = InvalidDsaPointer;
+		session->selectQuery = InvalidDsaPointer;
+		session->updateQuery = InvalidDsaPointer;
+		session->nUpdatedRecords = -1;
 		SharedRingBuffer::initialize_data(&(session->bufferData));
 		LWLockRelease(lock_);
 		return session;
@@ -1507,11 +1589,11 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
 	arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement(
 		const arrow::flight::ServerCallContext& context,
 		const arrow::flight::sql::StatementQuery& command,
-		const arrow::flight::FlightDescriptor& descriptor)
+		const arrow::flight::FlightDescriptor& descriptor) override
 	{
 		ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
 		const auto& query = command.query;
-		ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->execute(sessionID, query));
+		ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, query));
 		ARROW_ASSIGN_OR_RAISE(auto ticket,
 		                      arrow::flight::sql::CreateStatementQueryTicket(query));
 		std::vector<arrow::flight::FlightEndpoint> endpoints{
@@ -1524,13 +1606,22 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
 
 	arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement(
 		const arrow::flight::ServerCallContext& context,
-		const arrow::flight::sql::StatementQueryTicket& command)
+		const arrow::flight::sql::StatementQueryTicket& command) override
 	{
 		ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
 		ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read(sessionID));
 		return std::make_unique<arrow::flight::RecordBatchStream>(reader);
 	}
 
+	arrow::Result<int64_t> DoPutCommandStatementUpdate(
+		const arrow::flight::ServerCallContext& context,
+		const arrow::flight::sql::StatementUpdate& command) override
+	{
+		ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+		const auto& query = command.query;
+		return proxy_->update(sessionID, query);
+	}
+
    private:
 	arrow::Result<uint64_t> session_id(const arrow::flight::ServerCallContext& context)
 	{
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 3e0992d..6997460 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -51,4 +51,27 @@ class FlightSQLTest < Test::Unit::TestCase
     assert_equal(Arrow::Table.new(value: Arrow::Int32Array.new([1, -2, 3])),
                  reader.read_all)
   end
+
+  def test_isnert_int32
+    unless filght_sql_client.respond_to?(:execute_update)
+      omit("red-arrow-flight-sql 13.0.0 or later is required")
+    end
+
+    run_sql("CREATE TABLE data (value integer)")
+
+    n_changed_records = flight_sql_client.execute_update(
+      "INSERT INTO data VALUES (1), (-2), (3)",
+      @options)
+    assert_equal(3, n_changed_records)
+    assert_equal([<<-RESULT, ""], run_sql("SELECT * FROM data"))
+SELECT * FROM data
+ value 
+-------
+     1
+    -2
+     3
+(3 rows)
+
+    RESULT
+  end
 end