This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new d4c7753  Add missing error check on waiting read/write (#174)
d4c7753 is described below

commit d4c77538c657b9f33afa3c6a660d6cc6d98054be
Author: Sutou Kouhei <[email protected]>
AuthorDate: Mon Nov 20 16:30:19 2023 +0900

    Add missing error check on waiting read/write (#174)
    
    Closes GH-173
---
 src/afs.cc | 481 ++++++++++++++++++++++++++++++++++---------------------------
 1 file changed, 270 insertions(+), 211 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index 867596c..d564fbc 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -471,7 +471,8 @@ struct LocalSessionData {
                  valid(false),
                  peerPID(InvalidPid),
                  bufferData(nullptr),
-                 bufferAddress(nullptr)
+                 bufferAddress(nullptr),
+                 errorMessage(std::nullopt)
        {
        }
 
@@ -480,6 +481,7 @@ struct LocalSessionData {
        pid_t peerPID;
        SharedRingBufferData* bufferData;
        void* bufferAddress;
+       std::optional<std::string> errorMessage;
 };
 
 // Session data shared with multiple processes. LWLockAcquire() with
@@ -1471,8 +1473,8 @@ class Executor : public WorkerProcessor {
        void open()
        {
                const char* tag = "open";
-               // pg_usleep(5000000);
-               // pg_usleep(5000000);
+               // Use this when you want to attach a executor process.
+               // sleep(5);
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
opening").c_str());
                auto session = find_session();
                PG_TRY();
@@ -2435,10 +2437,60 @@ class Proxy : public WorkerProcessor {
                process_update_prepared_statement_requests();
                process_read_requests();
                process_write_requests();
-               delete_finished_sessions();
+               sync_sessions();
                conditionVariable_.notify_all();
        }
 
+       // Don't use PostgreSQL API.
+       bool is_valid_session(uint64_t sessionID) { return 
find_session(sessionID).ok(); }
+
+   private:
+       // Don't use PostgreSQL API.
+       arrow::Result<std::shared_ptr<LocalSessionData>> find_session(uint64_t 
id)
+       {
+               std::lock_guard<std::mutex> lock(mutex_);
+               auto it = localSessions_.find(id);
+               if (it == localSessions_.end())
+               {
+                       return arrow::Status::Invalid("Unknown session: ", id);
+               }
+               return it->second;
+       }
+
+       // Don't use PostgreSQL API.
+       arrow::Status check_local_session_error(LocalSessionData* localSession,
+                                               bool locked = false)
+       {
+               if (!localSession->errorMessage.has_value())
+               {
+                       return arrow::Status::OK();
+               }
+               if (locked)
+               {
+                       auto errorMessage = 
std::move(localSession->errorMessage.value());
+                       localSession->errorMessage = std::nullopt;
+                       return arrow::Status::Invalid(errorMessage);
+               }
+               else
+               {
+                       std::lock_guard<std::mutex> lock(mutex_);
+                       if (!localSession->errorMessage.has_value())
+                       {
+                               return arrow::Status::OK();
+                       }
+                       auto errorMessage = 
std::move(localSession->errorMessage.value());
+                       localSession->errorMessage = std::nullopt;
+                       return arrow::Status::Invalid(errorMessage);
+               }
+       }
+
+       // Don't use PostgreSQL API.
+       arrow::Status check_local_session_error(
+               const std::shared_ptr<LocalSessionData>& localSession)
+       {
+               return check_local_session_error(localSession.get());
+       }
+
    private:
        struct ConnectRequest {
                ConnectRequest(uint64_t sessionID,
@@ -2452,8 +2504,7 @@ class Proxy : public WorkerProcessor {
                          password(std::move(password)),
                          clientAddress(std::move(clientAddress)),
                          processing(false),
-                         finished(false),
-                         errorMessage(std::nullopt)
+                         finished(false)
                {
                }
                uint64_t sessionID;
@@ -2463,7 +2514,6 @@ class Proxy : public WorkerProcessor {
                std::string clientAddress;
                bool processing;
                bool finished;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<ConnectRequest>> connectRequests_;
@@ -2487,27 +2537,18 @@ class Proxy : public WorkerProcessor {
                                        continue;
                                }
                                const auto initialized = session->initialized;
-                               if (initialized)
+                               if (initialized && 
!DsaPointerIsValid(session->errorMessage))
                                {
-                                       if 
(DsaPointerIsValid(session->errorMessage))
-                                       {
-                                               request->errorMessage = 
static_cast<const char*>(
-                                                       dsa_get_address(area_, 
session->errorMessage));
-                                       }
-                                       else
-                                       {
-                                               auto& localSession =
-                                                       
localSessions_.find(request->sessionID)->second;
-                                               localSession->valid = true;
-                                               localSession->peerPID = 
session->executorPID;
-                                               localSession->bufferData = 
&(session->bufferData);
-                                               localSession->bufferAddress =
-                                                       dsa_get_address(area_, 
session->bufferData.pointer);
-                                               P("%s: %s: connect: %" PRIu64,
-                                                 Tag,
-                                                 tag_,
-                                                 session->bufferData.pointer);
-                                       }
+                                       auto& localSession = 
localSessions_.find(request->sessionID)->second;
+                                       localSession->valid = true;
+                                       localSession->peerPID = 
session->executorPID;
+                                       localSession->bufferData = 
&(session->bufferData);
+                                       localSession->bufferAddress =
+                                               dsa_get_address(area_, 
session->bufferData.pointer);
+                                       P("%s: %s: connect: %" PRIu64,
+                                         Tag,
+                                         tag_,
+                                         session->bufferData.pointer);
                                }
                                dshash_release_lock(sessions_, session);
                                if (!initialized)
@@ -2526,8 +2567,9 @@ class Proxy : public WorkerProcessor {
                                if (found)
                                {
                                        request->finished = true;
-                                       request->errorMessage = 
std::string("duplicated session ID: ") +
-                                                               
std::to_string(request->sessionID);
+                                       auto& localSession = 
localSessions_.find(request->sessionID)->second;
+                                       localSession->errorMessage = 
std::string("duplicated session ID: ") +
+                                                                    
std::to_string(request->sessionID);
                                        it = connectRequests_.erase(it);
                                }
                                else
@@ -2583,6 +2625,7 @@ class Proxy : public WorkerProcessor {
                                        const std::string& clientAddress)
        {
                auto id = assign_session_id();
+               ARROW_ASSIGN_OR_RAISE(auto localSession, find_session(id));
                auto request = std::make_shared<ConnectRequest>(
                        id, databaseName, userName, password, clientAddress);
                {
@@ -2593,6 +2636,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -2600,10 +2647,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
@@ -2611,31 +2655,17 @@ class Proxy : public WorkerProcessor {
                return id;
        }
 
-       // Don't use PostgreSQL API.
-       bool is_valid_session(uint64_t sessionID)
-       {
-               std::lock_guard<std::mutex> lock(mutex_);
-               auto it = localSessions_.find(sessionID);
-               if (it == localSessions_.end())
-               {
-                       return false;
-               }
-               return it->second->valid;
-       }
-
    private:
        struct SelectRequest {
-               SelectRequest(uint64_t sessionID, std::string query)
-                       : sessionID(sessionID),
+               SelectRequest(std::shared_ptr<LocalSessionData> localSession, 
std::string query)
+                       : localSession(std::move(localSession)),
                          query(std::move(query)),
-                         finished(false),
-                         errorMessage(std::nullopt)
+                         finished(false)
                {
                }
-               uint64_t sessionID;
+               std::shared_ptr<LocalSessionData> localSession;
                std::string query;
                bool finished;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<SelectRequest>> selectRequests_;
@@ -2651,11 +2681,12 @@ class Proxy : public WorkerProcessor {
                {
                        request->finished = true;
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->sessionID), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
-                               request->errorMessage =
-                                       std::string("stolen session: ") + 
std::to_string(request->sessionID);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                continue;
                        }
                        auto executorPID = session->executorPID;
@@ -2671,15 +2702,11 @@ class Proxy : public WorkerProcessor {
                selectRequests_.clear();
        }
 
-       arrow::Result<std::shared_ptr<arrow::Schema>> read_schema(uint64_t 
sessionID,
-                                                                 const char* 
tag)
+       arrow::Result<std::shared_ptr<arrow::Schema>> read_schema(
+               std::shared_ptr<LocalSessionData> localSession, const char* tag)
        {
-               auto it = localSessions_.find(sessionID);
-               if (it == localSessions_.end())
-               {
-                       return arrow::Status::Invalid("Unknown session: ", 
sessionID);
-               }
-               auto input = 
std::make_shared<SharedRingBufferInputStream>(this, it->second);
+               auto input =
+                       std::make_shared<SharedRingBufferInputStream>(this, 
std::move(localSession));
 
                // Read schema only stream format data.
                ARROW_ASSIGN_OR_RAISE(auto reader,
@@ -2702,7 +2729,8 @@ class Proxy : public WorkerProcessor {
                                                             const std::string& 
query)
        {
                const char* tag = "select";
-               auto request = std::make_shared<SelectRequest>(sessionID, 
query);
+               ARROW_ASSIGN_OR_RAISE(auto localSession, 
find_session(sessionID));
+               auto request = std::make_shared<SelectRequest>(localSession, 
query);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        selectRequests_.push_back(request);
@@ -2711,6 +2739,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -2718,37 +2750,32 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
                }
                P("%s: %s: %s: open", Tag, tag_, tag);
-               auto schema = read_schema(sessionID, tag);
+               auto schema = read_schema(localSession, tag);
                P("%s: %s: %s: schema", Tag, tag_, tag);
                return schema;
        }
 
    private:
        struct UpdateRequest {
-               UpdateRequest(uint64_t sessionID, std::string query)
-                       : sessionID(sessionID),
+               UpdateRequest(std::shared_ptr<LocalSessionData> localSession, 
std::string query)
+                       : localSession(std::move(localSession)),
                          query(std::move(query)),
                          processing(false),
                          finished(false),
-                         nUpdatedRecords(0),
-                         errorMessage(std::nullopt)
+                         nUpdatedRecords(0)
                {
                }
-               uint64_t sessionID;
+               std::shared_ptr<LocalSessionData> localSession;
                std::string query;
                bool processing;
                bool finished;
                int64_t nUpdatedRecords;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<UpdateRequest>> updateRequests_;
@@ -2764,12 +2791,13 @@ class Proxy : public WorkerProcessor {
                {
                        auto& request = *it;
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->sessionID), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
                                request->finished = true;
-                               request->errorMessage =
-                                       std::string("stolen session: ") + 
std::to_string(request->sessionID);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                it = updateRequests_.erase(it);
                                continue;
                        }
@@ -2780,8 +2808,6 @@ class Proxy : public WorkerProcessor {
                                {
                                        request->processing = false;
                                        request->finished = true;
-                                       request->errorMessage = 
static_cast<const char*>(
-                                               dsa_get_address(area_, 
session->errorMessage));
                                        it = updateRequests_.erase(it);
                                }
                                else if (session->nUpdatedRecords >= 0)
@@ -2820,7 +2846,8 @@ class Proxy : public WorkerProcessor {
 #ifdef AFS_VERBOSE
                const char* tag = "update";
 #endif
-               auto request = std::make_shared<UpdateRequest>(sessionID, 
query);
+               ARROW_ASSIGN_OR_RAISE(auto localSession, 
find_session(sessionID));
+               auto request = std::make_shared<UpdateRequest>(localSession, 
query);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        updateRequests_.push_back(request);
@@ -2829,6 +2856,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -2836,10 +2867,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
@@ -2850,21 +2878,19 @@ class Proxy : public WorkerProcessor {
 
    private:
        struct PrepareRequest {
-               PrepareRequest(uint64_t sessionID, std::string query)
-                       : sessionID(sessionID),
+               PrepareRequest(std::shared_ptr<LocalSessionData> localSession, 
std::string query)
+                       : localSession(std::move(localSession)),
                          query(std::move(query)),
                          processing(false),
                          finished(false),
-                         handle(),
-                         errorMessage(std::nullopt)
+                         handle()
                {
                }
-               uint64_t sessionID;
+               std::shared_ptr<LocalSessionData> localSession;
                std::string query;
                bool processing;
                bool finished;
                std::string handle;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<PrepareRequest>> prepareRequests_;
@@ -2880,12 +2906,13 @@ class Proxy : public WorkerProcessor {
                {
                        auto& request = *it;
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->sessionID), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
                                request->finished = true;
-                               request->errorMessage =
-                                       std::string("stolen session: ") + 
std::to_string(request->sessionID);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                it = prepareRequests_.erase(it);
                                continue;
                        }
@@ -2896,8 +2923,6 @@ class Proxy : public WorkerProcessor {
                                {
                                        request->processing = false;
                                        request->finished = true;
-                                       request->errorMessage = 
static_cast<const char*>(
-                                               dsa_get_address(area_, 
session->errorMessage));
                                        it = prepareRequests_.erase(it);
                                }
                                else if 
(DsaPointerIsValid(session->preparedStatementHandle))
@@ -2939,7 +2964,8 @@ class Proxy : public WorkerProcessor {
 #ifdef AFS_VERBOSE
                const char* tag = "prepare";
 #endif
-               auto request = std::make_shared<PrepareRequest>(sessionID, 
query);
+               ARROW_ASSIGN_OR_RAISE(auto localSession, 
find_session(sessionID));
+               auto request = std::make_shared<PrepareRequest>(localSession, 
query);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        prepareRequests_.push_back(request);
@@ -2948,6 +2974,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -2955,10 +2985,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
@@ -2974,19 +3001,18 @@ class Proxy : public WorkerProcessor {
 
    private:
        struct ClosePreparedStatementRequest {
-               ClosePreparedStatementRequest(uint64_t sessionID, std::string 
handle)
-                       : sessionID(sessionID),
+               ClosePreparedStatementRequest(std::shared_ptr<LocalSessionData> 
localSession,
+                                             std::string handle)
+                       : localSession(std::move(localSession)),
                          handle(std::move(handle)),
                          processing(false),
-                         finished(false),
-                         errorMessage(std::nullopt)
+                         finished(false)
                {
                }
-               uint64_t sessionID;
+               std::shared_ptr<LocalSessionData> localSession;
                std::string handle;
                bool processing;
                bool finished;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<ClosePreparedStatementRequest>>
@@ -3004,12 +3030,13 @@ class Proxy : public WorkerProcessor {
                {
                        auto& request = *it;
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->sessionID), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
                                request->finished = true;
-                               request->errorMessage =
-                                       std::string("stolen session: ") + 
std::to_string(request->sessionID);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                it = closePreparedStatementRequests_.erase(it);
                                continue;
                        }
@@ -3020,8 +3047,6 @@ class Proxy : public WorkerProcessor {
                                {
                                        request->processing = false;
                                        request->finished = true;
-                                       request->errorMessage = 
static_cast<const char*>(
-                                               dsa_get_address(area_, 
session->errorMessage));
                                        it = 
closePreparedStatementRequests_.erase(it);
                                }
                                else if 
(DsaPointerIsValid(session->preparedStatementHandle))
@@ -3058,7 +3083,9 @@ class Proxy : public WorkerProcessor {
 #ifdef AFS_VERBOSE
                const char* tag = "close prepared statement";
 #endif
-               auto request = 
std::make_shared<ClosePreparedStatementRequest>(sessionID, handle);
+               ARROW_ASSIGN_OR_RAISE(auto localSession, 
find_session(sessionID));
+               auto request =
+                       
std::make_shared<ClosePreparedStatementRequest>(localSession, handle);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        closePreparedStatementRequests_.push_back(request);
@@ -3067,6 +3094,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -3074,10 +3105,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
@@ -3088,19 +3116,18 @@ class Proxy : public WorkerProcessor {
 
    private:
        struct SetParametersRequest {
-               SetParametersRequest(uint64_t sessionID, std::string handle)
-                       : sessionID(sessionID),
+               SetParametersRequest(std::shared_ptr<LocalSessionData> 
localSession,
+                                    std::string handle)
+                       : localSession(std::move(localSession)),
                          handle(std::move(handle)),
                          processing(false),
-                         finished(false),
-                         errorMessage(std::nullopt)
+                         finished(false)
                {
                }
-               uint64_t sessionID;
+               std::shared_ptr<LocalSessionData> localSession;
                std::string handle;
                bool processing;
                bool finished;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<SetParametersRequest>> setParametersRequests_;
@@ -3117,12 +3144,13 @@ class Proxy : public WorkerProcessor {
                {
                        auto& request = *it;
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->sessionID), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
                                request->finished = true;
-                               request->errorMessage =
-                                       std::string("stolen session: ") + 
std::to_string(request->sessionID);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                it = setParametersRequests_.erase(it);
                                continue;
                        }
@@ -3133,8 +3161,6 @@ class Proxy : public WorkerProcessor {
                                {
                                        request->processing = false;
                                        request->finished = true;
-                                       request->errorMessage = 
static_cast<const char*>(
-                                               dsa_get_address(area_, 
session->errorMessage));
                                        it = setParametersRequests_.erase(it);
                                }
                                else if (session->setParametersFinished)
@@ -3175,17 +3201,17 @@ class Proxy : public WorkerProcessor {
 #ifdef AFS_VERBOSE
                const char* tag = "set parameters";
 #endif
-               auto request = 
std::make_shared<SetParametersRequest>(sessionID, handle);
+               ARROW_ASSIGN_OR_RAISE(auto localSession, 
find_session(sessionID));
+               auto request = 
std::make_shared<SetParametersRequest>(localSession, handle);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        setParametersRequests_.push_back(request);
                }
                kill(MyProcPid, SIGUSR1);
-               auto session = localSessions_.find(sessionID)->second;
-               auto executorPID = session->peerPID;
+               auto executorPID = localSession->peerPID;
                {
                        ARROW_ASSIGN_OR_RAISE(const auto& schema, 
reader->GetSchema());
-                       SharedRingBufferOutputStream output(this, 
std::move(session));
+                       SharedRingBufferOutputStream output(this, localSession);
                        auto options = arrow::ipc::IpcWriteOptions::Defaults();
                        options.emit_dictionary_deltas = true;
                        ARROW_ASSIGN_OR_RAISE(auto writer,
@@ -3206,6 +3232,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -3213,10 +3243,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
@@ -3227,17 +3254,16 @@ class Proxy : public WorkerProcessor {
 
    private:
        struct SelectPreparedStatementRequest {
-               SelectPreparedStatementRequest(uint64_t sessionID, std::string 
handle)
-                       : sessionID(sessionID),
+               
SelectPreparedStatementRequest(std::shared_ptr<LocalSessionData> localSession,
+                                              std::string handle)
+                       : localSession(std::move(localSession)),
                          handle(std::move(handle)),
-                         finished(false),
-                         errorMessage(std::nullopt)
+                         finished(false)
                {
                }
-               uint64_t sessionID;
+               std::shared_ptr<LocalSessionData> localSession;
                std::string handle;
                bool finished;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<SelectPreparedStatementRequest>>
@@ -3254,11 +3280,12 @@ class Proxy : public WorkerProcessor {
                {
                        request->finished = true;
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->sessionID), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
-                               request->errorMessage =
-                                       std::string("stolen session: ") + 
std::to_string(request->sessionID);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                continue;
                        }
                        auto executorPID = session->executorPID;
@@ -3279,8 +3306,9 @@ class Proxy : public WorkerProcessor {
                uint64_t sessionID, const std::string& handle)
        {
                const char* tag = "select prepared statement";
+               ARROW_ASSIGN_OR_RAISE(auto localSession, 
find_session(sessionID));
                auto request =
-                       
std::make_shared<SelectPreparedStatementRequest>(sessionID, handle);
+                       
std::make_shared<SelectPreparedStatementRequest>(localSession, handle);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        selectPreparedStatementRequests_.push_back(request);
@@ -3289,6 +3317,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -3296,37 +3328,33 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
                }
                P("%s: %s: %s: open", Tag, tag_, tag);
-               auto schema = read_schema(sessionID, tag);
+               auto schema = read_schema(localSession, tag);
                P("%s: %s: %s: schema", Tag, tag_, tag);
                return schema;
        }
 
    private:
        struct UpdatePreparedStatementRequest {
-               UpdatePreparedStatementRequest(uint64_t sessionID, std::string 
handle)
-                       : sessionID(sessionID),
+               
UpdatePreparedStatementRequest(std::shared_ptr<LocalSessionData> localSession,
+                                              std::string handle)
+                       : localSession(std::move(localSession)),
                          handle(std::move(handle)),
                          processing(false),
                          finished(false),
-                         nUpdatedRecords(0),
-                         errorMessage(std::nullopt)
+                         nUpdatedRecords(0)
                {
                }
-               uint64_t sessionID;
+               std::shared_ptr<LocalSessionData> localSession;
                std::string handle;
                bool processing;
                bool finished;
                int64_t nUpdatedRecords;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<UpdatePreparedStatementRequest>>
@@ -3344,12 +3372,13 @@ class Proxy : public WorkerProcessor {
                {
                        auto& request = *it;
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->sessionID), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
                                request->finished = true;
-                               request->errorMessage =
-                                       std::string("stolen session: ") + 
std::to_string(request->sessionID);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                it = updatePreparedStatementRequests_.erase(it);
                                continue;
                        }
@@ -3360,8 +3389,6 @@ class Proxy : public WorkerProcessor {
                                {
                                        request->processing = false;
                                        request->finished = true;
-                                       request->errorMessage = 
static_cast<const char*>(
-                                               dsa_get_address(area_, 
session->errorMessage));
                                        it = 
updatePreparedStatementRequests_.erase(it);
                                }
                                else if (session->nUpdatedRecords >= 0)
@@ -3403,18 +3430,18 @@ class Proxy : public WorkerProcessor {
 #ifdef AFS_VERBOSE
                const char* tag = "update prepared statement";
 #endif
+               ARROW_ASSIGN_OR_RAISE(auto localSession, 
find_session(sessionID));
                auto request =
-                       
std::make_shared<UpdatePreparedStatementRequest>(sessionID, handle);
+                       
std::make_shared<UpdatePreparedStatementRequest>(localSession, handle);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        updatePreparedStatementRequests_.push_back(request);
                }
                kill(MyProcPid, SIGUSR1);
-               auto session = localSessions_.find(sessionID)->second;
-               auto executorPID = session->peerPID;
+               auto executorPID = localSession->peerPID;
                {
                        ARROW_ASSIGN_OR_RAISE(const auto& schema, 
reader->GetSchema());
-                       SharedRingBufferOutputStream output(this, 
std::move(session));
+                       SharedRingBufferOutputStream output(this, localSession);
                        auto options = arrow::ipc::IpcWriteOptions::Defaults();
                        options.emit_dictionary_deltas = true;
                        ARROW_ASSIGN_OR_RAISE(auto writer,
@@ -3435,6 +3462,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -3442,10 +3473,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::Invalid(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
@@ -3456,11 +3484,11 @@ class Proxy : public WorkerProcessor {
 
    private:
        struct ReadRequest {
-               ReadRequest(LocalSessionData* session,
+               ReadRequest(LocalSessionData* localSession,
                            SharedRingBuffer* buffer,
                            size_t n,
                            void* output)
-                       : session(session),
+                       : localSession(localSession),
                          buffer(buffer),
                          n(n),
                          output(output),
@@ -3468,7 +3496,7 @@ class Proxy : public WorkerProcessor {
                          readBytes(0)
                {
                }
-               LocalSessionData* session;
+               LocalSessionData* localSession;
                SharedRingBuffer* buffer;
                size_t n;
                void* output;
@@ -3492,20 +3520,24 @@ class Proxy : public WorkerProcessor {
                                ProcessorLockGuard lock(this);
                                request->readBytes = 
request->buffer->read(request->n, request->output);
                        }
-                       P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, 
request->session->peerPID);
-                       kill(request->session->peerPID, SIGUSR1);
+                       P("%s: %s: %s: kill executor: %d",
+                         Tag,
+                         tag_,
+                         tag,
+                         request->localSession->peerPID);
+                       kill(request->localSession->peerPID, SIGUSR1);
                }
                readRequests_.clear();
        }
 
    public:
        // Don't use PostgreSQL API.
-       arrow::Result<size_t> read(LocalSessionData* session,
+       arrow::Result<size_t> read(LocalSessionData* localSession,
                                   SharedRingBuffer* buffer,
                                   size_t n,
                                   void* output) override
        {
-               auto request = std::make_shared<ReadRequest>(session, buffer, 
n, output);
+               auto request = std::make_shared<ReadRequest>(localSession, 
buffer, n, output);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        readRequests_.push_back(request);
@@ -3514,6 +3546,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -3521,6 +3557,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::IOError("interrupted");
@@ -3530,26 +3567,24 @@ class Proxy : public WorkerProcessor {
 
    private:
        struct WriteRequest {
-               WriteRequest(LocalSessionData* session,
+               WriteRequest(LocalSessionData* localSession,
                             SharedRingBuffer* buffer,
                             const void* data,
                             size_t n)
-                       : session(session),
+                       : localSession(localSession),
                          buffer(buffer),
                          data(data),
                          n(n),
                          finished(false),
-                         writtenBytes(0),
-                         errorMessage(std::nullopt)
+                         writtenBytes(0)
                {
                }
-               LocalSessionData* session;
+               LocalSessionData* localSession;
                SharedRingBuffer* buffer;
                const void* data;
                size_t n;
                bool finished;
                size_t writtenBytes;
-               std::optional<std::string> errorMessage;
        };
 
        std::list<std::shared_ptr<WriteRequest>> writeRequests_;
@@ -3564,22 +3599,18 @@ class Proxy : public WorkerProcessor {
                for (auto& request : writeRequests_)
                {
                        auto session = static_cast<SharedSessionData*>(
-                               dshash_find(sessions_, &(request->session->id), 
false));
+                               dshash_find(sessions_, 
&(request->localSession->id), false));
                        if (!session)
                        {
                                request->finished = true;
-                               request->errorMessage = std::string("stolen 
session: ") +
-                                                       
std::to_string(request->session->id);
+                               request->localSession->errorMessage =
+                                       std::string("stolen session: ") +
+                                       
std::to_string(request->localSession->id);
                                continue;
                        }
                        {
                                SharedSessionReleaser 
sessionReleaser(sessions_, session);
-                               if (DsaPointerIsValid(session->errorMessage))
-                               {
-                                       request->errorMessage = 
static_cast<const char*>(
-                                               dsa_get_address(area_, 
session->errorMessage));
-                               }
-                               else
+                               if (!DsaPointerIsValid(session->errorMessage))
                                {
                                        ProcessorLockGuard lock(this);
                                        request->writtenBytes =
@@ -3587,20 +3618,24 @@ class Proxy : public WorkerProcessor {
                                }
                                request->finished = true;
                        }
-                       P("%s: %s: %s: kill executor: %d", Tag, tag_, tag, 
request->session->peerPID);
-                       kill(request->session->peerPID, SIGUSR1);
+                       P("%s: %s: %s: kill executor: %d",
+                         Tag,
+                         tag_,
+                         tag,
+                         request->localSession->peerPID);
+                       kill(request->localSession->peerPID, SIGUSR1);
                }
                writeRequests_.clear();
        }
 
    public:
        // Don't use PostgreSQL API.
-       arrow::Result<size_t> write(LocalSessionData* session,
+       arrow::Result<size_t> write(LocalSessionData* localSession,
                                    SharedRingBuffer* buffer,
                                    const void* data,
                                    size_t n) override
        {
-               auto request = std::make_shared<WriteRequest>(session, buffer, 
data, n);
+               auto request = std::make_shared<WriteRequest>(localSession, 
buffer, data, n);
                {
                        std::lock_guard<std::mutex> lock(mutex_);
                        writeRequests_.push_back(request);
@@ -3609,6 +3644,10 @@ class Proxy : public WorkerProcessor {
                {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
+                               if (localSession->errorMessage.has_value())
+                               {
+                                       return true;
+                               }
                                if (INTERRUPTS_PENDING_CONDITION())
                                {
                                        return true;
@@ -3616,10 +3655,7 @@ class Proxy : public WorkerProcessor {
                                return request->finished;
                        });
                }
-               if (request->errorMessage.has_value())
-               {
-                       return 
arrow::Status::IOError(request->errorMessage.value());
-               }
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::IOError("interrupted");
@@ -3629,24 +3665,39 @@ class Proxy : public WorkerProcessor {
 
    private:
        // Can use PostgreSQL API.
-       void delete_finished_sessions()
+       void sync_sessions()
        {
                dshash_seq_status sessionsStatus;
                dshash_seq_init(&sessionsStatus, sessions_, false);
-               SharedSessionData* session;
-               while (
-                       (session = 
static_cast<SharedSessionData*>(dshash_seq_next(&sessionsStatus))))
+               while (true)
                {
-                       if (!session->finished)
+                       auto session =
+                               
static_cast<SharedSessionData*>(dshash_seq_next(&sessionsStatus));
+                       if (!session)
                        {
-                               continue;
+                               break;
                        }
+                       if (DsaPointerIsValid(session->errorMessage))
                        {
                                std::lock_guard<std::mutex> lock(mutex_);
-                               
localSessions_.erase(localSessions_.find(session->id));
+                               auto localSession = 
localSessions_.find(session->id)->second;
+                               if (localSession)
+                               {
+                                       localSession->errorMessage = 
static_cast<const char*>(
+                                               dsa_get_address(area_, 
session->errorMessage));
+                               }
+                               dsa_free(area_, session->errorMessage);
+                               session->errorMessage = InvalidDsaPointer;
+                       }
+                       if (session->finished)
+                       {
+                               {
+                                       std::lock_guard<std::mutex> 
lock(mutex_);
+                                       
localSessions_.erase(localSessions_.find(session->id));
+                               }
+                               shared_session_data_finalize(session, area_);
+                               dshash_delete_current(&sessionsStatus);
                        }
-                       shared_session_data_finalize(session, area_);
-                       dshash_delete_current(&sessionsStatus);
                }
                dshash_seq_term(&sessionsStatus);
        }
@@ -3656,25 +3707,33 @@ class Proxy : public WorkerProcessor {
 
        const char* peer_name() override { return "executor"; }
 
-       arrow::Status wait_internal(LocalSessionData* session,
+       arrow::Status wait_internal(LocalSessionData* localSession,
                                    SharedRingBuffer* buffer,
                                    WaitMode mode,
                                    const char* tag) override
        {
                std::unique_lock<std::mutex> lock(mutex_);
                conditionVariable_.wait(lock, [&] {
-                       P("%s: %s: %s: %s: wait: %" PRIsize,
+                       P("%s: %s: %s: %s: wait: %" PRIsize ": error: %s",
                          Tag,
                          tag_,
                          tag,
                          peer_name(),
-                         get_waiting_buffer_size(buffer, mode));
+                         get_waiting_buffer_size(buffer, mode),
+                         localSession->errorMessage.has_value()
+                             ? localSession->errorMessage.value().c_str()
+                             : "");
+                       if (localSession->errorMessage.has_value())
+                       {
+                               return true;
+                       }
                        if (INTERRUPTS_PENDING_CONDITION())
                        {
                                return true;
                        }
                        return get_waiting_buffer_size(buffer, mode) > 0;
                });
+               ARROW_RETURN_NOT_OK(check_local_session_error(localSession, 
true));
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid(tag_, ": ", tag, ": ", 
"interrupted");


Reply via email to