This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git
The following commit(s) were added to refs/heads/main by this push:
new 6c7a8aa Add support for prepared INSERT (#63)
6c7a8aa is described below
commit 6c7a8aabe8f11ee714017c569d834ef710ab8145
Author: Sutou Kouhei <[email protected]>
AuthorDate: Tue Aug 22 14:06:09 2023 +0900
Add support for prepared INSERT (#63)
Closes GH-62
---
src/afs.cc | 1098 ++++++++++++++++++++++++++++++++++++-----------
test/helper/sandbox.rb | 26 +-
test/test-flight-sql.rb | 29 +-
3 files changed, 899 insertions(+), 254 deletions(-)
diff --git a/src/afs.cc b/src/afs.cc
index d267a43..c3befe9 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -58,6 +58,7 @@ extern "C"
#include <condition_variable>
#include <fstream>
#include <iterator>
+#include <map>
#include <random>
#include <sstream>
@@ -167,6 +168,7 @@ struct SharedRingBufferData {
size_t tail;
};
+// Naive ring buffer implementation. We can improve this later.
class SharedRingBuffer {
public:
static void initialize_data(SharedRingBufferData* data)
@@ -333,6 +335,55 @@ class SharedRingBuffer {
}
};
+enum class Action
+{
+ None,
+ Select,
+ Update,
+ Prepare,
+ ClosePreparedStatement,
+ UpdatePreparedStatement,
+};
+
+const char*
+action_name(Action action)
+{
+ switch (action)
+ {
+ case Action::None:
+ return "Action::None";
+ case Action::Select:
+ return "Action::Select";
+ case Action::Update:
+ return "Action::Update";
+ case Action::Prepare:
+ return "Action::Prepare";
+ case Action::ClosePreparedStatement:
+ return "Action::ClosePreparedStatement";
+ case Action::UpdatePreparedStatement:
+ return "Action::UpdatePreparedStatement";
+ default:
+ return "Action::Unknown";
+ }
+}
+
+void
+dsa_pointer_set_string(dsa_pointer& pointer, dsa_area* area, const
std::string& input)
+{
+ if (DsaPointerIsValid(pointer))
+ {
+ dsa_free(area, pointer);
+ pointer = InvalidDsaPointer;
+ }
+ if (input.empty())
+ {
+ return;
+ }
+ pointer = dsa_allocate(area, input.size() + 1);
+ memcpy(dsa_get_address(area, pointer), input.c_str(), input.size() + 1);
+}
+
+// Put only data (don't add methods) to use with dshash.
struct SessionData {
uint64_t id;
dsa_pointer errorMessage;
@@ -342,12 +393,61 @@ struct SessionData {
dsa_pointer userName;
dsa_pointer password;
dsa_pointer clientAddress;
+ Action action;
dsa_pointer selectQuery;
dsa_pointer updateQuery;
int64_t nUpdatedRecords;
+ dsa_pointer prepareQuery;
+ dsa_pointer preparedStatementHandle;
SharedRingBufferData bufferData;
};
+void
+session_data_initialize(SessionData* session,
+ dsa_area* area,
+ const std::string& databaseName,
+ const std::string& userName,
+ const std::string& password,
+ const std::string& clientAddress)
+{
+ session->errorMessage = InvalidDsaPointer;
+ session->executorPID = InvalidPid;
+ session->initialized = false;
+ dsa_pointer_set_string(session->databaseName, area, databaseName);
+ dsa_pointer_set_string(session->userName, area, userName);
+ dsa_pointer_set_string(session->password, area, password);
+ dsa_pointer_set_string(session->clientAddress, area, clientAddress);
+ session->action = Action::None;
+ session->selectQuery = InvalidDsaPointer;
+ session->updateQuery = InvalidDsaPointer;
+ session->nUpdatedRecords = -1;
+ session->prepareQuery = InvalidDsaPointer;
+ session->preparedStatementHandle = InvalidDsaPointer;
+ SharedRingBuffer::initialize_data(&(session->bufferData));
+}
+
+void
+session_data_finalize(SessionData* session, dsa_area* area)
+{
+ if (DsaPointerIsValid(session->errorMessage))
+ dsa_free(area, session->errorMessage);
+ if (DsaPointerIsValid(session->databaseName))
+ dsa_free(area, session->databaseName);
+ if (DsaPointerIsValid(session->userName))
+ dsa_free(area, session->userName);
+ if (DsaPointerIsValid(session->password))
+ dsa_free(area, session->password);
+ if (DsaPointerIsValid(session->selectQuery))
+ dsa_free(area, session->selectQuery);
+ if (DsaPointerIsValid(session->updateQuery))
+ dsa_free(area, session->updateQuery);
+ if (DsaPointerIsValid(session->prepareQuery))
+ dsa_free(area, session->prepareQuery);
+ if (DsaPointerIsValid(session->preparedStatementHandle))
+ dsa_free(area, session->preparedStatementHandle);
+ SharedRingBuffer::free_data(&(session->bufferData), area);
+}
+
class SessionReleaser {
public:
explicit SessionReleaser(dshash_table* sessions, SessionData* data)
@@ -381,8 +481,15 @@ struct SharedData {
class Processor {
public:
- Processor(const char* tag)
+ enum class WaitMode
+ {
+ Read,
+ Written,
+ };
+
+ Processor(const char* tag, bool runInPGThread)
: tag_(tag),
+ runInPGThread_(runInPGThread),
sharedData_(nullptr),
area_(nullptr),
lock_(),
@@ -394,16 +501,95 @@ class Processor {
virtual ~Processor()
{
if (area_)
+ {
dsa_detach(area_);
+ }
}
const char* tag() { return tag_; }
- void lock_acquire(LWLockMode mode) { LWLockAcquire(lock_,
LW_EXCLUSIVE); }
+ void lock_acquire() { LWLockAcquire(lock_, LW_EXCLUSIVE); }
void lock_release() { LWLockRelease(lock_); }
- void signaled()
+ SharedRingBuffer create_shared_ring_buffer(SessionData* session)
+ {
+ return SharedRingBuffer(&(session->bufferData), area_);
+ }
+
+ arrow::Status wait(SessionData* session, SharedRingBuffer* buffer,
WaitMode mode)
+ {
+ const bool read = (mode == WaitMode::Read);
+ const char* tag = read ? "wait read" : "wait written";
+ auto peerPID = peer_pid(session);
+ auto peerName = peer_name(session);
+
+ if (ARROW_PREDICT_FALSE(peerPID == InvalidPid))
+ {
+ return arrow::Status::IOError(
+ Tag, ": ", tag_, ": ", tag, ": ", peerName, ":
not alive");
+ }
+
+ P("%s: %s: %s: %s: kill: %d", Tag, tag_, tag, peerName,
peerPID);
+ kill(peerPID, SIGUSR1);
+ auto get_target_size =
+ read ? [](SharedRingBuffer* buffer) { return
buffer->rest_size(); }
+ : [](SharedRingBuffer* buffer) { return
buffer->size(); };
+ auto targetSize = get_target_size(buffer);
+ if (runInPGThread_)
+ {
+ while (true)
+ {
+ int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
+ WaitLatch(MyLatch, events, -1,
PG_WAIT_EXTENSION);
+ if (GotSIGTERM)
+ {
+ break;
+ }
+ ResetLatch(MyLatch);
+
+ if (GotSIGUSR1)
+ {
+ GotSIGUSR1 = false;
+ P("%s: %s: %s: %s: wait: %d:%d",
+ Tag,
+ tag_,
+ tag,
+ peerName,
+ get_target_size(buffer),
+ targetSize);
+ if (get_target_size(buffer) !=
targetSize)
+ {
+ break;
+ }
+ }
+
+ // TODO: Convert PG error to arrow::Status.
+ CHECK_FOR_INTERRUPTS();
+ }
+ }
+ else
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ conditionVariable_.wait(lock, [&] {
+ P("%s: %s: %s: %s: wait: %d:%d",
+ Tag,
+ tag_,
+ tag,
+ peerName,
+ get_target_size(buffer),
+ targetSize);
+ if (INTERRUPTS_PENDING_CONDITION())
+ {
+ return true;
+ }
+ return get_target_size(buffer) != targetSize;
+ });
+ }
+ return arrow::Status::OK();
+ }
+
+ virtual void signaled()
{
P("%s: %s: signaled: before", Tag, tag_);
conditionVariable_.notify_all();
@@ -413,20 +599,15 @@ class Processor {
protected:
void set_shared_string(dsa_pointer& pointer, const std::string& input)
{
- if (DsaPointerIsValid(pointer))
- {
- dsa_free(area_, pointer);
- pointer = InvalidDsaPointer;
- }
- if (input.empty())
- {
- return;
- }
- pointer = dsa_allocate(area_, input.size() + 1);
- memcpy(dsa_get_address(area_, pointer), input.c_str(),
input.size() + 1);
+ dsa_pointer_set_string(pointer, area_, input);
}
+ virtual pid_t peer_pid(SessionData* session) { return InvalidPid; }
+
+ virtual const char* peer_name(SessionData* session) { return "unknown";
}
+
const char* tag_;
+ bool runInPGThread_;
SharedData* sharedData_;
dsa_area* area_;
LWLock* lock_;
@@ -437,9 +618,9 @@ class Processor {
class Proxy;
class SharedRingBufferInputStream : public arrow::io::InputStream {
public:
- SharedRingBufferInputStream(Proxy* proxy, SessionData* session)
+ SharedRingBufferInputStream(Processor* processor, SessionData* session)
: arrow::io::InputStream(),
- proxy_(proxy),
+ processor_(processor),
session_(session),
position_(0),
is_open_(true)
@@ -468,7 +649,7 @@ class SharedRingBufferInputStream : public
arrow::io::InputStream {
}
private:
- Proxy* proxy_;
+ Processor* processor_;
SessionData* session_;
int64_t position_;
bool is_open_;
@@ -477,8 +658,12 @@ class SharedRingBufferInputStream : public
arrow::io::InputStream {
class Executor;
class SharedRingBufferOutputStream : public arrow::io::OutputStream {
public:
- SharedRingBufferOutputStream(Executor* executor)
- : arrow::io::OutputStream(), executor_(executor), position_(0),
is_open_(true)
+ SharedRingBufferOutputStream(Processor* processor, SessionData* session)
+ : arrow::io::OutputStream(),
+ processor_(processor),
+ session_(session),
+ position_(0),
+ is_open_(true)
{
}
@@ -497,14 +682,16 @@ class SharedRingBufferOutputStream : public
arrow::io::OutputStream {
using arrow::io::OutputStream::Write;
private:
- Executor* executor_;
+ Processor* processor_;
+ SessionData* session_;
int64_t position_;
bool is_open_;
};
class WorkerProcessor : public Processor {
public:
- explicit WorkerProcessor(const char* tag) : Processor(tag),
sessions_(nullptr)
+ explicit WorkerProcessor(const char* tag, bool runInPGThread)
+ : Processor(tag, runInPGThread), sessions_(nullptr)
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
@@ -530,19 +717,7 @@ class WorkerProcessor : public Processor {
protected:
void delete_session(SessionData* session)
{
- if (DsaPointerIsValid(session->errorMessage))
- dsa_free(area_, session->errorMessage);
- if (DsaPointerIsValid(session->databaseName))
- dsa_free(area_, session->databaseName);
- if (DsaPointerIsValid(session->userName))
- dsa_free(area_, session->userName);
- if (DsaPointerIsValid(session->password))
- dsa_free(area_, session->password);
- if (DsaPointerIsValid(session->selectQuery))
- dsa_free(area_, session->selectQuery);
- if (DsaPointerIsValid(session->updateQuery))
- dsa_free(area_, session->updateQuery);
- SharedRingBuffer::free_data(&(session->bufferData), area_);
+ session_data_finalize(session, area_);
dshash_delete_entry(sessions_, session);
}
@@ -550,14 +725,153 @@ class WorkerProcessor : public Processor {
dshash_table* sessions_;
};
+class ArrowPGValueConverter : public arrow::ArrayVisitor {
+ public:
+ explicit ArrowPGValueConverter(int64_t i_row, Datum& datum)
+ : i_row_(i_row), datum_(datum)
+ {
+ }
+
+ arrow::Status Visit(const arrow::Int32Array& array)
+ {
+ datum_ = Int32GetDatum(array.Value(i_row_));
+ return arrow::Status::OK();
+ }
+
+ private:
+ int64_t i_row_;
+ Datum& datum_;
+};
+
+class PreparedStatement {
+ public:
+ explicit PreparedStatement(std::string query) :
query_(std::move(query)) {}
+
+ ~PreparedStatement() {}
+
+ arrow::Result<int64_t>
update(std::shared_ptr<SharedRingBufferInputStream>& input)
+ {
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+
arrow::ipc::RecordBatchStreamReader::Open(input));
+ const auto& schema = reader->schema();
+ SPIExecuteOptions options = {};
+ if (schema->num_fields() > 0)
+ {
+ options.params = makeParamList(schema->num_fields());
+ }
+ options.read_only = false;
+ options.tcount = 0;
+ ARROW_ASSIGN_OR_RAISE(auto pgTypes, create_pg_types(schema));
+ for (size_t i = 0; i < pgTypes.size(); ++i)
+ {
+ options.params->params[i].pflags = PARAM_FLAG_CONST;
+ options.params->params[i].ptype = pgTypes[i];
+ }
+ auto plan = SPI_prepare(query_.c_str(), pgTypes.size(),
pgTypes.data());
+ struct PlanFinalizer {
+ PlanFinalizer(SPIPlanPtr plan) : plan_(plan) {}
+ ~PlanFinalizer() { SPI_freeplan(plan_); }
+ SPIPlanPtr plan_;
+ } planFinalizer(plan);
+
+ int64_t nUpdatedRecords = 0;
+ while (true)
+ {
+ std::shared_ptr<arrow::RecordBatch> recordBatch;
+ ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
+ if (!recordBatch)
+ {
+ break;
+ }
+ const auto& columns = recordBatch->columns();
+ for (int64_t i = 0; i < recordBatch->num_rows(); ++i)
+ {
+
ARROW_RETURN_NOT_OK(assign_parameters(recordBatch, i, columns, options));
+ auto result = SPI_execute_plan_extended(plan,
&options);
+ switch (result)
+ {
+ case SPI_OK_INSERT:
+ case SPI_OK_DELETE:
+ case SPI_OK_UPDATE:
+ break;
+ default:
+ return arrow::Status::Invalid(
+ "failed to run a
prepared statement: ",
+
SPI_result_code_string(result));
+ break;
+ }
+ nUpdatedRecords += SPI_processed;
+ }
+ }
+ return nUpdatedRecords;
+ }
+
+ private:
+ arrow::Result<std::vector<Oid>> create_pg_types(
+ const std::shared_ptr<arrow::Schema>& schema)
+ {
+ std::vector<Oid> pgTypes;
+ for (const auto& field : schema->fields())
+ {
+ switch (field->type()->id())
+ {
+ case arrow::Type::INT32:
+ pgTypes.push_back(INT4OID);
+ break;
+ default:
+ return arrow::Status::NotImplemented(
+ "Unsupported Apache Arrow type:
", field->type()->name());
+ }
+ }
+ return std::move(pgTypes);
+ }
+
+ arrow::Status assign_parameters(
+ const std::shared_ptr<arrow::RecordBatch>& recordBatch,
+ int64_t i_row,
+ const std::vector<std::shared_ptr<arrow::Array>>& columns,
+ SPIExecuteOptions& options)
+ {
+ int64_t i_column = 0;
+ for (const auto& column : columns)
+ {
+ auto param = &(options.params->params[i_column]);
+ param->isnull = column->IsNull(i_row);
+ if (!param->isnull)
+ {
+ ArrowPGValueConverter converter(i_row,
param->value);
+ ARROW_RETURN_NOT_OK(column->Accept(&converter));
+ }
+ ++i_column;
+ }
+ return arrow::Status::OK();
+ }
+
+ std::string query_;
+};
+
+struct Transaction {
+ Transaction() { StartTransactionCommand(); }
+
+ ~Transaction() { CommitTransactionCommand(); }
+};
+
+struct Snapshot {
+ Snapshot() { PushActiveSnapshot(GetTransactionSnapshot()); }
+
+ ~Snapshot() { PopActiveSnapshot(); }
+};
+
class Executor : public WorkerProcessor {
public:
explicit Executor(uint64_t sessionID)
- : WorkerProcessor("executor"),
+ : WorkerProcessor("executor", true),
sessionID_(sessionID),
session_(nullptr),
connected_(false),
- closed_(false)
+ closed_(false),
+ nextPreparedStatementID_(1),
+ preparedStatements_()
{
}
@@ -597,7 +911,7 @@ class Executor : public WorkerProcessor {
// TODO: Customizable.
buffer.allocate(1L * 1024L * 1024L);
}
- StartTransactionCommand();
+ SetCurrentStatementStartTimestamp();
SPI_connect();
pgstat_report_activity(STATE_IDLE, NULL);
session_->initialized = true;
@@ -608,87 +922,67 @@ class Executor : public WorkerProcessor {
void close() { close_internal(true); }
- SharedRingBuffer create_shared_ring_buffer()
+ void signaled() override
{
- return SharedRingBuffer(&(session_->bufferData), area_);
- }
-
- void wait_server_read(SharedRingBuffer* buffer)
- {
- if (ARROW_PREDICT_FALSE(sharedData_->serverPID == InvalidPid))
- {
- ereport(ERROR,
- errcode(ERRCODE_INTERNAL_ERROR),
- errmsg("%s: %s: server isn't alive", Tag,
tag_));
- }
-
- P("%s: %s: %s: kill server: %d", Tag, tag_, AFS_FUNC,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
- auto restSize = buffer->rest_size();
- while (true)
+ lock_acquire();
+ auto action = session_->action;
+ session_->action = Action::None;
+ lock_release();
+ P("%s: %s: signaled: before: %s", Tag, tag_,
action_name(action));
+ PG_TRY();
{
- int events = WL_LATCH_SET | WL_EXIT_ON_PM_DEATH;
- WaitLatch(MyLatch, events, -1, PG_WAIT_EXTENSION);
- if (GotSIGTERM)
+ switch (action)
{
- break;
- }
- ResetLatch(MyLatch);
-
- if (GotSIGUSR1)
- {
- GotSIGUSR1 = false;
- P("%s: %s: %s: wait: read: %d:%d",
- Tag,
- tag_,
- AFS_FUNC,
- buffer->rest_size(),
- restSize);
- if (buffer->rest_size() != restSize)
- {
+ case Action::Select:
+ select();
+ break;
+ case Action::Update:
+ update();
+ break;
+ case Action::Prepare:
+ prepare();
+ break;
+ case Action::ClosePreparedStatement:
+ close_prepared_statement();
+ break;
+ case Action::UpdatePreparedStatement:
+ update_prepared_statement();
+ break;
+ deafult:
+ Processor::signaled();
break;
- }
}
-
- CHECK_FOR_INTERRUPTS();
}
- }
-
- void signaled()
- {
- P("%s: %s: signaled: before: %d/%d",
- Tag,
- tag_,
- session_->selectQuery,
- session_->updateQuery);
- if (DsaPointerIsValid(session_->selectQuery))
+ PG_CATCH();
{
- select();
- }
- else if (DsaPointerIsValid(session_->updateQuery))
- {
- update();
- }
- else
- {
- Processor::signaled();
+ if (session_ &&
!DsaPointerIsValid(session_->errorMessage))
+ {
+ auto error = CopyErrorData();
+ set_shared_string(session_->errorMessage,
+ std::string("failed to run:
") + action_name(action) +
+ ": " + error->message);
+ FreeErrorData(error);
+ }
+ PG_RE_THROW();
}
- P("%s: %s: signaled: after: %d/%d",
- Tag,
- tag_,
- session_->selectQuery,
- session_->updateQuery);
+ PG_END_TRY();
+ P("%s: %s: signaled: after: %s", Tag, tag_,
action_name(action));
}
+ protected:
+ pid_t peer_pid(SessionData* session) override { return
sharedData_->serverPID; }
+
+ const char* peer_name(SessionData* session) override { return "server";
}
+
private:
void close_internal(bool unlockSession)
{
closed_ = true;
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
closing").c_str());
+ preparedStatements_.clear();
if (connected_)
{
SPI_finish();
- CommitTransactionCommand();
{
SharedRingBuffer
buffer(&(session_->bufferData), area_);
buffer.free();
@@ -868,89 +1162,59 @@ class Executor : public WorkerProcessor {
void select()
{
- pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
selecting").c_str());
+ if (!DsaPointerIsValid(session_->selectQuery))
+ {
+ lock_acquire();
+ set_shared_string(
+ session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": select" +
": query is missing");
+ lock_release();
+ return;
+ }
- PushActiveSnapshot(GetTransactionSnapshot());
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
selecting").c_str());
- LWLockAcquire(lock_, LW_EXCLUSIVE);
+ lock_acquire();
std::string query(
static_cast<const char*>(dsa_get_address(area_,
session_->selectQuery)));
dsa_free(area_, session_->selectQuery);
session_->selectQuery = InvalidDsaPointer;
- SetCurrentStatementStartTimestamp();
+ lock_release();
P("%s: %s: select: %s", Tag, tag_, query.c_str());
- auto result = SPI_execute(query.c_str(), true, 0);
- LWLockRelease(lock_);
-
- if (result == SPI_OK_SELECT)
- {
- pgstat_report_activity(STATE_RUNNING,
- (std::string(Tag) + ": select:
writing").c_str());
- auto status = write();
- if (!status.ok())
- {
- set_shared_string(session_->errorMessage,
status.ToString());
- }
- }
- else
- {
- set_shared_string(session_->errorMessage,
- std::string(Tag) + ": " + tag_ + ":
select" +
- ": failed to run a query: <" +
query +
- ">: " +
SPI_result_code_string(result));
- }
- PopActiveSnapshot();
-
- if (sharedData_->serverPID != InvalidPid)
{
- P("%s: %s: select: kill server: %d", Tag, tag_,
sharedData_->serverPID);
- kill(sharedData_->serverPID, SIGUSR1);
- }
-
- pgstat_report_activity(STATE_IDLE, NULL);
- }
-
- void update()
- {
- pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
updating").c_str());
+ Transaction transaction;
+ Snapshot snapshot;
- PushActiveSnapshot(GetTransactionSnapshot());
+ SetCurrentStatementStartTimestamp();
+ auto result = SPI_execute(query.c_str(), true, 0);
- LWLockAcquire(lock_, LW_EXCLUSIVE);
- std::string query(
- static_cast<const char*>(dsa_get_address(area_,
session_->updateQuery)));
- dsa_free(area_, session_->updateQuery);
- session_->updateQuery = InvalidDsaPointer;
- SetCurrentStatementStartTimestamp();
- P("%s: %s: update: %s", Tag, tag_, query.c_str());
- auto result = SPI_execute(query.c_str(), false, 0);
- LWLockRelease(lock_);
-
- switch (result)
- {
- case SPI_OK_INSERT:
- case SPI_OK_DELETE:
- case SPI_OK_UPDATE:
- session_->nUpdatedRecords = SPI_processed;
- break;
- default:
+ if (result == SPI_OK_SELECT)
+ {
+ pgstat_report_activity(STATE_RUNNING,
+ (std::string(Tag) + ":
select: writing").c_str());
+ auto status = write();
+ if (!status.ok())
+ {
+ lock_acquire();
+
set_shared_string(session_->errorMessage, status.ToString());
+ lock_release();
+ }
+ }
+ else
+ {
+ lock_acquire();
set_shared_string(session_->errorMessage,
- std::string(Tag) + ": " +
tag_ + ": update" +
+ std::string(Tag) + ": " +
tag_ + ": select" +
": failed to run a query:
<" + query +
">: " +
SPI_result_code_string(result));
- break;
+ lock_release();
+ }
}
- PopActiveSnapshot();
-
- // TODO: Is this usage correct?
- CommitTransactionCommand();
- StartTransactionCommand();
-
if (sharedData_->serverPID != InvalidPid)
{
- P("%s: %s: update: kill server: %d", Tag, tag_,
sharedData_->serverPID);
+ P("%s: %s: select: kill server: %d", Tag, tag_,
sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
}
@@ -959,7 +1223,7 @@ class Executor : public WorkerProcessor {
arrow::Status write()
{
- SharedRingBufferOutputStream output(this);
+ SharedRingBufferOutputStream output(this, session_);
std::vector<std::shared_ptr<arrow::Field>> fields;
for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
{
@@ -981,12 +1245,12 @@ class Executor : public WorkerProcessor {
ARROW_ASSIGN_OR_RAISE(
auto builder,
arrow::RecordBatchBuilder::Make(schema,
arrow::default_memory_pool()));
- auto option = arrow::ipc::IpcWriteOptions::Defaults();
- option.emit_dictionary_deltas = true;
+ auto options = arrow::ipc::IpcWriteOptions::Defaults();
+ options.emit_dictionary_deltas = true;
// Write schema only stream format data to return only schema.
ARROW_ASSIGN_OR_RAISE(auto writer,
- arrow::ipc::MakeStreamWriter(&output,
schema, option));
+ arrow::ipc::MakeStreamWriter(&output,
schema, options));
// Build an empty record batch to write schema.
ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush());
P("%s: %s: write: schema: WriteRecordBatch", Tag, tag_);
@@ -996,7 +1260,7 @@ class Executor : public WorkerProcessor {
// Write another stream format data with record batches.
ARROW_ASSIGN_OR_RAISE(writer,
- arrow::ipc::MakeStreamWriter(&output,
schema, option));
+ arrow::ipc::MakeStreamWriter(&output,
schema, options));
bool needLastFlush = false;
for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple)
{
@@ -1060,10 +1324,208 @@ class Executor : public WorkerProcessor {
return output.Close();
}
+ void update()
+ {
+ if (!DsaPointerIsValid(session_->updateQuery))
+ {
+ lock_acquire();
+ set_shared_string(
+ session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": update" +
": query is missing");
+ lock_release();
+ return;
+ }
+
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
updating").c_str());
+
+ lock_acquire();
+ std::string query(
+ static_cast<const char*>(dsa_get_address(area_,
session_->updateQuery)));
+ dsa_free(area_, session_->updateQuery);
+ session_->updateQuery = InvalidDsaPointer;
+ lock_release();
+ P("%s: %s: update: %s", Tag, tag_, query.c_str());
+
+ {
+ Transaction transaction;
+ Snapshot snapshot;
+
+ SetCurrentStatementStartTimestamp();
+ auto result = SPI_execute(query.c_str(), false, 0);
+ switch (result)
+ {
+ case SPI_OK_INSERT:
+ case SPI_OK_DELETE:
+ case SPI_OK_UPDATE:
+ session_->nUpdatedRecords =
SPI_processed;
+ break;
+ default:
+ lock_acquire();
+
set_shared_string(session_->errorMessage,
+ std::string(Tag) + ":
" + tag_ + ": update" +
+ ": failed to run
a query: <" + query +
+ ">: " +
SPI_result_code_string(result));
+ lock_release();
+ break;
+ }
+ }
+
+ if (sharedData_->serverPID != InvalidPid)
+ {
+ P("%s: %s: update: kill server: %d", Tag, tag_,
sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+
+ pgstat_report_activity(STATE_IDLE, NULL);
+ }
+
+ void prepare()
+ {
+ if (!DsaPointerIsValid(session_->prepareQuery))
+ {
+ lock_acquire();
+ set_shared_string(
+ session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": prepare" +
": query is missing");
+ lock_release();
+ return;
+ }
+
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ":
preparing").c_str());
+
+ lock_acquire();
+ std::string query(
+ static_cast<const char*>(dsa_get_address(area_,
session_->prepareQuery)));
+ dsa_free(area_, session_->prepareQuery);
+ session_->prepareQuery = InvalidDsaPointer;
+ lock_release();
+ P("%s: %s: prepare: %s", Tag, tag_, query.c_str());
+ std::string handle(std::to_string(nextPreparedStatementID_++));
+ set_shared_string(session_->preparedStatementHandle, handle);
+ preparedStatements_.insert(
+ std::make_pair(handle,
PreparedStatement(std::move(query))));
+
+ if (sharedData_->serverPID != InvalidPid)
+ {
+ P("%s: %s: prepare: kill server: %d", Tag, tag_,
sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+
+ pgstat_report_activity(STATE_IDLE, nullptr);
+ }
+
+ void close_prepared_statement()
+ {
+ const char* tag = "close prepared statement";
+
+ if (!DsaPointerIsValid(session_->preparedStatementHandle))
+ {
+ lock_acquire();
+ set_shared_string(
+ session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": " + tag +
": handle is missing");
+ lock_release();
+ return;
+ }
+
+ pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": "
+ tag).c_str());
+
+ lock_acquire();
+ std::string handle(static_cast<const char*>(
+ dsa_get_address(area_,
session_->preparedStatementHandle)));
+ dsa_free(area_, session_->preparedStatementHandle);
+ session_->preparedStatementHandle = InvalidDsaPointer;
+ if (preparedStatements_.erase(handle) == 0)
+ {
+ set_shared_string(session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": "
+ tag +
+ ": nonexistent handle: <" +
handle + ">");
+ }
+ lock_release();
+ P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
+
+ if (sharedData_->serverPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill server: %d", Tag, tag_, tag,
sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+
+ pgstat_report_activity(STATE_IDLE, nullptr);
+ }
+
+ void update_prepared_statement()
+ {
+ const char* tag = "update prepared statement";
+
+ if (!DsaPointerIsValid(session_->preparedStatementHandle))
+ {
+ lock_acquire();
+ set_shared_string(
+ session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": " + tag + "
: handle is missing");
+ lock_release();
+ return;
+ }
+
+ pgstat_report_activity(
+ STATE_RUNNING, (std::string(Tag) + ": updating prepared
statement").c_str());
+
+ lock_acquire();
+ std::string handle(static_cast<const char*>(
+ dsa_get_address(area_,
session_->preparedStatementHandle)));
+ dsa_free(area_, session_->preparedStatementHandle);
+ session_->preparedStatementHandle = InvalidDsaPointer;
+ PreparedStatement* preparedStatement = nullptr;
+ auto it = preparedStatements_.find(handle);
+ if (it == preparedStatements_.end())
+ {
+ set_shared_string(session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": "
+ tag +
+ ": nonexistent handle: <" +
handle + ">");
+ }
+ else
+ {
+ preparedStatement = &(it->second);
+ }
+ lock_release();
+ P("%s: %s: %s: %s", Tag, tag_, tag, handle.c_str());
+
+ if (preparedStatement)
+ {
+ Transaction transaction;
+ Snapshot snapshot;
+
+ auto input =
std::make_shared<SharedRingBufferInputStream>(this, session_);
+ auto n_updated_records_result =
preparedStatement->update(input);
+ if (n_updated_records_result.ok())
+ {
+ session_->nUpdatedRecords =
*n_updated_records_result;
+ }
+ else
+ {
+ set_shared_string(
+ session_->errorMessage,
+ std::string(Tag) + ": " + tag_ + ": " +
tag +
+ ": failed to update a prepared
statement: <" + handle +
+ ">: " +
n_updated_records_result.status().ToString());
+ }
+ }
+
+ if (sharedData_->serverPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill server: %d", Tag, tag_, tag,
sharedData_->serverPID);
+ kill(sharedData_->serverPID, SIGUSR1);
+ }
+
+ pgstat_report_activity(STATE_IDLE, nullptr);
+ }
+
uint64_t sessionID_;
SessionData* session_;
bool connected_;
bool closed_;
+ uint64_t nextPreparedStatementID_;
+ std::map<std::string, PreparedStatement> preparedStatements_;
};
arrow::Status
@@ -1071,18 +1533,18 @@ SharedRingBufferOutputStream::Write(const void* data,
int64_t nBytes)
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
- return arrow::Status::IOError(std::string(Tag) + ": " +
executor_->tag() +
+ return arrow::Status::IOError(std::string(Tag) + ": " +
processor_->tag() +
": SharedRingBufferOutputStream
is closed");
}
if (ARROW_PREDICT_TRUE(nBytes > 0))
{
- auto buffer = std::move(executor_->create_shared_ring_buffer());
+ auto buffer =
std::move(processor_->create_shared_ring_buffer(session_));
size_t rest = static_cast<size_t>(nBytes);
while (true)
{
- executor_->lock_acquire(LW_EXCLUSIVE);
+ processor_->lock_acquire();
auto writtenSize = buffer.write(data, rest);
- executor_->lock_release();
+ processor_->lock_release();
position_ += writtenSize;
rest -= writtenSize;
@@ -1093,7 +1555,8 @@ SharedRingBufferOutputStream::Write(const void* data,
int64_t nBytes)
break;
}
- executor_->wait_server_read(&buffer);
+ ARROW_RETURN_NOT_OK(
+ processor_->wait(session_, &buffer,
Processor::WaitMode::Read));
}
}
return arrow::Status::OK();
@@ -1102,41 +1565,8 @@ SharedRingBufferOutputStream::Write(const void* data,
int64_t nBytes)
class Proxy : public WorkerProcessor {
public:
explicit Proxy()
- : WorkerProcessor("proxy"), randomSeed_(),
randomEngine_(randomSeed_())
- {
- }
-
- SharedRingBuffer create_shared_ring_buffer(SessionData* session)
+ : WorkerProcessor("proxy", false), randomSeed_(),
randomEngine_(randomSeed_())
{
- return SharedRingBuffer(&(session->bufferData), area_);
- }
-
- void wait_executor_written(SessionData* session, SharedRingBuffer*
buffer)
- {
- if (ARROW_PREDICT_FALSE(session->executorPID == InvalidPid))
- {
- ereport(ERROR,
- errcode(ERRCODE_INTERNAL_ERROR),
- errmsg("%s: %s: executor isn't alive", Tag,
tag_));
- }
-
- P("%s: %s: %s: kill executor: %d", Tag, tag_, AFS_FUNC,
session->executorPID);
- kill(session->executorPID, SIGUSR1);
- auto size = buffer->size();
- std::unique_lock<std::mutex> lock(mutex_);
- conditionVariable_.wait(lock, [&] {
- P("%s: %s: %s: wait: write: %d:%d",
- Tag,
- tag_,
- AFS_FUNC,
- buffer->size(),
- size);
- if (INTERRUPTS_PENDING_CONDITION())
- {
- return true;
- }
- return buffer->size() != size;
- });
}
arrow::Result<uint64_t> connect(const std::string& databaseName,
@@ -1199,19 +1629,22 @@ class Proxy : public WorkerProcessor {
arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID,
const std::string&
query)
{
+ const char* tag = "select";
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
set_shared_string(session->selectQuery, query);
+ session->action = Action::Select;
if (session->executorPID != InvalidPid)
{
- P("%s: %s: select: kill executor: %d", Tag, tag_,
session->executorPID);
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
auto buffer =
std::move(create_shared_ring_buffer(session));
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
- P("%s: %s: %s: wait: select", Tag, tag_,
AFS_FUNC);
+ P("%s: %s: %s: wait", Tag, tag_, tag);
+ pid_t pid = 0;
return DsaPointerIsValid(session->errorMessage)
|| buffer.size() > 0;
});
}
@@ -1219,40 +1652,31 @@ class Proxy : public WorkerProcessor {
{
return report_session_error(session);
}
- P("%s: %s: select: open", Tag, tag_);
- auto input =
std::make_shared<SharedRingBufferInputStream>(this, session);
- // Read schema only stream format data.
- ARROW_ASSIGN_OR_RAISE(auto reader,
-
arrow::ipc::RecordBatchStreamReader::Open(input));
- while (true)
- {
- std::shared_ptr<arrow::RecordBatch> recordBatch;
- P("%s: %s: select: read next", Tag, tag_);
- ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
- if (!recordBatch)
- {
- break;
- }
- }
- P("%s: %s: select: schema", Tag, tag_);
- return reader->schema();
+ P("%s: %s: %s: open", Tag, tag_, tag);
+ auto schema = read_schema(session, tag);
+ P("%s: %s: %s: schema", Tag, tag_, tag);
+ return std::move(schema);
}
arrow::Result<int64_t> update(uint64_t sessionID, const std::string&
query)
{
+ const char* tag = "update";
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
+ lock_acquire();
set_shared_string(session->updateQuery, query);
+ session->action = Action::Update;
session->nUpdatedRecords = -1;
+ lock_release();
if (session->executorPID != InvalidPid)
{
- P("%s: %s: update: kill executor: %d", Tag, tag_,
session->executorPID);
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
- P("%s: %s: %s: wait: update", Tag, tag_,
AFS_FUNC);
+ P("%s: %s: %s: wait", Tag, tag_, tag);
return DsaPointerIsValid(session->errorMessage)
||
session->nUpdatedRecords >= 0;
});
@@ -1261,7 +1685,7 @@ class Proxy : public WorkerProcessor {
{
return report_session_error(session);
}
- P("%s: %s: update: done: %ld", Tag, tag_,
session->nUpdatedRecords);
+ P("%s: %s: %s: done: %ld", Tag, tag_, tag,
session->nUpdatedRecords);
return session->nUpdatedRecords;
}
@@ -1274,13 +1698,144 @@ class Proxy : public WorkerProcessor {
return arrow::ipc::RecordBatchStreamReader::Open(input);
}
+ arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult>
prepare(
+ uint64_t sessionID, const std::string& query)
+ {
+ const char* tag = "prepare";
+ auto session = find_session(sessionID);
+ SessionReleaser sessionReleaser(sessions_, session);
+ lock_acquire();
+ set_shared_string(session->prepareQuery, query);
+ session->action = Action::Prepare;
+ set_shared_string(session->preparedStatementHandle,
std::string(""));
+ lock_release();
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ conditionVariable_.wait(lock, [&] {
+ P("%s: %s: %s: wait", Tag, tag_, tag);
+ return DsaPointerIsValid(session->errorMessage)
||
+
DsaPointerIsValid(session->preparedStatementHandle);
+ });
+ }
+ if (DsaPointerIsValid(session->errorMessage))
+ {
+ return report_session_error(session);
+ }
+ std::string handle(static_cast<const char*>(
+ dsa_get_address(area_,
session->preparedStatementHandle)));
+ arrow::flight::sql::ActionCreatePreparedStatementResult result
= {
+ nullptr,
+ nullptr,
+ std::move(handle),
+ };
+ P("%s: %s: %s: done", Tag, tag_, tag);
+ return std::move(result);
+ }
+
+ arrow::Status close_prepared_statement(uint64_t sessionID, const
std::string& handle)
+ {
+ const char* tag = "close prepared statement";
+ auto session = find_session(sessionID);
+ SessionReleaser sessionReleaser(sessions_, session);
+ lock_acquire();
+ set_shared_string(session->preparedStatementHandle, handle);
+ session->action = Action::ClosePreparedStatement;
+ lock_release();
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ conditionVariable_.wait(lock, [&] {
+ P("%s: %s: %s: wait", Tag, tag_, tag);
+ return DsaPointerIsValid(session->errorMessage)
||
+
!DsaPointerIsValid(session->preparedStatementHandle);
+ });
+ }
+ if (DsaPointerIsValid(session->errorMessage))
+ {
+ return report_session_error(session);
+ }
+ P("%s: %s: %s: done", Tag, tag_, tag);
+ return arrow::Status::OK();
+ }
+
+ arrow::Result<int64_t> update_prepared_statement(
+ uint64_t sessionID,
+ const std::string& handle,
+ arrow::flight::FlightMessageReader* reader)
+ {
+ const char* tag = "update prepared statement";
+ auto session = find_session(sessionID);
+ SessionReleaser sessionReleaser(sessions_, session);
+ lock_acquire();
+ set_shared_string(session->preparedStatementHandle, handle);
+ session->action = Action::UpdatePreparedStatement;
+ session->nUpdatedRecords = -1;
+ lock_release();
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ ARROW_ASSIGN_OR_RAISE(const auto& schema,
reader->GetSchema());
+ SharedRingBufferOutputStream output(this, session);
+ auto options = arrow::ipc::IpcWriteOptions::Defaults();
+ options.emit_dictionary_deltas = true;
+ ARROW_ASSIGN_OR_RAISE(auto writer,
+
arrow::ipc::MakeStreamWriter(&output, schema, options));
+ while (true)
+ {
+ ARROW_ASSIGN_OR_RAISE(const auto& chunk,
reader->Next());
+ if (!chunk.data)
+ {
+ break;
+ }
+
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*(chunk.data)));
+ }
+ ARROW_RETURN_NOT_OK(writer->Close());
+ }
+ if (session->executorPID != InvalidPid)
+ {
+ P("%s: %s: %s: kill executor: %d", Tag, tag_, tag,
session->executorPID);
+ kill(session->executorPID, SIGUSR1);
+ }
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ conditionVariable_.wait(lock, [&] {
+ P("%s: %s: %s: wait", Tag, tag_, tag);
+ return DsaPointerIsValid(session->errorMessage)
||
+ session->nUpdatedRecords >= 0;
+ });
+ }
+ if (DsaPointerIsValid(session->errorMessage))
+ {
+ return report_session_error(session);
+ }
+ P("%s: %s: %s: done: %ld", Tag, tag_, tag,
session->nUpdatedRecords);
+ return session->nUpdatedRecords;
+ }
+
+ protected:
+ pid_t peer_pid(SessionData* session) override { return
session->executorPID; }
+
+ const char* peer_name(SessionData* session) override { return
"executor"; }
+
private:
SessionData* create_session(const std::string& databaseName,
const std::string& userName,
const std::string& password,
const std::string& clientAddress)
{
- LWLockAcquire(lock_, LW_EXCLUSIVE);
+ lock_acquire();
uint64_t id = 0;
SessionData* session = nullptr;
do
@@ -1298,18 +1853,9 @@ class Proxy : public WorkerProcessor {
break;
}
} while (true);
- session->errorMessage = InvalidDsaPointer;
- session->executorPID = InvalidPid;
- session->initialized = false;
- set_shared_string(session->databaseName, databaseName);
- set_shared_string(session->userName, userName);
- set_shared_string(session->password, password);
- set_shared_string(session->clientAddress, clientAddress);
- session->selectQuery = InvalidDsaPointer;
- session->updateQuery = InvalidDsaPointer;
- session->nUpdatedRecords = -1;
- SharedRingBuffer::initialize_data(&(session->bufferData));
- LWLockRelease(lock_);
+ session_data_initialize(
+ session, area_, databaseName, userName, password,
clientAddress);
+ lock_release();
return session;
}
@@ -1331,6 +1877,26 @@ class Proxy : public WorkerProcessor {
return status;
}
+ arrow::Result<std::shared_ptr<arrow::Schema>> read_schema(SessionData*
session,
+ const char*
tag)
+ {
+ auto input =
std::make_shared<SharedRingBufferInputStream>(this, session);
+ // Read schema only stream format data.
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+
arrow::ipc::RecordBatchStreamReader::Open(input));
+ while (true)
+ {
+ std::shared_ptr<arrow::RecordBatch> recordBatch;
+ P("%s: %s: %s: read next", Tag, tag_, tag);
+ ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
+ if (!recordBatch)
+ {
+ break;
+ }
+ }
+ return std::move(reader->schema());
+ }
+
std::random_device randomSeed_;
std::mt19937_64 randomEngine_;
};
@@ -1340,16 +1906,16 @@ SharedRingBufferInputStream::Read(int64_t nBytes, void*
out)
{
if (ARROW_PREDICT_FALSE(!is_open_))
{
- return arrow::Status::IOError(std::string(Tag) + ": " +
proxy_->tag() +
+ return arrow::Status::IOError(std::string(Tag) + ": " +
processor_->tag() +
": SharedRingBufferInputStream is
closed");
}
- auto buffer = std::move(proxy_->create_shared_ring_buffer(session_));
+ auto buffer =
std::move(processor_->create_shared_ring_buffer(session_));
size_t rest = static_cast<size_t>(nBytes);
while (true)
{
- proxy_->lock_acquire(LW_EXCLUSIVE);
+ processor_->lock_acquire();
auto readBytes = buffer.read(rest, out);
- proxy_->lock_release();
+ processor_->lock_release();
position_ += readBytes;
rest -= readBytes;
@@ -1359,10 +1925,11 @@ SharedRingBufferInputStream::Read(int64_t nBytes, void*
out)
break;
}
- proxy_->wait_executor_written(session_, &buffer);
+ ARROW_RETURN_NOT_OK(
+ processor_->wait(session_, &buffer,
Processor::WaitMode::Written));
if (INTERRUPTS_PENDING_CONDITION())
{
- return arrow::Status::IOError(std::string(Tag) + ": " +
proxy_->tag() +
+ return arrow::Status::IOError(std::string(Tag) + ": " +
processor_->tag() +
": interrupted");
}
}
@@ -1371,7 +1938,7 @@ SharedRingBufferInputStream::Read(int64_t nBytes, void*
out)
class MainProcessor : public Processor {
public:
- MainProcessor() : Processor("main"), sessions_(nullptr)
+ MainProcessor() : Processor("main", true), sessions_(nullptr)
{
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
bool found;
@@ -1637,6 +2204,35 @@ class FlightSQLServer : public
arrow::flight::sql::FlightSqlServerBase {
return proxy_->update(sessionID, query);
}
+ arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult>
+ CreatePreparedStatement(
+ const arrow::flight::ServerCallContext& context,
+ const arrow::flight::sql::ActionCreatePreparedStatementRequest&
request)
+ {
+ ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+ const auto& query = request.query;
+ return proxy_->prepare(sessionID, query);
+ }
+
+ arrow::Status ClosePreparedStatement(
+ const arrow::flight::ServerCallContext& context,
+ const arrow::flight::sql::ActionClosePreparedStatementRequest&
request)
+ {
+ ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+ const auto& handle = request.prepared_statement_handle;
+ return proxy_->close_prepared_statement(sessionID, handle);
+ }
+
+ arrow::Result<int64_t> DoPutPreparedStatementUpdate(
+ const arrow::flight::ServerCallContext& context,
+ const arrow::flight::sql::PreparedStatementUpdate& command,
+ arrow::flight::FlightMessageReader* reader) override
+ {
+ ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
+ const auto& handle = command.prepared_statement_handle;
+ return proxy_->update_prepared_statement(sessionID, handle,
reader);
+ }
+
private:
arrow::Result<uint64_t> session_id(const
arrow::flight::ServerCallContext& context)
{
diff --git a/test/helper/sandbox.rb b/test/helper/sandbox.rb
index a320683..995a147 100644
--- a/test/helper/sandbox.rb
+++ b/test/helper/sandbox.rb
@@ -113,6 +113,7 @@ module Helper
@flight_sql_uri = nil
@user = "arrow-flight-sql-test"
@password = "Passw0rd!"
+ @pid = nil
@running = false
end
@@ -193,12 +194,33 @@ module Helper
end
end
@running = true
+ pid_path = File.join(@dir, "postmaster.pid")
+ if File.exist?(pid_path)
+ first_line = File.readlines(pid_path, chomp: true)[0]
+ begin
+ @pid = Integer(first_line, 10)
+ rescue ArgumentError
+ end
+ end
end
def stop
return unless running?
- run_command("pg_ctl", "stop",
- "-D", @dir)
+ begin
+ run_command("pg_ctl", "stop",
+ "-D", @dir,
+ "-t", "60")
+ rescue
+ if @pid
+ Process.kill(:KILL, @pid)
+ @pid = nil
+ @running = false
+ end
+ raise
+ else
+ @pid = nil
+ @running = false
+ end
end
def psql(db, sql)
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 4237dd4..cb995b4 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -52,7 +52,7 @@ class FlightSQLTest < Test::Unit::TestCase
reader.read_all)
end
- def test_isnert_int32
+ def test_insert_direct
unless flight_sql_client.respond_to?(:execute_update)
omit("red-arrow-flight-sql 13.0.0 or later is required")
end
@@ -74,4 +74,31 @@ SELECT * FROM data
RESULT
end
+
+ def test_insert_int32
+ unless flight_sql_client.respond_to?(:prepare)
+ omit("red-arrow-flight-sql 14.0.0 or later is required")
+ end
+
+ run_sql("CREATE TABLE data (value integer)")
+
+ flight_sql_client.prepare("INSERT INTO data VALUES ($1)",
+ @options) do |statement|
+ values = Arrow::Int32Array.new([1, -2, 3])
+ statement.record_batch = Arrow::RecordBatch.new(value: values)
+ n_changed_records = statement.execute_update(@options)
+ assert_equal(3, n_changed_records)
+ end
+
+ assert_equal([<<-RESULT, ""], run_sql("SELECT * FROM data"))
+SELECT * FROM data
+ value
+-------
+ 1
+ -2
+ 3
+(3 rows)
+
+ RESULT
+ end
end