This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b1510aa fix(go/adbc/pkg): follow CGO rules properly (#902)
7b1510aa is described below

commit 7b1510aa087bfa31c1b92e0264d8c13c219d6a4a
Author: David Li <[email protected]>
AuthorDate: Thu Jul 13 19:30:41 2023 -0400

    fix(go/adbc/pkg): follow CGO rules properly (#902)
    
    Depends on apache/arrow#36670.
    
    Fixes #729.
---
 c/driver/flightsql/sqlite_flightsql_test.cc | 106 ++++++++++++++++++++++++++++
 go/adbc/go.mod                              |   2 +-
 go/adbc/go.sum                              |   6 +-
 go/adbc/pkg/_tmpl/driver.go.tmpl            |  10 +++
 go/adbc/pkg/_tmpl/utils.c.tmpl              |  11 +++
 go/adbc/pkg/flightsql/driver.go             |  10 +++
 go/adbc/pkg/flightsql/utils.c               |  11 +++
 go/adbc/pkg/panicdummy/driver.go            |  10 +++
 go/adbc/pkg/panicdummy/utils.c              |  11 +++
 go/adbc/pkg/snowflake/driver.go             |  10 +++
 go/adbc/pkg/snowflake/utils.c               |  11 +++
 11 files changed, 193 insertions(+), 5 deletions(-)

diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc 
b/c/driver/flightsql/sqlite_flightsql_test.cc
index 2bf0441e..96d448a0 100644
--- a/c/driver/flightsql/sqlite_flightsql_test.cc
+++ b/c/driver/flightsql/sqlite_flightsql_test.cc
@@ -15,15 +15,21 @@
 // specific language governing permissions and limitations
 // under the License.
 
+#include <chrono>
+#include <random>
+#include <thread>
+
 #include <adbc.h>
 #include <gmock/gmock-matchers.h>
 #include <gtest/gtest-matchers.h>
 #include <gtest/gtest-param-test.h>
 #include <gtest/gtest.h>
 #include <nanoarrow/nanoarrow.h>
+
 #include "validation/adbc_validation.h"
 #include "validation/adbc_validation_util.h"
 
+using adbc_validation::IsOkErrno;
 using adbc_validation::IsOkStatus;
 
 #define CHECK_OK(EXPR)                                              \
@@ -103,6 +109,106 @@ class SqliteFlightSqlTest : public ::testing::Test, 
public adbc_validation::Data
 };
 ADBCV_TEST_DATABASE(SqliteFlightSqlTest)
 
+TEST_F(SqliteFlightSqlTest, TestGarbageInput) {
+  // Regression test for https://github.com/apache/arrow-adbc/issues/729
+
+  // 0xc000000000 is the base of the Go heap.  Go's write barriers ask
+  // the GC to mark both the pointer being written, and the pointer
+  // being *overwritten*.  So if Go overwrites a value in a C
+  // structure that looks like a Go pointer, the GC may get confused
+  // and error.
+  void* bad_pointer = reinterpret_cast<void*>(uintptr_t(0xc000000240));
+
+  // ADBC functions are expected not to blindly overwrite an
+  // already-allocated value/callers are expected to zero-initialize.
+  database.private_data = bad_pointer;
+  database.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
+  ASSERT_THAT(AdbcDatabaseNew(&database, &error), 
::testing::Not(IsOkStatus(&error)));
+
+  std::memset(&database, 0, sizeof(database));
+  ASSERT_THAT(AdbcDatabaseNew(&database, &error), IsOkStatus(&error));
+  ASSERT_THAT(quirks()->SetupDatabase(&database, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcDatabaseInit(&database, &error), IsOkStatus(&error));
+
+  struct AdbcConnection connection;
+  connection.private_data = bad_pointer;
+  connection.private_driver = reinterpret_cast<struct 
AdbcDriver*>(bad_pointer);
+  ASSERT_THAT(AdbcConnectionNew(&connection, &error), 
::testing::Not(IsOkStatus(&error)));
+
+  std::memset(&connection, 0, sizeof(connection));
+  ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), 
IsOkStatus(&error));
+
+  struct AdbcStatement statement;
+  statement.private_data = bad_pointer;
+  statement.private_driver = reinterpret_cast<struct AdbcDriver*>(bad_pointer);
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
+              ::testing::Not(IsOkStatus(&error)));
+
+  // This needs to happen in parallel since we need to trigger the
+  // write barrier buffer, which means we need to trigger a GC.  The
+  // Go FFI bridge deterministically triggers GC on Release calls.
+
+  auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
+  while (std::chrono::steady_clock::now() < deadline) {
+    std::vector<std::thread> threads;
+    std::random_device rd;
+    for (int i = 0; i < 23; i++) {
+      auto seed = rd();
+      threads.emplace_back([&, seed]() {
+        std::mt19937 gen(seed);
+        std::uniform_int_distribution<int64_t> dist(0xc000000000L, 
0xc000002000L);
+        for (int i = 0; i < 23; i++) {
+          void* bad_pointer = reinterpret_cast<void*>(uintptr_t(dist(gen)));
+
+          struct AdbcStatement statement;
+          std::memset(&statement, 0, sizeof(statement));
+          ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
+                      IsOkStatus(&error));
+
+          ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 1", &error),
+                      IsOkStatus(&error));
+          // This is not expected to be zero-initialized
+          struct ArrowArrayStream stream;
+          stream.private_data = bad_pointer;
+          stream.release =
+              reinterpret_cast<void (*)(struct 
ArrowArrayStream*)>(bad_pointer);
+          ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &stream, nullptr, 
&error),
+                      IsOkStatus(&error));
+
+          struct ArrowSchema schema;
+          std::memset(&schema, 0, sizeof(schema));
+          schema.name = reinterpret_cast<const char*>(bad_pointer);
+          schema.format = reinterpret_cast<const char*>(bad_pointer);
+          schema.private_data = bad_pointer;
+          ASSERT_THAT(stream.get_schema(&stream, &schema), IsOkErrno());
+
+          while (true) {
+            struct ArrowArray array;
+            array.private_data = bad_pointer;
+            ASSERT_THAT(stream.get_next(&stream, &array), IsOkErrno());
+            if (array.release) {
+              array.release(&array);
+            } else {
+              break;
+            }
+          }
+
+          schema.release(&schema);
+          stream.release(&stream);
+          ASSERT_THAT(AdbcStatementRelease(&statement, &error), 
IsOkStatus(&error));
+        }
+      });
+    }
+    for (auto& thread : threads) {
+      thread.join();
+    }
+  }
+
+  ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+}
+
 class SqliteFlightSqlConnectionTest : public ::testing::Test,
                                       public adbc_validation::ConnectionTest {
  public:
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index e4496c24..1a412b22 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
 go 1.18
 
 require (
-       github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553
+       github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355
        github.com/bluele/gcache v0.0.2
        github.com/google/uuid v1.3.0
        github.com/snowflakedb/gosnowflake v1.6.21
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index c8b128ec..70654f96 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -16,10 +16,8 @@ github.com/andybalholm/brotli v1.0.5 
h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/
 github.com/andybalholm/brotli v1.0.5/go.mod 
h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
 github.com/apache/arrow/go/v12 v12.0.0 
h1:xtZE63VWl7qLdB0JObIXvvhGjoVNrQ9ciIHG2OK5cmc=
 github.com/apache/arrow/go/v12 v12.0.0/go.mod 
h1:d+tV/eHZZ7Dz7RPrFKtPK02tpr+c9/PEd/zm8mDS9Vg=
-github.com/apache/arrow/go/v13 v13.0.0-20230620164925-94af6c3c9646 
h1:hLcsUn9hiiD7jDfJDKOe1tBfOL5v0wgrya5S8XXqzLw=
-github.com/apache/arrow/go/v13 v13.0.0-20230620164925-94af6c3c9646/go.mod 
h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
-github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553 
h1:LV3nIWJ2254APRpYAcMxWbxoQwt66gnrkZ5NaDs1IPI=
-github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553/go.mod 
h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
+github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355 
h1:QuXqLb2HzL5EjY99fFp+iG9NagAruvQIbU/2++x+2VY=
+github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355/go.mod 
h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
 github.com/apache/thrift v0.16.0 
h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY=
 github.com/apache/thrift v0.16.0/go.mod 
h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU=
 github.com/aws/aws-sdk-go-v2 v1.18.0 
h1:882kkTpSFhdgYRKVZ/VCgf7sd0ru57p2JCxz4/oN5RY=
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 03a94c02..248800f7 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -444,6 +444,7 @@ func {{.Prefix}}ConnectionGetInfo(cnxn 
*C.struct_AdbcConnection, codes *C.uint32
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
 
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -489,6 +490,7 @@ func {{.Prefix}}ConnectionGetObjects(cnxn 
*C.struct_AdbcConnection, depth C.int,
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -529,6 +531,7 @@ func {{.Prefix}}ConnectionGetTableTypes(cnxn 
*C.struct_AdbcConnection, out *C.st
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -549,6 +552,7 @@ func {{.Prefix}}ConnectionReadPartition(cnxn 
*C.struct_AdbcConnection, serialize
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -612,6 +616,11 @@ func {{.Prefix}}StatementNew(cnxn 
*C.struct_AdbcConnection, stmt *C.struct_AdbcS
                setErr(err, "AdbcStatementNew: Go panicked, driver is in 
unknown state")
                return C.ADBC_STATUS_INTERNAL
        }
+       if stmt.private_data != nil {
+               setErr(err, "AdbcStatementNew: statement already allocated")
+               return C.ADBC_STATUS_INVALID_STATE
+       }
+
        conn := checkConnInit(cnxn, err, "AdbcStatementNew")
        if conn == nil {
                return C.ADBC_STATUS_INVALID_STATE
@@ -711,6 +720,7 @@ func {{.Prefix}}StatementExecuteQuery(stmt 
*C.struct_AdbcStatement, out *C.struc
                        *affected = C.int64_t(n)
                }
 
+               defer rdr.Release()
                cdata.ExportRecordReader(rdr, toCdataStream(out))
        }
        return C.ADBC_STATUS_OK
diff --git a/go/adbc/pkg/_tmpl/utils.c.tmpl b/go/adbc/pkg/_tmpl/utils.c.tmpl
index 29d19bc5..38222875 100644
--- a/go/adbc/pkg/_tmpl/utils.c.tmpl
+++ b/go/adbc/pkg/_tmpl/utils.c.tmpl
@@ -21,6 +21,8 @@
 
 #include "utils.h"
 
+#include <string.h>
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -74,6 +76,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* 
connection,
                                      uint32_t* info_codes, size_t 
info_codes_length,
                                      struct ArrowArrayStream* out,
                                      struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return {{.Prefix}}ConnectionGetInfo(connection, info_codes, 
info_codes_length, out, error);
 }
 
@@ -83,6 +86,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct 
AdbcConnection* connection, int d
                                         const char* column_name,
                                         struct ArrowArrayStream* out,
                                         struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return {{.Prefix}}ConnectionGetObjects(connection, depth, catalog, 
db_schema, table_name,
                                     table_type, column_name, out, error);
 }
@@ -92,6 +96,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
                                             const char* table_name,
                                             struct ArrowSchema* schema,
                                             struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return {{.Prefix}}ConnectionGetTableSchema(connection, catalog, db_schema, 
table_name,
                                         schema, error);
 }
@@ -99,6 +104,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
 AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return {{.Prefix}}ConnectionGetTableTypes(connection, out, error);
 }
 
@@ -107,6 +113,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct 
AdbcConnection* connection,
                                            size_t serialized_length,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return {{.Prefix}}ConnectionReadPartition(connection, serialized_partition,
                                        serialized_length, out, error);
 }
@@ -136,6 +143,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct 
AdbcStatement* statement,
                                          struct ArrowArrayStream* out,
                                          int64_t* rows_affected,
                                          struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, 
error);
 }
 
