This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 612bb9d  Bulk data ingestion (#7)
612bb9d is described below

commit 612bb9dcca1561ebfe27453ed9e24f8b524253f5
Author: David Li <[email protected]>
AuthorDate: Thu Jun 9 10:54:44 2022 -0400

    Bulk data ingestion (#7)
    
    * Sketch out bulk data ingestion
    
    * Document parameters and flow better
---
 adbc.h                                          |  98 +++++--
 adbc_driver_manager/adbc_driver_manager.cc      |  59 ++++-
 adbc_driver_manager/adbc_driver_manager.h       |   1 +
 adbc_driver_manager/adbc_driver_manager_test.cc |  48 ++++
 drivers/sqlite/sqlite.cc                        | 335 +++++++++++++++++-------
 drivers/sqlite/sqlite_test.cc                   |  85 ++++++
 6 files changed, 504 insertions(+), 122 deletions(-)

diff --git a/adbc.h b/adbc.h
index a692e91..fb3db59 100644
--- a/adbc.h
+++ b/adbc.h
@@ -389,9 +389,12 @@ AdbcStatusCode AdbcConnectionGetTables(struct 
AdbcConnection* connection,
 /// }@
 
 /// \defgroup adbc-statement Managing statements.
-/// Applications should first initialize and configure a statement with
-/// AdbcStatementInit and the AdbcStatementSetOption functions, then use the
-/// statement with a function like AdbcConnectionSqlExecute.
+/// Applications should first initialize a statement with
+/// AdbcStatementNew. Then, the statement should be configured with
+/// functions like AdbcStatementSetSqlQuery and
+/// AdbcStatementSetOption. Finally, the statement can be executed
+/// with AdbcStatementExecute (or call AdbcStatementPrepare first to
+/// turn it into a prepared statement instead).
 /// @{
 
 /// \brief An instance of a database query, from parameters set before
@@ -416,6 +419,13 @@ struct AdbcStatement {
 AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
                                 struct AdbcStatement* statement, struct 
AdbcError* error);
 
+/// \brief Destroy a statement.
+/// \param[in] statement The statement to release.
+/// \param[out] error An optional location to return an error
+///   message if necessary.
+AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
+                                    struct AdbcError* error);
+
 /// \brief Execute a statement.
 AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
                                     struct AdbcError* error);
@@ -424,13 +434,6 @@ AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* 
statement,
 AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
                                     struct AdbcError* error);
 
-/// \brief Destroy a statement.
-/// \param[in] statement The statement to release.
-/// \param[out] error An optional location to return an error
-///   message if necessary.
-AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
-                                    struct AdbcError* error);
-
 /// \defgroup adbc-statement-sql SQL Semantics
 /// Functions for executing SQL queries, or querying SQL-related
 /// metadata. Drivers are not required to support both SQL and
@@ -438,16 +441,16 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* 
statement,
 /// between representations internally.
 /// @{
 
-/// \brief Execute a one-shot query.
+/// \brief Set the SQL query to execute.
 ///
-/// For queries expected to be executed repeatedly, create a
-/// prepared statement.
+/// The query can then be executed with AdbcStatementExecute.  For
+/// queries expected to be executed repeatedly, AdbcStatementPrepare
+/// the statement first.
 ///
-/// \param[in] connection The database connection.
+/// \param[in] statement The statement.
 /// \param[in] query The query to execute.
-/// \param[in,out] statement The result set. Allocate with AdbcStatementInit.
 /// \param[out] error Error details, if an error occurs.
-AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* connection,
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
                                         const char* query, struct AdbcError* 
error);
 
 /// }@
@@ -459,13 +462,26 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct 
AdbcStatement* connection,
 /// converting between representations internally.
 /// @{
 
-// TODO: not yet defined
+/// \brief Set the Substrait plan to execute.
+///
+/// The query can then be executed with AdbcStatementExecute.  For
+/// queries expected to be executed repeatedly, AdbcStatementPrepare
+/// the statement first.
+///
+/// \param[in] statement The statement.
+/// \param[in] plan The serialized substrait.Plan to execute.
+/// \param[in] length The length of the serialized plan.
+/// \param[out] error Error details, if an error occurs.
+AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement,
+                                             const uint8_t* plan, size_t 
length,
+                                             struct AdbcError* error);
 
 /// }@
 
-/// \brief Bind parameter values for parameterized statements.
+/// \brief Bind Arrow data. This can be used for bulk inserts or
+///   prepared statements.
 /// \param[in] statement The statement to bind to.
-/// \param[in] values The values to bind. The driver will not call the
+/// \param[in] values The values to bind. The driver will call the
 ///   release callback itself, although it may not do this until the
 ///   statement is released.
 /// \param[in] schema The schema of the values to bind.
@@ -475,6 +491,18 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement* 
statement,
                                  struct ArrowArray* values, struct 
ArrowSchema* schema,
                                  struct AdbcError* error);
 
