This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b473e9  Add support for opening a database (#14)
0b473e9 is described below

commit 0b473e9e6e28dddc343378b686e086e64c159d67
Author: Sutou Kouhei <[email protected]>
AuthorDate: Fri Feb 10 15:26:41 2023 +0900

    Add support for opening a database (#14)
---
 src/afs.cc              | 352 ++++++++++++++++++++++++++++++++++++++++++------
 test/helper/sandbox.rb  |  11 +-
 test/test-flight-sql.rb |   4 +
 3 files changed, 326 insertions(+), 41 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index a1e016a..4fd3f76 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -19,93 +19,367 @@ extern "C"
 {
 #include <postgres.h>
 
+#include <access/xact.h>
+#include <executor/spi.h>
 #include <fmgr.h>
 #include <miscadmin.h>
 #include <postmaster/bgworker.h>
 #include <storage/ipc.h>
 #include <storage/latch.h>
+#include <storage/lwlock.h>
+#include <storage/procsignal.h>
+#include <storage/shmem.h>
+#include <utils/backend_status.h>
+#include <utils/dsa.h>
 #include <utils/guc.h>
+#include <utils/snapmgr.h>
 #include <utils/wait_event.h>
 }
 
 #include <arrow/flight/sql/server.h>
 
+#include <condition_variable>
+
 extern "C"
 {
        PG_MODULE_MAGIC;
 
        extern PGDLLEXPORT void _PG_init(void);
-       extern PGDLLEXPORT void afs_listener(Datum datum) 
pg_attribute_noreturn();
+       extern PGDLLEXPORT void afs_executor(Datum datum) 
pg_attribute_noreturn();
+       extern PGDLLEXPORT void afs_server(Datum datum) pg_attribute_noreturn();
+       extern PGDLLEXPORT void afs_main(Datum datum) pg_attribute_noreturn();
 }
 
-#define TAG "arrow-flight-sql"
-#define AFSURIDefault "grpc://127.0.0.1:15432"
-
 namespace {
-static char* AFSURI;
-static volatile sig_atomic_t AFSGotSIGTERM = false;
-static const char* AFSLibraryName = "arrow_flight_sql";
+static const char* Tag = "arrow-flight-sql";
+static const char* URIDefault = "grpc://127.0.0.1:15432";
+static char* URI;
+static volatile sig_atomic_t GotSIGTERM = false;
+static volatile sig_atomic_t GotSIGUSR1 = false;
+static const char* LibraryName = "arrow_flight_sql";
+static const char* SharedDataName = "arrow-flight-sql: shared data";
 
 void afs_sigterm(SIGNAL_ARGS)
 {
-       auto save_errno = errno;
-
-       AFSGotSIGTERM = true;
+       auto errnoSaved = errno;
+       GotSIGTERM = true;
        SetLatch(MyLatch);
+       errno = errnoSaved;
+}
 
-       errno = save_errno;
+void afs_sigusr1(SIGNAL_ARGS)
+{
+       procsignal_sigusr1_handler(postgres_signal_arg);
+       GotSIGUSR1 = true;
+       auto errnoSaved = errno;
+       SetLatch(MyLatch);
+       errno = errnoSaved;
 }
 
-class PostgreSQLFlightSqlServer : public 
arrow::flight::sql::FlightSqlServerBase {
+struct SharedData {
+       dsa_handle handle;
+       pid_t executorPID;
+       pid_t serverPID;
+       pid_t mainPID;
+       Oid databaseOID;
+};
+
+class Processor {
+   public:
+       Processor(const char* tag) : tag_(tag), sharedData_(nullptr), 
area_(nullptr) {}
+
+       virtual ~Processor() { dsa_detach(area_); }
+
+   protected:
+       const char* tag_;
+       SharedData* sharedData_;
+       dsa_area* area_;
+};
+
+class WorkerProcessor : public Processor {
+   public:
+       explicit WorkerProcessor(const char* tag) : Processor(tag)
+       {
+               LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
+               bool found;
+               auto sharedData = static_cast<SharedData*>(
+                       ShmemInitStruct(SharedDataName, sizeof(SharedData), 
&found));
+               if (!found)
+               {
+                       LWLockRelease(AddinShmemInitLock);
+                       elog(ERROR, "%s: %s: shared data isn't created yet", 
Tag, tag_);
+               }
+               auto area = dsa_attach(sharedData->handle);
+               LWLockRelease(AddinShmemInitLock);
+               sharedData_ = sharedData;
+               area_ = area;
+       }
+};
+
+class Executor : public WorkerProcessor {
+   public:
+       explicit Executor() : WorkerProcessor("executor") {}
+
+       void open()
+       {
+               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
opening").c_str());
+               BackgroundWorkerInitializeConnectionByOid(
+                       sharedData_->databaseOID, InvalidOid, 0);
+               StartTransactionCommand();
+               SPI_connect();
+               PushActiveSnapshot(GetTransactionSnapshot());
+               pgstat_report_activity(STATE_IDLE, NULL);
+       }
+
+       void close()
+       {
+               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
closing").c_str());
+               PopActiveSnapshot();
+               SPI_finish();
+               CommitTransactionCommand();
+               pgstat_report_activity(STATE_IDLE, NULL);
+       }
+
+       void execute() {}
+};
+
+class Proxy : public WorkerProcessor {
+   public:
+       explicit Proxy() : WorkerProcessor("proxy") {}
+
+       void connect()
+       {
+               kill(sharedData_->mainPID, SIGUSR1);
+               std::unique_lock<std::mutex> lock(mutex_);
+               condition_variable_.wait(lock,
+                                        [&] { return sharedData_->executorPID 
!= InvalidPid; });
+       }
+
+       void signaled()
+       {
+               std::lock_guard<std::mutex> lock(mutex_);
+               condition_variable_.notify_all();
+       }
+
+   private:
+       std::mutex mutex_;
+       std::condition_variable condition_variable_;
+};
+
+class MainProcessor : public Processor {
+   public:
+       MainProcessor() : Processor("main")
+       {
+               LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
+               bool found;
+               auto sharedData = static_cast<SharedData*>(
+                       ShmemInitStruct(SharedDataName, sizeof(SharedData), 
&found));
+               if (found)
+               {
+                       LWLockRelease(AddinShmemInitLock);
+                       elog(ERROR, "%s: %s: shared data is already created", 
Tag, tag_);
+               }
+               auto area = dsa_create(LWLockNewTrancheId());
+               sharedData->handle = dsa_get_handle(area);
+               sharedData->executorPID = InvalidPid;
+               sharedData->serverPID = InvalidPid;
+               sharedData->mainPID = MyProcPid;
+               LWLockRelease(AddinShmemInitLock);
+               sharedData_ = sharedData;
+               area_ = area;
+       }
+
+       BackgroundWorkerHandle* start_server()
+       {
+               BackgroundWorker worker = {0};
+               snprintf(worker.bgw_name, BGW_MAXLEN, "%s: server", Tag);
+               snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
+               worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
+               worker.bgw_start_time = BgWorkerStart_ConsistentState;
+               worker.bgw_restart_time = BGW_NEVER_RESTART;
+               snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", 
LibraryName);
+               snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_server");
+               worker.bgw_main_arg = 0;
+               worker.bgw_notify_pid = MyProcPid;
+               BackgroundWorkerHandle* handle;
+               if (!RegisterDynamicBackgroundWorker(&worker, &handle))
+               {
+                       elog(ERROR, "%s: %s: failed to start server", Tag, 
tag_);
+               }
+               WaitForBackgroundWorkerStartup(handle, 
&(sharedData_->serverPID));
+               return handle;
+       }
+
+       void process_connect_request()
+       {
+               BackgroundWorker worker = {0};
+               snprintf(worker.bgw_name, BGW_MAXLEN, "%s: executor", Tag);
+               snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
+               worker.bgw_flags = BGWORKER_SHMEM_ACCESS | 
BGWORKER_BACKEND_DATABASE_CONNECTION;
+               worker.bgw_start_time = BgWorkerStart_ConsistentState;
+               worker.bgw_restart_time = BGW_NEVER_RESTART;
+               snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", 
LibraryName);
+               snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_executor");
+               worker.bgw_main_arg = 0;
+               worker.bgw_notify_pid = MyProcPid;
+               BackgroundWorkerHandle* handle;
+               if (!RegisterDynamicBackgroundWorker(&worker, &handle))
+               {
+                       elog(ERROR, "%s: %s: failed to start executor", Tag, 
tag_);
+               }
+               WaitForBackgroundWorkerStartup(handle, 
&(sharedData_->executorPID));
+               kill(sharedData_->serverPID, SIGUSR1);
+       }
+};
+
+class AuthHandler : public arrow::flight::ServerAuthHandler {
+   public:
+       explicit AuthHandler(Proxy* proxy) : 
arrow::flight::ServerAuthHandler(), proxy_(proxy)
+       {
+       }
+
+       ~AuthHandler() override {}
+
+       arrow::Status Authenticate(arrow::flight::ServerAuthSender* outgoing,
+                                  arrow::flight::ServerAuthReader* incoming) 
override
+       {
+               proxy_->connect();
+               return arrow::Status::OK();
+       }
+
+       arrow::Status IsValid(const std::string& token, std::string* 
peer_identity) override
+       {
+               *peer_identity = "postgres";
+               return arrow::Status::OK();
+       }
+
+   private:
+       Proxy* proxy_;
+};
+
+class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
    public:
-       PostgreSQLFlightSqlServer() : arrow::flight::sql::FlightSqlServerBase() 
{}
-       ~PostgreSQLFlightSqlServer() override {}
+       explicit FlightSQLServer(Proxy* proxy)
+               : arrow::flight::sql::FlightSqlServerBase(), proxy_(proxy)
+       {
+       }
+
+       ~FlightSQLServer() override {}
+
+   private:
+       Proxy* proxy_;
 };
 
 arrow::Status
-afs_listen_internal(void)
+afs_server_internal(Proxy* proxy)
 {
-       ARROW_ASSIGN_OR_RAISE(auto location, 
arrow::flight::Location::Parse(AFSURI));
+       ARROW_ASSIGN_OR_RAISE(auto location, 
arrow::flight::Location::Parse(URI));
        arrow::flight::FlightServerOptions options(location);
-       PostgreSQLFlightSqlServer server;
-       ARROW_RETURN_NOT_OK(server.Init(options));
+       options.auth_handler = std::make_shared<AuthHandler>(proxy);
+       FlightSQLServer flightSQLServer(proxy);
+       ARROW_RETURN_NOT_OK(flightSQLServer.Init(options));
 
-       while (!AFSGotSIGTERM)
+       while (!GotSIGTERM)
        {
-               WaitLatch(MyLatch,
-                         WL_LATCH_SET | WL_TIMEOUT | WL_EXIT_ON_PM_DEATH,
-                         0,
-                         PG_WAIT_EXTENSION);
+               WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, -1, 
PG_WAIT_EXTENSION);
                ResetLatch(MyLatch);
 
+               if (GotSIGUSR1)
+               {
+                       GotSIGUSR1 = false;
+                       proxy->signaled();
+               }
+
                CHECK_FOR_INTERRUPTS();
        }
 
-       return server.Shutdown();
+       auto deadline = std::chrono::system_clock::now() + 
std::chrono::microseconds(10);
+       return flightSQLServer.Shutdown(&deadline);
 }
 
 }  // namespace
 
 extern "C" void
