This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git
The following commit(s) were added to refs/heads/main by this push:
new 0b473e9 Add support for opening a database (#14)
0b473e9 is described below
commit 0b473e9e6e28dddc343378b686e086e64c159d67
Author: Sutou Kouhei <[email protected]>
AuthorDate: Fri Feb 10 15:26:41 2023 +0900
Add support for opening a database (#14)
---
src/afs.cc | 352 ++++++++++++++++++++++++++++++++++++++++++------
test/helper/sandbox.rb | 11 +-
test/test-flight-sql.rb | 4 +
3 files changed, 326 insertions(+), 41 deletions(-)
diff --git a/src/afs.cc b/src/afs.cc
index a1e016a..4fd3f76 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -19,93 +19,367 @@ extern "C"
{
#include <postgres.h>
+#include <access/xact.h>
+#include <executor/spi.h>
#include <fmgr.h>
#include <miscadmin.h>
#include <postmaster/bgworker.h>
#include <storage/ipc.h>
#include <storage/latch.h>
+#include <storage/lwlock.h>
+#include <storage/procsignal.h>
+#include <storage/shmem.h>
+#include <utils/backend_status.h>
+#include <utils/dsa.h>
#include <utils/guc.h>
+#include <utils/snapmgr.h>
#include <utils/wait_event.h>
}
#include <arrow/flight/sql/server.h>
+#include <condition_variable>
+
extern "C"
{
PG_MODULE_MAGIC;
extern PGDLLEXPORT void _PG_init(void);
- extern PGDLLEXPORT void afs_listener(Datum datum)
pg_attribute_noreturn();
+ extern PGDLLEXPORT void afs_executor(Datum datum)
pg_attribute_noreturn();
+ extern PGDLLEXPORT void afs_server(Datum datum) pg_attribute_noreturn();
+ extern PGDLLEXPORT void afs_main(Datum datum) pg_attribute_noreturn();
}
-#define TAG "arrow-flight-sql"
-#define AFSURIDefault "grpc://127.0.0.1:15432"
-
namespace {
-static char* AFSURI;
-static volatile sig_atomic_t AFSGotSIGTERM = false;
-static const char* AFSLibraryName = "arrow_flight_sql";
+static const char* Tag = "arrow-flight-sql";
+static const char* URIDefault = "grpc://127.0.0.1:15432";
+static char* URI;
+static volatile sig_atomic_t GotSIGTERM = false;
+static volatile sig_atomic_t GotSIGUSR1 = false;
+static const char* LibraryName = "arrow_flight_sql";
+static const char* SharedDataName = "arrow-flight-sql: shared data";
void afs_sigterm(SIGNAL_ARGS)
{
- auto save_errno = errno;
-
- AFSGotSIGTERM = true;
+ auto errnoSaved = errno;
+ GotSIGTERM = true;
SetLatch(MyLatch);
+ errno = errnoSaved;
+}
- errno = save_errno;
+void afs_sigusr1(SIGNAL_ARGS)
+{
+ procsignal_sigusr1_handler(postgres_signal_arg);
+ GotSIGUSR1 = true;
+ auto errnoSaved = errno;
+ SetLatch(MyLatch);
+ errno = errnoSaved;
}
-class PostgreSQLFlightSqlServer : public
arrow::flight::sql::FlightSqlServerBase {
+struct SharedData {
+ dsa_handle handle;
+ pid_t executorPID;
+ pid_t serverPID;
+ pid_t mainPID;
+ Oid databaseOID;
+};
+
+class Processor {
+ public:
+ Processor(const char* tag) : tag_(tag), sharedData_(nullptr),
area_(nullptr) {}
+
+ virtual ~Processor() { dsa_detach(area_); }
+
+ protected:
+ const char* tag_;
+ SharedData* sharedData_;
+ dsa_area* area_;
+};
+
+class WorkerProcessor : public Processor {
+ public:
+ explicit WorkerProcessor(const char* tag) : Processor(tag)
+ {
+ LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
+ bool found;
+ auto sharedData = static_cast<SharedData*>(
+ ShmemInitStruct(SharedDataName, sizeof(SharedData),
&found));
+ if (!found)
+ {
+ LWLockRelease(AddinShmemInitLock);
+ elog(ERROR, "%s: %s: shared data isn't created yet",
Tag, tag_);
+ }
+ auto area = dsa_attach(sharedData->handle);
+ LWLockRelease(AddinShmemInitLock);
+ sharedData_ = sharedData;
+ area_ = area;
+ }
+};
+
+class Executor : public WorkerProcessor {
+ public:
+ explicit Executor() : WorkerProcessor("executor") {}
+
+ void open()
+ {
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
opening").c_str());
+ BackgroundWorkerInitializeConnectionByOid(
+ sharedData_->databaseOID, InvalidOid, 0);
+ StartTransactionCommand();
+ SPI_connect();
+ PushActiveSnapshot(GetTransactionSnapshot());
+ pgstat_report_activity(STATE_IDLE, NULL);
+ }
+
+ void close()
+ {
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
closing").c_str());
+ PopActiveSnapshot();
+ SPI_finish();
+ CommitTransactionCommand();
+ pgstat_report_activity(STATE_IDLE, NULL);
+ }
+
+ void execute() {}
+};
+
+class Proxy : public WorkerProcessor {
+ public:
+ explicit Proxy() : WorkerProcessor("proxy") {}
+
+ void connect()
+ {
+ kill(sharedData_->mainPID, SIGUSR1);
+ std::unique_lock<std::mutex> lock(mutex_);
+ condition_variable_.wait(lock,
+ [&] { return sharedData_->executorPID
!= InvalidPid; });
+ }
+
+ void signaled()
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ condition_variable_.notify_all();
+ }
+
+ private:
+ std::mutex mutex_;
+ std::condition_variable condition_variable_;
+};
+
+class MainProcessor : public Processor {
+ public:
+ MainProcessor() : Processor("main")
+ {
+ LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
+ bool found;
+ auto sharedData = static_cast<SharedData*>(
+ ShmemInitStruct(SharedDataName, sizeof(SharedData),
&found));
+ if (found)
+ {
+ LWLockRelease(AddinShmemInitLock);
+ elog(ERROR, "%s: %s: shared data is already created",
Tag, tag_);
+ }
+ auto area = dsa_create(LWLockNewTrancheId());
+ sharedData->handle = dsa_get_handle(area);
+ sharedData->executorPID = InvalidPid;
+ sharedData->serverPID = InvalidPid;
+ sharedData->mainPID = MyProcPid;
+ LWLockRelease(AddinShmemInitLock);
+ sharedData_ = sharedData;
+ area_ = area;
+ }
+
+ BackgroundWorkerHandle* start_server()
+ {
+ BackgroundWorker worker = {0};
+ snprintf(worker.bgw_name, BGW_MAXLEN, "%s: server", Tag);
+ snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
+ worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
+ worker.bgw_start_time = BgWorkerStart_ConsistentState;
+ worker.bgw_restart_time = BGW_NEVER_RESTART;
+ snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s",
LibraryName);
+ snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_server");
+ worker.bgw_main_arg = 0;
+ worker.bgw_notify_pid = MyProcPid;
+ BackgroundWorkerHandle* handle;
+ if (!RegisterDynamicBackgroundWorker(&worker, &handle))
+ {
+ elog(ERROR, "%s: %s: failed to start server", Tag,
tag_);
+ }
+ WaitForBackgroundWorkerStartup(handle,
&(sharedData_->serverPID));
+ return handle;
+ }
+
+ void process_connect_request()
+ {
+ BackgroundWorker worker = {0};
+ snprintf(worker.bgw_name, BGW_MAXLEN, "%s: executor", Tag);
+ snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
+ worker.bgw_flags = BGWORKER_SHMEM_ACCESS |
BGWORKER_BACKEND_DATABASE_CONNECTION;
+ worker.bgw_start_time = BgWorkerStart_ConsistentState;
+ worker.bgw_restart_time = BGW_NEVER_RESTART;
+ snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s",
LibraryName);
+ snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_executor");
+ worker.bgw_main_arg = 0;
+ worker.bgw_notify_pid = MyProcPid;
+ BackgroundWorkerHandle* handle;
+ if (!RegisterDynamicBackgroundWorker(&worker, &handle))
+ {
+ elog(ERROR, "%s: %s: failed to start executor", Tag,
tag_);
+ }
+ WaitForBackgroundWorkerStartup(handle,
&(sharedData_->executorPID));
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+};
+
+class AuthHandler : public arrow::flight::ServerAuthHandler {
+ public:
+ explicit AuthHandler(Proxy* proxy) :
arrow::flight::ServerAuthHandler(), proxy_(proxy)
+ {
+ }
+
+ ~AuthHandler() override {}
+
+ arrow::Status Authenticate(arrow::flight::ServerAuthSender* outgoing,
+ arrow::flight::ServerAuthReader* incoming)
override
+ {
+ proxy_->connect();
+ return arrow::Status::OK();
+ }
+
+ arrow::Status IsValid(const std::string& token, std::string*
peer_identity) override
+ {
+ *peer_identity = "postgres";
+ return arrow::Status::OK();
+ }
+
+ private:
+ Proxy* proxy_;
+};
+
+class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
public:
- PostgreSQLFlightSqlServer() : arrow::flight::sql::FlightSqlServerBase()
{}
- ~PostgreSQLFlightSqlServer() override {}
+ explicit FlightSQLServer(Proxy* proxy)
+ : arrow::flight::sql::FlightSqlServerBase(), proxy_(proxy)
+ {
+ }
+
+ ~FlightSQLServer() override {}
+
+ private:
+ Proxy* proxy_;
};
arrow::Status
-afs_listen_internal(void)
+afs_server_internal(Proxy* proxy)
{
- ARROW_ASSIGN_OR_RAISE(auto location,
arrow::flight::Location::Parse(AFSURI));
+ ARROW_ASSIGN_OR_RAISE(auto location,
arrow::flight::Location::Parse(URI));
arrow::flight::FlightServerOptions options(location);
- PostgreSQLFlightSqlServer server;
- ARROW_RETURN_NOT_OK(server.Init(options));
+ options.auth_handler = std::make_shared<AuthHandler>(proxy);
+ FlightSQLServer flightSQLServer(proxy);
+ ARROW_RETURN_NOT_OK(flightSQLServer.Init(options));
- while (!AFSGotSIGTERM)
+ while (!GotSIGTERM)
{
- WaitLatch(MyLatch,
- WL_LATCH_SET | WL_TIMEOUT | WL_EXIT_ON_PM_DEATH,
- 0,
- PG_WAIT_EXTENSION);
+ WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1,
PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
+ if (GotSIGUSR1)
+ {
+ GotSIGUSR1 = false;
+ proxy->signaled();
+ }
+
CHECK_FOR_INTERRUPTS();
}
- return server.Shutdown();
+ auto deadline = std::chrono::system_clock::now() +
std::chrono::microseconds(10);
+ return flightSQLServer.Shutdown(&deadline);
}
} // namespace
extern "C" void
-afs_listener(Datum arg)
+afs_executor(Datum arg)
{
pqsignal(SIGTERM, afs_sigterm);
+ pqsignal(SIGUSR1, afs_sigusr1);
BackgroundWorkerUnblockSignals();
- auto status = afs_listen_internal();
- if (!status.ok())
{
- elog(ERROR, "%s: listener: failed: %s",
status.ToString().c_str());
+ Executor executor;
+ executor.open();
+ while (!GotSIGTERM)
+ {
+ WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH,
-1, PG_WAIT_EXTENSION);
+ ResetLatch(MyLatch);
+
+ if (GotSIGUSR1)
+ {
+ GotSIGUSR1 = false;
+ executor.execute();
+ }
+
+ CHECK_FOR_INTERRUPTS();
+ }
+ executor.close();
}
- proc_exit(1);
+ proc_exit(0);
}
extern "C" void
-_PG_init(void)
+afs_server(Datum arg)
{
- BackgroundWorker worker = {0};
+ pqsignal(SIGTERM, afs_sigterm);
+ pqsignal(SIGUSR1, afs_sigusr1);
+ BackgroundWorkerUnblockSignals();
+
+ {
+ Proxy proxy;
+ auto status = afs_server_internal(&proxy);
+ if (!status.ok())
+ {
+ elog(ERROR, "%s: server: failed: %s", Tag,
status.ToString().c_str());
+ }
+ }
+
+ proc_exit(0);
+}
+
+extern "C" void
+afs_main(Datum arg)
+{
+ pqsignal(SIGTERM, afs_sigterm);
+ pqsignal(SIGUSR1, afs_sigusr1);
+ BackgroundWorkerUnblockSignals();
+
+ {
+ MainProcessor processor;
+ auto serverHandle = processor.start_server();
+ while (!GotSIGTERM)
+ {
+ WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH,
-1, PG_WAIT_EXTENSION);
+ ResetLatch(MyLatch);
+
+ if (GotSIGUSR1)
+ {
+ GotSIGUSR1 = false;
+ processor.process_connect_request();
+ }
+
+ CHECK_FOR_INTERRUPTS();
+ }
+
+ WaitForBackgroundWorkerShutdown(serverHandle);
+ }
+
+ proc_exit(0);
+}
+extern "C" void
+_PG_init(void)
+{
if (!process_shared_preload_libraries_in_progress)
{
return;
@@ -113,24 +387,24 @@ _PG_init(void)
DefineCustomStringVariable("arrow_flight_sql.uri",
"Apache Arrow Flight SQL endpoint URI.",
- "default: " AFSURIDefault,
- &AFSURI,
- AFSURIDefault,
+ (std::string("default: ") +
URIDefault).c_str(),
+ &URI,
+ URIDefault,
PGC_USERSET,
0,
NULL,
NULL,
NULL);
- snprintf(worker.bgw_name, BGW_MAXLEN, TAG ": listener");
- snprintf(worker.bgw_type, BGW_MAXLEN, TAG);
+ BackgroundWorker worker = {0};
+ snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag);
+ snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
worker.bgw_start_time = BgWorkerStart_ConsistentState;
worker.bgw_restart_time = BGW_NEVER_RESTART;
- snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", AFSLibraryName);
- snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_listener");
+ snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
+ snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_main");
worker.bgw_main_arg = 0;
worker.bgw_notify_pid = 0;
-
RegisterBackgroundWorker(&worker);
}
diff --git a/test/helper/sandbox.rb b/test/helper/sandbox.rb
index 90bcf37..f7b0d4a 100644
--- a/test/helper/sandbox.rb
+++ b/test/helper/sandbox.rb
@@ -186,9 +186,12 @@ module Helper
[output, error]
end
+ def flight_client
+ @flight_client ||= ArrowFlight::Client.new(@flight_sql_uri)
+ end
+
def flight_sql_client
- client = ArrowFlight::Client.new(@flight_sql_uri)
- ArrowFlightSQL::Client.new(client)
+ @flight_sql_client ||= ArrowFlightSQL::Client.new(flight_client)
end
def read_log
@@ -231,6 +234,10 @@ module Helper
psql(@test_db_name, sql)
end
+ def flight_client
+ @postgresql.flight_client
+ end
+
def flight_sql_client
@postgresql.flight_sql_client
end
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index e74723b..5407485 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -19,6 +19,10 @@ class FlightSQLTest < Test::Unit::TestCase
include Helper::Sandbox
def test_connect
+ unless flight_client.respond_to?(:authenticate_basic_token)
+ omit("red-flight-sql 12.0.0 or later is required")
+ end
+ flight_client.authenticate_basic_token(@postgresql.user, "password")
exception = assert_raise(Arrow::Error::NotImplemented) do
flight_sql_client.execute("SELECT 1")
end