This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 612bb9d Bulk data ingestion (#7)
612bb9d is described below
commit 612bb9dcca1561ebfe27453ed9e24f8b524253f5
Author: David Li <[email protected]>
AuthorDate: Thu Jun 9 10:54:44 2022 -0400
Bulk data ingestion (#7)
* Sketch out bulk data ingestion
* Document parameters and flow better
---
adbc.h | 98 +++++--
adbc_driver_manager/adbc_driver_manager.cc | 59 ++++-
adbc_driver_manager/adbc_driver_manager.h | 1 +
adbc_driver_manager/adbc_driver_manager_test.cc | 48 ++++
drivers/sqlite/sqlite.cc | 335 +++++++++++++++++-------
drivers/sqlite/sqlite_test.cc | 85 ++++++
6 files changed, 504 insertions(+), 122 deletions(-)
diff --git a/adbc.h b/adbc.h
index a692e91..fb3db59 100644
--- a/adbc.h
+++ b/adbc.h
@@ -389,9 +389,12 @@ AdbcStatusCode AdbcConnectionGetTables(struct
AdbcConnection* connection,
/// }@
/// \defgroup adbc-statement Managing statements.
-/// Applications should first initialize and configure a statement with
-/// AdbcStatementInit and the AdbcStatementSetOption functions, then use the
-/// statement with a function like AdbcConnectionSqlExecute.
+/// Applications should first initialize a statement with
+/// AdbcStatementNew. Then, the statement should be configured with
+/// functions like AdbcStatementSetSqlQuery and
+/// AdbcStatementSetOption. Finally, the statement can be executed
+/// with AdbcStatementExecute (or call AdbcStatementPrepare first to
+/// turn it into a prepared statement instead).
/// @{
/// \brief An instance of a database query, from parameters set before
@@ -416,6 +419,13 @@ struct AdbcStatement {
AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
struct AdbcStatement* statement, struct
AdbcError* error);
+/// \brief Destroy a statement.
+/// \param[in] statement The statement to release.
+/// \param[out] error An optional location to return an error
+/// message if necessary.
+AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
+ struct AdbcError* error);
+
/// \brief Execute a statement.
AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
struct AdbcError* error);
@@ -424,13 +434,6 @@ AdbcStatusCode AdbcStatementExecute(struct AdbcStatement*
statement,
AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
struct AdbcError* error);
-/// \brief Destroy a statement.
-/// \param[in] statement The statement to release.
-/// \param[out] error An optional location to return an error
-/// message if necessary.
-AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
- struct AdbcError* error);
-
/// \defgroup adbc-statement-sql SQL Semantics
/// Functions for executing SQL queries, or querying SQL-related
/// metadata. Drivers are not required to support both SQL and
@@ -438,16 +441,16 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement*
statement,
/// between representations internally.
/// @{
-/// \brief Execute a one-shot query.
+/// \brief Set the SQL query to execute.
///
-/// For queries expected to be executed repeatedly, create a
-/// prepared statement.
+/// The query can then be executed with AdbcStatementExecute. For
+/// queries expected to be executed repeatedly, AdbcStatementPrepare
+/// the statement first.
///
-/// \param[in] connection The database connection.
+/// \param[in] statement The statement.
/// \param[in] query The query to execute.
-/// \param[in,out] statement The result set. Allocate with AdbcStatementInit.
/// \param[out] error Error details, if an error occurs.
-AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* connection,
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
const char* query, struct AdbcError*
error);
/// }@
@@ -459,13 +462,26 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct
AdbcStatement* connection,
/// converting between representations internally.
/// @{
-// TODO: not yet defined
+/// \brief Set the Substrait plan to execute.
+///
+/// The query can then be executed with AdbcStatementExecute. For
+/// queries expected to be executed repeatedly, AdbcStatementPrepare
+/// the statement first.
+///
+/// \param[in] statement The statement.
+/// \param[in] plan The serialized substrait.Plan to execute.
+/// \param[in] length The length of the serialized plan.
+/// \param[out] error Error details, if an error occurs.
+AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement,
+ const uint8_t* plan, size_t
length,
+ struct AdbcError* error);
/// }@
-/// \brief Bind parameter values for parameterized statements.
+/// \brief Bind Arrow data. This can be used for bulk inserts or
+/// prepared statements.
/// \param[in] statement The statement to bind to.
-/// \param[in] values The values to bind. The driver will not call the
+/// \param[in] values The values to bind. The driver will call the
/// release callback itself, although it may not do this until the
/// statement is released.
/// \param[in] schema The schema of the values to bind.
@@ -475,6 +491,18 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement*
statement,
struct ArrowArray* values, struct
ArrowSchema* schema,
struct AdbcError* error);
+/// \brief Bind Arrow data. This can be used for bulk inserts or
+/// prepared statements.
+/// \param[in] statement The statement to bind to.
+/// \param[in] stream The values to bind. The driver will call the
+/// release callback itself, although it may not do this until the
+/// statement is released.
+/// \param[out] error An optional location to return an error message
+/// if necessary.
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* values,
+ struct AdbcError* error);
+
/// \brief Read the result of a statement.
///
/// This method can be called only once per execution of the
@@ -487,6 +515,26 @@ AdbcStatusCode AdbcStatementGetStream(struct
AdbcStatement* statement,
struct ArrowArrayStream* out,
struct AdbcError* error);
+/// \brief Set a string option on a statement.
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const
char* key,
+ const char* value, struct AdbcError*
error);
+
+/// \defgroup adbc-statement-ingestion Bulk Data Ingestion
+/// While it is possible to insert data via prepared statements, it
+/// can be more efficient to explicitly perform a bulk insert. For
+/// compatible drivers, this can be accomplished by setting up and
+/// executing a statement. Instead of setting a SQL query or
+/// Substrait plan, bind the source data via AdbcStatementBind, and
+/// set the name of the table to be created via AdbcStatementSetOption
+/// and the options below.
+///
+/// @{
+
+/// \brief The name of the target table for a bulk insert.
+#define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table"
+
+/// }@
+
// TODO: methods to get a particular result set from the statement,
// etc. especially for prepared statements with parameter batches
@@ -572,10 +620,6 @@ struct AdbcDriver {
struct AdbcError*);
AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcError*);
AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct
AdbcError*);
- AdbcStatusCode (*ConnectionSqlExecute)(struct AdbcConnection*, const char*,
- struct AdbcStatement*, struct
AdbcError*);
- AdbcStatusCode (*ConnectionSqlPrepare)(struct AdbcConnection*, const char*,
- struct AdbcStatement*, struct
AdbcError*);
AdbcStatusCode (*ConnectionDeserializePartitionDesc)(struct AdbcConnection*,
const uint8_t*, size_t,
struct AdbcStatement*,
@@ -596,6 +640,8 @@ struct AdbcDriver {
AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*);
AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*,
struct ArrowSchema*, struct AdbcError*);
+ AdbcStatusCode (*StatementBindStream)(struct AdbcStatement*, struct
ArrowArrayStream*,
+ struct AdbcError*);
AdbcStatusCode (*StatementExecute)(struct AdbcStatement*, struct AdbcError*);
AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*);
AdbcStatusCode (*StatementGetStream)(struct AdbcStatement*, struct
ArrowArrayStream*,
@@ -604,8 +650,12 @@ struct AdbcDriver {
struct AdbcError*);
AdbcStatusCode (*StatementGetPartitionDesc)(struct AdbcStatement*, uint8_t*,
struct AdbcError*);
+ AdbcStatusCode (*StatementSetOption)(struct AdbcStatement*, const char*,
const char*,
+ struct AdbcError*);
AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*,
struct AdbcError*);
+ AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const
uint8_t*,
+ size_t, struct AdbcError*);
// Do not edit fields. New fields can only be appended to the end.
};
@@ -628,7 +678,7 @@ typedef AdbcStatusCode (*AdbcDriverInitFunc)(size_t count,
struct AdbcDriver* dr
// struct/entrypoint instead?
// For use with count
-#define ADBC_VERSION_0_0_1 21
+#define ADBC_VERSION_0_0_1 25
/// }@
diff --git a/adbc_driver_manager/adbc_driver_manager.cc
b/adbc_driver_manager/adbc_driver_manager.cc
index e20815a..2575e2d 100644
--- a/adbc_driver_manager/adbc_driver_manager.cc
+++ b/adbc_driver_manager/adbc_driver_manager.cc
@@ -39,10 +39,6 @@ void SetError(struct AdbcError* error, const std::string&
message) {
}
// Default stubs
-AdbcStatusCode ConnectionSqlPrepare(struct AdbcConnection*, const char*,
- struct AdbcStatement*, struct AdbcError*
error) {
- return ADBC_STATUS_NOT_IMPLEMENTED;
-}
AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*,
struct ArrowSchema*, struct AdbcError* error) {
@@ -53,7 +49,26 @@ AdbcStatusCode StatementExecute(struct AdbcStatement*,
struct AdbcError* error)
return ADBC_STATUS_NOT_IMPLEMENTED;
}
-// Temporary
+AdbcStatusCode StatementPrepare(struct AdbcStatement*, struct AdbcError*
error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const
char*,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const
uint8_t*, size_t,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+/// Temporary state while the database is being configured.
struct TempDatabase {
std::unordered_map<std::string, std::string> options;
std::string driver;
@@ -186,6 +201,15 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement*
statement,
return statement->private_driver->StatementBind(statement, values, schema,
error);
}
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ if (!statement->private_driver) {
+ return ADBC_STATUS_UNINITIALIZED;
+ }
+ return statement->private_driver->StatementBindStream(statement, stream,
error);
+}
+
AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
struct AdbcError* error) {
if (!statement->private_driver) {
@@ -232,6 +256,14 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement*
statement,
return status;
}
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const
char* key,
+ const char* value, struct AdbcError*
error) {
+ if (!statement->private_driver) {
+ return ADBC_STATUS_UNINITIALIZED;
+ }
+ return statement->private_driver->StatementSetOption(statement, key, value,
error);
+}
+
AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
const char* query, struct AdbcError*
error) {
if (!statement->private_driver) {
@@ -240,6 +272,16 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct
AdbcStatement* statement,
return statement->private_driver->StatementSetSqlQuery(statement, query,
error);
}
+AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement,
+ const uint8_t* plan, size_t
length,
+ struct AdbcError* error) {
+ if (!statement->private_driver) {
+ return ADBC_STATUS_UNINITIALIZED;
+ }
+ return statement->private_driver->StatementSetSubstraitPlan(statement, plan,
length,
+ error);
+}
+
const char* AdbcStatusCodeMessage(AdbcStatusCode code) {
#define STRINGIFY(s) #s
#define STRINGIFY_VALUE(s) STRINGIFY(s)
@@ -300,9 +342,14 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name,
const char* entrypoint,
CHECK_REQUIRED(driver, DatabaseNew);
CHECK_REQUIRED(driver, DatabaseInit);
- FILL_DEFAULT(driver, ConnectionSqlPrepare);
+ CHECK_REQUIRED(driver, DatabaseRelease);
+
FILL_DEFAULT(driver, StatementBind);
FILL_DEFAULT(driver, StatementExecute);
+ FILL_DEFAULT(driver, StatementPrepare);
+ FILL_DEFAULT(driver, StatementSetSqlQuery);
+ FILL_DEFAULT(driver, StatementSetSubstraitPlan);
+
return ADBC_STATUS_OK;
#undef FILL_DEFAULT
diff --git a/adbc_driver_manager/adbc_driver_manager.h
b/adbc_driver_manager/adbc_driver_manager.h
index 3a67706..d20c4b1 100644
--- a/adbc_driver_manager/adbc_driver_manager.h
+++ b/adbc_driver_manager/adbc_driver_manager.h
@@ -48,6 +48,7 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const
char* entrypoint,
size_t count, struct AdbcDriver* driver,
size_t* initialized, struct AdbcError* error);
+/// \brief Get a human-friendly description of a status code.
const char* AdbcStatusCodeMessage(AdbcStatusCode code);
#endif // ADBC_DRIVER_MANAGER_H
diff --git a/adbc_driver_manager/adbc_driver_manager_test.cc
b/adbc_driver_manager/adbc_driver_manager_test.cc
index ae53e84..e2ef8a7 100644
--- a/adbc_driver_manager/adbc_driver_manager_test.cc
+++ b/adbc_driver_manager/adbc_driver_manager_test.cc
@@ -20,6 +20,7 @@
#include <arrow/c/bridge.h>
#include <arrow/record_batch.h>
+#include <arrow/table.h>
#include <arrow/testing/matchers.h>
#include "adbc.h"
@@ -162,4 +163,51 @@ TEST_F(DriverManager, SqlPrepareMultipleParams) {
}));
}
+TEST_F(DriverManager, BulkIngestStream) {
+ ArrowArrayStream export_stream;
+ auto bulk_schema = arrow::schema(
+ {arrow::field("ints", arrow::int64()), arrow::field("strs",
arrow::utf8())});
+ std::vector<std::shared_ptr<arrow::RecordBatch>> bulk_batches{
+ adbc::RecordBatchFromJSON(bulk_schema, R"([[1, "foo"], [2, "bar"]])"),
+ adbc::RecordBatchFromJSON(bulk_schema, R"([[3, ""], [4, "baz"]])"),
+ };
+ auto bulk_table = *arrow::Table::FromRecordBatches(bulk_batches);
+ auto reader = std::make_shared<arrow::TableBatchReader>(*bulk_table);
+ ASSERT_OK(arrow::ExportRecordBatchReader(reader, &export_stream));
+
+ {
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ "bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementBindStream(&statement, &export_stream, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+ }
+
+ {
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM
bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+ ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
+ EXPECT_THAT(
+ batches,
+ ::testing::UnorderedPointwise(
+ PointeesEqual(),
+ {
+ adbc::RecordBatchFromJSON(
+ bulk_schema, R"([[1, "foo"], [2, "bar"], [3, ""], [4,
"baz"]])"),
+ }));
+ }
+}
+
} // namespace adbc
diff --git a/drivers/sqlite/sqlite.cc b/drivers/sqlite/sqlite.cc
index 8d8956a..0f5a99e 100644
--- a/drivers/sqlite/sqlite.cc
+++ b/drivers/sqlite/sqlite.cc
@@ -27,6 +27,7 @@
#include <arrow/c/bridge.h>
#include <arrow/record_batch.h>
#include <arrow/status.h>
+#include <arrow/table.h>
#include <arrow/util/logging.h>
#include <arrow/util/string_builder.h>
@@ -75,6 +76,19 @@ void SetError(struct AdbcError* error, Args&&... args) {
error->release = ReleaseError;
}
+AdbcStatusCode CheckRc(sqlite3* db, sqlite3_stmt* stmt, int rc, const char*
context,
+ struct AdbcError* error) {
+ if (rc != SQLITE_OK) {
+ SetError(db, context, error);
+ rc = sqlite3_finalize(stmt);
+ if (rc != SQLITE_OK) {
+ SetError(db, "sqlite3_finalize", error);
+ }
+ return ADBC_STATUS_IO;
+ }
+ return ADBC_STATUS_OK;
+}
+
std::shared_ptr<arrow::Schema> StatementToSchema(sqlite3_stmt* stmt) {
const int num_columns = sqlite3_column_count(stmt);
arrow::FieldVector fields(num_columns);
@@ -124,14 +138,9 @@ class SqliteDatabaseImpl {
auto it = options_.find("filename");
if (it != options_.end()) filename = it->second.c_str();
- auto status = sqlite3_open_v2(
- filename, &db_, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
/*zVfs=*/nullptr);
- if (status != SQLITE_OK) {
- if (db_) {
- SetError(db_, "sqlite3_open_v2", error);
- }
- return ADBC_STATUS_IO;
- }
+ int rc = sqlite3_open_v2(filename, &db_, SQLITE_OPEN_READWRITE |
SQLITE_OPEN_CREATE,
+ /*zVfs=*/nullptr);
+ ADBC_RETURN_NOT_OK(CheckRc(db_, nullptr, rc, "sqlite3_open_v2", error));
options_.clear();
return ADBC_STATUS_OK;
}
@@ -165,7 +174,7 @@ class SqliteDatabaseImpl {
auto status = sqlite3_close(db_);
if (status != SQLITE_OK) {
- // TODO:
+ if (db_) SetError(db_, "sqlite3_close", error);
return ADBC_STATUS_UNKNOWN;
}
return ADBC_STATUS_OK;
@@ -275,7 +284,13 @@ class SqliteStatementImpl : public
arrow::RecordBatchReader {
if (status == SQLITE_ROW) {
continue;
} else if (status == SQLITE_DONE) {
- if (bind_parameters_ && bind_index_ < bind_parameters_->num_rows()) {
+ if (bind_parameters_ &&
+ (!next_parameters_ || bind_index_ >=
next_parameters_->num_rows())) {
+ ARROW_RETURN_NOT_OK(bind_parameters_->ReadNext(&next_parameters_));
+ bind_index_ = 0;
+ }
+
+ if (next_parameters_ && bind_index_ < next_parameters_->num_rows()) {
status = sqlite3_reset(stmt_);
if (status != SQLITE_OK) {
return Status::IOError("[SQLite3] sqlite3_reset: ",
sqlite3_errmsg(db));
@@ -285,7 +300,7 @@ class SqliteStatementImpl : public arrow::RecordBatchReader
{
if (status == SQLITE_ROW) continue;
} else {
done_ = true;
- bind_parameters_.reset();
+ next_parameters_.reset();
}
break;
}
@@ -306,13 +321,12 @@ class SqliteStatementImpl : public
arrow::RecordBatchReader {
AdbcStatusCode Close(struct AdbcError* error) {
if (stmt_) {
- auto status = sqlite3_finalize(stmt_);
- if (status != SQLITE_OK) {
- SetError(connection_->db(), "sqlite3_finalize", error);
- return ADBC_STATUS_UNKNOWN;
- }
+ const int rc = sqlite3_finalize(stmt_);
stmt_ = nullptr;
+ next_parameters_.reset();
bind_parameters_.reset();
+ ADBC_RETURN_NOT_OK(
+ CheckRc(connection_->db(), nullptr, rc, "sqlite3_finalize", error));
connection_.reset();
}
return ADBC_STATUS_OK;
@@ -325,71 +339,42 @@ class SqliteStatementImpl : public
arrow::RecordBatchReader {
AdbcStatusCode Bind(const std::shared_ptr<SqliteStatementImpl>& self,
struct ArrowArray* values, struct ArrowSchema* schema,
struct AdbcError* error) {
- auto status = arrow::ImportRecordBatch(values,
schema).Value(&bind_parameters_);
+ std::shared_ptr<arrow::RecordBatch> batch;
+ auto status = arrow::ImportRecordBatch(values, schema).Value(&batch);
+ if (!status.ok()) {
+ SetError(error, status);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ std::shared_ptr<arrow::Table> table;
+ status = arrow::Table::FromRecordBatches({std::move(batch)}).Value(&table);
if (!status.ok()) {
SetError(error, status);
return ADBC_STATUS_INVALID_ARGUMENT;
}
+
+ bind_parameters_.reset(new arrow::TableBatchReader(std::move(table)));
return ADBC_STATUS_OK;
}
- AdbcStatusCode Execute(const std::shared_ptr<SqliteStatementImpl>& self,
- struct AdbcError* error) {
- sqlite3* db = connection_->db();
- int rc = 0;
- if (schema_) {
- rc = sqlite3_clear_bindings(stmt_);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_reset_bindings", error);
- rc = sqlite3_finalize(stmt_);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_finalize", error);
- }
- return ADBC_STATUS_IO;
- }
-
- rc = sqlite3_reset(stmt_);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_reset", error);
- rc = sqlite3_finalize(stmt_);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_finalize", error);
- }
- return ADBC_STATUS_IO;
- }
- }
- // Step the statement and get the schema (SQLite doesn't
- // necessarily know the schema until it begins to execute it)
- auto status = BindNext().Value(&rc);
- // XXX: with parameters, inferring the schema from the first
- // argument is inaccurate (what if one is null?). Is there a way
- // to hint to SQLite the real type?
+ AdbcStatusCode Bind(const std::shared_ptr<SqliteStatementImpl>& self,
+ struct ArrowArrayStream* stream, struct AdbcError*
error) {
+ auto status =
arrow::ImportRecordBatchReader(stream).Value(&bind_parameters_);
if (!status.ok()) {
- // TODO: map Arrow codes to ADBC codes
SetError(error, status);
- return ADBC_STATUS_IO;
- } else if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_bind", error);
- rc = sqlite3_finalize(stmt_);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_finalize", error);
- }
- return ADBC_STATUS_IO;
- }
- rc = sqlite3_step(stmt_);
- if (rc == SQLITE_ERROR) {
- SetError(db, "sqlite3_step", error);
- rc = sqlite3_finalize(stmt_);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_finalize", error);
- }
- return ADBC_STATUS_IO;
+ return ADBC_STATUS_INVALID_ARGUMENT;
}
- schema_ = StatementToSchema(stmt_);
- done_ = rc != SQLITE_ROW;
return ADBC_STATUS_OK;
}
+ AdbcStatusCode Execute(const std::shared_ptr<SqliteStatementImpl>& self,
+ struct AdbcError* error) {
+ if (bulk_table_.empty()) {
+ return ExecutePrepared(error);
+ }
+ return ExecuteBulk(error);
+ }
+
AdbcStatusCode GetStream(const std::shared_ptr<SqliteStatementImpl>& self,
struct ArrowArrayStream* out, struct AdbcError*
error) {
if (!stmt_ || !schema_) {
@@ -404,51 +389,62 @@ class SqliteStatementImpl : public
arrow::RecordBatchReader {
return ADBC_STATUS_OK;
}
+ AdbcStatusCode SetOption(const std::shared_ptr<SqliteStatementImpl>& self,
+ const char* key, const char* value, struct
AdbcError* error) {
+ if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) {
+ // Bulk ingest
+ if (std::strlen(value) == 0) return ADBC_STATUS_INVALID_ARGUMENT;
+ bulk_table_ = value;
+ if (stmt_) {
+ int rc = sqlite3_finalize(stmt_);
+ ADBC_RETURN_NOT_OK(
+ CheckRc(connection_->db(), nullptr, rc, "sqlite3_finalize",
error));
+ }
+ return ADBC_STATUS_OK;
+ }
+ SetError(error, "Unknown option: ", key);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+
AdbcStatusCode SetSqlQuery(const std::shared_ptr<SqliteStatementImpl>& self,
const char* query, struct AdbcError* error) {
+ bulk_table_.clear();
sqlite3* db = connection_->db();
int rc = sqlite3_prepare_v2(db, query,
static_cast<int>(std::strlen(query)), &stmt_,
/*pzTail=*/nullptr);
- if (rc != SQLITE_OK) {
- if (stmt_) {
- rc = sqlite3_finalize(stmt_);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_finalize", error);
- }
- }
- SetError(db, "sqlite3_prepare_v2", error);
- return ADBC_STATUS_UNKNOWN;
- }
- return ADBC_STATUS_OK;
+ return CheckRc(connection_->db(), stmt_, rc, "sqlite3_prepare_v2", error);
}
private:
arrow::Result<int> BindNext() {
- if (!bind_parameters_ || bind_index_ >= bind_parameters_->num_rows()) {
+ if (!next_parameters_ || bind_index_ >= next_parameters_->num_rows()) {
return SQLITE_OK;
}
+ return BindImpl(stmt_, *next_parameters_, bind_index_++);
+ }
+
+ arrow::Result<int> BindImpl(sqlite3_stmt* stmt, const arrow::RecordBatch&
data,
+ int64_t row) {
int col_index = 1;
- // TODO: multiple output rows
- const int bind_index = bind_index_++;
- for (const auto& column : bind_parameters_->columns()) {
- if (column->IsNull(bind_index)) {
- const int rc = sqlite3_bind_null(stmt_, col_index);
+ for (const auto& column : data.columns()) {
+ if (column->IsNull(row)) {
+ const int rc = sqlite3_bind_null(stmt, col_index);
if (rc != SQLITE_OK) return rc;
} else {
switch (column->type()->id()) {
case arrow::Type::INT64: {
const int rc = sqlite3_bind_int64(
- stmt_, col_index,
- static_cast<const
arrow::Int64Array&>(*column).Value(bind_index));
+ stmt, col_index,
+ static_cast<const arrow::Int64Array&>(*column).Value(row));
if (rc != SQLITE_OK) return rc;
break;
}
case arrow::Type::STRING: {
const auto& strings = static_cast<const
arrow::StringArray&>(*column);
- const int rc = sqlite3_bind_text64(
- stmt_, col_index, strings.Value(bind_index).data(),
- strings.value_length(bind_index), SQLITE_STATIC, SQLITE_UTF8);
+ const int rc = sqlite3_bind_text64(stmt, col_index,
strings.Value(row).data(),
+ strings.value_length(row),
SQLITE_STATIC,
+ SQLITE_UTF8);
if (rc != SQLITE_OK) return rc;
break;
}
@@ -459,14 +455,148 @@ class SqliteStatementImpl : public
arrow::RecordBatchReader {
}
col_index++;
}
-
return SQLITE_OK;
}
+ AdbcStatusCode ExecuteBulk(struct AdbcError* error) {
+ if (!bind_parameters_) {
+ SetError(error, "Must AdbcStatementBind for bulk insertion");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ sqlite3* db = connection_->db();
+ sqlite3_stmt* stmt = nullptr;
+ int rc = SQLITE_OK;
+
+ auto check_status = [&](const arrow::Status& st) mutable {
+ if (!st.ok()) {
+ SetError(error, st);
+ if (stmt) {
+ rc = sqlite3_finalize(stmt);
+ if (rc != SQLITE_OK) {
+ SetError(db, "sqlite3_finalize", error);
+ }
+ }
+ return ADBC_STATUS_IO;
+ }
+ return ADBC_STATUS_OK;
+ };
+
+ // Create the table
+ // TODO: parameter to choose append/overwrite/error
+ {
+ // XXX: not injection-safe
+ std::string query = "CREATE TABLE ";
+ query += bulk_table_;
+ query += " (";
+ const auto& fields = bind_parameters_->schema()->fields();
+ for (int i = 0; i < fields.size(); i++) {
+ if (i > 0) query += ',';
+ query += fields[i]->name();
+ }
+ query += ')';
+
+ rc = sqlite3_prepare_v2(db, query.c_str(),
static_cast<int>(query.size()), &stmt,
+ /*pzTail=*/nullptr);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_prepare_v2", error));
+
+ rc = sqlite3_step(stmt);
+ if (rc != SQLITE_DONE) return CheckRc(db, stmt, rc, "sqlite3_step",
error);
+
+ rc = sqlite3_finalize(stmt);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_finalize", error));
+ }
+
+ // Insert the rows
+
+ {
+ std::string query = "INSERT INTO ";
+ query += bulk_table_;
+ query += " VALUES (";
+ const auto& fields = bind_parameters_->schema()->fields();
+ for (int i = 0; i < fields.size(); i++) {
+ if (i > 0) query += ',';
+ query += '?';
+ }
+ query += ')';
+ rc = sqlite3_prepare_v2(db, query.c_str(),
static_cast<int>(query.size()), &stmt,
+ /*pzTail=*/nullptr);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, query.c_str(), error));
+ }
+
+ while (true) {
+ std::shared_ptr<arrow::RecordBatch> batch;
+ auto status = bind_parameters_->Next().Value(&batch);
+ ADBC_RETURN_NOT_OK(check_status(status));
+ if (!batch) break;
+
+ for (int64_t row = 0; row < batch->num_rows(); row++) {
+ status = BindImpl(stmt, *batch, row).Value(&rc);
+ ADBC_RETURN_NOT_OK(check_status(status));
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_bind", error));
+
+ rc = sqlite3_step(stmt);
+ if (rc != SQLITE_DONE) {
+ return CheckRc(db, stmt, rc, "sqlite3_step", error);
+ }
+
+ rc = sqlite3_reset(stmt);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_reset", error));
+
+ rc = sqlite3_clear_bindings(stmt);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_clear_bindings",
error));
+ }
+ }
+
+ rc = sqlite3_finalize(stmt);
+ return CheckRc(db, nullptr, rc, "sqlite3_finalize", error);
+ }
+
+ AdbcStatusCode ExecutePrepared(struct AdbcError* error) {
+ sqlite3* db = connection_->db();
+ int rc = SQLITE_OK;
+ if (schema_) {
+ rc = sqlite3_clear_bindings(stmt_);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt_, rc, "sqlite3_clear_bindings",
error));
+
+ rc = sqlite3_reset(stmt_);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt_, rc, "sqlite3_reset", error));
+ }
+ // Step the statement and get the schema (SQLite doesn't
+ // necessarily know the schema until it begins to execute it)
+
+ Status status;
+ if (bind_parameters_) {
+ status = bind_parameters_->ReadNext(&next_parameters_);
+ if (status.ok()) status = BindNext().Value(&rc);
+ }
+ // XXX: with parameters, inferring the schema from the first
+ // argument is inaccurate (what if one is null?). Is there a way
+ // to hint to SQLite the real type?
+
+ if (!status.ok()) {
+ // TODO: map Arrow codes to ADBC codes
+ SetError(error, status);
+ return ADBC_STATUS_IO;
+ }
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt_, rc, "sqlite3_bind", error));
+
+ rc = sqlite3_step(stmt_);
+ if (rc == SQLITE_ERROR) {
+ return CheckRc(db, stmt_, rc, "sqlite3_error", error);
+ }
+ schema_ = StatementToSchema(stmt_);
+ done_ = rc != SQLITE_ROW;
+ return ADBC_STATUS_OK;
+ }
+
std::shared_ptr<SqliteConnectionImpl> connection_;
+ // Target of bulk ingestion (rather janky to store state like this, though…)
+ std::string bulk_table_;
sqlite3_stmt* stmt_;
std::shared_ptr<arrow::Schema> schema_;
- std::shared_ptr<arrow::RecordBatch> bind_parameters_;
+ std::shared_ptr<arrow::RecordBatchReader> bind_parameters_;
+ std::shared_ptr<arrow::RecordBatch> next_parameters_;
int64_t bind_index_;
bool done_;
};
@@ -557,6 +687,16 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement*
statement,
return (*ptr)->Bind(*ptr, values, schema, error);
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto* ptr =
+
reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
+ return (*ptr)->Bind(*ptr, stream, error);
+}
+
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
struct AdbcError* error) {
@@ -621,6 +761,15 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement*
statement,
return status;
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const
char* key,
+ const char* value, struct AdbcError*
error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto* ptr =
+
reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
+ return (*ptr)->SetOption(*ptr, key, value, error);
+}
+
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
const char* query, struct AdbcError*
error) {
@@ -648,6 +797,7 @@ AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct
AdbcDriver* driver,
driver->ConnectionSetOption = AdbcConnectionSetOption;
driver->StatementBind = AdbcStatementBind;
+ driver->StatementBindStream = AdbcStatementBindStream;
driver->StatementExecute = AdbcStatementExecute;
driver->StatementGetPartitionDesc = AdbcStatementGetPartitionDesc;
driver->StatementGetPartitionDescSize = AdbcStatementGetPartitionDescSize;
@@ -655,6 +805,7 @@ AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct
AdbcDriver* driver,
driver->StatementNew = AdbcStatementNew;
driver->StatementPrepare = AdbcStatementPrepare;
driver->StatementRelease = AdbcStatementRelease;
+ driver->StatementSetOption = AdbcStatementSetOption;
driver->StatementSetSqlQuery = AdbcStatementSetSqlQuery;
*initialized = ADBC_VERSION_0_0_1;
return ADBC_STATUS_OK;
diff --git a/drivers/sqlite/sqlite_test.cc b/drivers/sqlite/sqlite_test.cc
index e3b4f6e..b562f6a 100644
--- a/drivers/sqlite/sqlite_test.cc
+++ b/drivers/sqlite/sqlite_test.cc
@@ -20,6 +20,7 @@
#include <arrow/c/bridge.h>
#include <arrow/record_batch.h>
+#include <arrow/table.h>
#include <arrow/testing/matchers.h>
#include "adbc.h"
@@ -149,6 +150,90 @@ TEST_F(Sqlite, SqlPrepareMultipleParams) {
}));
}
+TEST_F(Sqlite, BulkIngestTable) {
+ ArrowArray export_table;
+ ArrowSchema export_schema;
+ auto bulk_schema = arrow::schema(
+ {arrow::field("ints", arrow::int64()), arrow::field("strs",
arrow::utf8())});
+ auto bulk_table = adbc::RecordBatchFromJSON(bulk_schema, R"([[1, "foo"], [2,
"bar"]])");
+ ASSERT_OK(ExportRecordBatch(*bulk_table, &export_table));
+ ASSERT_OK(ExportSchema(*bulk_schema, &export_schema));
+
+ {
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ "bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementBind(&statement, &export_table, &export_schema,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+ }
+
+ {
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM
bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+ ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
+ EXPECT_THAT(batches, ::testing::UnorderedPointwise(PointeesEqual(),
{bulk_table}));
+ }
+}
+
+TEST_F(Sqlite, BulkIngestStream) {
+ ArrowArrayStream export_stream;
+ auto bulk_schema = arrow::schema(
+ {arrow::field("ints", arrow::int64()), arrow::field("strs",
arrow::utf8())});
+ std::vector<std::shared_ptr<arrow::RecordBatch>> bulk_batches{
+ adbc::RecordBatchFromJSON(bulk_schema, R"([[1, "foo"], [2, "bar"]])"),
+ adbc::RecordBatchFromJSON(bulk_schema, R"([[3, ""], [4, "baz"]])"),
+ };
+ auto bulk_table = *arrow::Table::FromRecordBatches(bulk_batches);
+ auto reader = std::make_shared<arrow::TableBatchReader>(*bulk_table);
+ ASSERT_OK(arrow::ExportRecordBatchReader(reader, &export_stream));
+
+ {
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ "bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementBindStream(&statement, &export_stream, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+ }
+
+ {
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM
bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+ ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
+ EXPECT_THAT(
+ batches,
+ ::testing::UnorderedPointwise(
+ PointeesEqual(),
+ {
+ adbc::RecordBatchFromJSON(
+ bulk_schema, R"([[1, "foo"], [2, "bar"], [3, ""], [4,
"baz"]])"),
+ }));
+ }
+}
+
TEST_F(Sqlite, MultipleConnections) {
struct AdbcConnection connection2;