This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 3aa0d121 feat(go/adbc)!: close database explicitly  (#1460)
3aa0d121 is described below

commit 3aa0d12169764e2b0afabaf9b1f1f68c2d63aea8
Author: Anton Levakin <[email protected]>
AuthorDate: Fri Jan 19 17:09:54 2024 +0100

    feat(go/adbc)!: close database explicitly  (#1460)
    
    Implicit database release behaves inconsistently on different OS, which
    leads to bugs.
    
    BREAKING CHANGE: adds Close to the Database interface.
    Closes #1306.
    
    ---------
    
    Co-authored-by: Matt Topol <[email protected]>
---
 docs/source/driver/duckdb.rst                      |  1 +
 docs/source/driver/flight_sql.rst                  |  1 +
 docs/source/driver/postgresql.rst                  |  1 +
 docs/source/driver/snowflake.rst                   |  2 +
 docs/source/driver/sqlite.rst                      |  1 +
 go/adbc/adbc.go                                    |  3 ++
 go/adbc/driver/driverbase/database.go              |  5 ++
 go/adbc/driver/driverbase/driver.go                |  4 +-
 .../driver/flightsql/flightsql_adbc_server_test.go |  1 +
 go/adbc/driver/flightsql/flightsql_adbc_test.go    |  9 ++++
 go/adbc/driver/flightsql/flightsql_database.go     | 20 ++++---
 go/adbc/driver/flightsql/flightsql_driver.go       |  1 +
 go/adbc/driver/panicdummy/panicdummy_adbc.go       |  5 ++
 go/adbc/driver/snowflake/connection.go             |  6 +--
 go/adbc/driver/snowflake/driver.go                 |  1 +
 go/adbc/driver/snowflake/driver_test.go            | 62 ++++++++++++----------
 go/adbc/driver/snowflake/snowflake_database.go     |  4 ++
 go/adbc/drivermgr/wrapper.go                       | 46 ++++++++++------
 go/adbc/drivermgr/wrapper_sqlite_test.go           |  5 ++
 go/adbc/pkg/_tmpl/driver.go.tmpl                   | 11 ++--
 go/adbc/pkg/flightsql/driver.go                    | 11 ++--
 go/adbc/pkg/panicdummy/driver.go                   |  1 +
 go/adbc/pkg/snowflake/driver.go                    | 11 ++--
 go/adbc/validation/validation.go                   |  3 ++
 24 files changed, 148 insertions(+), 67 deletions(-)

diff --git a/docs/source/driver/duckdb.rst b/docs/source/driver/duckdb.rst
index 410331c3..94460eb5 100644
--- a/docs/source/driver/duckdb.rst
+++ b/docs/source/driver/duckdb.rst
@@ -72,6 +72,7 @@ ADBC support in DuckDB requires the driver manager.
             if err != nil {
                // handle error
             }
+            defer db.Close()
 
             cnxn, err := db.Open(context.Background())
             if err != nil {
diff --git a/docs/source/driver/flight_sql.rst 
b/docs/source/driver/flight_sql.rst
index aca95d86..7473a7cb 100644
--- a/docs/source/driver/flight_sql.rst
+++ b/docs/source/driver/flight_sql.rst
@@ -152,6 +152,7 @@ the :cpp:class:`AdbcDatabase`.
             if err != nil {
                 // do something with the error
             }
+            defer db.Close()
 
             cnxn, err := db.Open(context.Background())
             if err != nil {
diff --git a/docs/source/driver/postgresql.rst 
b/docs/source/driver/postgresql.rst
index ddf9115d..c724a2c1 100644
--- a/docs/source/driver/postgresql.rst
+++ b/docs/source/driver/postgresql.rst
@@ -124,6 +124,7 @@ the :cpp:class:`AdbcDatabase`.  This should be a 
`connection URI
             if err != nil {
                // handle error
             }
+            defer db.Close()
 
             cnxn, err := db.Open(context.Background())
             if err != nil {
diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst
index 04023a62..bf445349 100644
--- a/docs/source/driver/snowflake.rst
+++ b/docs/source/driver/snowflake.rst
@@ -127,6 +127,7 @@ constructing the :cpp::class:`AdbcDatabase`.
             if err != nil {
                 // handle error
             }
+            defer db.Close()
 
             cnxn, err := db.Open(context.Background())
             if err != nil {
@@ -241,6 +242,7 @@ a listing).
             if err != nil {
                 // handle error
             }
+            defer db.Close()
 
             cnxn, err := db.Open(context.Background())
             if err != nil {
diff --git a/docs/source/driver/sqlite.rst b/docs/source/driver/sqlite.rst
index 30e7d32b..96bd7bbd 100644
--- a/docs/source/driver/sqlite.rst
+++ b/docs/source/driver/sqlite.rst
@@ -140,6 +140,7 @@ shared across all connections.
             if err != nil {
                // handle error
             }
+            defer db.Close()
 
             cnxn, err := db.Open(context.Background())
             if err != nil {
diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go
index 3fb61d69..71a75daf 100644
--- a/go/adbc/adbc.go
+++ b/go/adbc/adbc.go
@@ -329,6 +329,9 @@ type Driver interface {
 type Database interface {
        SetOptions(map[string]string) error
        Open(ctx context.Context) (Connection, error)
+
+       // Close closes this database and releases any associated resources.
+       Close() error
 }
 
 type InfoCode uint32
diff --git a/go/adbc/driver/driverbase/database.go 
b/go/adbc/driver/driverbase/database.go
index e3a96ff1..7f32510c 100644
--- a/go/adbc/driver/driverbase/database.go
+++ b/go/adbc/driver/driverbase/database.go
@@ -31,6 +31,7 @@ type DatabaseImpl interface {
        adbc.GetSetOptions
        Base() *DatabaseImplBase
        Open(context.Context) (adbc.Connection, error)
+       Close() error
        SetOptions(map[string]string) error
 }
 
@@ -134,6 +135,10 @@ func (db *database) Open(ctx context.Context) 
(adbc.Connection, error) {
        return db.impl.Open(ctx)
 }
 
+func (db *database) Close() error {
+       return db.impl.Close()
+}
+
 func (db *database) SetLogger(logger *slog.Logger) {
        if logger != nil {
                db.impl.Base().Logger = logger
diff --git a/go/adbc/driver/driverbase/driver.go 
b/go/adbc/driver/driverbase/driver.go
index c4767794..acd182f8 100644
--- a/go/adbc/driver/driverbase/driver.go
+++ b/go/adbc/driver/driverbase/driver.go
@@ -32,7 +32,7 @@ type DriverImpl interface {
        NewDatabase(opts map[string]string) (adbc.Database, error)
 }
 
-// DatabaseImplBase is a struct that provides default implementations of the
+// DriverImplBase is a struct that provides default implementations of the
 // DriverImpl interface. It is meant to be used as a composite struct for a
 // driver's DriverImpl implementation.
 type DriverImplBase struct {
@@ -56,7 +56,7 @@ type driver struct {
        impl DriverImpl
 }
 
-// NewDatabase wraps a DriverImpl to create an adbc.Driver.
+// NewDriver wraps a DriverImpl to create an adbc.Driver.
 func NewDriver(impl DriverImpl) adbc.Driver {
        return &driver{impl}
 }
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index a591f1ca..dfd1f6cf 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -90,6 +90,7 @@ func (suite *ServerBasedTests) TearDownTest() {
 }
 
 func (suite *ServerBasedTests) TearDownSuite() {
+       suite.NoError(suite.db.Close())
        suite.db = nil
        suite.s.Shutdown()
 }
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 1619f8fa..dc7d207d 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -352,6 +352,7 @@ func (suite *DefaultDialOptionsTests) SetupSuite() {
 
 func (suite *DefaultDialOptionsTests) TearDownSuite() {
        suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
+       suite.NoError(suite.DB.Close())
        suite.DB = nil
        suite.Driver = nil
 }
@@ -361,6 +362,7 @@ func (suite *DefaultDialOptionsTests) 
TestMaxIncomingMessageSizeDefault() {
        opts["adbc.flight.sql.client_option.with_max_msg_size"] = "1000000"
        db, err := suite.Driver.NewDatabase(opts)
        suite.NoError(err)
+       defer suite.NoError(db.Close())
 
        cnxn, err := db.Open(suite.ctx)
        suite.NoError(err)
@@ -505,6 +507,7 @@ func (suite *PartitionTests) TearDownTest() {
        suite.Require().NoError(suite.Cnxn.Close())
        suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
        suite.Cnxn = nil
+       suite.NoError(suite.DB.Close())
        suite.DB = nil
        suite.Driver = nil
 }
@@ -558,6 +561,7 @@ func (suite *StatementTests) TearDownTest() {
        suite.Require().NoError(suite.Cnxn.Close())
        suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
        suite.Cnxn = nil
+       suite.NoError(suite.DB.Close())
        suite.DB = nil
        suite.Driver = nil
 }
@@ -639,6 +643,7 @@ func (suite *HeaderTests) TearDownTest() {
        suite.Require().NoError(suite.Cnxn.Close())
        suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
        suite.Cnxn = nil
+       suite.NoError(suite.DB.Close())
        suite.DB = nil
        suite.Driver = nil
 }
@@ -842,6 +847,7 @@ func (suite *TLSTests) TearDownTest() {
        suite.Require().NoError(suite.Cnxn.Close())
        suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
        suite.Cnxn = nil
+       suite.NoError(suite.DB.Close())
        suite.DB = nil
        suite.Driver = nil
 }
@@ -863,6 +869,7 @@ func (suite *TLSTests) TestInvalidOptions() {
                "adbc.flight.sql.client_option.tls_skip_verify": "false",
        })
        suite.Require().NoError(err)
+       defer suite.NoError(db.Close())
 
        cnxn, err := db.Open(suite.ctx)
        suite.Require().NoError(err)
@@ -912,6 +919,7 @@ func (suite *ConnectionTests) SetupSuite() {
 }
 
 func (suite *ConnectionTests) TearDownSuite() {
+       suite.NoError(suite.DB.Close())
        suite.server.Shutdown()
        suite.alloc.AssertSize(suite.T(), 0)
 }
@@ -1009,6 +1017,7 @@ func (suite *DomainSocketTests) SetupSuite() {
 func (suite *DomainSocketTests) TearDownSuite() {
        suite.Require().NoError(suite.Stmt.Close())
        suite.Require().NoError(suite.Cnxn.Close())
+       suite.NoError(suite.DB.Close())
        suite.server.Shutdown()
        suite.alloc.AssertSize(suite.T(), 0)
 }
diff --git a/go/adbc/driver/flightsql/flightsql_database.go 
b/go/adbc/driver/flightsql/flightsql_database.go
index 8b6ab2cc..f9537f50 100644
--- a/go/adbc/driver/flightsql/flightsql_database.go
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -332,6 +332,10 @@ func (d *databaseImpl) SetOptionDouble(key string, value 
float64) error {
        return d.DatabaseImplBase.SetOptionDouble(key, value)
 }
 
+func (d *databaseImpl) Close() error {
+       return nil
+}
+
 func getFlightClient(ctx context.Context, loc string, d *databaseImpl) 
(*flightsql.Client, error) {
        authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
        middleware := []flight.ClientMiddleware{
@@ -396,8 +400,8 @@ type support struct {
        transactions bool
 }
 
-func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
-       cl, err := getFlightClient(ctx, impl.uri.String(), impl)
+func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
+       cl, err := getFlightClient(ctx, d.uri.String(), d)
        if err != nil {
                return nil, err
        }
@@ -410,12 +414,12 @@ func (impl *databaseImpl) Open(ctx context.Context) 
(adbc.Connection, error) {
                                return nil, adbc.Error{Msg: 
fmt.Sprintf("Location must be a string, got %#v", uri), Code: 
adbc.StatusInternal}
                        }
 
-                       cl, err := getFlightClient(context.Background(), uri, 
impl)
+                       cl, err := getFlightClient(context.Background(), uri, d)
                        if err != nil {
                                return nil, err
                        }
 
-                       cl.Alloc = impl.Alloc
+                       cl.Alloc = d.Alloc
                        return cl, nil
                }).
                EvictedFunc(func(_, client interface{}) {
@@ -425,13 +429,13 @@ func (impl *databaseImpl) Open(ctx context.Context) 
(adbc.Connection, error) {
 
        var cnxnSupport support
 
-       info, err := cl.GetSqlInfo(ctx, 
[]flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, impl.timeout)
+       info, err := cl.GetSqlInfo(ctx, 
[]flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, d.timeout)
        // ignore this if it fails
        if err == nil {
                const int32code = 3
 
                for _, endpoint := range info.Endpoint {
-                       rdr, err := doGet(ctx, cl, endpoint, cache, 
impl.timeout)
+                       rdr, err := doGet(ctx, cl, endpoint, cache, d.timeout)
                        if err != nil {
                                continue
                        }
@@ -465,8 +469,8 @@ func (impl *databaseImpl) Open(ctx context.Context) 
(adbc.Connection, error) {
                }
        }
 
-       return &cnxn{cl: cl, db: impl, clientCache: cache,
-               hdrs: make(metadata.MD), timeouts: impl.timeout,
+       return &cnxn{cl: cl, db: d, clientCache: cache,
+               hdrs: make(metadata.MD), timeouts: d.timeout,
                supportInfo: cnxnSupport}, nil
 }
 
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go 
b/go/adbc/driver/flightsql/flightsql_driver.go
index 0060c040..cc58a9e1 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -145,5 +145,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) 
(adbc.Database, error)
        if err := db.SetOptions(opts); err != nil {
                return nil, err
        }
+
        return driverbase.NewDatabase(db), nil
 }
diff --git a/go/adbc/driver/panicdummy/panicdummy_adbc.go 
b/go/adbc/driver/panicdummy/panicdummy_adbc.go
index f0513cd6..95171591 100644
--- a/go/adbc/driver/panicdummy/panicdummy_adbc.go
+++ b/go/adbc/driver/panicdummy/panicdummy_adbc.go
@@ -66,6 +66,11 @@ func (d *database) Open(ctx context.Context) 
(adbc.Connection, error) {
        return &cnxn{}, nil
 }
 
+func (d *database) Close() error {
+       maybePanic("DatabaseClose")
+       return nil
+}
+
 type cnxn struct{}
 
 func (c *cnxn) SetOption(key, value string) error {
diff --git a/go/adbc/driver/snowflake/connection.go 
b/go/adbc/driver/snowflake/connection.go
index 73c31604..e2f98487 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -714,16 +714,16 @@ func prepareTablesSQL(matchingCatalogNames []string, 
catalog *string, dbSchema *
 
 func prepareColumnsSQL(matchingCatalogNames []string, catalog *string, 
dbSchema *string, tableName *string, columnName *string, tableType []string) 
(string, []interface{}) {
        prefixQuery := ""
-       for _, catalog_name := range matchingCatalogNames {
+       for _, catalogName := range matchingCatalogNames {
                if prefixQuery != "" {
                        prefixQuery += " UNION ALL "
                }
                prefixQuery += `SELECT T.table_type,
                                        C.*
                                FROM
-                               "` + strings.ReplaceAll(catalog_name, "\"", 
"\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
+                               "` + strings.ReplaceAll(catalogName, "\"", 
"\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
                        JOIN
-                               "` + strings.ReplaceAll(catalog_name, "\"", 
"\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
+                               "` + strings.ReplaceAll(catalogName, "\"", 
"\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
                        ON
                                T.table_catalog = C.table_catalog
                                AND T.table_schema = C.table_schema
diff --git a/go/adbc/driver/snowflake/driver.go 
b/go/adbc/driver/snowflake/driver.go
index db5efed4..3b9d72cc 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -202,5 +202,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) 
(adbc.Database, error)
        if err := db.SetOptions(opts); err != nil {
                return nil, err
        }
+
        return driverbase.NewDatabase(db), nil
 }
diff --git a/go/adbc/driver/snowflake/driver_test.go 
b/go/adbc/driver/snowflake/driver_test.go
index 61e94486..a69a0b04 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -336,6 +336,7 @@ func (suite *SnowflakeTests) TearDownTest() {
 }
 
 func (suite *SnowflakeTests) TearDownSuite() {
+       suite.NoError(suite.db.Close())
        suite.db = nil
 }
 
@@ -464,21 +465,21 @@ func (suite *SnowflakeTests) 
TestMetadataGetObjectsColumnsXdbc() {
                xdbcDateTimeSub  []string
        }{
                {
-                       "BASIC",                       //name
-                       []string{"int64s", "strings"}, //colNames
-                       []string{"1", "2"},            //positions
-                       []string{"NUMBER", "TEXT"},    //dataTypes
-                       []string{"", ""},              //comments
-                       []string{"9", "13"},           //xdbcDataType
-                       []string{"NUMBER", "TEXT"},    //xdbcTypeName
-                       []string{"-5", "12"},          //xdbcSqlDataType
-                       []string{"1", "1"},            //xdbcNullable
-                       []string{"YES", "YES"},        //xdbcIsNullable
-                       []string{"0", "0"},            //xdbcScale
-                       []string{"10", "0"},           //xdbcNumPrecRadix
-                       []string{"38", "16777216"},    //xdbcCharMaxLen 
(xdbcPrecision)
-                       []string{"0", "16777216"},     //xdbcCharOctetLen
-                       []string{"-5", "12", "0"},     //xdbcDateTimeSub
+                       "BASIC",                       // name
+                       []string{"int64s", "strings"}, // colNames
+                       []string{"1", "2"},            // positions
+                       []string{"NUMBER", "TEXT"},    // dataTypes
+                       []string{"", ""},              // comments
+                       []string{"9", "13"},           // xdbcDataType
+                       []string{"NUMBER", "TEXT"},    // xdbcTypeName
+                       []string{"-5", "12"},          // xdbcSqlDataType
+                       []string{"1", "1"},            // xdbcNullable
+                       []string{"YES", "YES"},        // xdbcIsNullable
+                       []string{"0", "0"},            // xdbcScale
+                       []string{"10", "0"},           // xdbcNumPrecRadix
+                       []string{"38", "16777216"},    // xdbcCharMaxLen 
(xdbcPrecision)
+                       []string{"0", "16777216"},     // xdbcCharOctetLen
+                       []string{"-5", "12", "0"},     // xdbcDateTimeSub
                },
        }
 
@@ -576,20 +577,20 @@ func (suite *SnowflakeTests) 
TestMetadataGetObjectsColumnsXdbc() {
 
                        suite.False(rdr.Next())
                        suite.True(foundExpected)
-                       suite.Equal(tt.colnames, colnames)                  
//colNames
-                       suite.Equal(tt.positions, positions)                
//positions
-                       suite.Equal(tt.comments, comments)                  
//comments
-                       suite.Equal(tt.xdbcDataType, xdbcDataTypes)         
//xdbcDataType
-                       suite.Equal(tt.dataTypes, dataTypes)                
//dataTypes
-                       suite.Equal(tt.xdbcTypeName, xdbcTypeNames)         
//xdbcTypeName
-                       suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens)     
//xdbcCharMaxLen
-                       suite.Equal(tt.xdbcScale, xdbcScales)               
//xdbcScale
-                       suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) 
//xdbcNumPrecRadix
-                       suite.Equal(tt.xdbcNullable, xdbcNullables)         
//xdbcNullable
-                       suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes)   
//xdbcSqlDataType
-                       suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub)    
//xdbcDateTimeSub
-                       suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen)  
//xdbcCharOctetLen
-                       suite.Equal(tt.xdbcIsNullable, xdbcIsNullables)     
//xdbcIsNullable
+                       suite.Equal(tt.colnames, colnames)                  // 
colNames
+                       suite.Equal(tt.positions, positions)                // 
positions
+                       suite.Equal(tt.comments, comments)                  // 
comments
+                       suite.Equal(tt.xdbcDataType, xdbcDataTypes)         // 
xdbcDataType
+                       suite.Equal(tt.dataTypes, dataTypes)                // 
dataTypes
+                       suite.Equal(tt.xdbcTypeName, xdbcTypeNames)         // 
xdbcTypeName
+                       suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens)     // 
xdbcCharMaxLen
+                       suite.Equal(tt.xdbcScale, xdbcScales)               // 
xdbcScale
+                       suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) // 
xdbcNumPrecRadix
+                       suite.Equal(tt.xdbcNullable, xdbcNullables)         // 
xdbcNullable
+                       suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes)   // 
xdbcSqlDataType
+                       suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub)    // 
xdbcDateTimeSub
+                       suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen)  // 
xdbcCharOctetLen
+                       suite.Equal(tt.xdbcIsNullable, xdbcIsNullables)     // 
xdbcIsNullable
 
                })
        }
