This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d021eac6 feat(csharp/src/Drivers/Apache): extend capability of 
GetInfo for Spark driver (#1863)
3d021eac6 is described below

commit 3d021eac66aef5e03306554b0dceab81835dc14a
Author: Bruce Irschick <[email protected]>
AuthorDate: Fri May 17 15:51:09 2024 -0700

    feat(csharp/src/Drivers/Apache): extend capability of GetInfo for Spark 
driver (#1863)
    
    Extend capability of GetInfo for Spark driver
    * Adds dynamic calls to get the following from the DBMS
      * vendor name
      * vendor version
    * vendor sql (`true` - hard-coded default)
    * driver version (using file info/product version)
    
    Adds tests for supported and unsupported info.
---
 .../Drivers/Apache/Hive2/HiveServer2Connection.cs  | 91 ++++++++--------------
 csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 61 ++++++++++++---
 csharp/test/Drivers/Apache/Spark/DriverTests.cs    | 77 +++++++++++++++++-
 3 files changed, 154 insertions(+), 75 deletions(-)

diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index 9e11cec10..57bca5d1d 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -17,6 +17,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Threading;
 using System.Threading.Tasks;
 using Apache.Arrow.Ipc;
@@ -35,10 +36,18 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         internal TTransport? transport;
         internal TCLIService.Client? client;
         internal TSessionHandle? sessionHandle;
+        private readonly Lazy<string> _vendorVersion;
+        private readonly Lazy<string> _vendorName;
 
         internal HiveServer2Connection(IReadOnlyDictionary<string, string> 
properties)
         {
             this.properties = properties;
+            // Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe 
initialization where
+            // the first successful thread sets the value. If an exception is 
thrown, initialization
+            // will retry until it successfully returns a value without an 
exception.
+            // 
https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
+            _vendorVersion = new Lazy<string>(() => 
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), 
LazyThreadSafetyMode.PublicationOnly);
+            _vendorName = new Lazy<string>(() => 
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), 
LazyThreadSafetyMode.PublicationOnly);
         }
 
         internal TCLIService.Client Client
@@ -46,6 +55,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             get { return this.client ?? throw new 
InvalidOperationException("connection not open"); }
         }
 
+        protected string VendorVersion => _vendorVersion.Value;
+
+        protected string VendorName => _vendorName.Value;
+
         internal async Task OpenAsync()
         {
             TProtocol protocol = await CreateProtocolAsync();
@@ -81,6 +94,24 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             } while (statusResponse.OperationState == 
TOperationState.PENDING_STATE || statusResponse.OperationState == 
TOperationState.RUNNING_STATE);
         }
 
+        private string GetInfoTypeStringValue(TGetInfoType infoType)
+        {
+            TGetInfoReq req = new()
+            {
+                SessionHandle = this.sessionHandle ?? throw new 
InvalidOperationException("session not created"),
+                InfoType = infoType,
+            };
+
+            TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
+            if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+            {
+                throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
+                    .SetNativeError(getInfoResp.Status.ErrorCode)
+                    .SetSqlState(getInfoResp.Status.SqlState);
+            }
+
+            return getInfoResp.InfoValue.StringValue;
+        }
 
         public override void Dispose()
         {
@@ -102,65 +133,5 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             TGetResultSetMetadataResp response = 
this.Client.GetResultSetMetadata(request).Result;
             return SchemaParser.GetArrowSchema(response.Schema);
         }
-
-        sealed class GetObjectsReader : IArrowArrayStream
-        {
-            HiveServer2Connection? connection;
-            Schema schema;
-            List<TSparkArrowBatch>? batches;
-            int index;
-            IArrowReader? reader;
-
-            public GetObjectsReader(HiveServer2Connection connection, Schema 
schema)
-            {
-                this.connection = connection;
-                this.schema = schema;
-            }
-
-            public Schema Schema { get { return schema; } }
-
-            public async ValueTask<RecordBatch?> 
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
-            {
-                while (true)
-                {
-                    if (this.reader != null)
-                    {
-                        RecordBatch? next = await 
this.reader.ReadNextRecordBatchAsync(cancellationToken);
-                        if (next != null)
-                        {
-                            return next;
-                        }
-                        this.reader = null;
-                    }
-
-                    if (this.batches != null && this.index < 
this.batches.Count)
-                    {
-                        this.reader = new ArrowStreamReader(new 
ChunkStream(this.schema, this.batches[this.index++].Batch));
-                        continue;
-                    }
-
-                    this.batches = null;
-                    this.index = 0;
-
-                    if (this.connection == null)
-                    {
-                        return null;
-                    }
-
-                    TFetchResultsReq request = new 
TFetchResultsReq(this.connection.operationHandle, TFetchOrientation.FETCH_NEXT, 
50000);
-                    TFetchResultsResp response = await 
this.connection.Client.FetchResults(request, cancellationToken);
-                    this.batches = response.Results.ArrowBatches;
-
-                    if (!response.HasMoreRows)
-                    {
-                        this.connection = null;
-                    }
-                }
-            }
-
-            public void Dispose()
-            {
-            }
-        }
     }
 }
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index 36d62f89f..3e3bbcae2 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -20,6 +20,7 @@ using System.Collections.Generic;
 using System.Diagnostics;
 using System.Net.Http;
 using System.Net.Http.Headers;
