This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 30e8856dc feat(go/adbc/sqldriver): handle timestamp/time.Time values 
for input (#3109)
30e8856dc is described below

commit 30e8856dc89a41d2e412a7ac5279053b753c6d24
Author: Matt Topol <[email protected]>
AuthorDate: Wed Jul 9 14:51:11 2025 -0400

    feat(go/adbc/sqldriver): handle timestamp/time.Time values for input (#3109)
    
    fixes #3103
    
    Adding new cases to `arrFromVal` to allow for handling `time.Time` and
    `arrow.Timestamp` and `arrow.Time32`/`arrow.Time64` types. This only
    works when the parameter schema is provided by the FlightSQL server side
    when utilizing a prepared statement or otherwise.
    
    If we don't have the parameter schema, then it will error as usual.
---
 go/adbc/sqldriver/driver.go                | 95 +++++++++++++++++++++++++-----
 go/adbc/sqldriver/driver_internals_test.go | 69 +++++++++++++++++++++-
 2 files changed, 147 insertions(+), 17 deletions(-)

diff --git a/go/adbc/sqldriver/driver.go b/go/adbc/sqldriver/driver.go
index 9515fda6e..03b6eecb6 100644
--- a/go/adbc/sqldriver/driver.go
+++ b/go/adbc/sqldriver/driver.go
@@ -413,10 +413,9 @@ func (s *stmt) CheckNamedValue(val *driver.NamedValue) 
error {
        return nil
 }
 
-func arrFromVal(val any) arrow.Array {
+func arrFromVal(val any, dt arrow.DataType) (arrow.Array, error) {
        var (
                buffers = make([]*memory.Buffer, 2)
-               dt      arrow.DataType
        )
        switch v := val.(type) {
        case bool:
@@ -459,17 +458,65 @@ func arrFromVal(val any) arrow.Array {
                dt = arrow.PrimitiveTypes.Date64
                buffers[1] = 
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
        case []byte:
-               dt = arrow.BinaryTypes.Binary
-               buffers[1] = 
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
+               if dt == nil || dt.ID() == arrow.BINARY {
+                       dt = arrow.BinaryTypes.Binary
+                       buffers[1] = 
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
+               } else if dt.ID() == arrow.LARGE_BINARY {
+                       dt = arrow.BinaryTypes.LargeBinary
+                       buffers[1] = 
memory.NewBufferBytes(arrow.Int64Traits.CastToBytes([]int64{0, int64(len(v))}))
+               }
                buffers = append(buffers, memory.NewBufferBytes(v))
        case string:
-               dt = arrow.BinaryTypes.String
-               buffers[1] = 
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
-
+               if dt == nil || dt.ID() == arrow.STRING {
+                       dt = arrow.BinaryTypes.String
+                       buffers[1] = 
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
+               } else if dt.ID() == arrow.LARGE_STRING {
+                       dt = arrow.BinaryTypes.LargeString
+                       buffers[1] = 
memory.NewBufferBytes(arrow.Int64Traits.CastToBytes([]int64{0, int64(len(v))}))
+               }
                buf := unsafe.Slice(unsafe.StringData(v), len(v))
                buffers = append(buffers, memory.NewBufferBytes(buf))
+       case arrow.Time32:
+               if dt == nil || dt.ID() != arrow.TIME32 {
+                       return nil, errors.New("can only create array from 
arrow.Time32 with a provided parameter schema")
+               }
+
+               buffers[1] = 
memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&v))[:])
+       case arrow.Time64:
+               if dt == nil || dt.ID() != arrow.TIME64 {
+                       return nil, errors.New("can only create array from 
arrow.Time64 with a provided parameter schema")
+               }
+
+               buffers[1] = 
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
+       case arrow.Timestamp:
+               if dt == nil || dt.ID() != arrow.TIMESTAMP {
+                       return nil, errors.New("can only create array from 
arrow.Timestamp with a provided parameter schema")
+               }
+
+               buffers[1] = 
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
+       case time.Time:
+               if dt == nil {
+                       return nil, errors.New("can only create array from 
time.Time with a provided parameter schema")
+               }
+
+               switch dt.ID() {
+               case arrow.DATE32:
+                       val := arrow.Date32FromTime(v)
+                       buffers[1] = 
memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&val))[:])
+               case arrow.DATE64:
+                       val := arrow.Date64FromTime(v)
+                       buffers[1] = 
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&val))[:])
+               case arrow.TIMESTAMP:
+                       val, err := arrow.TimestampFromTime(v, 
dt.(*arrow.TimestampType).Unit)
+                       if err != nil {
+                               return nil, fmt.Errorf("could not convert 
time.Time to arrow.Timestamp: %v", err)
+                       }
+                       buffers[1] = 
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&val))[:])
+               default:
+                       return nil, fmt.Errorf("time.Time with type %s 
unsupported", dt)
+               }
        default:
-               panic(fmt.Sprintf("unsupported type %T", val))
+               return nil, fmt.Errorf("unsupported type %T", val)
        }
        for _, b := range buffers {
                if b != nil {
@@ -478,10 +525,10 @@ func arrFromVal(val any) arrow.Array {
        }
        data := array.NewData(dt, 1, buffers, nil, 0, 0)
        defer data.Release()
-       return array.MakeFromData(data)
+       return array.MakeFromData(data), nil
 }
 
-func createBoundRecord(values []driver.NamedValue, schema *arrow.Schema) 
arrow.Record {
+func createBoundRecord(values []driver.NamedValue, schema *arrow.Schema) 
(arrow.Record, error) {
        fields := make([]arrow.Field, len(values))
        cols := make([]arrow.Array, len(values))
        if schema == nil {
@@ -492,13 +539,16 @@ func createBoundRecord(values []driver.NamedValue, schema 
*arrow.Schema) arrow.R
                        } else {
                                f.Name = v.Name
                        }
-                       arr := arrFromVal(v.Value)
+                       arr, err := arrFromVal(v.Value, nil)
+                       if err != nil {
+                               return nil, err
+                       }
                        defer arr.Release()
                        f.Type = arr.DataType()
                        cols[v.Ordinal-1] = arr
                }
 
-               return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1)
+               return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1), 
nil
        }
 
        for _, v := range values {
@@ -514,17 +564,25 @@ func createBoundRecord(values []driver.NamedValue, schema 
*arrow.Schema) arrow.R
 
                f := &fields[idx]
                f.Name = name
-               arr := arrFromVal(v.Value)
+               arr, err := arrFromVal(v.Value, f.Type)
+               if err != nil {
+                       return nil, err
+               }
                defer arr.Release()
                f.Type = arr.DataType()
                cols[idx] = arr
        }
-       return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1)
+       return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1), nil
 }
 
 func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) 
