This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 3aa0d121 feat(go/adbc)!: close database explicitly (#1460)
3aa0d121 is described below
commit 3aa0d12169764e2b0afabaf9b1f1f68c2d63aea8
Author: Anton Levakin <[email protected]>
AuthorDate: Fri Jan 19 17:09:54 2024 +0100
feat(go/adbc)!: close database explicitly (#1460)
Implicit database release behaves inconsistently on different OS, which
leads to bugs.
BREAKING CHANGE: adds Close to the Database interface.
Closes #1306.
---------
Co-authored-by: Matt Topol <[email protected]>
---
docs/source/driver/duckdb.rst | 1 +
docs/source/driver/flight_sql.rst | 1 +
docs/source/driver/postgresql.rst | 1 +
docs/source/driver/snowflake.rst | 2 +
docs/source/driver/sqlite.rst | 1 +
go/adbc/adbc.go | 3 ++
go/adbc/driver/driverbase/database.go | 5 ++
go/adbc/driver/driverbase/driver.go | 4 +-
.../driver/flightsql/flightsql_adbc_server_test.go | 1 +
go/adbc/driver/flightsql/flightsql_adbc_test.go | 9 ++++
go/adbc/driver/flightsql/flightsql_database.go | 20 ++++---
go/adbc/driver/flightsql/flightsql_driver.go | 1 +
go/adbc/driver/panicdummy/panicdummy_adbc.go | 5 ++
go/adbc/driver/snowflake/connection.go | 6 +--
go/adbc/driver/snowflake/driver.go | 1 +
go/adbc/driver/snowflake/driver_test.go | 62 ++++++++++++----------
go/adbc/driver/snowflake/snowflake_database.go | 4 ++
go/adbc/drivermgr/wrapper.go | 46 ++++++++++------
go/adbc/drivermgr/wrapper_sqlite_test.go | 5 ++
go/adbc/pkg/_tmpl/driver.go.tmpl | 11 ++--
go/adbc/pkg/flightsql/driver.go | 11 ++--
go/adbc/pkg/panicdummy/driver.go | 1 +
go/adbc/pkg/snowflake/driver.go | 11 ++--
go/adbc/validation/validation.go | 3 ++
24 files changed, 148 insertions(+), 67 deletions(-)
diff --git a/docs/source/driver/duckdb.rst b/docs/source/driver/duckdb.rst
index 410331c3..94460eb5 100644
--- a/docs/source/driver/duckdb.rst
+++ b/docs/source/driver/duckdb.rst
@@ -72,6 +72,7 @@ ADBC support in DuckDB requires the driver manager.
if err != nil {
// handle error
}
+ defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
diff --git a/docs/source/driver/flight_sql.rst
b/docs/source/driver/flight_sql.rst
index aca95d86..7473a7cb 100644
--- a/docs/source/driver/flight_sql.rst
+++ b/docs/source/driver/flight_sql.rst
@@ -152,6 +152,7 @@ the :cpp:class:`AdbcDatabase`.
if err != nil {
// do something with the error
}
+ defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
diff --git a/docs/source/driver/postgresql.rst
b/docs/source/driver/postgresql.rst
index ddf9115d..c724a2c1 100644
--- a/docs/source/driver/postgresql.rst
+++ b/docs/source/driver/postgresql.rst
@@ -124,6 +124,7 @@ the :cpp:class:`AdbcDatabase`. This should be a
`connection URI
if err != nil {
// handle error
}
+ defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst
index 04023a62..bf445349 100644
--- a/docs/source/driver/snowflake.rst
+++ b/docs/source/driver/snowflake.rst
@@ -127,6 +127,7 @@ constructing the :cpp::class:`AdbcDatabase`.
if err != nil {
// handle error
}
+ defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
@@ -241,6 +242,7 @@ a listing).
if err != nil {
// handle error
}
+ defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
diff --git a/docs/source/driver/sqlite.rst b/docs/source/driver/sqlite.rst
index 30e7d32b..96bd7bbd 100644
--- a/docs/source/driver/sqlite.rst
+++ b/docs/source/driver/sqlite.rst
@@ -140,6 +140,7 @@ shared across all connections.
if err != nil {
// handle error
}
+ defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go
index 3fb61d69..71a75daf 100644
--- a/go/adbc/adbc.go
+++ b/go/adbc/adbc.go
@@ -329,6 +329,9 @@ type Driver interface {
type Database interface {
SetOptions(map[string]string) error
Open(ctx context.Context) (Connection, error)
+
+ // Close closes this database and releases any associated resources.
+ Close() error
}
type InfoCode uint32
diff --git a/go/adbc/driver/driverbase/database.go
b/go/adbc/driver/driverbase/database.go
index e3a96ff1..7f32510c 100644
--- a/go/adbc/driver/driverbase/database.go
+++ b/go/adbc/driver/driverbase/database.go
@@ -31,6 +31,7 @@ type DatabaseImpl interface {
adbc.GetSetOptions
Base() *DatabaseImplBase
Open(context.Context) (adbc.Connection, error)
+ Close() error
SetOptions(map[string]string) error
}
@@ -134,6 +135,10 @@ func (db *database) Open(ctx context.Context)
(adbc.Connection, error) {
return db.impl.Open(ctx)
}
+func (db *database) Close() error {
+ return db.impl.Close()
+}
+
func (db *database) SetLogger(logger *slog.Logger) {
if logger != nil {
db.impl.Base().Logger = logger
diff --git a/go/adbc/driver/driverbase/driver.go
b/go/adbc/driver/driverbase/driver.go
index c4767794..acd182f8 100644
--- a/go/adbc/driver/driverbase/driver.go
+++ b/go/adbc/driver/driverbase/driver.go
@@ -32,7 +32,7 @@ type DriverImpl interface {
NewDatabase(opts map[string]string) (adbc.Database, error)
}
-// DatabaseImplBase is a struct that provides default implementations of the
+// DriverImplBase is a struct that provides default implementations of the
// DriverImpl interface. It is meant to be used as a composite struct for a
// driver's DriverImpl implementation.
type DriverImplBase struct {
@@ -56,7 +56,7 @@ type driver struct {
impl DriverImpl
}
-// NewDatabase wraps a DriverImpl to create an adbc.Driver.
+// NewDriver wraps a DriverImpl to create an adbc.Driver.
func NewDriver(impl DriverImpl) adbc.Driver {
return &driver{impl}
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index a591f1ca..dfd1f6cf 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -90,6 +90,7 @@ func (suite *ServerBasedTests) TearDownTest() {
}
func (suite *ServerBasedTests) TearDownSuite() {
+ suite.NoError(suite.db.Close())
suite.db = nil
suite.s.Shutdown()
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 1619f8fa..dc7d207d 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -352,6 +352,7 @@ func (suite *DefaultDialOptionsTests) SetupSuite() {
func (suite *DefaultDialOptionsTests) TearDownSuite() {
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
+ suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
@@ -361,6 +362,7 @@ func (suite *DefaultDialOptionsTests)
TestMaxIncomingMessageSizeDefault() {
opts["adbc.flight.sql.client_option.with_max_msg_size"] = "1000000"
db, err := suite.Driver.NewDatabase(opts)
suite.NoError(err)
+ defer suite.NoError(db.Close())
cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
@@ -505,6 +507,7 @@ func (suite *PartitionTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
+ suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
@@ -558,6 +561,7 @@ func (suite *StatementTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
+ suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
@@ -639,6 +643,7 @@ func (suite *HeaderTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
+ suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
@@ -842,6 +847,7 @@ func (suite *TLSTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
+ suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
@@ -863,6 +869,7 @@ func (suite *TLSTests) TestInvalidOptions() {
"adbc.flight.sql.client_option.tls_skip_verify": "false",
})
suite.Require().NoError(err)
+ defer suite.NoError(db.Close())
cnxn, err := db.Open(suite.ctx)
suite.Require().NoError(err)
@@ -912,6 +919,7 @@ func (suite *ConnectionTests) SetupSuite() {
}
func (suite *ConnectionTests) TearDownSuite() {
+ suite.NoError(suite.DB.Close())
suite.server.Shutdown()
suite.alloc.AssertSize(suite.T(), 0)
}
@@ -1009,6 +1017,7 @@ func (suite *DomainSocketTests) SetupSuite() {
func (suite *DomainSocketTests) TearDownSuite() {
suite.Require().NoError(suite.Stmt.Close())
suite.Require().NoError(suite.Cnxn.Close())
+ suite.NoError(suite.DB.Close())
suite.server.Shutdown()
suite.alloc.AssertSize(suite.T(), 0)
}
diff --git a/go/adbc/driver/flightsql/flightsql_database.go
b/go/adbc/driver/flightsql/flightsql_database.go
index 8b6ab2cc..f9537f50 100644
--- a/go/adbc/driver/flightsql/flightsql_database.go
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -332,6 +332,10 @@ func (d *databaseImpl) SetOptionDouble(key string, value
float64) error {
return d.DatabaseImplBase.SetOptionDouble(key, value)
}
+func (d *databaseImpl) Close() error {
+ return nil
+}
+
func getFlightClient(ctx context.Context, loc string, d *databaseImpl)
(*flightsql.Client, error) {
authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
middleware := []flight.ClientMiddleware{
@@ -396,8 +400,8 @@ type support struct {
transactions bool
}
-func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
- cl, err := getFlightClient(ctx, impl.uri.String(), impl)
+func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
+ cl, err := getFlightClient(ctx, d.uri.String(), d)
if err != nil {
return nil, err
}
@@ -410,12 +414,12 @@ func (impl *databaseImpl) Open(ctx context.Context)
(adbc.Connection, error) {
return nil, adbc.Error{Msg:
fmt.Sprintf("Location must be a string, got %#v", uri), Code:
adbc.StatusInternal}
}
- cl, err := getFlightClient(context.Background(), uri,
impl)
+ cl, err := getFlightClient(context.Background(), uri, d)
if err != nil {
return nil, err
}
- cl.Alloc = impl.Alloc
+ cl.Alloc = d.Alloc
return cl, nil
}).
EvictedFunc(func(_, client interface{}) {
@@ -425,13 +429,13 @@ func (impl *databaseImpl) Open(ctx context.Context)
(adbc.Connection, error) {
var cnxnSupport support
- info, err := cl.GetSqlInfo(ctx,
[]flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, impl.timeout)
+ info, err := cl.GetSqlInfo(ctx,
[]flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, d.timeout)
// ignore this if it fails
if err == nil {
const int32code = 3
for _, endpoint := range info.Endpoint {
- rdr, err := doGet(ctx, cl, endpoint, cache,
impl.timeout)
+ rdr, err := doGet(ctx, cl, endpoint, cache, d.timeout)
if err != nil {
continue
}
@@ -465,8 +469,8 @@ func (impl *databaseImpl) Open(ctx context.Context)
(adbc.Connection, error) {
}
}
- return &cnxn{cl: cl, db: impl, clientCache: cache,
- hdrs: make(metadata.MD), timeouts: impl.timeout,
+ return &cnxn{cl: cl, db: d, clientCache: cache,
+ hdrs: make(metadata.MD), timeouts: d.timeout,
supportInfo: cnxnSupport}, nil
}
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go
b/go/adbc/driver/flightsql/flightsql_driver.go
index 0060c040..cc58a9e1 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -145,5 +145,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string)
(adbc.Database, error)
if err := db.SetOptions(opts); err != nil {
return nil, err
}
+
return driverbase.NewDatabase(db), nil
}
diff --git a/go/adbc/driver/panicdummy/panicdummy_adbc.go
b/go/adbc/driver/panicdummy/panicdummy_adbc.go
index f0513cd6..95171591 100644
--- a/go/adbc/driver/panicdummy/panicdummy_adbc.go
+++ b/go/adbc/driver/panicdummy/panicdummy_adbc.go
@@ -66,6 +66,11 @@ func (d *database) Open(ctx context.Context)
(adbc.Connection, error) {
return &cnxn{}, nil
}
+func (d *database) Close() error {
+ maybePanic("DatabaseClose")
+ return nil
+}
+
type cnxn struct{}
func (c *cnxn) SetOption(key, value string) error {
diff --git a/go/adbc/driver/snowflake/connection.go
b/go/adbc/driver/snowflake/connection.go
index 73c31604..e2f98487 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -714,16 +714,16 @@ func prepareTablesSQL(matchingCatalogNames []string,
catalog *string, dbSchema *
func prepareColumnsSQL(matchingCatalogNames []string, catalog *string,
dbSchema *string, tableName *string, columnName *string, tableType []string)
(string, []interface{}) {
prefixQuery := ""
- for _, catalog_name := range matchingCatalogNames {
+ for _, catalogName := range matchingCatalogNames {
if prefixQuery != "" {
prefixQuery += " UNION ALL "
}
prefixQuery += `SELECT T.table_type,
C.*
FROM
- "` + strings.ReplaceAll(catalog_name, "\"",
"\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
+ "` + strings.ReplaceAll(catalogName, "\"",
"\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
JOIN
- "` + strings.ReplaceAll(catalog_name, "\"",
"\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
+ "` + strings.ReplaceAll(catalogName, "\"",
"\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
ON
T.table_catalog = C.table_catalog
AND T.table_schema = C.table_schema
diff --git a/go/adbc/driver/snowflake/driver.go
b/go/adbc/driver/snowflake/driver.go
index db5efed4..3b9d72cc 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -202,5 +202,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string)
(adbc.Database, error)
if err := db.SetOptions(opts); err != nil {
return nil, err
}
+
return driverbase.NewDatabase(db), nil
}
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index 61e94486..a69a0b04 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -336,6 +336,7 @@ func (suite *SnowflakeTests) TearDownTest() {
}
func (suite *SnowflakeTests) TearDownSuite() {
+ suite.NoError(suite.db.Close())
suite.db = nil
}
@@ -464,21 +465,21 @@ func (suite *SnowflakeTests)
TestMetadataGetObjectsColumnsXdbc() {
xdbcDateTimeSub []string
}{
{
- "BASIC", //name
- []string{"int64s", "strings"}, //colNames
- []string{"1", "2"}, //positions
- []string{"NUMBER", "TEXT"}, //dataTypes
- []string{"", ""}, //comments
- []string{"9", "13"}, //xdbcDataType
- []string{"NUMBER", "TEXT"}, //xdbcTypeName
- []string{"-5", "12"}, //xdbcSqlDataType
- []string{"1", "1"}, //xdbcNullable
- []string{"YES", "YES"}, //xdbcIsNullable
- []string{"0", "0"}, //xdbcScale
- []string{"10", "0"}, //xdbcNumPrecRadix
- []string{"38", "16777216"}, //xdbcCharMaxLen
(xdbcPrecision)
- []string{"0", "16777216"}, //xdbcCharOctetLen
- []string{"-5", "12", "0"}, //xdbcDateTimeSub
+ "BASIC", // name
+ []string{"int64s", "strings"}, // colNames
+ []string{"1", "2"}, // positions
+ []string{"NUMBER", "TEXT"}, // dataTypes
+ []string{"", ""}, // comments
+ []string{"9", "13"}, // xdbcDataType
+ []string{"NUMBER", "TEXT"}, // xdbcTypeName
+ []string{"-5", "12"}, // xdbcSqlDataType
+ []string{"1", "1"}, // xdbcNullable
+ []string{"YES", "YES"}, // xdbcIsNullable
+ []string{"0", "0"}, // xdbcScale
+ []string{"10", "0"}, // xdbcNumPrecRadix
+ []string{"38", "16777216"}, // xdbcCharMaxLen
(xdbcPrecision)
+ []string{"0", "16777216"}, // xdbcCharOctetLen
+ []string{"-5", "12", "0"}, // xdbcDateTimeSub
},
}
@@ -576,20 +577,20 @@ func (suite *SnowflakeTests)
TestMetadataGetObjectsColumnsXdbc() {
suite.False(rdr.Next())
suite.True(foundExpected)
- suite.Equal(tt.colnames, colnames)
//colNames
- suite.Equal(tt.positions, positions)
//positions
- suite.Equal(tt.comments, comments)
//comments
- suite.Equal(tt.xdbcDataType, xdbcDataTypes)
//xdbcDataType
- suite.Equal(tt.dataTypes, dataTypes)
//dataTypes
- suite.Equal(tt.xdbcTypeName, xdbcTypeNames)
//xdbcTypeName
- suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens)
//xdbcCharMaxLen
- suite.Equal(tt.xdbcScale, xdbcScales)
//xdbcScale
- suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs)
//xdbcNumPrecRadix
- suite.Equal(tt.xdbcNullable, xdbcNullables)
//xdbcNullable
- suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes)
//xdbcSqlDataType
- suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub)
//xdbcDateTimeSub
- suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen)
//xdbcCharOctetLen
- suite.Equal(tt.xdbcIsNullable, xdbcIsNullables)
//xdbcIsNullable
+ suite.Equal(tt.colnames, colnames) //
colNames
+ suite.Equal(tt.positions, positions) //
positions
+ suite.Equal(tt.comments, comments) //
comments
+ suite.Equal(tt.xdbcDataType, xdbcDataTypes) //
xdbcDataType
+ suite.Equal(tt.dataTypes, dataTypes) //
dataTypes
+ suite.Equal(tt.xdbcTypeName, xdbcTypeNames) //
xdbcTypeName
+ suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens) //
xdbcCharMaxLen
+ suite.Equal(tt.xdbcScale, xdbcScales) //
xdbcScale
+ suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) //
xdbcNumPrecRadix
+ suite.Equal(tt.xdbcNullable, xdbcNullables) //
xdbcNullable
+ suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes) //
xdbcSqlDataType
+ suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub) //
xdbcDateTimeSub
+ suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen) //
xdbcCharOctetLen
+ suite.Equal(tt.xdbcIsNullable, xdbcIsNullables) //
xdbcIsNullable
})
}
@@ -605,6 +606,7 @@ func (suite *SnowflakeTests) TestNewDatabaseGetSetOptions()
{
})
suite.NoError(err)
suite.NotNil(db)
+ defer suite.NoError(db.Close())
getSetDB, ok := db.(adbc.GetSetOptions)
suite.True(ok)
@@ -862,6 +864,7 @@ func ConnectWithJwt(uri, keyValue, passcode string) {
if err != nil {
panic(err)
}
+ defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
@@ -912,6 +915,7 @@ func (suite *SnowflakeTests) TestJwtPrivateKey() {
opts[driver.OptionJwtPrivateKey] = keyFile
db, err := suite.driver.NewDatabase(opts)
suite.NoError(err)
+ defer db.Close()
cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
defer cnxn.Close()
diff --git a/go/adbc/driver/snowflake/snowflake_database.go
b/go/adbc/driver/snowflake/snowflake_database.go
index 45e3aab4..7b76fa5a 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -466,3 +466,7 @@ func (d *databaseImpl) Open(ctx context.Context)
(adbc.Connection, error) {
useHighPrecision: d.useHighPrecision,
}, nil
}
+
+func (d *databaseImpl) Close() error {
+ return nil
+}
diff --git a/go/adbc/drivermgr/wrapper.go b/go/adbc/drivermgr/wrapper.go
index 07bb94b8..63fb9ee9 100644
--- a/go/adbc/drivermgr/wrapper.go
+++ b/go/adbc/drivermgr/wrapper.go
@@ -39,7 +39,7 @@ package drivermgr
import "C"
import (
"context"
- "runtime"
+ "sync"
"unsafe"
"github.com/apache/arrow-adbc/go/adbc"
@@ -100,27 +100,15 @@ func (d Driver) NewDatabase(opts map[string]string)
(adbc.Database, error) {
return nil, errOut
}
- runtime.SetFinalizer(db, func(db *Database) {
- if db.db != nil {
- var err C.struct_AdbcError
- code := adbc.Status(C.AdbcDatabaseRelease(db.db, &err))
- if code != adbc.StatusOK {
- panic(toAdbcError(code, &err))
- }
- }
-
- for _, o := range db.options {
- C.free(unsafe.Pointer(o.key))
- C.free(unsafe.Pointer(o.val))
- }
- })
-
return db, nil
}
type Database struct {
options map[string]option
db *C.struct_AdbcDatabase
+
+ mu sync.Mutex // protects following fields
+ closed bool
}
func toAdbcError(code adbc.Status, e *C.struct_AdbcError) error {
@@ -182,6 +170,32 @@ func (d *Database) Open(context.Context) (adbc.Connection,
error) {
return &cnxn{conn: &c}, nil
}
+func (d *Database) Close() error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.closed {
+ return nil
+ }
+
+ d.closed = true
+
+ for _, o := range d.options {
+ C.free(unsafe.Pointer(o.key))
+ C.free(unsafe.Pointer(o.val))
+ }
+
+ if d.db != nil {
+ var err C.struct_AdbcError
+ code := adbc.Status(C.AdbcDatabaseRelease(d.db, &err))
+ if code != adbc.StatusOK {
+ return toAdbcError(code, &err)
+ }
+ }
+
+ return nil
+}
+
func getRdr(out *C.struct_ArrowArrayStream) (array.RecordReader, error) {
rdr, err :=
cdata.ImportCRecordReader((*cdata.CArrowArrayStream)(unsafe.Pointer(out)), nil)
if err != nil {
diff --git a/go/adbc/drivermgr/wrapper_sqlite_test.go
b/go/adbc/drivermgr/wrapper_sqlite_test.go
index c33adf27..af307a08 100644
--- a/go/adbc/drivermgr/wrapper_sqlite_test.go
+++ b/go/adbc/drivermgr/wrapper_sqlite_test.go
@@ -74,6 +74,10 @@ func (dm *DriverMgrSuite) SetupSuite() {
dm.Equal(int64(1), nrows)
}
+func (dm *DriverMgrSuite) TearDownSuite() {
+ dm.NoError(dm.db.Close())
+}
+
func (dm *DriverMgrSuite) SetupTest() {
cnxn, err := dm.db.Open(dm.ctx)
dm.Require().NoError(err)
@@ -597,6 +601,7 @@ func TestDriverMgrCustomInitFunc(t *testing.T) {
cnxn, err := db.Open(context.Background())
assert.NoError(t, err)
require.NoError(t, cnxn.Close())
+ require.NoError(t, db.Close())
// set invalid entrypoint
drv = drivermgr.Driver{}
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 4b7008ea..901d164e 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -591,10 +591,15 @@ func {{.Prefix}}DatabaseRelease(db
*C.struct_AdbcDatabase, err *C.struct_AdbcErr
h := (*(*cgo.Handle)(db.private_data))
cdb := h.Value().(*cDatabase)
- cdb.db = nil
+ if cdb.db != nil {
+ cdb.db.Close()
+ cdb.db = nil
+ }
cdb.opts = nil
- C.free(unsafe.Pointer(db.private_data))
- db.private_data = nil
+ if db.private_data != nil {
+ C.free(unsafe.Pointer(db.private_data))
+ db.private_data = nil
+ }
h.Delete()
// manually trigger GC for two reasons:
// 1. ASAN expects the release callback to be called before
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 2847274c..d57a91b7 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -594,10 +594,15 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase,
err *C.struct_AdbcError
h := (*(*cgo.Handle)(db.private_data))
cdb := h.Value().(*cDatabase)
- cdb.db = nil
+ if cdb.db != nil {
+ cdb.db.Close()
+ cdb.db = nil
+ }
cdb.opts = nil
- C.free(unsafe.Pointer(db.private_data))
- db.private_data = nil
+ if db.private_data != nil {
+ C.free(unsafe.Pointer(db.private_data))
+ db.private_data = nil
+ }
h.Delete()
// manually trigger GC for two reasons:
// 1. ASAN expects the release callback to be called before
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 399d0edc..fbaa5204 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -594,6 +594,7 @@ func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase,
err *C.struct_AdbcErro
h := (*(*cgo.Handle)(db.private_data))
cdb := h.Value().(*cDatabase)
+ cdb.db.Close()
cdb.db = nil
cdb.opts = nil
C.free(unsafe.Pointer(db.private_data))
diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go
index b5910181..6e2d3bac 100644
--- a/go/adbc/pkg/snowflake/driver.go
+++ b/go/adbc/pkg/snowflake/driver.go
@@ -594,10 +594,15 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase,
err *C.struct_AdbcError
h := (*(*cgo.Handle)(db.private_data))
cdb := h.Value().(*cDatabase)
- cdb.db = nil
+ if cdb.db != nil {
+ cdb.db.Close()
+ cdb.db = nil
+ }
cdb.opts = nil
- C.free(unsafe.Pointer(db.private_data))
- db.private_data = nil
+ if db.private_data != nil {
+ C.free(unsafe.Pointer(db.private_data))
+ db.private_data = nil
+ }
h.Delete()
// manually trigger GC for two reasons:
// 1. ASAN expects the release callback to be called before
diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go
index 8925f725..19222807 100644
--- a/go/adbc/validation/validation.go
+++ b/go/adbc/validation/validation.go
@@ -100,6 +100,7 @@ func (d *DatabaseTests) TestNewDatabase() {
d.NoError(err)
d.NotNil(db)
d.Implements((*adbc.Database)(nil), db)
+ d.NoError(db.Close())
}
type ConnectionTests struct {
@@ -121,6 +122,7 @@ func (c *ConnectionTests) SetupTest() {
func (c *ConnectionTests) TearDownTest() {
c.Quirks.TearDownDriver(c.T(), c.Driver)
c.Driver = nil
+ c.NoError(c.DB.Close())
c.DB = nil
}
@@ -514,6 +516,7 @@ func (s *StatementTests) TearDownTest() {
s.Require().NoError(s.Cnxn.Close())
s.Quirks.TearDownDriver(s.T(), s.Driver)
s.Cnxn = nil
+ s.NoError(s.DB.Close())
s.DB = nil
s.Driver = nil
}