-afs_listener(Datum arg)
+afs_executor(Datum arg)
 {
        pqsignal(SIGTERM, afs_sigterm);
+       pqsignal(SIGUSR1, afs_sigusr1);
        BackgroundWorkerUnblockSignals();
 
-       auto status = afs_listen_internal();
-       if (!status.ok())
        {
-               elog(ERROR, "%s: listener: failed: %s", 
status.ToString().c_str());
+               Executor executor;
+               executor.open();
+               while (!GotSIGTERM)
+               {
+                       WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, 
-1, PG_WAIT_EXTENSION);
+                       ResetLatch(MyLatch);
+
+                       if (GotSIGUSR1)
+                       {
+                               GotSIGUSR1 = false;
+                               executor.execute();
+                       }
+
+                       CHECK_FOR_INTERRUPTS();
+               }
+               executor.close();
        }
 
-       proc_exit(1);
+       proc_exit(0);
 }
 
 extern "C" void
-_PG_init(void)
+afs_server(Datum arg)
 {
-       BackgroundWorker worker = {0};
+       pqsignal(SIGTERM, afs_sigterm);
+       pqsignal(SIGUSR1, afs_sigusr1);
+       BackgroundWorkerUnblockSignals();
+
+       {
+               Proxy proxy;
+               auto status = afs_server_internal(&proxy);
+               if (!status.ok())
+               {
+                       elog(ERROR, "%s: server: failed: %s", Tag, 
status.ToString().c_str());
+               }
+       }
+
+       proc_exit(0);
+}
+
+extern "C" void
+afs_main(Datum arg)
+{
+       pqsignal(SIGTERM, afs_sigterm);
+       pqsignal(SIGUSR1, afs_sigusr1);
+       BackgroundWorkerUnblockSignals();
+
+       {
+               MainProcessor processor;
+               auto serverHandle = processor.start_server();
+               while (!GotSIGTERM)
+               {
+                       WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, 
-1, PG_WAIT_EXTENSION);
+                       ResetLatch(MyLatch);
+
+                       if (GotSIGUSR1)
+                       {
+                               GotSIGUSR1 = false;
+                               processor.process_connect_request();
+                       }
+
+                       CHECK_FOR_INTERRUPTS();
+               }
+
+               WaitForBackgroundWorkerShutdown(serverHandle);
+       }
+
+       proc_exit(0);
+}
 