@@ -170,6 +178,7 @@ AdbcStatusCode AdbcStatementBindStream(struct 
AdbcStatement* statement,
 AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
                                                struct ArrowSchema* schema,
                                                struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return {{.Prefix}}StatementGetParameterSchema(statement, schema, error);
 }
 
@@ -183,6 +192,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct 
AdbcStatement* statement,
                                               struct AdbcPartitions* 
partitions,
                                               int64_t* rows_affected,
                                               struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
+  if (partitions) memset(partitions, 0, sizeof(*partitions));
   return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, 
rows_affected,
                                           error);
 }
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 6d5cf75b..2936f90e 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -448,6 +448,7 @@ func FlightSQLConnectionGetInfo(cnxn 
*C.struct_AdbcConnection, codes *C.uint32_t
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
 
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -493,6 +494,7 @@ func FlightSQLConnectionGetObjects(cnxn 
*C.struct_AdbcConnection, depth C.int, c
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -533,6 +535,7 @@ func FlightSQLConnectionGetTableTypes(cnxn 
*C.struct_AdbcConnection, out *C.stru
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -553,6 +556,7 @@ func FlightSQLConnectionReadPartition(cnxn 
*C.struct_AdbcConnection, serialized
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -616,6 +620,11 @@ func FlightSQLStatementNew(cnxn *C.struct_AdbcConnection, 
stmt *C.struct_AdbcSta
                setErr(err, "AdbcStatementNew: Go panicked, driver is in 
unknown state")
                return C.ADBC_STATUS_INTERNAL
        }
+       if stmt.private_data != nil {
+               setErr(err, "AdbcStatementNew: statement already allocated")
+               return C.ADBC_STATUS_INVALID_STATE
+       }
+
        conn := checkConnInit(cnxn, err, "AdbcStatementNew")
        if conn == nil {
                return C.ADBC_STATUS_INVALID_STATE
@@ -715,6 +724,7 @@ func FlightSQLStatementExecuteQuery(stmt 
*C.struct_AdbcStatement, out *C.struct_
                        *affected = C.int64_t(n)
                }
 
+               defer rdr.Release()
                cdata.ExportRecordReader(rdr, toCdataStream(out))
        }
        return C.ADBC_STATUS_OK
diff --git a/go/adbc/pkg/flightsql/utils.c b/go/adbc/pkg/flightsql/utils.c
index 3d3d89c5..41777a98 100644
--- a/go/adbc/pkg/flightsql/utils.c
+++ b/go/adbc/pkg/flightsql/utils.c
@@ -23,6 +23,8 @@
 
 #include "utils.h"
 
+#include <string.h>
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* 
connection,
                                      uint32_t* info_codes, size_t 
info_codes_length,
                                      struct ArrowArrayStream* out,
                                      struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, 
out,
                                     error);
 }
@@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct 
AdbcConnection* connection, int d
                                         const char* column_name,
                                         struct ArrowArrayStream* out,
                                         struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return FlightSQLConnectionGetObjects(connection, depth, catalog, db_schema, 
table_name,
                                        table_type, column_name, out, error);
 }
@@ -95,6 +99,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
                                             const char* table_name,
                                             struct ArrowSchema* schema,
                                             struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return FlightSQLConnectionGetTableSchema(connection, catalog, db_schema, 
table_name,
                                            schema, error);
 }
@@ -102,6 +107,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
 AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return FlightSQLConnectionGetTableTypes(connection, out, error);
 }
 
