This is an automated email from the ASF dual-hosted git repository.

exceptionfactory pushed a commit to branch support/nifi-1.x
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/support/nifi-1.x by this push:
     new 39c7a356d2 NIFI-11386 Added Resource and Audience support to 
StandardOauth2AccessTokenProvider
39c7a356d2 is described below

commit 39c7a356d2ef3614db0a123d3c34ed7e5fe984d7
Author: Tamas Palfy <[email protected]>
AuthorDate: Tue Apr 11 17:10:07 2023 +0200

    NIFI-11386 Added Resource and Audience support to 
StandardOauth2AccessTokenProvider
    
    - Also keeping previous Refresh Token if one is not provided during a 
refresh request
    
    This closes #7164
    
    Signed-off-by: David Handermann <[email protected]>
    (cherry picked from commit 88587f5c0208d72db95e61112126de64f12c1b87)
---
 .../oauth2/StandardOauth2AccessTokenProvider.java  | 105 +++++++++++++--------
 .../StandardOauth2AccessTokenProviderTest.java     |  65 ++++++++++++-
 2 files changed, 127 insertions(+), 43 deletions(-)

diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
 
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
index 6a3cb723aa..3062e887ad 100644
--- 
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
@@ -163,6 +163,22 @@ public class StandardOauth2AccessTokenProvider extends 
AbstractControllerService
         .addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
         .build();
 