+extern "C" void
+_PG_init(void)
+{
        if (!process_shared_preload_libraries_in_progress)
        {
                return;
@@ -113,24 +387,24 @@ _PG_init(void)
 
        DefineCustomStringVariable("arrow_flight_sql.uri",
                                   "Apache Arrow Flight SQL endpoint URI.",
-                                  "default: " AFSURIDefault,
-                                  &AFSURI,
-                                  AFSURIDefault,
+                                  (std::string("default: ") + 
URIDefault).c_str(),
+                                  &URI,
+                                  URIDefault,
                                   PGC_USERSET,
                                   0,
                                   NULL,
                                   NULL,
                                   NULL);
 
-       snprintf(worker.bgw_name, BGW_MAXLEN, TAG ": listener");
-       snprintf(worker.bgw_type, BGW_MAXLEN, TAG);
+       BackgroundWorker worker = {0};
+       snprintf(worker.bgw_name, BGW_MAXLEN, "%s: main", Tag);
+       snprintf(worker.bgw_type, BGW_MAXLEN, Tag);
        worker.bgw_flags = BGWORKER_SHMEM_ACCESS;
        worker.bgw_start_time = BgWorkerStart_ConsistentState;
        worker.bgw_restart_time = BGW_NEVER_RESTART;
-       snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", AFSLibraryName);
-       snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_listener");
+       snprintf(worker.bgw_library_name, BGW_MAXLEN, "%s", LibraryName);
+       snprintf(worker.bgw_function_name, BGW_MAXLEN, "afs_main");
        worker.bgw_main_arg = 0;
        worker.bgw_notify_pid = 0;
-
        RegisterBackgroundWorker(&worker);
 }
diff --git a/test/helper/sandbox.rb b/test/helper/sandbox.rb
index 90bcf37..f7b0d4a 100644
--- a/test/helper/sandbox.rb
+++ b/test/helper/sandbox.rb
@@ -186,9 +186,12 @@ module Helper
       [output, error]
     end
 
+    def flight_client
+      @flight_client ||= ArrowFlight::Client.new(@flight_sql_uri)
+    end
+
     def flight_sql_client
-      client = ArrowFlight::Client.new(@flight_sql_uri)
-      ArrowFlightSQL::Client.new(client)
+      @flight_sql_client ||= ArrowFlightSQL::Client.new(flight_client)
     end
 
     def read_log
@@ -231,6 +234,10 @@ module Helper
       psql(@test_db_name, sql)
     end
 
+    def flight_client
+      @postgresql.flight_client
+    end
+
     def flight_sql_client
       @postgresql.flight_sql_client
     end
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index e74723b..5407485 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -19,6 +19,10 @@ class FlightSQLTest < Test::Unit::TestCase
   include Helper::Sandbox
 
   def test_connect
+    unless flight_client.respond_to?(:authenticate_basic_token)
+      omit("red-flight-sql 12.0.0 or later is required")
+    end
+    flight_client.authenticate_basic_token(@postgresql.user, "password")
     exception = assert_raise(Arrow::Error::NotImplemented) do
       flight_sql_client.execute("SELECT 1")
     end

Reply via email to