This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 67b39d8  Make AdbcConnectionNew 2-adic for consistency (#27)
67b39d8 is described below

commit 67b39d85bc22a13585a165bb674c39a141c07c6d
Author: David Li <[email protected]>
AuthorDate: Tue Jul 5 12:54:33 2022 -0400

    Make AdbcConnectionNew 2-adic for consistency (#27)
    
    * Make AdbcConnectionNew 2-adic for consistency
    
    * Fix Python build
---
 adbc.h                                             | 11 +++--
 adbc_driver_manager/adbc_driver_manager.cc         | 52 +++++++++++++++++-----
 adbc_driver_manager/adbc_driver_manager_test.cc    | 22 ++++++++-
 drivers/flight_sql/flight_sql.cc                   | 30 +++++++------
 drivers/flight_sql/flight_sql_test.cc              |  5 ++-
 drivers/sqlite/sqlite.cc                           | 31 +++++++------
 drivers/sqlite/sqlite_test.cc                      | 16 +++----
 .../adbc_driver_manager/_lib.pyx                   |  8 ++--
 8 files changed, 115 insertions(+), 60 deletions(-)

diff --git a/adbc.h b/adbc.h
index f21415f..3c2445b 100644
--- a/adbc.h
+++ b/adbc.h
@@ -322,8 +322,7 @@ struct ADBC_EXPORT AdbcConnection {
 
 /// \brief Allocate a new (but uninitialized) connection.
 ADBC_EXPORT
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
-                                 struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
                                  struct AdbcError* error);
 
 ADBC_EXPORT
@@ -333,7 +332,7 @@ AdbcStatusCode AdbcConnectionSetOption(struct 
AdbcConnection* connection, const
 /// \brief Finish setting options and initialize the connection.
 ADBC_EXPORT
 AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
-                                  struct AdbcError* error);
+                                  struct AdbcDatabase* database, struct 
AdbcError* error);
 
 /// \brief Destroy this connection.
 /// \param[in] connection The connection to release.
@@ -813,11 +812,11 @@ struct ADBC_EXPORT AdbcDriver {
   AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*);
   AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*);
 
-  AdbcStatusCode (*ConnectionNew)(struct AdbcDatabase*, struct AdbcConnection*,
-                                  struct AdbcError*);
+  AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*);
   AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, 
const char*,
                                         struct AdbcError*);
-  AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcError*);
+  AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct 
AdbcDatabase*,
+                                   struct AdbcError*);
   AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct 
