This is an automated email from the ASF dual-hosted git repository. curth pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push: new f68b23a30 feat(csharp/src/Drivers/Apache): Enabled Standard protocol for Spark and used SASL transport with basic auth (#3380) f68b23a30 is described below commit f68b23a30d392ea85c6466fb1ff5a6b6b723e3e1 Author: Ryan Syed <sy.r...@icloud.com> AuthorDate: Wed Sep 3 08:55:49 2025 -0700 feat(csharp/src/Drivers/Apache): Enabled Standard protocol for Spark and used SASL transport with basic auth (#3380) 1. Enabled Spark Standard protocol 2. Used SASL transport in SparkStandardConnection with basic auth --- csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 8 ++ .../Drivers/Apache/Spark/SparkConnectionFactory.cs | 3 +- .../Drivers/Apache/Spark/SparkHttpConnection.cs | 5 -- .../Apache/Spark/SparkStandardConnection.cs | 94 +++++++++++++++++++--- .../test/Drivers/Apache/ApacheTestConfiguration.cs | 9 +++ .../Drivers/Apache/Spark/SparkTestEnvironment.cs | 27 +++++++ 6 files changed, 130 insertions(+), 16 deletions(-) diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 584380bce..bf6269c3a 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; +using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Hive.Service.Rpc.Thrift; @@ -57,6 +58,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark return statement; } + protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(IResponse response, CancellationToken cancellationToken = default) => + GetResultSetMetadataAsync(response.OperationHandle!, Client, cancellationToken); + protected override Task<TRowSet> GetRowSetAsync(IResponse response, CancellationToken cancellationToken = default) => + FetchResultsAsync(response.OperationHandle!, cancellationToken: cancellationToken); + protected internal override int PositionRequiredOffset => 1; internal override void SetPrecisionScaleAndTypeName( @@ -113,6 +119,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark protected override bool IsColumnSizeValidForDecimal => false; + internal override SchemaParser SchemaParser { get; } = new HiveServer2SchemaParser(); + protected internal override bool TrySetGetDirectResults(IRequest request) { request.GetDirectResults = sparkGetDirectResults; diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs index 484e51bc9..deeb82600 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs @@ -36,8 +36,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark return serverTypeValue switch { SparkServerType.Http => new SparkHttpConnection(properties), - // TODO: Re-enable when properly supported - //SparkServerType.Standard => new SparkStandardConnection(properties), + SparkServerType.Standard => new SparkStandardConnection(properties), _ => throw new ArgumentOutOfRangeException(nameof(properties), $"Unsupported or unknown value '{type}' given for property '{SparkParameters.Type}'. Supported types: {ServerTypeParser.SupportedList}"), }; } diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index 1e34991e9..b3710314d 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -241,11 +241,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark return req; } - protected override Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(IResponse response, CancellationToken cancellationToken = default) => - GetResultSetMetadataAsync(response.OperationHandle!, Client, cancellationToken); - protected override Task<TRowSet> GetRowSetAsync(IResponse response, CancellationToken cancellationToken = default) => - FetchResultsAsync(response.OperationHandle!, cancellationToken: cancellationToken); - internal override SchemaParser SchemaParser => new HiveServer2SchemaParser(); internal override SparkServerType ServerType => SparkServerType.Http; diff --git a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs index 25d808574..f7438fe47 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs @@ -18,8 +18,12 @@ using System; using System.Collections.Generic; using System.Net; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Protocol; using Thrift.Transport; @@ -27,7 +31,7 @@ using Thrift.Transport.Client; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { - internal class SparkStandardConnection : SparkHttpConnection + internal class SparkStandardConnection : SparkConnection { public SparkStandardConnection(IReadOnlyDictionary<string, string> properties) : base(properties) { @@ -82,12 +86,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark // Validate port range Properties.TryGetValue(SparkParameters.Port, out string? port); + if (string.IsNullOrWhiteSpace(port)) + { + throw new ArgumentException( + $"Required parameter '{SparkParameters.Port}' is missing. Please provide a port number for the data source.", + nameof(Properties)); + } if (int.TryParse(port, out int portNumber) && (portNumber <= IPEndPoint.MinPort || portNumber > IPEndPoint.MaxPort)) + { throw new ArgumentOutOfRangeException( nameof(Properties), port, $"Parameter '{SparkParameters.Port}' value is not in the valid range of 1 .. {IPEndPoint.MaxPort}."); - + } } protected override TTransport CreateTransport() @@ -95,19 +106,68 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark // Assumption: hostName and port have already been validated. Properties.TryGetValue(SparkParameters.HostName, out string? hostName); Properties.TryGetValue(SparkParameters.Port, out string? port); + Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + + if (!SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue)) + { + throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); + } // Delay the open connection until later. bool connectClient = false; - TSocketTransport transport = new(hostName!, int.Parse(port!), connectClient, config: new()); - return transport; + int portValue = int.Parse(port!); + + // TLS setup + TTransport baseTransport; + if (TlsOptions.IsTlsEnabled) + { + X509Certificate2? trustedCert = !string.IsNullOrEmpty(TlsOptions.TrustedCertificatePath) + ? new X509Certificate2(TlsOptions.TrustedCertificatePath!) + : null; + + RemoteCertificateValidationCallback certValidator = (sender, cert, chain, errors) => HiveServer2TlsImpl.ValidateCertificate(cert, errors, TlsOptions); + + if (IPAddress.TryParse(hostName!, out var ipAddress)) + { + baseTransport = new TTlsSocketTransport(ipAddress, portValue, config: new(), 0, trustedCert, certValidator); + } + else + { + baseTransport = new TTlsSocketTransport(hostName!, portValue, config: new(), 0, trustedCert, certValidator); + } + } + else + { + baseTransport = new TSocketTransport(hostName!, portValue, connectClient, config: new()); + } + TBufferedTransport bufferedTransport = new TBufferedTransport(baseTransport); + switch (authTypeValue) + { + case SparkAuthType.None: + return bufferedTransport; + + case SparkAuthType.Basic: + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + { + throw new InvalidOperationException("Username and password must be provided for this authentication type."); + } + + PlainSaslMechanism saslMechanism = new(username, password); + TSaslTransport saslTransport = new(bufferedTransport, saslMechanism, config: new()); + return new TFramedTransport(saslTransport); + + default: + throw new NotSupportedException($"Authentication type '{authTypeValue}' is not supported."); + } } protected override async Task<TProtocol> CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) { - return await base.CreateProtocolAsync(transport, cancellationToken); - - //if (!transport.IsOpen) await transport.OpenAsync(CancellationToken.None); - //return new TBinaryProtocol(transport); + if (!transport.IsOpen) await transport.OpenAsync(cancellationToken); + return new TBinaryProtocol(transport, true, true); } protected override TOpenSessionReq CreateSessionRequest() @@ -120,7 +180,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); } - TOpenSessionReq request = base.CreateSessionRequest(); + TOpenSessionReq request = new TOpenSessionReq + { + Client_protocol = TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V7, + CanUseMultipleCatalogs = true, + }; switch (authTypeValue) { case SparkAuthType.UsernameOnly: @@ -139,10 +203,22 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark return request; } + protected override void ValidateOptions() + { + Properties.TryGetValue(SparkParameters.DataTypeConv, out string? dataTypeConv); + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); + TlsOptions = HiveServer2TlsImpl.GetStandardTlsOptions(Properties); + } + + internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, IResponse response, TGetResultSetMetadataResp? metadataResp = null) => + new HiveServer2Reader(statement, schema, response, dataTypeConversion: statement.Connection.DataTypeConversion); + internal override SparkServerType ServerType => SparkServerType.Standard; public override string AssemblyName => s_assemblyName; public override string AssemblyVersion => s_assemblyVersion; + + protected override int ColumnMapIndexOffset => 0; } } diff --git a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs index 47571194e..6b1b74f86 100644 --- a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs +++ b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs @@ -63,6 +63,9 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache [JsonPropertyName("http_options"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public HttpTestConfiguration? HttpOptions { get; set; } + [JsonPropertyName("standard_options"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public StandardTestConfiguration? StandardOptions { get; set; } + [JsonPropertyName("catalog"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string Catalog { get; set; } = string.Empty; @@ -79,6 +82,12 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache public ProxyTestConfiguration? Proxy { get; set; } } + public class StandardTestConfiguration + { + [JsonPropertyName("tls"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public TlsTestConfiguration? Tls { get; set; } + } + public class ProxyTestConfiguration { [JsonPropertyName("use_proxy"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs index d05025469..98b1f8ac4 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs @@ -176,6 +176,33 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark } } } + if (testConfiguration.StandardOptions != null) + { + if (testConfiguration.StandardOptions.Tls != null) + { + TlsTestConfiguration tlsOptions = testConfiguration.StandardOptions.Tls; + if (tlsOptions.Enabled.HasValue) + { + parameters.Add(StandardTlsOptions.IsTlsEnabled, tlsOptions.Enabled.Value.ToString()); + } + if (tlsOptions.AllowSelfSigned.HasValue) + { + parameters.Add(StandardTlsOptions.AllowSelfSigned, tlsOptions.AllowSelfSigned.Value.ToString()); + } + if (tlsOptions.AllowHostnameMismatch.HasValue) + { + parameters.Add(StandardTlsOptions.AllowHostnameMismatch, tlsOptions.AllowHostnameMismatch.Value.ToString()); + } + if (tlsOptions.DisableServerCertificateValidation.HasValue) + { + parameters.Add(StandardTlsOptions.DisableServerCertificateValidation, tlsOptions.DisableServerCertificateValidation.Value.ToString()); + } + if (!string.IsNullOrEmpty(tlsOptions.TrustedCertificatePath)) + { + parameters.Add(StandardTlsOptions.TrustedCertificatePath, tlsOptions.TrustedCertificatePath!); + } + } + } return parameters; }