This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git
The following commit(s) were added to refs/heads/main by this push:
new 8e92eee Add support for prepared SELECT (#86)
8e92eee is described below
commit 8e92eee4768e831c4ef16cb61c60ebaaa92e19c6
Author: Sutou Kouhei <[email protected]>
AuthorDate: Wed Aug 30 17:15:43 2023 +0900
Add support for prepared SELECT (#86)
Closes GH-81
---
src/afs.cc | 790 ++++++++++++++++++++++++++++++++----------------
test/test-flight-sql.rb | 33 +-
2 files changed, 557 insertions(+), 266 deletions(-)
diff --git a/src/afs.cc b/src/afs.cc
index bfb6776..bf78472 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -63,6 +63,7 @@ extern "C"
#include <map>
#include <random>
#include <sstream>
+#include <type_traits>
#include <arpa/inet.h>
@@ -163,6 +164,22 @@ class ScopedMemoryContext {
MemoryContext oldMemoryContext_;
};
+struct ScopedTransaction {
+ ScopedTransaction() { StartTransactionCommand(); }
+ ~ScopedTransaction() { CommitTransactionCommand(); }
+};
+
+struct ScopedSnapshot {
+ ScopedSnapshot() { PushActiveSnapshot(GetTransactionSnapshot()); }
+ ~ScopedSnapshot() { PopActiveSnapshot(); }
+};
+
+struct ScopedPlan {
+ ScopedPlan(SPIPlanPtr plan) : plan_(plan) {}
+ ~ScopedPlan() { SPI_freeplan(plan_); }
+ SPIPlanPtr plan_;
+};
+
struct SharedRingBufferData {
dsa_pointer pointer;
size_t total;
@@ -344,6 +361,8 @@ enum class Action
Update,
Prepare,
ClosePreparedStatement,
+ SetParameters,
+ SelectPreparedStatement,
UpdatePreparedStatement,
};
@@ -362,6 +381,10 @@ action_name(Action action)
return "Action::Prepare";
case Action::ClosePreparedStatement:
return "Action::ClosePreparedStatement";
+ case Action::SetParameters:
+ return "Action::SetParameters";
+ case Action::SelectPreparedStatement:
+ return "Action::SelectPreparedStatement";
case Action::UpdatePreparedStatement:
return "Action::UpdatePreparedStatement";
default:
@@ -617,6 +640,17 @@ class Processor {
std::condition_variable conditionVariable_;
};
+struct ProcessorLockGuard {
+ ProcessorLockGuard(Processor* processor) : processor_(processor)
+ {
+ processor_->lock_acquire();
+ }
+ ~ProcessorLockGuard() { processor_->lock_release(); }
+
+ private:
+ Processor* processor_;
+};
+
class Proxy;
class SharedRingBufferInputStream : public arrow::io::InputStream {
public:
@@ -1004,34 +1038,56 @@ class PGArrowValueConverter : public
arrow::ArrayVisitor {
class PreparedStatement {
public:
- explicit PreparedStatement(std::string query) :
query_(std::move(query)) {}
+ explicit PreparedStatement(std::string query)
+ : query_(std::move(query)), parameters_()
+ {
+ }
~PreparedStatement() {}
- arrow::Result<int64_t>
update(std::shared_ptr<SharedRingBufferInputStream>& input)
+ using WriteFunc = std::add_pointer<arrow::Status(void*)>::type;
+ arrow::Status select(WriteFunc write, void* writeData)
{
- ARROW_ASSIGN_OR_RAISE(auto reader,
-
arrow::ipc::RecordBatchStreamReader::Open(input));
- const auto& schema = reader->schema();
- SPIExecuteOptions options = {};
- if (schema->num_fields() > 0)
+ for (const auto& recordBatch : parameters_)
{
- options.params = makeParamList(schema->num_fields());
+ SPIExecuteOptions options = {};
+ std::vector<Oid> pgTypes;
+ ARROW_RETURN_NOT_OK(prepare(options, pgTypes,
recordBatch->schema()));
+ auto plan = SPI_prepare(query_.c_str(), pgTypes.size(),
pgTypes.data());
+ ScopedPlan scopedPlan(plan);
+ ARROW_RETURN_NOT_OK(
+ execute(plan, recordBatch, options, [&]() {
return write(writeData); }));
}
- options.read_only = false;
- options.tcount = 0;
- ARROW_ASSIGN_OR_RAISE(auto pgTypes, create_pg_types(schema));
- for (size_t i = 0; i < pgTypes.size(); ++i)
+ return arrow::Status::OK();
+ }
+
+ arrow::Status
set_parameters(std::shared_ptr<SharedRingBufferInputStream>& input)
+ {
+ parameters_.clear();
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+
arrow::ipc::RecordBatchStreamReader::Open(input));
+ while (true)
{
- options.params->params[i].pflags = PARAM_FLAG_CONST;
- options.params->params[i].ptype = pgTypes[i];
+ std::shared_ptr<arrow::RecordBatch> recordBatch;
+ ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
+ if (!recordBatch)
+ {
+ break;
+ }
+ parameters_.push_back(std::move(recordBatch));
}
+ return arrow::Status::OK();
+ }
+
+ arrow::Result<int64_t>
update(std::shared_ptr<SharedRingBufferInputStream>& input)
+ {
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+
arrow::ipc::RecordBatchStreamReader::Open(input));
+ SPIExecuteOptions options = {};
+ std::vector<Oid> pgTypes;
+ ARROW_RETURN_NOT_OK(prepare(options, pgTypes,
reader->schema()));
auto plan = SPI_prepare(query_.c_str(), pgTypes.size(),
pgTypes.data());
- struct PlanFinalizer {
- PlanFinalizer(SPIPlanPtr plan) : plan_(plan) {}
- ~PlanFinalizer() { SPI_freeplan(plan_); }
- SPIPlanPtr plan_;
- } planFinalizer(plan);
+ ScopedPlan scopedPlan(plan);
int64_t nUpdatedRecords = 0;
while (true)
@@ -1042,41 +1098,67 @@ class PreparedStatement {
{
break;
}
- const auto& columns = recordBatch->columns();
- for (int64_t i = 0; i < recordBatch->num_rows(); ++i)
- {
-
ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i, columns, options));
- auto result = SPI_execute_plan_extended(plan,
&options);
- switch (result)
- {
- case SPI_OK_INSERT:
- case SPI_OK_DELETE:
- case SPI_OK_UPDATE:
- break;
- default:
- return arrow::Status::Invalid(
- "failed to run a
prepared statement: ",
-
SPI_result_code_string(result));
- break;
- }
+ ARROW_RETURN_NOT_OK(execute(plan, recordBatch, options,
[&nUpdatedRecords]() {
nUpdatedRecords += SPI_processed;
- }
+ return arrow::Status::OK();
+ }));
}
return nUpdatedRecords;
}
private:
- arrow::Result<std::vector<Oid>> create_pg_types(
- const std::shared_ptr<arrow::Schema>& schema)
+ arrow::Status prepare_pg_types(std::vector<Oid>& pgTypes,
+ const std::shared_ptr<arrow::Schema>&
schema)
{
- std::vector<Oid> pgTypes;
ArrowPGTypeConverter converter;
for (const auto& field : schema->fields())
{
ARROW_RETURN_NOT_OK(field->type()->Accept(&converter));
pgTypes.push_back(converter.oid());
}
- return std::move(pgTypes);
+ return arrow::Status::OK();
+ }
+
+ arrow::Status prepare(SPIExecuteOptions& options,
+ std::vector<Oid>& pgTypes,
+ const std::shared_ptr<arrow::Schema>& schema)
+ {
+ if (schema->num_fields() > 0)
+ {
+ options.params = makeParamList(schema->num_fields());
+ }
+ options.read_only = false;
+ options.tcount = 0;
+ ARROW_RETURN_NOT_OK(prepare_pg_types(pgTypes, schema));
+ for (size_t i = 0; i < pgTypes.size(); ++i)
+ {
+ options.params->params[i].pflags = PARAM_FLAG_CONST;
+ options.params->params[i].ptype = pgTypes[i];
+ }
+ return arrow::Status::OK();
+ }
+
+ template <typename OnSuccessFunc>
+ arrow::Status execute(SPIPlanPtr plan,
+ const std::shared_ptr<arrow::RecordBatch>&
recordBatch,
+ SPIExecuteOptions& options,
+ OnSuccessFunc onSuccess)
+ {
+ const auto& columns = recordBatch->columns();
+ for (int64_t i = 0; i < recordBatch->num_rows(); ++i)
+ {
+ ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i,
columns, options));
+ auto result = SPI_execute_plan_extended(plan, &options);
+ if (result <= 0)
+ {
+ return arrow::Status::Invalid("failed to run a
prepared statement: ",
+
SPI_result_code_string(result),
+ ": ",
+ query_);
+ }
+ ARROW_RETURN_NOT_OK(onSuccess());
+ }
+ return arrow::Status::OK();
}
arrow::Status assign_parameters(
@@ -1101,18 +1183,7 @@ class PreparedStatement {
}
std::string query_;
-};
-
-struct Transaction {
- Transaction() { StartTransactionCommand(); }
-
- ~Transaction() { CommitTransactionCommand(); }
-};
-
-struct Snapshot {
- Snapshot() { PushActiveSnapshot(GetTransactionSnapshot()); }
-
- ~Snapshot() { PopActiveSnapshot(); }
+ std::vector<std::shared_ptr<arrow::RecordBatch>> parameters_;
};
class Executor : public WorkerProcessor {
@@ -1138,6 +1209,7 @@ class Executor : public WorkerProcessor {
void open()
{
+ const char* tag = "open";
// pg_usleep(5000000);
// pg_usleep(5000000);
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
opening").c_str());
@@ -1155,8 +1227,7 @@ class Executor : public WorkerProcessor {
if (!check_password(databaseName, userName, password,
clientAddress))
{
session_->initialized = true;
- P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
+ signal_server(tag);
return;
}
{
@@ -1169,18 +1240,19 @@ class Executor : public WorkerProcessor {
pgstat_report_activity(STATE_IDLE, NULL);
session_->initialized = true;
connected_ = true;
- P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
+ signal_server(tag);
}
void close() { close_internal(true); }
void signaled() override
{
- lock_acquire();
- auto action = session_->action;
- session_->action = Action::None;
- lock_release();
+ Action action;
+ {
+ ProcessorLockGuard lock(this);
+ action = session_->action;
+ session_->action = Action::None;
+ }
P("%s: %s: signaled: before: %s", Tag, tag_,
action_name(action));
PG_TRY();
{
@@ -1198,6 +1270,12 @@ class Executor : public WorkerProcessor {
case Action::ClosePreparedStatement:
close_prepared_statement();
break;
+ case Action::SetParameters:
+ set_parameters();
+ break;
+ case Action::SelectPreparedStatement:
+ select_prepared_statement();
+ break;
case Action::UpdatePreparedStatement:
update_prepared_statement();
break;
@@ -1211,14 +1289,16 @@ class Executor : public WorkerProcessor {
if (session_ &&
!DsaPointerIsValid(session_->errorMessage))
{
auto error = CopyErrorData();
- set_shared_string(session_->errorMessage,
- std::string("failed to run:
") + action_name(action) +
- ": " + error->message);
+ set_error_message(std::string("failed to run:
") + action_name(action) +
+ ": " + error->message,
+ "unexpected error");
FreeErrorData(error);
}
+ pgstat_report_activity(STATE_IDLE, NULL);
PG_RE_THROW();
}
PG_END_TRY();
+ pgstat_report_activity(STATE_IDLE, NULL);
P("%s: %s: signaled: after: %s", Tag, tag_,
action_name(action));
}
@@ -1228,8 +1308,32 @@ class Executor : public WorkerProcessor {
const char* peer_name(SessionData* session) override { return "server";
}
private:
+ void signal_server(const char* tag)
+ {
+ if (sharedData_->serverPID == InvalidPid)
+ {
+ return;
+ }
+ P("%s: %s: %s: kill server: %d", Tag, tag_, tag,
sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+
+ void set_error_message(const std::string& message, const char* tag)
+ {
+ if (DsaPointerIsValid(session_->errorMessage))
+ {
+ return;
+ }
+ {
+ ProcessorLockGuard lock(this);
+ set_shared_string(session_->errorMessage, message);
+ }
+ signal_server(tag);
+ }
+
void close_internal(bool unlockSession)
{
+ const char* tag = "close";
closed_ = true;
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
closing").c_str());
preparedStatements_.clear();
@@ -1244,17 +1348,13 @@ class Executor : public WorkerProcessor {
}
else
{
- if (!DsaPointerIsValid(session_->errorMessage))
- {
- set_shared_string(session_->errorMessage,
"failed to connect");
- }
+ set_error_message("failed to connect", tag);
session_->initialized = true;
if (unlockSession)
{
dshash_release_lock(sessions_, session_);
}
- P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
+ signal_server(tag);
}
if (CurrentResourceOwner)
{
@@ -1267,7 +1367,6 @@ class Executor : public WorkerProcessor {
resourceOwner, RESOURCE_RELEASE_AFTER_LOCKS,
false, true);
ResourceOwnerDelete(resourceOwner);
}
- pgstat_report_activity(STATE_IDLE, NULL);
}
bool check_password(const char* databaseName,
@@ -1275,6 +1374,7 @@ class Executor : public WorkerProcessor {
const char* password,
const char* clientAddress)
{
+ const char* tag = "check password";
MemoryContext memoryContext =
AllocSetContextCreate(CurrentMemoryContext,
"arrow-flight-sql:
Executor::check_password()",
@@ -1291,20 +1391,18 @@ class Executor : public WorkerProcessor {
hba_getauthmethod(&port);
if (!port.hba)
{
- set_shared_string(session_->errorMessage, "failed to
get auth method");
+ set_error_message("failed to get auth method", tag);
return false;
}
switch (port.hba->auth_method)
{
case uaMD5:
// TODO
- set_shared_string(session_->errorMessage,
- "MD5 auth method isn't
supported yet");
+ set_error_message("MD5 auth method isn't
supported yet", tag);
return false;
case uaSCRAM:
// TODO
- set_shared_string(session_->errorMessage,
- "SCRAM auth method isn't
supported yet");
+ set_error_message("SCRAM auth method isn't
supported yet", tag);
return false;
case uaPassword:
{
@@ -1312,18 +1410,16 @@ class Executor : public WorkerProcessor {
auto shadowPassword =
get_role_password(port.user_name, &logDetail);
if (!shadowPassword)
{
- set_shared_string(
- session_->errorMessage,
- std::string("failed to get
password: ") + logDetail);
+ set_error_message(std::string("failed
to get password: ") + logDetail,
+ tag);
return false;
}
auto result = plain_crypt_verify(
port.user_name, shadowPassword,
password, &logDetail);
if (result != STATUS_OK)
{
- set_shared_string(
- session_->errorMessage,
- std::string("failed to verify
password: ") + logDetail);
+ set_error_message(
+ std::string("failed to verify
password: ") + logDetail, tag);
return false;
}
return true;
@@ -1331,15 +1427,16 @@ class Executor : public WorkerProcessor {
case uaTrust:
return true;
default:
- set_shared_string(session_->errorMessage,
- std::string("unsupported auth
method: ") +
-
hba_authname(port.hba->auth_method));
+ set_error_message(std::string("unsupported auth
method: ") +
+
hba_authname(port.hba->auth_method),
+ tag);
return false;
}
}
bool fill_client_address(Port* port, const char* clientAddress)
{
+ const char* tag = "fill client address";
// clientAddress: "ipv4:127.0.0.1:40468"
// family: "ipv4"
// host: "127.0.0.1"
@@ -1353,9 +1450,8 @@ class Executor : public WorkerProcessor {
std::getline(clientAddressStream, clientPort);
if (!(clientFamily == "ipv4" || clientFamily == "ipv6"))
{
- set_shared_string(
- session_->errorMessage,
- std::string("client family must be ipv4 or
ipv6: ") + clientFamily);
+ set_error_message(
+ std::string("client family must be ipv4 or
ipv6: ") + clientFamily, tag);
return false;
}
auto clientPortStart = clientPort.c_str();
@@ -1363,21 +1459,19 @@ class Executor : public WorkerProcessor {
auto clientPortNumber = std::strtoul(clientPortStart,
&clientPortEnd, 10);
if (clientPortEnd[0] != '\0')
{
- set_shared_string(session_->errorMessage,
- std::string("client port is invalid:
") + clientPort);
+ set_error_message(std::string("client port is invalid:
") + clientPort, tag);
return false;
}
if (clientPortNumber == 0)
{
- set_shared_string(session_->errorMessage,
- std::string("client port must not
0"));
+ set_error_message(std::string("client port must not
0"), tag);
return false;
}
if (clientPortNumber > 65535)
{
- set_shared_string(session_->errorMessage,
- std::string("client port is too
large: ") +
- std::to_string(clientPortNumber));
+ set_error_message(std::string("client port is too
large: ") +
+ std::to_string(clientPortNumber),
+ tag);
return false;
}
if (clientFamily == "ipv4")
@@ -1388,9 +1482,8 @@ class Executor : public WorkerProcessor {
raddr->sin_port = htons(clientPortNumber);
if (inet_pton(AF_INET, clientHost.c_str(),
&(raddr->sin_addr)) == 0)
{
- set_shared_string(
- session_->errorMessage,
- std::string("client IPv4 address is
invalid: ") + clientHost);
+ set_error_message(
+ std::string("client IPv4 address is
invalid: ") + clientHost, tag);
return false;
}
}
@@ -1403,9 +1496,8 @@ class Executor : public WorkerProcessor {
raddr->sin6_flowinfo = 0;
if (inet_pton(AF_INET6, clientHost.c_str(),
&(raddr->sin6_addr)) == 0)
{
- set_shared_string(
- session_->errorMessage,
- std::string("client IPv6 address is
invalid: ") + clientHost);
+ set_error_message(
+ std::string("client IPv6 address is
invalid: ") + clientHost, tag);
return false;
}
raddr->sin6_scope_id = 0;
@@ -1415,66 +1507,60 @@ class Executor : public WorkerProcessor {
void select()
{
+ const char* tag = "select";
if (!DsaPointerIsValid(session_->selectQuery))
{
- lock_acquire();
- set_shared_string(
- session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": select" +
": query is missing");
- lock_release();
+ set_error_message(
+ std::string(Tag) + ": " + tag_ + ": " + tag +
": query is missing", tag);
return;
}
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
selecting").c_str());
- lock_acquire();
- std::string query(
- static_cast<const char*>(dsa_get_address(area_,
session_->selectQuery)));
- dsa_free(area_, session_->selectQuery);
- session_->selectQuery = InvalidDsaPointer;
- lock_release();
- P("%s: %s: select: %s", Tag, tag_, query.c_str());
+ std::string query;
+ {
+ ProcessorLockGuard lock(this);
+ query =
+ static_cast<const char*>(dsa_get_address(area_,
session_->selectQuery));
+ dsa_free(area_, session_->selectQuery);
+ session_->selectQuery = InvalidDsaPointer;
+ }
+ P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
{
- Transaction transaction;
- Snapshot snapshot;
+ ScopedTransaction scopedTransaction;
+ ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
auto result = SPI_execute(query.c_str(), true, 0);
- if (result == SPI_OK_SELECT)
+ if (result > 0)
{
- pgstat_report_activity(STATE_RUNNING,
- (std::string(Tag) + ":
select: writing").c_str());
- auto status = write();
- if (!status.ok())
+ pgstat_report_activity(
+ STATE_RUNNING, (std::string(Tag) + ": "
+ tag + ": writing").c_str());
+ auto status = write(tag);
+ if (status.ok())
{
- lock_acquire();
-
set_shared_string(session_->errorMessage, status.ToString());
- lock_release();
+ signal_server(tag);
+ }
+ else
+ {
+ set_error_message(std::string(Tag) + ":
" + tag_ + ": " + tag +
+ ": failed to
write: " + status.ToString(),
+ tag);
}
}
else
{
- lock_acquire();
- set_shared_string(session_->errorMessage,
- std::string(Tag) + ": " +
tag_ + ": select" +
+ set_error_message(std::string(Tag) + ": " +
tag_ + ": " + tag +
": failed to run a query:
<" + query +
- ">: " +
SPI_result_code_string(result));
- lock_release();
+ ">: " +
SPI_result_code_string(result),
+ tag);
}
}
-
- if (sharedData_->serverPID != InvalidPid)
- {
- P("%s: %s: select: kill server: %d", Tag, tag_,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
- }
-
- pgstat_report_activity(STATE_IDLE, NULL);
}
- arrow::Status write()
+ arrow::Status write(const char* tag)
{
SharedRingBufferOutputStream output(this, session_);
std::vector<PGArrowValueConverter> converters;
@@ -1500,9 +1586,9 @@ class Executor : public WorkerProcessor {
arrow::ipc::MakeStreamWriter(&output,
schema, options));
// Build an empty record batch to write schema.
ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush());
- P("%s: %s: write: schema: WriteRecordBatch", Tag, tag_);
+ P("%s: %s: %s: write: schema: WriteRecordBatch", Tag, tag_,
tag);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
- P("%s: %s: write: schema: Close", Tag, tag_);
+ P("%s: %s: %s: write: schema: Close", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->Close());
// Write another stream format data with record batches.
@@ -1511,17 +1597,19 @@ class Executor : public WorkerProcessor {
bool needLastFlush = false;
for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple)
{
- P("%s: %s: write: data: record batch: %d/%d",
+ P("%s: %s: %s: write: data: record batch: %d/%d",
Tag,
tag_,
+ tag,
iTuple,
SPI_processed);
for (int iAttribute = 0; iAttribute <
SPI_tuptable->tupdesc->natts;
++iAttribute)
{
- P("%s: %s: write: data: record batch: %d/%d:
%d/%d",
+ P("%s: %s: %s: write: data: record batch:
%d/%d: %d/%d",
Tag,
tag_,
+ tag,
iTuple,
SPI_processed,
iAttribute,
@@ -1546,9 +1634,10 @@ class Executor : public WorkerProcessor {
if (((iTuple + 1) % MaxNRowsPerRecordBatch) == 0)
{
ARROW_ASSIGN_OR_RAISE(recordBatch,
builder->Flush());
- P("%s: %s: write: data: WriteRecordBatch:
%d/%d",
+ P("%s: %s: %s: write: data: WriteRecordBatch:
%d/%d",
Tag,
tag_,
+ tag,
iTuple,
SPI_processed);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
@@ -1562,208 +1651,259 @@ class Executor : public WorkerProcessor {
if (needLastFlush)
{
ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
- P("%s: %s: write: data: WriteRecordBatch", Tag, tag_);
+ P("%s: %s: %s: write: data: WriteRecordBatch", Tag,
tag_, tag);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
}
- P("%s: %s: write: data: Close", Tag, tag_);
+ P("%s: %s: %s, write: data: Close", Tag, tag_, tag);
ARROW_RETURN_NOT_OK(writer->Close());
return output.Close();
}
void update()
{
+ const char* tag = "update";
if (!DsaPointerIsValid(session_->updateQuery))
{
- lock_acquire();
- set_shared_string(
- session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": update" +
": query is missing");
- lock_release();
+ set_error_message(
+ std::string(Tag) + ": " + tag_ + ": " + tag +
": query is missing", tag);
return;
}
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
updating").c_str());
- lock_acquire();
- std::string query(
- static_cast<const char*>(dsa_get_address(area_,
session_->updateQuery)));
- dsa_free(area_, session_->updateQuery);
- session_->updateQuery = InvalidDsaPointer;
- lock_release();
- P("%s: %s: update: %s", Tag, tag_, query.c_str());
+ std::string query;
+ {
+ ProcessorLockGuard lock(this);
+ query =
+ static_cast<const char*>(dsa_get_address(area_,
session_->updateQuery));
+ dsa_free(area_, session_->updateQuery);
+ session_->updateQuery = InvalidDsaPointer;
+ }
+ P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
{
- Transaction transaction;
- Snapshot snapshot;
+ ScopedTransaction scopedTransaction;
+ ScopedSnapshot scopedSnapshot;
SetCurrentStatementStartTimestamp();
auto result = SPI_execute(query.c_str(), false, 0);
- switch (result)
+ if (result > 0)
{
- case SPI_OK_INSERT:
- case SPI_OK_DELETE:
- case SPI_OK_UPDATE:
- session_->nUpdatedRecords =
SPI_processed;
- break;
- default:
- lock_acquire();
-
set_shared_string(session_->errorMessage,
- std::string(Tag) + ":
" + tag_ + ": update" +
- ": failed to run
a query: <" + query +
- ">: " +
SPI_result_code_string(result));
- lock_release();
- break;
+ session_->nUpdatedRecords = SPI_processed;
+ signal_server(tag);
+ }
+ else
+ {
+ set_error_message(std::string(Tag) + ": " +
tag_ + ": " + tag +
+ ": failed to run a query:
<" + query +
+ ">: " +
SPI_result_code_string(result),
+ tag);
}
}
-
- if (sharedData_->serverPID != InvalidPid)
- {
- P("%s: %s: update: kill server: %d", Tag, tag_,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
- }
-
- pgstat_report_activity(STATE_IDLE, NULL);
}
void prepare()
{
+ const char* tag = "prepare";
if (!DsaPointerIsValid(session_->prepareQuery))
{
- lock_acquire();
- set_shared_string(
- session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": prepare" +
": query is missing");
- lock_release();
+ set_error_message(
+ std::string(Tag) + ": " + tag_ + ": " + tag +
": query is missing", tag);
return;
}
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
preparing").c_str());
- lock_acquire();
- std::string query(
- static_cast<const char*>(dsa_get_address(area_,
session_->prepareQuery)));
- dsa_free(area_, session_->prepareQuery);
- session_->prepareQuery = InvalidDsaPointer;
- lock_release();
- P("%s: %s: prepare: %s", Tag, tag_, query.c_str());
+ std::string query;
+ {
+ ProcessorLockGuard lock(this);
+ query =
+ static_cast<const char*>(dsa_get_address(area_,
session_->prepareQuery));
+ dsa_free(area_, session_->prepareQuery);
+ session_->prepareQuery = InvalidDsaPointer;
+ }
+ P("%s: %s: %s: %s", Tag, tag_, tag, query.c_str());
std::string handle(std::to_string(nextPreparedStatementID_++));
- set_shared_string(session_->preparedStatementHandle, handle);
preparedStatements_.insert(
std::make_pair(handle,
PreparedStatement(std::move(query))));
+ {
+ ProcessorLockGuard lock(this);
+ set_shared_string(session_->preparedStatementHandle,
handle);
+ }
+ signal_server(tag);
+ }
+
+ bool extract_handle(std::string& handle, const char* tag)
+ {
+ if (!DsaPointerIsValid(session_->preparedStatementHandle))
+ {
+ set_error_message(
+ std::string(Tag) + ": " + tag_ + ": " + tag +
": handle is missing", tag);
+ return false;
+ }
- if (sharedData_->serverPID != InvalidPid)
{
- P("%s: %s: prepare: kill server: %d", Tag, tag_,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
+ ProcessorLockGuard lock(this);
+ handle = static_cast<const char*>(
+ dsa_get_address(area_,
session_->preparedStatementHandle));
+ dsa_free(area_, session_->preparedStatementHandle);
+ session_->preparedStatementHandle = InvalidDsaPointer;
}
- pgstat_report_activity(STATE_IDLE, nullptr);
+ return true;
}
void close_prepared_statement()
{
const char* tag = "close prepared statement";
- if (!DsaPointerIsValid(session_->preparedStatementHandle))
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": "
+ tag).c_str());
+
+ std::string handle;
+ if (!extract_handle(handle, tag))
{
- lock_acquire();
- set_shared_string(
- session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": " + tag +
": handle is missing");
- lock_release();
return;
}
+ P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
+ if (preparedStatements_.erase(handle) > 0)
+ {
+ signal_server(tag);
+ }
+ else
+ {
+ set_error_message(std::string(Tag) + ": " + tag_ + ": "
+ tag +
+ ": nonexistent handle: <" +
handle + ">",
+ tag);
+ }
+ }
- pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": "
+ tag).c_str());
+ PreparedStatement* find_prepared_statement(std::string& handle, const
char* tag)
+ {
+ if (!extract_handle(handle, tag))
+ {
+ return nullptr;
+ }
- lock_acquire();
- std::string handle(static_cast<const char*>(
- dsa_get_address(area_,
session_->preparedStatementHandle)));
- dsa_free(area_, session_->preparedStatementHandle);
- session_->preparedStatementHandle = InvalidDsaPointer;
- if (preparedStatements_.erase(handle) == 0)
+ ProcessorLockGuard lock(this);
+ auto it = preparedStatements_.find(handle);
+ if (it == preparedStatements_.end())
{
- set_shared_string(session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": "
+ tag +
- ": nonexistent handle: <" +
handle + ">");
+ set_error_message(std::string(Tag) + ": " + tag_ + ": "
+ tag +
+ ": nonexistent handle: <" +
handle + ">",
+ tag);
+ return nullptr;
}
- lock_release();
+ else
+ {
+ return &(it->second);
+ }
+ }
+
+ void set_parameters()
+ {
+ const char* tag = "set parameters";
+
+ pgstat_report_activity(STATE_RUNNING,
+ (std::string(Tag) + ": setting
parameters").c_str());
+ std::string handle;
+ auto preparedStatement = find_prepared_statement(handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
- if (sharedData_->serverPID != InvalidPid)
+ if (!preparedStatement)
{
- P("%s: %s: %s: kill server: %d", Tag, tag_, tag,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
+ return;
}
- pgstat_report_activity(STATE_IDLE, nullptr);
+ auto input =
std::make_shared<SharedRingBufferInputStream>(this, session_);
+ auto status = preparedStatement->set_parameters(input);
+ if (status.ok())
+ {
+ signal_server(tag);
+ }
+ else
+ {
+ set_error_message(std::string(Tag) + ": " + tag_ + ": "
+ tag +
+ ": failed to set parameters: <" +
handle +
+ ">: " + status.ToString(),
+ tag);
+ }
}
- void update_prepared_statement()
+ void select_prepared_statement()
{
- const char* tag = "update prepared statement";
+ const char* tag = "select prepared statement";
- if (!DsaPointerIsValid(session_->preparedStatementHandle))
+ pgstat_report_activity(
+ STATE_RUNNING, (std::string(Tag) + ": selecting
prepared statement").c_str());
+
+ std::string handle;
+ auto preparedStatement = find_prepared_statement(handle, tag);
+ P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
+
+ if (!preparedStatement)
{
- lock_acquire();
- set_shared_string(
- session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": " + tag + "
: handle is missing");
- lock_release();
return;
}
- pgstat_report_activity(
- STATE_RUNNING, (std::string(Tag) + ": updating prepared
statement").c_str());
+ ScopedTransaction scopedTransaction;
+ ScopedSnapshot scopedSnapshot;
- lock_acquire();
- std::string handle(static_cast<const char*>(
- dsa_get_address(area_,
session_->preparedStatementHandle)));
- dsa_free(area_, session_->preparedStatementHandle);
- session_->preparedStatementHandle = InvalidDsaPointer;
- PreparedStatement* preparedStatement = nullptr;
- auto it = preparedStatements_.find(handle);
- if (it == preparedStatements_.end())
+ struct Data {
+ Executor* executor;
+ const char* tag;
+ } data = {this, tag};
+ auto status = preparedStatement->select(
+ [](void* data) {
+ auto d = static_cast<Data*>(data);
+ return d->executor->write(d->tag);
+ },
+ &data);
+ if (status.ok())
{
- set_shared_string(session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": "
+ tag +
- ": nonexistent handle: <" +
handle + ">");
+ signal_server(tag);
}
else
{
- preparedStatement = &(it->second);
+ set_error_message(std::string(Tag) + ": " + tag_ + ": "
+ tag +
+ ": failed to select a prepared
statement: <" + handle +
+ ">: " + status.ToString(),
+ tag);
}
- lock_release();
+ }
+
+ void update_prepared_statement()
+ {
+ const char* tag = "update prepared statement";
+
+ pgstat_report_activity(
+ STATE_RUNNING, (std::string(Tag) + ": updating prepared
statement").c_str());
+
+ std::string handle;
+ auto preparedStatement = find_prepared_statement(handle, tag);
P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
- if (preparedStatement)
+ if (!preparedStatement)
{
- Transaction transaction;
- Snapshot snapshot;
-
- auto input =
std::make_shared<SharedRingBufferInputStream>(this, session_);
- auto n_updated_records_result =
preparedStatement->update(input);
- if (n_updated_records_result.ok())
- {
- session_->nUpdatedRecords =
*n_updated_records_result;
- }
- else
- {
- set_shared_string(
- session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ": " +
tag +
- ": failed to update a prepared
statement: <" + handle +
- ">: " +
n_updated_records_result.status().ToString());
- }
+ return;
}
- if (sharedData_->serverPID != InvalidPid)
+ ScopedTransaction scopedTransaction;
+ ScopedSnapshot scopedSnapshot;
+
+ auto input =
std::make_shared<SharedRingBufferInputStream>(this, session_);
+ auto n_updated_records_result =
preparedStatement->update(input);
+ if (n_updated_records_result.ok())
{
- P("%s: %s: %s: kill server: %d", Tag, tag_, tag,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
+ session_->nUpdatedRecords = *n_updated_records_result;
+ signal_server(tag);
+ }
+ else
+ {
+ set_error_message(std::string(Tag) + ": " + tag_ + ": "
+ tag +
+ ": failed to update a prepared
statement: <" + handle +
+ ">: " +
n_updated_records_result.status().ToString(),
+ tag);
}
-
- pgstat_report_activity(STATE_IDLE, nullptr);
}
uint64_t sessionID_;
@@ -2018,6 +2158,97 @@ class Proxy : public WorkerProcessor {
return arrow::Status::OK();
}
+ arrow::Status set_parameters(uint64_t sessionID,
+ const std::string& handle,
+ arrow::flight::FlightMessageReader* reader,
+ arrow::flight::FlightMetadataWriter*
writer)
+ {
+#ifdef AFS_DEBUG
+ const char* tag = "set parameters";
+#endif
+ auto session = find_session(sessionID);
+ SessionReleaser sessionReleaser(sessions_, session);
+ lock_acquire();
+ set_shared_string(session->preparedStatementHandle, handle);
+ session->action = Action::SetParameters;
+ lock_release();
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ ARROW_ASSIGN_OR_RAISE(const auto& schema,
reader->GetSchema());
+ SharedRingBufferOutputStream output(this, session);
+ auto options = arrow::ipc::IpcWriteOptions::Defaults();
+ options.emit_dictionary_deltas = true;
+ ARROW_ASSIGN_OR_RAISE(auto writer,
+
arrow::ipc::MakeStreamWriter(&output, schema, options));
+ while (true)
+ {
+ ARROW_ASSIGN_OR_RAISE(const auto& chunk,
reader->Next());
+ if (!chunk.data)
+ {
+ break;
+ }
+
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data)));
+ }
+ ARROW_RETURN_NOT_OK(writer->Close());
+ }
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ auto buffer = create_shared_ring_buffer(session);
+ std::unique_lock<std::mutex> lock(mutex_);
+ conditionVariable_.wait(lock, [&] {
+ P("%s: %s: %s: wait", Tag, tag_, tag);
+ return DsaPointerIsValid(session->errorMessage)
|| buffer.size() == 0;
+ });
+ }
+ if (DsaPointerIsValid(session->errorMessage))
+ {
+ return report_session_error(session);
+ }
+ P("%s: %s: %s: done", Tag, tag_, tag);
+ return arrow::Status::OK();
+ }
+
+ arrow::Result<std::shared_ptr<arrow::Schema>> select_prepared_statement(
+ uint64_t sessionID, const std::string& handle)
+ {
+ const char* tag = "select prepared statement";
+ auto session = find_session(sessionID);
+ SessionReleaser sessionReleaser(sessions_, session);
+ lock_acquire();
+ set_shared_string(session->preparedStatementHandle, handle);
+ session->action = Action::SelectPreparedStatement;
+ lock_release();
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ auto buffer = create_shared_ring_buffer(session);
+ std::unique_lock<std::mutex> lock(mutex_);
+ conditionVariable_.wait(lock, [&] {
+ P("%s: %s: %s: wait", Tag, tag_, tag);
+ return DsaPointerIsValid(session->errorMessage)
|| buffer.size() > 0;
+ });
+ }
+ if (DsaPointerIsValid(session->errorMessage))
+ {
+ return report_session_error(session);
+ }
+ P("%s: %s: %s: open", Tag, tag_, tag);
+ auto schema = read_schema(session, tag);
+ P("%s: %s: %s: schema", Tag, tag_, tag);
+ return schema;
+ }
+
arrow::Result<int64_t> update_prepared_statement(
uint64_t sessionID,
const std::string& handle,
@@ -2476,6 +2707,37 @@ class FlightSQLServer : public
arrow::flight::sql::FlightSqlServerBase {
return proxy_->close_prepared_statement(sessionID, handle);
}
+ arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
+ GetFlightInfoPreparedStatement(
+ const arrow::flight::ServerCallContext& context,
+ const arrow::flight::sql::PreparedStatementQuery& command,
+ const arrow::flight::FlightDescriptor& descriptor) override
+ {
+ ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+ const auto& handle = command.prepared_statement_handle;
+ ARROW_ASSIGN_OR_RAISE(auto schema,
+
proxy_->select_prepared_statement(sessionID, handle));
+ ARROW_ASSIGN_OR_RAISE(auto ticket,
+
arrow::flight::sql::CreateStatementQueryTicket(handle));
+ std::vector<arrow::flight::FlightEndpoint> endpoints{
+ arrow::flight::FlightEndpoint{std::move(ticket), {}}};
+ ARROW_ASSIGN_OR_RAISE(
+ auto result,
+ arrow::flight::FlightInfo::Make(*schema, descriptor,
endpoints, -1, -1));
+ return std::make_unique<arrow::flight::FlightInfo>(result);
+ }
+
+ arrow::Status DoPutPreparedStatementQuery(
+ const arrow::flight::ServerCallContext& context,
+ const arrow::flight::sql::PreparedStatementQuery& command,
+ arrow::flight::FlightMessageReader* reader,
+ arrow::flight::FlightMetadataWriter* writer) override
+ {
+ ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+ const auto& handle = command.prepared_statement_handle;
+ return proxy_->set_parameters(sessionID, handle, reader,
writer);
+ }
+
arrow::Result<int64_t> DoPutPreparedStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementUpdate& command,
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 4465d56..429fa95 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -67,7 +67,7 @@ class FlightSQLTest < Test::Unit::TestCase
data("string - varchar", ["varchar(10)", :string, "b"])
data("binary", ["bytea", :binary, "\x0".b])
data("timestamp", ["timestamp", [:timestamp, :micro], timestamp_value])
- def test_select_type
+ def test_select_direct
pg_type, data_type, value = data
data_type = Arrow::DataType.resolve(data_type)
values = data_type.build_array([value])
@@ -81,6 +81,35 @@ class FlightSQLTest < Test::Unit::TestCase
reader.read_all)
end
+ data("int16", [:int16, -2])
+ data("int32", [:int32, -2])
+ data("int64", [:int64, -2])
+ data("float", [:float, -2.2])
+ data("double", [:double, -2.2])
+ data("string", [:string, "b"])
+ data("binary", [:binary, "\x0".b])
+ data("timestamp", [[:timestamp, :micro], timestamp_value])
+ def test_select_prepare
+ unless flight_sql_client.respond_to?(:prepare)
+ omit("red-arrow-flight-sql 14.0.0 or later is required")
+ end
+
+ data_type, value = data
+ data_type = Arrow::DataType.resolve(data_type)
+ values = data_type.build_array([value])
+ flight_sql_client.prepare("SELECT $1 AS value",
+ @options) do |statement|
+ statement.record_batch = Arrow::RecordBatch.new(value: values)
+ info = statement.execute(@options)
+ assert_equal(Arrow::Schema.new(value: values.value_data_type),
+ info.get_schema)
+ endpoint = info.endpoints.first
+ reader = flight_sql_client.do_get(endpoint.ticket, @options)
+ assert_equal(Arrow::Table.new(value: values),
+ reader.read_all)
+ end
+ end
+
def test_select_from
run_sql("CREATE TABLE data (value integer)")
run_sql("INSERT INTO data VALUES (1), (-2), (3)")
@@ -144,7 +173,7 @@ SELECT * FROM data
["timestamp", [:timestamp, :micro], timestamp_values])
data("timestamp(nano)",
["timestamp", [:timestamp, :nano], timestamp_values])
- def test_insert_type
+ def test_insert_prepare
unless flight_sql_client.respond_to?(:prepare)
omit("red-arrow-flight-sql 14.0.0 or later is required")
end