AdbcError*);
 
   AdbcStatusCode (*ConnectionDeserializePartitionDesc)(struct AdbcConnection*,
diff --git a/adbc_driver_manager/adbc_driver_manager.cc 
b/adbc_driver_manager/adbc_driver_manager.cc
index 68e7223..96f0bf2 100644
--- a/adbc_driver_manager/adbc_driver_manager.cc
+++ b/adbc_driver_manager/adbc_driver_manager.cc
@@ -21,6 +21,7 @@
 #include <cstring>
 #include <string>
 #include <unordered_map>
+#include <utility>
 
 #if defined(_WIN32)
 #include <windows.h>  // Must come first
@@ -127,6 +128,11 @@ struct TempDatabase {
   std::string entrypoint;
 };
 
+/// Temporary state while the database is being configured.
+struct TempConnection {
+  std::unordered_map<std::string, std::string> options;
+};
+
 #if defined(_WIN32)
 /// Append a description of the Windows error to the buffer.
 void GetWinError(std::string* buffer) {
@@ -278,7 +284,12 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* 
database, struct AdbcError*
 AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
                                    struct AdbcError* error) {
   if (!database->private_driver) {
-    return ADBC_STATUS_INVALID_STATE;
+    if (database->private_data) {
+      TempDatabase* args = 
reinterpret_cast<TempDatabase*>(database->private_data);
+      delete args;
+      database->private_data = nullptr;
+    }
+    return ADBC_STATUS_OK;
   }
   auto status = database->private_driver->DatabaseRelease(database, error);
   if (database->private_driver->release) {
@@ -297,28 +308,45 @@ AdbcStatusCode AdbcConnectionCommit(struct 
AdbcConnection* connection,
 }
 
 AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+                                  struct AdbcDatabase* database,
                                   struct AdbcError* error) {
-  if (!connection->private_driver) {
+  if (!connection->private_data) {
+    SetError(error, "Must call AdbcConnectionNew first");
     return ADBC_STATUS_INVALID_STATE;
   }
-  return connection->private_driver->ConnectionInit(connection, error);
+  TempConnection* args = 
reinterpret_cast<TempConnection*>(connection->private_data);
+  std::unordered_map<std::string, std::string> options = 
std::move(args->options);
+  delete args;
+
+  auto status = database->private_driver->ConnectionNew(connection, error);
+  if (status != ADBC_STATUS_OK) return status;
+  connection->private_driver = database->private_driver;
+
+  for (const auto& option : options) {
+    status = database->private_driver->ConnectionSetOption(
+        connection, option.first.c_str(), option.second.c_str(), error);
+    if (status != ADBC_STATUS_OK) return status;
+  }
+  return connection->private_driver->ConnectionInit(connection, database, 
error);
 }
 
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
-                                 struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
                                  struct AdbcError* error) {
-  if (!database->private_driver) {
-    return ADBC_STATUS_INVALID_STATE;
-  }
-  auto status = database->private_driver->ConnectionNew(database, connection, 
error);
-  connection->private_driver = database->private_driver;
-  return status;
+  // Allocate a temporary structure to store options pre-Init
+  connection->private_data = new TempConnection;
+  connection->private_driver = nullptr;
+  return ADBC_STATUS_OK;
 }
 
 AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
                                      struct AdbcError* error) {
   if (!connection->private_driver) {
-    return ADBC_STATUS_INVALID_STATE;
+    if (connection->private_data) {
+      TempConnection* args = 
reinterpret_cast<TempConnection*>(connection->private_data);
+      delete args;
+      connection->private_data = nullptr;
+    }
+    return ADBC_STATUS_OK;
   }
   auto status = connection->private_driver->ConnectionRelease(connection, 
error);
   connection->private_driver = nullptr;
diff --git a/adbc_driver_manager/adbc_driver_manager_test.cc 
b/adbc_driver_manager/adbc_driver_manager_test.cc
index e130621..5858594 100644
--- a/adbc_driver_manager/adbc_driver_manager_test.cc
+++ b/adbc_driver_manager/adbc_driver_manager_test.cc
@@ -58,8 +58,8 @@ class DriverManager : public ::testing::Test {
     ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
     ASSERT_NE(database.private_data, nullptr);
 
-    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection, 
&error));
-    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, 
&database, &error));
     ASSERT_NE(connection.private_data, nullptr);
   }
 
@@ -88,6 +88,24 @@ class DriverManager : public ::testing::Test {
   AdbcError error = {};
 };
 