+using System.Reflection;
 using System.Text;
 using System.Text.RegularExpressions;
 using System.Threading;
@@ -43,16 +44,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             AdbcInfoCode.DriverName,
             AdbcInfoCode.DriverVersion,
             AdbcInfoCode.DriverArrowVersion,
-            AdbcInfoCode.VendorName
+            AdbcInfoCode.VendorName,
+            AdbcInfoCode.VendorSql,
+            AdbcInfoCode.VendorVersion,
         };
 
+        const string ProductVersionDefault = "1.0.0";
         const string InfoDriverName = "ADBC Spark Driver";
-        const string InfoDriverVersion = "1.0.0";
-        const string InfoVendorName = "Spark";
         const string InfoDriverArrowVersion = "1.0.0";
+        const bool InfoVendorSql = true;
         const int DecimalPrecisionDefault = 10;
         const int DecimalScaleDefault = 0;
 
+        private readonly Lazy<string> _productVersion;
+
         internal static TSparkGetDirectResults sparkGetDirectResults = new 
TSparkGetDirectResults(1000);
 
         internal static readonly Dictionary<string, string> timestampConfig = 
new Dictionary<string, string>
@@ -83,8 +88,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
         internal SparkConnection(IReadOnlyDictionary<string, string> 
properties)
             : base(properties)
         {
+            _productVersion = new Lazy<string>(() => GetProductVersion(), 
LazyThreadSafetyMode.PublicationOnly);
         }
 
+        protected string ProductVersion => _productVersion.Value;
+
         protected override async ValueTask<TProtocol> CreateProtocolAsync()
         {
             Trace.TraceError($"create protocol with {properties.Count} 
properties.");
@@ -137,6 +145,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
         public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> 
codes)
         {
             const int strValTypeID = 0;
+            const int boolValTypeId = 1;
 
             UnionType infoUnionType = new UnionType(
                 new Field[]
@@ -178,8 +187,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             ArrowBuffer.Builder<byte> typeBuilder = new 
ArrowBuffer.Builder<byte>();
             ArrowBuffer.Builder<int> offsetBuilder = new 
ArrowBuffer.Builder<int>();
             StringArray.Builder stringInfoBuilder = new StringArray.Builder();
+            BooleanArray.Builder booleanInfoBuilder = new 
BooleanArray.Builder();
+
             int nullCount = 0;
             int arrayLength = codes.Count;
+            int offset = 0;
 
             foreach (AdbcInfoCode code in codes)
             {
@@ -188,32 +200,53 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
                     case AdbcInfoCode.DriverName:
                         infoNameBuilder.Append((UInt32)code);
                         typeBuilder.Append(strValTypeID);
-                        offsetBuilder.Append(stringInfoBuilder.Length);
+                        offsetBuilder.Append(offset++);
                         stringInfoBuilder.Append(InfoDriverName);
+                        booleanInfoBuilder.AppendNull();
                         break;
                     case AdbcInfoCode.DriverVersion:
                         infoNameBuilder.Append((UInt32)code);
                         typeBuilder.Append(strValTypeID);
-                        offsetBuilder.Append(stringInfoBuilder.Length);
-                        stringInfoBuilder.Append(InfoDriverVersion);
+                        offsetBuilder.Append(offset++);
+                        stringInfoBuilder.Append(ProductVersion);
+                        booleanInfoBuilder.AppendNull();
                         break;
                     case AdbcInfoCode.DriverArrowVersion:
                         infoNameBuilder.Append((UInt32)code);
                         typeBuilder.Append(strValTypeID);
-                        offsetBuilder.Append(stringInfoBuilder.Length);
+                        offsetBuilder.Append(offset++);
                         stringInfoBuilder.Append(InfoDriverArrowVersion);
+                        booleanInfoBuilder.AppendNull();
                         break;
                     case AdbcInfoCode.VendorName:
                         infoNameBuilder.Append((UInt32)code);
                         typeBuilder.Append(strValTypeID);
-                        offsetBuilder.Append(stringInfoBuilder.Length);
-                        stringInfoBuilder.Append(InfoVendorName);
+                        offsetBuilder.Append(offset++);
+                        string vendorName = VendorName;
+                        stringInfoBuilder.Append(vendorName);
+                        booleanInfoBuilder.AppendNull();
+                        break;
+                    case AdbcInfoCode.VendorVersion:
+                        infoNameBuilder.Append((UInt32)code);
+                        typeBuilder.Append(strValTypeID);
+                        offsetBuilder.Append(offset++);
+                        string? vendorVersion = VendorVersion;
+                        stringInfoBuilder.Append(vendorVersion);
+                        booleanInfoBuilder.AppendNull();
+                        break;
+                    case AdbcInfoCode.VendorSql:
+                        infoNameBuilder.Append((UInt32)code);
+                        typeBuilder.Append(boolValTypeId);
+                        offsetBuilder.Append(offset++);
+                        stringInfoBuilder.AppendNull();
+                        booleanInfoBuilder.Append(InfoVendorSql);
                         break;
                     default:
                         infoNameBuilder.Append((UInt32)code);
                         typeBuilder.Append(strValTypeID);
-                        offsetBuilder.Append(stringInfoBuilder.Length);
+                        offsetBuilder.Append(offset++);
                         stringInfoBuilder.AppendNull();
+                        booleanInfoBuilder.AppendNull();
                         nullCount++;
                         break;
                 }
@@ -231,7 +264,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             IArrowArray[] childrenArrays = new IArrowArray[]
             {
                 stringInfoBuilder.Build(),
-                new BooleanArray.Builder().Build(),
+                booleanInfoBuilder.Build(),
                 new Int64Array.Builder().Build(),
                 new Int32Array.Builder().Build(),
                 new ListArray.Builder(StringType.Default).Build(),
@@ -749,6 +782,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
                 return true;
             }
         }
+
+        private string GetProductVersion()
+        {
+            FileVersionInfo fileVersionInfo = 
FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location);
+            return fileVersionInfo.ProductVersion ?? ProductVersionDefault;
+        }
     }
 
     internal struct TableInfoPair
diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs 
b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
index a4f3a4607..a7507e473 100644
--- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
@@ -84,12 +84,30 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
         {
             AdbcConnection adbcConnection = NewConnection();
 
-            using IArrowArrayStream stream = adbcConnection.GetInfo(new 
List<AdbcInfoCode>() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, 
AdbcInfoCode.VendorName });
+            // Test the supported info codes
+            List<AdbcInfoCode> handledCodes = new List<AdbcInfoCode>()
+            {
+                AdbcInfoCode.DriverName,
+                AdbcInfoCode.DriverVersion,
+                AdbcInfoCode.VendorName,
+                AdbcInfoCode.DriverArrowVersion,
+                AdbcInfoCode.VendorVersion,
+                AdbcInfoCode.VendorSql
+            };
+            using IArrowArrayStream stream = 
adbcConnection.GetInfo(handledCodes);
 
             RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync();
             UInt32Array infoNameArray = 
(UInt32Array)recordBatch.Column("info_name");
 
-            List<string> expectedValues = new List<string>() { "DriverName", 
"DriverVersion", "VendorName" };
+            List<string> expectedValues = new List<string>()
+            {
+                "DriverName",
+                "DriverVersion",
+                "VendorName",
+                "DriverArrowVersion",
+                "VendorVersion",
+                "VendorSql"
+            };
 
             for (int i = 0; i < infoNameArray.Length; i++)
             {
@@ -98,8 +116,59 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
 
                 Assert.Contains(value.ToString(), expectedValues);
 
-                StringArray stringArray = (StringArray)valueArray.Fields[0];
-                Console.WriteLine($"{value}={stringArray.GetString(i)}");
+                switch (value)
+                {
+                    case AdbcInfoCode.VendorSql:
+                        // TODO: How does external developer know the second 
field is the boolean field?
+                        BooleanArray booleanArray = 
(BooleanArray)valueArray.Fields[1];
+                        bool? boolValue = booleanArray.GetValue(i);
+                        OutputHelper?.WriteLine($"{value}={boolValue}");
+                        Assert.True(boolValue);
+                        break;
+                    default:
+                        StringArray stringArray = 
(StringArray)valueArray.Fields[0];
+                        string stringValue = stringArray.GetString(i);
+                        OutputHelper?.WriteLine($"{value}={stringValue}");
+                        Assert.NotNull(stringValue);
+                        break;
+                }
+            }
+
+            // Test the unhandled info codes.
+            List<AdbcInfoCode> unhandledCodes = new List<AdbcInfoCode>()
+            {
+                AdbcInfoCode.VendorArrowVersion,
+                AdbcInfoCode.VendorSubstrait,
+                AdbcInfoCode.VendorSubstraitMaxVersion
+            };
+            using IArrowArrayStream stream2 = 
adbcConnection.GetInfo(unhandledCodes);
+
+            recordBatch = await stream2.ReadNextRecordBatchAsync();
+            infoNameArray = (UInt32Array)recordBatch.Column("info_name");
+
+            List<string> unexpectedValues = new List<string>()
+            {
+                "VendorArrowVersion",
+                "VendorSubstrait",
+                "VendorSubstraitMaxVersion"
+            };
+            for (int i = 0; i < infoNameArray.Length; i++)
+            {
+                AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i);
+                DenseUnionArray valueArray = 
(DenseUnionArray)recordBatch.Column("info_value");
+
+                Assert.Contains(value.ToString(), unexpectedValues);
+                switch (value)
+                {
+                    case AdbcInfoCode.VendorSql:
+                        BooleanArray booleanArray = 
(BooleanArray)valueArray.Fields[1];
+                        Assert.Null(booleanArray.GetValue(i));
+                        break;
+                    default:
+                        StringArray stringArray = 
(StringArray)valueArray.Fields[0];
+                        Assert.Null(stringArray.GetString(i));
+                        break;
+                }
             }
         }
 

Reply via email to