This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git
The following commit(s) were added to refs/heads/main by this push:
new 5c3876c Add support for session timeout (#16)
5c3876c is described below
commit 5c3876cb7c80cf3c3128f8bbfae9b55eea526802
Author: Sutou Kouhei <[email protected]>
AuthorDate: Sun Feb 12 16:13:09 2023 +0900
Add support for session timeout (#16)
Close GH-15
---
src/afs.cc | 97 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 87 insertions(+), 10 deletions(-)
diff --git a/src/afs.cc b/src/afs.cc
index 4fd3f76..9e40639 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -51,14 +51,17 @@ extern "C"
}
namespace {
+static const char* LibraryName = "arrow_flight_sql";
+static const char* SharedDataName = "arrow-flight-sql: shared data";
static const char* Tag = "arrow-flight-sql";
+
static const char* URIDefault = "grpc://127.0.0.1:15432";
static char* URI;
-static volatile sig_atomic_t GotSIGTERM = false;
-static volatile sig_atomic_t GotSIGUSR1 = false;
-static const char* LibraryName = "arrow_flight_sql";
-static const char* SharedDataName = "arrow-flight-sql: shared data";
+static const int SessionTimeoutDefault = 300;
+static int SessionTimeout;
+
+static volatile sig_atomic_t GotSIGTERM = false;
void afs_sigterm(SIGNAL_ARGS)
{
auto errnoSaved = errno;
@@ -67,6 +70,7 @@ void afs_sigterm(SIGNAL_ARGS)
errno = errnoSaved;
}
+static volatile sig_atomic_t GotSIGUSR1 = false;
void afs_sigusr1(SIGNAL_ARGS)
{
procsignal_sigusr1_handler(postgres_signal_arg);
@@ -76,12 +80,30 @@ void afs_sigusr1(SIGNAL_ARGS)
errno = errnoSaved;
}
+static shmem_request_hook_type PreviousShmemRequestHook = nullptr;
+static const char* LWLockTrancheName = "arrow-flight-sql: lwlock tranche";
+void
+afs_shmem_request_hook(void)
+{
+ if (PreviousShmemRequestHook)
+ PreviousShmemRequestHook();
+
+ RequestNamedLWLockTranche(LWLockTrancheName, 1);
+}
+
+struct ConnectData {
+ dsa_pointer databaseName;
+ dsa_pointer userName;
+ dsa_pointer password;
+};
+
struct SharedData {
dsa_handle handle;
+ LWLock* lock;
pid_t executorPID;
pid_t serverPID;
pid_t mainPID;
- Oid databaseOID;
+ ConnectData connectData;
};
class Processor {
@@ -94,6 +116,7 @@ class Processor {
const char* tag_;
SharedData* sharedData_;
dsa_area* area_;
+ LWLock* lock_;
};
class WorkerProcessor : public Processor {
@@ -110,6 +133,7 @@ class WorkerProcessor : public Processor {
elog(ERROR, "%s: %s: shared data isn't created yet",
Tag, tag_);
}
auto area = dsa_attach(sharedData->handle);
+ lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
sharedData_ = sharedData;
area_ = area;
@@ -123,8 +147,15 @@ class Executor : public WorkerProcessor {
void open()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
opening").c_str());
- BackgroundWorkerInitializeConnectionByOid(
- sharedData_->databaseOID, InvalidOid, 0);
+ LWLockAcquire(lock_, LW_EXCLUSIVE);
+ BackgroundWorkerInitializeConnection(
+ static_cast<const char*>(
+ dsa_get_address(area_,
sharedData_->connectData.databaseName)),
+ nullptr,
+ 0);
+ dsa_free(area_, sharedData_->connectData.databaseName);
+ sharedData_->connectData.databaseName = InvalidDsaPointer;
+ LWLockRelease(lock_);
StartTransactionCommand();
SPI_connect();
PushActiveSnapshot(GetTransactionSnapshot());
@@ -147,8 +178,15 @@ class Proxy : public WorkerProcessor {
public:
explicit Proxy() : WorkerProcessor("proxy") {}
- void connect()
+ void connect(const std::string& databaseName)
{
+ LWLockAcquire(lock_, LW_EXCLUSIVE);
+ sharedData_->connectData.databaseName =
+ dsa_allocate(area_, databaseName.size() + 1);
+ memcpy(dsa_get_address(area_,
sharedData_->connectData.databaseName),
+ databaseName.c_str(),
+ databaseName.size() + 1);
+ LWLockRelease(lock_);
kill(sharedData_->mainPID, SIGUSR1);
std::unique_lock<std::mutex> lock(mutex_);
condition_variable_.wait(lock,
@@ -184,6 +222,10 @@ class MainProcessor : public Processor {
sharedData->executorPID = InvalidPid;
sharedData->serverPID = InvalidPid;
sharedData->mainPID = MyProcPid;
+ sharedData->connectData.databaseName = InvalidDsaPointer;
+ sharedData->connectData.userName = InvalidDsaPointer;
+ sharedData->connectData.password = InvalidDsaPointer;
+ lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
sharedData_ = sharedData;
area_ = area;
@@ -212,6 +254,11 @@ class MainProcessor : public Processor {
void process_connect_request()
{
+ if (!DsaPointerIsValid(sharedData_->connectData.databaseName))
+ {
+ return;
+ }
+
BackgroundWorker worker = {0};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: executor", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
@@ -243,7 +290,8 @@ class AuthHandler : public arrow::flight::ServerAuthHandler
{
arrow::Status Authenticate(arrow::flight::ServerAuthSender* outgoing,
arrow::flight::ServerAuthReader* incoming)
override
{
- proxy_->connect();
+ std::string databaseName("postgres");
+ proxy_->connect(databaseName);
return arrow::Status::OK();
}
@@ -311,7 +359,19 @@ afs_executor(Datum arg)
executor.open();
while (!GotSIGTERM)
{
- WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH,
-1, PG_WAIT_EXTENSION);
+ int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
+ const long timeout = SessionTimeout * 1000;
+ if (timeout >= 0)
+ {
+ events |= WL_TIMEOUT;
+ }
+ int conditions = WaitLatch(MyLatch, events, timeout,
PG_WAIT_EXTENSION);
+
+ if (conditions & WL_TIMEOUT)
+ {
+ break;
+ }
+
ResetLatch(MyLatch);
if (GotSIGUSR1)
@@ -396,6 +456,23 @@ _PG_init(void)
NULL,
NULL);
+ DefineCustomIntVariable("arrow_flight_sql.session_timeout",
+ "Maximum session duration in seconds.",
+ "The default is 300 seconds. "
+ "-1 means no timeout.",
+ &SessionTimeout,
+ SessionTimeoutDefault,
+ -1,
+ INT_MAX,
+ PGC_SIGHUP,
+ GUC_UNIT_S,
+ NULL,
+ NULL,
+ NULL);
+
+ PreviousShmemRequestHook = shmem_request_hook;
+ shmem_request_hook = afs_shmem_request_hook;
+
BackgroundWorker worker = {0};
snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag);
snprintf(worker.bgw_type, BGW_MAXLEN, Tag);