This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3b35325840 Add test_connection method to `GoogleBaseHook` (#24682)
3b35325840 is described below
commit 3b35325840e484f86df00e087410f5d5da4b9130
Author: Phani Kumar <[email protected]>
AuthorDate: Wed Jul 6 20:27:20 2022 +0530
Add test_connection method to `GoogleBaseHook` (#24682)
This PR adds test connection functionality to Google Cloud connection type
in airflow UI
---
.../providers/google/common/hooks/base_google.py | 24 +++++++++++++++++++++-
.../google/common/hooks/test_base_google.py | 18 ++++++++++++++++
2 files changed, 41 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/google/common/hooks/base_google.py
b/airflow/providers/google/common/hooks/base_google.py
index 997c72c4e6..4c4ab5c888 100644
--- a/airflow/providers/google/common/hooks/base_google.py
+++ b/airflow/providers/google/common/hooks/base_google.py
@@ -31,6 +31,7 @@ import google.auth
import google.auth.credentials
import google.oauth2.service_account
import google_auth_httplib2
+import requests
import tenacity
from google.api_core.exceptions import Forbidden, ResourceExhausted,
TooManyRequests
from google.api_core.gapic_v1.client_info import ClientInfo
@@ -270,7 +271,12 @@ class GoogleBaseHook(BaseHook):
def _get_access_token(self) -> str:
"""Returns a valid access token from Google API Credentials"""
- return self._get_credentials().token
+ credentials = self._get_credentials()
+ auth_req = google.auth.transport.requests.Request()
+ # credentials.token is None
+ # Need to refresh credentials to populate the token
+ credentials.refresh(auth_req)
+ return credentials.token
@functools.lru_cache(maxsize=None)
def _get_credentials_email(self) -> str:
@@ -580,3 +586,19 @@ class GoogleBaseHook(BaseHook):
while done is False:
_, done = downloader.next_chunk()
file_handle.flush()
+
+ def test_connection(self):
+ """Test the Google cloud connectivity from UI"""
+ status, message = False, ''
+ try:
+ token = self._get_access_token()
+ url =
f"https://www.googleapis.com/oauth2/v3/tokeninfo?access_token={token}"
+ response = requests.post(url)
+ if response.status_code == 200:
+ status = True
+ message = 'Connection successfully tested'
+ except Exception as e:
+ status = False
+ message = str(e)
+
+ return status, message
diff --git a/tests/providers/google/common/hooks/test_base_google.py
b/tests/providers/google/common/hooks/test_base_google.py
index a60d3a4e22..247dce40b5 100644
--- a/tests/providers/google/common/hooks/test_base_google.py
+++ b/tests/providers/google/common/hooks/test_base_google.py
@@ -341,6 +341,24 @@ class TestGoogleBaseHook(unittest.TestCase):
)
assert ('CREDENTIALS', 'PROJECT_ID') == result
+ @mock.patch('requests.post')
+ @mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
+ def test_connection_success(self, mock_get_creds_and_proj_id,
requests_post):
+ requests_post.return_value.status_code = 200
+ credentials = mock.MagicMock()
+ type(credentials).token = mock.PropertyMock(return_value="TOKEN")
+ mock_get_creds_and_proj_id.return_value = (credentials, "PROJECT_ID")
+ self.instance.extras = {}
+ result = self.instance.test_connection()
+ assert result == (True, 'Connection successfully tested')
+
+ @mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
+ def test_connection_failure(self, mock_get_creds_and_proj_id):
+ mock_get_creds_and_proj_id.side_effect = AirflowException('Invalid key
JSON.')
+ self.instance.extras = {}
+ result = self.instance.test_connection()
+ assert result == (False, 'Invalid key JSON.')
+
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
def test_get_credentials_and_project_id_with_service_account_file(self,
mock_get_creds_and_proj_id):
mock_credentials = mock.MagicMock()