This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b35325840 Add test_connection method to `GoogleBaseHook` (#24682)
3b35325840 is described below

commit 3b35325840e484f86df00e087410f5d5da4b9130
Author: Phani Kumar <[email protected]>
AuthorDate: Wed Jul 6 20:27:20 2022 +0530

    Add test_connection method to `GoogleBaseHook` (#24682)
    
    This PR adds test connection functionality to Google Cloud connection type 
in airflow UI
---
 .../providers/google/common/hooks/base_google.py   | 24 +++++++++++++++++++++-
 .../google/common/hooks/test_base_google.py        | 18 ++++++++++++++++
 2 files changed, 41 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/google/common/hooks/base_google.py 
b/airflow/providers/google/common/hooks/base_google.py
index 997c72c4e6..4c4ab5c888 100644
--- a/airflow/providers/google/common/hooks/base_google.py
+++ b/airflow/providers/google/common/hooks/base_google.py
@@ -31,6 +31,7 @@ import google.auth
 import google.auth.credentials
 import google.oauth2.service_account
 import google_auth_httplib2
+import requests
 import tenacity
 from google.api_core.exceptions import Forbidden, ResourceExhausted, 
TooManyRequests
 from google.api_core.gapic_v1.client_info import ClientInfo
@@ -270,7 +271,12 @@ class GoogleBaseHook(BaseHook):
 
     def _get_access_token(self) -> str:
         """Returns a valid access token from Google API Credentials"""
-        return self._get_credentials().token
+        credentials = self._get_credentials()
+        auth_req = google.auth.transport.requests.Request()
+        # credentials.token is None
+        # Need to refresh credentials to populate the token
+        credentials.refresh(auth_req)
+        return credentials.token
 
     @functools.lru_cache(maxsize=None)
     def _get_credentials_email(self) -> str:
@@ -580,3 +586,19 @@ class GoogleBaseHook(BaseHook):
         while done is False:
             _, done = downloader.next_chunk()
         file_handle.flush()
+
+    def test_connection(self):
+        """Test the Google cloud connectivity from UI"""
+        status, message = False, ''
+        try:
+            token = self._get_access_token()
+            url = 
f"https://www.googleapis.com/oauth2/v3/tokeninfo?access_token={token}";
+            response = requests.post(url)
+            if response.status_code == 200:
+                status = True
+                message = 'Connection successfully tested'
+        except Exception as e:
+            status = False
+            message = str(e)
+
+        return status, message
diff --git a/tests/providers/google/common/hooks/test_base_google.py 
b/tests/providers/google/common/hooks/test_base_google.py
index a60d3a4e22..247dce40b5 100644
--- a/tests/providers/google/common/hooks/test_base_google.py
+++ b/tests/providers/google/common/hooks/test_base_google.py
@@ -341,6 +341,24 @@ class TestGoogleBaseHook(unittest.TestCase):
         )
         assert ('CREDENTIALS', 'PROJECT_ID') == result
 
+    @mock.patch('requests.post')
+    @mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
+    def test_connection_success(self, mock_get_creds_and_proj_id, 
requests_post):
+        requests_post.return_value.status_code = 200
+        credentials = mock.MagicMock()
+        type(credentials).token = mock.PropertyMock(return_value="TOKEN")
+        mock_get_creds_and_proj_id.return_value = (credentials, "PROJECT_ID")
+        self.instance.extras = {}
+        result = self.instance.test_connection()
+        assert result == (True, 'Connection successfully tested')
+
+    @mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
+    def test_connection_failure(self, mock_get_creds_and_proj_id):
+        mock_get_creds_and_proj_id.side_effect = AirflowException('Invalid key 
JSON.')
+        self.instance.extras = {}
+        result = self.instance.test_connection()
+        assert result == (False, 'Invalid key JSON.')
+
     @mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
     def test_get_credentials_and_project_id_with_service_account_file(self, 
mock_get_creds_and_proj_id):
         mock_credentials = mock.MagicMock()

Reply via email to