This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c3876c  Add support for session timeout (#16)
5c3876c is described below

commit 5c3876cb7c80cf3c3128f8bbfae9b55eea526802
Author: Sutou Kouhei <[email protected]>
AuthorDate: Sun Feb 12 16:13:09 2023 +0900

    Add support for session timeout (#16)
    
    Close GH-15
---
 src/afs.cc | 97 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 87 insertions(+), 10 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index 4fd3f76..9e40639 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -51,14 +51,17 @@ extern "C"
 }
 
 namespace {
+static const char* LibraryName = "arrow_flight_sql";
+static const char* SharedDataName = "arrow-flight-sql: shared data";
 static const char* Tag = "arrow-flight-sql";
+
 static const char* URIDefault = "grpc://127.0.0.1:15432";
 static char* URI;
-static volatile sig_atomic_t GotSIGTERM = false;
-static volatile sig_atomic_t GotSIGUSR1 = false;
-static const char* LibraryName = "arrow_flight_sql";
-static const char* SharedDataName = "arrow-flight-sql: shared data";
 
+static const int SessionTimeoutDefault = 300;
+static int SessionTimeout;
+
+static volatile sig_atomic_t GotSIGTERM = false;
 void afs_sigterm(SIGNAL_ARGS)
 {
        auto errnoSaved = errno;
@@ -67,6 +70,7 @@ void afs_sigterm(SIGNAL_ARGS)
        errno = errnoSaved;
 }
 
+static volatile sig_atomic_t GotSIGUSR1 = false;
 void afs_sigusr1(SIGNAL_ARGS)
 {
        procsignal_sigusr1_handler(postgres_signal_arg);
@@ -76,12 +80,30 @@ void afs_sigusr1(SIGNAL_ARGS)
        errno = errnoSaved;
 }
 
+static shmem_request_hook_type PreviousShmemRequestHook = nullptr;
+static const char* LWLockTrancheName = "arrow-flight-sql: lwlock tranche";
+void
+afs_shmem_request_hook(void)
+{
+       if (PreviousShmemRequestHook)
+               PreviousShmemRequestHook();
+
+       RequestNamedLWLockTranche(LWLockTrancheName, 1);
+}
+
+struct ConnectData {
+       dsa_pointer databaseName;
+       dsa_pointer userName;
+       dsa_pointer password;
+};
+
 struct SharedData {
        dsa_handle handle;
+       LWLock* lock;
        pid_t executorPID;
        pid_t serverPID;
        pid_t mainPID;
-       Oid databaseOID;
+       ConnectData connectData;
 };
 
 class Processor {
@@ -94,6 +116,7 @@ class Processor {
        const char* tag_;
        SharedData* sharedData_;
        dsa_area* area_;
+       LWLock* lock_;
 };
 
 class WorkerProcessor : public Processor {
@@ -110,6 +133,7 @@ class WorkerProcessor : public Processor {
                        elog(ERROR, "%s: %s: shared data isn't created yet", 
Tag, tag_);
                }
                auto area = dsa_attach(sharedData->handle);
+               lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
                LWLockRelease(AddinShmemInitLock);
                sharedData_ = sharedData;
                area_ = area;
@@ -123,8 +147,15 @@ class Executor : public WorkerProcessor {
        void open()
        {
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
opening").c_str());
-               BackgroundWorkerInitializeConnectionByOid(
-                       sharedData_->databaseOID, InvalidOid, 0);
+               LWLockAcquire(lock_, LW_EXCLUSIVE);
+               BackgroundWorkerInitializeConnection(
+                       static_cast<const char*>(
+                               dsa_get_address(area_, 
sharedData_->connectData.databaseName)),
+                       nullptr,
+                       0);
+               dsa_free(area_, sharedData_->connectData.databaseName);
+               sharedData_->connectData.databaseName = InvalidDsaPointer;
+               LWLockRelease(lock_);
                StartTransactionCommand();
                SPI_connect();
                PushActiveSnapshot(GetTransactionSnapshot());
@@ -147,8 +178,15 @@ class Proxy : public WorkerProcessor {
    public:
        explicit Proxy() : WorkerProcessor("proxy") {}
 
-       void connect()
+       void connect(const std::string& databaseName)
        {
+               LWLockAcquire(lock_, LW_EXCLUSIVE);
+               sharedData_->connectData.databaseName =
+                       dsa_allocate(area_, databaseName.size() + 1);
+               memcpy(dsa_get_address(area_, 
sharedData_->connectData.databaseName),
+                      databaseName.c_str(),
+                      databaseName.size() + 1);
+               LWLockRelease(lock_);
                kill(sharedData_->mainPID, SIGUSR1);
                std::unique_lock<std::mutex> lock(mutex_);
                condition_variable_.wait(lock,
@@ -184,6 +222,10 @@ class MainProcessor : public Processor {
                sharedData->executorPID = InvalidPid;
                sharedData->serverPID = InvalidPid;
                sharedData->mainPID = MyProcPid;
+               sharedData->connectData.databaseName = InvalidDsaPointer;
+               sharedData->connectData.userName = InvalidDsaPointer;
+               sharedData->connectData.password = InvalidDsaPointer;
+               lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
                LWLockRelease(AddinShmemInitLock);
                sharedData_ = sharedData;
                area_ = area;
@@ -212,6 +254,11 @@ class MainProcessor : public Processor {
 
        void process_connect_request()
        {
+               if (!DsaPointerIsValid(sharedData_->connectData.databaseName))
+               {
+                       return;
+               }
+
                BackgroundWorker worker = {0};
                snprintf(worker.bgw_name, BGW_MAXLEN, "%s: executor", Tag);
                snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
@@ -243,7 +290,8 @@ class AuthHandler : public arrow::flight::ServerAuthHandler 
{
        arrow::Status Authenticate(arrow::flight::ServerAuthSender* outgoing,
                                   arrow::flight::ServerAuthReader* incoming) 
override
        {
-               proxy_->connect();
+               std::string databaseName("postgres");
+               proxy_->connect(databaseName);
                return arrow::Status::OK();
        }
 
@@ -311,7 +359,19 @@ afs_executor(Datum arg)
                executor.open();
                while (!GotSIGTERM)
                {
-                       WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, 
-1, PG_WAIT_EXTENSION);
+                       int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
+                       const long timeout = SessionTimeout * 1000;
+                       if (timeout >= 0)
+                       {
+                               events |= WL_TIMEOUT;
+                       }
+                       int conditions = WaitLatch(MyLatch, events, timeout, 
PG_WAIT_EXTENSION);
+
+                       if (conditions & WL_TIMEOUT)
+                       {
+                               break;
+                       }
+
                        ResetLatch(MyLatch);
 
                        if (GotSIGUSR1)
@@ -396,6 +456,23 @@ _PG_init(void)
                                   NULL,
                                   NULL);
 
+       DefineCustomIntVariable("arrow_flight_sql.session_timeout",
+                               "Maximum session duration in seconds.",
+                               "The default is 300 seconds. "
+                               "-1 means no timeout.",
+                               &SessionTimeout,
+                               SessionTimeoutDefault,
+                               -1,
+                               INT_MAX,
+                               PGC_SIGHUP,
+                               GUC_UNIT_S,
+                               NULL,
+                               NULL,
+                               NULL);
+
+       PreviousShmemRequestHook = shmem_request_hook;
+       shmem_request_hook = afs_shmem_request_hook;
+
        BackgroundWorker worker = {0};
        snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag);
        snprintf(worker.bgw_type, BGW_MAXLEN, Tag);

Reply via email to