This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new b6b237717 feat(csharp/src/Client): Additional parameter support for 
DbCommand (#2195)
b6b237717 is described below

commit b6b237717f51713fb08b2f35c0d2688f1e63c74d
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri Sep 27 15:18:48 2024 -0700

    feat(csharp/src/Client): Additional parameter support for DbCommand (#2195)
    
    Implements support for mapping DbType.Time and DbType.Decimal. Uses
    System.Convert to support a larger number of source types.
---
 csharp/src/Client/AdbcCommand.cs                   | 207 +++++++++++++++------
 csharp/src/Client/AdbcParameter.cs                 |   8 +-
 .../test/Drivers/Interop/Snowflake/ClientTests.cs  |  19 +-
 3 files changed, 172 insertions(+), 62 deletions(-)

diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index f76c246cc..2dac9a95d 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -20,6 +20,7 @@ using System.Collections;
 using System.Collections.Generic;
 using System.Data;
 using System.Data.Common;
+using System.Data.SqlTypes;
 using System.Linq;
 using System.Threading.Tasks;
 using Apache.Arrow.Types;
@@ -231,111 +232,195 @@ namespace Apache.Arrow.Adbc.Client
                 for (int i = 0; i < fields.Length; i++)
                 {
                     AdbcParameter param = 
(AdbcParameter)_dbParameterCollection[i];
-                    ArrowType type;
                     switch (param.DbType)
                     {
                         case DbType.Binary:
-                            type = BinaryType.Default;
                             var binaryBuilder = new BinaryArray.Builder();
-                            if (param.Value == null)
+                            switch (param.Value)
                             {
-                                binaryBuilder.AppendNull();
-                            }
-                            else
-                            {
-                                
binaryBuilder.Append(((byte[])param.Value).AsSpan());
+                                case null: binaryBuilder.AppendNull(); break;
+                                case byte[] array: 
binaryBuilder.Append(array.AsSpan()); break;
+                                default: throw new 
NotSupportedException($"Values of type {param.Value.GetType().Name} cannot be 
bound as binary");
                             }
                             parameters[i] = binaryBuilder.Build();
                             break;
                         case DbType.Boolean:
-                            type = BooleanType.Default;
                             var boolBuilder = new BooleanArray.Builder();
-                            if (param.Value == null)
+                            switch (param.Value)
                             {
-                                boolBuilder.AppendNull();
-                            }
-                            else
-                            {
-                                boolBuilder.Append((bool)param.Value);
+                                case null: boolBuilder.AppendNull(); break;
+                                case bool boolValue: 
boolBuilder.Append(boolValue); break;
+                                default: 
boolBuilder.Append(ConvertValue(param.Value, Convert.ToBoolean, 
DbType.Boolean)); break;
                             }
                             parameters[i] = boolBuilder.Build();
                             break;
                         case DbType.Byte:
-                            type = UInt8Type.Default;
-                            parameters[i] = new 
UInt8Array.Builder().Append((byte?)param.Value).Build();
+                            var uint8Builder = new UInt8Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: uint8Builder.AppendNull(); break;
+                                case byte byteValue: 
uint8Builder.Append(byteValue); break;
+                                default: 
uint8Builder.Append(ConvertValue(param.Value, Convert.ToByte, DbType.Byte)); 
break;
+                            }
+                            parameters[i] = uint8Builder.Build();
                             break;
                         case DbType.Date:
-                            type = Date32Type.Default;
                             var dateBuilder = new Date32Array.Builder();
-                            if (param.Value == null)
+                            switch (param.Value)
                             {
-                                dateBuilder.AppendNull();
-                            }
+                                case null: dateBuilder.AppendNull(); break;
+                                case DateTime datetime: 
dateBuilder.Append(datetime); break;
 #if NET5_0_OR_GREATER
-                            else if (param.Value is DateOnly)
-                            {
-                                dateBuilder.Append((DateOnly)param.Value);
-                            }
+                                case DateOnly dateonly: 
dateBuilder.Append(dateonly); break;
 #endif
-                            else
-                            {
-                                dateBuilder.Append((DateTime)param.Value);
+                                default: 
dateBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.Date)); 
break;
                             }
                             parameters[i] = dateBuilder.Build();
                             break;
                         case DbType.DateTime:
-                            type = TimestampType.Default;
                             var timestampBuilder = new 
TimestampArray.Builder();
-                            if (param.Value == null)
+                            switch (param.Value)
                             {
-                                timestampBuilder.AppendNull();
+                                case null: timestampBuilder.AppendNull(); 
break;
+                                case DateTime datetime: 
timestampBuilder.Append(datetime); break;
+                                default: 
timestampBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, 
DbType.DateTime)); break;
+                            }
+                            parameters[i] = timestampBuilder.Build();
+                            break;
+                        case DbType.Decimal:
+                            var value = param.Value switch
+                            {
+                                null => (SqlDecimal?)null,
+                                SqlDecimal sqlDecimal => sqlDecimal,
+                                decimal d => new SqlDecimal(d),
+                                _ => new SqlDecimal(ConvertValue(param.Value, 
Convert.ToDecimal, DbType.Decimal)),
+                            };
+                            var decimalBuilder = new 
Decimal128Array.Builder(new Decimal128Type(value?.Precision ?? 10, value?.Scale 
?? 0));
+                            if (value is null)
+                            {
+                                decimalBuilder.AppendNull();
                             }
                             else
                             {
-                                timestampBuilder.Append((DateTime)param.Value);
+                                decimalBuilder.Append(value.Value);
                             }