@@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct 
AdbcConnection* connection,
                                            size_t serialized_length,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return FlightSQLConnectionReadPartition(connection, serialized_partition,
                                           serialized_length, out, error);
 }
@@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct 
AdbcStatement* statement,
                                          struct ArrowArrayStream* out,
                                          int64_t* rows_affected,
                                          struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error);
 }
 
@@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct 
AdbcStatement* statement,
 AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
                                                struct ArrowSchema* schema,
                                                struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return FlightSQLStatementGetParameterSchema(statement, schema, error);
 }
 
@@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct 
AdbcStatement* statement,
                                               struct AdbcPartitions* 
partitions,
                                               int64_t* rows_affected,
                                               struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
+  if (partitions) memset(partitions, 0, sizeof(*partitions));
   return FlightSQLStatementExecutePartitions(statement, schema, partitions, 
rows_affected,
                                              error);
 }
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 374c3cb8..33a4f984 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -448,6 +448,7 @@ func PanicDummyConnectionGetInfo(cnxn 
*C.struct_AdbcConnection, codes *C.uint32_
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
 
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -493,6 +494,7 @@ func PanicDummyConnectionGetObjects(cnxn 
*C.struct_AdbcConnection, depth C.int,
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -533,6 +535,7 @@ func PanicDummyConnectionGetTableTypes(cnxn 
*C.struct_AdbcConnection, out *C.str
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -553,6 +556,7 @@ func PanicDummyConnectionReadPartition(cnxn 
*C.struct_AdbcConnection, serialized
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -616,6 +620,11 @@ func PanicDummyStatementNew(cnxn *C.struct_AdbcConnection, 
stmt *C.struct_AdbcSt
                setErr(err, "AdbcStatementNew: Go panicked, driver is in 
unknown state")
                return C.ADBC_STATUS_INTERNAL
        }
+       if stmt.private_data != nil {
+               setErr(err, "AdbcStatementNew: statement already allocated")
+               return C.ADBC_STATUS_INVALID_STATE
+       }
+
        conn := checkConnInit(cnxn, err, "AdbcStatementNew")
        if conn == nil {
                return C.ADBC_STATUS_INVALID_STATE
@@ -715,6 +724,7 @@ func PanicDummyStatementExecuteQuery(stmt 
*C.struct_AdbcStatement, out *C.struct
                        *affected = C.int64_t(n)
                }
 
+               defer rdr.Release()
                cdata.ExportRecordReader(rdr, toCdataStream(out))
        }
        return C.ADBC_STATUS_OK
diff --git a/go/adbc/pkg/panicdummy/utils.c b/go/adbc/pkg/panicdummy/utils.c
index 5978aaa5..d0a29366 100644
--- a/go/adbc/pkg/panicdummy/utils.c
+++ b/go/adbc/pkg/panicdummy/utils.c
@@ -23,6 +23,8 @@
 
 #include "utils.h"
 
+#include <string.h>
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* 
connection,
                                      uint32_t* info_codes, size_t 
info_codes_length,
                                      struct ArrowArrayStream* out,
                                      struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return PanicDummyConnectionGetInfo(connection, info_codes, 
info_codes_length, out,
                                      error);
 }
@@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct 
AdbcConnection* connection, int d
                                         const char* column_name,
                                         struct ArrowArrayStream* out,
                                         struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return PanicDummyConnectionGetObjects(connection, depth, catalog, db_schema, 
table_name,
                                         table_type, column_name, out, error);
 }
@@ -95,6 +99,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
                                             const char* table_name,
                                             struct ArrowSchema* schema,
                                             struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return PanicDummyConnectionGetTableSchema(connection, catalog, db_schema, 
table_name,
                                             schema, error);
 }
@@ -102,6 +107,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
 AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return PanicDummyConnectionGetTableTypes(connection, out, error);
 }
 
@@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct 
AdbcConnection* connection,
                                            size_t serialized_length,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return PanicDummyConnectionReadPartition(connection, serialized_partition,
                                            serialized_length, out, error);
 }
