This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e92eee  Add support for prepared SELECT (#86)
8e92eee is described below

commit 8e92eee4768e831c4ef16cb61c60ebaaa92e19c6
Author: Sutou Kouhei <[email protected]>
AuthorDate: Wed Aug 30 17:15:43 2023 +0900

    Add support for prepared SELECT (#86)
    
    Closes GH-81
---
 src/afs.cc              | 790 ++++++++++++++++++++++++++++++++----------------
 test/test-flight-sql.rb |  33 +-
 2 files changed, 557 insertions(+), 266 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index bfb6776..bf78472 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -63,6 +63,7 @@ extern "C"
 #include <map>
 #include <random>
 #include <sstream>
+#include <type_traits>
 
 #include <arpa/inet.h>
 
@@ -163,6 +164,22 @@ class ScopedMemoryContext {
        MemoryContext oldMemoryContext_;
 };
 
+struct ScopedTransaction {
+       ScopedTransaction() { StartTransactionCommand(); }
+       ~ScopedTransaction() { CommitTransactionCommand(); }
+};
+
+struct ScopedSnapshot {
+       ScopedSnapshot() { PushActiveSnapshot(GetTransactionSnapshot()); }
+       ~ScopedSnapshot() { PopActiveSnapshot(); }
+};
+
+struct ScopedPlan {
+       ScopedPlan(SPIPlanPtr plan) : plan_(plan) {}
+       ~ScopedPlan() { SPI_freeplan(plan_); }
+       SPIPlanPtr plan_;
+};
+
 struct SharedRingBufferData {
        dsa_pointer pointer;
        size_t total;
@@ -344,6 +361,8 @@ enum class Action
        Update,
        Prepare,
        ClosePreparedStatement,
+       SetParameters,
+       SelectPreparedStatement,
        UpdatePreparedStatement,
 };
 
@@ -362,6 +381,10 @@ action_name(Action action)
                        return "Action::Prepare";
                case Action::ClosePreparedStatement:
                        return "Action::ClosePreparedStatement";
+               case Action::SetParameters:
+                       return "Action::SetParameters";
+               case Action::SelectPreparedStatement:
+                       return "Action::SelectPreparedStatement";
                case Action::UpdatePreparedStatement:
                        return "Action::UpdatePreparedStatement";
                default:
@@ -617,6 +640,17 @@ class Processor {
        std::condition_variable conditionVariable_;
 };
 
+struct ProcessorLockGuard {
+       ProcessorLockGuard(Processor* processor) : processor_(processor)
+       {
+               processor_->lock_acquire();
+       }
+       ~ProcessorLockGuard() { processor_->lock_release(); }
+
+   private:
+       Processor* processor_;
+};
+
 class Proxy;
 class SharedRingBufferInputStream : public arrow::io::InputStream {
    public:
@@ -1004,34 +1038,56 @@ class PGArrowValueConverter : public 
arrow::ArrayVisitor {
 
 class PreparedStatement {
    public:
-       explicit PreparedStatement(std::string query) : 
query_(std::move(query)) {}
+       explicit PreparedStatement(std::string query)
+               : query_(std::move(query)), parameters_()
+       {
+       }
 
        ~PreparedStatement() {}
 
-       arrow::Result<int64_t> 
update(std::shared_ptr<SharedRingBufferInputStream>& input)
+       using WriteFunc = std::add_pointer<arrow::Status(void*)>::type;
+       arrow::Status select(WriteFunc write, void* writeData)
        {
-               ARROW_ASSIGN_OR_RAISE(auto reader,
-                                     
arrow::ipc::RecordBatchStreamReader::Open(input));
-               const auto& schema = reader->schema();
-               SPIExecuteOptions options = {};
-               if (schema->num_fields() > 0)
+               for (const auto& recordBatch : parameters_)
                {
-                       options.params = makeParamList(schema->num_fields());
+                       SPIExecuteOptions options = {};
+                       std::vector<Oid> pgTypes;
+                       ARROW_RETURN_NOT_OK(prepare(options, pgTypes, 
recordBatch->schema()));
+                       auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), 
pgTypes.data());
+                       ScopedPlan scopedPlan(plan);
+                       ARROW_RETURN_NOT_OK(
+                               execute(plan, recordBatch, options, [&]() { 
return write(writeData); }));
                }
-               options.read_only = false;
-               options.tcount = 0;
-               ARROW_ASSIGN_OR_RAISE(auto pgTypes, create_pg_types(schema));
-               for (size_t i = 0; i < pgTypes.size(); ++i)
+               return arrow::Status::OK();
+       }
+
+       arrow::Status 
set_parameters(std::shared_ptr<SharedRingBufferInputStream>& input)
+       {
+               parameters_.clear();
+               ARROW_ASSIGN_OR_RAISE(auto reader,
+                                     
arrow::ipc::RecordBatchStreamReader::Open(input));
+               while (true)
                {
-                       options.params->params[i].pflags = PARAM_FLAG_CONST;
-                       options.params->params[i].ptype = pgTypes[i];
+                       std::shared_ptr<arrow::RecordBatch> recordBatch;
+                       ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
+                       if (!recordBatch)
+                       {
+                               break;
+                       }
+                       parameters_.push_back(std::move(recordBatch));
                }
+               return arrow::Status::OK();
+       }
+
+       arrow::Result<int64_t> 
update(std::shared_ptr<SharedRingBufferInputStream>& input)
+       {
+               ARROW_ASSIGN_OR_RAISE(auto reader,
+                                     
arrow::ipc::RecordBatchStreamReader::Open(input));
+               SPIExecuteOptions options = {};
+               std::vector<Oid> pgTypes;
+               ARROW_RETURN_NOT_OK(prepare(options, pgTypes, 
reader->schema()));
                auto plan = SPI_prepare(query_.c_str(), pgTypes.size(), 
pgTypes.data());
-               struct PlanFinalizer {
-                       PlanFinalizer(SPIPlanPtr plan) : plan_(plan) {}
-                       ~PlanFinalizer() { SPI_freeplan(plan_); }
-                       SPIPlanPtr plan_;
-               } planFinalizer(plan);
+               ScopedPlan scopedPlan(plan);
 
                int64_t nUpdatedRecords = 0;
                while (true)
@@ -1042,41 +1098,67 @@ class PreparedStatement {
                        {
                                break;
                        }
-                       const auto& columns = recordBatch->columns();
-                       for (int64_t i = 0; i < recordBatch->num_rows(); ++i)
-                       {
-                               
ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i, columns, options));
-                               auto result = SPI_execute_plan_extended(plan, 
&options);
-                               switch (result)
-                               {
-                                       case SPI_OK_INSERT:
-                                       case SPI_OK_DELETE:
-                                       case SPI_OK_UPDATE:
-                                               break;
-                                       default:
-                                               return arrow::Status::Invalid(
-                                                       "failed to run a 
prepared statement: ",
-                                                       
SPI_result_code_string(result));
-                                               break;
-                               }
+                       ARROW_RETURN_NOT_OK(execute(plan, recordBatch, options, 
[&nUpdatedRecords]() {
                                nUpdatedRecords += SPI_processed;
-                       }
+                               return arrow::Status::OK();
+                       }));
                }
                return nUpdatedRecords;
        }
 
    private:
-       arrow::Result<std::vector<Oid>> create_pg_types(
-               const std::shared_ptr<arrow::Schema>& schema)
+       arrow::Status prepare_pg_types(std::vector<Oid>& pgTypes,
+                                      const std::shared_ptr<arrow::Schema>& 
schema)
        {
-               std::vector<Oid> pgTypes;
                ArrowPGTypeConverter converter;
                for (const auto& field : schema->fields())
                {
                        ARROW_RETURN_NOT_OK(field->type()->Accept(&converter));
                        pgTypes.push_back(converter.oid());
                }
-               return std::move(pgTypes);
+               return arrow::Status::OK();
+       }
+
+       arrow::Status prepare(SPIExecuteOptions& options,
+                             std::vector<Oid>& pgTypes,
+                             const std::shared_ptr<arrow::Schema>& schema)
+       {
+               if (schema->num_fields() > 0)
+               {
+                       options.params = makeParamList(schema->num_fields());
+               }
+               options.read_only = false;
+               options.tcount = 0;
+               ARROW_RETURN_NOT_OK(prepare_pg_types(pgTypes, schema));
+               for (size_t i = 0; i < pgTypes.size(); ++i)
+               {
+                       options.params->params[i].pflags = PARAM_FLAG_CONST;
+                       options.params->params[i].ptype = pgTypes[i];
+               }
+               return arrow::Status::OK();
+       }
+
+       template <typename OnSuccessFunc>
+       arrow::Status execute(SPIPlanPtr plan,
+                             const std::shared_ptr<arrow::RecordBatch>& 
recordBatch,
+                             SPIExecuteOptions& options,
+                             OnSuccessFunc onSuccess)
+       {
+               const auto& columns = recordBatch->columns();
+               for (int64_t i = 0; i < recordBatch->num_rows(); ++i)
+               {
+                       ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i, 
columns, options));
+                       auto result = SPI_execute_plan_extended(plan, &options);
+                       if (result <= 0)
+                       {
+                               return arrow::Status::Invalid("failed to run a 
prepared statement: ",
+                                                             
SPI_result_code_string(result),
+                                                             ": ",
+                                                             query_);
+                       }
+                       ARROW_RETURN_NOT_OK(onSuccess());
+               }
+               return arrow::Status::OK();
        }
 
        arrow::Status assign_parameters(
@@ -1101,18 +1183,7 @@ class PreparedStatement {
        }
 
        std::string query_;
-};
-
-struct Transaction {
-       Transaction() { StartTransactionCommand(); }
-
-       ~Transaction() { CommitTransactionCommand(); }
-};
-
-struct Snapshot {
-       Snapshot() { PushActiveSnapshot(GetTransactionSnapshot()); }
-
-       ~Snapshot() { PopActiveSnapshot(); }
+       std::vector<std::shared_ptr<arrow::RecordBatch>> parameters_;
 };
 
 class Executor : public WorkerProcessor {
@@ -1138,6 +1209,7 @@ class Executor : public WorkerProcessor {
 
        void open()
        {
+               const char* tag = "open";
                // pg_usleep(5000000);
                // pg_usleep(5000000);
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
opening").c_str());
@@ -1155,8 +1227,7 @@ class Executor : public WorkerProcessor {
                if (!check_password(databaseName, userName, password, 
clientAddress))
                {
                        session_->initialized = true;
-                       P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
+                       signal_server(tag);
                        return;
                }
                {
@@ -1169,18 +1240,19 @@ class Executor : public WorkerProcessor {
                pgstat_report_activity(STATE_IDLE, NULL);
                session_->initialized = true;
                connected_ = true;
-               P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC, 
sharedData_->serverPID);
-               kill(sharedData_->serverPID, SIGUSR1);
+               signal_server(tag);
        }
 
        void close() { close_internal(true); }
 
        void signaled() override
        {
-               lock_acquire();
-               auto action = session_->action;
-               session_->action = Action::None;
-               lock_release();
+               Action action;
+               {
+                       ProcessorLockGuard lock(this);
+                       action = session_->action;
+                       session_->action = Action::None;
+               }
                P("%s: %s: signaled: before: %s", Tag, tag_, 
action_name(action));
                PG_TRY();
                {
@@ -1198,6 +1270,12 @@ class Executor : public WorkerProcessor {
                                case Action::ClosePreparedStatement:
                                        close_prepared_statement();
                                        break;
+                               case Action::SetParameters:
+                                       set_parameters();
+                                       break;
+                               case Action::SelectPreparedStatement:
+                                       select_prepared_statement();
+                                       break;
                                case Action::UpdatePreparedStatement:
                                        update_prepared_statement();
                                        break;
@@ -1211,14 +1289,16 @@ class Executor : public WorkerProcessor {
                        if (session_ && 
!DsaPointerIsValid(session_->errorMessage))
                        {
                                auto error = CopyErrorData();
-                               set_shared_string(session_->errorMessage,
-                                                 std::string("failed to run: 
") + action_name(action) +
-                                                     ": " + error->message);
+                               set_error_message(std::string("failed to run: 
") + action_name(action) +
+                                                     ": " + error->message,
+                                                 "unexpected error");
                                FreeErrorData(error);
                        }
+                       pgstat_report_activity(STATE_IDLE, NULL);
                        PG_RE_THROW();
                }
                PG_END_TRY();
+               pgstat_report_activity(STATE_IDLE, NULL);
                P("%s: %s: signaled: after: %s", Tag, tag_, 
action_name(action));
        }
 
@@ -1228,8 +1308,32 @@ class Executor : public WorkerProcessor {
        const char* peer_name(SessionData* session) override { return "server"; 
}
 
    private:
+       void signal_server(const char* tag)
+       {
+               if (sharedData_->serverPID == InvalidPid)
+               {
+                       return;
+               }
+               P("%s: %s: %s: kill server: %d", Tag, tag_, tag, 
sharedData_->serverPID);
+               kill(sharedData_->serverPID, SIGUSR1);
+       }
+
+       void set_error_message(const std::string& message, const char* tag)
+       {
+               if (DsaPointerIsValid(session_->errorMessage))
+               {
+                       return;
+               }
+               {
+                       ProcessorLockGuard lock(this);
+                       set_shared_string(session_->errorMessage, message);
+               }
+               signal_server(tag);
+       }
+
        void close_internal(bool unlockSession)
        {
+               const char* tag = "close";
                closed_ = true;
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
closing").c_str());
                preparedStatements_.clear();
@@ -1244,17 +1348,13 @@ class Executor : public WorkerProcessor {
                }
                else
                {
-                       if (!DsaPointerIsValid(session_->errorMessage))
-                       {
-                               set_shared_string(session_->errorMessage, 
"failed to connect");
-                       }
+                       set_error_message("failed to connect", tag);
                        session_->initialized = true;
                        if (unlockSession)
                        {
                                dshash_release_lock(sessions_, session_);
                        }
-                       P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
+                       signal_server(tag);
                }
                if (CurrentResourceOwner)
                {
@@ -1267,7 +1367,6 @@ class Executor : public WorkerProcessor {
                                resourceOwner, RESOURCE_RELEASE_AFTER_LOCKS, 
false, true);
                        ResourceOwnerDelete(resourceOwner);
                }
-               pgstat_report_activity(STATE_IDLE, NULL);
        }
 
        bool check_password(const char* databaseName,
@@ -1275,6 +1374,7 @@ class Executor : public WorkerProcessor {
                            const char* password,
                            const char* clientAddress)
        {
+               const char* tag = "check password";
                MemoryContext memoryContext =
                        AllocSetContextCreate(CurrentMemoryContext,
                                          "arrow-flight-sql: 
Executor::check_password()",
@@ -1291,20 +1391,18 @@ class Executor : public WorkerProcessor {
                hba_getauthmethod(&port);
                if (!port.hba)
                {
-                       set_shared_string(session_->errorMessage, "failed to 
get auth method");
+                       set_error_message("failed to get auth method", tag);
                        return false;
                }
                switch (port.hba->auth_method)
                {
                        case uaMD5:
                                // TODO
-                               set_shared_string(session_->errorMessage,
-                                                 "MD5 auth method isn't 
supported yet");
+                               set_error_message("MD5 auth method isn't 
supported yet", tag);
                                return false;
                        case uaSCRAM:
                                // TODO
-                               set_shared_string(session_->errorMessage,
-                                                 "SCRAM auth method isn't 
supported yet");
+                               set_error_message("SCRAM auth method isn't 
supported yet", tag);
                                return false;
                        case uaPassword:
                        {
@@ -1312,18 +1410,16 @@ class Executor : public WorkerProcessor {
                                auto shadowPassword = 
get_role_password(port.user_name, &logDetail);
                                if (!shadowPassword)
                                {
-                                       set_shared_string(
-                                               session_->errorMessage,
-                                               std::string("failed to get 
password: ") + logDetail);
+                                       set_error_message(std::string("failed 
to get password: ") + logDetail,
+                                                         tag);
                                        return false;
                                }
                                auto result = plain_crypt_verify(
                                        port.user_name, shadowPassword, 
password, &logDetail);
                                if (result != STATUS_OK)
                                {
-                                       set_shared_string(
-                                               session_->errorMessage,
-                                               std::string("failed to verify 
password: ") + logDetail);
+                                       set_error_message(
+                                               std::string("failed to verify 
password: ") + logDetail, tag);
                                        return false;
                                }
                                return true;
@@ -1331,15 +1427,16 @@ class Executor : public WorkerProcessor {
                        case uaTrust:
                                return true;
                        default:
-                               set_shared_string(session_->errorMessage,
-                                                 std::string("unsupported auth 
method: ") +
-                                                     
hba_authname(port.hba->auth_method));
+                               set_error_message(std::string("unsupported auth 
method: ") +
+                                                     
hba_authname(port.hba->auth_method),
+                                                 tag);
                                return false;
                }
        }
 
        bool fill_client_address(Port* port, const char* clientAddress)
        {
+               const char* tag = "fill client address";
                // clientAddress: "ipv4:127.0.0.1:40468"
                // family: "ipv4"
                // host: "127.0.0.1"
@@ -1353,9 +1450,8 @@ class Executor : public WorkerProcessor {
                std::getline(clientAddressStream, clientPort);
                if (!(clientFamily == "ipv4" || clientFamily == "ipv6"))
                {
-                       set_shared_string(
-                               session_->errorMessage,
-                               std::string("client family must be ipv4 or 
ipv6: ") + clientFamily);
+                       set_error_message(
+                               std::string("client family must be ipv4 or 
ipv6: ") + clientFamily, tag);
                        return false;
                }
                auto clientPortStart = clientPort.c_str();
@@ -1363,21 +1459,19 @@ class Executor : public WorkerProcessor {
                auto clientPortNumber = std::strtoul(clientPortStart, 
&clientPortEnd, 10);
                if (clientPortEnd[0] != '\0')
                {
-                       set_shared_string(session_->errorMessage,
-                                         std::string("client port is invalid: 
") + clientPort);
+                       set_error_message(std::string("client port is invalid: 
") + clientPort, tag);
                        return false;
                }
                if (clientPortNumber == 0)
                {
-                       set_shared_string(session_->errorMessage,
-                                         std::string("client port must not 
0"));
+                       set_error_message(std::string("client port must not 
0"), tag);
                        return false;
                }
                if (clientPortNumber > 65535)
                {
-                       set_shared_string(session_->errorMessage,
-                                         std::string("client port is too 
large: ") +
-                                             std::to_string(clientPortNumber));
+                       set_error_message(std::string("client port is too 
large: ") +
+                                             std::to_string(clientPortNumber),
+                                         tag);
                        return false;
                }
                if (clientFamily == "ipv4")
@@ -1388,9 +1482,8 @@ class Executor : public WorkerProcessor {
                        raddr->sin_port = htons(clientPortNumber);
                        if (inet_pton(AF_INET, clientHost.c_str(), 
&(raddr->sin_addr)) == 0)
                        {
-                               set_shared_string(
-                                       session_->errorMessage,
-                                       std::string("client IPv4 address is 
invalid: ") + clientHost);
+                               set_error_message(
+                                       std::string("client IPv4 address is 
invalid: ") + clientHost, tag);
                                return false;
                        }
                }
@@ -1403,9 +1496,8 @@ class Executor : public WorkerProcessor {
                        raddr->sin6_flowinfo = 0;
                        if (inet_pton(AF_INET6, clientHost.c_str(), 
&(raddr->sin6_addr)) == 0)
                        {
-                               set_shared_string(
-                                       session_->errorMessage,
-                                       std::string("client IPv6 address is 
invalid: ") + clientHost);
+                               set_error_message(
+                                       std::string("client IPv6 address is 
invalid: ") + clientHost, tag);
                                return false;
                        }
                        raddr->sin6_scope_id = 0;
@@ -1415,66 +1507,60 @@ class Executor : public WorkerProcessor {
 
        void select()
        {
+               const char* tag = "select";
                if (!DsaPointerIsValid(session_->selectQuery))
                {
-                       lock_acquire();
-                       set_shared_string(
-                               session_->errorMessage,
-                               std::string(Tag) + ": " + tag_ + ": select" + 
": query is missing");
-                       lock_release();
+                       set_error_message(
+                               std::string(Tag) + ": " + tag_ + ": " + tag + 
": query is missing", tag);
                        return;
                }
 
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
selecting").c_str());
 
-               lock_acquire();
-               std::string query(
-                       static_cast<const char*>(dsa_get_address(area_, 
session_->selectQuery)));
-               dsa_free(area_, session_->selectQuery);
-               session_->selectQuery = InvalidDsaPointer;
-               lock_release();
-               P("%s: %s: select: %s", Tag, tag_, query.c_str());
+               std::string query;
+               {
+                       ProcessorLockGuard lock(this);
+                       query =
+                               static_cast<const char*>(dsa_get_address(area_, 
session_->selectQuery));
+                       dsa_free(area_, session_->selectQuery);
+                       session_->selectQuery = InvalidDsaPointer;
+               }
+               P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
 
                {
-                       Transaction transaction;
-                       Snapshot snapshot;
+                       ScopedTransaction scopedTransaction;
+                       ScopedSnapshot scopedSnapshot;
 
                        SetCurrentStatementStartTimestamp();
                        auto result = SPI_execute(query.c_str(), true, 0);
 
-                       if (result == SPI_OK_SELECT)
+                       if (result > 0)
                        {
-                               pgstat_report_activity(STATE_RUNNING,
-                                                      (std::string(Tag) + ": 
select: writing").c_str());
-                               auto status = write();
-                               if (!status.ok())
+                               pgstat_report_activity(
+                                       STATE_RUNNING, (std::string(Tag) + ": " 
+ tag + ": writing").c_str());
+                               auto status = write(tag);
+                               if (status.ok())
                                {
-                                       lock_acquire();
-                                       
set_shared_string(session_->errorMessage, status.ToString());
-                                       lock_release();
+                                       signal_server(tag);
+                               }
+                               else
+                               {
+                                       set_error_message(std::string(Tag) + ": 
" + tag_ + ": " + tag +
+                                                             ": failed to 
write: " + status.ToString(),
+                                                         tag);
                                }
                        }
                        else
                        {
-                               lock_acquire();
-                               set_shared_string(session_->errorMessage,
-                                                 std::string(Tag) + ": " + 
tag_ + ": select" +
+                               set_error_message(std::string(Tag) + ": " + 
tag_ + ": " + tag +
                                                      ": failed to run a query: 
<" + query +
-                                                     ">: " + 
SPI_result_code_string(result));
-                               lock_release();
+                                                     ">: " + 
SPI_result_code_string(result),
+                                                 tag);
                        }
                }
-
-               if (sharedData_->serverPID != InvalidPid)
-               {
-                       P("%s: %s: select: kill server: %d", Tag, tag_, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
-               }
-
-               pgstat_report_activity(STATE_IDLE, NULL);
        }
 
-       arrow::Status write()
+       arrow::Status write(const char* tag)
        {
                SharedRingBufferOutputStream output(this, session_);
                std::vector<PGArrowValueConverter> converters;
@@ -1500,9 +1586,9 @@ class Executor : public WorkerProcessor {
                                      arrow::ipc::MakeStreamWriter(&output, 
schema, options));
                // Build an empty record batch to write schema.
                ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush());
-               P("%s: %s: write: schema: WriteRecordBatch", Tag, tag_);
+               P("%s: %s: %s: write: schema: WriteRecordBatch", Tag, tag_, 
tag);
                ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
-               P("%s: %s: write: schema: Close", Tag, tag_);
+               P("%s: %s: %s: write: schema: Close", Tag, tag_, tag);
                ARROW_RETURN_NOT_OK(writer->Close());
 
                // Write another stream format data with record batches.
@@ -1511,17 +1597,19 @@ class Executor : public WorkerProcessor {
                bool needLastFlush = false;
                for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple)
                {
-                       P("%s: %s: write: data: record batch: %d/%d",
+                       P("%s: %s: %s: write: data: record batch: %d/%d",
                          Tag,
                          tag_,
+                         tag,
                          iTuple,
                          SPI_processed);
                        for (int iAttribute = 0; iAttribute < 
SPI_tuptable->tupdesc->natts;
                             ++iAttribute)
                        {
-                               P("%s: %s: write: data: record batch: %d/%d: 
%d/%d",
+                               P("%s: %s: %s: write: data: record batch: 
%d/%d: %d/%d",
                                  Tag,
                                  tag_,
+                                 tag,
                                  iTuple,
                                  SPI_processed,
                                  iAttribute,
@@ -1546,9 +1634,10 @@ class Executor : public WorkerProcessor {
                        if (((iTuple + 1) % MaxNRowsPerRecordBatch) == 0)
                        {
                                ARROW_ASSIGN_OR_RAISE(recordBatch, 
builder->Flush());
-                               P("%s: %s: write: data: WriteRecordBatch: 
%d/%d",
+                               P("%s: %s: %s: write: data: WriteRecordBatch: 
%d/%d",
                                  Tag,
                                  tag_,
+                                 tag,
                                  iTuple,
                                  SPI_processed);
                                
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
@@ -1562,208 +1651,259 @@ class Executor : public WorkerProcessor {
                if (needLastFlush)
                {
                        ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
-                       P("%s: %s: write: data: WriteRecordBatch", Tag, tag_);
+                       P("%s: %s: %s: write: data: WriteRecordBatch", Tag, 
tag_, tag);
                        
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
                }
-               P("%s: %s: write: data: Close", Tag, tag_);
+               P("%s: %s: %s, write: data: Close", Tag, tag_, tag);
                ARROW_RETURN_NOT_OK(writer->Close());
                return output.Close();
        }
 
        void update()
        {
+               const char* tag = "update";
                if (!DsaPointerIsValid(session_->updateQuery))
                {
-                       lock_acquire();
-                       set_shared_string(
-                               session_->errorMessage,
-                               std::string(Tag) + ": " + tag_ + ": update" + 
": query is missing");
-                       lock_release();
+                       set_error_message(
+                               std::string(Tag) + ": " + tag_ + ": " + tag + 
": query is missing", tag);
                        return;
                }
 
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
updating").c_str());
 
-               lock_acquire();
-               std::string query(
-                       static_cast<const char*>(dsa_get_address(area_, 
session_->updateQuery)));
-               dsa_free(area_, session_->updateQuery);
-               session_->updateQuery = InvalidDsaPointer;
-               lock_release();
-               P("%s: %s: update: %s", Tag, tag_, query.c_str());
+               std::string query;
+               {
+                       ProcessorLockGuard lock(this);
+                       query =
+                               static_cast<const char*>(dsa_get_address(area_, 
session_->updateQuery));
+                       dsa_free(area_, session_->updateQuery);
+                       session_->updateQuery = InvalidDsaPointer;
+               }
+               P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
 
                {
-                       Transaction transaction;
-                       Snapshot snapshot;
+                       ScopedTransaction scopedTransaction;
+                       ScopedSnapshot scopedSnapshot;
 
                        SetCurrentStatementStartTimestamp();
                        auto result = SPI_execute(query.c_str(), false, 0);
-                       switch (result)
+                       if (result > 0)
                        {
-                               case SPI_OK_INSERT:
-                               case SPI_OK_DELETE:
-                               case SPI_OK_UPDATE:
-                                       session_->nUpdatedRecords = 
SPI_processed;
-                                       break;
-                               default:
-                                       lock_acquire();
-                                       
set_shared_string(session_->errorMessage,
-                                                         std::string(Tag) + ": 
" + tag_ + ": update" +
-                                                             ": failed to run 
a query: <" + query +
-                                                             ">: " + 
SPI_result_code_string(result));
-                                       lock_release();
-                                       break;
+                               session_->nUpdatedRecords = SPI_processed;
+                               signal_server(tag);
+                       }
+                       else
+                       {
+                               set_error_message(std::string(Tag) + ": " + 
tag_ + ": " + tag +
+                                                     ": failed to run a query: 
<" + query +
+                                                     ">: " + 
SPI_result_code_string(result),
+                                                 tag);
                        }
                }
-
-               if (sharedData_->serverPID != InvalidPid)
-               {
-                       P("%s: %s: update: kill server: %d", Tag, tag_, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
-               }
-
-               pgstat_report_activity(STATE_IDLE, NULL);
        }
 
        void prepare()
        {
+               const char* tag = "prepare";
                if (!DsaPointerIsValid(session_->prepareQuery))
                {
-                       lock_acquire();
-                       set_shared_string(
-                               session_->errorMessage,
-                               std::string(Tag) + ": " + tag_ + ": prepare" + 
": query is missing");
-                       lock_release();
+                       set_error_message(
+                               std::string(Tag) + ": " + tag_ + ": " + tag + 
": query is missing", tag);
                        return;
                }
 
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
preparing").c_str());
 
-               lock_acquire();
-               std::string query(
-                       static_cast<const char*>(dsa_get_address(area_, 
session_->prepareQuery)));
-               dsa_free(area_, session_->prepareQuery);
-               session_->prepareQuery = InvalidDsaPointer;
-               lock_release();
-               P("%s: %s: prepare: %s", Tag, tag_, query.c_str());
+               std::string query;
+               {
+                       ProcessorLockGuard lock(this);
+                       query =
+                               static_cast<const char*>(dsa_get_address(area_, 
session_->prepareQuery));
+                       dsa_free(area_, session_->prepareQuery);
+                       session_->prepareQuery = InvalidDsaPointer;
+               }
+               P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
                std::string handle(std::to_string(nextPreparedStatementID_++));
-               set_shared_string(session_->preparedStatementHandle, handle);
                preparedStatements_.insert(
                        std::make_pair(handle, 
PreparedStatement(std::move(query))));
+               {
+                       ProcessorLockGuard lock(this);
+                       set_shared_string(session_->preparedStatementHandle, 
handle);
+               }
+               signal_server(tag);
+       }
+
+       bool extract_handle(std::string& handle, const char* tag)
+       {
+               if (!DsaPointerIsValid(session_->preparedStatementHandle))
+               {
+                       set_error_message(
+                               std::string(Tag) + ": " + tag_ + ": " + tag + 
": handle is missing", tag);
+                       return false;
+               }
 
-               if (sharedData_->serverPID != InvalidPid)
                {
-                       P("%s: %s: prepare: kill server: %d", Tag, tag_, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
+                       ProcessorLockGuard lock(this);
+                       handle = static_cast<const char*>(
+                               dsa_get_address(area_, 
session_->preparedStatementHandle));
+                       dsa_free(area_, session_->preparedStatementHandle);
+                       session_->preparedStatementHandle = InvalidDsaPointer;
                }
 
-               pgstat_report_activity(STATE_IDLE, nullptr);
+               return true;
        }
 
        void close_prepared_statement()
        {
                const char* tag = "close prepared statement";
 
-               if (!DsaPointerIsValid(session_->preparedStatementHandle))
+               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": " 
+ tag).c_str());
+
+               std::string handle;
+               if (!extract_handle(handle, tag))
                {
-                       lock_acquire();
-                       set_shared_string(
-                               session_->errorMessage,
-                               std::string(Tag) + ": " + tag_ + ": " + tag + 
": handle is missing");
-                       lock_release();
                        return;
                }
+               P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
+               if (preparedStatements_.erase(handle) > 0)
+               {
+                       signal_server(tag);
+               }
+               else
+               {
+                       set_error_message(std::string(Tag) + ": " + tag_ + ": " 
+ tag +
+                                             ": nonexistent handle: <" + 
handle + ">",
+                                         tag);
+               }
+       }
 
-               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": " 
+ tag).c_str());
+       PreparedStatement* find_prepared_statement(std::string& handle, const 
char* tag)
+       {
+               if (!extract_handle(handle, tag))
+               {
+                       return nullptr;
+               }
 
-               lock_acquire();
-               std::string handle(static_cast<const char*>(
-                       dsa_get_address(area_, 
session_->preparedStatementHandle)));
-               dsa_free(area_, session_->preparedStatementHandle);
-               session_->preparedStatementHandle = InvalidDsaPointer;
-               if (preparedStatements_.erase(handle) == 0)
+               ProcessorLockGuard lock(this);
+               auto it = preparedStatements_.find(handle);
+               if (it == preparedStatements_.end())
                {
-                       set_shared_string(session_->errorMessage,
-                                         std::string(Tag) + ": " + tag_ + ": " 
+ tag +
-                                             ": nonexistent handle: <" + 
handle + ">");
+                       set_error_message(std::string(Tag) + ": " + tag_ + ": " 
+ tag +
+                                             ": nonexistent handle: <" + 
handle + ">",
+                                         tag);
+                       return nullptr;
                }
-               lock_release();
+               else
+               {
+                       return &(it->second);
+               }
+       }
+
+       void set_parameters()
+       {
+               const char* tag = "set parameters";
+
+               pgstat_report_activity(STATE_RUNNING,
+                                      (std::string(Tag) + ": setting 
parameters").c_str());
+               std::string handle;
+               auto preparedStatement = find_prepared_statement(handle, tag);
                P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
 
-               if (sharedData_->serverPID != InvalidPid)
+               if (!preparedStatement)
                {
-                       P("%s: %s: %s: kill server: %d", Tag, tag_, tag, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
+                       return;
                }
 
-               pgstat_report_activity(STATE_IDLE, nullptr);
+               auto input = 
std::make_shared<SharedRingBufferInputStream>(this, session_);
+               auto status = preparedStatement->set_parameters(input);
+               if (status.ok())
+               {
+                       signal_server(tag);
+               }
+               else
+               {
+                       set_error_message(std::string(Tag) + ": " + tag_ + ": " 
+ tag +
+                                             ": failed to set parameters: <" + 
handle +
+                                             ">: " + status.ToString(),
+                                         tag);
+               }
        }
 
-       void update_prepared_statement()
+       void select_prepared_statement()
        {
-               const char* tag = "update prepared statement";
+               const char* tag = "select prepared statement";
 
-               if (!DsaPointerIsValid(session_->preparedStatementHandle))
+               pgstat_report_activity(
+                       STATE_RUNNING, (std::string(Tag) + ": selecting 
prepared statement").c_str());
+
+               std::string handle;
+               auto preparedStatement = find_prepared_statement(handle, tag);
+               P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
+
+               if (!preparedStatement)
                {
-                       lock_acquire();
-                       set_shared_string(
-                               session_->errorMessage,
-                               std::string(Tag) + ": " + tag_ + ": " + tag + " 
: handle is missing");
-                       lock_release();
                        return;
                }
 
-               pgstat_report_activity(
-                       STATE_RUNNING, (std::string(Tag) + ": updating prepared 
statement").c_str());
+               ScopedTransaction scopedTransaction;
+               ScopedSnapshot scopedSnapshot;
 
-               lock_acquire();
-               std::string handle(static_cast<const char*>(
-                       dsa_get_address(area_, 
session_->preparedStatementHandle)));
-               dsa_free(area_, session_->preparedStatementHandle);
-               session_->preparedStatementHandle = InvalidDsaPointer;
-               PreparedStatement* preparedStatement = nullptr;
-               auto it = preparedStatements_.find(handle);
-               if (it == preparedStatements_.end())
+               struct Data {
+                       Executor* executor;
+                       const char* tag;
+               } data = {this, tag};
+               auto status = preparedStatement->select(
+                       [](void* data) {
+                               auto d = static_cast<Data*>(data);
+                               return d->executor->write(d->tag);
+                       },
+                       &data);
+               if (status.ok())
                {
-                       set_shared_string(session_->errorMessage,
-                                         std::string(Tag) + ": " + tag_ + ": " 
+ tag +
-                                             ": nonexistent handle: <" + 
handle + ">");
+                       signal_server(tag);
                }
                else
                {
-                       preparedStatement = &(it->second);
+                       set_error_message(std::string(Tag) + ": " + tag_ + ": " 
+ tag +
+                                             ": failed to select a prepared 
statement: <" + handle +
+                                             ">: " + status.ToString(),
+                                         tag);
                }
-               lock_release();
+       }
+
+       void update_prepared_statement()
+       {
+               const char* tag = "update prepared statement";
+
+               pgstat_report_activity(
+                       STATE_RUNNING, (std::string(Tag) + ": updating prepared 
statement").c_str());
+
+               std::string handle;
+               auto preparedStatement = find_prepared_statement(handle, tag);
                P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
 
-               if (preparedStatement)
+               if (!preparedStatement)
                {
-                       Transaction transaction;
-                       Snapshot snapshot;
-
-                       auto input = 
std::make_shared<SharedRingBufferInputStream>(this, session_);
-                       auto n_updated_records_result = 
preparedStatement->update(input);
-                       if (n_updated_records_result.ok())
-                       {
-                               session_->nUpdatedRecords = 
*n_updated_records_result;
-                       }
-                       else
-                       {
-                               set_shared_string(
-                                       session_->errorMessage,
-                                       std::string(Tag) + ": " + tag_ + ": " + 
tag +
-                                               ": failed to update a prepared 
statement: <" + handle +
-                                               ">: " + 
n_updated_records_result.status().ToString());
-                       }
+                       return;
                }
 
-               if (sharedData_->serverPID != InvalidPid)
+               ScopedTransaction scopedTransaction;
+               ScopedSnapshot scopedSnapshot;
+
+               auto input = 
std::make_shared<SharedRingBufferInputStream>(this, session_);
+               auto n_updated_records_result = 
preparedStatement->update(input);
+               if (n_updated_records_result.ok())
                {
-                       P("%s: %s: %s: kill server: %d", Tag, tag_, tag, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
+                       session_->nUpdatedRecords = *n_updated_records_result;
+                       signal_server(tag);
+               }
+               else
+               {
+                       set_error_message(std::string(Tag) + ": " + tag_ + ": " 
+ tag +
+                                             ": failed to update a prepared 
statement: <" + handle +
+                                             ">: " + 
n_updated_records_result.status().ToString(),
+                                         tag);
                }
-
-               pgstat_report_activity(STATE_IDLE, nullptr);
        }
 
        uint64_t sessionID_;
@@ -2018,6 +2158,97 @@ class Proxy : public WorkerProcessor {
                return arrow::Status::OK();
        }
 
+       arrow::Status set_parameters(uint64_t sessionID,
+                                    const std::string& handle,
+                                    arrow::flight::FlightMessageReader* reader,
+                                    arrow::flight::FlightMetadataWriter* 
writer)
+       {
+#ifdef AFS_DEBUG
+               const char* tag = "set parameters";
+#endif
+               auto session = find_session(sessionID);
+               SessionReleaser sessionReleaser(sessions_, session);
+               lock_acquire();
+               set_shared_string(session->preparedStatementHandle, handle);
+               session->action = Action::SetParameters;
+               lock_release();
+               if (session->executorPID != InvalidPid)
+               {
+                       P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, 
session->executorPID);
+                       kill(session->executorPID, SIGUSR1);
+               }
+               {
+                       ARROW_ASSIGN_OR_RAISE(const auto& schema, 
reader->GetSchema());
+                       SharedRingBufferOutputStream output(this, session);
+                       auto options = arrow::ipc::IpcWriteOptions::Defaults();
+                       options.emit_dictionary_deltas = true;
+                       ARROW_ASSIGN_OR_RAISE(auto writer,
+                                             
arrow::ipc::MakeStreamWriter(&output, schema, options));
+                       while (true)
+                       {
+                               ARROW_ASSIGN_OR_RAISE(const auto& chunk, 
reader->Next());
+                               if (!chunk.data)
+                               {
+                                       break;
+                               }
+                               
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data)));
+                       }
+                       ARROW_RETURN_NOT_OK(writer->Close());
+               }
+               if (session->executorPID != InvalidPid)
+               {
+                       P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, 
session->executorPID);
+                       kill(session->executorPID, SIGUSR1);
+               }
+               {
+                       auto buffer = create_shared_ring_buffer(session);
+                       std::unique_lock<std::mutex> lock(mutex_);
+                       conditionVariable_.wait(lock, [&] {
+                               P("%s: %s: %s: wait", Tag, tag_, tag);
+                               return DsaPointerIsValid(session->errorMessage) 
|| buffer.size() == 0;
+                       });
+               }
+               if (DsaPointerIsValid(session->errorMessage))
+               {
+                       return report_session_error(session);
+               }
+               P("%s: %s: %s: done", Tag, tag_, tag);
+               return arrow::Status::OK();
+       }
+
+       arrow::Result<std::shared_ptr<arrow::Schema>> select_prepared_statement(
+               uint64_t sessionID, const std::string& handle)
+       {
+               const char* tag = "select prepared statement";
+               auto session = find_session(sessionID);
+               SessionReleaser sessionReleaser(sessions_, session);
+               lock_acquire();
+               set_shared_string(session->preparedStatementHandle, handle);
+               session->action = Action::SelectPreparedStatement;
+               lock_release();
+               if (session->executorPID != InvalidPid)
+               {
+                       P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, 
session->executorPID);
+                       kill(session->executorPID, SIGUSR1);
+               }
+               {
+                       auto buffer = create_shared_ring_buffer(session);
+                       std::unique_lock<std::mutex> lock(mutex_);
+                       conditionVariable_.wait(lock, [&] {
+                               P("%s: %s: %s: wait", Tag, tag_, tag);
+                               return DsaPointerIsValid(session->errorMessage) 
|| buffer.size() > 0;
+                       });
+               }
+               if (DsaPointerIsValid(session->errorMessage))
+               {
+                       return report_session_error(session);
+               }
+               P("%s: %s: %s: open", Tag, tag_, tag);
+               auto schema = read_schema(session, tag);
+               P("%s: %s: %s: schema", Tag, tag_, tag);
+               return schema;
+       }
+
        arrow::Result<int64_t> update_prepared_statement(
                uint64_t sessionID,
                const std::string& handle,
@@ -2476,6 +2707,37 @@ class FlightSQLServer : public 
arrow::flight::sql::FlightSqlServerBase {
                return proxy_->close_prepared_statement(sessionID, handle);
        }
 
+       arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
+       GetFlightInfoPreparedStatement(
+               const arrow::flight::ServerCallContext& context,
+               const arrow::flight::sql::PreparedStatementQuery& command,
+               const arrow::flight::FlightDescriptor& descriptor) override
+       {
+               ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+               const auto& handle = command.prepared_statement_handle;
+               ARROW_ASSIGN_OR_RAISE(auto schema,
+                                     
proxy_->select_prepared_statement(sessionID, handle));
+               ARROW_ASSIGN_OR_RAISE(auto ticket,
+                                     
arrow::flight::sql::CreateStatementQueryTicket(handle));
+               std::vector<arrow::flight::FlightEndpoint> endpoints{
+                       arrow::flight::FlightEndpoint{std::move(ticket), {}}};
+               ARROW_ASSIGN_OR_RAISE(
+                       auto result,
+                       arrow::flight::FlightInfo::Make(*schema, descriptor, 
endpoints, -1, -1));
+               return std::make_unique<arrow::flight::FlightInfo>(result);
+       }
+
+       arrow::Status DoPutPreparedStatementQuery(
+               const arrow::flight::ServerCallContext& context,
+               const arrow::flight::sql::PreparedStatementQuery& command,
+               arrow::flight::FlightMessageReader* reader,
+               arrow::flight::FlightMetadataWriter* writer) override
+       {
+               ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+               const auto& handle = command.prepared_statement_handle;
+               return proxy_->set_parameters(sessionID, handle, reader, 
writer);
+       }
+
        arrow::Result<int64_t> DoPutPreparedStatementUpdate(
                const arrow::flight::ServerCallContext& context,
                const arrow::flight::sql::PreparedStatementUpdate& command,
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 4465d56..429fa95 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -67,7 +67,7 @@ class FlightSQLTest < Test::Unit::TestCase
   data("string - varchar", ["varchar(10)", :string, "b"])
   data("binary", ["bytea", :binary, "\x0".b])
   data("timestamp", ["timestamp", [:timestamp, :micro], timestamp_value])
-  def test_select_type
+  def test_select_direct
     pg_type, data_type, value = data
     data_type = Arrow::DataType.resolve(data_type)
     values = data_type.build_array([value])
@@ -81,6 +81,35 @@ class FlightSQLTest < Test::Unit::TestCase
                  reader.read_all)
   end
 
+  data("int16",  [:int16, -2])
+  data("int32",  [:int32, -2])
+  data("int64",  [:int64, -2])
+  data("float",  [:float, -2.2])
+  data("double", [:double, -2.2])
+  data("string", [:string, "b"])
+  data("binary", [:binary, "\x0".b])
+  data("timestamp", [[:timestamp, :micro], timestamp_value])
+  def test_select_prepare
+    unless flight_sql_client.respond_to?(:prepare)
+      omit("red-arrow-flight-sql 14.0.0 or later is required")
+    end
+
+    data_type, value = data
+    data_type = Arrow::DataType.resolve(data_type)
+    values = data_type.build_array([value])
+    flight_sql_client.prepare("SELECT $1 AS value",
+                              @options) do |statement|
+      statement.record_batch = Arrow::RecordBatch.new(value: values)
+      info = statement.execute(@options)
+      assert_equal(Arrow::Schema.new(value: values.value_data_type),
+                   info.get_schema)
+      endpoint = info.endpoints.first
+      reader = flight_sql_client.do_get(endpoint.ticket, @options)
+      assert_equal(Arrow::Table.new(value: values),
+                   reader.read_all)
+    end
+  end
+
   def test_select_from
     run_sql("CREATE TABLE data (value integer)")
     run_sql("INSERT INTO data VALUES (1), (-2), (3)")
@@ -144,7 +173,7 @@ SELECT * FROM data
        ["timestamp", [:timestamp, :micro], timestamp_values])
   data("timestamp(nano)",
        ["timestamp", [:timestamp, :nano], timestamp_values])
-  def test_insert_type
+  def test_insert_prepare
     unless flight_sql_client.respond_to?(:prepare)
       omit("red-arrow-flight-sql 14.0.0 or later is required")
     end


Reply via email to