+TEST_F(DriverManager, DatabaseInitRelease) {
+  AdbcError error = {};
+  AdbcDatabase database;
+  std::memset(&database, 0, sizeof(database));
+
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+}
+
+TEST_F(DriverManager, ConnectionInitRelease) {
+  AdbcError error = {};
+  AdbcConnection connection;
+  std::memset(&connection, 0, sizeof(connection));
+
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection, &error));
+}
+
 TEST_F(DriverManager, SqlExecute) {
   std::string query = "SELECT 1";
   AdbcStatement statement;
diff --git a/drivers/flight_sql/flight_sql.cc b/drivers/flight_sql/flight_sql.cc
index 50b11d3..f33ac8b 100644
--- a/drivers/flight_sql/flight_sql.cc
+++ b/drivers/flight_sql/flight_sql.cc
@@ -143,8 +143,7 @@ class FlightSqlDatabaseImpl {
 
 class FlightSqlConnectionImpl {
  public:
-  explicit FlightSqlConnectionImpl(std::shared_ptr<FlightSqlDatabaseImpl> 
database)
-      : database_(std::move(database)), client_(nullptr) {}
+  FlightSqlConnectionImpl() : database_(nullptr), client_(nullptr) {}
 
   //----------------------------------------------------------
   // Common Functions
@@ -152,7 +151,14 @@ class FlightSqlConnectionImpl {
 
   flightsql::FlightSqlClient* client() const { return client_; }
 
-  AdbcStatusCode Init(struct AdbcError* error) {
+  AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error) {
+    if (!database->private_data) {
+      SetError(error, "database is not initialized");
+      return ADBC_STATUS_INVALID_STATE;
+    }
+
+    database_ = *reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(
+        database->private_data);
     client_ = database_->Connect();
     if (!client_) {
       SetError(error, "Database not yet initialized!");
@@ -398,12 +404,9 @@ AdbcStatusCode FlightSqlConnectionGetTableTypes(struct 
AdbcConnection* connectio
   return (*ptr)->GetTableTypes(error);
 }
 
-AdbcStatusCode FlightSqlConnectionNew(struct AdbcDatabase* database,
-                                      struct AdbcConnection* connection,
+AdbcStatusCode FlightSqlConnectionNew(struct AdbcConnection* connection,
                                       struct AdbcError* error) {
-  auto ptr =
-      
reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(database->private_data);
-  auto impl = std::make_shared<FlightSqlConnectionImpl>(*ptr);
+  auto impl = std::make_shared<FlightSqlConnectionImpl>();
   connection->private_data = new 
std::shared_ptr<FlightSqlConnectionImpl>(impl);
   return ADBC_STATUS_OK;
 }
@@ -415,11 +418,12 @@ AdbcStatusCode FlightSqlConnectionSetOption(struct 
AdbcConnection* connection,
 }
 
 AdbcStatusCode FlightSqlConnectionInit(struct AdbcConnection* connection,
+                                       struct AdbcDatabase* database,
                                        struct AdbcError* error) {
   if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
   auto ptr = reinterpret_cast<std::shared_ptr<FlightSqlConnectionImpl>*>(
       connection->private_data);
-  return (*ptr)->Init(error);
+  return (*ptr)->Init(database, error);
 }
 
 AdbcStatusCode FlightSqlConnectionRelease(struct AdbcConnection* connection,
@@ -533,14 +537,14 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct 
AdbcConnection* connection,
 }
 
 AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+                                  struct AdbcDatabase* database,
                                   struct AdbcError* error) {
-  return FlightSqlConnectionInit(connection, error);
+  return FlightSqlConnectionInit(connection, database, error);
 }
 
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
-                                 struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
                                  struct AdbcError* error) {
-  return FlightSqlConnectionNew(database, connection, error);
+  return FlightSqlConnectionNew(connection, error);
 }
 
 AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, 
const char* key,
diff --git a/drivers/flight_sql/flight_sql_test.cc 
b/drivers/flight_sql/flight_sql_test.cc
index ecb5a30..815078a 100644
--- a/drivers/flight_sql/flight_sql_test.cc
+++ b/drivers/flight_sql/flight_sql_test.cc
@@ -41,8 +41,9 @@ class AdbcFlightSqlTest : public ::testing::Test {
       ADBC_ASSERT_OK_WITH_ERROR(
           error, AdbcDatabaseSetOption(&database, "location", location, 
&error));
       ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
-      ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, 
&connection, &error));
-      ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, 
&error));
+      ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+      ADBC_ASSERT_OK_WITH_ERROR(error,
+                                AdbcConnectionInit(&connection, &database, 
&error));
     } else {
       FAIL() << "Must provide location of Flight SQL server at " << 
kServerEnvVar;
     }
diff --git a/drivers/sqlite/sqlite.cc b/drivers/sqlite/sqlite.cc
index 10abb8c..493675a 100644
--- a/drivers/sqlite/sqlite.cc
+++ b/drivers/sqlite/sqlite.cc
@@ -236,8 +236,7 @@ class SqliteDatabaseImpl {
 
 class SqliteConnectionImpl {
  public:
-  explicit SqliteConnectionImpl(std::shared_ptr<SqliteDatabaseImpl> database)
-      : database_(std::move(database)), db_(nullptr), autocommit_(true) {}
+  SqliteConnectionImpl() : database_(nullptr), db_(nullptr), autocommit_(true) 
{}
 
   sqlite3* db() const { return db_; }
 
@@ -274,7 +273,15 @@ class SqliteConnectionImpl {
     return FromArrowStatus(arrow::ExportSchema(*arrow_schema, schema), error);
   }
 
-  AdbcStatusCode Init(struct AdbcError* error) { return 
database_->Connect(&db_, error); }
+  AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error) {
+    if (!database->private_data) {
+      SetError(error, "database is not initialized");
+      return ADBC_STATUS_INVALID_STATE;
+    }
+    database_ =
+        
*reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(database->private_data);
+    return database_->Connect(&db_, error);
+  }
 
   AdbcStatusCode Release(struct AdbcError* error) {
     return database_->Disconnect(db_, error);
@@ -1228,19 +1235,17 @@ AdbcStatusCode SqliteConnectionGetTableTypes(struct 
AdbcConnection* connection,
 }
 
 AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection,
+                                    struct AdbcDatabase* database,
                                     struct AdbcError* error) {
   if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
   auto ptr =
       
reinterpret_cast<std::shared_ptr<SqliteConnectionImpl>*>(connection->private_data);
-  return (*ptr)->Init(error);
+  return (*ptr)->Init(database, error);
 }
 
-AdbcStatusCode SqliteConnectionNew(struct AdbcDatabase* database,
-                                   struct AdbcConnection* connection,
+AdbcStatusCode SqliteConnectionNew(struct AdbcConnection* connection,
                                    struct AdbcError* error) {
-  auto ptr =
-      
reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(database->private_data);
-  auto impl = std::make_shared<SqliteConnectionImpl>(*ptr);
+  auto impl = std::make_shared<SqliteConnectionImpl>();
   connection->private_data = new std::shared_ptr<SqliteConnectionImpl>(impl);
   return ADBC_STATUS_OK;
 }
@@ -1430,14 +1435,14 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct 
AdbcConnection* connection,
 }
 
 AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+                                  struct AdbcDatabase* database,
                                   struct AdbcError* error) {
-  return SqliteConnectionInit(connection, error);
+  return SqliteConnectionInit(connection, database, error);
 }
 
-AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
-                                 struct AdbcConnection* connection,
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
                                  struct AdbcError* error) {
-  return SqliteConnectionNew(database, connection, error);
+  return SqliteConnectionNew(connection, error);
 }
 
 AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