+/// \brief Bind Arrow data. This can be used for bulk inserts or
+///   prepared statements.
+/// \param[in] statement The statement to bind to.
+/// \param[in] stream The values to bind. The driver will call the
+///   release callback itself, although it may not do this until the
+///   statement is released.
+/// \param[out] error An optional location to return an error message
+///   if necessary.
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+                                       struct ArrowArrayStream* values,
+                                       struct AdbcError* error);
+
 /// \brief Read the result of a statement.
 ///
 /// This method can be called only once per execution of the
@@ -487,6 +515,26 @@ AdbcStatusCode AdbcStatementGetStream(struct 
AdbcStatement* statement,
                                       struct ArrowArrayStream* out,
                                       struct AdbcError* error);
 
+/// \brief Set a string option on a statement.
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const 
char* key,
+                                      const char* value, struct AdbcError* 
error);
+
+/// \defgroup adbc-statement-ingestion Bulk Data Ingestion
+/// While it is possible to insert data via prepared statements, it
+/// can be more efficient to explicitly perform a bulk insert.  For
+/// compatible drivers, this can be accomplished by setting up and
+/// executing a statement.  Instead of setting a SQL query or
+/// Substrait plan, bind the source data via AdbcStatementBind, and
+/// set the name of the table to be created via AdbcStatementSetOption
+/// and the options below.
+///
+/// @{
+
+/// \brief The name of the target table for a bulk insert.
+#define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table"
+
+/// }@
+
 // TODO: methods to get a particular result set from the statement,
 // etc. especially for prepared statements with parameter batches
 
@@ -572,10 +620,6 @@ struct AdbcDriver {
                                         struct AdbcError*);
   AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcError*);
   AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct 
AdbcError*);
-  AdbcStatusCode (*ConnectionSqlExecute)(struct AdbcConnection*, const char*,
-                                         struct AdbcStatement*, struct 
AdbcError*);
-  AdbcStatusCode (*ConnectionSqlPrepare)(struct AdbcConnection*, const char*,
-                                         struct AdbcStatement*, struct 
AdbcError*);
   AdbcStatusCode (*ConnectionDeserializePartitionDesc)(struct AdbcConnection*,
                                                        const uint8_t*, size_t,
                                                        struct AdbcStatement*,
@@ -596,6 +640,8 @@ struct AdbcDriver {
   AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*);
   AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*,
                                   struct ArrowSchema*, struct AdbcError*);
+  AdbcStatusCode (*StatementBindStream)(struct AdbcStatement*, struct 
ArrowArrayStream*,
+                                        struct AdbcError*);
   AdbcStatusCode (*StatementExecute)(struct AdbcStatement*, struct AdbcError*);
   AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*);
   AdbcStatusCode (*StatementGetStream)(struct AdbcStatement*, struct 
ArrowArrayStream*,
@@ -604,8 +650,12 @@ struct AdbcDriver {
                                                   struct AdbcError*);
   AdbcStatusCode (*StatementGetPartitionDesc)(struct AdbcStatement*, uint8_t*,
                                               struct AdbcError*);
+  AdbcStatusCode (*StatementSetOption)(struct AdbcStatement*, const char*, 
const char*,
+                                       struct AdbcError*);
   AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*,
                                          struct AdbcError*);
+  AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const 
uint8_t*,
+                                              size_t, struct AdbcError*);
   // Do not edit fields. New fields can only be appended to the end.
 };
 
@@ -628,7 +678,7 @@ typedef AdbcStatusCode (*AdbcDriverInitFunc)(size_t count, 
struct AdbcDriver* dr
 // struct/entrypoint instead?
 
 // For use with count
-#define ADBC_VERSION_0_0_1 21
+#define ADBC_VERSION_0_0_1 25
 
 /// }@
 
diff --git a/adbc_driver_manager/adbc_driver_manager.cc 
b/adbc_driver_manager/adbc_driver_manager.cc
index e20815a..2575e2d 100644
--- a/adbc_driver_manager/adbc_driver_manager.cc
+++ b/adbc_driver_manager/adbc_driver_manager.cc
@@ -39,10 +39,6 @@ void SetError(struct AdbcError* error, const std::string& 
message) {
 }
 
 // Default stubs
-AdbcStatusCode ConnectionSqlPrepare(struct AdbcConnection*, const char*,
-                                    struct AdbcStatement*, struct AdbcError* 
error) {
-  return ADBC_STATUS_NOT_IMPLEMENTED;
-}
 
 AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*,
                              struct ArrowSchema*, struct AdbcError* error) {
@@ -53,7 +49,26 @@ AdbcStatusCode StatementExecute(struct AdbcStatement*, 
struct AdbcError* error)
   return ADBC_STATUS_NOT_IMPLEMENTED;
 }
 