+    public static final PropertyDescriptor RESOURCE = new 
PropertyDescriptor.Builder()
+        .name("resource")
+        .displayName("Resource")
+        .description("Resource URI for the access token request defined in RFC 
8707 Section 2")
+        .required(false)
+        .addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
+        .build();
+
+    public static final PropertyDescriptor AUDIENCE = new 
PropertyDescriptor.Builder()
+        .name("audience")
+        .displayName("Audience")
+        .description("Audience for the access token request defined in RFC 
8693 Section 2.1")
+        .required(false)
+        .addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
+        .build();
+
     public static final PropertyDescriptor REFRESH_WINDOW = new 
PropertyDescriptor.Builder()
         .name("refresh-window")
         .displayName("Refresh Window")
@@ -199,6 +215,8 @@ public class StandardOauth2AccessTokenProvider extends 
AbstractControllerService
         CLIENT_ID,
         CLIENT_SECRET,
         SCOPE,
+        RESOURCE,
+        AUDIENCE,
         REFRESH_WINDOW,
         SSL_CONTEXT,
         HTTP_PROTOCOL_STRATEGY
@@ -220,6 +238,8 @@ public class StandardOauth2AccessTokenProvider extends 
AbstractControllerService
     private volatile String clientId;
     private volatile String clientSecret;
     private volatile String scope;
+    private volatile String resource;
+    private volatile String audience;
     private volatile long refreshWindowSeconds;
 
     private volatile AccessToken accessDetails;
@@ -242,6 +262,8 @@ public class StandardOauth2AccessTokenProvider extends 
AbstractControllerService
         clientId = 
context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
         clientSecret = context.getProperty(CLIENT_SECRET).getValue();
         scope = context.getProperty(SCOPE).getValue();
+        resource = context.getProperty(RESOURCE).getValue();
+        audience = context.getProperty(AUDIENCE).getValue();
 
         if (context.getProperty(REFRESH_TOKEN).isSet()) {
             String refreshToken = 
context.getProperty(REFRESH_TOKEN).evaluateAttributeExpressions().getValue();
@@ -319,6 +341,14 @@ public class StandardOauth2AccessTokenProvider extends 
AbstractControllerService
         return accessDetails;
     }
 
+    private boolean isRefreshRequired() {
+        final Instant expirationRefreshTime = accessDetails.getFetchTime()
+                .plusSeconds(accessDetails.getExpiresIn())
+                .minusSeconds(refreshWindowSeconds);
+
+        return Instant.now().isAfter(expirationRefreshTime);
+    }
+
     private void acquireAccessDetails() {
         getLogger().debug("New Access Token request started [{}]", 
authorizationServerUrl);
 
@@ -332,58 +362,59 @@ public class StandardOauth2AccessTokenProvider extends 
AbstractControllerService
             acquireTokenBuilder.add("grant_type", "client_credentials");
         }
 
-        if (ClientAuthenticationStrategy.REQUEST_BODY == 
clientAuthenticationStrategy && clientId != null) {
-            acquireTokenBuilder.add("client_id", clientId);
-            acquireTokenBuilder.add("client_secret", clientSecret);
-        }
-
-        if (scope != null) {
-            acquireTokenBuilder.add("scope", scope);
-        }
-
-        RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
-        Request.Builder acquireTokenRequestBuilder = new Request.Builder()
-            .url(authorizationServerUrl)
-            .post(acquireTokenRequestBody);
+        addFormData(acquireTokenBuilder);
 
-        if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == 
clientAuthenticationStrategy && clientId != null) {
-            acquireTokenRequestBuilder.addHeader(AUTHORIZATION_HEADER, 
Credentials.basic(clientId, clientSecret));
-        }
-
-        Request acquireTokenRequest = acquireTokenRequestBuilder.build();
-
-        this.accessDetails = getAccessDetails(acquireTokenRequest);
+        this.accessDetails = requestToken(acquireTokenBuilder);
     }
 
     private void refreshAccessDetails() {
         getLogger().debug("Refresh Access Token request started [{}]", 
authorizationServerUrl);
 
         FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
-            .add("grant_type", "refresh_token")
-            .add("refresh_token", this.accessDetails.getRefreshToken());
+                .add("grant_type", "refresh_token")
+                .add("refresh_token", this.accessDetails.getRefreshToken());
+
+        addFormData(refreshTokenBuilder);
 
-        if (ClientAuthenticationStrategy.REQUEST_BODY == 
clientAuthenticationStrategy && clientId != null) {
-            refreshTokenBuilder.add("client_id", clientId);
-            refreshTokenBuilder.add("client_secret", clientSecret);
+        AccessToken newAccessDetails = requestToken(refreshTokenBuilder);
+
+        if (newAccessDetails.getRefreshToken() == null) {
+            
newAccessDetails.setRefreshToken(this.accessDetails.getRefreshToken());
         }
 
+        this.accessDetails = newAccessDetails;
+    }
+
+    private void addFormData(FormBody.Builder formBuilder) {
+        if (clientAuthenticationStrategy == 
ClientAuthenticationStrategy.REQUEST_BODY && clientId != null) {
+            formBuilder.add("client_id", clientId);
+            formBuilder.add("client_secret", clientSecret);
+        }
         if (scope != null) {
-            refreshTokenBuilder.add("scope", scope);
+            formBuilder.add("scope", scope);
+        }
+        if (resource != null) {
+            formBuilder.add("resource", resource);
         }
+        if (audience != null) {
+            formBuilder.add("audience", audience);
+        }
+    }
 
-        RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
+    private AccessToken requestToken(FormBody.Builder formBuilder) {
+        RequestBody requestBody = formBuilder.build();
 
-        Request.Builder refreshRequestBuilder = new Request.Builder()
-            .url(authorizationServerUrl)
-            .post(refreshTokenRequestBody);
+        Request.Builder requestBuilder = new Request.Builder()
+                .url(authorizationServerUrl)
+                .post(requestBody);
 
         if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == 
clientAuthenticationStrategy && clientId != null) {
-            refreshRequestBuilder.addHeader(AUTHORIZATION_HEADER, 
Credentials.basic(clientId, clientSecret));
+            requestBuilder.addHeader(AUTHORIZATION_HEADER, 
Credentials.basic(clientId, clientSecret));
         }
 
-        Request refreshRequest = refreshRequestBuilder.build();
+        Request request = requestBuilder.build();
 
-        this.accessDetails = getAccessDetails(refreshRequest);
+        return getAccessDetails(request);
     }
 
     private AccessToken getAccessDetails(final Request newRequest) {
@@ -402,14 +433,6 @@ public class StandardOauth2AccessTokenProvider extends 
AbstractControllerService
         }
     }
 
