This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 30297b8  fix(go/adbc/driver/driver/flightsql): use libc allocator 
(#407)
30297b8 is described below

commit 30297b82136a6c041954b97e32562518b6b4ec48
Author: David Li <[email protected]>
AuthorDate: Thu Feb 2 12:27:41 2023 -0500

    fix(go/adbc/driver/driver/flightsql): use libc allocator (#407)
    
    This lets us pass cgocheck=2.
---
 ci/scripts/cpp_test.sh                     |  1 +
 go/adbc/driver/flightsql/flightsql_adbc.go | 16 +++++++++-------
 go/adbc/pkg/_tmpl/driver.go.tmpl           | 26 ++++++++++++++++++++------
 go/adbc/pkg/flightsql/driver.go            | 26 ++++++++++++++++++++------
 4 files changed, 50 insertions(+), 19 deletions(-)

diff --git a/ci/scripts/cpp_test.sh b/ci/scripts/cpp_test.sh
index a62d30e..fdd9a3f 100755
--- a/ci/scripts/cpp_test.sh
+++ b/ci/scripts/cpp_test.sh
@@ -63,6 +63,7 @@ main() {
     fi
 
     if [[ "${BUILD_DRIVER_FLIGHTSQL}" -gt 0 ]]; then
+        export GODEBUG=cgocheck=2
         test_subproject "${build_dir}" driver/flightsql
     fi
 }
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go 
b/go/adbc/driver/flightsql/flightsql_adbc.go
index c0b70fa..81952c7 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -785,7 +785,8 @@ func (c *cnxn) GetObjects(ctx context.Context, depth 
adbc.ObjectDepth, catalog *
        for rdr.Next() {
                arr := rdr.Record().Column(0).(*array.String)
                for i := 0; i < arr.Len(); i++ {
-                       catalogName := arr.Value(i)
+                       // XXX: force copy since accessor is unsafe
+                       catalogName := string([]byte(arr.Value(i)))
                        g.appendCatalog(catalogName)
                }
        }
@@ -916,6 +917,7 @@ func (g *getObjects) release() {
 
 func (g *getObjects) finish() (array.RecordReader, error) {
        record := g.builder.NewRecord()
+       defer record.Release()
        result, err := array.NewRecordReader(g.builder.Schema(), 
[]arrow.Record{record})
        if err != nil {
                return nil, adbc.Error{
@@ -1080,9 +1082,9 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, 
depth adbc.ObjectDepth,
                for i := 0; i < catalog.Len(); i++ {
                        catalogName := ""
                        if !catalog.IsNull(i) {
-                               catalogName = catalog.Value(i)
+                               catalogName = string([]byte(catalog.Value(i)))
                        }
-                       result[catalogName] = append(result[catalogName], 
dbSchema.Value(i))
+                       result[catalogName] = append(result[catalogName], 
string([]byte(dbSchema.Value(i))))
                }
        }
 
@@ -1141,10 +1143,10 @@ func (c *cnxn) getObjectsTables(ctx context.Context, 
depth adbc.ObjectDepth, cat
                        catalogName := ""
                        dbSchemaName := ""
                        if !catalog.IsNull(i) {
-                               catalogName = catalog.Value(i)
+                               catalogName = string([]byte(catalog.Value(i)))
                        }
                        if !dbSchema.IsNull(i) {
-                               dbSchemaName = dbSchema.Value(i)
+                               dbSchemaName = string([]byte(dbSchema.Value(i)))
                        }
                        key := catalogAndSchema{
                                catalog: catalogName,
@@ -1165,8 +1167,8 @@ func (c *cnxn) getObjectsTables(ctx context.Context, 
depth adbc.ObjectDepth, cat
                        }
 
                        result[key] = append(result[key], tableInfo{
-                               name:      tableName.Value(i),
-                               tableType: tableType.Value(i),
+                               name:      string([]byte(tableName.Value(i))),
+                               tableType: string([]byte(tableType.Value(i))),
                                schema:    schema,
                        })
                }
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 8920915..d81bb90 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -41,9 +41,10 @@ import (
        "github.com/apache/arrow-adbc/go/adbc"
        "github.com/apache/arrow/go/v12/arrow/array"
        "github.com/apache/arrow/go/v12/arrow/cdata"
+       "github.com/apache/arrow/go/v12/arrow/memory/mallocator"
 )
 
-var drv = {{.Driver}}{}
+var drv = {{.Driver}}{Alloc: mallocator.NewMallocator()}
 
 const errPrefix = "[{{.Prefix}}] "
 
@@ -617,11 +618,16 @@ func {{.Prefix}}StatementSetOption(stmt 
*C.struct_AdbcStatement, key, value *C.c
 
 //export releasePartitions
 func releasePartitions(partitions *C.struct_AdbcPartitions) {
+       if partitions.private_data == nil {
+               return
+       }
+
        C.free(unsafe.Pointer(partitions.partitions))
        C.free(unsafe.Pointer(partitions.partition_lengths))
-       h := (*(*cgo.Handle)(partitions.private_data))
        C.free(partitions.private_data)
-       h.Delete()
+       partitions.partitions = nil
+       partitions.partition_lengths = nil
+       partitions.private_data = nil
 }
 
 //export {{.Prefix}}StatementExecutePartitions
@@ -653,15 +659,23 @@ func {{.Prefix}}StatementExecutePartitions(stmt 
*C.struct_AdbcStatement, schema
        partitions.partitions = 
(**C.cuint8_t)(C.malloc(C.size_t(unsafe.Sizeof((*C.uint8_t)(nil)) * 
uintptr(part.NumPartitions))))
        partitions.partition_lengths = 
(*C.size_t)(C.malloc(C.size_t(unsafe.Sizeof(C.size_t(0)) * 
uintptr(part.NumPartitions))))
 
+       // Copy into C-allocated memory to avoid violating CGO rules
+       totalLen := 0
+       for _, p := range part.PartitionIDs {
+               totalLen += len(p)
+       }
+       partitions.private_data = C.malloc(C.size_t(totalLen))
+       dst := fromCArr[byte]((*byte)(partitions.private_data), totalLen)
+
        partIDs := fromCArr[*C.cuint8_t](partitions.partitions, 
int(partitions.num_partitions))
        partLens := fromCArr[C.size_t](partitions.partition_lengths, 
int(partitions.num_partitions))
        for i, p := range part.PartitionIDs {
-               partIDs[i] = (*C.cuint8_t)(unsafe.Pointer(&p[0]))
+               partIDs[i] = (*C.cuint8_t)(&dst[0])
+               copy(dst, p)
+               dst = dst[len(p):]
                partLens[i] = C.size_t(len(p))
        }
 
-       h := cgo.NewHandle(part)
-       partitions.private_data = createHandle(h)
        partitions.release = (*[0]byte)(C.releasePartitions)
        return C.ADBC_STATUS_OK
 }
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index b48a51c..f70d6a5 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -44,9 +44,10 @@ import (
        "github.com/apache/arrow-adbc/go/adbc/driver/flightsql"
        "github.com/apache/arrow/go/v12/arrow/array"
        "github.com/apache/arrow/go/v12/arrow/cdata"
+       "github.com/apache/arrow/go/v12/arrow/memory/mallocator"
 )
 
-var drv = flightsql.Driver{}
+var drv = flightsql.Driver{Alloc: mallocator.NewMallocator()}
 
 const errPrefix = "[FlightSQL] "
 
@@ -620,11 +621,16 @@ func FlightSQLStatementSetOption(stmt 
*C.struct_AdbcStatement, key, value *C.cch
 
 //export releasePartitions
 func releasePartitions(partitions *C.struct_AdbcPartitions) {
+       if partitions.private_data == nil {
+               return
+       }
+
        C.free(unsafe.Pointer(partitions.partitions))
        C.free(unsafe.Pointer(partitions.partition_lengths))
-       h := (*(*cgo.Handle)(partitions.private_data))
        C.free(partitions.private_data)
-       h.Delete()
+       partitions.partitions = nil
+       partitions.partition_lengths = nil
+       partitions.private_data = nil
 }
 
 //export FlightSQLStatementExecutePartitions
@@ -656,15 +662,23 @@ func FlightSQLStatementExecutePartitions(stmt 
*C.struct_AdbcStatement, schema *C
        partitions.partitions = 
(**C.cuint8_t)(C.malloc(C.size_t(unsafe.Sizeof((*C.uint8_t)(nil)) * 
uintptr(part.NumPartitions))))
        partitions.partition_lengths = 
(*C.size_t)(C.malloc(C.size_t(unsafe.Sizeof(C.size_t(0)) * 
uintptr(part.NumPartitions))))
 
+       // Copy into C-allocated memory to avoid violating CGO rules
+       totalLen := 0
+       for _, p := range part.PartitionIDs {
+               totalLen += len(p)
+       }
+       partitions.private_data = C.malloc(C.size_t(totalLen))
+       dst := fromCArr[byte]((*byte)(partitions.private_data), totalLen)
+
        partIDs := fromCArr[*C.cuint8_t](partitions.partitions, 
int(partitions.num_partitions))
        partLens := fromCArr[C.size_t](partitions.partition_lengths, 
int(partitions.num_partitions))
        for i, p := range part.PartitionIDs {
-               partIDs[i] = (*C.cuint8_t)(unsafe.Pointer(&p[0]))
+               partIDs[i] = (*C.cuint8_t)(&dst[0])
+               copy(dst, p)
+               dst = dst[len(p):]
                partLens[i] = C.size_t(len(p))
        }
 
-       h := cgo.NewHandle(part)
-       partitions.private_data = createHandle(h)
        partitions.release = (*[0]byte)(C.releasePartitions)
        return C.ADBC_STATUS_OK
 }

Reply via email to