-// Temporary
+AdbcStatusCode StatementPrepare(struct AdbcStatement*, struct AdbcError* 
error) {
+  return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const 
char*,
+                                  struct AdbcError* error) {
+  return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*,
+                                    struct AdbcError* error) {
+  return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const 
uint8_t*, size_t,
+                                         struct AdbcError* error) {
+  return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+/// Temporary state while the database is being configured.
 struct TempDatabase {
   std::unordered_map<std::string, std::string> options;
   std::string driver;
@@ -186,6 +201,15 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement* 
statement,
   return statement->private_driver->StatementBind(statement, values, schema, 
error);
 }
 
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+                                       struct ArrowArrayStream* stream,
+                                       struct AdbcError* error) {
+  if (!statement->private_driver) {
+    return ADBC_STATUS_UNINITIALIZED;
+  }
+  return statement->private_driver->StatementBindStream(statement, stream, 
error);
+}
+
 AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
                                     struct AdbcError* error) {
   if (!statement->private_driver) {
@@ -232,6 +256,14 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* 
statement,
   return status;
 }
 
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const 
char* key,
+                                      const char* value, struct AdbcError* 
error) {
+  if (!statement->private_driver) {
+    return ADBC_STATUS_UNINITIALIZED;
+  }
+  return statement->private_driver->StatementSetOption(statement, key, value, 
error);
+}
+
 AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
                                         const char* query, struct AdbcError* 
error) {
   if (!statement->private_driver) {
@@ -240,6 +272,16 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct 
AdbcStatement* statement,
   return statement->private_driver->StatementSetSqlQuery(statement, query, 
error);
 }
 
+AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement,
+                                             const uint8_t* plan, size_t 
length,
+                                             struct AdbcError* error) {
+  if (!statement->private_driver) {
+    return ADBC_STATUS_UNINITIALIZED;
+  }
+  return statement->private_driver->StatementSetSubstraitPlan(statement, plan, 
length,
+                                                              error);
+}
+
 const char* AdbcStatusCodeMessage(AdbcStatusCode code) {
 #define STRINGIFY(s) #s
 #define STRINGIFY_VALUE(s) STRINGIFY(s)
@@ -300,9 +342,14 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, 
const char* entrypoint,
 
   CHECK_REQUIRED(driver, DatabaseNew);
   CHECK_REQUIRED(driver, DatabaseInit);
-  FILL_DEFAULT(driver, ConnectionSqlPrepare);
+  CHECK_REQUIRED(driver, DatabaseRelease);
+
   FILL_DEFAULT(driver, StatementBind);
   FILL_DEFAULT(driver, StatementExecute);
+  FILL_DEFAULT(driver, StatementPrepare);
+  FILL_DEFAULT(driver, StatementSetSqlQuery);
+  FILL_DEFAULT(driver, StatementSetSubstraitPlan);
+
   return ADBC_STATUS_OK;
 
 #undef FILL_DEFAULT
diff --git a/adbc_driver_manager/adbc_driver_manager.h 
b/adbc_driver_manager/adbc_driver_manager.h
index 3a67706..d20c4b1 100644
--- a/adbc_driver_manager/adbc_driver_manager.h
+++ b/adbc_driver_manager/adbc_driver_manager.h
@@ -48,6 +48,7 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const 
char* entrypoint,
                               size_t count, struct AdbcDriver* driver,
                               size_t* initialized, struct AdbcError* error);
 
+/// \brief Get a human-friendly description of a status code.
 const char* AdbcStatusCodeMessage(AdbcStatusCode code);
 
 #endif  // ADBC_DRIVER_MANAGER_H
diff --git a/adbc_driver_manager/adbc_driver_manager_test.cc 
b/adbc_driver_manager/adbc_driver_manager_test.cc
index ae53e84..e2ef8a7 100644
--- a/adbc_driver_manager/adbc_driver_manager_test.cc
+++ b/adbc_driver_manager/adbc_driver_manager_test.cc
@@ -20,6 +20,7 @@
 
 #include <arrow/c/bridge.h>
 #include <arrow/record_batch.h>
+#include <arrow/table.h>
 #include <arrow/testing/matchers.h>
 
 #include "adbc.h"
@@ -162,4 +163,51 @@ TEST_F(DriverManager, SqlPrepareMultipleParams) {
                   }));
 }
 
+TEST_F(DriverManager, BulkIngestStream) {
+  ArrowArrayStream export_stream;
+  auto bulk_schema = arrow::schema(
+      {arrow::field("ints", arrow::int64()), arrow::field("strs", 
arrow::utf8())});
+  std::vector<std::shared_ptr<arrow::RecordBatch>> bulk_batches{
+      adbc::RecordBatchFromJSON(bulk_schema, R"([[1, "foo"], [2, "bar"]])"),
+      adbc::RecordBatchFromJSON(bulk_schema, R"([[3, ""], [4, "baz"]])"),
+  };
+  auto bulk_table = *arrow::Table::FromRecordBatches(bulk_batches);
+  auto reader = std::make_shared<arrow::TableBatchReader>(*bulk_table);
+  ASSERT_OK(arrow::ExportRecordBatchReader(reader, &export_stream));
+
+  {
+    AdbcStatement statement;
+    std::memset(&statement, 0, sizeof(statement));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, 
&error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementSetOption(&statement, 
ADBC_INGEST_OPTION_TARGET_TABLE,
+                                      "bulk_insert", &error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementBindStream(&statement, &export_stream, &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+  }
+
+  {
+    AdbcStatement statement;
+    std::memset(&statement, 0, sizeof(statement));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, 
&error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM 
bulk_insert", &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+    std::shared_ptr<arrow::Schema> schema;
+    arrow::RecordBatchVector batches;
+    ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+    ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
+    EXPECT_THAT(
+        batches,
+        ::testing::UnorderedPointwise(
+            PointeesEqual(),
+            {
+                adbc::RecordBatchFromJSON(
+                    bulk_schema, R"([[1, "foo"], [2, "bar"], [3, ""], [4, 
"baz"]])"),
+            }));
+  }
+}
+
 }  // namespace adbc