+                            parameters[i] = decimalBuilder.Build();
                             break;
-                        // TODO: case DbType.Decimal:
                         case DbType.Double:
-                            type = DoubleType.Default;
-                            parameters[i] = new 
DoubleArray.Builder().Append((double?)param.Value).Build();
+                            var doubleBuilder = new DoubleArray.Builder();
+                            switch (param.Value)
+                            {
+                                case null: doubleBuilder.AppendNull(); break;
+                                case double dbl: doubleBuilder.Append(dbl); 
break;
+                                default: 
doubleBuilder.Append(ConvertValue(param.Value, Convert.ToDouble, 
DbType.Double)); break;
+                            }
+                            parameters[i] = doubleBuilder.Build();
                             break;
                         case DbType.Int16:
-                            type = Int16Type.Default;
-                            parameters[i] = new 
Int16Array.Builder().Append((short?)param.Value).Build();
+                            var int16Builder = new Int16Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: int16Builder.AppendNull(); break;
+                                case short shortValue: 
int16Builder.Append(shortValue); break;
+                                default: 
int16Builder.Append(ConvertValue(param.Value, Convert.ToInt16, DbType.Int16)); 
break;
+                            }
+                            parameters[i] = int16Builder.Build();
                             break;
                         case DbType.Int32:
-                            type = Int32Type.Default;
-                            parameters[i] = new 
Int32Array.Builder().Append((int?)param.Value).Build();
+                            var int32Builder = new Int32Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: int32Builder.AppendNull(); break;
+                                case int intValue: 
int32Builder.Append(intValue); break;
+                                default: 
int32Builder.Append(ConvertValue(param.Value, Convert.ToInt32, DbType.Int32)); 
break;
+                            }
+                            parameters[i] = int32Builder.Build();
                             break;
                         case DbType.Int64:
-                            type = Int64Type.Default;
-                            parameters[i] = new 
Int64Array.Builder().Append((long?)param.Value).Build();
+                            var int64Builder = new Int64Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: int64Builder.AppendNull(); break;
+                                case long longValue: 
int64Builder.Append(longValue); break;
+                                default: 
int64Builder.Append(ConvertValue(param.Value, Convert.ToInt64, DbType.Int64)); 
break;
+                            }
+                            parameters[i] = int64Builder.Build();
                             break;
                         case DbType.SByte:
