This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new 5416129  Add support for authentication (#39)
5416129 is described below

commit 5416129f580c6a9ebd2db74266b7de057f3a6f4a
Author: Sutou Kouhei <[email protected]>
AuthorDate: Mon May 8 21:19:26 2023 +0900

    Add support for authentication (#39)
    
    Closes GH-18
    
    This only supports "password" and "trust" for now.
---
 .github/workflows/test.yaml |   1 +
 src/afs.cc                  | 374 ++++++++++++++++++++++++++++++++++++--------
 test/helper/sandbox.rb      |  34 +++-
 test/test-flight-sql.rb     |   2 +-
 4 files changed, 340 insertions(+), 71 deletions(-)

diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index 0b8436f..74ab677 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -61,6 +61,7 @@ jobs:
           wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc 
| sudo apt-key add -
           sudo apt update
           sudo apt -y -V -t ${suite} install \
+            libkrb5-dev \
             postgresql-${{ matrix.postgresql-version }} \
             postgresql-server-dev-${{ matrix.postgresql-version }}
       - name: Install Apache Arrow Flight SQL adapter
diff --git a/src/afs.cc b/src/afs.cc
index b33546f..92b7e10 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -23,6 +23,8 @@ extern "C"
 #include <executor/spi.h>
 #include <fmgr.h>
 #include <lib/dshash.h>
+#include <libpq/crypt.h>
+#include <libpq/libpq-be.h>
 #include <miscadmin.h>
 #include <postmaster/bgworker.h>
 #include <storage/ipc.h>
@@ -33,6 +35,7 @@ extern "C"
 #include <utils/backend_status.h>
 #include <utils/dsa.h>
 #include <utils/guc.h>
+#include <utils/memutils.h>
 #include <utils/snapmgr.h>
 #include <utils/wait_event.h>
 }
@@ -54,6 +57,8 @@ extern "C"
 #include <random>
 #include <sstream>
 
+#include <arpa/inet.h>
+
 #ifdef __GNUC__
 #      define AFS_FUNC __PRETTY_FUNCTION__
 #else
@@ -129,6 +134,25 @@ afs_shmem_request_hook(void)
        RequestNamedLWLockTranche(LWLockTrancheName, 1);
 }
 
