This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 491ab8d4 fix(go/adbc/sqldriver): Fix nil pointer panics for query
parameters (#1342)
491ab8d4 is described below
commit 491ab8d4638391e5b60c32f1e29c5fe3bacbd0f5
Author: William <[email protected]>
AuthorDate: Tue Dec 5 20:20:34 2023 +0100
fix(go/adbc/sqldriver): Fix nil pointer panics for query parameters (#1342)
Thought I would contribute some fixes I've been using locally for the
issues described in #1341
I have no previous experience with this repository or with the Arrow
memory model so I would say it's likely I've gotten something wrong.
Feel free to ask me to improve on my contributions or merely take them
as inspiration for some other fix.
Resolves #1341
---
go/adbc/sqldriver/driver.go | 11 +++-
go/adbc/sqldriver/driver_internals_test.go | 95 ++++++++++++++++++++++++++++++
go/adbc/sqldriver/driver_test.go | 6 +-
3 files changed, 106 insertions(+), 6 deletions(-)
diff --git a/go/adbc/sqldriver/driver.go b/go/adbc/sqldriver/driver.go
index 4b83495f..775f3f78 100644
--- a/go/adbc/sqldriver/driver.go
+++ b/go/adbc/sqldriver/driver.go
@@ -22,6 +22,7 @@ import (
"database/sql"
"database/sql/driver"
"errors"
+ "fmt"
"io"
"reflect"
"strconv"
@@ -463,16 +464,20 @@ func arrFromVal(val any) arrow.Array {
case []byte:
dt = arrow.BinaryTypes.Binary
buffers[1] =
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
- buffers[2] = memory.NewBufferBytes(v)
+ buffers = append(buffers, memory.NewBufferBytes(v))
case string:
dt = arrow.BinaryTypes.String
buffers[1] =
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
var buf = *(*[]byte)(unsafe.Pointer(&v))
(*reflect.SliceHeader)(unsafe.Pointer(&buf)).Cap = len(v)
- buffers[2] = memory.NewBufferBytes(buf)
+ buffers = append(buffers, memory.NewBufferBytes(buf))
+ default:
+ panic(fmt.Sprintf("unsupported type %T", val))
}
for _, b := range buffers {
- defer b.Release()
+ if b != nil {
+ defer b.Release()
+ }
}
data := array.NewData(dt, 1, buffers, nil, 0, 0)
defer data.Release()
diff --git a/go/adbc/sqldriver/driver_internals_test.go
b/go/adbc/sqldriver/driver_internals_test.go
index 8e9ce565..9981a40d 100644
--- a/go/adbc/sqldriver/driver_internals_test.go
+++ b/go/adbc/sqldriver/driver_internals_test.go
@@ -19,6 +19,7 @@ package sqldriver
import (
"database/sql/driver"
+ "encoding/base64"
"fmt"
"strings"
"testing"
@@ -273,3 +274,97 @@ func TestNextRowTypes(t *testing.T) {
})
}
}
+
+func TestArrFromVal(t *testing.T) {
+ tests := []struct {
+ value any
+ expectedDataType arrow.DataType
+ expectedStringValue string
+ }{
+ {
+ value: true,
+ expectedDataType: arrow.FixedWidthTypes.Boolean,
+ expectedStringValue: "true",
+ },
+ {
+ value: int8(1),
+ expectedDataType: arrow.PrimitiveTypes.Int8,
+ expectedStringValue: "1",
+ },
+ {
+ value: uint8(1),
+ expectedDataType: arrow.PrimitiveTypes.Uint8,
+ expectedStringValue: "1",
+ },
+ {
+ value: int16(1),
+ expectedDataType: arrow.PrimitiveTypes.Int16,
+ expectedStringValue: "1",
+ },
+ {
+ value: uint16(1),
+ expectedDataType: arrow.PrimitiveTypes.Uint16,
+ expectedStringValue: "1",
+ },
+ {
+ value: int32(1),
+ expectedDataType: arrow.PrimitiveTypes.Int32,
+ expectedStringValue: "1",
+ },
+ {
+ value: uint32(1),
+ expectedDataType: arrow.PrimitiveTypes.Uint32,
+ expectedStringValue: "1",
+ },
+ {
+ value: int64(1),
+ expectedDataType: arrow.PrimitiveTypes.Int64,
+ expectedStringValue: "1",
+ },
+ {
+ value: uint64(1),
+ expectedDataType: arrow.PrimitiveTypes.Uint64,
+ expectedStringValue: "1",
+ },
+ {
+ value: float32(1),
+ expectedDataType: arrow.PrimitiveTypes.Float32,
+ expectedStringValue: "1",
+ },
+ {
+ value: float64(1),
+ expectedDataType: arrow.PrimitiveTypes.Float64,
+ expectedStringValue: "1",
+ },
+ {
+ value: arrow.Date32FromTime(testTime),
+ expectedDataType: arrow.PrimitiveTypes.Date32,
+ expectedStringValue: testTime.UTC().Truncate(24 *
time.Hour).Format("2006-01-02"),
+ },
+ {
+ value: arrow.Date64FromTime(testTime),
+ expectedDataType: arrow.PrimitiveTypes.Date64,
+ expectedStringValue: testTime.UTC().Truncate(24 *
time.Hour).Format("2006-01-02"),
+ },
+ {
+ value: []byte("my-string"),
+ expectedDataType: arrow.BinaryTypes.Binary,
+ expectedStringValue:
base64.StdEncoding.EncodeToString([]byte("my-string")),
+ },
+ {
+ value: "my-string",
+ expectedDataType: arrow.BinaryTypes.String,
+ expectedStringValue: "my-string",
+ },
+ }
+ for i, test := range tests {
+ t.Run(fmt.Sprintf("%d-%T", i, test.value), func(t *testing.T) {
+ arr := arrFromVal(test.value)
+
+ assert.Equal(t, test.expectedDataType, arr.DataType())
+ require.Equal(t, 1, arr.Len())
+ assert.True(t, arr.IsValid(0))
+ assert.Equal(t, test.expectedStringValue,
arr.ValueStr(0))
+ })
+ }
+}
diff --git a/go/adbc/sqldriver/driver_test.go b/go/adbc/sqldriver/driver_test.go
index 121e04f3..511ce9c9 100644
--- a/go/adbc/sqldriver/driver_test.go
+++ b/go/adbc/sqldriver/driver_test.go
@@ -38,7 +38,7 @@ func Example() {
panic(err)
}
- rows, err := db.Query("SELECT 1")
+ rows, err := db.Query("SELECT ?", 1)
if err != nil {
panic(err)
}
@@ -66,8 +66,8 @@ func Example() {
}
// Output:
- // [1]
- // 1
+ // [?]
+ // ?
// true true
// 1
}