-                            type = Int8Type.Default;
-                            parameters[i] = new 
Int8Array.Builder().Append((sbyte?)param.Value).Build();
+                            var int8Builder = new Int8Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: int8Builder.AppendNull(); break;
+                                case sbyte sbyteValue: 
int8Builder.Append(sbyteValue); break;
+                                default: 
int8Builder.Append(ConvertValue(param.Value, Convert.ToSByte, DbType.SByte)); 
break;
+                            }
+                            parameters[i] = int8Builder.Build();
                             break;
                         case DbType.Single:
-                            type = FloatType.Default;
-                            parameters[i] = new 
FloatArray.Builder().Append((float?)param.Value).Build();
+                            var floatBuilder = new FloatArray.Builder();
+                            switch (param.Value)
+                            {
+                                case null: floatBuilder.AppendNull(); break;
+                                case float floatValue: 
floatBuilder.Append(floatValue); break;
+                                default: 
floatBuilder.Append(ConvertValue(param.Value, Convert.ToSingle, 
DbType.Single)); break;
+                            }
+                            parameters[i] = floatBuilder.Build();
                             break;
                         case DbType.String:
-                            type = StringType.Default;
-                            parameters[i] = new 
StringArray.Builder().Append((string)param.Value!).Build();
+                            var stringBuilder = new StringArray.Builder();
+                            switch (param.Value)
+                            {
+                                case null: stringBuilder.AppendNull(); break;
+                                case string stringValue: 
stringBuilder.Append(stringValue); break;
+                                default: 
stringBuilder.Append(ConvertValue(param.Value, Convert.ToString, 
DbType.String)); break;
+                            }
+                            parameters[i] = stringBuilder.Build();
+                            break;
+                        case DbType.Time:
+                            var timeBuilder = new Time32Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: timeBuilder.AppendNull(); break;
+                                case DateTime datetime: 
timeBuilder.Append((int)(datetime.TimeOfDay.Ticks / 
TimeSpan.TicksPerMillisecond)); break;
+#if NET5_0_OR_GREATER
+                                case TimeOnly timeonly: 
timeBuilder.Append(timeonly); break;
+#endif
+                                default:
+                                    DateTime convertedDateTime = 
ConvertValue(param.Value, Convert.ToDateTime, DbType.Time);
+                                    
timeBuilder.Append((int)(convertedDateTime.TimeOfDay.Ticks / 
TimeSpan.TicksPerMillisecond));
+                                    break;
+                            }
+                            parameters[i] = timeBuilder.Build();
                             break;
-                        // TODO: case DbType.Time:
                         case DbType.UInt16:
-                            type = UInt16Type.Default;
-                            parameters[i] = new 
UInt16Array.Builder().Append((ushort?)param.Value).Build();
+                            var uint16Builder = new UInt16Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: uint16Builder.AppendNull(); break;
+                                case ushort ushortValue: 
uint16Builder.Append(ushortValue); break;
+                                default: 
uint16Builder.Append(ConvertValue(param.Value, Convert.ToUInt16, 
DbType.UInt16)); break;
+                            }
+                            parameters[i] = uint16Builder.Build();
                             break;
                         case DbType.UInt32:
-                            type = UInt32Type.Default;
-                            parameters[i] = new 
UInt32Array.Builder().Append((uint?)param.Value).Build();
+                            var uint32Builder = new UInt32Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: uint32Builder.AppendNull(); break;
+                                case uint uintValue: 
uint32Builder.Append(uintValue); break;
+                                default: 
uint32Builder.Append(ConvertValue(param.Value, Convert.ToUInt32, 
DbType.UInt32)); break;
+                            }
+                            parameters[i] = uint32Builder.Build();
                             break;
                         case DbType.UInt64:
-                            type = UInt64Type.Default;
-                            parameters[i] = new 
UInt64Array.Builder().Append((ulong?)param.Value).Build();
+                            var uint64Builder = new UInt64Array.Builder();
+                            switch (param.Value)
+                            {
+                                case null: uint64Builder.AppendNull(); break;
+                                case ulong ulongValue: 
uint64Builder.Append(ulongValue); break;
+                                default: 
uint64Builder.Append(ConvertValue(param.Value, Convert.ToUInt64, 
DbType.UInt64)); break;
+                            }
+                            parameters[i] = uint64Builder.Build();
                             break;
                         default:
                             throw new NotSupportedException($"Parameters of 
type {param.DbType} are not supported");
@@ -343,7 +428,7 @@ namespace Apache.Arrow.Adbc.Client
 
                     fields[i] = new Field(
                         string.IsNullOrWhiteSpace(param.ParameterName) ? 
Guid.NewGuid().ToString() : param.ParameterName,
-                        type,
+                        parameters[i].Data.DataType,
                         param.IsNullable || param.Value == null);
                 }
 