diff --git a/drivers/sqlite/sqlite_test.cc b/drivers/sqlite/sqlite_test.cc
index 6ffcdc2..3f40587 100644
--- a/drivers/sqlite/sqlite_test.cc
+++ b/drivers/sqlite/sqlite_test.cc
@@ -56,8 +56,8 @@ class Sqlite : public ::testing::Test {
     ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
     ASSERT_NE(database.private_data, nullptr);
 
-    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection, 
&error));
-    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+    ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, 
&database, &error));
     ASSERT_NE(connection.private_data, nullptr);
   }
 
@@ -328,8 +328,8 @@ TEST_F(Sqlite, BulkIngestStream) {
 
 TEST_F(Sqlite, MultipleConnections) {
   struct AdbcConnection connection2;
-  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection2, 
&error));
-  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection2, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &database, 
&error));
   ASSERT_NE(connection.private_data, nullptr);
 
   {
@@ -661,12 +661,12 @@ TEST_F(Sqlite, Transactions) {
       AdbcDatabaseSetOption(&database, "filename",
                             
"file:Sqlite_Transactions?mode=memory&cache=shared", &error));
   ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
-  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection, 
&error));
-  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &database, 
&error));
 
   struct AdbcConnection connection2;
-  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection2, 
&error));
-  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&connection2, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &database, 
&error));
   ASSERT_NE(connection.private_data, nullptr);
 
   AdbcStatement statement;
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 614f701..410e9af 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -79,9 +79,9 @@ cdef extern from "adbc.h":
     AdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error)
     AdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* 
error)
 
-    AdbcStatusCode AdbcConnectionNew(CAdbcDatabase* database, CAdbcConnection* 
connection, CAdbcError* error)
+    AdbcStatusCode AdbcConnectionNew(CAdbcConnection* connection, CAdbcError* 
error)
     AdbcStatusCode AdbcConnectionSetOption(CAdbcConnection* connection, const 
char* key, const char* value, CAdbcError* error)
-    AdbcStatusCode AdbcConnectionInit(CAdbcConnection* connection, CAdbcError* 
error)
+    AdbcStatusCode AdbcConnectionInit(CAdbcConnection* connection, 
CAdbcDatabase* database, CAdbcError* error)
     AdbcStatusCode AdbcConnectionRelease(CAdbcConnection* connection, 
CAdbcError* error)
 
     AdbcStatusCode AdbcStatementBind(CAdbcStatement* statement, CArrowArray*, 
CArrowSchema*, CAdbcError* error)
@@ -247,7 +247,7 @@ cdef class AdbcConnection(_AdbcHandle):
         cdef const char* c_value
         memset(&self.connection, 0, cython.sizeof(CAdbcConnection))
 
-        status = AdbcConnectionNew(&database.database, &self.connection, 
&c_error)
+        status = AdbcConnectionNew(&self.connection, &c_error)
         check_error(status, &c_error)
 
         for key, value in kwargs.items():
@@ -258,7 +258,7 @@ cdef class AdbcConnection(_AdbcHandle):
             status = AdbcConnectionSetOption(&self.connection, c_key, c_value, 
&c_error)
             check_error(status, &c_error)
 
-        status = AdbcConnectionInit(&self.connection, &c_error)
+        status = AdbcConnectionInit(&self.connection, &database.database, 
&c_error)
         check_error(status, &c_error)
 
     def close(self) -> None:

Reply via email to