diff --git a/drivers/sqlite/sqlite.cc b/drivers/sqlite/sqlite.cc
index 8d8956a..0f5a99e 100644
--- a/drivers/sqlite/sqlite.cc
+++ b/drivers/sqlite/sqlite.cc
@@ -27,6 +27,7 @@
 #include <arrow/c/bridge.h>
 #include <arrow/record_batch.h>
 #include <arrow/status.h>
+#include <arrow/table.h>
 #include <arrow/util/logging.h>
 #include <arrow/util/string_builder.h>
 
@@ -75,6 +76,19 @@ void SetError(struct AdbcError* error, Args&&... args) {
   error->release = ReleaseError;
 }
 
+AdbcStatusCode CheckRc(sqlite3* db, sqlite3_stmt* stmt, int rc, const char* 
context,
+                       struct AdbcError* error) {
+  if (rc != SQLITE_OK) {
+    SetError(db, context, error);
+    rc = sqlite3_finalize(stmt);
+    if (rc != SQLITE_OK) {
+      SetError(db, "sqlite3_finalize", error);
+    }
+    return ADBC_STATUS_IO;
+  }
+  return ADBC_STATUS_OK;
+}
+
 std::shared_ptr<arrow::Schema> StatementToSchema(sqlite3_stmt* stmt) {
   const int num_columns = sqlite3_column_count(stmt);
   arrow::FieldVector fields(num_columns);
@@ -124,14 +138,9 @@ class SqliteDatabaseImpl {
     auto it = options_.find("filename");
     if (it != options_.end()) filename = it->second.c_str();
 
-    auto status = sqlite3_open_v2(
-        filename, &db_, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, 
/*zVfs=*/nullptr);
-    if (status != SQLITE_OK) {
-      if (db_) {
-        SetError(db_, "sqlite3_open_v2", error);
-      }
-      return ADBC_STATUS_IO;
-    }
+    int rc = sqlite3_open_v2(filename, &db_, SQLITE_OPEN_READWRITE | 
SQLITE_OPEN_CREATE,
+                             /*zVfs=*/nullptr);
+    ADBC_RETURN_NOT_OK(CheckRc(db_, nullptr, rc, "sqlite3_open_v2", error));
     options_.clear();
     return ADBC_STATUS_OK;
   }
@@ -165,7 +174,7 @@ class SqliteDatabaseImpl {
 
     auto status = sqlite3_close(db_);
     if (status != SQLITE_OK) {
-      // TODO:
+      if (db_) SetError(db_, "sqlite3_close", error);
       return ADBC_STATUS_UNKNOWN;
     }
     return ADBC_STATUS_OK;
@@ -275,7 +284,13 @@ class SqliteStatementImpl : public 
arrow::RecordBatchReader {
       if (status == SQLITE_ROW) {
         continue;
       } else if (status == SQLITE_DONE) {
-        if (bind_parameters_ && bind_index_ < bind_parameters_->num_rows()) {
+        if (bind_parameters_ &&
+            (!next_parameters_ || bind_index_ >= 
next_parameters_->num_rows())) {
+          ARROW_RETURN_NOT_OK(bind_parameters_->ReadNext(&next_parameters_));
+          bind_index_ = 0;
+        }
+
+        if (next_parameters_ && bind_index_ < next_parameters_->num_rows()) {
           status = sqlite3_reset(stmt_);
           if (status != SQLITE_OK) {
             return Status::IOError("[SQLite3] sqlite3_reset: ", 
sqlite3_errmsg(db));
@@ -285,7 +300,7 @@ class SqliteStatementImpl : public arrow::RecordBatchReader 
{
           if (status == SQLITE_ROW) continue;
         } else {
           done_ = true;
-          bind_parameters_.reset();
+          next_parameters_.reset();
         }
         break;
       }
@@ -306,13 +321,12 @@ class SqliteStatementImpl : public 
arrow::RecordBatchReader {
 
   AdbcStatusCode Close(struct AdbcError* error) {
     if (stmt_) {
-      auto status = sqlite3_finalize(stmt_);
-      if (status != SQLITE_OK) {
-        SetError(connection_->db(), "sqlite3_finalize", error);
-        return ADBC_STATUS_UNKNOWN;
-      }
+      const int rc = sqlite3_finalize(stmt_);
       stmt_ = nullptr;
+      next_parameters_.reset();
       bind_parameters_.reset();
+      ADBC_RETURN_NOT_OK(
+          CheckRc(connection_->db(), nullptr, rc, "sqlite3_finalize", error));
       connection_.reset();
     }
     return ADBC_STATUS_OK;
@@ -325,71 +339,42 @@ class SqliteStatementImpl : public 
arrow::RecordBatchReader {
   AdbcStatusCode Bind(const std::shared_ptr<SqliteStatementImpl>& self,
                       struct ArrowArray* values, struct ArrowSchema* schema,
                       struct AdbcError* error) {
-    auto status = arrow::ImportRecordBatch(values, 
schema).Value(&bind_parameters_);
+    std::shared_ptr<arrow::RecordBatch> batch;
+    auto status = arrow::ImportRecordBatch(values, schema).Value(&batch);
+    if (!status.ok()) {
+      SetError(error, status);
+      return ADBC_STATUS_INVALID_ARGUMENT;
+    }
+
+    std::shared_ptr<arrow::Table> table;
+    status = arrow::Table::FromRecordBatches({std::move(batch)}).Value(&table);
     if (!status.ok()) {
       SetError(error, status);
       return ADBC_STATUS_INVALID_ARGUMENT;
     }
+
+    bind_parameters_.reset(new arrow::TableBatchReader(std::move(table)));
     return ADBC_STATUS_OK;
   }
 
-  AdbcStatusCode Execute(const std::shared_ptr<SqliteStatementImpl>& self,
-                         struct AdbcError* error) {
-    sqlite3* db = connection_->db();
-    int rc = 0;
-    if (schema_) {
-      rc = sqlite3_clear_bindings(stmt_);
-      if (rc != SQLITE_OK) {
-        SetError(db, "sqlite3_reset_bindings", error);
-        rc = sqlite3_finalize(stmt_);
-        if (rc != SQLITE_OK) {
-          SetError(db, "sqlite3_finalize", error);
-        }
-        return ADBC_STATUS_IO;
-      }
-
-      rc = sqlite3_reset(stmt_);
-      if (rc != SQLITE_OK) {
-        SetError(db, "sqlite3_reset", error);
-        rc = sqlite3_finalize(stmt_);
-        if (rc != SQLITE_OK) {
-          SetError(db, "sqlite3_finalize", error);
-        }
-        return ADBC_STATUS_IO;
-      }
-    }
-    // Step the statement and get the schema (SQLite doesn't
-    // necessarily know the schema until it begins to execute it)
-    auto status = BindNext().Value(&rc);
-    // XXX: with parameters, inferring the schema from the first
-    // argument is inaccurate (what if one is null?). Is there a way
-    // to hint to SQLite the real type?
+  AdbcStatusCode Bind(const std::shared_ptr<SqliteStatementImpl>& self,
+                      struct ArrowArrayStream* stream, struct AdbcError* 
error) {
+    auto status = 
arrow::ImportRecordBatchReader(stream).Value(&bind_parameters_);
     if (!status.ok()) {
-      // TODO: map Arrow codes to ADBC codes
       SetError(error, status);
-      return ADBC_STATUS_IO;
-    } else if (rc != SQLITE_OK) {
-      SetError(db, "sqlite3_bind", error);
-      rc = sqlite3_finalize(stmt_);
-      if (rc != SQLITE_OK) {
-        SetError(db, "sqlite3_finalize", error);
-      }
-      return ADBC_STATUS_IO;
-    }
-    rc = sqlite3_step(stmt_);
-    if (rc == SQLITE_ERROR) {
-      SetError(db, "sqlite3_step", error);
-      rc = sqlite3_finalize(stmt_);
-      if (rc != SQLITE_OK) {
-        SetError(db, "sqlite3_finalize", error);
-      }
-      return ADBC_STATUS_IO;
+      return ADBC_STATUS_INVALID_ARGUMENT;
     }
-    schema_ = StatementToSchema(stmt_);
-    done_ = rc != SQLITE_ROW;
     return ADBC_STATUS_OK;
   }
 
+  AdbcStatusCode Execute(const std::shared_ptr<SqliteStatementImpl>& self,
+                         struct AdbcError* error) {
+    if (bulk_table_.empty()) {
+      return ExecutePrepared(error);
+    }
+    return ExecuteBulk(error);
+  }
+
   AdbcStatusCode GetStream(const std::shared_ptr<SqliteStatementImpl>& self,
                            struct ArrowArrayStream* out, struct AdbcError* 
error) {
     if (!stmt_ || !schema_) {
@@ -404,51 +389,62 @@ class SqliteStatementImpl : public 
arrow::RecordBatchReader {
     return ADBC_STATUS_OK;
   }
 
+  AdbcStatusCode SetOption(const std::shared_ptr<SqliteStatementImpl>& self,
+                           const char* key, const char* value, struct 
AdbcError* error) {
+    if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) {
+      // Bulk ingest
+      if (std::strlen(value) == 0) return ADBC_STATUS_INVALID_ARGUMENT;
+      bulk_table_ = value;
+      if (stmt_) {
+        int rc = sqlite3_finalize(stmt_);
+        ADBC_RETURN_NOT_OK(
+            CheckRc(connection_->db(), nullptr, rc, "sqlite3_finalize", 
error));
+      }
+      return ADBC_STATUS_OK;
+    }
+    SetError(error, "Unknown option: ", key);
+    return ADBC_STATUS_NOT_IMPLEMENTED;
+  }
+
   AdbcStatusCode SetSqlQuery(const std::shared_ptr<SqliteStatementImpl>& self,
                              const char* query, struct AdbcError* error) {
+    bulk_table_.clear();
     sqlite3* db = connection_->db();
     int rc = sqlite3_prepare_v2(db, query, 
static_cast<int>(std::strlen(query)), &stmt_,
                                 /*pzTail=*/nullptr);
-    if (rc != SQLITE_OK) {
-      if (stmt_) {
-        rc = sqlite3_finalize(stmt_);
-        if (rc != SQLITE_OK) {
-          SetError(db, "sqlite3_finalize", error);
-        }
-      }
-      SetError(db, "sqlite3_prepare_v2", error);
-      return ADBC_STATUS_UNKNOWN;
-    }
-    return ADBC_STATUS_OK;
+    return CheckRc(connection_->db(), stmt_, rc, "sqlite3_prepare_v2", error);
   }
 
  private:
   arrow::Result<int> BindNext() {
-    if (!bind_parameters_ || bind_index_ >= bind_parameters_->num_rows()) {
+    if (!next_parameters_ || bind_index_ >= next_parameters_->num_rows()) {
       return SQLITE_OK;
     }
 
+    return BindImpl(stmt_, *next_parameters_, bind_index_++);
+  }
+
+  arrow::Result<int> BindImpl(sqlite3_stmt* stmt, const arrow::RecordBatch& 
data,
+                              int64_t row) {
     int col_index = 1;
-    // TODO: multiple output rows
-    const int bind_index = bind_index_++;
-    for (const auto& column : bind_parameters_->columns()) {
-      if (column->IsNull(bind_index)) {
-        const int rc = sqlite3_bind_null(stmt_, col_index);
+    for (const auto& column : data.columns()) {
+      if (column->IsNull(row)) {
+        const int rc = sqlite3_bind_null(stmt, col_index);
         if (rc != SQLITE_OK) return rc;
       } else {
         switch (column->type()->id()) {
           case arrow::Type::INT64: {
             const int rc = sqlite3_bind_int64(
-                stmt_, col_index,
-                static_cast<const 
arrow::Int64Array&>(*column).Value(bind_index));
+                stmt, col_index,
+                static_cast<const arrow::Int64Array&>(*column).Value(row));
             if (rc != SQLITE_OK) return rc;
             break;
           }
           case arrow::Type::STRING: {
             const auto& strings = static_cast<const 
arrow::StringArray&>(*column);
-            const int rc = sqlite3_bind_text64(
-                stmt_, col_index, strings.Value(bind_index).data(),
-                strings.value_length(bind_index), SQLITE_STATIC, SQLITE_UTF8);
+            const int rc = sqlite3_bind_text64(stmt, col_index, 
strings.Value(row).data(),
+                                               strings.value_length(row), 
SQLITE_STATIC,
+                                               SQLITE_UTF8);
             if (rc != SQLITE_OK) return rc;
             break;
           }
@@ -459,14 +455,148 @@ class SqliteStatementImpl : public 
arrow::RecordBatchReader {
       }
       col_index++;
     }
-
     return SQLITE_OK;
   }
 
+  AdbcStatusCode ExecuteBulk(struct AdbcError* error) {
+    if (!bind_parameters_) {
+      SetError(error, "Must AdbcStatementBind for bulk insertion");
+      return ADBC_STATUS_INVALID_ARGUMENT;
+    }
+
+    sqlite3* db = connection_->db();
+    sqlite3_stmt* stmt = nullptr;
+    int rc = SQLITE_OK;
+
+    auto check_status = [&](const arrow::Status& st) mutable {
+      if (!st.ok()) {
+        SetError(error, st);
+        if (stmt) {
+          rc = sqlite3_finalize(stmt);
+          if (rc != SQLITE_OK) {
+            SetError(db, "sqlite3_finalize", error);
+          }
+        }
+        return ADBC_STATUS_IO;
+      }
+      return ADBC_STATUS_OK;
+    };
+
+    // Create the table
+    // TODO: parameter to choose append/overwrite/error
+    {
+      // XXX: not injection-safe
+      std::string query = "CREATE TABLE ";
+      query += bulk_table_;
+      query += " (";
+      const auto& fields = bind_parameters_->schema()->fields();
+      for (int i = 0; i < fields.size(); i++) {
+        if (i > 0) query += ',';
+        query += fields[i]->name();
+      }
+      query += ')';
+
+      rc = sqlite3_prepare_v2(db, query.c_str(), 
static_cast<int>(query.size()), &stmt,
+                              /*pzTail=*/nullptr);
+      ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_prepare_v2", error));
+
+      rc = sqlite3_step(stmt);
+      if (rc != SQLITE_DONE) return CheckRc(db, stmt, rc, "sqlite3_step", 
error);
+
+      rc = sqlite3_finalize(stmt);
+      ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_finalize", error));
+    }
+
+    // Insert the rows
+
+    {
+      std::string query = "INSERT INTO ";
+      query += bulk_table_;
+      query += " VALUES (";
+      const auto& fields = bind_parameters_->schema()->fields();
+      for (int i = 0; i < fields.size(); i++) {
+        if (i > 0) query += ',';
+        query += '?';
+      }
+      query += ')';
+      rc = sqlite3_prepare_v2(db, query.c_str(), 
static_cast<int>(query.size()), &stmt,
+                              /*pzTail=*/nullptr);
+      ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, query.c_str(), error));
+    }
+
+    while (true) {
+      std::shared_ptr<arrow::RecordBatch> batch;
+      auto status = bind_parameters_->Next().Value(&batch);
+      ADBC_RETURN_NOT_OK(check_status(status));
+      if (!batch) break;
+
+      for (int64_t row = 0; row < batch->num_rows(); row++) {
+        status = BindImpl(stmt, *batch, row).Value(&rc);
+        ADBC_RETURN_NOT_OK(check_status(status));
+        ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_bind", error));
+
+        rc = sqlite3_step(stmt);
+        if (rc != SQLITE_DONE) {
+          return CheckRc(db, stmt, rc, "sqlite3_step", error);
+        }
+
+        rc = sqlite3_reset(stmt);
+        ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_reset", error));
+
+        rc = sqlite3_clear_bindings(stmt);
+        ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_clear_bindings", 
error));
+      }
+    }
+
+    rc = sqlite3_finalize(stmt);
+    return CheckRc(db, nullptr, rc, "sqlite3_finalize", error);
+  }
+
+  AdbcStatusCode ExecutePrepared(struct AdbcError* error) {
+    sqlite3* db = connection_->db();
+    int rc = SQLITE_OK;
+    if (schema_) {
+      rc = sqlite3_clear_bindings(stmt_);
+      ADBC_RETURN_NOT_OK(CheckRc(db, stmt_, rc, "sqlite3_clear_bindings", 
error));
+
+      rc = sqlite3_reset(stmt_);
+      ADBC_RETURN_NOT_OK(CheckRc(db, stmt_, rc, "sqlite3_reset", error));
+    }
+    // Step the statement and get the schema (SQLite doesn't
+    // necessarily know the schema until it begins to execute it)
+
+    Status status;
+    if (bind_parameters_) {
+      status = bind_parameters_->ReadNext(&next_parameters_);
+      if (status.ok()) status = BindNext().Value(&rc);
+    }
+    // XXX: with parameters, inferring the schema from the first
+    // argument is inaccurate (what if one is null?). Is there a way
+    // to hint to SQLite the real type?
+
+    if (!status.ok()) {
+      // TODO: map Arrow codes to ADBC codes
+      SetError(error, status);
+      return ADBC_STATUS_IO;
+    }
+    ADBC_RETURN_NOT_OK(CheckRc(db, stmt_, rc, "sqlite3_bind", error));
+
+    rc = sqlite3_step(stmt_);
+    if (rc == SQLITE_ERROR) {
+      return CheckRc(db, stmt_, rc, "sqlite3_error", error);
+    }
+    schema_ = StatementToSchema(stmt_);
+    done_ = rc != SQLITE_ROW;
+    return ADBC_STATUS_OK;
+  }
+
   std::shared_ptr<SqliteConnectionImpl> connection_;
+  // Target of bulk ingestion (rather janky to store state like this, though…)
+  std::string bulk_table_;
   sqlite3_stmt* stmt_;
   std::shared_ptr<arrow::Schema> schema_;
-  std::shared_ptr<arrow::RecordBatch> bind_parameters_;
+  std::shared_ptr<arrow::RecordBatchReader> bind_parameters_;
+  std::shared_ptr<arrow::RecordBatch> next_parameters_;
   int64_t bind_index_;
   bool done_;
 };
@@ -557,6 +687,16 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement* 
statement,
   return (*ptr)->Bind(*ptr, values, schema, error);
 }
 
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+                                       struct ArrowArrayStream* stream,
+                                       struct AdbcError* error) {
+  if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
+  auto* ptr =
+      
reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
+  return (*ptr)->Bind(*ptr, stream, error);
+}
+
 ADBC_DRIVER_EXPORT
 AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
                                     struct AdbcError* error) {
@@ -621,6 +761,15 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* 
statement,
   return status;
 }
 
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const 
char* key,
+                                      const char* value, struct AdbcError* 
error) {
+  if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
+  auto* ptr =
+      
reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
+  return (*ptr)->SetOption(*ptr, key, value, error);
+}
+
 ADBC_DRIVER_EXPORT
 AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
                                         const char* query, struct AdbcError* 
error) {
@@ -648,6 +797,7 @@ AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct 
AdbcDriver* driver,
   driver->ConnectionSetOption = AdbcConnectionSetOption;
 
   driver->StatementBind = AdbcStatementBind;
+  driver->StatementBindStream = AdbcStatementBindStream;
   driver->StatementExecute = AdbcStatementExecute;
   driver->StatementGetPartitionDesc = AdbcStatementGetPartitionDesc;
   driver->StatementGetPartitionDescSize = AdbcStatementGetPartitionDescSize;
@@ -655,6 +805,7 @@ AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct 
AdbcDriver* driver,
   driver->StatementNew = AdbcStatementNew;
   driver->StatementPrepare = AdbcStatementPrepare;
   driver->StatementRelease = AdbcStatementRelease;
+  driver->StatementSetOption = AdbcStatementSetOption;
   driver->StatementSetSqlQuery = AdbcStatementSetSqlQuery;
   *initialized = ADBC_VERSION_0_0_1;
   return ADBC_STATUS_OK;
diff --git a/drivers/sqlite/sqlite_test.cc b/drivers/sqlite/sqlite_test.cc
index e3b4f6e..b562f6a 100644
--- a/drivers/sqlite/sqlite_test.cc
+++ b/drivers/sqlite/sqlite_test.cc
@@ -20,6 +20,7 @@
 
 #include <arrow/c/bridge.h>
 #include <arrow/record_batch.h>
+#include <arrow/table.h>
 #include <arrow/testing/matchers.h>
 
 #include "adbc.h"
@@ -149,6 +150,90 @@ TEST_F(Sqlite, SqlPrepareMultipleParams) {
                   }));
 }
 
+TEST_F(Sqlite, BulkIngestTable) {
+  ArrowArray export_table;
+  ArrowSchema export_schema;
+  auto bulk_schema = arrow::schema(
+      {arrow::field("ints", arrow::int64()), arrow::field("strs", 
arrow::utf8())});
+  auto bulk_table = adbc::RecordBatchFromJSON(bulk_schema, R"([[1, "foo"], [2, 
"bar"]])");
+  ASSERT_OK(ExportRecordBatch(*bulk_table, &export_table));
+  ASSERT_OK(ExportSchema(*bulk_schema, &export_schema));
+
+  {
+    AdbcStatement statement;
+    std::memset(&statement, 0, sizeof(statement));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, 
&error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementSetOption(&statement, 
ADBC_INGEST_OPTION_TARGET_TABLE,
+                                      "bulk_insert", &error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementBind(&statement, &export_table, &export_schema, 
&error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+  }
+
+  {
+    AdbcStatement statement;
+    std::memset(&statement, 0, sizeof(statement));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, 
&error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM 
bulk_insert", &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+    std::shared_ptr<arrow::Schema> schema;
+    arrow::RecordBatchVector batches;
+    ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+    ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
+    EXPECT_THAT(batches, ::testing::UnorderedPointwise(PointeesEqual(), 
{bulk_table}));
+  }
+}
+
+TEST_F(Sqlite, BulkIngestStream) {
+  ArrowArrayStream export_stream;
+  auto bulk_schema = arrow::schema(
+      {arrow::field("ints", arrow::int64()), arrow::field("strs", 
arrow::utf8())});
+  std::vector<std::shared_ptr<arrow::RecordBatch>> bulk_batches{
+      adbc::RecordBatchFromJSON(bulk_schema, R"([[1, "foo"], [2, "bar"]])"),
+      adbc::RecordBatchFromJSON(bulk_schema, R"([[3, ""], [4, "baz"]])"),
+  };
+  auto bulk_table = *arrow::Table::FromRecordBatches(bulk_batches);
+  auto reader = std::make_shared<arrow::TableBatchReader>(*bulk_table);
+  ASSERT_OK(arrow::ExportRecordBatchReader(reader, &export_stream));
+
+  {
+    AdbcStatement statement;
+    std::memset(&statement, 0, sizeof(statement));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, 
&error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementSetOption(&statement, 
ADBC_INGEST_OPTION_TARGET_TABLE,
+                                      "bulk_insert", &error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementBindStream(&statement, &export_stream, &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+  }
+
+  {
+    AdbcStatement statement;
+    std::memset(&statement, 0, sizeof(statement));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, 
&error));
+    ADBC_ASSERT_OK_WITH_ERROR(
+        error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM 
bulk_insert", &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+    std::shared_ptr<arrow::Schema> schema;
+    arrow::RecordBatchVector batches;
+    ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+    ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
+    EXPECT_THAT(
+        batches,
+        ::testing::UnorderedPointwise(
+            PointeesEqual(),
+            {
+                adbc::RecordBatchFromJSON(
+                    bulk_schema, R"([[1, "foo"], [2, "bar"], [3, ""], [4, 
"baz"]])"),
+            }));
+  }
+}
+
 TEST_F(Sqlite, MultipleConnections) {
   struct AdbcConnection connection2;
 

Reply via email to