(driver.Result, error) {
        if len(args) > 0 {
-               if err := s.stmt.Bind(ctx, createBoundRecord(args, 
s.paramSchema)); err != nil {
+               rec, err := createBoundRecord(args, s.paramSchema)
+               if err != nil {
+                       return nil, err
+               }
+
+               if err := s.stmt.Bind(ctx, rec); err != nil {
                        return nil, err
                }
        }
@@ -539,7 +597,12 @@ func (s *stmt) ExecContext(ctx context.Context, args 
[]driver.NamedValue) (drive
 
 func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) 
(driver.Rows, error) {
        if len(args) > 0 {
-               if err := s.stmt.Bind(ctx, createBoundRecord(args, 
s.paramSchema)); err != nil {
+               rec, err := createBoundRecord(args, s.paramSchema)
+               if err != nil {
+                       return nil, err
+               }
+
+               if err := s.stmt.Bind(ctx, rec); err != nil {
                        return nil, err
                }
        }
diff --git a/go/adbc/sqldriver/driver_internals_test.go 
b/go/adbc/sqldriver/driver_internals_test.go
index 539d55735..0c42e5ac2 100644
--- a/go/adbc/sqldriver/driver_internals_test.go
+++ b/go/adbc/sqldriver/driver_internals_test.go
@@ -146,6 +146,11 @@ var (
                Name: "int",
                Type: arrow.PrimitiveTypes.Int32,
        }
+
+       tstampSec, _   = arrow.TimestampFromTime(testTime, arrow.Second)
+       tstampMilli, _ = arrow.TimestampFromTime(testTime, arrow.Millisecond)
+       tstampMicro, _ = arrow.TimestampFromTime(testTime, arrow.Microsecond)
+       tstampNano, _  = arrow.TimestampFromTime(testTime, arrow.Nanosecond)
 )
 
 func TestNextRowTypes(t *testing.T) {
@@ -328,6 +333,7 @@ func TestNextRowTypes(t *testing.T) {
 func TestArrFromVal(t *testing.T) {
        tests := []struct {
                value               any
+               inputDataType       arrow.DataType
                expectedDataType    arrow.DataType
                expectedStringValue string
        }{
@@ -401,15 +407,76 @@ func TestArrFromVal(t *testing.T) {
                        expectedDataType:    arrow.BinaryTypes.Binary,
                        expectedStringValue: 
base64.StdEncoding.EncodeToString([]byte("my-string")),
                },
+               {
+                       value:               []byte("my-string"),
+                       inputDataType:       arrow.BinaryTypes.LargeBinary,
+                       expectedDataType:    arrow.BinaryTypes.LargeBinary,
+                       expectedStringValue: 
base64.StdEncoding.EncodeToString([]byte("my-string")),
+               },
                {
                        value:               "my-string",
                        expectedDataType:    arrow.BinaryTypes.String,
                        expectedStringValue: "my-string",
                },
+               {
+                       value:               "my-string",
+                       inputDataType:       arrow.BinaryTypes.LargeString,
+                       expectedDataType:    arrow.BinaryTypes.LargeString,
+                       expectedStringValue: "my-string",
+               },
+               {
+                       value:               tstampSec,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Second},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Second},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Second).Format("2006-01-02 15:04:05Z"),
+               },
+               {
+                       value:               tstampMilli,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Millisecond},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Millisecond},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Millisecond).Format("2006-01-02 15:04:05.000Z"),
+               },
+               {
+                       value:               tstampMicro,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Microsecond},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Microsecond},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Microsecond).Format("2006-01-02 15:04:05.000000Z"),
+               },
+               {
+                       value:               tstampNano,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Nanosecond},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Nanosecond},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Nanosecond).Format("2006-01-02 
15:04:05.000000000Z"),
+               },
+               {
+                       value:               testTime,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Second},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Second},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Second).Format("2006-01-02 15:04:05Z"),
+               },
+               {
+                       value:               testTime,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Millisecond},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Millisecond},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Millisecond).Format("2006-01-02 15:04:05.000Z"),
+               },
+               {
+                       value:               testTime,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Microsecond},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Microsecond},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Microsecond).Format("2006-01-02 15:04:05.000000Z"),
+               },
+               {
+                       value:               testTime,
+                       inputDataType:       &arrow.TimestampType{Unit: 
arrow.Nanosecond},
+                       expectedDataType:    &arrow.TimestampType{Unit: 
arrow.Nanosecond},
+                       expectedStringValue: 
testTime.UTC().Truncate(time.Nanosecond).Format("2006-01-02 
15:04:05.000000000Z"),
+               },
        }
        for i, test := range tests {
                t.Run(fmt.Sprintf("%d-%T", i, test.value), func(t *testing.T) {
-                       arr := arrFromVal(test.value)
+                       arr, err := arrFromVal(test.value, test.inputDataType)
+                       require.NoError(t, err)
 
                        assert.Equal(t, test.expectedDataType, arr.DataType())
                        require.Equal(t, 1, arr.Len())

Reply via email to