This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new acff4c79dc Make Dataprep system test self-sufficient (#34880)
acff4c79dc is described below
commit acff4c79dcbb7926923d86adb4c5115e02cf28e6
Author: max <[email protected]>
AuthorDate: Thu Oct 26 21:18:01 2023 +0200
Make Dataprep system test self-sufficient (#34880)
---
airflow/providers/google/cloud/hooks/dataprep.py | 102 +++++-
.../providers/google/cloud/operators/dataprep.py | 4 +-
.../providers/google/cloud/hooks/test_dataprep.py | 383 ++++++++++++++++++++-
.../google/cloud/operators/test_dataprep.py | 2 +-
.../google/cloud/dataprep/example_dataprep.py | 218 +++++++++---
5 files changed, 652 insertions(+), 57 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataprep.py
b/airflow/providers/google/cloud/hooks/dataprep.py
index c01a48d5ae..9e006fa99f 100644
--- a/airflow/providers/google/cloud/hooks/dataprep.py
+++ b/airflow/providers/google/cloud/hooks/dataprep.py
@@ -72,9 +72,10 @@ class GoogleDataprepHook(BaseHook):
conn_type = "dataprep"
hook_name = "Google Dataprep"
- def __init__(self, dataprep_conn_id: str = default_conn_name) -> None:
+ def __init__(self, dataprep_conn_id: str = default_conn_name, api_version:
str = "v4") -> None:
super().__init__()
self.dataprep_conn_id = dataprep_conn_id
+ self.api_version = api_version
conn = self.get_connection(self.dataprep_conn_id)
extras = conn.extra_dejson
self._token = _get_field(extras, "token")
@@ -95,7 +96,7 @@ class GoogleDataprepHook(BaseHook):
:param job_id: The ID of the job that will be fetched
"""
- endpoint_path = f"v4/jobGroups/{job_id}/jobs"
+ endpoint_path = f"{self.api_version}/jobGroups/{job_id}/jobs"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
@@ -113,7 +114,7 @@ class GoogleDataprepHook(BaseHook):
:param include_deleted: if set to "true", will include deleted objects
"""
params: dict[str, Any] = {"embed": embed, "includeDeleted":
include_deleted}
- endpoint_path = f"v4/jobGroups/{job_group_id}"
+ endpoint_path = f"{self.api_version}/jobGroups/{job_group_id}"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers, params=params)
self._raise_for_status(response)
@@ -131,12 +132,26 @@ class GoogleDataprepHook(BaseHook):
:param body_request: The identifier for the recipe you would like to
run.
"""
- endpoint_path = "v4/jobGroups"
+ endpoint_path = f"{self.api_version}/jobGroups"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def create_flow(self, *, body_request: dict) -> dict:
+ """
+ Creates flow.
+
+ :param body_request: Body of the POST request to be sent.
+ For more details check
https://clouddataprep.com/documentation/api#operation/createFlow
+ """
+ endpoint = f"/{self.api_version}/flows"
+ url: str = urljoin(self._base_url, endpoint)
+ response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
+ self._raise_for_status(response)
+ return response.json()
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
def copy_flow(
self, *, flow_id: int, name: str = "", description: str = "",
copy_datasources: bool = False
@@ -149,7 +164,7 @@ class GoogleDataprepHook(BaseHook):
:param description: Description of the copy of the flow
:param copy_datasources: Bool value to define should copies of data
inputs be made or not.
"""
- endpoint_path = f"v4/flows/{flow_id}/copy"
+ endpoint_path = f"{self.api_version}/flows/{flow_id}/copy"
url: str = urljoin(self._base_url, endpoint_path)
body_request = {
"name": name,
@@ -167,7 +182,7 @@ class GoogleDataprepHook(BaseHook):
:param flow_id: ID of the flow to be copied
"""
- endpoint_path = f"v4/flows/{flow_id}"
+ endpoint_path = f"{self.api_version}/flows/{flow_id}"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.delete(url, headers=self._headers)
self._raise_for_status(response)
@@ -180,7 +195,7 @@ class GoogleDataprepHook(BaseHook):
:param flow_id: ID of the flow to be copied
:param body_request: Body of the POST request to be sent.
"""
- endpoint = f"v4/flows/{flow_id}/run"
+ endpoint = f"{self.api_version}/flows/{flow_id}/run"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
self._raise_for_status(response)
@@ -193,7 +208,7 @@ class GoogleDataprepHook(BaseHook):
:param job_group_id: ID of the job group to check
"""
- endpoint = f"/v4/jobGroups/{job_group_id}/status"
+ endpoint = f"/{self.api_version}/jobGroups/{job_group_id}/status"
url: str = urljoin(self._base_url, endpoint)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
@@ -205,3 +220,74 @@ class GoogleDataprepHook(BaseHook):
except HTTPError:
self.log.error(response.json().get("exception"))
raise
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def create_imported_dataset(self, *, body_request: dict) -> dict:
+ """
+ Creates imported dataset.
+
+ :param body_request: Body of the POST request to be sent.
+ For more details check
https://clouddataprep.com/documentation/api#operation/createImportedDataset
+ """
+ endpoint = f"/{self.api_version}/importedDatasets"
+ url: str = urljoin(self._base_url, endpoint)
+ response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
+ self._raise_for_status(response)
+ return response.json()
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def create_wrangled_dataset(self, *, body_request: dict) -> dict:
+ """
+ Creates wrangled dataset.
+
+ :param body_request: Body of the POST request to be sent.
+ For more details check
+
https://clouddataprep.com/documentation/api#operation/createWrangledDataset
+ """
+ endpoint = f"/{self.api_version}/wrangledDatasets"
+ url: str = urljoin(self._base_url, endpoint)
+ response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
+ self._raise_for_status(response)
+ return response.json()
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def create_output_object(self, *, body_request: dict) -> dict:
+ """
+ Creates output.
+
+ :param body_request: Body of the POST request to be sent.
+ For more details check
+
https://clouddataprep.com/documentation/api#operation/createOutputObject
+ """
+ endpoint = f"/{self.api_version}/outputObjects"
+ url: str = urljoin(self._base_url, endpoint)
+ response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
+ self._raise_for_status(response)
+ return response.json()
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def create_write_settings(self, *, body_request: dict) -> dict:
+ """
+ Creates write settings.
+
+ :param body_request: Body of the POST request to be sent.
+ For more details check
+ https://clouddataprep.com/documentation/api#tag/createWriteSetting
+ """
+ endpoint = f"/{self.api_version}/writeSettings"
+ url: str = urljoin(self._base_url, endpoint)
+ response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
+ self._raise_for_status(response)
+ return response.json()
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def delete_imported_dataset(self, *, dataset_id: int) -> None:
+ """
+ Deletes imported dataset.
+
+ :param dataset_id: ID of the imported dataset for removal.
+ """
+ endpoint = f"/{self.api_version}/importedDatasets/{dataset_id}"
+ url: str = urljoin(self._base_url, endpoint)
+ response = requests.delete(url, headers=self._headers)
+ self._raise_for_status(response)
diff --git a/airflow/providers/google/cloud/operators/dataprep.py
b/airflow/providers/google/cloud/operators/dataprep.py
index 7f19f6993b..59710293e6 100644
--- a/airflow/providers/google/cloud/operators/dataprep.py
+++ b/airflow/providers/google/cloud/operators/dataprep.py
@@ -51,13 +51,13 @@ class
DataprepGetJobsForJobGroupOperator(GoogleCloudBaseOperator):
**kwargs,
) -> None:
super().__init__(**kwargs)
- self.dataprep_conn_id = (dataprep_conn_id,)
+ self.dataprep_conn_id = dataprep_conn_id
self.job_group_id = job_group_id
def execute(self, context: Context) -> dict:
self.log.info("Fetching data for job with id: %d ...",
self.job_group_id)
hook = GoogleDataprepHook(
- dataprep_conn_id="dataprep_default",
+ dataprep_conn_id=self.dataprep_conn_id,
)
response = hook.get_jobs_for_job_group(job_id=int(self.job_group_id))
return response
diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py
b/tests/providers/google/cloud/hooks/test_dataprep.py
index e29d2be3dc..a0cef77ae9 100644
--- a/tests/providers/google/cloud/hooks/test_dataprep.py
+++ b/tests/providers/google/cloud/hooks/test_dataprep.py
@@ -35,7 +35,12 @@ EXTRA = {"token": TOKEN}
EMBED = ""
INCLUDE_DELETED = False
DATA = {"wrangledDataset": {"id": RECIPE_ID}}
-URL = "https://api.clouddataprep.com/v4/jobGroups"
+URL_BASE = "https://api.clouddataprep.com"
+URL_JOB_GROUPS = URL_BASE + "/v4/jobGroups"
+URL_IMPORTED_DATASETS = URL_BASE + "/v4/importedDatasets"
+URL_WRANGLED_DATASETS = URL_BASE + "/v4/wrangledDatasets"
+URL_OUTPUT_OBJECTS = URL_BASE + "/v4/outputObjects"
+URL_WRITE_SETTINGS = URL_BASE + "/v4/writeSettings"
class TestGoogleDataprepHook:
@@ -43,12 +48,41 @@ class TestGoogleDataprepHook:
with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn:
conn.return_value.extra_dejson = EXTRA
self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default")
+ self._imported_dataset_id = 12345
+ self._create_imported_dataset_body_request = {
+ "uri": "gs://test/uri",
+ "name": "test_name",
+ }
+ self._create_wrangled_dataset_body_request = {
+ "importedDataset": {"id": "test_dataset_id"},
+ "flow": {"id": "test_flow_id"},
+ "name": "test_dataset_name",
+ }
+ self._create_output_object_body_request = {
+ "execution": "dataflow",
+ "profiler": False,
+ "flowNodeId": "test_flow_node_id",
+ }
+ self._create_write_settings_body_request = {
+ "path": "gs://test/path",
+ "action": "create",
+ "format": "csv",
+ "outputObjectId": "test_output_object_id",
+ }
+ self._expected_create_imported_dataset_hook_data = json.dumps(
+ self._create_imported_dataset_body_request
+ )
+ self._expected_create_wrangled_dataset_hook_data = json.dumps(
+ self._create_wrangled_dataset_body_request
+ )
+ self._expected_create_output_object_hook_data =
json.dumps(self._create_output_object_body_request)
+ self._expected_create_write_settings_hook_data =
json.dumps(self._create_write_settings_body_request)
@patch("airflow.providers.google.cloud.hooks.dataprep.requests.get")
def test_get_jobs_for_job_group_should_be_called_once_with_params(self,
mock_get_request):
self.hook.get_jobs_for_job_group(JOB_ID)
mock_get_request.assert_called_once_with(
- f"{URL}/{JOB_ID}/jobs",
+ f"{URL_JOB_GROUPS}/{JOB_ID}/jobs",
headers={"Content-Type": "application/json", "Authorization":
f"Bearer {TOKEN}"},
)
@@ -93,7 +127,7 @@ class TestGoogleDataprepHook:
def test_get_job_group_should_be_called_once_with_params(self,
mock_get_request):
self.hook.get_job_group(JOB_ID, EMBED, INCLUDE_DELETED)
mock_get_request.assert_called_once_with(
- f"{URL}/{JOB_ID}",
+ f"{URL_JOB_GROUPS}/{JOB_ID}",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {TOKEN}",
@@ -148,7 +182,7 @@ class TestGoogleDataprepHook:
def test_run_job_group_should_be_called_once_with_params(self,
mock_get_request):
self.hook.run_job_group(body_request=DATA)
mock_get_request.assert_called_once_with(
- f"{URL}",
+ f"{URL_JOB_GROUPS}",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {TOKEN}",
@@ -203,7 +237,7 @@ class TestGoogleDataprepHook:
def test_get_job_group_status_should_be_called_once_with_params(self,
mock_get_request):
self.hook.get_job_group_status(job_group_id=JOB_ID)
mock_get_request.assert_called_once_with(
- f"{URL}/{JOB_ID}/status",
+ f"{URL_JOB_GROUPS}/{JOB_ID}/status",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {TOKEN}",
@@ -266,12 +300,290 @@ class TestGoogleDataprepHook:
assert hook._token == "abc"
assert hook._base_url == "abc"
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+ def test_create_imported_dataset_should_be_called_once_with_params(self,
mock_post_request):
+
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+ mock_post_request.assert_called_once_with(
+ URL_IMPORTED_DATASETS,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ data=self._expected_create_imported_dataset_hook_data,
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_create_imported_dataset_should_pass_after_retry(self,
mock_post_request):
+
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+ assert mock_post_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_create_imported_dataset_retry_after_success(self,
mock_post_request):
+ self.hook.create_imported_dataset.retry.sleep = mock.Mock()
+
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+ assert mock_post_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_create_imported_dataset_four_errors(self, mock_post_request):
+ self.hook.create_imported_dataset.retry.sleep = mock.Mock()
+
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+ assert mock_post_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_create_imported_dataset_five_calls(self, mock_post_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.create_imported_dataset.retry.sleep = mock.Mock()
+
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_post_request.call_count == 5
+
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+ def test_create_wrangled_dataset_should_be_called_once_with_params(self,
mock_post_request):
+
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+ mock_post_request.assert_called_once_with(
+ URL_WRANGLED_DATASETS,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ data=self._expected_create_wrangled_dataset_hook_data,
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_create_wrangled_dataset_should_pass_after_retry(self,
mock_post_request):
+
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+ assert mock_post_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_create_wrangled_dataset_retry_after_success(self,
mock_post_request):
+ self.hook.create_wrangled_dataset.retry.sleep = mock.Mock()
+
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+ assert mock_post_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_create_wrangled_dataset_four_errors(self, mock_post_request):
+ self.hook.create_wrangled_dataset.retry.sleep = mock.Mock()
+
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+ assert mock_post_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_create_wrangled_dataset_five_calls(self, mock_post_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.create_wrangled_dataset.retry.sleep = mock.Mock()
+
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_post_request.call_count == 5
+
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+ def test_create_output_object_should_be_called_once_with_params(self,
mock_post_request):
+
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+ mock_post_request.assert_called_once_with(
+ URL_OUTPUT_OBJECTS,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ data=self._expected_create_output_object_hook_data,
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_create_output_objects_should_pass_after_retry(self,
mock_post_request):
+
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+ assert mock_post_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_create_output_objects_retry_after_success(self,
mock_post_request):
+ self.hook.create_output_object.retry.sleep = mock.Mock()
+
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+ assert mock_post_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_create_output_objects_four_errors(self, mock_post_request):
+ self.hook.create_output_object.retry.sleep = mock.Mock()
+
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+ assert mock_post_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_create_output_objects_five_calls(self, mock_post_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.create_output_object.retry.sleep = mock.Mock()
+
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_post_request.call_count == 5
+
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+ def test_create_write_settings_should_be_called_once_with_params(self,
mock_post_request):
+
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+ mock_post_request.assert_called_once_with(
+ URL_WRITE_SETTINGS,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ data=self._expected_create_write_settings_hook_data,
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_create_write_settings_should_pass_after_retry(self,
mock_post_request):
+
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+ assert mock_post_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_create_write_settings_retry_after_success(self,
mock_post_request):
+ self.hook.create_write_settings.retry.sleep = mock.Mock()
+
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+ assert mock_post_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_create_write_settings_four_errors(self, mock_post_request):
+ self.hook.create_write_settings.retry.sleep = mock.Mock()
+
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+ assert mock_post_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_create_write_settings_five_calls(self, mock_post_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.create_write_settings.retry.sleep = mock.Mock()
+
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_post_request.call_count == 5
+
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.delete")
+ def test_delete_imported_dataset_should_be_called_once_with_params(self,
mock_delete_request):
+ self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+ mock_delete_request.assert_called_once_with(
+ f"{URL_IMPORTED_DATASETS}/{self._imported_dataset_id}",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_delete_imported_dataset_should_pass_after_retry(self,
mock_delete_request):
+ self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+ assert mock_delete_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_delete_imported_dataset_retry_after_success(self,
mock_delete_request):
+ self.hook.delete_imported_dataset.retry.sleep = mock.Mock()
+ self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+ assert mock_delete_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_delete_imported_dataset_four_errors(self, mock_delete_request):
+ self.hook.delete_imported_dataset.retry.sleep = mock.Mock()
+ self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+ assert mock_delete_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_delete_imported_dataset_five_calls(self, mock_delete_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.delete_imported_dataset.retry.sleep = mock.Mock()
+
self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_delete_request.call_count == 5
+
class TestGoogleDataprepFlowPathHooks:
_url = "https://api.clouddataprep.com/v4/flows"
def setup_method(self):
self._flow_id = 1234567
+ self._create_flow_body_request = {
+ "name": "test_name",
+ "description": "Test description",
+ }
self._expected_copy_flow_hook_data = json.dumps(
{
"name": "",
@@ -280,10 +592,71 @@ class TestGoogleDataprepFlowPathHooks:
}
)
self._expected_run_flow_hook_data = json.dumps({})
+ self._expected_create_flow_hook_data = json.dumps(
+ {
+ "name": "test_name",
+ "description": "Test description",
+ }
+ )
with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn:
conn.return_value.extra_dejson = EXTRA
self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default")
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+ def test_create_flow_should_be_called_once_with_params(self,
mock_post_request):
+ self.hook.create_flow(body_request=self._create_flow_body_request)
+ mock_post_request.assert_called_once_with(
+ self._url,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ data=self._expected_create_flow_hook_data,
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_create_flow_should_pass_after_retry(self, mock_post_request):
+ self.hook.create_flow(body_request=self._create_flow_body_request)
+ assert mock_post_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_create_flow_should_not_retry_after_success(self,
mock_post_request):
+ self.hook.create_flow.retry.sleep = mock.Mock()
+ self.hook.create_flow(body_request=self._create_flow_body_request)
+ assert mock_post_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_create_flow_should_retry_after_four_errors(self,
mock_post_request):
+ self.hook.create_flow.retry.sleep = mock.Mock()
+ self.hook.create_flow(body_request=self._create_flow_body_request)
+ assert mock_post_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_create_flow_raise_error_after_five_calls(self, mock_post_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.create_flow.retry.sleep = mock.Mock()
+ self.hook.create_flow(body_request=self._create_flow_body_request)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_post_request.call_count == 5
+
@patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
def test_copy_flow_should_be_called_once_with_params(self,
mock_get_request):
self.hook.copy_flow(
diff --git a/tests/providers/google/cloud/operators/test_dataprep.py
b/tests/providers/google/cloud/operators/test_dataprep.py
index 08237d0d51..d5800716a8 100644
--- a/tests/providers/google/cloud/operators/test_dataprep.py
+++ b/tests/providers/google/cloud/operators/test_dataprep.py
@@ -66,7 +66,7 @@ class TestDataprepGetJobsForJobGroupOperator:
dataprep_conn_id=DATAPREP_CONN_ID, job_group_id=JOB_ID,
task_id=TASK_ID
)
op.execute(context={})
- hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
+ hook_mock.assert_called_once_with(dataprep_conn_id=DATAPREP_CONN_ID)
hook_mock.return_value.get_jobs_for_job_group.assert_called_once_with(job_id=JOB_ID)
diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py
b/tests/system/providers/google/cloud/dataprep/example_dataprep.py
index c07cd5a456..126e3f4c9b 100644
--- a/tests/system/providers/google/cloud/dataprep/example_dataprep.py
+++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py
@@ -16,13 +16,25 @@
# under the License.
"""
Example Airflow DAG that shows how to use Google Dataprep.
+
+This DAG relies on the following OS environment variables
+
+* SYSTEM_TESTS_DATAPREP_TOKEN - Dataprep API access token.
+ For generating it please use instruction
+
https://docs.trifacta.com/display/DP/Manage+API+Access+Tokens#:~:text=Enable%20individual%20access-,Generate%20New%20Token,-Via%20UI.
"""
from __future__ import annotations
+import logging
import os
from datetime import datetime
-from airflow.models.dag import DAG
+from airflow import models
+from airflow.decorators import task
+from airflow.models import Connection
+from airflow.models.baseoperator import chain
+from airflow.operators.bash import BashOperator
+from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook
from airflow.providers.google.cloud.operators.dataprep import (
DataprepCopyFlowOperator,
DataprepDeleteFlowOperator,
@@ -33,31 +45,40 @@ from airflow.providers.google.cloud.operators.dataprep
import (
)
from airflow.providers.google.cloud.operators.gcs import
GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.sensors.dataprep import
DataprepJobGroupIsFinishedSensor
+from airflow.settings import Session
from airflow.utils.trigger_rule import TriggerRule
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
DAG_ID = "example_dataprep"
+CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_")
+DATAPREP_TOKEN = os.environ.get("SYSTEM_TESTS_DATAPREP_TOKEN", "")
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
GCS_BUCKET_NAME = f"dataprep-bucket-{DAG_ID}-{ENV_ID}"
GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/"
-FLOW_ID = os.environ.get("FLOW_ID")
-RECIPE_ID = os.environ.get("RECIPE_ID")
-RECIPE_NAME = os.environ.get("RECIPE_NAME")
-WRITE_SETTINGS = (
- {
- "writesettings": [
- {
- "path": GCS_BUCKET_PATH,
- "action": "create",
- "format": "csv",
- }
- ],
- },
-)
+DATASET_URI =
"gs://airflow-system-tests-resources/dataprep/dataset-00000.parquet"
+DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}".replace("-", "_")
+DATASET_WRANGLED_NAME = f"wrangled_{DATASET_NAME}"
+DATASET_WRANGLED_ID = "{{
task_instance.xcom_pull('create_wrangled_dataset')['id'] }}"
+
+FLOW_ID = "{{ task_instance.xcom_pull('create_flow')['id'] }}"
+FLOW_COPY_ID = "{{ task_instance.xcom_pull('copy_flow')['id'] }}"
+RECIPE_NAME = DATASET_WRANGLED_NAME
+WRITE_SETTINGS = {
+ "writesettings": [
+ {
+ "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv",
+ "action": "create",
+ "format": "csv",
+ },
+ ],
+}
+
+log = logging.getLogger(__name__)
-with DAG(
+
+with models.DAG(
DAG_ID,
schedule="@once",
start_date=datetime(2021, 1, 1), # Override to match your needs
@@ -71,42 +92,128 @@ with DAG(
project_id=GCP_PROJECT_ID,
)
+ @task
+ def create_connection(**kwargs) -> None:
+ connection = Connection(
+ conn_id=CONNECTION_ID,
+ description="Example Dataprep connection",
+ conn_type="dataprep",
+ extra={"token": DATAPREP_TOKEN},
+ )
+ session: Session = Session()
+ if session.query(Connection).filter(Connection.conn_id ==
CONNECTION_ID).first():
+ log.warning("Connection %s already exists", CONNECTION_ID)
+ return None
+ session.add(connection)
+ session.commit()
+
+ create_connection_task = create_connection()
+
+ @task
+ def create_imported_dataset():
+ hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+ response = hook.create_imported_dataset(
+ body_request={
+ "uri": DATASET_URI,
+ "name": DATASET_NAME,
+ }
+ )
+ return response
+
+ create_imported_dataset_task = create_imported_dataset()
+
+ @task
+ def create_flow():
+ hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+ response = hook.create_flow(
+ body_request={
+ "name": f"test_flow_{DAG_ID}_{ENV_ID}",
+ "description": "Test flow",
+ }
+ )
+ return response
+
+ create_flow_task = create_flow()
+
+ @task
+ def create_wrangled_dataset(flow, imported_dataset):
+ hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+ response = hook.create_wrangled_dataset(
+ body_request={
+ "importedDataset": {"id": imported_dataset["id"]},
+ "flow": {"id": flow["id"]},
+ "name": DATASET_WRANGLED_NAME,
+ }
+ )
+ return response
+
+ create_wrangled_dataset_task = create_wrangled_dataset(create_flow_task,
create_imported_dataset_task)
+
+ @task
+ def create_output(wrangled_dataset):
+ hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+ response = hook.create_output_object(
+ body_request={
+ "execution": "dataflow",
+ "profiler": False,
+ "flowNodeId": wrangled_dataset["id"],
+ }
+ )
+ return response
+
+ create_output_task = create_output(create_wrangled_dataset_task)
+
+ @task
+ def create_write_settings(output):
+ hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+ response = hook.create_write_settings(
+ body_request={
+ "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv",
+ "action": "create",
+ "format": "csv",
+ "outputObjectId": output["id"],
+ }
+ )
+ return response
+
+ create_write_settings_task = create_write_settings(create_output_task)
+
+ # [START how_to_dataprep_copy_flow_operator]
+ copy_task = DataprepCopyFlowOperator(
+ task_id="copy_flow",
+ dataprep_conn_id=CONNECTION_ID,
+ project_id=GCP_PROJECT_ID,
+ flow_id=FLOW_ID,
+ name=f"copy_{DATASET_NAME}",
+ )
+ # [END how_to_dataprep_copy_flow_operator]
+
# [START how_to_dataprep_run_job_group_operator]
run_job_group_task = DataprepRunJobGroupOperator(
task_id="run_job_group",
+ dataprep_conn_id=CONNECTION_ID,
project_id=GCP_PROJECT_ID,
body_request={
- "wrangledDataset": {"id": RECIPE_ID},
+ "wrangledDataset": {"id": DATASET_WRANGLED_ID},
"overrides": WRITE_SETTINGS,
},
)
# [END how_to_dataprep_run_job_group_operator]
- # [START how_to_dataprep_copy_flow_operator]
- copy_task = DataprepCopyFlowOperator(
- task_id="copy_flow",
- project_id=GCP_PROJECT_ID,
- flow_id=FLOW_ID,
- name=f"dataprep_example_flow_{DAG_ID}_{ENV_ID}",
- )
- # [END how_to_dataprep_copy_flow_operator]
-
# [START how_to_dataprep_dataprep_run_flow_operator]
run_flow_task = DataprepRunFlowOperator(
task_id="run_flow",
+ dataprep_conn_id=CONNECTION_ID,
project_id=GCP_PROJECT_ID,
- flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}",
- body_request={
- "overrides": {
- RECIPE_NAME: WRITE_SETTINGS,
- },
- },
+ flow_id=FLOW_COPY_ID,
+ body_request={},
)
# [END how_to_dataprep_dataprep_run_flow_operator]
# [START how_to_dataprep_get_job_group_operator]
get_job_group_task = DataprepGetJobGroupOperator(
task_id="get_job_group",
+ dataprep_conn_id=CONNECTION_ID,
project_id=GCP_PROJECT_ID,
job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id']
}}",
embed="",
@@ -117,6 +224,7 @@ with DAG(
# [START how_to_dataprep_get_jobs_for_job_group_operator]
get_jobs_for_job_group_task = DataprepGetJobsForJobGroupOperator(
task_id="get_jobs_for_job_group",
+ dataprep_conn_id=CONNECTION_ID,
job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id']
}}",
)
# [END how_to_dataprep_get_jobs_for_job_group_operator]
@@ -124,6 +232,7 @@ with DAG(
# [START how_to_dataprep_job_group_finished_sensor]
check_flow_status_sensor = DataprepJobGroupIsFinishedSensor(
task_id="check_flow_status",
+ dataprep_conn_id=CONNECTION_ID,
job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id']
}}",
)
# [END how_to_dataprep_job_group_finished_sensor]
@@ -131,6 +240,7 @@ with DAG(
# [START how_to_dataprep_job_group_finished_sensor]
check_job_group_status_sensor = DataprepJobGroupIsFinishedSensor(
task_id="check_job_group_status",
+ dataprep_conn_id=CONNECTION_ID,
job_group_id="{{ task_instance.xcom_pull('run_job_group')['id'] }}",
)
# [END how_to_dataprep_job_group_finished_sensor]
@@ -138,29 +248,55 @@ with DAG(
# [START how_to_dataprep_delete_flow_operator]
delete_flow_task = DataprepDeleteFlowOperator(
task_id="delete_flow",
+ dataprep_conn_id=CONNECTION_ID,
flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}",
)
# [END how_to_dataprep_delete_flow_operator]
delete_flow_task.trigger_rule = TriggerRule.ALL_DONE
+ delete_flow_task_original = DataprepDeleteFlowOperator(
+ task_id="delete_flow_original",
+ dataprep_conn_id=CONNECTION_ID,
+ flow_id="{{ task_instance.xcom_pull('create_flow')['id'] }}",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ @task(trigger_rule=TriggerRule.ALL_DONE)
+ def delete_dataset(dataset):
+ hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+ hook.delete_imported_dataset(dataset_id=dataset["id"])
+
+ delete_dataset_task = delete_dataset(create_imported_dataset_task)
+
delete_bucket_task = GCSDeleteBucketOperator(
task_id="delete_bucket",
bucket_name=GCS_BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE,
)
- (
+ delete_connection = BashOperator(
+ task_id="delete_connection",
+ bash_command=f"airflow connections delete {CONNECTION_ID}",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ chain(
# TEST SETUP
- create_bucket_task
- >> copy_task
+ create_bucket_task,
+ create_connection_task,
+ [create_imported_dataset_task, create_flow_task],
+ create_wrangled_dataset_task,
+ create_output_task,
+ create_write_settings_task,
# TEST BODY
- >> [run_job_group_task, run_flow_task]
- >> get_job_group_task
- >> get_jobs_for_job_group_task
+ copy_task,
+ [run_job_group_task, run_flow_task],
+ [get_job_group_task, get_jobs_for_job_group_task],
+ [check_flow_status_sensor, check_job_group_status_sensor],
# TEST TEARDOWN
- >> check_flow_status_sensor
- >> [delete_flow_task, check_job_group_status_sensor]
- >> delete_bucket_task
+ delete_dataset_task,
+ [delete_flow_task, delete_flow_task_original],
+ [delete_bucket_task, delete_connection],
)
from tests.system.utils.watcher import watcher