This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 653683ed fix(go/adbc/sqldriver): fix handling of decimal types (#970)
653683ed is described below

commit 653683ede04b284974070876e0714dc1e07291b1
Author: Matt Topol <[email protected]>
AuthorDate: Tue Aug 15 11:21:26 2023 -0400

    fix(go/adbc/sqldriver): fix handling of decimal types (#970)
    
    Fixes #969.
---
 go/adbc/sqldriver/driver.go                |  8 ++++++++
 go/adbc/sqldriver/driver_internals_test.go | 26 ++++++++++++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/go/adbc/sqldriver/driver.go b/go/adbc/sqldriver/driver.go
index ce008e19..f7f5c08c 100644
--- a/go/adbc/sqldriver/driver.go
+++ b/go/adbc/sqldriver/driver.go
@@ -361,6 +361,10 @@ func isCorrectParamType(typ arrow.Type, val driver.Value) 
bool {
                return checkType[arrow.Time64](val)
        case arrow.TIMESTAMP:
                return checkType[arrow.Timestamp](val)
+       case arrow.DECIMAL128:
+               return checkType[decimal128.Num](val)
+       case arrow.DECIMAL256:
+               return checkType[decimal256.Num](val)
        }
        // TODO: add more types here
        return true
@@ -639,6 +643,10 @@ func (r *rows) Next(dest []driver.Value) error {
                        dest[i] = 
col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time64Type).Unit)
                case *array.Timestamp:
                        dest[i] = 
col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.TimestampType).Unit)
+               case *array.Decimal128:
+                       dest[i] = col.Value(int(r.curRow))
+               case *array.Decimal256:
+                       dest[i] = col.Value(int(r.curRow))
                default:
                        return &adbc.Error{
                                Code: adbc.StatusNotImplemented,
diff --git a/go/adbc/sqldriver/driver_internals_test.go 
b/go/adbc/sqldriver/driver_internals_test.go
index b8bb7563..9d6f775e 100644
--- a/go/adbc/sqldriver/driver_internals_test.go
+++ b/go/adbc/sqldriver/driver_internals_test.go
@@ -27,6 +27,8 @@ import (
        "github.com/apache/arrow-adbc/go/adbc"
        "github.com/apache/arrow/go/v13/arrow"
        "github.com/apache/arrow/go/v13/arrow/array"
+       "github.com/apache/arrow/go/v13/arrow/decimal128"
+       "github.com/apache/arrow/go/v13/arrow/decimal256"
        "github.com/apache/arrow/go/v13/arrow/memory"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
@@ -111,6 +113,14 @@ func TestColumnTypeDatabaseTypeName(t *testing.T) {
                        typ:      &arrow.DurationType{Unit: arrow.Nanosecond},
                        typeName: "duration[ns]",
                },
+               {
+                       typ:      &arrow.Decimal128Type{Precision: 9, Scale: 2},
+                       typeName: "decimal(9, 2)",
+               },
+               {
+                       typ:      &arrow.Decimal256Type{Precision: 28, Scale: 
4},
+                       typeName: "decimal256(28, 4)",
+               },
        }
 
        for i, test := range tests {
@@ -227,6 +237,22 @@ func TestNextRowTypes(t *testing.T) {
                        },
                        golangValue: time.Date(1970, time.January, 1, 
testTime.Hour(), testTime.Minute(), testTime.Second(), testTime.Nanosecond(), 
time.UTC),
                },
+               {
+                       arrowType: &arrow.Decimal128Type{Precision: 9, Scale: 
2},
+                       arrowValueFunc: func(t *testing.T, b array.Builder) {
+                               t.Helper()
+                               
b.(*array.Decimal128Builder).Append(decimal128.FromU64(10))
+                       },
+                       golangValue: decimal128.FromU64(10),
+               },
+               {
+                       arrowType: &arrow.Decimal256Type{Precision: 10, Scale: 
5},
+                       arrowValueFunc: func(t *testing.T, b array.Builder) {
+                               t.Helper()
+                               
b.(*array.Decimal256Builder).Append(decimal256.FromU64(10))
+                       },
+                       golangValue: decimal256.FromU64(10),
+               },
        }
 
        for i, test := range tests {

Reply via email to