This is an automated email from the ASF dual-hosted git repository.

cgivre pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new bcabe93  DRILL-8148: Add REST Endpoints to Update OAuth Tokens (#2473)
bcabe93 is described below

commit bcabe9306b3a55d802e8e67d8ca798fe19ebfd46
Author: Charles S. Givre <[email protected]>
AuthorDate: Mon Feb 28 12:45:34 2022 -0500

    DRILL-8148: Add REST Endpoints to Update OAuth Tokens (#2473)
    
    * Initial commit
    
    * Added unit tests
    
    * Fixed unused imports
    
    * Addressed Review comments
    
    * Fixed Link Spelling
---
 .../drill/exec/store/http/TestOAuthProcess.java    |   8 +-
 .../exec/store/http/TestOAuthTokenUpdate.java      | 164 +++++++++++++++++++++
 .../drill/exec/server/rest/StorageResources.java   |  99 ++++++++++++-
 3 files changed, 266 insertions(+), 5 deletions(-)

diff --git 
a/contrib/storage-http/src/test/java/org/apache/drill/exec/store/http/TestOAuthProcess.java
 
b/contrib/storage-http/src/test/java/org/apache/drill/exec/store/http/TestOAuthProcess.java
index ce1edec..943a331 100644
--- 
a/contrib/storage-http/src/test/java/org/apache/drill/exec/store/http/TestOAuthProcess.java
+++ 
b/contrib/storage-http/src/test/java/org/apache/drill/exec/store/http/TestOAuthProcess.java
@@ -101,7 +101,7 @@ public class TestOAuthProcess extends ClusterTest {
       .build();
 
     HttpOAuthConfig oAuthConfig = HttpOAuthConfig.builder()
-      .callbackURL(hostname + "/update_oath2_authtoken")
+      .callbackURL(hostname + "/update_oauth2_authtoken")
       .build();
 
     Map<String, HttpApiConfig> configs = new HashMap<>();
@@ -117,7 +117,7 @@ public class TestOAuthProcess extends ClusterTest {
 
   @Test
   public void testAccessToken() {
-    String url = hostname + "/update_oath2_authtoken?code=ABCDEF";
+    String url = hostname + "/update_oauth2_authtoken?code=ABCDEF";
     Request request = new Request.Builder().url(url).build();
 
     try (MockWebServer server = startServer()) {
@@ -144,7 +144,7 @@ public class TestOAuthProcess extends ClusterTest {
 
   @Test
   public void testGetDataWithAuthentication() {
-    String url = hostname + "/update_oath2_authtoken?code=ABCDEF";
+    String url = hostname + "/update_oauth2_authtoken?code=ABCDEF";
     Request request = new Request.Builder().url(url).build();
     try (MockWebServer server = startServer()) {
       server.enqueue(new 
MockResponse().setResponseCode(200).setBody(ACCESS_TOKEN_RESPONSE));
@@ -190,7 +190,7 @@ public class TestOAuthProcess extends ClusterTest {
 
   @Test
   public void testGetDataWithTokenRefresh() {
-    String url = hostname + "/update_oath2_authtoken?code=ABCDEF";
+    String url = hostname + "/update_oauth2_authtoken?code=ABCDEF";
     Request request = new Request.Builder().url(url).build();
     try (MockWebServer server = startServer()) {
       server.enqueue(new 
MockResponse().setResponseCode(200).setBody(ACCESS_TOKEN_RESPONSE));
diff --git 
a/contrib/storage-http/src/test/java/org/apache/drill/exec/store/http/TestOAuthTokenUpdate.java
 
b/contrib/storage-http/src/test/java/org/apache/drill/exec/store/http/TestOAuthTokenUpdate.java
new file mode 100644
index 0000000..3fd48d8
--- /dev/null
+++ 
b/contrib/storage-http/src/test/java/org/apache/drill/exec/store/http/TestOAuthTokenUpdate.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.drill.exec.store.http;
+
+import okhttp3.Call;
+import okhttp3.FormBody;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.Response;
+import org.apache.drill.common.logical.security.CredentialsProvider;
+import org.apache.drill.common.logical.security.PlainCredentialsProvider;
+import org.apache.drill.exec.ExecConstants;
+import org.apache.drill.exec.oauth.PersistentTokenTable;
+import org.apache.drill.exec.store.StoragePluginRegistry.PluginException;
+import org.apache.drill.exec.store.security.oauth.OAuthTokenCredentials;
+import org.apache.drill.test.ClusterFixtureBuilder;
+import org.apache.drill.test.ClusterTest;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertEquals;
+
+public class TestOAuthTokenUpdate extends ClusterTest {
+
+  private static final String CONNECTION_NAME = "localOauth";
+  private static final int MOCK_SERVER_PORT = 47770;
+  private static final int TIMEOUT = 30;
+  private static String hostname;
+
+  private final OkHttpClient httpClient = new OkHttpClient.Builder()
+    .connectTimeout(TIMEOUT, TimeUnit.SECONDS)
+    .writeTimeout(TIMEOUT, TimeUnit.SECONDS)
+    .readTimeout(TIMEOUT, TimeUnit.SECONDS).build();
+
+  @BeforeClass
+  public static void setup() throws Exception {
+    ClusterFixtureBuilder builder = new ClusterFixtureBuilder(dirTestWatcher)
+      .configProperty(ExecConstants.HTTP_ENABLE, true)
+      .configProperty(ExecConstants.HTTP_PORT_HUNT, true);
+    startCluster(builder);
+    int portNumber = cluster.drillbit().getWebServerPort();
+    hostname = "http://localhost:"; + portNumber + "/storage/" + 
CONNECTION_NAME;
+
+    Map<String, String> creds = new HashMap<>();
+    creds.put("clientID", "12345");
+    creds.put("clientSecret", "54321");
+    creds.put("accessToken", null);
+    creds.put("refreshToken", null);
+    creds.put(OAuthTokenCredentials.TOKEN_URI, "http://localhost:"; + 
MOCK_SERVER_PORT + "/get_access_token");
+
+    CredentialsProvider credentialsProvider = new 
PlainCredentialsProvider(creds);
+
+    HttpApiConfig connectionConfig = HttpApiConfig.builder()
+      .url("http://localhost:"; + MOCK_SERVER_PORT + "/getdata")
+      .method("get")
+      .requireTail(false)
+      .inputType("json")
+      .build();
+
+    HttpOAuthConfig oAuthConfig = HttpOAuthConfig.builder()
+      .callbackURL(hostname + "/update_ouath2_authtoken")
+      .build();
+
+    Map<String, HttpApiConfig> configs = new HashMap<>();
+    configs.put("test", connectionConfig);
+
+    // Add storage plugin for test OAuth
+    HttpStoragePluginConfig mockStorageConfigWithWorkspace =
+      new HttpStoragePluginConfig(false, configs, TIMEOUT, "", 80, "", "", "",
+        oAuthConfig, credentialsProvider);
+    mockStorageConfigWithWorkspace.setEnabled(true);
+    cluster.defineStoragePlugin("localOauth", mockStorageConfigWithWorkspace);
+  }
+
+  @Test
+  public void testUpdateAccessToken() throws Exception {
+    RequestBody formBody = new FormBody.Builder()
+      .add("access_token", "access_approved")
+      .build();
+
+    Request request = new Request.Builder()
+      .url(hostname + "/update_access_token")
+      .post(formBody)
+      .build();
+
+    Call call = httpClient.newCall(request);
+    Response response = call.execute();
+    assertEquals(response.code(), 200);
+
+    PersistentTokenTable tokenTable = getTokenTable();
+    assertEquals(tokenTable.getAccessToken(), "access_approved");
+  }
+
+  @Test
+  public void testUpdateRefreshToken() throws Exception {
+    RequestBody formBody = new FormBody.Builder()
+      .add("refresh_token", "refresh_me")
+      .build();
+
+    Request request = new Request.Builder()
+      .url(hostname + "/update_refresh_token")
+      .post(formBody)
+      .build();
+
+    Call call = httpClient.newCall(request);
+    Response response = call.execute();
+    assertEquals(response.code(), 200);
+
+    PersistentTokenTable tokenTable = getTokenTable();
+    assertEquals(tokenTable.getRefreshToken(), "refresh_me");
+  }
+
+
+  @Test
+  public void testUpdateAllTokens() throws Exception {
+    RequestBody formBody = new FormBody.Builder()
+      .add("access_token", "access_approved")
+      .add("refresh_token", "refresh_me")
+      .build();
+
+    Request request = new Request.Builder()
+      .url(hostname + "/update_oauth_tokens")
+      .post(formBody)
+      .build();
+
+    Call call = httpClient.newCall(request);
+    Response response = call.execute();
+    assertEquals(response.code(), 200);
+
+    PersistentTokenTable tokenTable = getTokenTable();
+    assertEquals(tokenTable.getAccessToken(), "access_approved");
+    assertEquals(tokenTable.getRefreshToken(), "refresh_me");
+  }
+
+  private PersistentTokenTable getTokenTable() throws PluginException {
+    PersistentTokenTable tokenTable = ((HttpStoragePlugin) 
cluster.storageRegistry()
+      .getPlugin("localOauth"))
+      .getTokenRegistry()
+      .getTokenTable("localOauth");
+
+    return tokenTable;
+  }
+}
diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/server/rest/StorageResources.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/server/rest/StorageResources.java
index 1dfdcfb..5238941 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/server/rest/StorageResources.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/server/rest/StorageResources.java
@@ -56,8 +56,10 @@ import okhttp3.Request;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.drill.common.logical.AbstractSecuredStoragePluginConfig;
 import org.apache.drill.common.logical.security.CredentialsProvider;
+import org.apache.drill.exec.oauth.OAuthTokenProvider;
 import org.apache.drill.exec.oauth.PersistentTokenTable;
 import org.apache.drill.exec.oauth.TokenRegistry;
+import org.apache.drill.exec.server.DrillbitContext;
 import org.apache.drill.exec.server.rest.DrillRestServer.UserAuthEnabled;
 import org.apache.drill.exec.store.AbstractStoragePlugin;
 import org.apache.drill.exec.store.StoragePluginRegistry;
@@ -200,8 +202,103 @@ public class StorageResources {
     }
   }
 
+  @POST
+  @Path("/storage/{name}/update_refresh_token")
+  @Produces(MediaType.APPLICATION_JSON)
+  public Response updateRefreshToken(@PathParam("name") String name,
+                                    @FormParam("refresh_token") String 
refreshToken) {
+    try {
+      if (storage.getPlugin(name).getConfig() instanceof 
AbstractSecuredStoragePluginConfig) {
+        DrillbitContext context = ((AbstractStoragePlugin) 
storage.getPlugin(name)).getContext();
+        OAuthTokenProvider tokenProvider = context.getoAuthTokenProvider();
+        PersistentTokenTable tokenTable = 
tokenProvider.getOauthTokenRegistry().getTokenTable(name);
+
+        // Set the access token
+        tokenTable.setRefreshToken(refreshToken);
+
+        return Response.status(Status.OK)
+          .entity("Refresh token have been updated.")
+          .build();
+      } else {
+        logger.error("{} is not a HTTP plugin. You can only add access tokens 
to HTTP plugins.", name);
+        return Response.status(Status.INTERNAL_SERVER_ERROR)
+          .entity(message("Unable to add tokens: %s", name))
+          .build();
+      }
+    } catch (PluginException e) {
+      logger.error("Error when adding tokens to {}", name);
+      return Response.status(Status.INTERNAL_SERVER_ERROR)
+        .entity(message("Unable to add tokens: %s", e.getMessage()))
+        .build();
+    }
+  }
+
+  @POST
+  @Path("/storage/{name}/update_access_token")
+  @Produces(MediaType.APPLICATION_JSON)
+  public Response updateAccessToken(@PathParam("name") String name,
+                                    @FormParam("access_token") String 
accessToken) {
+    try {
+      if (storage.getPlugin(name).getConfig() instanceof 
AbstractSecuredStoragePluginConfig) {
+        DrillbitContext context = ((AbstractStoragePlugin) 
storage.getPlugin(name)).getContext();
+        OAuthTokenProvider tokenProvider = context.getoAuthTokenProvider();
+        PersistentTokenTable tokenTable = 
tokenProvider.getOauthTokenRegistry().getTokenTable(name);
+
+        // Set the access token
+        tokenTable.setAccessToken(accessToken);
+
+        return Response.status(Status.OK)
+          .entity("Access tokens have been updated.")
+          .build();
+      } else {
+        logger.error("{} is not a HTTP plugin. You can only add access tokens 
to HTTP plugins.", name);
+        return Response.status(Status.INTERNAL_SERVER_ERROR)
+          .entity(message("Unable to add tokens: %s", name))
+          .build();
+      }
+    } catch (PluginException e) {
+      logger.error("Error when adding tokens to {}", name);
+      return Response.status(Status.INTERNAL_SERVER_ERROR)
+        .entity(message("Unable to add tokens: %s", e.getMessage()))
+        .build();
+    }
+  }
+
+  @POST
+  @Path("/storage/{name}/update_oauth_tokens")
+  @Produces(MediaType.APPLICATION_JSON)
+  public Response updateOAuthTokens(@PathParam("name") String name,
+                                    @FormParam("access_token") String 
accessToken,
+                                    @FormParam("refresh_token") String 
refreshToken) {
+    try {
+      if (storage.getPlugin(name).getConfig() instanceof 
AbstractSecuredStoragePluginConfig) {
+        DrillbitContext context = ((AbstractStoragePlugin) 
storage.getPlugin(name)).getContext();
+        OAuthTokenProvider tokenProvider = context.getoAuthTokenProvider();
+        PersistentTokenTable tokenTable = 
tokenProvider.getOauthTokenRegistry().getTokenTable(name);
+
+        // Set the access and refresh token
+        tokenTable.setAccessToken(accessToken);
+        tokenTable.setRefreshToken(refreshToken);
+
+        return Response.status(Status.OK)
+          .entity("Access tokens have been updated.")
+          .build();
+      } else {
+        logger.error("{} is not a HTTP plugin. You can only add access tokens 
to HTTP plugins.", name);
+        return Response.status(Status.INTERNAL_SERVER_ERROR)
+          .entity(message("Unable to add tokens: %s", name))
+          .build();
+      }
+    } catch (PluginException e) {
+      logger.error("Error when adding tokens to {}", name);
+      return Response.status(Status.INTERNAL_SERVER_ERROR)
+        .entity(message("Unable to add tokens: %s", e.getMessage()))
+        .build();
+    }
+  }
+
   @GET
-  @Path("/storage/{name}/update_oath2_authtoken")
+  @Path("/storage/{name}/update_oauth2_authtoken")
   @Produces(MediaType.TEXT_HTML)
   public Response updateAuthToken(@PathParam("name") String name, 
@QueryParam("code") String code) {
     try {

Reply via email to