This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git
The following commit(s) were added to refs/heads/main by this push:
new 5416129 Add support for authentication (#39)
5416129 is described below
commit 5416129f580c6a9ebd2db74266b7de057f3a6f4a
Author: Sutou Kouhei <[email protected]>
AuthorDate: Mon May 8 21:19:26 2023 +0900
Add support for authentication (#39)
Closes GH-18
This only supports "password" and "trust" for now.
---
.github/workflows/test.yaml | 1 +
src/afs.cc | 374 ++++++++++++++++++++++++++++++++++++--------
test/helper/sandbox.rb | 34 +++-
test/test-flight-sql.rb | 2 +-
4 files changed, 340 insertions(+), 71 deletions(-)
diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index 0b8436f..74ab677 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -61,6 +61,7 @@ jobs:
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc
| sudo apt-key add -
sudo apt update
sudo apt -y -V -t ${suite} install \
+ libkrb5-dev \
postgresql-${{ matrix.postgresql-version }} \
postgresql-server-dev-${{ matrix.postgresql-version }}
- name: Install Apache Arrow Flight SQL adapter
diff --git a/src/afs.cc b/src/afs.cc
index b33546f..92b7e10 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -23,6 +23,8 @@ extern "C"
#include <executor/spi.h>
#include <fmgr.h>
#include <lib/dshash.h>
+#include <libpq/crypt.h>
+#include <libpq/libpq-be.h>
#include <miscadmin.h>
#include <postmaster/bgworker.h>
#include <storage/ipc.h>
@@ -33,6 +35,7 @@ extern "C"
#include <utils/backend_status.h>
#include <utils/dsa.h>
#include <utils/guc.h>
+#include <utils/memutils.h>
#include <utils/snapmgr.h>
#include <utils/wait_event.h>
}
@@ -54,6 +57,8 @@ extern "C"
#include <random>
#include <sstream>
+#include <arpa/inet.h>
+
#ifdef __GNUC__
# define AFS_FUNC __PRETTY_FUNCTION__
#else
@@ -129,6 +134,25 @@ afs_shmem_request_hook(void)
RequestNamedLWLockTranche(LWLockTrancheName, 1);
}
+class ScopedMemoryContext {
+ public:
+ explicit ScopedMemoryContext(MemoryContext memoryContext)
+ : memoryContext_(memoryContext), oldMemoryContext_(nullptr)
+ {
+ oldMemoryContext_ = MemoryContextSwitchTo(memoryContext_);
+ }
+
+ ~ScopedMemoryContext()
+ {
+ MemoryContextSwitchTo(oldMemoryContext_);
+ MemoryContextDelete(memoryContext_);
+ }
+
+ private:
+ MemoryContext memoryContext_;
+ MemoryContext oldMemoryContext_;
+};
+
struct SharedRingBufferData {
dsa_pointer pointer;
size_t total;
@@ -304,12 +328,13 @@ class SharedRingBuffer {
struct SessionData {
uint64_t id;
- arrow::Status status;
+ dsa_pointer errorMessage;
pid_t executorPID;
bool initialized;
dsa_pointer databaseName;
dsa_pointer userName;
dsa_pointer password;
+ dsa_pointer clientAddress;
dsa_pointer query;
SharedRingBufferData bufferData;
};
@@ -377,6 +402,21 @@ class Processor {
}
protected:
+ void set_shared_string(dsa_pointer& pointer, const std::string& input)
+ {
+ if (DsaPointerIsValid(pointer))
+ {
+ dsa_free(area_, pointer);
+ pointer = InvalidDsaPointer;
+ }
+ if (input.empty())
+ {
+ return;
+ }
+ pointer = dsa_allocate(area_, input.size() + 1);
+ memcpy(dsa_get_address(area_, pointer), input.c_str(),
input.size() + 1);
+ }
+
const char* tag_;
SharedData* sharedData_;
dsa_area* area_;
@@ -481,7 +521,8 @@ class WorkerProcessor : public Processor {
protected:
void delete_session(SessionData* session)
{
- session->status.~Status();
+ if (DsaPointerIsValid(session->errorMessage))
+ dsa_free(area_, session->errorMessage);
if (DsaPointerIsValid(session->databaseName))
dsa_free(area_, session->databaseName);
if (DsaPointerIsValid(session->userName))
@@ -504,20 +545,42 @@ class Executor : public WorkerProcessor {
: WorkerProcessor("executor"),
sessionID_(sessionID),
session_(nullptr),
- connected_(false)
+ connected_(false),
+ closed_(false)
{
}
+ ~Executor()
+ {
+ if (!closed_)
+ {
+ close_internal(false);
+ }
+ }
+
void open()
{
+ // pg_usleep(5000000);
+ // pg_usleep(5000000);
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
opening").c_str());
session_ = static_cast<SessionData*>(dshash_find(sessions_,
&sessionID_, false));
auto databaseName =
static_cast<const char*>(dsa_get_address(area_,
session_->databaseName));
auto userName =
static_cast<const char*>(dsa_get_address(area_,
session_->userName));
- // TODO: Check password. See src/backend/libpq/auth.c
+ auto password =
+ static_cast<const char*>(dsa_get_address(area_,
session_->password));
+ auto clientAddress =
+ static_cast<const char*>(dsa_get_address(area_,
session_->clientAddress));
BackgroundWorkerInitializeConnection(databaseName, userName, 0);
+ CurrentResourceOwner = ResourceOwnerCreate(nullptr,
"arrow-flight-sql: Executor");
+ if (!check_password(databaseName, userName, password,
clientAddress))
+ {
+ session_->initialized = true;
+ P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC,
sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ return;
+ }
{
SharedRingBuffer buffer(&(session_->bufferData), area_);
// TODO: Customizable.
@@ -532,29 +595,7 @@ class Executor : public WorkerProcessor {
kill(sharedData_->serverPID, SIGUSR1);
}
- void close()
- {
- pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
closing").c_str());
- if (connected_)
- {
- SPI_finish();
- CommitTransactionCommand();
- {
- SharedRingBuffer
buffer(&(session_->bufferData), area_);
- buffer.free();
- }
- delete_session(session_);
- }
- else
- {
- // TODO: Improve failed to connect case.
- session_->status = arrow::Status::Invalid("failed to
connect");
- session_->initialized = true;
- P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
- }
- pgstat_report_activity(STATE_IDLE, NULL);
- }
+ void close() { close_internal(true); }
SharedRingBuffer create_shared_ring_buffer()
{
@@ -618,6 +659,191 @@ class Executor : public WorkerProcessor {
}
private:
+ void close_internal(bool unlockSession)
+ {
+ closed_ = true;
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
closing").c_str());
+ if (connected_)
+ {
+ SPI_finish();
+ CommitTransactionCommand();
+ {
+ SharedRingBuffer
buffer(&(session_->bufferData), area_);
+ buffer.free();
+ }
+ delete_session(session_);
+ }
+ else
+ {
+ if (!DsaPointerIsValid(session_->errorMessage))
+ {
+ set_shared_string(session_->errorMessage,
"failed to connect");
+ }
+ session_->initialized = true;
+ if (unlockSession)
+ {
+ dshash_release_lock(sessions_, session_);
+ }
+ P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC,
sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+ if (CurrentResourceOwner)
+ {
+ auto resourceOwner = CurrentResourceOwner;
+ CurrentResourceOwner = nullptr;
+ ResourceOwnerRelease(
+ resourceOwner, RESOURCE_RELEASE_BEFORE_LOCKS,
false, true);
+ ResourceOwnerRelease(resourceOwner,
RESOURCE_RELEASE_LOCKS, false, true);
+ ResourceOwnerRelease(
+ resourceOwner, RESOURCE_RELEASE_AFTER_LOCKS,
false, true);
+ ResourceOwnerDelete(resourceOwner);
+ }
+ pgstat_report_activity(STATE_IDLE, NULL);
+ }
+
+ bool check_password(const char* databaseName,
+ const char* userName,
+ const char* password,
+ const char* clientAddress)
+ {
+ MemoryContext memoryContext =
+ AllocSetContextCreate(CurrentMemoryContext,
+ "arrow-flight-sql:
Executor::check_password()",
+ ALLOCSET_DEFAULT_SIZES);
+ ScopedMemoryContext scopedMemoryContext(memoryContext);
+ Port port = {};
+ port.database_name = pstrdup(databaseName);
+ port.user_name = pstrdup(userName);
+ if (!fill_client_address(&port, clientAddress))
+ {
+ return false;
+ }
+ load_hba();
+ hba_getauthmethod(&port);
+ if (!port.hba)
+ {
+ set_shared_string(session_->errorMessage, "failed to
get auth method");
+ return false;
+ }
+ switch (port.hba->auth_method)
+ {
+ case uaMD5:
+ // TODO
+ set_shared_string(session_->errorMessage,
+ "MD5 auth method isn't
supported yet");
+ return false;
+ case uaSCRAM:
+ // TODO
+ set_shared_string(session_->errorMessage,
+ "SCRAM auth method isn't
supported yet");
+ return false;
+ case uaPassword:
+ {
+ const char* logDetail = nullptr;
+ auto shadowPassword =
get_role_password(port.user_name, &logDetail);
+ if (!shadowPassword)
+ {
+ set_shared_string(
+ session_->errorMessage,
+ std::string("failed to get
password: ") + logDetail);
+ return false;
+ }
+ auto result = plain_crypt_verify(
+ port.user_name, shadowPassword,
password, &logDetail);
+ if (result != STATUS_OK)
+ {
+ set_shared_string(
+ session_->errorMessage,
+ std::string("failed to verify
password: ") + logDetail);
+ return false;
+ }
+ return true;
+ }
+ case uaTrust:
+ return true;
+ default:
+ set_shared_string(session_->errorMessage,
+ std::string("unsupported auth
method: ") +
+
hba_authname(port.hba->auth_method));
+ return false;
+ }
+ }
+
+ bool fill_client_address(Port* port, const char* clientAddress)
+ {
+ // clientAddress: "ipv4:127.0.0.1:40468"
+ // family: "ipv4"
+ // host: "127.0.0.1"
+ // port: "40468"
+ std::stringstream
clientAddressStream{std::string(clientAddress)};
+ std::string clientFamily("");
+ std::string clientHost("");
+ std::string clientPort("");
+ std::getline(clientAddressStream, clientFamily, ':');
+ std::getline(clientAddressStream, clientHost, ':');
+ std::getline(clientAddressStream, clientPort);
+ if (!(clientFamily == "ipv4" || clientFamily == "ipv6"))
+ {
+ set_shared_string(
+ session_->errorMessage,
+ std::string("client family must be ipv4 or
ipv6: ") + clientFamily);
+ return false;
+ }
+ auto clientPortStart = clientPort.c_str();
+ char* clientPortEnd = nullptr;
+ auto clientPortNumber = std::strtoul(clientPortStart,
&clientPortEnd, 10);
+ if (clientPortEnd[0] != '\0')
+ {
+ set_shared_string(session_->errorMessage,
+ std::string("client port is invalid:
") + clientPort);
+ return false;
+ }
+ if (clientPortNumber == 0)
+ {
+ set_shared_string(session_->errorMessage,
+ std::string("client port must not
0"));
+ return false;
+ }
+ if (clientPortNumber > 65535)
+ {
+ set_shared_string(session_->errorMessage,
+ std::string("client port is too
large: ") +
+ std::to_string(clientPortNumber));
+ return false;
+ }
+ if (clientFamily == "ipv4")
+ {
+ auto raddr =
reinterpret_cast<sockaddr_in*>(&(port->raddr.addr));
+ port->raddr.salen = sizeof(sockaddr_in);
+ raddr->sin_family = AF_INET;
+ raddr->sin_port = htons(clientPortNumber);
+ if (inet_pton(AF_INET, clientHost.c_str(),
&(raddr->sin_addr)) == 0)
+ {
+ set_shared_string(
+ session_->errorMessage,
+ std::string("client IPv4 address is
invalid: ") + clientHost);
+ return false;
+ }
+ }
+ else if (clientFamily == "ipv6")
+ {
+ auto raddr =
reinterpret_cast<sockaddr_in6*>(&(port->raddr.addr));
+ port->raddr.salen = sizeof(sockaddr_in6);
+ raddr->sin6_family = AF_INET6;
+ raddr->sin6_port = htons(clientPortNumber);
+ raddr->sin6_flowinfo = 0;
+ if (inet_pton(AF_INET6, clientHost.c_str(),
&(raddr->sin6_addr)) == 0)
+ {
+ set_shared_string(
+ session_->errorMessage,
+ std::string("client IPv6 address is
invalid: ") + clientHost);
+ return false;
+ }
+ raddr->sin6_scope_id = 0;
+ }
+ return true;
+ }
+
void execute()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
executing").c_str());
@@ -641,18 +867,15 @@ class Executor : public WorkerProcessor {
auto status = write();
if (!status.ok())
{
- session_->status = status;
+ set_shared_string(session_->errorMessage,
status.ToString());
}
}
else
{
- session_->status = arrow::Status::Invalid(Tag,
- ": ",
- tag_,
- ": failed to
run a query: <",
- query,
- ">: ",
-
SPI_result_code_string(result));
+ set_shared_string(session_->errorMessage,
+ std::string(Tag) + ": " + tag_ +
+ ": failed to run a query: <" +
query +
+ ">: " +
SPI_result_code_string(result));
}
PopActiveSnapshot();
@@ -772,6 +995,7 @@ class Executor : public WorkerProcessor {
uint64_t sessionID_;
SessionData* session_;
bool connected_;
+ bool closed_;
};
arrow::Status
@@ -849,9 +1073,10 @@ class Proxy : public WorkerProcessor {
arrow::Result<uint64_t> connect(const std::string& databaseName,
const std::string& userName,
- const std::string& password)
+ const std::string& password,
+ const std::string& clientAddress)
{
- auto session = create_session(databaseName, userName, password);
+ auto session = create_session(databaseName, userName, password,
clientAddress);
auto id = session->id;
dshash_release_lock(sessions_, session);
kill(sharedData_->mainPID, SIGUSR1);
@@ -877,13 +1102,11 @@ class Proxy : public WorkerProcessor {
{
return arrow::Status::Invalid("session is stale: ", id);
}
- if (!session->status.ok())
+ SessionReleaser sessionReleaser(sessions_, session);
+ if (DsaPointerIsValid(session->errorMessage))
{
- auto status = session->status;
- delete_session(session);
- return status;
+ return report_session_error(session);
}
- dshash_release_lock(sessions_, session);
if (INTERRUPTS_PENDING_CONDITION())
{
return arrow::Status::Invalid("interrupted");
@@ -921,9 +1144,13 @@ class Proxy : public WorkerProcessor {
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait: execute", Tag, tag_,
AFS_FUNC);
- return !session->status.ok() || buffer.size() >
0;
+ return DsaPointerIsValid(session->errorMessage)
|| buffer.size() > 0;
});
}
+ if (DsaPointerIsValid(session->errorMessage))
+ {
+ return report_session_error(session);
+ }
P("%s: %s: execute: open", Tag, tag_);
auto input =
std::make_shared<SharedRingBufferInputStream>(this, session);
// Read schema only stream format data.
@@ -955,7 +1182,8 @@ class Proxy : public WorkerProcessor {
private:
SessionData* create_session(const std::string& databaseName,
const std::string& userName,
- const std::string& password)
+ const std::string& password,
+ const std::string& clientAddress)
{
LWLockAcquire(lock_, LW_EXCLUSIVE);
uint64_t id = 0;
@@ -975,12 +1203,13 @@ class Proxy : public WorkerProcessor {
break;
}
} while (true);
- new (&(session->status)) arrow::Status;
+ session->errorMessage = InvalidDsaPointer;
session->executorPID = InvalidPid;
session->initialized = false;
set_shared_string(session->databaseName, databaseName);
set_shared_string(session->userName, userName);
set_shared_string(session->password, password);
+ set_shared_string(session->clientAddress, clientAddress);
session->query = InvalidDsaPointer;
SharedRingBuffer::initialize_data(&(session->bufferData));
LWLockRelease(lock_);
@@ -992,19 +1221,17 @@ class Proxy : public WorkerProcessor {
return static_cast<SessionData*>(dshash_find(sessions_,
&sessionID, false));
}
- void set_shared_string(dsa_pointer& pointer, const std::string& input)
+ arrow::Status report_session_error(SessionData* session)
{
- if (DsaPointerIsValid(pointer))
- {
- dsa_free(area_, pointer);
- pointer = InvalidDsaPointer;
- }
- if (input.empty())
- {
- return;
- }
- pointer = dsa_allocate(area_, input.size() + 1);
- memcpy(dsa_get_address(area_, pointer), input.c_str(),
input.size() + 1);
+ auto status = arrow::Status::Invalid(
+ static_cast<const char*>(dsa_get_address(area_,
session->errorMessage)));
+ P("%s: %s: %s: kill SIGTERM executor: %d",
+ Tag,
+ tag_,
+ AFS_FUNC,
+ session->executorPID);
+ kill(session->executorPID, SIGTERM);
+ return status;
}
std::random_device randomSeed_;
@@ -1130,8 +1357,10 @@ class MainProcessor : public Processor {
}
else
{
- session->status = arrow::Status::UnknownError(
- Tag, ": ", tag_, ": failed to start
executor: ", session->id);
+ set_shared_string(
+ session->errorMessage,
+ std::string(Tag) + ": " + tag_ +
+ ": failed to start executor: "
+ std::to_string(session->id));
}
}
dshash_seq_term(&sessionsStatus);
@@ -1171,17 +1400,26 @@ class HeaderAuthServerMiddlewareFactory : public
arrow::flight::ServerMiddleware
arrow::Status StartCall(
const arrow::flight::CallInfo& info,
+#if ARROW_VERSION_MAJOR >= 13
+ const arrow::flight::ServerCallContext& context,
+#else
const arrow::flight::CallHeaders& incoming_headers,
+#endif
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware)
override
{
std::string databaseName("postgres");
- auto databaseHeader =
incoming_headers.find("x-flight-sql-database");
- if (databaseHeader != incoming_headers.end())
+#if ARROW_VERSION_MAJOR >= 13
+ const auto& incomingHeaders = context.incoming_headers();
+#else
+ const auto& incomingHeaders = incoming_headers;
+#endif
+ auto databaseHeader =
incomingHeaders.find("x-flight-sql-database");
+ if (databaseHeader != incomingHeaders.end())
{
databaseName = databaseHeader->second;
}
- auto authorizationHeader =
incoming_headers.find("authorization");
- if (authorizationHeader == incoming_headers.end())
+ auto authorizationHeader =
incomingHeaders.find("authorization");
+ if (authorizationHeader == incomingHeaders.end())
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
@@ -1199,7 +1437,14 @@ class HeaderAuthServerMiddlewareFactory : public
arrow::flight::ServerMiddleware
std::string password("");
std::getline(decodedStream, userName, ':');
std::getline(decodedStream, password);
- auto sessionIDResult = proxy_->connect(databaseName,
userName, password);
+#if ARROW_VERSION_MAJOR >= 13
+ const auto& clientAddress = context.peer();
+#else
+ // 192.0.0.1 is one of reserved IPv4 addresses for
documentation.
+ std::string clientAddress("ipv4:192.0.2.1:2929");
+#endif
+ auto sessionIDResult =
+ proxy_->connect(databaseName, userName,
password, clientAddress);
if (!sessionIDResult.status().ok())
{
return arrow::flight::MakeFlightError(
@@ -1393,6 +1638,7 @@ afs_executor(Datum arg)
CHECK_FOR_INTERRUPTS();
}
+ executor->close();
proc_exit(0);
}
diff --git a/test/helper/sandbox.rb b/test/helper/sandbox.rb
index f7b0d4a..6aec966 100644
--- a/test/helper/sandbox.rb
+++ b/test/helper/sandbox.rb
@@ -17,6 +17,7 @@
require "fileutils"
require "socket"
+require "tempfile"
require "arrow-flight-sql"
@@ -27,6 +28,9 @@ module Helper
"LC_ALL" => "C",
"PGCLIENTENCODING" => "UTF-8",
}
+ if args.first.is_a?(Hash)
+ env.merge!(args.shift)
+ end
output_read, output_write = IO.pipe
error_read, error_write = IO.pipe
options = {
@@ -97,6 +101,7 @@ module Helper
attr_reader :flight_sql_port
attr_reader :flight_sql_uri
attr_reader :user
+ attr_reader :password
def initialize(base_dir)
@base_dir = base_dir
@dir = nil
@@ -107,6 +112,7 @@ module Helper
@flight_sql_port = nil
@flight_sql_uri = nil
@user = "arrow-flight-sql-test"
+ @password = "Passw0rd!"
@running = false
end
@@ -122,13 +128,21 @@ module Helper
@log_path = File.join(@dir, "log", @log_base_name)
socket_dir = File.join(@dir, "socket")
@port = port
+ @pgpass = Tempfile.new("arrow-flight-sql-test-pgpass")
+ @pgpass.puts("#{@host}:#{@port}:*:#{@user}:#{@password}")
+ @pgpass.close
@flight_sql_port = flight_sql_port
@flight_sql_uri = "grpc://#{@host}:#{@flight_sql_port}"
- run_command("initdb",
- "--locale", "C",
- "--encoding", "UTF-8",
- "--username", @user,
- "-D", @dir)
+ Tempfile.create("arrow-flight-sql-test-password") do |password|
+ password.print(@password)
+ password.close
+ run_command("initdb",
+ "--locale", "C",
+ "--encoding", "UTF-8",
+ "--username", @user,
+ "--pwfile", password.path,
+ "-D", @dir)
+ end
FileUtils.mkdir_p(socket_dir)
postgresql_conf = File.join(@dir, "postgresql.conf")
File.open(postgresql_conf, "a") do |conf|
@@ -144,6 +158,10 @@ module Helper
conf.puts("arrow_flight_sql.uri = #{@flight_sql_uri}")
yield(conf) if block_given?
end
+ pg_hba_conf = File.join(@dir, "pg_hba.conf")
+ pg_hba = File.read(pg_hba_conf)
+ pg_hba.gsub!(/^(host.+)trust$/, "\\1password")
+ File.write(pg_hba_conf, pg_hba)
end
def start
@@ -175,12 +193,16 @@ module Helper
end
def psql(db, sql)
- output, error = run_command("psql",
+ output, error = run_command({
+ "PGPASSFILE" => @pgpass.path,
+ },
+ "psql",
"--host", @host,
"--port", @port.to_s,
"--username", @user,
"--dbname", db,
"--echo-all",
+ "--no-password",
"--no-psqlrc",
"--command", sql)
[output, error]
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 71a2174..3e0992d 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -25,7 +25,7 @@ class FlightSQLTest < Test::Unit::TestCase
@options = ArrowFlight::CallOptions.new
@options.add_header("x-flight-sql-database", @test_db_name)
user = @postgresql.user
- password = ""
+ password = @postgresql.password
flight_client.authenticate_basic(user, password, @options)
end