@@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct 
AdbcStatement* statement,
                                          struct ArrowArrayStream* out,
                                          int64_t* rows_affected,
                                          struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return PanicDummyStatementExecuteQuery(statement, out, rows_affected, error);
 }
 
@@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct 
AdbcStatement* statement,
 AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
                                                struct ArrowSchema* schema,
                                                struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return PanicDummyStatementGetParameterSchema(statement, schema, error);
 }
 
@@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct 
AdbcStatement* statement,
                                               struct AdbcPartitions* 
partitions,
                                               int64_t* rows_affected,
                                               struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
+  if (partitions) memset(partitions, 0, sizeof(*partitions));
   return PanicDummyStatementExecutePartitions(statement, schema, partitions,
                                               rows_affected, error);
 }
diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go
index 31e2f131..12307b1a 100644
--- a/go/adbc/pkg/snowflake/driver.go
+++ b/go/adbc/pkg/snowflake/driver.go
@@ -448,6 +448,7 @@ func SnowflakeConnectionGetInfo(cnxn 
*C.struct_AdbcConnection, codes *C.uint32_t
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
 
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -493,6 +494,7 @@ func SnowflakeConnectionGetObjects(cnxn 
*C.struct_AdbcConnection, depth C.int, c
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -533,6 +535,7 @@ func SnowflakeConnectionGetTableTypes(cnxn 
*C.struct_AdbcConnection, out *C.stru
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -553,6 +556,7 @@ func SnowflakeConnectionReadPartition(cnxn 
*C.struct_AdbcConnection, serialized
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
+       defer rdr.Release()
        cdata.ExportRecordReader(rdr, toCdataStream(out))
        return C.ADBC_STATUS_OK
 }
@@ -616,6 +620,11 @@ func SnowflakeStatementNew(cnxn *C.struct_AdbcConnection, 
stmt *C.struct_AdbcSta
                setErr(err, "AdbcStatementNew: Go panicked, driver is in 
unknown state")
                return C.ADBC_STATUS_INTERNAL
        }
+       if stmt.private_data != nil {
+               setErr(err, "AdbcStatementNew: statement already allocated")
+               return C.ADBC_STATUS_INVALID_STATE
+       }
+
        conn := checkConnInit(cnxn, err, "AdbcStatementNew")
        if conn == nil {
                return C.ADBC_STATUS_INVALID_STATE
@@ -715,6 +724,7 @@ func SnowflakeStatementExecuteQuery(stmt 
*C.struct_AdbcStatement, out *C.struct_
                        *affected = C.int64_t(n)
                }
 
+               defer rdr.Release()
                cdata.ExportRecordReader(rdr, toCdataStream(out))
        }
        return C.ADBC_STATUS_OK
diff --git a/go/adbc/pkg/snowflake/utils.c b/go/adbc/pkg/snowflake/utils.c
index 8c360b0f..24d3ca3d 100644
--- a/go/adbc/pkg/snowflake/utils.c
+++ b/go/adbc/pkg/snowflake/utils.c
@@ -23,6 +23,8 @@
 
 #include "utils.h"
 
+#include <string.h>
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* 
connection,
                                      uint32_t* info_codes, size_t 
info_codes_length,
                                      struct ArrowArrayStream* out,
                                      struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return SnowflakeConnectionGetInfo(connection, info_codes, info_codes_length, 
out,
                                     error);
 }
@@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct 
AdbcConnection* connection, int d
                                         const char* column_name,
                                         struct ArrowArrayStream* out,
                                         struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return SnowflakeConnectionGetObjects(connection, depth, catalog, db_schema, 
table_name,
                                        table_type, column_name, out, error);
 }
@@ -95,6 +99,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
                                             const char* table_name,
                                             struct ArrowSchema* schema,
                                             struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return SnowflakeConnectionGetTableSchema(connection, catalog, db_schema, 
table_name,
                                            schema, error);
 }
@@ -102,6 +107,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct 
AdbcConnection* connection,
 AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return SnowflakeConnectionGetTableTypes(connection, out, error);
 }
 
@@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct 
AdbcConnection* connection,
                                            size_t serialized_length,
                                            struct ArrowArrayStream* out,
                                            struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return SnowflakeConnectionReadPartition(connection, serialized_partition,
                                           serialized_length, out, error);
 }
@@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct 
AdbcStatement* statement,
                                          struct ArrowArrayStream* out,
                                          int64_t* rows_affected,
                                          struct AdbcError* error) {
+  if (out) memset(out, 0, sizeof(*out));
   return SnowflakeStatementExecuteQuery(statement, out, rows_affected, error);
 }
 
@@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct 
AdbcStatement* statement,
 AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
                                                struct ArrowSchema* schema,
                                                struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
   return SnowflakeStatementGetParameterSchema(statement, schema, error);
 }
 
@@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct 
AdbcStatement* statement,
                                               struct AdbcPartitions* 
partitions,
                                               int64_t* rows_affected,
                                               struct AdbcError* error) {
+  if (schema) memset(schema, 0, sizeof(*schema));
+  if (partitions) memset(partitions, 0, sizeof(*partitions));
   return SnowflakeStatementExecutePartitions(statement, schema, partitions, 
rows_affected,
                                              error);
 }

Reply via email to