This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 491ab8d4 fix(go/adbc/sqldriver): Fix nil pointer panics for query 
parameters (#1342)
491ab8d4 is described below

commit 491ab8d4638391e5b60c32f1e29c5fe3bacbd0f5
Author: William <[email protected]>
AuthorDate: Tue Dec 5 20:20:34 2023 +0100

    fix(go/adbc/sqldriver): Fix nil pointer panics for query parameters (#1342)
    
    Thought I would contribute some fixes I've been using locally for the
    issues described in #1341
    
    I have no previous experience with this repository or with the Arrow
    memory model so I would say it's likely I've gotten something wrong.
    Feel free to ask me to improve on my contributions or merely take them
    as inspiration for some other fix.
    
    Resolves #1341
---
 go/adbc/sqldriver/driver.go                | 11 +++-
 go/adbc/sqldriver/driver_internals_test.go | 95 ++++++++++++++++++++++++++++++
 go/adbc/sqldriver/driver_test.go           |  6 +-
 3 files changed, 106 insertions(+), 6 deletions(-)

diff --git a/go/adbc/sqldriver/driver.go b/go/adbc/sqldriver/driver.go
index 4b83495f..775f3f78 100644
--- a/go/adbc/sqldriver/driver.go
+++ b/go/adbc/sqldriver/driver.go
@@ -22,6 +22,7 @@ import (
        "database/sql"
        "database/sql/driver"
        "errors"
+       "fmt"
        "io"
        "reflect"
        "strconv"
@@ -463,16 +464,20 @@ func arrFromVal(val any) arrow.Array {
        case []byte:
                dt = arrow.BinaryTypes.Binary
                buffers[1] = 
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
-               buffers[2] = memory.NewBufferBytes(v)
+               buffers = append(buffers, memory.NewBufferBytes(v))
        case string:
                dt = arrow.BinaryTypes.String
                buffers[1] = 
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
                var buf = *(*[]byte)(unsafe.Pointer(&v))
                (*reflect.SliceHeader)(unsafe.Pointer(&buf)).Cap = len(v)
-               buffers[2] = memory.NewBufferBytes(buf)
+               buffers = append(buffers, memory.NewBufferBytes(buf))
+       default:
+               panic(fmt.Sprintf("unsupported type %T", val))
        }
        for _, b := range buffers {
-               defer b.Release()
+               if b != nil {
+                       defer b.Release()
+               }
        }
        data := array.NewData(dt, 1, buffers, nil, 0, 0)
        defer data.Release()
diff --git a/go/adbc/sqldriver/driver_internals_test.go 
b/go/adbc/sqldriver/driver_internals_test.go
index 8e9ce565..9981a40d 100644
--- a/go/adbc/sqldriver/driver_internals_test.go
+++ b/go/adbc/sqldriver/driver_internals_test.go
@@ -19,6 +19,7 @@ package sqldriver
 
 import (
        "database/sql/driver"
+       "encoding/base64"
        "fmt"
        "strings"
        "testing"
@@ -273,3 +274,97 @@ func TestNextRowTypes(t *testing.T) {
                })
        }
 }
+
+func TestArrFromVal(t *testing.T) {
+       tests := []struct {
+               value               any
+               expectedDataType    arrow.DataType
+               expectedStringValue string
+       }{
+               {
+                       value:               true,
+                       expectedDataType:    arrow.FixedWidthTypes.Boolean,
+                       expectedStringValue: "true",
+               },
+               {
+                       value:               int8(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Int8,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               uint8(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Uint8,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               int16(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Int16,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               uint16(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Uint16,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               int32(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Int32,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               uint32(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Uint32,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               int64(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Int64,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               uint64(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Uint64,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               float32(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Float32,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               float64(1),
+                       expectedDataType:    arrow.PrimitiveTypes.Float64,
+                       expectedStringValue: "1",
+               },
+               {
+                       value:               arrow.Date32FromTime(testTime),
+                       expectedDataType:    arrow.PrimitiveTypes.Date32,
+                       expectedStringValue: testTime.UTC().Truncate(24 * 
time.Hour).Format("2006-01-02"),
+               },
+               {
+                       value:               arrow.Date64FromTime(testTime),
+                       expectedDataType:    arrow.PrimitiveTypes.Date64,
+                       expectedStringValue: testTime.UTC().Truncate(24 * 
time.Hour).Format("2006-01-02"),
+               },
+               {
+                       value:               []byte("my-string"),
+                       expectedDataType:    arrow.BinaryTypes.Binary,
+                       expectedStringValue: 
base64.StdEncoding.EncodeToString([]byte("my-string")),
+               },
+               {
+                       value:               "my-string",
+                       expectedDataType:    arrow.BinaryTypes.String,
+                       expectedStringValue: "my-string",
+               },
+       }
+       for i, test := range tests {
+               t.Run(fmt.Sprintf("%d-%T", i, test.value), func(t *testing.T) {
+                       arr := arrFromVal(test.value)
+
+                       assert.Equal(t, test.expectedDataType, arr.DataType())
+                       require.Equal(t, 1, arr.Len())
+                       assert.True(t, arr.IsValid(0))
+                       assert.Equal(t, test.expectedStringValue, 
arr.ValueStr(0))
+               })
+       }
+}
diff --git a/go/adbc/sqldriver/driver_test.go b/go/adbc/sqldriver/driver_test.go
index 121e04f3..511ce9c9 100644
--- a/go/adbc/sqldriver/driver_test.go
+++ b/go/adbc/sqldriver/driver_test.go
@@ -38,7 +38,7 @@ func Example() {
                panic(err)
        }
 
-       rows, err := db.Query("SELECT 1")
+       rows, err := db.Query("SELECT ?", 1)
        if err != nil {
                panic(err)
        }
@@ -66,8 +66,8 @@ func Example() {
        }
 
        // Output:
-       // [1]
-       // 1
+       // [?]
+       // ?
        // true true
        // 1
 }

Reply via email to