+class ScopedMemoryContext {
+   public:
+       explicit ScopedMemoryContext(MemoryContext memoryContext)
+               : memoryContext_(memoryContext), oldMemoryContext_(nullptr)
+       {
+               oldMemoryContext_ = MemoryContextSwitchTo(memoryContext_);
+       }
+
+       ~ScopedMemoryContext()
+       {
+               MemoryContextSwitchTo(oldMemoryContext_);
+               MemoryContextDelete(memoryContext_);
+       }
+
+   private:
+       MemoryContext memoryContext_;
+       MemoryContext oldMemoryContext_;
+};
+
 struct SharedRingBufferData {
        dsa_pointer pointer;
        size_t total;
@@ -304,12 +328,13 @@ class SharedRingBuffer {
 
 struct SessionData {
        uint64_t id;
-       arrow::Status status;
+       dsa_pointer errorMessage;
        pid_t executorPID;
        bool initialized;
        dsa_pointer databaseName;
        dsa_pointer userName;
        dsa_pointer password;
+       dsa_pointer clientAddress;
        dsa_pointer query;
        SharedRingBufferData bufferData;
 };
@@ -377,6 +402,21 @@ class Processor {
        }
 
    protected:
+       void set_shared_string(dsa_pointer& pointer, const std::string& input)
+       {
+               if (DsaPointerIsValid(pointer))
+               {
+                       dsa_free(area_, pointer);
+                       pointer = InvalidDsaPointer;
+               }
+               if (input.empty())
+               {
+                       return;
+               }
+               pointer = dsa_allocate(area_, input.size() + 1);
+               memcpy(dsa_get_address(area_, pointer), input.c_str(), 
input.size() + 1);
+       }
+
        const char* tag_;
        SharedData* sharedData_;
        dsa_area* area_;
@@ -481,7 +521,8 @@ class WorkerProcessor : public Processor {
    protected:
        void delete_session(SessionData* session)
        {
-               session->status.~Status();
+               if (DsaPointerIsValid(session->errorMessage))
+                       dsa_free(area_, session->errorMessage);
                if (DsaPointerIsValid(session->databaseName))
                        dsa_free(area_, session->databaseName);
                if (DsaPointerIsValid(session->userName))
@@ -504,20 +545,42 @@ class Executor : public WorkerProcessor {
                : WorkerProcessor("executor"),
                  sessionID_(sessionID),
                  session_(nullptr),
-                 connected_(false)
+                 connected_(false),
+                 closed_(false)
        {
        }
 
+       ~Executor()
+       {
+               if (!closed_)
+               {
+                       close_internal(false);
+               }
+       }
+
        void open()
        {
+               // pg_usleep(5000000);
+               // pg_usleep(5000000);
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
opening").c_str());
                session_ = static_cast<SessionData*>(dshash_find(sessions_, 
&sessionID_, false));
                auto databaseName =
                        static_cast<const char*>(dsa_get_address(area_, 
session_->databaseName));
                auto userName =
                        static_cast<const char*>(dsa_get_address(area_, 
session_->userName));
-               // TODO: Check password. See src/backend/libpq/auth.c
+               auto password =
+                       static_cast<const char*>(dsa_get_address(area_, 
session_->password));
+               auto clientAddress =
+                       static_cast<const char*>(dsa_get_address(area_, 
session_->clientAddress));
                BackgroundWorkerInitializeConnection(databaseName, userName, 0);
+               CurrentResourceOwner = ResourceOwnerCreate(nullptr, 
"arrow-flight-sql: Executor");
+               if (!check_password(databaseName, userName, password, 
clientAddress))
+               {
+                       session_->initialized = true;
+                       P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC, 
sharedData_->serverPID);
+                       kill(sharedData_->serverPID, SIGUSR1);
+                       return;
+               }
                {
                        SharedRingBuffer buffer(&(session_->bufferData), area_);
                        // TODO: Customizable.
@@ -532,29 +595,7 @@ class Executor : public WorkerProcessor {
                kill(sharedData_->serverPID, SIGUSR1);
        }
 
-       void close()
-       {
-               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
closing").c_str());
-               if (connected_)
-               {
-                       SPI_finish();
-                       CommitTransactionCommand();
-                       {
-                               SharedRingBuffer 
buffer(&(session_->bufferData), area_);
-                               buffer.free();
-                       }
-                       delete_session(session_);
-               }
-               else
-               {
-                       // TODO: Improve failed to connect case.
-                       session_->status = arrow::Status::Invalid("failed to 
connect");
-                       session_->initialized = true;
-                       P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC, 
sharedData_->serverPID);
-                       kill(sharedData_->serverPID, SIGUSR1);
-               }
-               pgstat_report_activity(STATE_IDLE, NULL);
-       }
+       void close() { close_internal(true); }
 
        SharedRingBuffer create_shared_ring_buffer()
        {
@@ -618,6 +659,191 @@ class Executor : public WorkerProcessor {
        }
 
    private:
+       void close_internal(bool unlockSession)
+       {
+               closed_ = true;
+               pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
closing").c_str());
+               if (connected_)
+               {
+                       SPI_finish();
+                       CommitTransactionCommand();
+                       {
+                               SharedRingBuffer 
buffer(&(session_->bufferData), area_);
+                               buffer.free();
+                       }
+                       delete_session(session_);
+               }
+               else
+               {
+                       if (!DsaPointerIsValid(session_->errorMessage))
+                       {
+                               set_shared_string(session_->errorMessage, 
"failed to connect");
+                       }
+                       session_->initialized = true;
+                       if (unlockSession)
+                       {
+                               dshash_release_lock(sessions_, session_);
+                       }
+                       P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC, 
sharedData_->serverPID);
+                       kill(sharedData_->serverPID, SIGUSR1);
+               }
+               if (CurrentResourceOwner)
+               {
+                       auto resourceOwner = CurrentResourceOwner;
+                       CurrentResourceOwner = nullptr;
+                       ResourceOwnerRelease(
+                               resourceOwner, RESOURCE_RELEASE_BEFORE_LOCKS, 
false, true);
+                       ResourceOwnerRelease(resourceOwner, 
RESOURCE_RELEASE_LOCKS, false, true);
+                       ResourceOwnerRelease(
+                               resourceOwner, RESOURCE_RELEASE_AFTER_LOCKS, 
false, true);
+                       ResourceOwnerDelete(resourceOwner);
+               }
+               pgstat_report_activity(STATE_IDLE, NULL);
+       }
+
+       bool check_password(const char* databaseName,
+                           const char* userName,
+                           const char* password,
+                           const char* clientAddress)
+       {
+               MemoryContext memoryContext =
+                       AllocSetContextCreate(CurrentMemoryContext,
+                                         "arrow-flight-sql: 
Executor::check_password()",
+                                         ALLOCSET_DEFAULT_SIZES);
+               ScopedMemoryContext scopedMemoryContext(memoryContext);
+               Port port = {};
+               port.database_name = pstrdup(databaseName);
+               port.user_name = pstrdup(userName);
+               if (!fill_client_address(&port, clientAddress))
+               {
+                       return false;
+               }
+               load_hba();
+               hba_getauthmethod(&port);
+               if (!port.hba)
+               {
+                       set_shared_string(session_->errorMessage, "failed to 
get auth method");
+                       return false;
+               }
+               switch (port.hba->auth_method)
+               {
+                       case uaMD5:
+                               // TODO
+                               set_shared_string(session_->errorMessage,
+                                                 "MD5 auth method isn't 
supported yet");
+                               return false;
+                       case uaSCRAM:
+                               // TODO
+                               set_shared_string(session_->errorMessage,
+                                                 "SCRAM auth method isn't 
supported yet");
+                               return false;
+                       case uaPassword:
+                       {
+                               const char* logDetail = nullptr;
+                               auto shadowPassword = 
get_role_password(port.user_name, &logDetail);
+                               if (!shadowPassword)
+                               {
+                                       set_shared_string(
+                                               session_->errorMessage,
+                                               std::string("failed to get 
password: ") + logDetail);
+                                       return false;
+                               }
+                               auto result = plain_crypt_verify(
+                                       port.user_name, shadowPassword, 
password, &logDetail);
+                               if (result != STATUS_OK)
+                               {
+                                       set_shared_string(
+                                               session_->errorMessage,
+                                               std::string("failed to verify 
password: ") + logDetail);
+                                       return false;
+                               }
+                               return true;
+                       }
+                       case uaTrust:
+                               return true;
+                       default:
+                               set_shared_string(session_->errorMessage,
+                                                 std::string("unsupported auth 
method: ") +
+                                                     
hba_authname(port.hba->auth_method));
+                               return false;
+               }
+       }
+
+       bool fill_client_address(Port* port, const char* clientAddress)
+       {
+               // clientAddress: "ipv4:127.0.0.1:40468"
+               // family: "ipv4"
+               // host: "127.0.0.1"
+               // port: "40468"
+               std::stringstream 
clientAddressStream{std::string(clientAddress)};
+               std::string clientFamily("");
+               std::string clientHost("");
+               std::string clientPort("");
+               std::getline(clientAddressStream, clientFamily, ':');
+               std::getline(clientAddressStream, clientHost, ':');
+               std::getline(clientAddressStream, clientPort);
+               if (!(clientFamily == "ipv4" || clientFamily == "ipv6"))
+               {
+                       set_shared_string(
+                               session_->errorMessage,
+                               std::string("client family must be ipv4 or 
ipv6: ") + clientFamily);
+                       return false;
+               }
+               auto clientPortStart = clientPort.c_str();
+               char* clientPortEnd = nullptr;
+               auto clientPortNumber = std::strtoul(clientPortStart, 
&clientPortEnd, 10);
+               if (clientPortEnd[0] != '\0')
+               {
+                       set_shared_string(session_->errorMessage,
+                                         std::string("client port is invalid: 
") + clientPort);
+                       return false;
+               }
+               if (clientPortNumber == 0)
+               {
+                       set_shared_string(session_->errorMessage,
+                                         std::string("client port must not 
0"));
+                       return false;
+               }
+               if (clientPortNumber > 65535)
+               {
+                       set_shared_string(session_->errorMessage,
+                                         std::string("client port is too 
large: ") +
+                                             std::to_string(clientPortNumber));
+                       return false;
+               }
+               if (clientFamily == "ipv4")
+               {
+                       auto raddr = 
reinterpret_cast<sockaddr_in*>(&(port->raddr.addr));
+                       port->raddr.salen = sizeof(sockaddr_in);
+                       raddr->sin_family = AF_INET;
+                       raddr->sin_port = htons(clientPortNumber);
+                       if (inet_pton(AF_INET, clientHost.c_str(), 
&(raddr->sin_addr)) == 0)
+                       {
+                               set_shared_string(
+                                       session_->errorMessage,
+                                       std::string("client IPv4 address is 
invalid: ") + clientHost);
+                               return false;
+                       }
+               }
+               else if (clientFamily == "ipv6")
+               {
+                       auto raddr = 
reinterpret_cast<sockaddr_in6*>(&(port->raddr.addr));
+                       port->raddr.salen = sizeof(sockaddr_in6);
+                       raddr->sin6_family = AF_INET6;
+                       raddr->sin6_port = htons(clientPortNumber);
+                       raddr->sin6_flowinfo = 0;
+                       if (inet_pton(AF_INET6, clientHost.c_str(), 
&(raddr->sin6_addr)) == 0)
+                       {
+                               set_shared_string(
+                                       session_->errorMessage,
+                                       std::string("client IPv6 address is 
invalid: ") + clientHost);
+                               return false;
+                       }
+                       raddr->sin6_scope_id = 0;
+               }
+               return true;
+       }
+
        void execute()
        {
                pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": 
executing").c_str());
@@ -641,18 +867,15 @@ class Executor : public WorkerProcessor {
                        auto status = write();
                        if (!status.ok())
                        {
-                               session_->status = status;
+                               set_shared_string(session_->errorMessage, 
status.ToString());
                        }
                }
                else
                {
-                       session_->status = arrow::Status::Invalid(Tag,
-                                                                 ": ",
-                                                                 tag_,
-                                                                 ": failed to 
run a query: <",
-                                                                 query,
-                                                                 ">: ",
-                                                                 
SPI_result_code_string(result));
+                       set_shared_string(session_->errorMessage,
+                                         std::string(Tag) + ": " + tag_ +
+                                             ": failed to run a query: <" + 
query +
+                                             ">: " + 
SPI_result_code_string(result));
                }
 
                PopActiveSnapshot();
@@ -772,6 +995,7 @@ class Executor : public WorkerProcessor {
        uint64_t sessionID_;
        SessionData* session_;
        bool connected_;
+       bool closed_;
 };
 
 arrow::Status
@@ -849,9 +1073,10 @@ class Proxy : public WorkerProcessor {
 
        arrow::Result<uint64_t> connect(const std::string& databaseName,
                                        const std::string& userName,
-                                       const std::string& password)
+                                       const std::string& password,
+                                       const std::string& clientAddress)
        {
-               auto session = create_session(databaseName, userName, password);
+               auto session = create_session(databaseName, userName, password, 
clientAddress);
                auto id = session->id;
                dshash_release_lock(sessions_, session);
                kill(sharedData_->mainPID, SIGUSR1);
@@ -877,13 +1102,11 @@ class Proxy : public WorkerProcessor {
                {
                        return arrow::Status::Invalid("session is stale: ", id);
                }
-               if (!session->status.ok())
+               SessionReleaser sessionReleaser(sessions_, session);
+               if (DsaPointerIsValid(session->errorMessage))
                {
-                       auto status = session->status;
-                       delete_session(session);
-                       return status;
+                       return report_session_error(session);
                }
-               dshash_release_lock(sessions_, session);
                if (INTERRUPTS_PENDING_CONDITION())
                {
                        return arrow::Status::Invalid("interrupted");
@@ -921,9 +1144,13 @@ class Proxy : public WorkerProcessor {
                        std::unique_lock<std::mutex> lock(mutex_);
                        conditionVariable_.wait(lock, [&] {
                                P("%s: %s: %s: wait: execute", Tag, tag_, 
AFS_FUNC);
-                               return !session->status.ok() || buffer.size() > 
0;
+                               return DsaPointerIsValid(session->errorMessage) 
|| buffer.size() > 0;
                        });
                }
+               if (DsaPointerIsValid(session->errorMessage))
+               {
+                       return report_session_error(session);
+               }
                P("%s: %s: execute: open", Tag, tag_);
                auto input = 
std::make_shared<SharedRingBufferInputStream>(this, session);
                // Read schema only stream format data.
@@ -955,7 +1182,8 @@ class Proxy : public WorkerProcessor {
    private:
        SessionData* create_session(const std::string& databaseName,
                                    const std::string& userName,
-                                   const std::string& password)
+                                   const std::string& password,
+                                   const std::string& clientAddress)
        {
                LWLockAcquire(lock_, LW_EXCLUSIVE);
                uint64_t id = 0;
@@ -975,12 +1203,13 @@ class Proxy : public WorkerProcessor {
                                break;
                        }
                } while (true);
-               new (&(session->status)) arrow::Status;
+               session->errorMessage = InvalidDsaPointer;
                session->executorPID = InvalidPid;
                session->initialized = false;
                set_shared_string(session->databaseName, databaseName);
                set_shared_string(session->userName, userName);
                set_shared_string(session->password, password);
+               set_shared_string(session->clientAddress, clientAddress);
                session->query = InvalidDsaPointer;
                SharedRingBuffer::initialize_data(&(session->bufferData));
                LWLockRelease(lock_);
@@ -992,19 +1221,17 @@ class Proxy : public WorkerProcessor {
                return static_cast<SessionData*>(dshash_find(sessions_, 
&sessionID, false));
        }
 
-       void set_shared_string(dsa_pointer& pointer, const std::string& input)
+       arrow::Status report_session_error(SessionData* session)
        {
-               if (DsaPointerIsValid(pointer))
-               {
-                       dsa_free(area_, pointer);
-                       pointer = InvalidDsaPointer;
-               }
-               if (input.empty())
-               {
-                       return;
-               }
-               pointer = dsa_allocate(area_, input.size() + 1);
-               memcpy(dsa_get_address(area_, pointer), input.c_str(), 
input.size() + 1);
+               auto status = arrow::Status::Invalid(
+                       static_cast<const char*>(dsa_get_address(area_, 
session->errorMessage)));
+               P("%s: %s: %s: kill SIGTERM executor: %d",
+                 Tag,
+                 tag_,
+                 AFS_FUNC,
+                 session->executorPID);
+               kill(session->executorPID, SIGTERM);
+               return status;
        }
 
        std::random_device randomSeed_;
@@ -1130,8 +1357,10 @@ class MainProcessor : public Processor {
                        }
                        else
                        {
-                               session->status = arrow::Status::UnknownError(
-                                       Tag, ": ", tag_, ": failed to start 
executor: ", session->id);
+                               set_shared_string(
+                                       session->errorMessage,
+                                       std::string(Tag) + ": " + tag_ +
+                                               ": failed to start executor: " 
+ std::to_string(session->id));
                        }
                }
                dshash_seq_term(&sessionsStatus);
@@ -1171,17 +1400,26 @@ class HeaderAuthServerMiddlewareFactory : public 
arrow::flight::ServerMiddleware
 
        arrow::Status StartCall(
                const arrow::flight::CallInfo& info,
+#if ARROW_VERSION_MAJOR >= 13
+               const arrow::flight::ServerCallContext& context,
+#else
                const arrow::flight::CallHeaders& incoming_headers,
+#endif
                std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) 
override
        {
                std::string databaseName("postgres");
-               auto databaseHeader = 
incoming_headers.find("x-flight-sql-database");
-               if (databaseHeader != incoming_headers.end())
+#if ARROW_VERSION_MAJOR >= 13
+               const auto& incomingHeaders = context.incoming_headers();
+#else
+               const auto& incomingHeaders = incoming_headers;
+#endif
+               auto databaseHeader = 
incomingHeaders.find("x-flight-sql-database");
+               if (databaseHeader != incomingHeaders.end())
                {
                        databaseName = databaseHeader->second;
                }
-               auto authorizationHeader = 
incoming_headers.find("authorization");
-               if (authorizationHeader == incoming_headers.end())
+               auto authorizationHeader = 
incomingHeaders.find("authorization");
+               if (authorizationHeader == incomingHeaders.end())
                {
                        return arrow::flight::MakeFlightError(
                                
arrow::flight::FlightStatusCode::Unauthenticated,
@@ -1199,7 +1437,14 @@ class HeaderAuthServerMiddlewareFactory : public 
arrow::flight::ServerMiddleware
                        std::string password("");
                        std::getline(decodedStream, userName, ':');
                        std::getline(decodedStream, password);
-                       auto sessionIDResult = proxy_->connect(databaseName, 
userName, password);
+#if ARROW_VERSION_MAJOR >= 13
+                       const auto& clientAddress = context.peer();
+#else
+                       // 192.0.0.1 is one of reserved IPv4 addresses for 
documentation.
+                       std::string clientAddress("ipv4:192.0.2.1:2929");
+#endif
+                       auto sessionIDResult =
+                               proxy_->connect(databaseName, userName, 
password, clientAddress);
                        if (!sessionIDResult.status().ok())
                        {
                                return arrow::flight::MakeFlightError(
@@ -1393,6 +1638,7 @@ afs_executor(Datum arg)
 
                CHECK_FOR_INTERRUPTS();
        }
+       executor->close();
 
        proc_exit(0);
 }
diff --git a/test/helper/sandbox.rb b/test/helper/sandbox.rb
index f7b0d4a..6aec966 100644
--- a/test/helper/sandbox.rb
+++ b/test/helper/sandbox.rb
@@ -17,6 +17,7 @@
 
 require "fileutils"
 require "socket"
+require "tempfile"
 
 require "arrow-flight-sql"
 
@@ -27,6 +28,9 @@ module Helper
         "LC_ALL" => "C",
         "PGCLIENTENCODING" => "UTF-8",
       }
+      if args.first.is_a?(Hash)
+        env.merge!(args.shift)
+      end
       output_read, output_write = IO.pipe
       error_read, error_write = IO.pipe
       options = {
@@ -97,6 +101,7 @@ module Helper
     attr_reader :flight_sql_port
     attr_reader :flight_sql_uri
     attr_reader :user
+    attr_reader :password
     def initialize(base_dir)
       @base_dir = base_dir
       @dir = nil
@@ -107,6 +112,7 @@ module Helper
       @flight_sql_port = nil
       @flight_sql_uri = nil
       @user = "arrow-flight-sql-test"
+      @password = "Passw0rd!"
       @running = false
     end
 
@@ -122,13 +128,21 @@ module Helper
       @log_path = File.join(@dir, "log", @log_base_name)
       socket_dir = File.join(@dir, "socket")
       @port = port
+      @pgpass = Tempfile.new("arrow-flight-sql-test-pgpass")
+      @pgpass.puts("#{@host}:#{@port}:*:#{@user}:#{@password}")
+      @pgpass.close
       @flight_sql_port = flight_sql_port
       @flight_sql_uri = "grpc://#{@host}:#{@flight_sql_port}"
-      run_command("initdb",
-                  "--locale", "C",
-                  "--encoding", "UTF-8",
-                  "--username", @user,
-                  "-D", @dir)
+      Tempfile.create("arrow-flight-sql-test-password") do |password|
+        password.print(@password)
+        password.close
+        run_command("initdb",
+                    "--locale", "C",
+                    "--encoding", "UTF-8",
+                    "--username", @user,
+                    "--pwfile", password.path,
+                    "-D", @dir)
+      end
       FileUtils.mkdir_p(socket_dir)
       postgresql_conf = File.join(@dir, "postgresql.conf")
       File.open(postgresql_conf, "a") do |conf|
@@ -144,6 +158,10 @@ module Helper
         conf.puts("arrow_flight_sql.uri = #{@flight_sql_uri}")
         yield(conf) if block_given?
       end
+      pg_hba_conf = File.join(@dir, "pg_hba.conf")
+      pg_hba = File.read(pg_hba_conf)
+      pg_hba.gsub!(/^(host.+)trust$/, "\\1password")
+      File.write(pg_hba_conf, pg_hba)
     end
 
     def start
@@ -175,12 +193,16 @@ module Helper
     end
 
     def psql(db, sql)
-      output, error = run_command("psql",
+      output, error = run_command({
+                                    "PGPASSFILE" => @pgpass.path,
+                                  },
+                                  "psql",
                                   "--host", @host,
                                   "--port", @port.to_s,
                                   "--username", @user,
                                   "--dbname", db,
                                   "--echo-all",
+                                  "--no-password",
                                   "--no-psqlrc",
                                   "--command", sql)
       [output, error]
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 71a2174..3e0992d 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -25,7 +25,7 @@ class FlightSQLTest < Test::Unit::TestCase
     @options = ArrowFlight::CallOptions.new
     @options.add_header("x-flight-sql-database", @test_db_name)
     user = @postgresql.user
-    password = ""
+    password = @postgresql.password
     flight_client.authenticate_basic(user, password, @options)
   end
 

Reply via email to