This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 67b39d8 Make AdbcConnectionNew 2-adic for consistency (#27)
67b39d8 is described below
commit 67b39d85bc22a13585a165bb674c39a141c07c6d
Author: David Li <[email protected]>
AuthorDate: Tue Jul 5 12:54:33 2022 -0400
Make AdbcConnectionNew 2-adic for consistency (#27)
* Make AdbcConnectionNew 2-adic for consistency
* Fix Python build
---
adbc.h | 11 +++--
adbc_driver_manager/adbc_driver_manager.cc | 52 +++++++++++++++++-----
adbc_driver_manager/adbc_driver_manager_test.cc | 22 ++++++++-
drivers/flight_sql/flight_sql.cc | 30 +++++++------
drivers/flight_sql/flight_sql_test.cc | 5 ++-
drivers/sqlite/sqlite.cc | 31 +++++++------
drivers/sqlite/sqlite_test.cc | 16 +++----
.../adbc_driver_manager/_lib.pyx | 8 ++--
8 files changed, 115 insertions(+), 60 deletions(-)
diff --git a/adbc.h b/adbc.h
index f21415f..3c2445b 100644
--- a/adbc.h
+++ b/adbc.h
@@ -322,8 +322,7 @@ struct ADBC_EXPORT AdbcConnection {
/// \brief Allocate a new (but uninitialized) connection.
ADBC_EXPORT
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
- struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
struct AdbcError* error);
ADBC_EXPORT
@@ -333,7 +332,7 @@ AdbcStatusCode AdbcConnectionSetOption(struct
AdbcConnection* connection, const
/// \brief Finish setting options and initialize the connection.
ADBC_EXPORT
AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
- struct AdbcError* error);
+ struct AdbcDatabase* database, struct
AdbcError* error);
/// \brief Destroy this connection.
/// \param[in] connection The connection to release.
@@ -813,11 +812,11 @@ struct ADBC_EXPORT AdbcDriver {
AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*);
AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*);
- AdbcStatusCode (*ConnectionNew)(struct AdbcDatabase*, struct AdbcConnection*,
- struct AdbcError*);
+ AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*);
AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*,
const char*,
struct AdbcError*);
- AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct
AdbcDatabase*,
+ struct AdbcError*);
AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct
AdbcError*);
AdbcStatusCode (*ConnectionDeserializePartitionDesc)(struct AdbcConnection*,
diff --git a/adbc_driver_manager/adbc_driver_manager.cc
b/adbc_driver_manager/adbc_driver_manager.cc
index 68e7223..96f0bf2 100644
--- a/adbc_driver_manager/adbc_driver_manager.cc
+++ b/adbc_driver_manager/adbc_driver_manager.cc
@@ -21,6 +21,7 @@
#include <cstring>
#include <string>
#include <unordered_map>
+#include <utility>
#if defined(_WIN32)
#include <windows.h> // Must come first
@@ -127,6 +128,11 @@ struct TempDatabase {
std::string entrypoint;
};
+/// Temporary state while the database is being configured.
+struct TempConnection {
+ std::unordered_map<std::string, std::string> options;
+};
+
#if defined(_WIN32)
/// Append a description of the Windows error to the buffer.
void GetWinError(std::string* buffer) {
@@ -278,7 +284,12 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase*
database, struct AdbcError*
AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
struct AdbcError* error) {
if (!database->private_driver) {
- return ADBC_STATUS_INVALID_STATE;
+ if (database->private_data) {
+ TempDatabase* args =
reinterpret_cast<TempDatabase*>(database->private_data);
+ delete args;
+ database->private_data = nullptr;
+ }
+ return ADBC_STATUS_OK;
}
auto status = database->private_driver->DatabaseRelease(database, error);
if (database->private_driver->release) {
@@ -297,28 +308,45 @@ AdbcStatusCode AdbcConnectionCommit(struct
AdbcConnection* connection,
}
AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
struct AdbcError* error) {
- if (!connection->private_driver) {
+ if (!connection->private_data) {
+ SetError(error, "Must call AdbcConnectionNew first");
return ADBC_STATUS_INVALID_STATE;
}
- return connection->private_driver->ConnectionInit(connection, error);
+ TempConnection* args =
reinterpret_cast<TempConnection*>(connection->private_data);
+ std::unordered_map<std::string, std::string> options =
std::move(args->options);
+ delete args;
+
+ auto status = database->private_driver->ConnectionNew(connection, error);
+ if (status != ADBC_STATUS_OK) return status;
+ connection->private_driver = database->private_driver;
+
+ for (const auto& option : options) {
+ status = database->private_driver->ConnectionSetOption(
+ connection, option.first.c_str(), option.second.c_str(), error);
+ if (status != ADBC_STATUS_OK) return status;
+ }
+ return connection->private_driver->ConnectionInit(connection, database,
error);
}
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
- struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
struct AdbcError* error) {
- if (!database->private_driver) {
- return ADBC_STATUS_INVALID_STATE;
- }
- auto status = database->private_driver->ConnectionNew(database, connection,
error);
- connection->private_driver = database->private_driver;
- return status;
+ // Allocate a temporary structure to store options pre-Init
+ connection->private_data = new TempConnection;
+ connection->private_driver = nullptr;
+ return ADBC_STATUS_OK;
}
AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
struct AdbcError* error) {
if (!connection->private_driver) {
- return ADBC_STATUS_INVALID_STATE;
+ if (connection->private_data) {
+ TempConnection* args =
reinterpret_cast<TempConnection*>(connection->private_data);
+ delete args;
+ connection->private_data = nullptr;
+ }
+ return ADBC_STATUS_OK;
}
auto status = connection->private_driver->ConnectionRelease(connection,
error);
connection->private_driver = nullptr;
diff --git a/adbc_driver_manager/adbc_driver_manager_test.cc
b/adbc_driver_manager/adbc_driver_manager_test.cc
index e130621..5858594 100644
--- a/adbc_driver_manager/adbc_driver_manager_test.cc
+++ b/adbc_driver_manager/adbc_driver_manager_test.cc
@@ -58,8 +58,8 @@ class DriverManager : public ::testing::Test {
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
ASSERT_NE(database.private_data, nullptr);
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection,
&error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection,
&database, &error));
ASSERT_NE(connection.private_data, nullptr);
}
@@ -88,6 +88,24 @@ class DriverManager : public ::testing::Test {
AdbcError error = {};
};
+TEST_F(DriverManager, DatabaseInitRelease) {
+ AdbcError error = {};
+ AdbcDatabase database;
+ std::memset(&database, 0, sizeof(database));
+
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+}
+
+TEST_F(DriverManager, ConnectionInitRelease) {
+ AdbcError error = {};
+ AdbcConnection connection;
+ std::memset(&connection, 0, sizeof(connection));
+
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection, &error));
+}
+
TEST_F(DriverManager, SqlExecute) {
std::string query = "SELECT 1";
AdbcStatement statement;
diff --git a/drivers/flight_sql/flight_sql.cc b/drivers/flight_sql/flight_sql.cc
index 50b11d3..f33ac8b 100644
--- a/drivers/flight_sql/flight_sql.cc
+++ b/drivers/flight_sql/flight_sql.cc
@@ -143,8 +143,7 @@ class FlightSqlDatabaseImpl {
class FlightSqlConnectionImpl {
public:
- explicit FlightSqlConnectionImpl(std::shared_ptr<FlightSqlDatabaseImpl>
database)
- : database_(std::move(database)), client_(nullptr) {}
+ FlightSqlConnectionImpl() : database_(nullptr), client_(nullptr) {}
//----------------------------------------------------------
// Common Functions
@@ -152,7 +151,14 @@ class FlightSqlConnectionImpl {
flightsql::FlightSqlClient* client() const { return client_; }
- AdbcStatusCode Init(struct AdbcError* error) {
+ AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error) {
+ if (!database->private_data) {
+ SetError(error, "database is not initialized");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ database_ = *reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(
+ database->private_data);
client_ = database_->Connect();
if (!client_) {
SetError(error, "Database not yet initialized!");
@@ -398,12 +404,9 @@ AdbcStatusCode FlightSqlConnectionGetTableTypes(struct
AdbcConnection* connectio
return (*ptr)->GetTableTypes(error);
}
-AdbcStatusCode FlightSqlConnectionNew(struct AdbcDatabase* database,
- struct AdbcConnection* connection,
+AdbcStatusCode FlightSqlConnectionNew(struct AdbcConnection* connection,
struct AdbcError* error) {
- auto ptr =
-
reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(database->private_data);
- auto impl = std::make_shared<FlightSqlConnectionImpl>(*ptr);
+ auto impl = std::make_shared<FlightSqlConnectionImpl>();
connection->private_data = new
std::shared_ptr<FlightSqlConnectionImpl>(impl);
return ADBC_STATUS_OK;
}
@@ -415,11 +418,12 @@ AdbcStatusCode FlightSqlConnectionSetOption(struct
AdbcConnection* connection,
}
AdbcStatusCode FlightSqlConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
struct AdbcError* error) {
if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
auto ptr = reinterpret_cast<std::shared_ptr<FlightSqlConnectionImpl>*>(
connection->private_data);
- return (*ptr)->Init(error);
+ return (*ptr)->Init(database, error);
}
AdbcStatusCode FlightSqlConnectionRelease(struct AdbcConnection* connection,
@@ -533,14 +537,14 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct
AdbcConnection* connection,
}
AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
struct AdbcError* error) {
- return FlightSqlConnectionInit(connection, error);
+ return FlightSqlConnectionInit(connection, database, error);
}
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
- struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
struct AdbcError* error) {
- return FlightSqlConnectionNew(database, connection, error);
+ return FlightSqlConnectionNew(connection, error);
}
AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection,
const char* key,
diff --git a/drivers/flight_sql/flight_sql_test.cc
b/drivers/flight_sql/flight_sql_test.cc
index ecb5a30..815078a 100644
--- a/drivers/flight_sql/flight_sql_test.cc
+++ b/drivers/flight_sql/flight_sql_test.cc
@@ -41,8 +41,9 @@ class AdbcFlightSqlTest : public ::testing::Test {
ADBC_ASSERT_OK_WITH_ERROR(
error, AdbcDatabaseSetOption(&database, "location", location,
&error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database,
&connection, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcConnectionInit(&connection, &database,
&error));
} else {
FAIL() << "Must provide location of Flight SQL server at " <<
kServerEnvVar;
}
diff --git a/drivers/sqlite/sqlite.cc b/drivers/sqlite/sqlite.cc
index 10abb8c..493675a 100644
--- a/drivers/sqlite/sqlite.cc
+++ b/drivers/sqlite/sqlite.cc
@@ -236,8 +236,7 @@ class SqliteDatabaseImpl {
class SqliteConnectionImpl {
public:
- explicit SqliteConnectionImpl(std::shared_ptr<SqliteDatabaseImpl> database)
- : database_(std::move(database)), db_(nullptr), autocommit_(true) {}
+ SqliteConnectionImpl() : database_(nullptr), db_(nullptr), autocommit_(true)
{}
sqlite3* db() const { return db_; }
@@ -274,7 +273,15 @@ class SqliteConnectionImpl {
return FromArrowStatus(arrow::ExportSchema(*arrow_schema, schema), error);
}
- AdbcStatusCode Init(struct AdbcError* error) { return
database_->Connect(&db_, error); }
+ AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error) {
+ if (!database->private_data) {
+ SetError(error, "database is not initialized");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ database_ =
+
*reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(database->private_data);
+ return database_->Connect(&db_, error);
+ }
AdbcStatusCode Release(struct AdbcError* error) {
return database_->Disconnect(db_, error);
@@ -1228,19 +1235,17 @@ AdbcStatusCode SqliteConnectionGetTableTypes(struct
AdbcConnection* connection,
}
AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
struct AdbcError* error) {
if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
auto ptr =
reinterpret_cast<std::shared_ptr<SqliteConnectionImpl>*>(connection->private_data);
- return (*ptr)->Init(error);
+ return (*ptr)->Init(database, error);
}
-AdbcStatusCode SqliteConnectionNew(struct AdbcDatabase* database,
- struct AdbcConnection* connection,
+AdbcStatusCode SqliteConnectionNew(struct AdbcConnection* connection,
struct AdbcError* error) {
- auto ptr =
-
reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(database->private_data);
- auto impl = std::make_shared<SqliteConnectionImpl>(*ptr);
+ auto impl = std::make_shared<SqliteConnectionImpl>();
connection->private_data = new std::shared_ptr<SqliteConnectionImpl>(impl);
return ADBC_STATUS_OK;
}
@@ -1430,14 +1435,14 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct
AdbcConnection* connection,
}
AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
struct AdbcError* error) {
- return SqliteConnectionInit(connection, error);
+ return SqliteConnectionInit(connection, database, error);
}
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
- struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
struct AdbcError* error) {
- return SqliteConnectionNew(database, connection, error);
+ return SqliteConnectionNew(connection, error);
}
AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
diff --git a/drivers/sqlite/sqlite_test.cc b/drivers/sqlite/sqlite_test.cc
index 6ffcdc2..3f40587 100644
--- a/drivers/sqlite/sqlite_test.cc
+++ b/drivers/sqlite/sqlite_test.cc
@@ -56,8 +56,8 @@ class Sqlite : public ::testing::Test {
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
ASSERT_NE(database.private_data, nullptr);
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection,
&error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection,
&database, &error));
ASSERT_NE(connection.private_data, nullptr);
}
@@ -328,8 +328,8 @@ TEST_F(Sqlite, BulkIngestStream) {
TEST_F(Sqlite, MultipleConnections) {
struct AdbcConnection connection2;
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection2,
&error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection2, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &database,
&error));
ASSERT_NE(connection.private_data, nullptr);
{
@@ -661,12 +661,12 @@ TEST_F(Sqlite, Transactions) {
AdbcDatabaseSetOption(&database, "filename",
"file:Sqlite_Transactions?mode=memory&cache=shared", &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection,
&error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &database,
&error));
struct AdbcConnection connection2;
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection2,
&error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection2, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &database,
&error));
ASSERT_NE(connection.private_data, nullptr);
AdbcStatement statement;
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 614f701..410e9af 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -79,9 +79,9 @@ cdef extern from "adbc.h":
AdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error)
AdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError*
error)
- AdbcStatusCode AdbcConnectionNew(CAdbcDatabase* database, CAdbcConnection*
connection, CAdbcError* error)
+ AdbcStatusCode AdbcConnectionNew(CAdbcConnection* connection, CAdbcError*
error)
AdbcStatusCode AdbcConnectionSetOption(CAdbcConnection* connection, const
char* key, const char* value, CAdbcError* error)
- AdbcStatusCode AdbcConnectionInit(CAdbcConnection* connection, CAdbcError*
error)
+ AdbcStatusCode AdbcConnectionInit(CAdbcConnection* connection,
CAdbcDatabase* database, CAdbcError* error)
AdbcStatusCode AdbcConnectionRelease(CAdbcConnection* connection,
CAdbcError* error)
AdbcStatusCode AdbcStatementBind(CAdbcStatement* statement, CArrowArray*,
CArrowSchema*, CAdbcError* error)
@@ -247,7 +247,7 @@ cdef class AdbcConnection(_AdbcHandle):
cdef const char* c_value
memset(&self.connection, 0, cython.sizeof(CAdbcConnection))
- status = AdbcConnectionNew(&database.database, &self.connection,
&c_error)
+ status = AdbcConnectionNew(&self.connection, &c_error)
check_error(status, &c_error)
for key, value in kwargs.items():
@@ -258,7 +258,7 @@ cdef class AdbcConnection(_AdbcHandle):
status = AdbcConnectionSetOption(&self.connection, c_key, c_value,
&c_error)
check_error(status, &c_error)
- status = AdbcConnectionInit(&self.connection, &c_error)
+ status = AdbcConnectionInit(&self.connection, &database.database,
&c_error)
check_error(status, &c_error)
def close(self) -> None: