This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new cd24bc096 fix(rust): implement database/connection constructors 
without options (#2242)
cd24bc096 is described below

commit cd24bc096612feeef31bff2a9271616cc94e1634
Author: Matthijs Brobbel <[email protected]>
AuthorDate: Sat Oct 12 03:39:17 2024 +0200

    fix(rust): implement database/connection constructors without options 
(#2242)
    
    Fixes #2241.
---
 rust/core/src/driver_manager.rs | 139 ++++++++++++++++++++++++++++++----------
 rust/drivers/dummy/src/lib.rs   |  19 +++---
 2 files changed, 113 insertions(+), 45 deletions(-)

diff --git a/rust/core/src/driver_manager.rs b/rust/core/src/driver_manager.rs
index ca0d53a63..c486aaaab 100644
--- a/rust/core/src/driver_manager.rs
+++ b/rust/core/src/driver_manager.rs
@@ -240,19 +240,11 @@ impl ManagedDriver {
         check_status(status, error)?;
         Ok(driver)
     }
-}
-
-impl Driver for ManagedDriver {
-    type DatabaseType = ManagedDatabase;
 
-    fn new_database(&mut self) -> Result<Self::DatabaseType> {
-        self.new_database_with_opts(None)
-    }
-
-    fn new_database_with_opts(
-        &mut self,
-        opts: impl IntoIterator<Item = (<Self::DatabaseType as 
Optionable>::Option, OptionValue)>,
-    ) -> Result<Self::DatabaseType> {
+    /// Returns a new database using the loaded driver.
+    ///
+    /// This uses `&mut self` to prevent a deadlock.
+    fn database_new(&mut self) -> Result<ffi::FFI_AdbcDatabase> {
         let driver = &self.inner.driver.lock().unwrap();
         let mut database = ffi::FFI_AdbcDatabase::default();
 
@@ -262,10 +254,17 @@ impl Driver for ManagedDriver {
         let status = unsafe { method(&mut database, &mut error) };
         check_status(status, error)?;
 
-        // DatabaseSetOption
-        for (key, value) in opts {
-            set_option_database(driver, &mut database, self.inner.version, 
key, value)?;
-        }
+        Ok(database)
+    }
+
+    /// Initialize the given database using the loaded driver.
+    ///
+    /// This uses `&mut self` to prevent a deadlock.
+    fn database_init(
+        &mut self,
+        mut database: ffi::FFI_AdbcDatabase,
+    ) -> Result<ffi::FFI_AdbcDatabase> {
+        let driver = &self.inner.driver.lock().unwrap();
 
         // DatabaseInit
         let mut error = ffi::FFI_AdbcError::with_driver(driver);
@@ -273,6 +272,40 @@ impl Driver for ManagedDriver {
         let status = unsafe { method(&mut database, &mut error) };
         check_status(status, error)?;
 
+        Ok(database)
+    }
+}
+
+impl Driver for ManagedDriver {
+    type DatabaseType = ManagedDatabase;
+
+    fn new_database(&mut self) -> Result<Self::DatabaseType> {
+        // Construct a new database.
+        let database = self.database_new()?;
+        // Initialize the database.
+        let database = self.database_init(database)?;
+        let inner = Arc::new(ManagedDatabaseInner {
+            database: Mutex::new(database),
+            driver: self.inner.clone(),
+        });
+        Ok(Self::DatabaseType { inner })
+    }
+
+    fn new_database_with_opts(
+        &mut self,
+        opts: impl IntoIterator<Item = (<Self::DatabaseType as 
Optionable>::Option, OptionValue)>,
+    ) -> Result<Self::DatabaseType> {
+        // Construct a new database.
+        let mut database = self.database_new()?;
+        // Set the options.
+        {
+            let driver = &self.inner.driver.lock().unwrap();
+            for (key, value) in opts {
+                set_option_database(driver, &mut database, self.inner.version, 
key, value)?;
+            }
+        }
+        // Initialize the database.
+        let database = self.database_init(database)?;
         let inner = Arc::new(ManagedDatabaseInner {
             database: Mutex::new(database),
             driver: self.inner.clone(),
@@ -425,6 +458,41 @@ impl ManagedDatabase {
     fn driver_version(&self) -> AdbcVersion {
         self.inner.driver.version
     }
+
+    /// Returns a new connection using the loaded driver.
+    ///
+    /// This uses `&mut self` to prevent a deadlock.
+    fn connection_new(&mut self) -> Result<ffi::FFI_AdbcConnection> {
+        let driver = &self.inner.driver.driver.lock().unwrap();
+        let mut connection = ffi::FFI_AdbcConnection::default();
+
+        // ConnectionNew
+        let mut error = ffi::FFI_AdbcError::with_driver(driver);
+        let method = driver_method!(driver, ConnectionNew);
+        let status = unsafe { method(&mut connection, &mut error) };
+        check_status(status, error)?;
+
+        Ok(connection)
+    }
+
+    /// Initialize the given connection using the loaded driver.
+    ///
+    /// This uses `&mut self` to prevent a deadlock.
+    fn connection_init(
+        &mut self,
+        mut connection: ffi::FFI_AdbcConnection,
+    ) -> Result<ffi::FFI_AdbcConnection> {
+        let driver = &self.inner.driver.driver.lock().unwrap();
+        let mut database = self.inner.database.lock().unwrap();
+
+        // ConnectionInit
+        let mut error = ffi::FFI_AdbcError::with_driver(driver);
+        let method = driver_method!(driver, ConnectionInit);
+        let status = unsafe { method(&mut connection, &mut *database, &mut 
error) };
+        check_status(status, error)?;
+
+        Ok(connection)
+    }
 }
 
 impl Optionable for ManagedDatabase {
@@ -497,35 +565,38 @@ impl Database for ManagedDatabase {
     type ConnectionType = ManagedConnection;
 
     fn new_connection(&mut self) -> Result<Self::ConnectionType> {
-        self.new_connection_with_opts(None)
+        // Construct a new connection.
+        let connection = self.connection_new()?;
+        // Initialize the connection.
+        let connection = self.connection_init(connection)?;
+        let inner = ManagedConnectionInner {
+            connection: Mutex::new(connection),
+            database: self.inner.clone(),
+        };
+        Ok(Self::ConnectionType {
+            inner: Arc::new(inner),
+        })
     }
 
     fn new_connection_with_opts(
         &mut self,
         opts: impl IntoIterator<Item = (<Self::ConnectionType as 
Optionable>::Option, OptionValue)>,
     ) -> Result<Self::ConnectionType> {
-        let driver = &self.inner.driver.driver.lock().unwrap();
-        let mut database = self.inner.database.lock().unwrap();
-        let mut connection = ffi::FFI_AdbcConnection::default();
-        let mut error = ffi::FFI_AdbcError::with_driver(driver);
-        let method = driver_method!(driver, ConnectionNew);
-        let status = unsafe { method(&mut connection, &mut error) };
-        check_status(status, error)?;
-
-        for (key, value) in opts {
-            set_option_connection(driver, &mut connection, 
self.driver_version(), key, value)?;
+        // Construct a new connection.
+        let mut connection = self.connection_new()?;
+        // Set the options.
+        {
+            let driver = &self.inner.driver.driver.lock().unwrap();
+            for (key, value) in opts {
+                set_option_connection(driver, &mut connection, 
self.driver_version(), key, value)?;
+            }
         }
-
-        let mut error = ffi::FFI_AdbcError::with_driver(driver);
-        let method = driver_method!(driver, ConnectionInit);
-        let status = unsafe { method(&mut connection, database.deref_mut(), 
&mut error) };
-        check_status(status, error)?;
-
+        // Initialize the connection.
+        let connection = self.connection_init(connection)?;
         let inner = ManagedConnectionInner {
             connection: Mutex::new(connection),
             database: self.inner.clone(),
         };
-
         Ok(Self::ConnectionType {
             inner: Arc::new(inner),
         })
diff --git a/rust/drivers/dummy/src/lib.rs b/rust/drivers/dummy/src/lib.rs
index 4ba348c3a..a841fe0e3 100644
--- a/rust/drivers/dummy/src/lib.rs
+++ b/rust/drivers/dummy/src/lib.rs
@@ -183,16 +183,14 @@ impl Driver for DummyDriver {
     type DatabaseType = DummyDatabase;
 
     fn new_database(&mut self) -> Result<Self::DatabaseType> {
-        self.new_database_with_opts(None)
+        Ok(Self::DatabaseType::default())
     }
 
     fn new_database_with_opts(
         &mut self,
         opts: impl IntoIterator<Item = (<Self::DatabaseType as 
Optionable>::Option, OptionValue)>,
     ) -> Result<Self::DatabaseType> {
-        let mut database = Self::DatabaseType {
-            options: HashMap::new(),
-        };
+        let mut database = Self::DatabaseType::default();
         for (key, value) in opts {
             database.set_option(key, value)?;
         }
@@ -200,6 +198,7 @@ impl Driver for DummyDriver {
     }
 }
 
+#[derive(Default)]
 pub struct DummyDatabase {
     options: HashMap<OptionDatabase, OptionValue>,
 }
@@ -232,16 +231,14 @@ impl Database for DummyDatabase {
     type ConnectionType = DummyConnection;
 
     fn new_connection(&mut self) -> Result<Self::ConnectionType> {
-        self.new_connection_with_opts(None)
+        Ok(Self::ConnectionType::default())
     }
 
     fn new_connection_with_opts(
         &mut self,
         opts: impl IntoIterator<Item = (<Self::ConnectionType as 
Optionable>::Option, OptionValue)>,
     ) -> Result<Self::ConnectionType> {
-        let mut connection = Self::ConnectionType {
-            options: HashMap::new(),
-        };
+        let mut connection = Self::ConnectionType::default();
         for (key, value) in opts {
             connection.set_option(key, value)?;
         }
@@ -249,6 +246,7 @@ impl Database for DummyDatabase {
     }
 }
 
+#[derive(Default)]
 pub struct DummyConnection {
     options: HashMap<OptionConnection, OptionValue>,
 }
@@ -281,9 +279,7 @@ impl Connection for DummyConnection {
     type StatementType = DummyStatement;
 
     fn new_statement(&mut self) -> Result<Self::StatementType> {
-        Ok(Self::StatementType {
-            options: HashMap::new(),
-        })
+        Ok(Self::StatementType::default())
     }
 
     // This method is used to test that errors round-trip correctly.
@@ -798,6 +794,7 @@ impl Connection for DummyConnection {
     }
 }
 
+#[derive(Default)]
 pub struct DummyStatement {
     options: HashMap<OptionStatement, OptionValue>,
 }

Reply via email to