-    private boolean isRefreshRequired() {
-        final Instant expirationRefreshTime = accessDetails.getFetchTime()
-                .plusSeconds(accessDetails.getExpiresIn())
-                .minusSeconds(refreshWindowSeconds);
-
-        return Instant.now().isAfter(expirationRefreshTime);
-    }
-
     @Override
     public List<ConfigVerificationResult> verify(ConfigurationContext context, 
ComponentLog verificationLogger, Map<String, String> variables) {
         ConfigVerificationResult.Builder builder = new 
ConfigVerificationResult.Builder()
diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
 
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
index 3b53b04d6c..cda25f5b8e 100644
--- 
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
@@ -78,6 +78,9 @@ public class StandardOauth2AccessTokenProviderTest {
     private static final String PASSWORD = "password";
     private static final String CLIENT_ID = "clientId";
     private static final String CLIENT_SECRET = "clientSecret";
+    private static final String SCOPE = "scope";
+    private static final String RESOURCE = "resource";
+    private static final String AUDIENCE = "audience";
     private static final long FIVE_MINUTES = 300;
 
     private static final int HTTP_OK = 200;
@@ -120,6 +123,9 @@ public class StandardOauth2AccessTokenProviderTest {
         
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
         
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
         
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
+        
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.SCOPE).getValue()).thenReturn(SCOPE);
+        
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.RESOURCE).getValue()).thenReturn(RESOURCE);
+        
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUDIENCE).getValue()).thenReturn(AUDIENCE);
         
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
         
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
 
@@ -361,6 +367,57 @@ public class StandardOauth2AccessTokenProviderTest {
         assertEquals(expectedToken, actualToken);
     }
 
+    @Test
+    public void testKeepPreviousRefreshTokenWhenNewOneIsNotProvided() throws 
Exception {
+        // GIVEN
+        String refreshTokenBeforeRefresh = "refresh_token";
+
+        Response response1 = buildResponse(
+                HTTP_OK,
+                "{ \"access_token\":\"not_checking_in_this_test\", 
\"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
+        );
+
+        Response response2 = buildResponse(
+                HTTP_OK,
+            "{ \"access_token\":\"not_checking_in_this_test_either\" }"
+        );
+
+        
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1,
 response2);
+
+        // WHEN
+        testSubject.getAccessDetails();
+        String refreshTokenAfterRefresh = 
testSubject.getAccessDetails().getRefreshToken();
+
+        // THEN
+        assertEquals(refreshTokenBeforeRefresh, refreshTokenAfterRefresh);
+    }
+
+    @Test
+    public void testOverwritePreviousRefreshTokenWhenNewOneIsProvided() throws 
Exception {
+        // GIVEN
+        String refreshTokenBeforeRefresh = "refresh_token_before_refresh";
+        String expectedRefreshTokenAfterRefresh = 
"refresh_token_after_refresh";
+
+        Response response1 = buildResponse(
+                HTTP_OK,
+                "{ \"access_token\":\"not_checking_in_this_test\", 
\"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
+        );
+
+        Response response2 = buildResponse(
+                HTTP_OK,
+            "{ \"access_token\":\"not_checking_in_this_test_either\", 
\"refresh_token\":\"" + expectedRefreshTokenAfterRefresh + "\" }"
+        );
+
+        
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1,
 response2);
+
+        // WHEN
+        testSubject.getAccessDetails();
+        String actualRefreshTokenAfterRefresh = 
testSubject.getAccessDetails().getRefreshToken();
+
+        // THEN
+        assertEquals(expectedRefreshTokenAfterRefresh, 
actualRefreshTokenAfterRefresh);
+    }
+
     @Test
     public void testBasicAuthentication() throws Exception {
         // GIVEN
@@ -377,7 +434,7 @@ public class StandardOauth2AccessTokenProviderTest {
     }
 
     @Test
-    public void testRequestBodyAuthentication() throws Exception {
+    public void testRequestBodyFormData() throws Exception {
         
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
         
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
         testSubject.onEnabled(mockContext);
@@ -385,7 +442,11 @@ public class StandardOauth2AccessTokenProviderTest {
         // GIVEN
         Response response = buildResponse(HTTP_OK, 
"{\"access_token\":\"foobar\"}");
         
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
-        String expected = "grant_type=client_credentials&client_id=" + 
CLIENT_ID + "&client_secret=" + CLIENT_SECRET;
+        String expected = "grant_type=client_credentials&client_id=" + 
CLIENT_ID
+                + "&client_secret=" + CLIENT_SECRET
+                + "&scope=" + SCOPE
+                + "&resource=" + RESOURCE
+                + "&audience=" + AUDIENCE;
 
         // WHEN
         testSubject.getAccessDetails();

Reply via email to