This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 3d021eac6 feat(csharp/src/Drivers/Apache): extend capability of
GetInfo for Spark driver (#1863)
3d021eac6 is described below
commit 3d021eac66aef5e03306554b0dceab81835dc14a
Author: Bruce Irschick <[email protected]>
AuthorDate: Fri May 17 15:51:09 2024 -0700
feat(csharp/src/Drivers/Apache): extend capability of GetInfo for Spark
driver (#1863)
Extend capability of GetInfo for Spark driver
* Adds dynamic calls to get the following from the DBMS
* vendor name
* vendor version
* vendor sql (`true` - hard-coded default)
* driver version (using file info/product version)
Adds tests for supported and unsupported info.
---
.../Drivers/Apache/Hive2/HiveServer2Connection.cs | 91 ++++++++--------------
csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 61 ++++++++++++---
csharp/test/Drivers/Apache/Spark/DriverTests.cs | 77 +++++++++++++++++-
3 files changed, 154 insertions(+), 75 deletions(-)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index 9e11cec10..57bca5d1d 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
@@ -35,10 +36,18 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
internal TTransport? transport;
internal TCLIService.Client? client;
internal TSessionHandle? sessionHandle;
+ private readonly Lazy<string> _vendorVersion;
+ private readonly Lazy<string> _vendorName;
internal HiveServer2Connection(IReadOnlyDictionary<string, string>
properties)
{
this.properties = properties;
+ // Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe
initialization where
+ // the first successful thread sets the value. If an exception is
thrown, initialization
+ // will retry until it successfully returns a value without an
exception.
+ //
https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
+ _vendorVersion = new Lazy<string>(() =>
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER),
LazyThreadSafetyMode.PublicationOnly);
+ _vendorName = new Lazy<string>(() =>
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME),
LazyThreadSafetyMode.PublicationOnly);
}
internal TCLIService.Client Client
@@ -46,6 +55,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
get { return this.client ?? throw new
InvalidOperationException("connection not open"); }
}
+ protected string VendorVersion => _vendorVersion.Value;
+
+ protected string VendorName => _vendorName.Value;
+
internal async Task OpenAsync()
{
TProtocol protocol = await CreateProtocolAsync();
@@ -81,6 +94,24 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
} while (statusResponse.OperationState ==
TOperationState.PENDING_STATE || statusResponse.OperationState ==
TOperationState.RUNNING_STATE);
}
+ private string GetInfoTypeStringValue(TGetInfoType infoType)
+ {
+ TGetInfoReq req = new()
+ {
+ SessionHandle = this.sessionHandle ?? throw new
InvalidOperationException("session not created"),
+ InfoType = infoType,
+ };
+
+ TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
+ if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+ {
+ throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
+ .SetNativeError(getInfoResp.Status.ErrorCode)
+ .SetSqlState(getInfoResp.Status.SqlState);
+ }
+
+ return getInfoResp.InfoValue.StringValue;
+ }
public override void Dispose()
{
@@ -102,65 +133,5 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
TGetResultSetMetadataResp response =
this.Client.GetResultSetMetadata(request).Result;
return SchemaParser.GetArrowSchema(response.Schema);
}
-
- sealed class GetObjectsReader : IArrowArrayStream
- {
- HiveServer2Connection? connection;
- Schema schema;
- List<TSparkArrowBatch>? batches;
- int index;
- IArrowReader? reader;
-
- public GetObjectsReader(HiveServer2Connection connection, Schema
schema)
- {
- this.connection = connection;
- this.schema = schema;
- }
-
- public Schema Schema { get { return schema; } }
-
- public async ValueTask<RecordBatch?>
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
- {
- while (true)
- {
- if (this.reader != null)
- {
- RecordBatch? next = await
this.reader.ReadNextRecordBatchAsync(cancellationToken);
- if (next != null)
- {
- return next;
- }
- this.reader = null;
- }
-
- if (this.batches != null && this.index <
this.batches.Count)
- {
- this.reader = new ArrowStreamReader(new
ChunkStream(this.schema, this.batches[this.index++].Batch));
- continue;
- }
-
- this.batches = null;
- this.index = 0;
-
- if (this.connection == null)
- {
- return null;
- }
-
- TFetchResultsReq request = new
TFetchResultsReq(this.connection.operationHandle, TFetchOrientation.FETCH_NEXT,
50000);
- TFetchResultsResp response = await
this.connection.Client.FetchResults(request, cancellationToken);
- this.batches = response.Results.ArrowBatches;
-
- if (!response.HasMoreRows)
- {
- this.connection = null;
- }
- }
- }
-
- public void Dispose()
- {
- }
- }
}
}
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index 36d62f89f..3e3bbcae2 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -20,6 +20,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Net.Http;
using System.Net.Http.Headers;
+using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
@@ -43,16 +44,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.DriverArrowVersion,
- AdbcInfoCode.VendorName
+ AdbcInfoCode.VendorName,
+ AdbcInfoCode.VendorSql,
+ AdbcInfoCode.VendorVersion,
};
+ const string ProductVersionDefault = "1.0.0";
const string InfoDriverName = "ADBC Spark Driver";
- const string InfoDriverVersion = "1.0.0";
- const string InfoVendorName = "Spark";
const string InfoDriverArrowVersion = "1.0.0";
+ const bool InfoVendorSql = true;
const int DecimalPrecisionDefault = 10;
const int DecimalScaleDefault = 0;
+ private readonly Lazy<string> _productVersion;
+
internal static TSparkGetDirectResults sparkGetDirectResults = new
TSparkGetDirectResults(1000);
internal static readonly Dictionary<string, string> timestampConfig =
new Dictionary<string, string>
@@ -83,8 +88,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
internal SparkConnection(IReadOnlyDictionary<string, string>
properties)
: base(properties)
{
+ _productVersion = new Lazy<string>(() => GetProductVersion(),
LazyThreadSafetyMode.PublicationOnly);
}
+ protected string ProductVersion => _productVersion.Value;
+
protected override async ValueTask<TProtocol> CreateProtocolAsync()
{
Trace.TraceError($"create protocol with {properties.Count}
properties.");
@@ -137,6 +145,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode>
codes)
{
const int strValTypeID = 0;
+ const int boolValTypeId = 1;
UnionType infoUnionType = new UnionType(
new Field[]
@@ -178,8 +187,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
ArrowBuffer.Builder<byte> typeBuilder = new
ArrowBuffer.Builder<byte>();
ArrowBuffer.Builder<int> offsetBuilder = new
ArrowBuffer.Builder<int>();
StringArray.Builder stringInfoBuilder = new StringArray.Builder();
+ BooleanArray.Builder booleanInfoBuilder = new
BooleanArray.Builder();
+
int nullCount = 0;
int arrayLength = codes.Count;
+ int offset = 0;
foreach (AdbcInfoCode code in codes)
{
@@ -188,32 +200,53 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
case AdbcInfoCode.DriverName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
- offsetBuilder.Append(stringInfoBuilder.Length);
+ offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverName);
+ booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
- offsetBuilder.Append(stringInfoBuilder.Length);
- stringInfoBuilder.Append(InfoDriverVersion);
+ offsetBuilder.Append(offset++);
+ stringInfoBuilder.Append(ProductVersion);
+ booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverArrowVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
- offsetBuilder.Append(stringInfoBuilder.Length);
+ offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverArrowVersion);
+ booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
- offsetBuilder.Append(stringInfoBuilder.Length);
- stringInfoBuilder.Append(InfoVendorName);
+ offsetBuilder.Append(offset++);
+ string vendorName = VendorName;
+ stringInfoBuilder.Append(vendorName);
+ booleanInfoBuilder.AppendNull();
+ break;
+ case AdbcInfoCode.VendorVersion:
+ infoNameBuilder.Append((UInt32)code);
+ typeBuilder.Append(strValTypeID);
+ offsetBuilder.Append(offset++);
+ string? vendorVersion = VendorVersion;
+ stringInfoBuilder.Append(vendorVersion);
+ booleanInfoBuilder.AppendNull();
+ break;
+ case AdbcInfoCode.VendorSql:
+ infoNameBuilder.Append((UInt32)code);
+ typeBuilder.Append(boolValTypeId);
+ offsetBuilder.Append(offset++);
+ stringInfoBuilder.AppendNull();
+ booleanInfoBuilder.Append(InfoVendorSql);
break;
default:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
- offsetBuilder.Append(stringInfoBuilder.Length);
+ offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
+ booleanInfoBuilder.AppendNull();
nullCount++;
break;
}
@@ -231,7 +264,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
IArrowArray[] childrenArrays = new IArrowArray[]
{
stringInfoBuilder.Build(),
- new BooleanArray.Builder().Build(),
+ booleanInfoBuilder.Build(),
new Int64Array.Builder().Build(),
new Int32Array.Builder().Build(),
new ListArray.Builder(StringType.Default).Build(),
@@ -749,6 +782,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return true;
}
}
+
+ private string GetProductVersion()
+ {
+ FileVersionInfo fileVersionInfo =
FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location);
+ return fileVersionInfo.ProductVersion ?? ProductVersionDefault;
+ }
}
internal struct TableInfoPair
diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs
b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
index a4f3a4607..a7507e473 100644
--- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
@@ -84,12 +84,30 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
{
AdbcConnection adbcConnection = NewConnection();
- using IArrowArrayStream stream = adbcConnection.GetInfo(new
List<AdbcInfoCode>() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion,
AdbcInfoCode.VendorName });
+ // Test the supported info codes
+ List<AdbcInfoCode> handledCodes = new List<AdbcInfoCode>()
+ {
+ AdbcInfoCode.DriverName,
+ AdbcInfoCode.DriverVersion,
+ AdbcInfoCode.VendorName,
+ AdbcInfoCode.DriverArrowVersion,
+ AdbcInfoCode.VendorVersion,
+ AdbcInfoCode.VendorSql
+ };
+ using IArrowArrayStream stream =
adbcConnection.GetInfo(handledCodes);
RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync();
UInt32Array infoNameArray =
(UInt32Array)recordBatch.Column("info_name");
- List<string> expectedValues = new List<string>() { "DriverName",
"DriverVersion", "VendorName" };
+ List<string> expectedValues = new List<string>()
+ {
+ "DriverName",
+ "DriverVersion",
+ "VendorName",
+ "DriverArrowVersion",
+ "VendorVersion",
+ "VendorSql"
+ };
for (int i = 0; i < infoNameArray.Length; i++)
{
@@ -98,8 +116,59 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
Assert.Contains(value.ToString(), expectedValues);
- StringArray stringArray = (StringArray)valueArray.Fields[0];
- Console.WriteLine($"{value}={stringArray.GetString(i)}");
+ switch (value)
+ {
+ case AdbcInfoCode.VendorSql:
+ // TODO: How does external developer know the second
field is the boolean field?
+ BooleanArray booleanArray =
(BooleanArray)valueArray.Fields[1];
+ bool? boolValue = booleanArray.GetValue(i);
+ OutputHelper?.WriteLine($"{value}={boolValue}");
+ Assert.True(boolValue);
+ break;
+ default:
+ StringArray stringArray =
(StringArray)valueArray.Fields[0];
+ string stringValue = stringArray.GetString(i);
+ OutputHelper?.WriteLine($"{value}={stringValue}");
+ Assert.NotNull(stringValue);
+ break;
+ }
+ }
+
+ // Test the unhandled info codes.
+ List<AdbcInfoCode> unhandledCodes = new List<AdbcInfoCode>()
+ {
+ AdbcInfoCode.VendorArrowVersion,
+ AdbcInfoCode.VendorSubstrait,
+ AdbcInfoCode.VendorSubstraitMaxVersion
+ };
+ using IArrowArrayStream stream2 =
adbcConnection.GetInfo(unhandledCodes);
+
+ recordBatch = await stream2.ReadNextRecordBatchAsync();
+ infoNameArray = (UInt32Array)recordBatch.Column("info_name");
+
+ List<string> unexpectedValues = new List<string>()
+ {
+ "VendorArrowVersion",
+ "VendorSubstrait",
+ "VendorSubstraitMaxVersion"
+ };
+ for (int i = 0; i < infoNameArray.Length; i++)
+ {
+ AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i);
+ DenseUnionArray valueArray =
(DenseUnionArray)recordBatch.Column("info_value");
+
+ Assert.Contains(value.ToString(), unexpectedValues);
+ switch (value)
+ {
+ case AdbcInfoCode.VendorSql:
+ BooleanArray booleanArray =
(BooleanArray)valueArray.Fields[1];
+ Assert.Null(booleanArray.GetValue(i));
+ break;
+ default:
+ StringArray stringArray =
(StringArray)valueArray.Fields[0];
+ Assert.Null(stringArray.GetString(i));
+ break;
+ }
}
}