This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new acff4c79dc Make Dataprep system test self-sufficient (#34880)
acff4c79dc is described below

commit acff4c79dcbb7926923d86adb4c5115e02cf28e6
Author: max <[email protected]>
AuthorDate: Thu Oct 26 21:18:01 2023 +0200

    Make Dataprep system test self-sufficient (#34880)
---
 airflow/providers/google/cloud/hooks/dataprep.py   | 102 +++++-
 .../providers/google/cloud/operators/dataprep.py   |   4 +-
 .../providers/google/cloud/hooks/test_dataprep.py  | 383 ++++++++++++++++++++-
 .../google/cloud/operators/test_dataprep.py        |   2 +-
 .../google/cloud/dataprep/example_dataprep.py      | 218 +++++++++---
 5 files changed, 652 insertions(+), 57 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataprep.py 
b/airflow/providers/google/cloud/hooks/dataprep.py
index c01a48d5ae..9e006fa99f 100644
--- a/airflow/providers/google/cloud/hooks/dataprep.py
+++ b/airflow/providers/google/cloud/hooks/dataprep.py
@@ -72,9 +72,10 @@ class GoogleDataprepHook(BaseHook):
     conn_type = "dataprep"
     hook_name = "Google Dataprep"
 
-    def __init__(self, dataprep_conn_id: str = default_conn_name) -> None:
+    def __init__(self, dataprep_conn_id: str = default_conn_name, api_version: 
str = "v4") -> None:
         super().__init__()
         self.dataprep_conn_id = dataprep_conn_id
+        self.api_version = api_version
         conn = self.get_connection(self.dataprep_conn_id)
         extras = conn.extra_dejson
         self._token = _get_field(extras, "token")
@@ -95,7 +96,7 @@ class GoogleDataprepHook(BaseHook):
 
         :param job_id: The ID of the job that will be fetched
         """
-        endpoint_path = f"v4/jobGroups/{job_id}/jobs"
+        endpoint_path = f"{self.api_version}/jobGroups/{job_id}/jobs"
         url: str = urljoin(self._base_url, endpoint_path)
         response = requests.get(url, headers=self._headers)
         self._raise_for_status(response)
@@ -113,7 +114,7 @@ class GoogleDataprepHook(BaseHook):
         :param include_deleted: if set to "true", will include deleted objects
         """
         params: dict[str, Any] = {"embed": embed, "includeDeleted": 
include_deleted}
-        endpoint_path = f"v4/jobGroups/{job_group_id}"
+        endpoint_path = f"{self.api_version}/jobGroups/{job_group_id}"
         url: str = urljoin(self._base_url, endpoint_path)
         response = requests.get(url, headers=self._headers, params=params)
         self._raise_for_status(response)
@@ -131,12 +132,26 @@ class GoogleDataprepHook(BaseHook):
 
         :param body_request: The identifier for the recipe you would like to 
run.
         """
-        endpoint_path = "v4/jobGroups"
+        endpoint_path = f"{self.api_version}/jobGroups"
         url: str = urljoin(self._base_url, endpoint_path)
         response = requests.post(url, headers=self._headers, 
data=json.dumps(body_request))
         self._raise_for_status(response)
         return response.json()
 
+    @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, 
max=10))
+    def create_flow(self, *, body_request: dict) -> dict:
+        """
+        Creates flow.
+
+        :param body_request: Body of the POST request to be sent.
+            For more details check 
https://clouddataprep.com/documentation/api#operation/createFlow
+        """
+        endpoint = f"/{self.api_version}/flows"
+        url: str = urljoin(self._base_url, endpoint)
+        response = requests.post(url, headers=self._headers, 
data=json.dumps(body_request))
+        self._raise_for_status(response)
+        return response.json()
+
     @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, 
max=10))
     def copy_flow(
         self, *, flow_id: int, name: str = "", description: str = "", 
copy_datasources: bool = False
@@ -149,7 +164,7 @@ class GoogleDataprepHook(BaseHook):
         :param description: Description of the copy of the flow
         :param copy_datasources: Bool value to define should copies of data 
inputs be made or not.
         """
-        endpoint_path = f"v4/flows/{flow_id}/copy"
+        endpoint_path = f"{self.api_version}/flows/{flow_id}/copy"
         url: str = urljoin(self._base_url, endpoint_path)
         body_request = {
             "name": name,
@@ -167,7 +182,7 @@ class GoogleDataprepHook(BaseHook):
 
         :param flow_id: ID of the flow to be copied
         """
-        endpoint_path = f"v4/flows/{flow_id}"
+        endpoint_path = f"{self.api_version}/flows/{flow_id}"
         url: str = urljoin(self._base_url, endpoint_path)
         response = requests.delete(url, headers=self._headers)
         self._raise_for_status(response)
@@ -180,7 +195,7 @@ class GoogleDataprepHook(BaseHook):
         :param flow_id: ID of the flow to be copied
         :param body_request: Body of the POST request to be sent.
         """
-        endpoint = f"v4/flows/{flow_id}/run"
+        endpoint = f"{self.api_version}/flows/{flow_id}/run"
         url: str = urljoin(self._base_url, endpoint)
         response = requests.post(url, headers=self._headers, 
data=json.dumps(body_request))
         self._raise_for_status(response)
@@ -193,7 +208,7 @@ class GoogleDataprepHook(BaseHook):
 
         :param job_group_id: ID of the job group to check
         """
-        endpoint = f"/v4/jobGroups/{job_group_id}/status"
+        endpoint = f"/{self.api_version}/jobGroups/{job_group_id}/status"
         url: str = urljoin(self._base_url, endpoint)
         response = requests.get(url, headers=self._headers)
         self._raise_for_status(response)
@@ -205,3 +220,74 @@ class GoogleDataprepHook(BaseHook):
         except HTTPError:
             self.log.error(response.json().get("exception"))
             raise
+
+    @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, 
max=10))
+    def create_imported_dataset(self, *, body_request: dict) -> dict:
+        """
+        Creates imported dataset.
+
+        :param body_request: Body of the POST request to be sent.
+            For more details check 
https://clouddataprep.com/documentation/api#operation/createImportedDataset
+        """
+        endpoint = f"/{self.api_version}/importedDatasets"
+        url: str = urljoin(self._base_url, endpoint)
+        response = requests.post(url, headers=self._headers, 
data=json.dumps(body_request))
+        self._raise_for_status(response)
+        return response.json()
+
+    @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, 
max=10))
+    def create_wrangled_dataset(self, *, body_request: dict) -> dict:
+        """
+        Creates wrangled dataset.
+
+        :param body_request: Body of the POST request to be sent.
+            For more details check
+            
https://clouddataprep.com/documentation/api#operation/createWrangledDataset
+        """
+        endpoint = f"/{self.api_version}/wrangledDatasets"
+        url: str = urljoin(self._base_url, endpoint)
+        response = requests.post(url, headers=self._headers, 
data=json.dumps(body_request))
+        self._raise_for_status(response)
+        return response.json()
+
+    @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, 
max=10))
+    def create_output_object(self, *, body_request: dict) -> dict:
+        """
+        Creates output.
+
+        :param body_request: Body of the POST request to be sent.
+            For more details check
+            
https://clouddataprep.com/documentation/api#operation/createOutputObject
+        """
+        endpoint = f"/{self.api_version}/outputObjects"
+        url: str = urljoin(self._base_url, endpoint)
+        response = requests.post(url, headers=self._headers, 
data=json.dumps(body_request))
+        self._raise_for_status(response)
+        return response.json()
+
+    @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, 
max=10))
+    def create_write_settings(self, *, body_request: dict) -> dict:
+        """
+        Creates write settings.
+
+        :param body_request: Body of the POST request to be sent.
+            For more details check
+            https://clouddataprep.com/documentation/api#tag/createWriteSetting
+        """
+        endpoint = f"/{self.api_version}/writeSettings"
+        url: str = urljoin(self._base_url, endpoint)
+        response = requests.post(url, headers=self._headers, 
data=json.dumps(body_request))
+        self._raise_for_status(response)
+        return response.json()
+
+    @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, 
max=10))
+    def delete_imported_dataset(self, *, dataset_id: int) -> None:
+        """
+        Deletes imported dataset.
+
+        :param dataset_id: ID of the imported dataset for removal.
+        """
+        endpoint = f"/{self.api_version}/importedDatasets/{dataset_id}"
+        url: str = urljoin(self._base_url, endpoint)
+        response = requests.delete(url, headers=self._headers)
+        self._raise_for_status(response)
diff --git a/airflow/providers/google/cloud/operators/dataprep.py 
b/airflow/providers/google/cloud/operators/dataprep.py
index 7f19f6993b..59710293e6 100644
--- a/airflow/providers/google/cloud/operators/dataprep.py
+++ b/airflow/providers/google/cloud/operators/dataprep.py
@@ -51,13 +51,13 @@ class 
DataprepGetJobsForJobGroupOperator(GoogleCloudBaseOperator):
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-        self.dataprep_conn_id = (dataprep_conn_id,)
+        self.dataprep_conn_id = dataprep_conn_id
         self.job_group_id = job_group_id
 
     def execute(self, context: Context) -> dict:
         self.log.info("Fetching data for job with id: %d ...", 
self.job_group_id)
         hook = GoogleDataprepHook(
-            dataprep_conn_id="dataprep_default",
+            dataprep_conn_id=self.dataprep_conn_id,
         )
         response = hook.get_jobs_for_job_group(job_id=int(self.job_group_id))
         return response
diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py 
b/tests/providers/google/cloud/hooks/test_dataprep.py
index e29d2be3dc..a0cef77ae9 100644
--- a/tests/providers/google/cloud/hooks/test_dataprep.py
+++ b/tests/providers/google/cloud/hooks/test_dataprep.py
@@ -35,7 +35,12 @@ EXTRA = {"token": TOKEN}
 EMBED = ""
 INCLUDE_DELETED = False
 DATA = {"wrangledDataset": {"id": RECIPE_ID}}
-URL = "https://api.clouddataprep.com/v4/jobGroups";
+URL_BASE = "https://api.clouddataprep.com";
+URL_JOB_GROUPS = URL_BASE + "/v4/jobGroups"
+URL_IMPORTED_DATASETS = URL_BASE + "/v4/importedDatasets"
+URL_WRANGLED_DATASETS = URL_BASE + "/v4/wrangledDatasets"
+URL_OUTPUT_OBJECTS = URL_BASE + "/v4/outputObjects"
+URL_WRITE_SETTINGS = URL_BASE + "/v4/writeSettings"
 
 
 class TestGoogleDataprepHook:
@@ -43,12 +48,41 @@ class TestGoogleDataprepHook:
         with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn:
             conn.return_value.extra_dejson = EXTRA
             self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default")
+        self._imported_dataset_id = 12345
+        self._create_imported_dataset_body_request = {
+            "uri": "gs://test/uri",
+            "name": "test_name",
+        }
+        self._create_wrangled_dataset_body_request = {
+            "importedDataset": {"id": "test_dataset_id"},
+            "flow": {"id": "test_flow_id"},
+            "name": "test_dataset_name",
+        }
+        self._create_output_object_body_request = {
+            "execution": "dataflow",
+            "profiler": False,
+            "flowNodeId": "test_flow_node_id",
+        }
+        self._create_write_settings_body_request = {
+            "path": "gs://test/path",
+            "action": "create",
+            "format": "csv",
+            "outputObjectId": "test_output_object_id",
+        }
+        self._expected_create_imported_dataset_hook_data = json.dumps(
+            self._create_imported_dataset_body_request
+        )
+        self._expected_create_wrangled_dataset_hook_data = json.dumps(
+            self._create_wrangled_dataset_body_request
+        )
+        self._expected_create_output_object_hook_data = 
json.dumps(self._create_output_object_body_request)
+        self._expected_create_write_settings_hook_data = 
json.dumps(self._create_write_settings_body_request)
 
     @patch("airflow.providers.google.cloud.hooks.dataprep.requests.get")
     def test_get_jobs_for_job_group_should_be_called_once_with_params(self, 
mock_get_request):
         self.hook.get_jobs_for_job_group(JOB_ID)
         mock_get_request.assert_called_once_with(
-            f"{URL}/{JOB_ID}/jobs",
+            f"{URL_JOB_GROUPS}/{JOB_ID}/jobs",
             headers={"Content-Type": "application/json", "Authorization": 
f"Bearer {TOKEN}"},
         )
 
@@ -93,7 +127,7 @@ class TestGoogleDataprepHook:
     def test_get_job_group_should_be_called_once_with_params(self, 
mock_get_request):
         self.hook.get_job_group(JOB_ID, EMBED, INCLUDE_DELETED)
         mock_get_request.assert_called_once_with(
-            f"{URL}/{JOB_ID}",
+            f"{URL_JOB_GROUPS}/{JOB_ID}",
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {TOKEN}",
@@ -148,7 +182,7 @@ class TestGoogleDataprepHook:
     def test_run_job_group_should_be_called_once_with_params(self, 
mock_get_request):
         self.hook.run_job_group(body_request=DATA)
         mock_get_request.assert_called_once_with(
-            f"{URL}",
+            f"{URL_JOB_GROUPS}",
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {TOKEN}",
@@ -203,7 +237,7 @@ class TestGoogleDataprepHook:
     def test_get_job_group_status_should_be_called_once_with_params(self, 
mock_get_request):
         self.hook.get_job_group_status(job_group_id=JOB_ID)
         mock_get_request.assert_called_once_with(
-            f"{URL}/{JOB_ID}/status",
+            f"{URL_JOB_GROUPS}/{JOB_ID}/status",
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {TOKEN}",
@@ -266,12 +300,290 @@ class TestGoogleDataprepHook:
             assert hook._token == "abc"
             assert hook._base_url == "abc"
 
+    @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+    def test_create_imported_dataset_should_be_called_once_with_params(self, 
mock_post_request):
+        
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+        mock_post_request.assert_called_once_with(
+            URL_IMPORTED_DATASETS,
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {TOKEN}",
+            },
+            data=self._expected_create_imported_dataset_hook_data,
+        )
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), mock.MagicMock()],
+    )
+    def test_create_imported_dataset_should_pass_after_retry(self, 
mock_post_request):
+        
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+        assert mock_post_request.call_count == 2
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[mock.MagicMock(), HTTPError()],
+    )
+    def test_create_imported_dataset_retry_after_success(self, 
mock_post_request):
+        self.hook.create_imported_dataset.retry.sleep = mock.Mock()
+        
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+        assert mock_post_request.call_count == 1
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            mock.MagicMock(),
+        ],
+    )
+    def test_create_imported_dataset_four_errors(self, mock_post_request):
+        self.hook.create_imported_dataset.retry.sleep = mock.Mock()
+        
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+        assert mock_post_request.call_count == 5
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), 
HTTPError()],
+    )
+    def test_create_imported_dataset_five_calls(self, mock_post_request):
+        with pytest.raises(RetryError) as ctx:
+            self.hook.create_imported_dataset.retry.sleep = mock.Mock()
+            
self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request)
+        assert "HTTPError" in str(ctx.value)
+        assert mock_post_request.call_count == 5
+
+    @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+    def test_create_wrangled_dataset_should_be_called_once_with_params(self, 
mock_post_request):
+        
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+        mock_post_request.assert_called_once_with(
+            URL_WRANGLED_DATASETS,
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {TOKEN}",
+            },
+            data=self._expected_create_wrangled_dataset_hook_data,
+        )
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), mock.MagicMock()],
+    )
+    def test_create_wrangled_dataset_should_pass_after_retry(self, 
mock_post_request):
+        
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+        assert mock_post_request.call_count == 2
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[mock.MagicMock(), HTTPError()],
+    )
+    def test_create_wrangled_dataset_retry_after_success(self, 
mock_post_request):
+        self.hook.create_wrangled_dataset.retry.sleep = mock.Mock()
+        
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+        assert mock_post_request.call_count == 1
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            mock.MagicMock(),
+        ],
+    )
+    def test_create_wrangled_dataset_four_errors(self, mock_post_request):
+        self.hook.create_wrangled_dataset.retry.sleep = mock.Mock()
+        
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+        assert mock_post_request.call_count == 5
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), 
HTTPError()],
+    )
+    def test_create_wrangled_dataset_five_calls(self, mock_post_request):
+        with pytest.raises(RetryError) as ctx:
+            self.hook.create_wrangled_dataset.retry.sleep = mock.Mock()
+            
self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request)
+        assert "HTTPError" in str(ctx.value)
+        assert mock_post_request.call_count == 5
+
+    @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+    def test_create_output_object_should_be_called_once_with_params(self, 
mock_post_request):
+        
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+        mock_post_request.assert_called_once_with(
+            URL_OUTPUT_OBJECTS,
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {TOKEN}",
+            },
+            data=self._expected_create_output_object_hook_data,
+        )
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), mock.MagicMock()],
+    )
+    def test_create_output_objects_should_pass_after_retry(self, 
mock_post_request):
+        
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+        assert mock_post_request.call_count == 2
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[mock.MagicMock(), HTTPError()],
+    )
+    def test_create_output_objects_retry_after_success(self, 
mock_post_request):
+        self.hook.create_output_object.retry.sleep = mock.Mock()
+        
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+        assert mock_post_request.call_count == 1
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            mock.MagicMock(),
+        ],
+    )
+    def test_create_output_objects_four_errors(self, mock_post_request):
+        self.hook.create_output_object.retry.sleep = mock.Mock()
+        
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+        assert mock_post_request.call_count == 5
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), 
HTTPError()],
+    )
+    def test_create_output_objects_five_calls(self, mock_post_request):
+        with pytest.raises(RetryError) as ctx:
+            self.hook.create_output_object.retry.sleep = mock.Mock()
+            
self.hook.create_output_object(body_request=self._create_output_object_body_request)
+        assert "HTTPError" in str(ctx.value)
+        assert mock_post_request.call_count == 5
+
+    @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+    def test_create_write_settings_should_be_called_once_with_params(self, 
mock_post_request):
+        
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+        mock_post_request.assert_called_once_with(
+            URL_WRITE_SETTINGS,
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {TOKEN}",
+            },
+            data=self._expected_create_write_settings_hook_data,
+        )
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), mock.MagicMock()],
+    )
+    def test_create_write_settings_should_pass_after_retry(self, 
mock_post_request):
+        
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+        assert mock_post_request.call_count == 2
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[mock.MagicMock(), HTTPError()],
+    )
+    def test_create_write_settings_retry_after_success(self, 
mock_post_request):
+        self.hook.create_write_settings.retry.sleep = mock.Mock()
+        
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+        assert mock_post_request.call_count == 1
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            mock.MagicMock(),
+        ],
+    )
+    def test_create_write_settings_four_errors(self, mock_post_request):
+        self.hook.create_write_settings.retry.sleep = mock.Mock()
+        
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+        assert mock_post_request.call_count == 5
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), 
HTTPError()],
+    )
+    def test_create_write_settings_five_calls(self, mock_post_request):
+        with pytest.raises(RetryError) as ctx:
+            self.hook.create_write_settings.retry.sleep = mock.Mock()
+            
self.hook.create_write_settings(body_request=self._create_write_settings_body_request)
+        assert "HTTPError" in str(ctx.value)
+        assert mock_post_request.call_count == 5
+
+    @patch("airflow.providers.google.cloud.hooks.dataprep.requests.delete")
+    def test_delete_imported_dataset_should_be_called_once_with_params(self, 
mock_delete_request):
+        self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+        mock_delete_request.assert_called_once_with(
+            f"{URL_IMPORTED_DATASETS}/{self._imported_dataset_id}",
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {TOKEN}",
+            },
+        )
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+        side_effect=[HTTPError(), mock.MagicMock()],
+    )
+    def test_delete_imported_dataset_should_pass_after_retry(self, 
mock_delete_request):
+        self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+        assert mock_delete_request.call_count == 2
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+        side_effect=[mock.MagicMock(), HTTPError()],
+    )
+    def test_delete_imported_dataset_retry_after_success(self, 
mock_delete_request):
+        self.hook.delete_imported_dataset.retry.sleep = mock.Mock()
+        self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+        assert mock_delete_request.call_count == 1
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+        side_effect=[
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            mock.MagicMock(),
+        ],
+    )
+    def test_delete_imported_dataset_four_errors(self, mock_delete_request):
+        self.hook.delete_imported_dataset.retry.sleep = mock.Mock()
+        self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+        assert mock_delete_request.call_count == 5
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+        side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), 
HTTPError()],
+    )
+    def test_delete_imported_dataset_five_calls(self, mock_delete_request):
+        with pytest.raises(RetryError) as ctx:
+            self.hook.delete_imported_dataset.retry.sleep = mock.Mock()
+            
self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id)
+        assert "HTTPError" in str(ctx.value)
+        assert mock_delete_request.call_count == 5
+
 
 class TestGoogleDataprepFlowPathHooks:
     _url = "https://api.clouddataprep.com/v4/flows";
 
     def setup_method(self):
         self._flow_id = 1234567
+        self._create_flow_body_request = {
+            "name": "test_name",
+            "description": "Test description",
+        }
         self._expected_copy_flow_hook_data = json.dumps(
             {
                 "name": "",
@@ -280,10 +592,71 @@ class TestGoogleDataprepFlowPathHooks:
             }
         )
         self._expected_run_flow_hook_data = json.dumps({})
+        self._expected_create_flow_hook_data = json.dumps(
+            {
+                "name": "test_name",
+                "description": "Test description",
+            }
+        )
         with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn:
             conn.return_value.extra_dejson = EXTRA
             self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default")
 
+    @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+    def test_create_flow_should_be_called_once_with_params(self, 
mock_post_request):
+        self.hook.create_flow(body_request=self._create_flow_body_request)
+        mock_post_request.assert_called_once_with(
+            self._url,
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {TOKEN}",
+            },
+            data=self._expected_create_flow_hook_data,
+        )
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), mock.MagicMock()],
+    )
+    def test_create_flow_should_pass_after_retry(self, mock_post_request):
+        self.hook.create_flow(body_request=self._create_flow_body_request)
+        assert mock_post_request.call_count == 2
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[mock.MagicMock(), HTTPError()],
+    )
+    def test_create_flow_should_not_retry_after_success(self, 
mock_post_request):
+        self.hook.create_flow.retry.sleep = mock.Mock()
+        self.hook.create_flow(body_request=self._create_flow_body_request)
+        assert mock_post_request.call_count == 1
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            HTTPError(),
+            mock.MagicMock(),
+        ],
+    )
+    def test_create_flow_should_retry_after_four_errors(self, 
mock_post_request):
+        self.hook.create_flow.retry.sleep = mock.Mock()
+        self.hook.create_flow(body_request=self._create_flow_body_request)
+        assert mock_post_request.call_count == 5
+
+    @patch(
+        "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+        side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), 
HTTPError()],
+    )
+    def test_create_flow_raise_error_after_five_calls(self, mock_post_request):
+        with pytest.raises(RetryError) as ctx:
+            self.hook.create_flow.retry.sleep = mock.Mock()
+            self.hook.create_flow(body_request=self._create_flow_body_request)
+        assert "HTTPError" in str(ctx.value)
+        assert mock_post_request.call_count == 5
+
     @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
     def test_copy_flow_should_be_called_once_with_params(self, 
mock_get_request):
         self.hook.copy_flow(
diff --git a/tests/providers/google/cloud/operators/test_dataprep.py 
b/tests/providers/google/cloud/operators/test_dataprep.py
index 08237d0d51..d5800716a8 100644
--- a/tests/providers/google/cloud/operators/test_dataprep.py
+++ b/tests/providers/google/cloud/operators/test_dataprep.py
@@ -66,7 +66,7 @@ class TestDataprepGetJobsForJobGroupOperator:
             dataprep_conn_id=DATAPREP_CONN_ID, job_group_id=JOB_ID, 
task_id=TASK_ID
         )
         op.execute(context={})
-        hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
+        hook_mock.assert_called_once_with(dataprep_conn_id=DATAPREP_CONN_ID)
         
hook_mock.return_value.get_jobs_for_job_group.assert_called_once_with(job_id=JOB_ID)
 
 
diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py 
b/tests/system/providers/google/cloud/dataprep/example_dataprep.py
index c07cd5a456..126e3f4c9b 100644
--- a/tests/system/providers/google/cloud/dataprep/example_dataprep.py
+++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py
@@ -16,13 +16,25 @@
 # under the License.
 """
 Example Airflow DAG that shows how to use Google Dataprep.
+
+This DAG relies on the following OS environment variables
+
+* SYSTEM_TESTS_DATAPREP_TOKEN - Dataprep API access token.
+  For generating it please use instruction
+  
https://docs.trifacta.com/display/DP/Manage+API+Access+Tokens#:~:text=Enable%20individual%20access-,Generate%20New%20Token,-Via%20UI.
 """
 from __future__ import annotations
 
+import logging
 import os
 from datetime import datetime
 
-from airflow.models.dag import DAG
+from airflow import models
+from airflow.decorators import task
+from airflow.models import Connection
+from airflow.models.baseoperator import chain
+from airflow.operators.bash import BashOperator
+from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook
 from airflow.providers.google.cloud.operators.dataprep import (
     DataprepCopyFlowOperator,
     DataprepDeleteFlowOperator,
@@ -33,31 +45,40 @@ from airflow.providers.google.cloud.operators.dataprep 
import (
 )
 from airflow.providers.google.cloud.operators.gcs import 
GCSCreateBucketOperator, GCSDeleteBucketOperator
 from airflow.providers.google.cloud.sensors.dataprep import 
DataprepJobGroupIsFinishedSensor
+from airflow.settings import Session
 from airflow.utils.trigger_rule import TriggerRule
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
 DAG_ID = "example_dataprep"
 
+CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_")
+DATAPREP_TOKEN = os.environ.get("SYSTEM_TESTS_DATAPREP_TOKEN", "")
 GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
 GCS_BUCKET_NAME = f"dataprep-bucket-{DAG_ID}-{ENV_ID}"
 GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/"
 
-FLOW_ID = os.environ.get("FLOW_ID")
-RECIPE_ID = os.environ.get("RECIPE_ID")
-RECIPE_NAME = os.environ.get("RECIPE_NAME")
-WRITE_SETTINGS = (
-    {
-        "writesettings": [
-            {
-                "path": GCS_BUCKET_PATH,
-                "action": "create",
-                "format": "csv",
-            }
-        ],
-    },
-)
+DATASET_URI = 
"gs://airflow-system-tests-resources/dataprep/dataset-00000.parquet"
+DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}".replace("-", "_")
+DATASET_WRANGLED_NAME = f"wrangled_{DATASET_NAME}"
+DATASET_WRANGLED_ID = "{{ 
task_instance.xcom_pull('create_wrangled_dataset')['id'] }}"
+
+FLOW_ID = "{{ task_instance.xcom_pull('create_flow')['id'] }}"
+FLOW_COPY_ID = "{{ task_instance.xcom_pull('copy_flow')['id'] }}"
+RECIPE_NAME = DATASET_WRANGLED_NAME
+WRITE_SETTINGS = {
+    "writesettings": [
+        {
+            "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv",
+            "action": "create",
+            "format": "csv",
+        },
+    ],
+}
+
+log = logging.getLogger(__name__)
 
-with DAG(
+
+with models.DAG(
     DAG_ID,
     schedule="@once",
     start_date=datetime(2021, 1, 1),  # Override to match your needs
@@ -71,42 +92,128 @@ with DAG(
         project_id=GCP_PROJECT_ID,
     )
 
+    @task
+    def create_connection(**kwargs) -> None:
+        connection = Connection(
+            conn_id=CONNECTION_ID,
+            description="Example Dataprep connection",
+            conn_type="dataprep",
+            extra={"token": DATAPREP_TOKEN},
+        )
+        session: Session = Session()
+        if session.query(Connection).filter(Connection.conn_id == 
CONNECTION_ID).first():
+            log.warning("Connection %s already exists", CONNECTION_ID)
+            return None
+        session.add(connection)
+        session.commit()
+
+    create_connection_task = create_connection()
+
+    @task
+    def create_imported_dataset():
+        hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+        response = hook.create_imported_dataset(
+            body_request={
+                "uri": DATASET_URI,
+                "name": DATASET_NAME,
+            }
+        )
+        return response
+
+    create_imported_dataset_task = create_imported_dataset()
+
+    @task
+    def create_flow():
+        hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+        response = hook.create_flow(
+            body_request={
+                "name": f"test_flow_{DAG_ID}_{ENV_ID}",
+                "description": "Test flow",
+            }
+        )
+        return response
+
+    create_flow_task = create_flow()
+
+    @task
+    def create_wrangled_dataset(flow, imported_dataset):
+        hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+        response = hook.create_wrangled_dataset(
+            body_request={
+                "importedDataset": {"id": imported_dataset["id"]},
+                "flow": {"id": flow["id"]},
+                "name": DATASET_WRANGLED_NAME,
+            }
+        )
+        return response
+
+    create_wrangled_dataset_task = create_wrangled_dataset(create_flow_task, 
create_imported_dataset_task)
+
+    @task
+    def create_output(wrangled_dataset):
+        hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+        response = hook.create_output_object(
+            body_request={
+                "execution": "dataflow",
+                "profiler": False,
+                "flowNodeId": wrangled_dataset["id"],
+            }
+        )
+        return response
+
+    create_output_task = create_output(create_wrangled_dataset_task)
+
+    @task
+    def create_write_settings(output):
+        hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+        response = hook.create_write_settings(
+            body_request={
+                "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv",
+                "action": "create",
+                "format": "csv",
+                "outputObjectId": output["id"],
+            }
+        )
+        return response
+
+    create_write_settings_task = create_write_settings(create_output_task)
+
+    # [START how_to_dataprep_copy_flow_operator]
+    copy_task = DataprepCopyFlowOperator(
+        task_id="copy_flow",
+        dataprep_conn_id=CONNECTION_ID,
+        project_id=GCP_PROJECT_ID,
+        flow_id=FLOW_ID,
+        name=f"copy_{DATASET_NAME}",
+    )
+    # [END how_to_dataprep_copy_flow_operator]
+
     # [START how_to_dataprep_run_job_group_operator]
     run_job_group_task = DataprepRunJobGroupOperator(
         task_id="run_job_group",
+        dataprep_conn_id=CONNECTION_ID,
         project_id=GCP_PROJECT_ID,
         body_request={
-            "wrangledDataset": {"id": RECIPE_ID},
+            "wrangledDataset": {"id": DATASET_WRANGLED_ID},
             "overrides": WRITE_SETTINGS,
         },
     )
     # [END how_to_dataprep_run_job_group_operator]
 
-    # [START how_to_dataprep_copy_flow_operator]
-    copy_task = DataprepCopyFlowOperator(
-        task_id="copy_flow",
-        project_id=GCP_PROJECT_ID,
-        flow_id=FLOW_ID,
-        name=f"dataprep_example_flow_{DAG_ID}_{ENV_ID}",
-    )
-    # [END how_to_dataprep_copy_flow_operator]
-
     # [START how_to_dataprep_dataprep_run_flow_operator]
     run_flow_task = DataprepRunFlowOperator(
         task_id="run_flow",
+        dataprep_conn_id=CONNECTION_ID,
         project_id=GCP_PROJECT_ID,
-        flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}",
-        body_request={
-            "overrides": {
-                RECIPE_NAME: WRITE_SETTINGS,
-            },
-        },
+        flow_id=FLOW_COPY_ID,
+        body_request={},
     )
     # [END how_to_dataprep_dataprep_run_flow_operator]
 
     # [START how_to_dataprep_get_job_group_operator]
     get_job_group_task = DataprepGetJobGroupOperator(
         task_id="get_job_group",
+        dataprep_conn_id=CONNECTION_ID,
         project_id=GCP_PROJECT_ID,
         job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] 
}}",
         embed="",
@@ -117,6 +224,7 @@ with DAG(
     # [START how_to_dataprep_get_jobs_for_job_group_operator]
     get_jobs_for_job_group_task = DataprepGetJobsForJobGroupOperator(
         task_id="get_jobs_for_job_group",
+        dataprep_conn_id=CONNECTION_ID,
         job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] 
}}",
     )
     # [END how_to_dataprep_get_jobs_for_job_group_operator]
@@ -124,6 +232,7 @@ with DAG(
     # [START how_to_dataprep_job_group_finished_sensor]
     check_flow_status_sensor = DataprepJobGroupIsFinishedSensor(
         task_id="check_flow_status",
+        dataprep_conn_id=CONNECTION_ID,
         job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] 
}}",
     )
     # [END how_to_dataprep_job_group_finished_sensor]
@@ -131,6 +240,7 @@ with DAG(
     # [START how_to_dataprep_job_group_finished_sensor]
     check_job_group_status_sensor = DataprepJobGroupIsFinishedSensor(
         task_id="check_job_group_status",
+        dataprep_conn_id=CONNECTION_ID,
         job_group_id="{{ task_instance.xcom_pull('run_job_group')['id'] }}",
     )
     # [END how_to_dataprep_job_group_finished_sensor]
@@ -138,29 +248,55 @@ with DAG(
     # [START how_to_dataprep_delete_flow_operator]
     delete_flow_task = DataprepDeleteFlowOperator(
         task_id="delete_flow",
+        dataprep_conn_id=CONNECTION_ID,
         flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}",
     )
     # [END how_to_dataprep_delete_flow_operator]
     delete_flow_task.trigger_rule = TriggerRule.ALL_DONE
 
+    delete_flow_task_original = DataprepDeleteFlowOperator(
+        task_id="delete_flow_original",
+        dataprep_conn_id=CONNECTION_ID,
+        flow_id="{{ task_instance.xcom_pull('create_flow')['id'] }}",
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    @task(trigger_rule=TriggerRule.ALL_DONE)
+    def delete_dataset(dataset):
+        hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID)
+        hook.delete_imported_dataset(dataset_id=dataset["id"])
+
+    delete_dataset_task = delete_dataset(create_imported_dataset_task)
+
     delete_bucket_task = GCSDeleteBucketOperator(
         task_id="delete_bucket",
         bucket_name=GCS_BUCKET_NAME,
         trigger_rule=TriggerRule.ALL_DONE,
     )
 
-    (
+    delete_connection = BashOperator(
+        task_id="delete_connection",
+        bash_command=f"airflow connections delete {CONNECTION_ID}",
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    chain(
         # TEST SETUP
-        create_bucket_task
-        >> copy_task
+        create_bucket_task,
+        create_connection_task,
+        [create_imported_dataset_task, create_flow_task],
+        create_wrangled_dataset_task,
+        create_output_task,
+        create_write_settings_task,
         # TEST BODY
-        >> [run_job_group_task, run_flow_task]
-        >> get_job_group_task
-        >> get_jobs_for_job_group_task
+        copy_task,
+        [run_job_group_task, run_flow_task],
+        [get_job_group_task, get_jobs_for_job_group_task],
+        [check_flow_status_sensor, check_job_group_status_sensor],
         # TEST TEARDOWN
-        >> check_flow_status_sensor
-        >> [delete_flow_task, check_job_group_status_sensor]
-        >> delete_bucket_task
+        delete_dataset_task,
+        [delete_flow_task, delete_flow_task_original],
+        [delete_bucket_task, delete_connection],
     )
 
     from tests.system.utils.watcher import watcher


Reply via email to