This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new a88751df0 feat(csharp/src/Drivers/Databricks): Implement 
ClientCredentialsProvider (#2743)
a88751df0 is described below

commit a88751df030664fbd06725fe5f5caba19b4da4d7
Author: Todd Meng <[email protected]>
AuthorDate: Tue Apr 29 11:07:19 2025 -0700

    feat(csharp/src/Drivers/Databricks): Implement ClientCredentialsProvider 
(#2743)
    
    First PR for Class to get token via oauth service for M2M
    authentication. Includes simple expiration and refresh logic.
    
    SDK refresh logic
    
[here](https://github.com/databricks/databricks-sdk-java/blob/main/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java)
    
    Follow up is to integrate with the rest of the driver
    
    
    To test out:
    1. Create a `databricks_test_config.json`
    ```
       {
         "oauth_client_id": "...",
         "oauth_client_secret": "...",
         "host": "databricks....com" // workspace hostname
       }
    ```
    2.
       On macOS/Linux
    export
    DATABRICKS_TEST_CONFIG_FILE=/path/to/your/databricks_test_config.json
    
       On Windows PowerShell
    $env:DATABRICKS_TEST_CONFIG_FILE =
    "C:\path\to\your\databricks_test_config.json"
    
    3.
    ```/csharp% dotnet test --filter 
"FullyQualifiedName~OAuthClientCredentialsServiceTests"```
---
 .../Auth/OAuthClientCredentialsProvider.cs         | 216 +++++++++++++++++++++
 .../Auth/OAuthClientCredentialsProviderTests.cs    |  82 ++++++++
 .../Databricks/DatabricksTestConfiguration.cs      |   5 +
 3 files changed, 303 insertions(+)

diff --git 
a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs 
b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs
new file mode 100644
index 000000000..8fa797586
--- /dev/null
+++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs
@@ -0,0 +1,216 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Net.Http;
+using System.Net.Http.Headers;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
+{
+    /// <summary>
+    /// Service for obtaining OAuth access tokens using the client credentials 
grant type.
+    /// </summary>
+    internal class OAuthClientCredentialsProvider : IDisposable
+    {
+        private readonly HttpClient _httpClient;
+        private readonly string _clientId;
+        private readonly string _clientSecret;
+        private readonly string _host;
+        private readonly string _tokenEndpoint;
+        private readonly int _timeoutMinutes;
+        private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1);
+        private TokenInfo? _cachedToken;
+
+        private class TokenInfo
+        {
+            public string? AccessToken { get; set; }
+            public DateTime ExpiresAt { get; set; }
+
+            // Add buffer time to refresh token before actual expiration
+            public bool NeedsRefresh => DateTime.UtcNow >= 
ExpiresAt.AddMinutes(-5);
+        }
+
+        /// <summary>
+        /// Initializes a new instance of the <see 
cref="OAuthClientCredentialsService"/> class.
+        /// </summary>
+        /// <param name="clientId">The OAuth client ID.</param>
+        /// <param name="clientSecret">The OAuth client secret.</param>
+        /// <param name="baseUri">The base URI of the Databricks 
workspace.</param>
+        public OAuthClientCredentialsProvider(
+            string clientId,
+            string clientSecret,
+            string host,
+            int timeoutMinutes = 1)
+        {
+            _clientId = clientId ?? throw new 
ArgumentNullException(nameof(clientId));
+            _clientSecret = clientSecret ?? throw new 
ArgumentNullException(nameof(clientSecret));
+            _host = host ?? throw new ArgumentNullException(nameof(host));
+            _timeoutMinutes = timeoutMinutes;
+            _tokenEndpoint = DetermineTokenEndpoint();
+
+            _httpClient = new HttpClient();
+            _httpClient.Timeout = TimeSpan.FromMinutes(_timeoutMinutes);
+        }
+
+        private string DetermineTokenEndpoint()
+        {
+            // For workspace URLs, the token endpoint is always /oidc/v1/token
+            return $"https://{_host}/oidc/v1/token";;
+        }
+
+        private string? GetValidCachedToken()
+        {
+            return _cachedToken != null && !_cachedToken.NeedsRefresh && 
_cachedToken.AccessToken != null
+                ? _cachedToken.AccessToken
+                : null;
+        }
+
+
+        private async Task<string> RefreshTokenInternalAsync(CancellationToken 
cancellationToken)
+        {
+            var request = CreateTokenRequest();
+
+            HttpResponseMessage response;
+            try
+            {
+                response = await _httpClient.SendAsync(request, 
cancellationToken);
+                response.EnsureSuccessStatusCode();
+            }
+            catch (Exception ex)
+            {
+                throw new DatabricksException($"Failed to acquire OAuth access 
token: {ex.Message}", ex);
+            }
+
+            string content = await response.Content.ReadAsStringAsync();
+
+            try
+            {
+                _cachedToken = ParseTokenResponse(content);
+                return _cachedToken.AccessToken!;
+            }
+            catch (JsonException ex)
+            {
+                throw new DatabricksException($"Failed to parse OAuth 
response: {ex.Message}", ex);
+            }
+        }
+
+        private HttpRequestMessage CreateTokenRequest()
+        {
+            var requestContent = new FormUrlEncodedContent(new[]
+            {
+                new KeyValuePair<string, string>("grant_type", 
"client_credentials"),
+                new KeyValuePair<string, string>("scope", "all-apis")
+            });
+
+            var request = new HttpRequestMessage(HttpMethod.Post, 
_tokenEndpoint)
+            {
+                Content = requestContent
+            };
+
+            // Use Basic Auth with client ID and secret
+            var authHeader = Convert.ToBase64String(
+                
System.Text.Encoding.ASCII.GetBytes($"{_clientId}:{_clientSecret}"));
+            request.Headers.Authorization = new 
AuthenticationHeaderValue("Basic", authHeader);
+            request.Headers.Accept.Add(new 
MediaTypeWithQualityHeaderValue("application/json"));
+
+            return request;
+        }
+
+        private TokenInfo ParseTokenResponse(string content)
+        {
+            using var jsonDoc = JsonDocument.Parse(content);
+
+            if (!jsonDoc.RootElement.TryGetProperty("access_token", out var 
accessTokenElement))
+            {
+                throw new DatabricksException("OAuth response did not contain 
an access_token");
+            }
+
+            string? accessToken = accessTokenElement.GetString();
+            if (string.IsNullOrEmpty(accessToken))
+            {
+                throw new DatabricksException("OAuth access_token was null or 
empty");
+            }
+
+            // Get expiration time from response
+            if (!jsonDoc.RootElement.TryGetProperty("expires_in", out var 
expiresInElement))
+            {
+                throw new DatabricksException("OAuth response did not contain 
expires_in");
+            }
+
+            int expiresIn = expiresInElement.GetInt32();
+            if (expiresIn <= 0)
+            {
+                throw new DatabricksException("OAuth expires_in value must be 
positive");
+            }
+
+            return new TokenInfo
+            {
+                AccessToken = accessToken!,
+                ExpiresAt = DateTime.UtcNow.AddSeconds(expiresIn)
+            };
+        }
+
+        private async Task<string> GetAccessTokenAsync(CancellationToken 
cancellationToken = default)
+        {
+            await _tokenLock.WaitAsync(cancellationToken);
+
+            try
+            {
+                // Double-check pattern in case another thread refreshed while 
we were waiting
+                if (GetValidCachedToken() is string refreshedToken)
+                {
+                    return refreshedToken;
+                }
+
+                return await RefreshTokenInternalAsync(cancellationToken);
+            }
+            finally
+            {
+                _tokenLock.Release();
+            }
+        }
+
+
+        /// <summary>
+        /// Gets an OAuth access token using the client credentials grant type.
+        /// </summary>
+        /// <param name="cancellationToken">A cancellation token to cancel the 
operation.</param>
+        /// <returns>The access token.</returns>
+        public string GetAccessToken(CancellationToken cancellationToken = 
default)
+        {
+            // First try to get cached token without acquiring lock
+            if (GetValidCachedToken() is string cachedToken)
+            {
+                return cachedToken;
+            }
+
+            return 
GetAccessTokenAsync(cancellationToken).GetAwaiter().GetResult();
+        }
+
+
+        public void Dispose()
+        {
+            _tokenLock.Dispose();
+            _httpClient.Dispose();
+        }
+
+    }
+}
diff --git 
a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs 
b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs
new file mode 100644
index 000000000..de9b01684
--- /dev/null
+++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs
@@ -0,0 +1,82 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Databricks.Auth;
+using Xunit;
+using Xunit.Abstractions;
+using Apache.Arrow.Adbc.Tests.Drivers.Databricks;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth
+{
+    public class OAuthClientCredentialsProviderTests : 
TestBase<DatabricksTestConfiguration, DatabricksTestEnvironment>, IDisposable
+    {
+        public OAuthClientCredentialsProviderTests(ITestOutputHelper? 
outputHelper)
+            : base(outputHelper, new DatabricksTestEnvironment.Factory())
+        {
+        }
+
+        private OAuthClientCredentialsProvider CreateService()
+        {
+            return new OAuthClientCredentialsProvider(
+                TestConfiguration.OAuthClientId,
+                TestConfiguration.OAuthClientSecret,
+                TestConfiguration.HostName,
+                timeoutMinutes: 1);
+        }
+
+        [SkippableFact]
+        public void GetAccessToken_WithValidCredentials_ReturnsToken()
+        {
+            Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), 
"OAuth credentials not configured");
+
+            var service = CreateService();
+            var token = service.GetAccessToken();
+
+            Assert.NotNull(token);
+            Assert.NotEmpty(token);
+        }
+
+        [SkippableFact]
+        public void 
GetAccessToken_WithCancellation_ThrowsOperationCanceledException()
+        {
+            Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), 
"OAuth credentials not configured");
+
+            var service = CreateService();
+            using var cts = new CancellationTokenSource();
+            cts.Cancel();
+
+            var ex = Assert.ThrowsAny<OperationCanceledException>(() =>
+                service.GetAccessToken(cts.Token));
+            Assert.IsType<TaskCanceledException>(ex);
+        }
+
+        [SkippableFact]
+        public void GetAccessToken_MultipleCalls_ReusesCachedToken()
+        {
+            Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), 
"OAuth credentials not configured");
+
+            var service = CreateService();
+            var token1 = service.GetAccessToken();
+            var token2 = service.GetAccessToken();
+
+            Assert.Equal(token1, token2);
+        }
+    }
+}
diff --git a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs 
b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs
index fb221560b..0f0366c20 100644
--- a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs
+++ b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs
@@ -15,12 +15,17 @@
 * limitations under the License.
 */
 
+using System.Text.Json.Serialization;
 using Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark;
 
 namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
 {
     public class DatabricksTestConfiguration : SparkTestConfiguration
     {
+        [JsonPropertyName("oauth_client_id"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        public string OAuthClientId { get; set; } = string.Empty;
 
+        [JsonPropertyName("oauth_client_secret"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        public string OAuthClientSecret { get; set; } = string.Empty;
     }
 }

Reply via email to