@@ -605,6 +606,7 @@ func (suite *SnowflakeTests) TestNewDatabaseGetSetOptions() 
{
        })
        suite.NoError(err)
        suite.NotNil(db)
+       defer suite.NoError(db.Close())
 
        getSetDB, ok := db.(adbc.GetSetOptions)
        suite.True(ok)
@@ -862,6 +864,7 @@ func ConnectWithJwt(uri, keyValue, passcode string) {
        if err != nil {
                panic(err)
        }
+       defer db.Close()
 
        cnxn, err := db.Open(context.Background())
        if err != nil {
@@ -912,6 +915,7 @@ func (suite *SnowflakeTests) TestJwtPrivateKey() {
                opts[driver.OptionJwtPrivateKey] = keyFile
                db, err := suite.driver.NewDatabase(opts)
                suite.NoError(err)
+               defer db.Close()
                cnxn, err := db.Open(suite.ctx)
                suite.NoError(err)
                defer cnxn.Close()
diff --git a/go/adbc/driver/snowflake/snowflake_database.go 
b/go/adbc/driver/snowflake/snowflake_database.go
index 45e3aab4..7b76fa5a 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -466,3 +466,7 @@ func (d *databaseImpl) Open(ctx context.Context) 
(adbc.Connection, error) {
                useHighPrecision: d.useHighPrecision,
        }, nil
 }
+
+func (d *databaseImpl) Close() error {
+       return nil
+}
diff --git a/go/adbc/drivermgr/wrapper.go b/go/adbc/drivermgr/wrapper.go
index 07bb94b8..63fb9ee9 100644
--- a/go/adbc/drivermgr/wrapper.go
+++ b/go/adbc/drivermgr/wrapper.go
@@ -39,7 +39,7 @@ package drivermgr
 import "C"
 import (
        "context"
-       "runtime"
+       "sync"
        "unsafe"
 
        "github.com/apache/arrow-adbc/go/adbc"
@@ -100,27 +100,15 @@ func (d Driver) NewDatabase(opts map[string]string) 
(adbc.Database, error) {
                return nil, errOut
        }
 
-       runtime.SetFinalizer(db, func(db *Database) {
-               if db.db != nil {
-                       var err C.struct_AdbcError
-                       code := adbc.Status(C.AdbcDatabaseRelease(db.db, &err))
-                       if code != adbc.StatusOK {
-                               panic(toAdbcError(code, &err))
-                       }
-               }
-
-               for _, o := range db.options {
-                       C.free(unsafe.Pointer(o.key))
-                       C.free(unsafe.Pointer(o.val))
-               }
-       })
-
        return db, nil
 }
 
 type Database struct {
        options map[string]option
        db      *C.struct_AdbcDatabase
+
+       mu     sync.Mutex // protects following fields
+       closed bool
 }
 
 func toAdbcError(code adbc.Status, e *C.struct_AdbcError) error {
@@ -182,6 +170,32 @@ func (d *Database) Open(context.Context) (adbc.Connection, 
error) {
        return &cnxn{conn: &c}, nil
 }
 
+func (d *Database) Close() error {
+       d.mu.Lock()
+       defer d.mu.Unlock()
+
+       if d.closed {
+               return nil
+       }
+
+       d.closed = true
+
+       for _, o := range d.options {
+               C.free(unsafe.Pointer(o.key))
+               C.free(unsafe.Pointer(o.val))
+       }
+
+       if d.db != nil {
+               var err C.struct_AdbcError
+               code := adbc.Status(C.AdbcDatabaseRelease(d.db, &err))
+               if code != adbc.StatusOK {
+                       return toAdbcError(code, &err)
+               }
+       }
+
+       return nil
+}
+
 func getRdr(out *C.struct_ArrowArrayStream) (array.RecordReader, error) {
        rdr, err := 
cdata.ImportCRecordReader((*cdata.CArrowArrayStream)(unsafe.Pointer(out)), nil)
        if err != nil {
diff --git a/go/adbc/drivermgr/wrapper_sqlite_test.go 
b/go/adbc/drivermgr/wrapper_sqlite_test.go
index c33adf27..af307a08 100644
--- a/go/adbc/drivermgr/wrapper_sqlite_test.go
+++ b/go/adbc/drivermgr/wrapper_sqlite_test.go
@@ -74,6 +74,10 @@ func (dm *DriverMgrSuite) SetupSuite() {
        dm.Equal(int64(1), nrows)
 }
 
+func (dm *DriverMgrSuite) TearDownSuite() {
+       dm.NoError(dm.db.Close())
+}
+
 func (dm *DriverMgrSuite) SetupTest() {
        cnxn, err := dm.db.Open(dm.ctx)
        dm.Require().NoError(err)
@@ -597,6 +601,7 @@ func TestDriverMgrCustomInitFunc(t *testing.T) {
        cnxn, err := db.Open(context.Background())
        assert.NoError(t, err)
        require.NoError(t, cnxn.Close())
+       require.NoError(t, db.Close())
 
        // set invalid entrypoint
        drv = drivermgr.Driver{}
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 4b7008ea..901d164e 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -591,10 +591,15 @@ func {{.Prefix}}DatabaseRelease(db 
*C.struct_AdbcDatabase, err *C.struct_AdbcErr
        h := (*(*cgo.Handle)(db.private_data))
 
        cdb := h.Value().(*cDatabase)
-       cdb.db = nil
+       if cdb.db != nil {
+               cdb.db.Close()
+               cdb.db = nil
+       }
        cdb.opts = nil
-       C.free(unsafe.Pointer(db.private_data))
-       db.private_data = nil
+       if db.private_data != nil {
+               C.free(unsafe.Pointer(db.private_data))
+               db.private_data = nil
+       }
        h.Delete()
        // manually trigger GC for two reasons:
        //  1. ASAN expects the release callback to be called before
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 2847274c..d57a91b7 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -594,10 +594,15 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, 
err *C.struct_AdbcError
        h := (*(*cgo.Handle)(db.private_data))
 
        cdb := h.Value().(*cDatabase)
-       cdb.db = nil
+       if cdb.db != nil {
+               cdb.db.Close()
+               cdb.db = nil
+       }
        cdb.opts = nil
-       C.free(unsafe.Pointer(db.private_data))
-       db.private_data = nil
+       if db.private_data != nil {
+               C.free(unsafe.Pointer(db.private_data))
+               db.private_data = nil
+       }
        h.Delete()
        // manually trigger GC for two reasons:
        //  1. ASAN expects the release callback to be called before
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 399d0edc..fbaa5204 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -594,6 +594,7 @@ func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, 
err *C.struct_AdbcErro
        h := (*(*cgo.Handle)(db.private_data))
 
        cdb := h.Value().(*cDatabase)
+       cdb.db.Close()
        cdb.db = nil
        cdb.opts = nil
        C.free(unsafe.Pointer(db.private_data))
diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go
index b5910181..6e2d3bac 100644
--- a/go/adbc/pkg/snowflake/driver.go
+++ b/go/adbc/pkg/snowflake/driver.go
@@ -594,10 +594,15 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, 
err *C.struct_AdbcError
        h := (*(*cgo.Handle)(db.private_data))
 
        cdb := h.Value().(*cDatabase)
-       cdb.db = nil
+       if cdb.db != nil {
+               cdb.db.Close()
+               cdb.db = nil
+       }
        cdb.opts = nil
-       C.free(unsafe.Pointer(db.private_data))
-       db.private_data = nil
+       if db.private_data != nil {
+               C.free(unsafe.Pointer(db.private_data))
+               db.private_data = nil
+       }
        h.Delete()
        // manually trigger GC for two reasons:
        //  1. ASAN expects the release callback to be called before
diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go
index 8925f725..19222807 100644
--- a/go/adbc/validation/validation.go
+++ b/go/adbc/validation/validation.go
@@ -100,6 +100,7 @@ func (d *DatabaseTests) TestNewDatabase() {
        d.NoError(err)
        d.NotNil(db)
        d.Implements((*adbc.Database)(nil), db)
+       d.NoError(db.Close())
 }
 
 type ConnectionTests struct {
@@ -121,6 +122,7 @@ func (c *ConnectionTests) SetupTest() {
 func (c *ConnectionTests) TearDownTest() {
        c.Quirks.TearDownDriver(c.T(), c.Driver)
        c.Driver = nil
+       c.NoError(c.DB.Close())
        c.DB = nil
 }
 
@@ -514,6 +516,7 @@ func (s *StatementTests) TearDownTest() {
        s.Require().NoError(s.Cnxn.Close())
        s.Quirks.TearDownDriver(s.T(), s.Driver)
        s.Cnxn = nil
+       s.NoError(s.DB.Close())
        s.DB = nil
        s.Driver = nil
 }

Reply via email to