This is an automated email from the ASF dual-hosted git repository.
exceptionfactory pushed a commit to branch support/nifi-1.x
in repository https://gitbox.apache.org/repos/asf/nifi.git
The following commit(s) were added to refs/heads/support/nifi-1.x by this push:
new 39c7a356d2 NIFI-11386 Added Resource and Audience support to
StandardOauth2AccessTokenProvider
39c7a356d2 is described below
commit 39c7a356d2ef3614db0a123d3c34ed7e5fe984d7
Author: Tamas Palfy <[email protected]>
AuthorDate: Tue Apr 11 17:10:07 2023 +0200
NIFI-11386 Added Resource and Audience support to
StandardOauth2AccessTokenProvider
- Also keeping previous Refresh Token if one is not provided during a
refresh request
This closes #7164
Signed-off-by: David Handermann <[email protected]>
(cherry picked from commit 88587f5c0208d72db95e61112126de64f12c1b87)
---
.../oauth2/StandardOauth2AccessTokenProvider.java | 105 +++++++++++++--------
.../StandardOauth2AccessTokenProviderTest.java | 65 ++++++++++++-
2 files changed, 127 insertions(+), 43 deletions(-)
diff --git
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
index 6a3cb723aa..3062e887ad 100644
---
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
+++
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/main/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProvider.java
@@ -163,6 +163,22 @@ public class StandardOauth2AccessTokenProvider extends
AbstractControllerService
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.build();
+ public static final PropertyDescriptor RESOURCE = new
PropertyDescriptor.Builder()
+ .name("resource")
+ .displayName("Resource")
+ .description("Resource URI for the access token request defined in RFC
8707 Section 2")
+ .required(false)
+ .addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
+ .build();
+
+ public static final PropertyDescriptor AUDIENCE = new
PropertyDescriptor.Builder()
+ .name("audience")
+ .displayName("Audience")
+ .description("Audience for the access token request defined in RFC
8693 Section 2.1")
+ .required(false)
+ .addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
+ .build();
+
public static final PropertyDescriptor REFRESH_WINDOW = new
PropertyDescriptor.Builder()
.name("refresh-window")
.displayName("Refresh Window")
@@ -199,6 +215,8 @@ public class StandardOauth2AccessTokenProvider extends
AbstractControllerService
CLIENT_ID,
CLIENT_SECRET,
SCOPE,
+ RESOURCE,
+ AUDIENCE,
REFRESH_WINDOW,
SSL_CONTEXT,
HTTP_PROTOCOL_STRATEGY
@@ -220,6 +238,8 @@ public class StandardOauth2AccessTokenProvider extends
AbstractControllerService
private volatile String clientId;
private volatile String clientSecret;
private volatile String scope;
+ private volatile String resource;
+ private volatile String audience;
private volatile long refreshWindowSeconds;
private volatile AccessToken accessDetails;
@@ -242,6 +262,8 @@ public class StandardOauth2AccessTokenProvider extends
AbstractControllerService
clientId =
context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
scope = context.getProperty(SCOPE).getValue();
+ resource = context.getProperty(RESOURCE).getValue();
+ audience = context.getProperty(AUDIENCE).getValue();
if (context.getProperty(REFRESH_TOKEN).isSet()) {
String refreshToken =
context.getProperty(REFRESH_TOKEN).evaluateAttributeExpressions().getValue();
@@ -319,6 +341,14 @@ public class StandardOauth2AccessTokenProvider extends
AbstractControllerService
return accessDetails;
}
+ private boolean isRefreshRequired() {
+ final Instant expirationRefreshTime = accessDetails.getFetchTime()
+ .plusSeconds(accessDetails.getExpiresIn())
+ .minusSeconds(refreshWindowSeconds);
+
+ return Instant.now().isAfter(expirationRefreshTime);
+ }
+
private void acquireAccessDetails() {
getLogger().debug("New Access Token request started [{}]",
authorizationServerUrl);
@@ -332,58 +362,59 @@ public class StandardOauth2AccessTokenProvider extends
AbstractControllerService
acquireTokenBuilder.add("grant_type", "client_credentials");
}
- if (ClientAuthenticationStrategy.REQUEST_BODY ==
clientAuthenticationStrategy && clientId != null) {
- acquireTokenBuilder.add("client_id", clientId);
- acquireTokenBuilder.add("client_secret", clientSecret);
- }
-
- if (scope != null) {
- acquireTokenBuilder.add("scope", scope);
- }
-
- RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
- Request.Builder acquireTokenRequestBuilder = new Request.Builder()
- .url(authorizationServerUrl)
- .post(acquireTokenRequestBody);
+ addFormData(acquireTokenBuilder);
- if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION ==
clientAuthenticationStrategy && clientId != null) {
- acquireTokenRequestBuilder.addHeader(AUTHORIZATION_HEADER,
Credentials.basic(clientId, clientSecret));
- }
-
- Request acquireTokenRequest = acquireTokenRequestBuilder.build();
-
- this.accessDetails = getAccessDetails(acquireTokenRequest);
+ this.accessDetails = requestToken(acquireTokenBuilder);
}
private void refreshAccessDetails() {
getLogger().debug("Refresh Access Token request started [{}]",
authorizationServerUrl);
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
- .add("grant_type", "refresh_token")
- .add("refresh_token", this.accessDetails.getRefreshToken());
+ .add("grant_type", "refresh_token")
+ .add("refresh_token", this.accessDetails.getRefreshToken());
+
+ addFormData(refreshTokenBuilder);
- if (ClientAuthenticationStrategy.REQUEST_BODY ==
clientAuthenticationStrategy && clientId != null) {
- refreshTokenBuilder.add("client_id", clientId);
- refreshTokenBuilder.add("client_secret", clientSecret);
+ AccessToken newAccessDetails = requestToken(refreshTokenBuilder);
+
+ if (newAccessDetails.getRefreshToken() == null) {
+
newAccessDetails.setRefreshToken(this.accessDetails.getRefreshToken());
}
+ this.accessDetails = newAccessDetails;
+ }
+
+ private void addFormData(FormBody.Builder formBuilder) {
+ if (clientAuthenticationStrategy ==
ClientAuthenticationStrategy.REQUEST_BODY && clientId != null) {
+ formBuilder.add("client_id", clientId);
+ formBuilder.add("client_secret", clientSecret);
+ }
if (scope != null) {
- refreshTokenBuilder.add("scope", scope);
+ formBuilder.add("scope", scope);
+ }
+ if (resource != null) {
+ formBuilder.add("resource", resource);
}
+ if (audience != null) {
+ formBuilder.add("audience", audience);
+ }
+ }
- RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
+ private AccessToken requestToken(FormBody.Builder formBuilder) {
+ RequestBody requestBody = formBuilder.build();
- Request.Builder refreshRequestBuilder = new Request.Builder()
- .url(authorizationServerUrl)
- .post(refreshTokenRequestBody);
+ Request.Builder requestBuilder = new Request.Builder()
+ .url(authorizationServerUrl)
+ .post(requestBody);
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION ==
clientAuthenticationStrategy && clientId != null) {
- refreshRequestBuilder.addHeader(AUTHORIZATION_HEADER,
Credentials.basic(clientId, clientSecret));
+ requestBuilder.addHeader(AUTHORIZATION_HEADER,
Credentials.basic(clientId, clientSecret));
}
- Request refreshRequest = refreshRequestBuilder.build();
+ Request request = requestBuilder.build();
- this.accessDetails = getAccessDetails(refreshRequest);
+ return getAccessDetails(request);
}
private AccessToken getAccessDetails(final Request newRequest) {
@@ -402,14 +433,6 @@ public class StandardOauth2AccessTokenProvider extends
AbstractControllerService
}
}
- private boolean isRefreshRequired() {
- final Instant expirationRefreshTime = accessDetails.getFetchTime()
- .plusSeconds(accessDetails.getExpiresIn())
- .minusSeconds(refreshWindowSeconds);
-
- return Instant.now().isAfter(expirationRefreshTime);
- }
-
@Override
public List<ConfigVerificationResult> verify(ConfigurationContext context,
ComponentLog verificationLogger, Map<String, String> variables) {
ConfigVerificationResult.Builder builder = new
ConfigVerificationResult.Builder()
diff --git
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
index 3b53b04d6c..cda25f5b8e 100644
---
a/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
+++
b/nifi-nar-bundles/nifi-standard-services/nifi-oauth2-provider-bundle/nifi-oauth2-provider-service/src/test/java/org/apache/nifi/oauth2/StandardOauth2AccessTokenProviderTest.java
@@ -78,6 +78,9 @@ public class StandardOauth2AccessTokenProviderTest {
private static final String PASSWORD = "password";
private static final String CLIENT_ID = "clientId";
private static final String CLIENT_SECRET = "clientSecret";
+ private static final String SCOPE = "scope";
+ private static final String RESOURCE = "resource";
+ private static final String AUDIENCE = "audience";
private static final long FIVE_MINUTES = 300;
private static final int HTTP_OK = 200;
@@ -120,6 +123,9 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
+
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.SCOPE).getValue()).thenReturn(SCOPE);
+
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.RESOURCE).getValue()).thenReturn(RESOURCE);
+
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUDIENCE).getValue()).thenReturn(AUDIENCE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
@@ -361,6 +367,57 @@ public class StandardOauth2AccessTokenProviderTest {
assertEquals(expectedToken, actualToken);
}
+ @Test
+ public void testKeepPreviousRefreshTokenWhenNewOneIsNotProvided() throws
Exception {
+ // GIVEN
+ String refreshTokenBeforeRefresh = "refresh_token";
+
+ Response response1 = buildResponse(
+ HTTP_OK,
+ "{ \"access_token\":\"not_checking_in_this_test\",
\"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
+ );
+
+ Response response2 = buildResponse(
+ HTTP_OK,
+ "{ \"access_token\":\"not_checking_in_this_test_either\" }"
+ );
+
+
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1,
response2);
+
+ // WHEN
+ testSubject.getAccessDetails();
+ String refreshTokenAfterRefresh =
testSubject.getAccessDetails().getRefreshToken();
+
+ // THEN
+ assertEquals(refreshTokenBeforeRefresh, refreshTokenAfterRefresh);
+ }
+
+ @Test
+ public void testOverwritePreviousRefreshTokenWhenNewOneIsProvided() throws
Exception {
+ // GIVEN
+ String refreshTokenBeforeRefresh = "refresh_token_before_refresh";
+ String expectedRefreshTokenAfterRefresh =
"refresh_token_after_refresh";
+
+ Response response1 = buildResponse(
+ HTTP_OK,
+ "{ \"access_token\":\"not_checking_in_this_test\",
\"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
+ );
+
+ Response response2 = buildResponse(
+ HTTP_OK,
+ "{ \"access_token\":\"not_checking_in_this_test_either\",
\"refresh_token\":\"" + expectedRefreshTokenAfterRefresh + "\" }"
+ );
+
+
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1,
response2);
+
+ // WHEN
+ testSubject.getAccessDetails();
+ String actualRefreshTokenAfterRefresh =
testSubject.getAccessDetails().getRefreshToken();
+
+ // THEN
+ assertEquals(expectedRefreshTokenAfterRefresh,
actualRefreshTokenAfterRefresh);
+ }
+
@Test
public void testBasicAuthentication() throws Exception {
// GIVEN
@@ -377,7 +434,7 @@ public class StandardOauth2AccessTokenProviderTest {
}
@Test
- public void testRequestBodyAuthentication() throws Exception {
+ public void testRequestBodyFormData() throws Exception {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
testSubject.onEnabled(mockContext);
@@ -385,7 +442,11 @@ public class StandardOauth2AccessTokenProviderTest {
// GIVEN
Response response = buildResponse(HTTP_OK,
"{\"access_token\":\"foobar\"}");
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
- String expected = "grant_type=client_credentials&client_id=" +
CLIENT_ID + "&client_secret=" + CLIENT_SECRET;
+ String expected = "grant_type=client_credentials&client_id=" +
CLIENT_ID
+ + "&client_secret=" + CLIENT_SECRET
+ + "&scope=" + SCOPE
+ + "&resource=" + RESOURCE
+ + "&audience=" + AUDIENCE;
// WHEN
testSubject.getAccessDetails();