@@ -352,6 +437,18 @@ namespace Apache.Arrow.Adbc.Client
             }
         }
 
+        private static T ConvertValue<T>(object value, Func<object, T> 
converter, DbType type)
+        {
+            try
+            {
+                return converter(value);
+            }
+            catch (Exception)
+            {
+                throw new NotSupportedException($"Values of type 
{value.GetType().Name} cannot be bound as {type}.");
+            }
+        }
+
 #if NET5_0_OR_GREATER
         public override ValueTask DisposeAsync()
         {
diff --git a/csharp/src/Client/AdbcParameter.cs 
b/csharp/src/Client/AdbcParameter.cs
index 620b921c5..c816b1a0b 100644
--- a/csharp/src/Client/AdbcParameter.cs
+++ b/csharp/src/Client/AdbcParameter.cs
@@ -25,13 +25,17 @@ namespace Apache.Arrow.Adbc.Client
     sealed public class AdbcParameter : DbParameter
     {
         public override DbType DbType { get; set; }
-        public override ParameterDirection Direction { get => 
ParameterDirection.Input; set => throw new NotImplementedException(); }
+        public override ParameterDirection Direction
+        {
+            get => ParameterDirection.Input;
+            set { if (value != ParameterDirection.Input) { throw new 
NotSupportedException(); } }
+        }
         public override bool IsNullable { get; set; } = true;
 #if NET5_0_OR_GREATER
         [AllowNull]
 #endif
         public override string ParameterName { get; set; } = string.Empty;
-        public override int Size { get => throw new NotImplementedException(); 
set => throw new NotImplementedException(); }
+        public override int Size { get; set; }
 #if NET5_0_OR_GREATER
         [AllowNull]
 #endif
diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
index 942a78fe7..b5f045740 100644
--- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
@@ -149,17 +149,26 @@ namespace 
Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
         public void CanClientExecuteParameterizedQuery()
         {
             SnowflakeTestConfiguration testConfiguration = 
Utils.LoadTestConfiguration<SnowflakeTestConfiguration>(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE);
-            testConfiguration.Query = "SELECT * FROM (SELECT column1 FROM 
(VALUES (1), (2), (3))) WHERE column1 < ?";
+            testConfiguration.Query = "SELECT ? as A, ? as B, ? as C, * FROM 
(SELECT column1 FROM (VALUES (1), (2), (3))) WHERE column1 < ?";
             testConfiguration.ExpectedResultsCount = 1;
 
             using (Adbc.Client.AdbcConnection adbcConnection = 
GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration))
             {
                 Tests.ClientTests.CanClientExecuteQuery(adbcConnection, 
testConfiguration, command =>
                 {
-                    DbParameter parameter1 = command.CreateParameter();
-                    parameter1.Value = 2;
-                    parameter1.DbType = DbType.Int32;
-                    command.Parameters.Add(parameter1);
+                    DbParameter CreateParameter(DbType dbType, object value)
+                    {
+                        DbParameter result = command.CreateParameter();
+                        result.DbType = dbType;
+                        result.Value = value;
+                        return result;
+                    }
+
+                    // TODO: Add tests for decimal and time once supported by 
the driver or gosnowflake
+                    command.Parameters.Add(CreateParameter(DbType.Int32, 2));
+                    command.Parameters.Add(CreateParameter(DbType.String, 
"text"));
+                    command.Parameters.Add(CreateParameter(DbType.Double, 
2.5));
+                    command.Parameters.Add(CreateParameter(DbType.Int32, 2));
                 });
             }
         }

Reply via email to