This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new cac2d836 fix(go/adbc/driver/snowflake): fix XDBC support when using 
high precision (#1311)
cac2d836 is described below

commit cac2d836c0371f86add8de0fe6f3d90293c808b0
Author: davidhcoe <[email protected]>
AuthorDate: Tue Nov 21 10:58:16 2023 -0500

    fix(go/adbc/driver/snowflake): fix XDBC support when using high precision 
(#1311)
    
    - Fixes https://github.com/apache/arrow-adbc/issues/1309
    - Fixes a namespace
    - Adds a test for Xdbc in the C# code
    - hanging merge fix from rebase
    
    ---------
    
    Co-authored-by: David Coe <[email protected]>
---
 csharp/test/Drivers/Snowflake/DriverTests.cs | 15 ++++++++++++++-
 csharp/test/Drivers/Snowflake/ValueTests.cs  | 10 +---------
 go/adbc/driver/snowflake/connection.go       | 17 ++++++++++++-----
 3 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/csharp/test/Drivers/Snowflake/DriverTests.cs 
b/csharp/test/Drivers/Snowflake/DriverTests.cs
index a6c8818c..1d80ce2b 100644
--- a/csharp/test/Drivers/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Snowflake/DriverTests.cs
@@ -330,8 +330,21 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
                 .FirstOrDefault();
 
             Assert.True(columns != null, "Columns cannot be null");
+            Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, 
columns.Count);
 
-            Assert.Equal(_testConfiguration.Metadata.ExpectedColumnCount, 
columns.Count);
+            if (testConfiguration.UseHighPrecision)
+            {
+                IEnumerable<AdbcColumn> highPrecisionColumns = columns.Where(c 
=> c.XdbcTypeName == "NUMBER");
+
+                if(highPrecisionColumns.Count() > 0)
+                {
+                    // ensure they all are coming back as 
XdbcDataType_XDBC_DECIMAL because they are Decimal128
+                    short XdbcDataType_XDBC_DECIMAL = 3;
+                    IEnumerable<AdbcColumn> invalidHighPrecisionColumns  = 
highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
+                    int count = invalidHighPrecisionColumns.Count();
+                    Assert.True(count == 0, $"There are {count} columns that 
do not map to the correct XdbcSqlDataType when UseHighPrecision=true");
+                }
+            }
         }
 
         /// <summary>
diff --git a/csharp/test/Drivers/Snowflake/ValueTests.cs 
b/csharp/test/Drivers/Snowflake/ValueTests.cs
index f6fa5937..1294dfde 100644
--- a/csharp/test/Drivers/Snowflake/ValueTests.cs
+++ b/csharp/test/Drivers/Snowflake/ValueTests.cs
@@ -18,12 +18,11 @@
 using System;
 using System.Collections.Generic;
 using System.Data.SqlTypes;
-using Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake;
 using Apache.Arrow.Ipc;
 using Apache.Arrow.Types;
 using Xunit;
 
-namespace Apache.Arrow.Adbc.Tests
+namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
 {
     // TODO: When supported, use prepared statements instead of SQL string 
literals
     //      Which will better test how the driver handles values sent/received
@@ -239,13 +238,6 @@ namespace Apache.Arrow.Adbc.Tests
                     return "'-inf'";
                 case double.NaN:
                     return "'NaN'";
-#if NET472
-                // Standard Double.ToString() calls round up the max value, 
resulting in Snowflake storing infinity
-                case double.MaxValue:
-                    return "1.7976931348623157E+308";
-                case double.MinValue:
-                    return "-1.7976931348623157E+308";
-#endif
                 default:
                     return value.ToString();
             }
diff --git a/go/adbc/driver/snowflake/connection.go 
b/go/adbc/driver/snowflake/connection.go
index 69160203..1b2d72ef 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -323,15 +323,22 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, 
depth adbc.ObjectDepth,
 
 var loc = time.Now().Location()
 
-func toField(name string, isnullable bool, dataType string, numPrec, 
numPrecRadix, numScale sql.NullInt16, isIdent bool, identGen, identInc 
sql.NullString, charMaxLength, charOctetLength sql.NullInt32, datetimePrec 
sql.NullInt16, comment sql.NullString, ordinalPos int) (ret arrow.Field) {
+func toField(name string, isnullable bool, dataType string, numPrec, 
numPrecRadix, numScale sql.NullInt16, isIdent, useHighPrecision bool, identGen, 
identInc sql.NullString, charMaxLength, charOctetLength sql.NullInt32, 
datetimePrec sql.NullInt16, comment sql.NullString, ordinalPos int) (ret 
arrow.Field) {
        ret.Name, ret.Nullable = name, isnullable
 
        switch dataType {
        case "NUMBER":
-               if !numScale.Valid || numScale.Int16 == 0 {
-                       ret.Type = arrow.PrimitiveTypes.Int64
+               if useHighPrecision {
+                       ret.Type = &arrow.Decimal128Type{
+                               Precision: int32(numPrec.Int16),
+                               Scale:     int32(numScale.Int16),
+                       }
                } else {
-                       ret.Type = arrow.PrimitiveTypes.Float64
+                       if !numScale.Valid || numScale.Int16 == 0 {
+                               ret.Type = arrow.PrimitiveTypes.Int64
+                       } else {
+                               ret.Type = arrow.PrimitiveTypes.Float64
+                       }
                }
        case "FLOAT":
                fallthrough
@@ -639,7 +646,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth 
adbc.ObjectDepth, cat
                        }
 
                        prevKey = key
-                       fieldList = append(fieldList, toField(colName, 
isNullable, dataType, numericPrec, numericPrecRadix, numericScale, isIdent, 
identGen, identIncrement, charMaxLength, charOctetLength, datetimePrec, 
comment, ordinalPos))
+                       fieldList = append(fieldList, toField(colName, 
isNullable, dataType, numericPrec, numericPrecRadix, numericScale, isIdent, 
c.useHighPrecision, identGen, identIncrement, charMaxLength, charOctetLength, 
datetimePrec, comment, ordinalPos))
                }
 
                if len(fieldList) > 0 && curTableInfo != nil {

Reply via email to