This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new fbddcf3  Add support for INSERT/UPDATE/DELETE (#42)
fbddcf3 is described below

commit fbddcf38596e3f73710ce84c26a9e5d606e8c1de
Author: Sutou Kouhei <[email protected]>
AuthorDate: Fri Jun 30 17:13:50 2023 +0900

    Add support for INSERT/UPDATE/DELETE (#42)
    
    Closes GH-19
---
 src/afs.cc              | 149 ++++++++++++++++++++++++++++++++++++++----------
 test/test-flight-sql.rb |  23 ++++++++
 2 files changed, 143 insertions(+), 29 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index 3cd0009..8f0e778 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -335,7 +335,9 @@ struct SessionData {
        dsa_pointer userName;
        dsa_pointer password;
        dsa_pointer clientAddress;
-       dsa_pointer query;
+       dsa_pointer selectQuery;
+       dsa_pointer updateQuery;
+       int64_t nUpdatedRecords;
        SharedRingBufferData bufferData;
 };
 
@@ -529,8 +531,10 @@ class WorkerProcessor : public Processor {
                        dsa_free(area_, session->userName);
                if (DsaPointerIsValid(session->password))
                        dsa_free(area_, session->password);
-               if (DsaPointerIsValid(session->query))
-                       dsa_free(area_, session->query);
+               if (DsaPointerIsValid(session->selectQuery))
+                       dsa_free(area_, session->selectQuery);
+               if (DsaPointerIsValid(session->updateQuery))
+                       dsa_free(area_, session->updateQuery);
                SharedRingBuffer::free_data(&(session->bufferData), area_);
                dshash_delete_entry(sessions_, session);
        }
@@ -645,17 +649,20 @@ class Executor : public WorkerProcessor {
 
        void signaled()
        {
-               P("%s: %s: signaled: before: %d", Tag, tag_, session_->query);
-               P("signaled: before: %d", session_->query);
-               if (DsaPointerIsValid(session_->query))
+               P("%s: %s: signaled: before: %d/%d", Tag, tag_, 
session_->selectQuery, session_->updateQuery);
+               if (DsaPointerIsValid(session_->selectQuery))
                {
-                       execute();
+                       select();
+               }
+               else if (DsaPointerIsValid(session_->updateQuery))
+               {
+                       update();
                }
                else
                {
                        Processor::signaled();
                }
-               P("%s: %s: signaled: after: %d", Tag, tag_, session_->query);
+               P("%s: %s: signaled: after: %d/%d", Tag, tag_, 
session_->selectQuery, session_->updateQuery);
        }
 
    private:
@@ -844,26 +851,26 @@ class Executor : public WorkerProcessor {
                return true;
        }
 
-       void execute()
+       void select()
        {
-               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
executing").c_str());
+               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
selecting").c_str());
 
                PushActiveSnapshot(GetTransactionSnapshot());
 
                LWLockAcquire(lock_, LW_EXCLUSIVE);
                std::string query(
-                       static_cast<const char*>(dsa_get_address(area_, 
session_->query)));
-               dsa_free(area_, session_->query);
-               session_->query = InvalidDsaPointer;
+                       static_cast<const char*>(dsa_get_address(area_, 
session_->selectQuery)));
+               dsa_free(area_, session_->selectQuery);
+               session_->selectQuery = InvalidDsaPointer;
                SetCurrentStatementStartTimestamp();
-               P("%s: %s: execute: %s", Tag, tag_, query.c_str());
+               P("%s: %s: select: %s", Tag, tag_, query.c_str());
                auto result = SPI_execute(query.c_str(), true, 0);
                LWLockRelease(lock_);
 
                if (result == SPI_OK_SELECT)
                {
                        pgstat_report_activity(STATE_RUNNING,
-                                              (std::string(Tag) + ": 
writing").c_str());
+                                              (std::string(Tag) + ": select: 
writing").c_str());
                        auto status = write();
                        if (!status.ok())
                        {
@@ -873,7 +880,7 @@ class Executor : public WorkerProcessor {
                else
                {
                        set_shared_string(session_->errorMessage,
-                                         std::string(Tag) + ": " + tag_ +
+                                         std::string(Tag) + ": " + tag_ + ": 
select" +
                                              ": failed to run a query: <" + 
query +
                                              ">: " + 
SPI_result_code_string(result));
                }
@@ -882,7 +889,53 @@ class Executor : public WorkerProcessor {
 
                if (sharedData_->serverPID != InvalidPid)
                {
-                       P("%s: %s: kill server: %d", Tag, tag_, 
sharedData_->serverPID);
+                       P("%s: %s: select: kill server: %d", Tag, tag_, 
sharedData_->serverPID);
+                       kill(sharedData_->serverPID, SIGUSR1);
+               }
+
+               pgstat_report_activity(STATE_IDLE, NULL);
+       }
+
+       void update()
+       {
+               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
updating").c_str());
+
+               PushActiveSnapshot(GetTransactionSnapshot());
+
+               LWLockAcquire(lock_, LW_EXCLUSIVE);
+               std::string query(
+                       static_cast<const char*>(dsa_get_address(area_, 
session_->updateQuery)));
+               dsa_free(area_, session_->updateQuery);
+               session_->updateQuery = InvalidDsaPointer;
+               SetCurrentStatementStartTimestamp();
+               P("%s: %s: update: %s", Tag, tag_, query.c_str());
+               auto result = SPI_execute(query.c_str(), false, 0);
+               LWLockRelease(lock_);
+
+               switch (result)
+               {
+               case SPI_OK_INSERT:
+               case SPI_OK_DELETE:
+               case SPI_OK_UPDATE:
+                       session_->nUpdatedRecords = SPI_processed;
+                       break;
+               default:
+                       set_shared_string(session_->errorMessage,
+                                         std::string(Tag) + ": " + tag_ + ": 
update" +
+                                             ": failed to run a query: <" + 
query +
+                                             ">: " + 
SPI_result_code_string(result));
+                       break;
+               }
+
+               PopActiveSnapshot();
+
+               // TODO: Is this usage correct?
+               CommitTransactionCommand();
+               StartTransactionCommand();
+
+               if (sharedData_->serverPID != InvalidPid)
+               {
+                       P("%s: %s: update: kill server: %d", Tag, tag_, 
sharedData_->serverPID);
                        kill(sharedData_->serverPID, SIGUSR1);
                }
 
@@ -1128,22 +1181,22 @@ class Proxy : public WorkerProcessor {
                }
        }
 
-       arrow::Result<std::shared_ptr<arrow::Schema>> execute(uint64_t 
sessionID,
-                                                             const 
std::string& query)
+       arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID,
+                                                                               
                                 const std::string& query)
        {
                auto session = find_session(sessionID);
                SessionReleaser sessionReleaser(sessions_, session);
-               set_shared_string(session->query, query);
+               set_shared_string(session->selectQuery, query);
                if (session->executorPID != InvalidPid)
                {
-                       P("%s: %s: execute: kill executor: %d", Tag, tag_, 
session->executorPID);
+                       P("%s: %s: select: kill executor: %d", Tag, tag_, 
session->executorPID);
                        kill(session->executorPID, SIGUSR1);
                }
                {
                        auto buffer = 
std::move(create_shared_ring_buffer(session));
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
-                               P("%s: %s: %s: wait: execute", Tag, tag_, 
AFS_FUNC);
+                               P("%s: %s: %s: wait: select", Tag, tag_, 
AFS_FUNC);
                                return DsaPointerIsValid(session->errorMessage) 
|| buffer.size() > 0;
                        });
                }
@@ -1151,7 +1204,7 @@ class Proxy : public WorkerProcessor {
                {
                        return report_session_error(session);
                }
-               P("%s: %s: execute: open", Tag, tag_);
+               P("%s: %s: select: open", Tag, tag_);
                auto input = 
std::make_shared<SharedRingBufferInputStream>(this, session);
                // Read schema only stream format data.
                ARROW_ASSIGN_OR_RAISE(auto reader,
@@ -1159,17 +1212,44 @@ class Proxy : public WorkerProcessor {
                while (true)
                {
                        std::shared_ptr<arrow::RecordBatch> recordBatch;
-                       P("%s: %s: execute: read next", Tag, tag_);
+                       P("%s: %s: select: read next", Tag, tag_);
                        ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
                        if (!recordBatch)
                        {
                                break;
                        }
                }
-               P("%s: %s: execute: schema", Tag, tag_);
+               P("%s: %s: select: schema", Tag, tag_);
                return reader->schema();
        }
 
+       arrow::Result<int64_t> update(uint64_t sessionID, const std::string& 
query)
+       {
+               auto session = find_session(sessionID);
+               SessionReleaser sessionReleaser(sessions_, session);
+               set_shared_string(session->updateQuery, query);
+               session->nUpdatedRecords = -1;
+               if (session->executorPID != InvalidPid)
+               {
+                       P("%s: %s: update: kill executor: %d",
+                         Tag, tag_, session->executorPID);
+                       kill(session->executorPID, SIGUSR1);
+               }
+               {
+                       std::unique_lock<std::mutex> lock(mutex_);
+                       conditionVariable_.wait(lock, [&] {
+                               P("%s: %s: %s: wait: update", Tag, tag_, 
AFS_FUNC);
+                               return DsaPointerIsValid(session->errorMessage) 
|| session->nUpdatedRecords >= 0;
+                       });
+               }
+               if (DsaPointerIsValid(session->errorMessage))
+               {
+                       return report_session_error(session);
+               }
+               P("%s: %s: update: done: %ld", Tag, tag_, 
session->nUpdatedRecords);
+               return session->nUpdatedRecords;
+       }
+
        arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read(uint64_t 
sessionID)
        {
                auto session = find_session(sessionID);
@@ -1210,7 +1290,9 @@ class Proxy : public WorkerProcessor {
                set_shared_string(session->userName, userName);
                set_shared_string(session->password, password);
                set_shared_string(session->clientAddress, clientAddress);
-               session->query = InvalidDsaPointer;
+               session->selectQuery = InvalidDsaPointer;
+               session->updateQuery = InvalidDsaPointer;
+               session->nUpdatedRecords = -1;
                SharedRingBuffer::initialize_data(&(session->bufferData));
                LWLockRelease(lock_);
                return session;
@@ -1507,11 +1589,11 @@ class FlightSQLServer : public 
arrow::flight::sql::FlightSqlServerBase {
        arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> 
GetFlightInfoStatement(
                const arrow::flight::ServerCallContext& context,
                const arrow::flight::sql::StatementQuery& command,
-               const arrow::flight::FlightDescriptor& descriptor)
+               const arrow::flight::FlightDescriptor& descriptor) override
        {
                ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
                const auto& query = command.query;
-               ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->execute(sessionID, 
query));
+               ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, 
query));
                ARROW_ASSIGN_OR_RAISE(auto ticket,
                                      
arrow::flight::sql::CreateStatementQueryTicket(query));
                std::vector<arrow::flight::FlightEndpoint> endpoints{
@@ -1524,13 +1606,22 @@ class FlightSQLServer : public 
arrow::flight::sql::FlightSqlServerBase {
 
        arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> 
DoGetStatement(
                const arrow::flight::ServerCallContext& context,
-               const arrow::flight::sql::StatementQueryTicket& command)
+               const arrow::flight::sql::StatementQueryTicket& command) 
override
        {
                ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
                ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read(sessionID));
                return 
std::make_unique<arrow::flight::RecordBatchStream>(reader);
        }
 
+       arrow::Result<int64_t> DoPutCommandStatementUpdate(
+               const arrow::flight::ServerCallContext& context,
+               const arrow::flight::sql::StatementUpdate& command) override
+       {
+               ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+               const auto& query = command.query;
+               return proxy_->update(sessionID, query);
+       }
+
    private:
        arrow::Result<uint64_t> session_id(const 
arrow::flight::ServerCallContext& context)
        {
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 3e0992d..6997460 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -51,4 +51,27 @@ class FlightSQLTest < Test::Unit::TestCase
     assert_equal(Arrow::Table.new(value: Arrow::Int32Array.new([1, -2, 3])),
                  reader.read_all)
   end
+
+  def test_isnert_int32
+    unless filght_sql_client.respond_to?(:execute_update)
+      omit("red-arrow-flight-sql 13.0.0 or later is required")
+    end
+
+    run_sql("CREATE TABLE data (value integer)")
+
+    n_changed_records = flight_sql_client.execute_update(
+      "INSERT INTO data VALUES (1), (-2), (3)",
+      @options)
+    assert_equal(3, n_changed_records)
+    assert_equal([<<-RESULT, ""], run_sql("SELECT * FROM data"))
+SELECT * FROM data
+ value 
+-------
+     1
+    -2
+     3
+(3 rows)
+
+    RESULT
+  end
 end

Reply via email to