alexandreyc commented on code in PR #1756: URL: https://github.com/apache/arrow-adbc/pull/1756#discussion_r1584769814
########## rust/core/src/driver_exporter.rs: ########## @@ -0,0 +1,1624 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::{HashMap, HashSet}; +use std::ffi::{CStr, CString}; +use std::hash::Hash; +use std::os::raw::{c_char, c_int, c_void}; + +use arrow::array::StructArray; +use arrow::datatypes::DataType; +use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; + +use crate::error::{Error, Result, Status}; +use crate::ffi::constants::ADBC_STATUS_OK; +use crate::ffi::{ + types::ErrorPrivateData, FFI_AdbcConnection, FFI_AdbcDatabase, FFI_AdbcDriver, FFI_AdbcError, + FFI_AdbcErrorDetail, FFI_AdbcPartitions, FFI_AdbcStatement, FFI_AdbcStatusCode, +}; +use crate::options::{InfoCode, ObjectDepth, OptionConnection, OptionDatabase, OptionValue}; +use crate::{Connection, Database, Driver, Optionable, Statement}; + +type DatabaseType<DriverType> = <DriverType as Driver>::DatabaseType; +type ConnectionType<DriverType> = + <<DriverType as Driver>::DatabaseType as Database>::ConnectionType; +type StatementType<DriverType> = + <<<DriverType as Driver>::DatabaseType as Database>::ConnectionType as Connection>::StatementType; + +enum ExportedDatabase<DriverType: Driver> { + Options(HashMap<OptionDatabase, OptionValue>), // Pre-init options + Database(DatabaseType<DriverType>), // Initialized database +} + +impl<DriverType: Driver> ExportedDatabase<DriverType> { + fn tuple( + &mut self, + ) -> ( + Option<&mut HashMap<OptionDatabase, OptionValue>>, + Option<&mut DatabaseType<DriverType>>, + ) { + match self { + Self::Options(options) => (Some(options), None), + Self::Database(database) => (None, Some(database)), + } + } +} + +enum ExportedConnection<DriverType: Driver> { + Options(HashMap<OptionConnection, OptionValue>), // Pre-init options + Connection(ConnectionType<DriverType>), // Initialized connection +} + +impl<DriverType: Driver> ExportedConnection<DriverType> { + fn tuple( + &mut self, + ) -> ( + Option<&mut HashMap<OptionConnection, OptionValue>>, + Option<&mut ConnectionType<DriverType>>, + ) { + match self { + Self::Options(options) => (Some(options), None), + Self::Connection(connection) => (None, Some(connection)), + } + } + + fn connection_or_panic(&mut self) -> &mut ConnectionType<DriverType> { + match self { + Self::Connection(connection) => connection, + _ => panic!("Broken invariant"), + } + } +} + +struct ExportedStatement<DriverType: Driver>(StatementType<DriverType>); + +pub trait FFIDriver { + fn ffi_driver() -> FFI_AdbcDriver; +} + +impl<DriverType: Driver + Default + 'static> FFIDriver for DriverType { + fn ffi_driver() -> FFI_AdbcDriver { + FFI_AdbcDriver { + private_data: std::ptr::null_mut(), + private_manager: std::ptr::null(), + release: Some(release_ffi_driver), + DatabaseInit: Some(database_init::<DriverType>), + DatabaseNew: Some(database_new::<DriverType>), + DatabaseSetOption: Some(database_set_option::<DriverType>), + DatabaseRelease: Some(database_release::<DriverType>), + ConnectionCommit: Some(connection_commit::<DriverType>), + ConnectionGetInfo: Some(connection_get_info::<DriverType>), + ConnectionGetObjects: Some(connection_get_objects::<DriverType>), + ConnectionGetTableSchema: Some(connection_get_table_schema::<DriverType>), + ConnectionGetTableTypes: Some(connection_get_table_types::<DriverType>), + ConnectionInit: Some(connection_init::<DriverType>), + ConnectionNew: Some(connection_new::<DriverType>), + ConnectionSetOption: Some(connection_set_option::<DriverType>), + ConnectionReadPartition: Some(connection_read_partition::<DriverType>), + ConnectionRelease: Some(connection_release::<DriverType>), + ConnectionRollback: Some(connection_rollback::<DriverType>), + StatementBind: Some(statement_bind::<DriverType>), + StatementBindStream: Some(statement_bind_stream::<DriverType>), + StatementExecuteQuery: Some(statement_execute_query::<DriverType>), + StatementExecutePartitions: Some(statement_execute_partitions::<DriverType>), + StatementGetParameterSchema: Some(statement_get_parameter_schema::<DriverType>), + StatementNew: Some(statement_new::<DriverType>), + StatementPrepare: Some(statement_prepare::<DriverType>), + StatementRelease: Some(statement_release::<DriverType>), + StatementSetOption: Some(statement_set_option::<DriverType>), + StatementSetSqlQuery: Some(statement_set_sql_query::<DriverType>), + StatementSetSubstraitPlan: Some(statement_set_substrait_plan::<DriverType>), + ErrorGetDetailCount: Some(error_get_detail_count), + ErrorGetDetail: Some(error_get_detail), + ErrorFromArrayStream: None, // TODO(alexandreyc): what to do with this? + DatabaseGetOption: Some(database_get_option::<DriverType>), + DatabaseGetOptionBytes: Some(database_get_option_bytes::<DriverType>), + DatabaseGetOptionDouble: Some(database_get_option_double::<DriverType>), + DatabaseGetOptionInt: Some(database_get_option_int::<DriverType>), + DatabaseSetOptionBytes: Some(database_set_option_bytes::<DriverType>), + DatabaseSetOptionDouble: Some(database_set_option_double::<DriverType>), + DatabaseSetOptionInt: Some(database_set_option_int::<DriverType>), + ConnectionCancel: Some(connection_cancel::<DriverType>), + ConnectionGetOption: Some(connection_get_option::<DriverType>), + ConnectionGetOptionBytes: Some(connection_get_option_bytes::<DriverType>), + ConnectionGetOptionDouble: Some(connection_get_option_double::<DriverType>), + ConnectionGetOptionInt: Some(connection_get_option_int::<DriverType>), + ConnectionGetStatistics: Some(connection_get_statistics::<DriverType>), + ConnectionGetStatisticNames: Some(connection_get_statistic_names::<DriverType>), + ConnectionSetOptionBytes: Some(connection_set_option_bytes::<DriverType>), + ConnectionSetOptionDouble: Some(connection_set_option_double::<DriverType>), + ConnectionSetOptionInt: Some(connection_set_option_int::<DriverType>), + StatementCancel: Some(statement_cancel::<DriverType>), + StatementExecuteSchema: Some(statement_execute_schema::<DriverType>), + StatementGetOption: Some(statement_get_option::<DriverType>), + StatementGetOptionBytes: Some(statement_get_option_bytes::<DriverType>), + StatementGetOptionDouble: Some(statement_get_option_double::<DriverType>), + StatementGetOptionInt: Some(statement_get_option_int::<DriverType>), + StatementSetOptionBytes: Some(statement_set_option_bytes::<DriverType>), + StatementSetOptionDouble: Some(statement_set_option_double::<DriverType>), + StatementSetOptionInt: Some(statement_set_option_int::<DriverType>), + } + } +} + +/// Export a Rust driver as a C driver. +/// +/// # Parameters +/// +/// - `$func_name` - Driver's initialization function name. The recommended name +/// is `AdbcDriverInit`, or a name derived from the name of the driver's shared +/// library as follows: remove the `lib` prefix (on Unix systems) and all file +/// extensions, then `PascalCase` the driver name, append `Init`, and prepend +/// `Adbc` (if not already there). For example: +/// - `libadbc_driver_sqlite.so.2.0.0` -> `AdbcDriverSqliteInit` +/// - `adbc_driver_sqlite.dll` -> `AdbcDriverSqliteInit` +/// - `proprietary_driver.dll` -> `AdbcProprietaryDriverInit` +/// - `$driver_type` - Driver's type which must implement [Driver] and [Default]. +/// Currently, the Rust driver is exported as an ADBC 1.1.0 C driver. +#[macro_export] +macro_rules! export_driver { + ($func_name:ident, $driver_type:ty) => { + #[no_mangle] + pub unsafe extern "C" fn $func_name( + version: std::os::raw::c_int, + driver: *mut std::os::raw::c_void, + error: *mut $crate::ffi::FFI_AdbcError, + ) -> $crate::ffi::FFI_AdbcStatusCode { + if version != $crate::options::AdbcVersion::V110.into() { + let err = $crate::error::Error::with_message_and_status( + format!("Unsupported ADBC version: {version}"), + $crate::error::Status::NotImplemented, + ); + $crate::check_err!(Err(err), error); + } + + if driver.is_null() { + let err = $crate::error::Error::with_message_and_status( + "Passed null pointer to initialization function", + $crate::error::Status::NotImplemented, + ); + $crate::check_err!(Err(err), error); + } + + let ffi_driver = <$driver_type as $crate::FFIDriver>::ffi_driver(); + unsafe { + std::ptr::write_unaligned(driver as *mut $crate::ffi::FFI_AdbcDriver, ffi_driver); + } + $crate::ffi::constants::ADBC_STATUS_OK + } + }; +} + +/// Given a Result, either unwrap the value or handle the error in ADBC function. +/// +/// This macro is for use when implementing ADBC methods that have an out +/// parameter for [FFI_AdbcError] and return [FFI_AdbcStatusCode]. If the result is +/// `Ok`, the expression resolves to the value. Otherwise, it will return early, +/// setting the error and status code appropriately. In order for this to work, +/// the error must be convertible to [crate::error::Error]. +#[doc(hidden)] +#[macro_export] +macro_rules! check_err { + ($res:expr, $err_out:expr) => { + match $res { + Ok(x) => x, + Err(error) => { + let error = $crate::error::Error::from(error); + let status: $crate::ffi::FFI_AdbcStatusCode = error.status.into(); + if !$err_out.is_null() { + let mut ffi_error = + $crate::ffi::FFI_AdbcError::try_from(error).unwrap_or_else(Into::into); + ffi_error.private_driver = (*$err_out).private_driver; + unsafe { std::ptr::write_unaligned($err_out, ffi_error) }; + } + return status; + } + } + }; +} + +/// Check that the given raw pointer is not null. +/// +/// If null, an error is returned from the enclosing function, otherwise this is +/// a no-op. +macro_rules! check_not_null { + ($ptr:ident, $err_out:expr) => { + let res = if $ptr.is_null() { + Err(Error::with_message_and_status( + format!("Passed null pointer for argument {:?}", stringify!($ptr)), + Status::InvalidArguments, + )) + } else { + Ok(()) + }; + check_err!(res, $err_out); + }; +} + +unsafe extern "C" fn release_ffi_driver( + driver: *mut FFI_AdbcDriver, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + if let Some(driver) = driver.as_mut() { + let release = driver.release.take(); + if release.is_none() { + check_err!( + Err(Error::with_message_and_status( + "Driver already released", + Status::InvalidState + )), + error + ); + } + } + ADBC_STATUS_OK +} + +// Option helpers + +// SAFETY: `dst` and `length` must be not null otherwise the function will panic. +unsafe fn copy_string(src: &str, dst: *mut c_char, length: *mut usize) -> Result<()> { + assert!(!dst.is_null() && !length.is_null()); + let src = CString::new(src)?; + let n = src.to_bytes_with_nul().len(); + if n <= *length { + std::ptr::copy_nonoverlapping(src.as_ptr(), dst, n); + } + *length = n; + Ok::<(), Error>(()) +} + +// SAFETY: `dst` and `length` must be not null otherwise the function will panic. +unsafe fn copy_bytes(src: &[u8], dst: *mut u8, length: *mut usize) { + assert!(!dst.is_null() && !length.is_null()); + let n = src.len(); + if n <= *length { + std::ptr::copy_nonoverlapping(src.as_ptr(), dst, n); + } + *length = n; +} + +unsafe fn get_option_int<'a, OptionType, Object>( + object: Option<&mut Object>, + options: Option<&mut HashMap<OptionType, OptionValue>>, + key: *const c_char, +) -> Result<i64> +where + OptionType: Hash + Eq + From<&'a str>, + Object: Optionable<Option = OptionType>, +{ + let key = CStr::from_ptr(key).to_str()?; + + if let Some(options) = options { + let optvalue = options + .get(&key.into()) + .ok_or(Error::with_message_and_status( + format!("Option key not found: {key:?}"), + Status::NotFound, + ))?; + if let OptionValue::Int(optvalue) = optvalue { + Ok(*optvalue) + } else { + let err = Error::with_message_and_status( + format!("Option value for key {key:?} has wrong type"), + Status::InvalidState, + ); + Err(err) + } + } else { + let object = object.expect("Broken invariant"); + let optvalue = object.get_option_int(key.into())?; + Ok(optvalue) + } +} + +unsafe fn get_option_double<'a, OptionType, Object>( + object: Option<&mut Object>, + options: Option<&mut HashMap<OptionType, OptionValue>>, + key: *const c_char, +) -> Result<f64> +where + OptionType: Hash + Eq + From<&'a str>, + Object: Optionable<Option = OptionType>, +{ + let key = CStr::from_ptr(key).to_str()?; + + if let Some(options) = options { + let optvalue = options + .get(&key.into()) + .ok_or(Error::with_message_and_status( + format!("Option key not found: {key}"), + Status::NotFound, + ))?; + if let OptionValue::Double(optvalue) = optvalue { + Ok(*optvalue) + } else { + let err = Error::with_message_and_status( + format!("Option value for key {:?} has wrong type", key), + Status::InvalidState, + ); + Err(err) + } + } else { + let object = object.expect("Broken invariant"); + let optvalue = object.get_option_double(key.into())?; + Ok(optvalue) + } +} + +unsafe fn get_option<'a, OptionType, Object>( + object: Option<&mut Object>, + options: Option<&mut HashMap<OptionType, OptionValue>>, + key: *const c_char, +) -> Result<String> +where + OptionType: Hash + Eq + From<&'a str>, + Object: Optionable<Option = OptionType>, +{ + let key = CStr::from_ptr(key).to_str()?; + + if let Some(options) = options { + let optvalue = options + .get(&key.into()) + .ok_or(Error::with_message_and_status( + format!("Option key not found: {key:?}"), + Status::NotFound, + ))?; + if let OptionValue::String(optvalue) = optvalue { + Ok(optvalue.clone()) + } else { + let err = Error::with_message_and_status( + format!("Option value for key {key:?} has wrong type"), + Status::InvalidState, + ); + Err(err) + } + } else { + let database = object.expect("Broken invariant"); + let optvalue = database.get_option_string(key.into())?; + Ok(optvalue) + } +} + +unsafe fn get_option_bytes<'a, OptionType, Object>( + object: Option<&mut Object>, + options: Option<&mut HashMap<OptionType, OptionValue>>, + key: *const c_char, +) -> Result<Vec<u8>> +where + OptionType: Hash + Eq + From<&'a str>, + Object: Optionable<Option = OptionType>, +{ + let key = CStr::from_ptr(key).to_str()?; + + if let Some(options) = options { + let optvalue = options + .get(&key.into()) + .ok_or(Error::with_message_and_status( + format!("Option key not found: {key:?}"), + Status::NotFound, + ))?; + if let OptionValue::Bytes(optvalue) = optvalue { + Ok(optvalue.clone()) + } else { + let err = Error::with_message_and_status( + format!("Option value for key {key:?} has wrong type"), + Status::InvalidState, + ); + Err(err) + } + } else { + let connection = object.expect("Broken invariant"); + let optvalue = connection.get_option_bytes(key.into())?; + Ok(optvalue) + } +} + +// Database + +unsafe fn database_private_data<'a, DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, +) -> Result<&'a mut ExportedDatabase<DriverType>> { + let database = database.as_mut().ok_or(Error::with_message_and_status( + "Passed null database pointer", + Status::InvalidArguments, + ))?; + let exported = database.private_data as *mut ExportedDatabase<DriverType>; + let exported = exported.as_mut().ok_or(Error::with_message_and_status( + "Uninitialized database", + Status::InvalidState, + )); + exported +} + +unsafe fn database_set_option_impl<DriverType: Driver + Default, Value: Into<OptionValue>>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: Value, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + let exported = check_err!(database_private_data::<DriverType>(database), error); + let key = check_err!(CStr::from_ptr(key).to_str(), error); + + match exported { + ExportedDatabase::Options(options) => { + options.insert(key.into(), value.into()); + } + ExportedDatabase::Database(database) => { + check_err!(database.set_option(key.into(), value.into()), error); + } + } + + ADBC_STATUS_OK +} + +unsafe extern "C" fn database_new<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + + let database = database.as_mut().unwrap(); + let exported = Box::new(ExportedDatabase::<DriverType>::Options(HashMap::new())); + database.private_data = Box::into_raw(exported) as *mut c_void; + + ADBC_STATUS_OK +} + +unsafe extern "C" fn database_init<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + + let exported = check_err!(database_private_data::<DriverType>(database), error); + + if let ExportedDatabase::Options(options) = exported { + let mut driver = DriverType::default(); + let database = driver.new_database_with_opts(options.clone()); + let database = check_err!(database, error); + *exported = ExportedDatabase::Database(database); + } else { + check_err!( + Err(Error::with_message_and_status( + "Database already initialized", + Status::InvalidState + )), + error + ); + } + + ADBC_STATUS_OK +} + +unsafe extern "C" fn database_release<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + + let database = database.as_mut().unwrap(); + let exported = Box::from_raw(database.private_data as *mut ExportedDatabase<DriverType>); + drop(exported); + database.private_data = std::ptr::null_mut(); + + ADBC_STATUS_OK +} + +unsafe extern "C" fn database_set_option<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: *const c_char, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + check_not_null!(value, error); + + let value = check_err!(CStr::from_ptr(value).to_str(), error); + database_set_option_impl::<DriverType, &str>(database, key, value, error) +} + +unsafe extern "C" fn database_set_option_int<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: i64, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + + database_set_option_impl::<DriverType, i64>(database, key, value, error) +} + +unsafe extern "C" fn database_set_option_double<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: f64, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + + database_set_option_impl::<DriverType, f64>(database, key, value, error) +} + +unsafe extern "C" fn database_set_option_bytes<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: *const u8, + length: usize, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + check_not_null!(value, error); + + let value = std::slice::from_raw_parts(value, length); + database_set_option_impl::<DriverType, &[u8]>(database, key, value, error) +} + +unsafe extern "C" fn database_get_option<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: *mut c_char, + length: *mut usize, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + check_not_null!(value, error); + check_not_null!(length, error); + + let exported = check_err!(database_private_data::<DriverType>(database), error); + let (options, database) = exported.tuple(); + + let optvalue = get_option(database, options, key); + let optvalue = check_err!(optvalue, error); + check_err!(copy_string(&optvalue, value, length), error); + + ADBC_STATUS_OK +} + +unsafe extern "C" fn database_get_option_int<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: *mut i64, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + check_not_null!(value, error); + + let exported = check_err!(database_private_data::<DriverType>(database), error); + let (options, database) = exported.tuple(); + + let optvalue = check_err!(get_option_int(database, options, key), error); + std::ptr::write_unaligned(value, optvalue); + + ADBC_STATUS_OK +} + +unsafe extern "C" fn database_get_option_double<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: *mut f64, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + check_not_null!(value, error); + + let exported = check_err!(database_private_data::<DriverType>(database), error); + let (options, database) = exported.tuple(); + + let optvalue = check_err!(get_option_double(database, options, key), error); + std::ptr::write_unaligned(value, optvalue); + + ADBC_STATUS_OK +} + +unsafe extern "C" fn database_get_option_bytes<DriverType: Driver + Default>( + database: *mut FFI_AdbcDatabase, + key: *const c_char, + value: *mut u8, + length: *mut usize, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(database, error); + check_not_null!(key, error); + check_not_null!(value, error); + check_not_null!(length, error); + + let exported = check_err!(database_private_data::<DriverType>(database), error); + let (options, database) = exported.tuple(); + + let optvalue = get_option_bytes(database, options, key); + let optvalue = check_err!(optvalue, error); + copy_bytes(&optvalue, value, length); + + ADBC_STATUS_OK +} + +// Connection + +unsafe fn connection_private_data<'a, DriverType: Driver + Default>( + connection: *mut FFI_AdbcConnection, +) -> Result<&'a mut ExportedConnection<DriverType>> { + let connection = connection.as_mut().ok_or(Error::with_message_and_status( + "Passed null connection pointer", + Status::InvalidArguments, + ))?; + let exported = connection.private_data as *mut ExportedConnection<DriverType>; + let exported = exported.as_mut().ok_or(Error::with_message_and_status( + "Uninitialized connection", + Status::InvalidState, + )); + exported +} + +unsafe fn connection_set_option_impl<DriverType: Driver + Default, Value: Into<OptionValue>>( + connection: *mut FFI_AdbcConnection, + key: *const c_char, + value: Value, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + let exported = check_err!(connection_private_data::<DriverType>(connection), error); + + let key = check_err!(CStr::from_ptr(key).to_str(), error); + + match exported { + ExportedConnection::Options(options) => { + options.insert(key.into(), value.into()); + } + ExportedConnection::Connection(connection) => { + check_err!(connection.set_option(key.into(), value.into()), error); + } + } + + ADBC_STATUS_OK +} + +unsafe extern "C" fn connection_new<DriverType: Driver + Default>( + connection: *mut FFI_AdbcConnection, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(connection, error); + + let connection = connection.as_mut().unwrap(); + let exported = Box::new(ExportedConnection::<DriverType>::Options(HashMap::new())); + connection.private_data = Box::into_raw(exported) as *mut c_void; + + ADBC_STATUS_OK +} + +unsafe extern "C" fn connection_init<DriverType: Driver + Default>( + connection: *mut FFI_AdbcConnection, + database: *mut FFI_AdbcDatabase, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(connection, error); + check_not_null!(database, error); + + let exported_connection = check_err!(connection_private_data::<DriverType>(connection), error); + let exported_database = check_err!(database_private_data::<DriverType>(database), error); + + if let ExportedConnection::Options(options) = exported_connection { + let connection = match exported_database { + ExportedDatabase::Database(database) => { + database.new_connection_with_opts(options.clone()) + } + _ => panic!("Broken invariant"), + }; + let connection = check_err!(connection, error); + *exported_connection = ExportedConnection::Connection(connection); + } else { + check_err!( + Err(Error::with_message_and_status( + "Connection already initialized", + Status::InvalidState + )), + error + ); + } + + ADBC_STATUS_OK +} + +unsafe extern "C" fn connection_release<DriverType: Driver + Default>( + connection: *mut FFI_AdbcConnection, + error: *mut FFI_AdbcError, +) -> FFI_AdbcStatusCode { + check_not_null!(connection, error); + + let connection = connection.as_mut().unwrap(); + let exported = Box::from_raw(connection.private_data as *mut ExportedConnection<DriverType>); Review Comment: I've added a null check with an error here instead of panicking. Same reasoning as before: it can only be an API misuse (namely calling Release twice or more) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
