This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new a88751df0 feat(csharp/src/Drivers/Databricks): Implement
ClientCredentialsProvider (#2743)
a88751df0 is described below
commit a88751df030664fbd06725fe5f5caba19b4da4d7
Author: Todd Meng <[email protected]>
AuthorDate: Tue Apr 29 11:07:19 2025 -0700
feat(csharp/src/Drivers/Databricks): Implement ClientCredentialsProvider
(#2743)
First PR for Class to get token via oauth service for M2M
authentication. Includes simple expiration and refresh logic.
SDK refresh logic
[here](https://github.com/databricks/databricks-sdk-java/blob/main/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java)
Follow up is to integrate with the rest of the driver
To test out:
1. Create a `databricks_test_config.json`
```
{
"oauth_client_id": "...",
"oauth_client_secret": "...",
"host": "databricks....com" // workspace hostname
}
```
2.
On macOS/Linux
export
DATABRICKS_TEST_CONFIG_FILE=/path/to/your/databricks_test_config.json
On Windows PowerShell
$env:DATABRICKS_TEST_CONFIG_FILE =
"C:\path\to\your\databricks_test_config.json"
3.
```/csharp% dotnet test --filter
"FullyQualifiedName~OAuthClientCredentialsServiceTests"```
---
.../Auth/OAuthClientCredentialsProvider.cs | 216 +++++++++++++++++++++
.../Auth/OAuthClientCredentialsProviderTests.cs | 82 ++++++++
.../Databricks/DatabricksTestConfiguration.cs | 5 +
3 files changed, 303 insertions(+)
diff --git
a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs
b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs
new file mode 100644
index 000000000..8fa797586
--- /dev/null
+++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs
@@ -0,0 +1,216 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Net.Http;
+using System.Net.Http.Headers;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
+{
+ /// <summary>
+ /// Service for obtaining OAuth access tokens using the client credentials
grant type.
+ /// </summary>
+ internal class OAuthClientCredentialsProvider : IDisposable
+ {
+ private readonly HttpClient _httpClient;
+ private readonly string _clientId;
+ private readonly string _clientSecret;
+ private readonly string _host;
+ private readonly string _tokenEndpoint;
+ private readonly int _timeoutMinutes;
+ private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1);
+ private TokenInfo? _cachedToken;
+
+ private class TokenInfo
+ {
+ public string? AccessToken { get; set; }
+ public DateTime ExpiresAt { get; set; }
+
+ // Add buffer time to refresh token before actual expiration
+ public bool NeedsRefresh => DateTime.UtcNow >=
ExpiresAt.AddMinutes(-5);
+ }
+
+ /// <summary>
+ /// Initializes a new instance of the <see
cref="OAuthClientCredentialsService"/> class.
+ /// </summary>
+ /// <param name="clientId">The OAuth client ID.</param>
+ /// <param name="clientSecret">The OAuth client secret.</param>
+ /// <param name="baseUri">The base URI of the Databricks
workspace.</param>
+ public OAuthClientCredentialsProvider(
+ string clientId,
+ string clientSecret,
+ string host,
+ int timeoutMinutes = 1)
+ {
+ _clientId = clientId ?? throw new
ArgumentNullException(nameof(clientId));
+ _clientSecret = clientSecret ?? throw new
ArgumentNullException(nameof(clientSecret));
+ _host = host ?? throw new ArgumentNullException(nameof(host));
+ _timeoutMinutes = timeoutMinutes;
+ _tokenEndpoint = DetermineTokenEndpoint();
+
+ _httpClient = new HttpClient();
+ _httpClient.Timeout = TimeSpan.FromMinutes(_timeoutMinutes);
+ }
+
+ private string DetermineTokenEndpoint()
+ {
+ // For workspace URLs, the token endpoint is always /oidc/v1/token
+ return $"https://{_host}/oidc/v1/token";
+ }
+
+ private string? GetValidCachedToken()
+ {
+ return _cachedToken != null && !_cachedToken.NeedsRefresh &&
_cachedToken.AccessToken != null
+ ? _cachedToken.AccessToken
+ : null;
+ }
+
+
+ private async Task<string> RefreshTokenInternalAsync(CancellationToken
cancellationToken)
+ {
+ var request = CreateTokenRequest();
+
+ HttpResponseMessage response;
+ try
+ {
+ response = await _httpClient.SendAsync(request,
cancellationToken);
+ response.EnsureSuccessStatusCode();
+ }
+ catch (Exception ex)
+ {
+ throw new DatabricksException($"Failed to acquire OAuth access
token: {ex.Message}", ex);
+ }
+
+ string content = await response.Content.ReadAsStringAsync();
+
+ try
+ {
+ _cachedToken = ParseTokenResponse(content);
+ return _cachedToken.AccessToken!;
+ }
+ catch (JsonException ex)
+ {
+ throw new DatabricksException($"Failed to parse OAuth
response: {ex.Message}", ex);
+ }
+ }
+
+ private HttpRequestMessage CreateTokenRequest()
+ {
+ var requestContent = new FormUrlEncodedContent(new[]
+ {
+ new KeyValuePair<string, string>("grant_type",
"client_credentials"),
+ new KeyValuePair<string, string>("scope", "all-apis")
+ });
+
+ var request = new HttpRequestMessage(HttpMethod.Post,
_tokenEndpoint)
+ {
+ Content = requestContent
+ };
+
+ // Use Basic Auth with client ID and secret
+ var authHeader = Convert.ToBase64String(
+
System.Text.Encoding.ASCII.GetBytes($"{_clientId}:{_clientSecret}"));
+ request.Headers.Authorization = new
AuthenticationHeaderValue("Basic", authHeader);
+ request.Headers.Accept.Add(new
MediaTypeWithQualityHeaderValue("application/json"));
+
+ return request;
+ }
+
+ private TokenInfo ParseTokenResponse(string content)
+ {
+ using var jsonDoc = JsonDocument.Parse(content);
+
+ if (!jsonDoc.RootElement.TryGetProperty("access_token", out var
accessTokenElement))
+ {
+ throw new DatabricksException("OAuth response did not contain
an access_token");
+ }
+
+ string? accessToken = accessTokenElement.GetString();
+ if (string.IsNullOrEmpty(accessToken))
+ {
+ throw new DatabricksException("OAuth access_token was null or
empty");
+ }
+
+ // Get expiration time from response
+ if (!jsonDoc.RootElement.TryGetProperty("expires_in", out var
expiresInElement))
+ {
+ throw new DatabricksException("OAuth response did not contain
expires_in");
+ }
+
+ int expiresIn = expiresInElement.GetInt32();
+ if (expiresIn <= 0)
+ {
+ throw new DatabricksException("OAuth expires_in value must be
positive");
+ }
+
+ return new TokenInfo
+ {
+ AccessToken = accessToken!,
+ ExpiresAt = DateTime.UtcNow.AddSeconds(expiresIn)
+ };
+ }
+
+ private async Task<string> GetAccessTokenAsync(CancellationToken
cancellationToken = default)
+ {
+ await _tokenLock.WaitAsync(cancellationToken);
+
+ try
+ {
+ // Double-check pattern in case another thread refreshed while
we were waiting
+ if (GetValidCachedToken() is string refreshedToken)
+ {
+ return refreshedToken;
+ }
+
+ return await RefreshTokenInternalAsync(cancellationToken);
+ }
+ finally
+ {
+ _tokenLock.Release();
+ }
+ }
+
+
+ /// <summary>
+ /// Gets an OAuth access token using the client credentials grant type.
+ /// </summary>
+ /// <param name="cancellationToken">A cancellation token to cancel the
operation.</param>
+ /// <returns>The access token.</returns>
+ public string GetAccessToken(CancellationToken cancellationToken =
default)
+ {
+ // First try to get cached token without acquiring lock
+ if (GetValidCachedToken() is string cachedToken)
+ {
+ return cachedToken;
+ }
+
+ return
GetAccessTokenAsync(cancellationToken).GetAwaiter().GetResult();
+ }
+
+
+ public void Dispose()
+ {
+ _tokenLock.Dispose();
+ _httpClient.Dispose();
+ }
+
+ }
+}
diff --git
a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs
b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs
new file mode 100644
index 000000000..de9b01684
--- /dev/null
+++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs
@@ -0,0 +1,82 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Databricks.Auth;
+using Xunit;
+using Xunit.Abstractions;
+using Apache.Arrow.Adbc.Tests.Drivers.Databricks;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth
+{
+ public class OAuthClientCredentialsProviderTests :
TestBase<DatabricksTestConfiguration, DatabricksTestEnvironment>, IDisposable
+ {
+ public OAuthClientCredentialsProviderTests(ITestOutputHelper?
outputHelper)
+ : base(outputHelper, new DatabricksTestEnvironment.Factory())
+ {
+ }
+
+ private OAuthClientCredentialsProvider CreateService()
+ {
+ return new OAuthClientCredentialsProvider(
+ TestConfiguration.OAuthClientId,
+ TestConfiguration.OAuthClientSecret,
+ TestConfiguration.HostName,
+ timeoutMinutes: 1);
+ }
+
+ [SkippableFact]
+ public void GetAccessToken_WithValidCredentials_ReturnsToken()
+ {
+ Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId),
"OAuth credentials not configured");
+
+ var service = CreateService();
+ var token = service.GetAccessToken();
+
+ Assert.NotNull(token);
+ Assert.NotEmpty(token);
+ }
+
+ [SkippableFact]
+ public void
GetAccessToken_WithCancellation_ThrowsOperationCanceledException()
+ {
+ Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId),
"OAuth credentials not configured");
+
+ var service = CreateService();
+ using var cts = new CancellationTokenSource();
+ cts.Cancel();
+
+ var ex = Assert.ThrowsAny<OperationCanceledException>(() =>
+ service.GetAccessToken(cts.Token));
+ Assert.IsType<TaskCanceledException>(ex);
+ }
+
+ [SkippableFact]
+ public void GetAccessToken_MultipleCalls_ReusesCachedToken()
+ {
+ Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId),
"OAuth credentials not configured");
+
+ var service = CreateService();
+ var token1 = service.GetAccessToken();
+ var token2 = service.GetAccessToken();
+
+ Assert.Equal(token1, token2);
+ }
+ }
+}
diff --git a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs
b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs
index fb221560b..0f0366c20 100644
--- a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs
+++ b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs
@@ -15,12 +15,17 @@
* limitations under the License.
*/
+using System.Text.Json.Serialization;
using Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark;
namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
{
public class DatabricksTestConfiguration : SparkTestConfiguration
{
+ [JsonPropertyName("oauth_client_id"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
+ public string OAuthClientId { get; set; } = string.Empty;
+ [JsonPropertyName("oauth_client_secret"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
+ public string OAuthClientSecret { get; set; } = string.Empty;
}
}