This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 59e3198f7e Change dataprep system tests assets (#26488)
59e3198f7e is described below
commit 59e3198f7e5f3f4d6999d930fa505e6bd307f325
Author: George <[email protected]>
AuthorDate: Thu Nov 10 11:57:47 2022 +0100
Change dataprep system tests assets (#26488)
---
.../google/cloud/example_dags/example_dataprep.py | 79 -------
airflow/providers/google/cloud/hooks/dataprep.py | 82 ++++++-
airflow/providers/google/cloud/links/base.py | 2 +-
airflow/providers/google/cloud/links/dataprep.py | 63 +++++
.../providers/google/cloud/operators/dataprep.py | 199 +++++++++++++++-
airflow/providers/google/cloud/sensors/dataprep.py | 53 +++++
airflow/providers/google/provider.yaml | 5 +
.../operators/cloud/dataprep.rst | 77 +++++-
.../providers/google/cloud/hooks/test_dataprep.py | 262 ++++++++++++++++++++-
.../google/cloud/operators/test_dataprep.py | 217 ++++++++++++++++-
.../google/cloud/sensors/test_dataprep.py | 46 ++++
.../providers/google/cloud/dataprep/__init__.py | 16 ++
.../google/cloud/dataprep/example_dataprep.py | 175 ++++++++++++++
13 files changed, 1166 insertions(+), 110 deletions(-)
diff --git a/airflow/providers/google/cloud/example_dags/example_dataprep.py
b/airflow/providers/google/cloud/example_dags/example_dataprep.py
deleted file mode 100644
index 6e295fac08..0000000000
--- a/airflow/providers/google/cloud/example_dags/example_dataprep.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Example Airflow DAG that shows how to use Google Dataprep.
-"""
-from __future__ import annotations
-
-import os
-from datetime import datetime
-
-from airflow import models
-from airflow.providers.google.cloud.operators.dataprep import (
- DataprepGetJobGroupOperator,
- DataprepGetJobsForJobGroupOperator,
- DataprepRunJobGroupOperator,
-)
-
-DATAPREP_JOB_ID = int(os.environ.get("DATAPREP_JOB_ID", 12345677))
-DATAPREP_JOB_RECIPE_ID = int(os.environ.get("DATAPREP_JOB_RECIPE_ID",
12345677))
-DATAPREP_BUCKET = os.environ.get("DATAPREP_BUCKET", "gs://INVALID BUCKET
NAME/[email protected]")
-
-DATA = {
- "wrangledDataset": {"id": DATAPREP_JOB_RECIPE_ID},
- "overrides": {
- "execution": "dataflow",
- "profiler": False,
- "writesettings": [
- {
- "path": DATAPREP_BUCKET,
- "action": "create",
- "format": "csv",
- "compression": "none",
- "header": False,
- "asSingleFile": False,
- }
- ],
- },
-}
-
-
-with models.DAG(
- "example_dataprep",
- start_date=datetime(2021, 1, 1), # Override to match your needs
- catchup=False,
-) as dag:
- # [START how_to_dataprep_run_job_group_operator]
- run_job_group = DataprepRunJobGroupOperator(task_id="run_job_group",
body_request=DATA)
- # [END how_to_dataprep_run_job_group_operator]
-
- # [START how_to_dataprep_get_jobs_for_job_group_operator]
- get_jobs_for_job_group = DataprepGetJobsForJobGroupOperator(
- task_id="get_jobs_for_job_group", job_id=DATAPREP_JOB_ID
- )
- # [END how_to_dataprep_get_jobs_for_job_group_operator]
-
- # [START how_to_dataprep_get_job_group_operator]
- get_job_group = DataprepGetJobGroupOperator(
- task_id="get_job_group",
- job_group_id=DATAPREP_JOB_ID,
- embed="",
- include_deleted=False,
- )
- # [END how_to_dataprep_get_job_group_operator]
-
- run_job_group >> [get_jobs_for_job_group, get_job_group]
diff --git a/airflow/providers/google/cloud/hooks/dataprep.py
b/airflow/providers/google/cloud/hooks/dataprep.py
index 45261fed00..c7cbc3b551 100644
--- a/airflow/providers/google/cloud/hooks/dataprep.py
+++ b/airflow/providers/google/cloud/hooks/dataprep.py
@@ -19,8 +19,9 @@
from __future__ import annotations
import json
-import os
+from enum import Enum
from typing import Any
+from urllib.parse import urljoin
import requests
from requests import HTTPError
@@ -43,6 +44,17 @@ def _get_field(extras: dict, field_name: str):
return extras.get(prefixed_name) or None
+class JobGroupStatuses(str, Enum):
+ """Types of job group run statuses."""
+
+ CREATED = "Created"
+ UNDEFINED = "undefined"
+ IN_PROGRESS = "InProgress"
+ COMPLETE = "Complete"
+ FAILED = "Failed"
+ CANCELED = "Canceled"
+
+
class GoogleDataprepHook(BaseHook):
"""
Hook for connection with Dataprep API.
@@ -82,7 +94,7 @@ class GoogleDataprepHook(BaseHook):
:param job_id: The ID of the job that will be fetched
"""
endpoint_path = f"v4/jobGroups/{job_id}/jobs"
- url: str = os.path.join(self._base_url, endpoint_path)
+ url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
return response.json()
@@ -99,7 +111,7 @@ class GoogleDataprepHook(BaseHook):
"""
params: dict[str, Any] = {"embed": embed, "includeDeleted":
include_deleted}
endpoint_path = f"v4/jobGroups/{job_group_id}"
- url: str = os.path.join(self._base_url, endpoint_path)
+ url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers, params=params)
self._raise_for_status(response)
return response.json()
@@ -115,11 +127,73 @@ class GoogleDataprepHook(BaseHook):
:param body_request: The identifier for the recipe you would like to
run.
"""
endpoint_path = "v4/jobGroups"
- url: str = os.path.join(self._base_url, endpoint_path)
+ url: str = urljoin(self._base_url, endpoint_path)
+ response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
+ self._raise_for_status(response)
+ return response.json()
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def copy_flow(
+ self, *, flow_id: int, name: str = "", description: str = "",
copy_datasources: bool = False
+ ) -> dict:
+ """
+ Create a copy of the provided flow id, as well as all contained
recipes.
+
+ :param flow_id: ID of the flow to be copied
+ :param name: Name for the copy of the flow
+ :param description: Description of the copy of the flow
+ :param copy_datasources: Bool value to define should copies of data
inputs be made or not.
+ """
+ endpoint_path = f"v4/flows/{flow_id}/copy"
+ url: str = urljoin(self._base_url, endpoint_path)
+ body_request = {
+ "name": name,
+ "description": description,
+ "copyDatasources": copy_datasources,
+ }
+ response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
+ self._raise_for_status(response)
+ return response.json()
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def delete_flow(self, *, flow_id: int) -> None:
+ """
+ Delete the flow with the provided id.
+
+ :param flow_id: ID of the flow to be copied
+ """
+ endpoint_path = f"v4/flows/{flow_id}"
+ url: str = urljoin(self._base_url, endpoint_path)
+ response = requests.delete(url, headers=self._headers)
+ self._raise_for_status(response)
+
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def run_flow(self, *, flow_id: int, body_request: dict) -> dict:
+ """
+ Runs the flow with the provided id copy of the provided flow id.
+
+ :param flow_id: ID of the flow to be copied
+ :param body_request: Body of the POST request to be sent.
+ """
+ endpoint = f"v4/flows/{flow_id}/run"
+ url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers,
data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1,
max=10))
+ def get_job_group_status(self, *, job_group_id: int) -> JobGroupStatuses:
+ """
+ Check the status of the Dataprep task to be finished.
+
+ :param job_group_id: ID of the job group to check
+ """
+ endpoint = f"/v4/jobGroups/{job_group_id}/status"
+ url: str = urljoin(self._base_url, endpoint)
+ response = requests.get(url, headers=self._headers)
+ self._raise_for_status(response)
+ return response.json()
+
def _raise_for_status(self, response: requests.models.Response) -> None:
try:
response.raise_for_status()
diff --git a/airflow/providers/google/cloud/links/base.py
b/airflow/providers/google/cloud/links/base.py
index 6539043a86..755266758e 100644
--- a/airflow/providers/google/cloud/links/base.py
+++ b/airflow/providers/google/cloud/links/base.py
@@ -45,6 +45,6 @@ class BaseGoogleLink(BaseOperatorLink):
conf = XCom.get_value(key=self.key, ti_key=ti_key)
if not conf:
return ""
- if self.format_str.startswith(BASE_LINK):
+ if self.format_str.startswith("http"):
return self.format_str.format(**conf)
return BASE_LINK + self.format_str.format(**conf)
diff --git a/airflow/providers/google/cloud/links/dataprep.py
b/airflow/providers/google/cloud/links/dataprep.py
new file mode 100644
index 0000000000..66caf1cfe8
--- /dev/null
+++ b/airflow/providers/google/cloud/links/dataprep.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.providers.google.cloud.links.base import BaseGoogleLink
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+BASE_LINK = "https://clouddataprep.com"
+DATAPREP_FLOW_LINK = BASE_LINK + "/flows/{flow_id}?projectId={project_id}"
+DATAPREP_JOB_GROUP_LINK = BASE_LINK +
"/jobs/{job_group_id}?projectId={project_id}"
+
+
+class DataprepFlowLink(BaseGoogleLink):
+ """Helper class for constructing Dataprep flow link."""
+
+ name = "Flow details page"
+ key = "dataprep_flow_page"
+ format_str = DATAPREP_FLOW_LINK
+
+ @staticmethod
+ def persist(context: Context, task_instance, project_id: str, flow_id:
int):
+ task_instance.xcom_push(
+ context=context,
+ key=DataprepFlowLink.key,
+ value={"project_id": project_id, "flow_id": flow_id},
+ )
+
+
+class DataprepJobGroupLink(BaseGoogleLink):
+ """Helper class for constructing Dataprep job group link."""
+
+ name = "Job group details page"
+ key = "dataprep_job_group_page"
+ format_str = DATAPREP_JOB_GROUP_LINK
+
+ @staticmethod
+ def persist(context: Context, task_instance, project_id: str,
job_group_id: int):
+ task_instance.xcom_push(
+ context=context,
+ key=DataprepJobGroupLink.key,
+ value={
+ "project_id": project_id,
+ "job_group_id": job_group_id,
+ },
+ )
diff --git a/airflow/providers/google/cloud/operators/dataprep.py
b/airflow/providers/google/cloud/operators/dataprep.py
index ac62f01032..61340b0747 100644
--- a/airflow/providers/google/cloud/operators/dataprep.py
+++ b/airflow/providers/google/cloud/operators/dataprep.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook
+from airflow.providers.google.cloud.links.dataprep import DataprepFlowLink,
DataprepJobGroupLink
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -36,22 +37,28 @@ class DataprepGetJobsForJobGroupOperator(BaseOperator):
For more information on how to use this operator, take a look at the
guide:
:ref:`howto/operator:DataprepGetJobsForJobGroupOperator`
- :param job_id The ID of the job that will be requests
+ :param job_group_id The ID of the job group that will be requests
"""
- template_fields: Sequence[str] = ("job_id",)
+ template_fields: Sequence[str] = ("job_group_id",)
- def __init__(self, *, dataprep_conn_id: str = "dataprep_default", job_id:
int, **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ dataprep_conn_id: str = "dataprep_default",
+ job_group_id: int | str,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.dataprep_conn_id = (dataprep_conn_id,)
- self.job_id = job_id
+ self.job_group_id = job_group_id
def execute(self, context: Context) -> dict:
- self.log.info("Fetching data for job with id: %d ...", self.job_id)
+ self.log.info("Fetching data for job with id: %d ...",
self.job_group_id)
hook = GoogleDataprepHook(
dataprep_conn_id="dataprep_default",
)
- response = hook.get_jobs_for_job_group(job_id=self.job_id)
+ response = hook.get_jobs_for_job_group(job_id=int(self.job_group_id))
return response
@@ -65,33 +72,49 @@ class DataprepGetJobGroupOperator(BaseOperator):
For more information on how to use this operator, take a look at the
guide:
:ref:`howto/operator:DataprepGetJobGroupOperator`
- :param job_group_id: The ID of the job that will be requests
+ :param job_group_id: The ID of the job group that will be requests
:param embed: Comma-separated list of objects to pull in as part of the
response
:param include_deleted: if set to "true", will include deleted objects
"""
- template_fields: Sequence[str] = ("job_group_id", "embed")
+ template_fields: Sequence[str] = (
+ "job_group_id",
+ "embed",
+ "project_id",
+ )
+ operator_extra_links = (DataprepJobGroupLink(),)
def __init__(
self,
*,
dataprep_conn_id: str = "dataprep_default",
- job_group_id: int,
+ project_id: str | None = None,
+ job_group_id: int | str,
embed: str,
include_deleted: bool,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.dataprep_conn_id: str = dataprep_conn_id
+ self.project_id = project_id
self.job_group_id = job_group_id
self.embed = embed
self.include_deleted = include_deleted
def execute(self, context: Context) -> dict:
self.log.info("Fetching data for job with id: %d ...",
self.job_group_id)
+
+ if self.project_id:
+ DataprepJobGroupLink.persist(
+ context=context,
+ task_instance=self,
+ project_id=self.project_id,
+ job_group_id=int(self.job_group_id),
+ )
+
hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
response = hook.get_job_group(
- job_group_id=self.job_group_id,
+ job_group_id=int(self.job_group_id),
embed=self.embed,
include_deleted=self.include_deleted,
)
@@ -115,14 +138,166 @@ class DataprepRunJobGroupOperator(BaseOperator):
"""
template_fields: Sequence[str] = ("body_request",)
+ operator_extra_links = (DataprepJobGroupLink(),)
- def __init__(self, *, dataprep_conn_id: str = "dataprep_default",
body_request: dict, **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ project_id: str | None = None,
+ dataprep_conn_id: str = "dataprep_default",
+ body_request: dict,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
- self.body_request = body_request
+ self.project_id = project_id
self.dataprep_conn_id = dataprep_conn_id
+ self.body_request = body_request
def execute(self, context: Context) -> dict:
self.log.info("Creating a job...")
hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
response = hook.run_job_group(body_request=self.body_request)
+
+ job_group_id = response.get("id")
+ if self.project_id and job_group_id:
+ DataprepJobGroupLink.persist(
+ context=context,
+ task_instance=self,
+ project_id=self.project_id,
+ job_group_id=int(job_group_id),
+ )
+
+ return response
+
+
+class DataprepCopyFlowOperator(BaseOperator):
+ """
+ Create a copy of the provided flow id, as well as all contained recipes.
+
+ :param dataprep_conn_id: The Dataprep connection ID
+ :param flow_id: ID of the flow to be copied
+ :param name: Name for the copy of the flow
+ :param description: Description of the copy of the flow
+ :param copy_datasources: Bool value to define should the copy of data
inputs be made or not.
+ """
+
+ template_fields: Sequence[str] = (
+ "flow_id",
+ "name",
+ "project_id",
+ "description",
+ )
+ operator_extra_links = (DataprepFlowLink(),)
+
+ def __init__(
+ self,
+ *,
+ project_id: str | None = None,
+ dataprep_conn_id: str = "dataprep_default",
+ flow_id: int | str,
+ name: str = "",
+ description: str = "",
+ copy_datasources: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.dataprep_conn_id = dataprep_conn_id
+ self.flow_id = flow_id
+ self.name = name
+ self.description = description
+ self.copy_datasources = copy_datasources
+
+ def execute(self, context: Context) -> dict:
+ self.log.info("Copying flow with id %d...", self.flow_id)
+ hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
+ response = hook.copy_flow(
+ flow_id=int(self.flow_id),
+ name=self.name,
+ description=self.description,
+ copy_datasources=self.copy_datasources,
+ )
+
+ copied_flow_id = response.get("id")
+ if self.project_id and copied_flow_id:
+ DataprepFlowLink.persist(
+ context=context,
+ task_instance=self,
+ project_id=self.project_id,
+ flow_id=int(copied_flow_id),
+ )
+ return response
+
+
+class DataprepDeleteFlowOperator(BaseOperator):
+ """
+ Delete the flow with provided id.
+
+ :param dataprep_conn_id: The Dataprep connection ID
+ :param flow_id: ID of the flow to be copied
+ """
+
+ template_fields: Sequence[str] = ("flow_id",)
+
+ def __init__(
+ self,
+ *,
+ dataprep_conn_id: str = "dataprep_default",
+ flow_id: int | str,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataprep_conn_id = dataprep_conn_id
+ self.flow_id = flow_id
+
+ def execute(self, context: Context) -> None:
+ self.log.info("Start delete operation of the flow with id: %d...",
self.flow_id)
+ hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
+ hook.delete_flow(flow_id=int(self.flow_id))
+
+
+class DataprepRunFlowOperator(BaseOperator):
+ """
+ Runs the flow with the provided id copy of the provided flow id.
+
+ :param dataprep_conn_id: The Dataprep connection ID
+ :param flow_id: ID of the flow to be copied
+ :param body_request: Body of the POST request to be sent.
+ """
+
+ template_fields: Sequence[str] = (
+ "flow_id",
+ "project_id",
+ )
+ operator_extra_links = (DataprepJobGroupLink(),)
+
+ def __init__(
+ self,
+ *,
+ project_id: str | None = None,
+ flow_id: int | str,
+ body_request: dict,
+ dataprep_conn_id: str = "dataprep_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.flow_id = flow_id
+ self.body_request = body_request
+ self.dataprep_conn_id = dataprep_conn_id
+
+ def execute(self, context: Context) -> dict:
+ self.log.info("Running the flow with id: %d...", self.flow_id)
+ hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
+ response = hooks.run_flow(flow_id=int(self.flow_id),
body_request=self.body_request)
+
+ if self.project_id:
+ job_group_id = response["data"][0]["id"]
+ DataprepJobGroupLink.persist(
+ context=context,
+ task_instance=self,
+ project_id=self.project_id,
+ job_group_id=int(job_group_id),
+ )
+
return response
diff --git a/airflow/providers/google/cloud/sensors/dataprep.py
b/airflow/providers/google/cloud/sensors/dataprep.py
new file mode 100644
index 0000000000..d30f6e18e8
--- /dev/null
+++ b/airflow/providers/google/cloud/sensors/dataprep.py
@@ -0,0 +1,53 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains a Dataprep Job sensor."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
+
+from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook,
JobGroupStatuses
+from airflow.sensors.base import BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class DataprepJobGroupIsFinishedSensor(BaseSensorOperator):
+ """
+ Check the status of the Dataprep task to be finished.
+
+ :param job_group_id: ID of the job group to check
+ """
+
+ template_fields: Sequence[str] = ("job_group_id",)
+
+ def __init__(
+ self,
+ *,
+ job_group_id: int | str,
+ dataprep_conn_id: str = "dataprep_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.job_group_id = job_group_id
+ self.dataprep_conn_id = dataprep_conn_id
+
+ def poke(self, context: Context) -> bool:
+ hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
+ status =
hooks.get_job_group_status(job_group_id=int(self.job_group_id))
+ return status != JobGroupStatuses.IN_PROGRESS
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 3533fba5f6..268fc32bc9 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -609,6 +609,9 @@ sensors:
- integration-name: Google Data Fusion
python-modules:
- airflow.providers.google.cloud.sensors.datafusion
+ - integration-name: Google Dataprep
+ python-modules:
+ - airflow.providers.google.cloud.sensors.dataprep
- integration-name: Google Dataplex
python-modules:
- airflow.providers.google.cloud.sensors.dataplex
@@ -985,6 +988,8 @@ extra-links:
- airflow.providers.google.cloud.links.dataproc.DataprocListLink
-
airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink
-
airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreLink
+ - airflow.providers.google.cloud.links.dataprep.DataprepFlowLink
+ - airflow.providers.google.cloud.links.dataprep.DataprepJobGroupLink
- airflow.providers.google.cloud.links.vertex_ai.VertexAIModelLink
- airflow.providers.google.cloud.links.vertex_ai.VertexAIModelListLink
- airflow.providers.google.cloud.links.vertex_ai.VertexAIModelExportLink
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst
b/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst
index 324a0b4875..4957235604 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst
@@ -59,7 +59,7 @@ To get information about jobs within a Cloud Dataprep job use:
Example usage:
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_dataprep.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataprep/example_dataprep.py
:language: python
:dedent: 4
:start-after: [START how_to_dataprep_run_job_group_operator]
@@ -77,7 +77,7 @@ To get information about jobs within a Cloud Dataprep job use:
Example usage:
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_dataprep.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataprep/example_dataprep.py
:language: python
:dedent: 4
:start-after: [START how_to_dataprep_get_jobs_for_job_group_operator]
@@ -96,8 +96,79 @@ To get information about jobs within a Cloud Dataprep job
use:
Example usage:
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_dataprep.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataprep/example_dataprep.py
:language: python
:dedent: 4
:start-after: [START how_to_dataprep_get_job_group_operator]
:end-before: [END how_to_dataprep_get_job_group_operator]
+
+Copy Flow
+^^^^^^^^^
+
+Operator task is to copy the flow.
+
+To get information about jobs within a Cloud Dataprep job use:
+:class:`~airflow.providers.google.cloud.operators.dataprep.DataprepCopyFlowOperator`
+
+Example usage:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataprep/example_dataprep.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_dataprep_copy_flow_operator]
+ :end-before: [END how_to_dataprep_copy_flow_operator]
+
+Run Flow
+^^^^^^^^
+
+Operator task is to run the flow.
+A flow is a container for wrangling logic which contains
+imported datasets, recipe, output objects, and References.
+
+To get information about jobs within a Cloud Dataprep job use:
+:class:`~airflow.providers.google.cloud.operators.dataprep.DataprepRunFlowOperator`
+
+Example usage:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataprep/example_dataprep.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_dataprep_dataprep_run_flow_operator]
+ :end-before: [END how_to_dataprep_dataprep_run_flow_operator]
+
+Delete flow
+^^^^^^^^^^^
+
+Operator task is to delete the flow.
+A flow is a container for wrangling logic which contains
+imported datasets, recipe, output objects, and References.
+
+To get information about jobs within a Cloud Dataprep job use:
+:class:`~airflow.providers.google.cloud.operators.dataprep.DataprepDeleteFlowOperator`
+
+Example usage:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataprep/example_dataprep.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_dataprep_delete_flow_operator]
+ :end-before: [END how_to_dataprep_delete_flow_operator]
+
+
+Check if Job Group is finished
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Sensor task is to tell the system when started job group is finished
+no matter successfully or not.
+A job group is a job that is executed from a specific node in a flow.
+
+To get information about jobs within a Cloud Dataprep job use:
+:class:`~airflow.providers.google.cloud.sensors.dataprep.DataprepJobGroupIsFinishedSensor`
+
+Example usage:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataprep/example_dataprep.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_dataprep_job_group_finished_sensor]
+ :end-before: [END how_to_dataprep_job_group_finished_sensor]
diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py
b/tests/providers/google/cloud/hooks/test_dataprep.py
index b5950e871a..a13369f734 100644
--- a/tests/providers/google/cloud/hooks/test_dataprep.py
+++ b/tests/providers/google/cloud/hooks/test_dataprep.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import json
import os
-from unittest import mock
+from unittest import TestCase, mock
from unittest.mock import patch
import pytest
@@ -35,7 +35,7 @@ TOKEN = "1111"
EXTRA = {"token": TOKEN}
EMBED = ""
INCLUDE_DELETED = False
-DATA = json.dumps({"wrangledDataset": {"id": RECIPE_ID}})
+DATA = {"wrangledDataset": {"id": RECIPE_ID}}
URL = "https://api.clouddataprep.com/v4/jobGroups"
@@ -151,7 +151,6 @@ class TestGoogleDataprepHook:
@patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
def test_run_job_group_should_be_called_once_with_params(self,
mock_get_request):
- data = '"{\\"wrangledDataset\\": {\\"id\\": 1234567}}"'
self.hook.run_job_group(body_request=DATA)
mock_get_request.assert_called_once_with(
f"{URL}",
@@ -159,7 +158,7 @@ class TestGoogleDataprepHook:
"Content-Type": "application/json",
"Authorization": f"Bearer {TOKEN}",
},
- data=data,
+ data=json.dumps(DATA),
)
@patch(
@@ -206,6 +205,60 @@ class TestGoogleDataprepHook:
assert "HTTPError" in str(ctx.value)
assert mock_get_request.call_count == 5
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.get")
+ def test_get_job_group_status_should_be_called_once_with_params(self,
mock_get_request):
+ self.hook.get_job_group_status(job_group_id=JOB_ID)
+ mock_get_request.assert_called_once_with(
+ f"{URL}/{JOB_ID}/status",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.get",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_get_job_group_status_should_pass_after_retry(self,
mock_get_request):
+ self.hook.get_job_group_status(job_group_id=JOB_ID)
+ assert mock_get_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.get",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_get_job_group_status_retry_after_success(self, mock_get_request):
+ self.hook.run_job_group.retry.sleep = mock.Mock()
+ self.hook.get_job_group_status(job_group_id=JOB_ID)
+ assert mock_get_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.get",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_get_job_group_status_four_errors(self, mock_get_request):
+ self.hook.run_job_group.retry.sleep = mock.Mock()
+ self.hook.get_job_group_status(job_group_id=JOB_ID)
+ assert mock_get_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.get",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_get_job_group_status_five_calls(self, mock_get_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.get_job_group_status.retry.sleep = mock.Mock()
+ self.hook.get_job_group_status(job_group_id=JOB_ID)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_get_request.call_count == 5
+
@pytest.mark.parametrize(
"uri",
[
@@ -218,3 +271,204 @@ class TestGoogleDataprepHook:
hook = GoogleDataprepHook("my_conn")
assert hook._token == "abc"
assert hook._base_url == "abc"
+
+
+class TestGoogleDataprepFlowPathHooks(TestCase):
+ _url = "https://api.clouddataprep.com/v4/flows"
+
+ def setUp(self) -> None:
+ self._flow_id = 1234567
+ self._expected_copy_flow_hook_data = json.dumps(
+ {
+ "name": "",
+ "description": "",
+ "copyDatasources": False,
+ }
+ )
+ self._expected_run_flow_hook_data = json.dumps({})
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn:
+ conn.return_value.extra_dejson = EXTRA
+ self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default")
+
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+ def test_copy_flow_should_be_called_once_with_params(self,
mock_get_request):
+ self.hook.copy_flow(
+ flow_id=self._flow_id,
+ )
+ mock_get_request.assert_called_once_with(
+ f"{self._url}/{self._flow_id}/copy",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ data=self._expected_copy_flow_hook_data,
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_copy_flow_should_pass_after_retry(self, mock_get_request):
+ self.hook.copy_flow(flow_id=self._flow_id)
+ assert mock_get_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_copy_flow_should_not_retry_after_success(self, mock_get_request):
+ self.hook.copy_flow.retry.sleep = mock.Mock()
+ self.hook.copy_flow(flow_id=self._flow_id)
+ assert mock_get_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_copy_flow_should_retry_after_four_errors(self, mock_get_request):
+ self.hook.copy_flow.retry.sleep = mock.Mock()
+ self.hook.copy_flow(flow_id=self._flow_id)
+ assert mock_get_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_copy_flow_raise_error_after_five_calls(self, mock_get_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.copy_flow.retry.sleep = mock.Mock()
+ self.hook.copy_flow(flow_id=self._flow_id)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_get_request.call_count == 5
+
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.delete")
+ def test_delete_flow_should_be_called_once_with_params(self,
mock_get_request):
+ self.hook.delete_flow(
+ flow_id=self._flow_id,
+ )
+ mock_get_request.assert_called_once_with(
+ f"{self._url}/{self._flow_id}",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_delete_flow_should_pass_after_retry(self, mock_get_request):
+ self.hook.delete_flow(flow_id=self._flow_id)
+ assert mock_get_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_delete_flow_should_not_retry_after_success(self,
mock_get_request):
+ self.hook.delete_flow.retry.sleep = mock.Mock()
+ self.hook.delete_flow(flow_id=self._flow_id)
+ assert mock_get_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_delete_flow_should_retry_after_four_errors(self,
mock_get_request):
+ self.hook.delete_flow.retry.sleep = mock.Mock()
+ self.hook.delete_flow(flow_id=self._flow_id)
+ assert mock_get_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.delete",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_delete_flow_raise_error_after_five_calls(self, mock_get_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.delete_flow.retry.sleep = mock.Mock()
+ self.hook.delete_flow(flow_id=self._flow_id)
+ assert "HTTPError" in str(ctx.value)
+ assert mock_get_request.call_count == 5
+
+ @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
+ def test_run_flow_should_be_called_once_with_params(self,
mock_get_request):
+ self.hook.run_flow(
+ flow_id=self._flow_id,
+ body_request={},
+ )
+ mock_get_request.assert_called_once_with(
+ f"{self._url}/{self._flow_id}/run",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {TOKEN}",
+ },
+ data=self._expected_run_flow_hook_data,
+ )
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), mock.MagicMock()],
+ )
+ def test_run_flow_should_pass_after_retry(self, mock_get_request):
+ self.hook.run_flow(
+ flow_id=self._flow_id,
+ body_request={},
+ )
+ assert mock_get_request.call_count == 2
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[mock.MagicMock(), HTTPError()],
+ )
+ def test_run_flow_should_not_retry_after_success(self, mock_get_request):
+ self.hook.run_flow.retry.sleep = mock.Mock()
+ self.hook.run_flow(
+ flow_id=self._flow_id,
+ body_request={},
+ )
+ assert mock_get_request.call_count == 1
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ HTTPError(),
+ mock.MagicMock(),
+ ],
+ )
+ def test_run_flow_should_retry_after_four_errors(self, mock_get_request):
+ self.hook.run_flow.retry.sleep = mock.Mock()
+ self.hook.run_flow(
+ flow_id=self._flow_id,
+ body_request={},
+ )
+ assert mock_get_request.call_count == 5
+
+ @patch(
+ "airflow.providers.google.cloud.hooks.dataprep.requests.post",
+ side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(),
HTTPError()],
+ )
+ def test_run_flow_raise_error_after_five_calls(self, mock_get_request):
+ with pytest.raises(RetryError) as ctx:
+ self.hook.run_flow.retry.sleep = mock.Mock()
+ self.hook.run_flow(
+ flow_id=self._flow_id,
+ body_request={},
+ )
+ assert "HTTPError" in str(ctx.value)
+ assert mock_get_request.call_count == 5
diff --git a/tests/providers/google/cloud/operators/test_dataprep.py
b/tests/providers/google/cloud/operators/test_dataprep.py
index 94e0591ad7..08237d0d51 100644
--- a/tests/providers/google/cloud/operators/test_dataprep.py
+++ b/tests/providers/google/cloud/operators/test_dataprep.py
@@ -17,16 +17,24 @@
# under the License.
from __future__ import annotations
-from unittest import TestCase, mock
+from unittest import mock
+
+import pytest
from airflow.providers.google.cloud.operators.dataprep import (
+ DataprepCopyFlowOperator,
+ DataprepDeleteFlowOperator,
DataprepGetJobGroupOperator,
DataprepGetJobsForJobGroupOperator,
+ DataprepRunFlowOperator,
DataprepRunJobGroupOperator,
)
+GCP_PROJECT_ID = "test-project-id"
DATAPREP_CONN_ID = "dataprep_default"
JOB_ID = 143
+FLOW_ID = 128754
+NEW_FLOW_ID = 1312
TASK_ID = "dataprep_job"
INCLUDE_DELETED = False
EMBED = ""
@@ -51,40 +59,235 @@ DATA = {
}
-class TestDataprepGetJobsForJobGroupOperator(TestCase):
+class TestDataprepGetJobsForJobGroupOperator:
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
def test_execute(self, hook_mock):
op = DataprepGetJobsForJobGroupOperator(
- dataprep_conn_id=DATAPREP_CONN_ID, job_id=JOB_ID, task_id=TASK_ID
+ dataprep_conn_id=DATAPREP_CONN_ID, job_group_id=JOB_ID,
task_id=TASK_ID
)
op.execute(context={})
hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
hook_mock.return_value.get_jobs_for_job_group.assert_called_once_with(job_id=JOB_ID)
-class TestDataprepGetJobGroupOperator(TestCase):
+class TestDataprepGetJobGroupOperator:
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
def test_execute(self, hook_mock):
op = DataprepGetJobGroupOperator(
dataprep_conn_id=DATAPREP_CONN_ID,
+ project_id=None,
job_group_id=JOB_ID,
embed=EMBED,
include_deleted=INCLUDE_DELETED,
task_id=TASK_ID,
)
- op.execute(context={})
+ op.execute(context=mock.MagicMock())
hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
hook_mock.return_value.get_job_group.assert_called_once_with(
job_group_id=JOB_ID, embed=EMBED, include_deleted=INCLUDE_DELETED
)
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.DataprepJobGroupLink")
+ @pytest.mark.parametrize(
+ "provide_project_id, expected_call_count",
+ [
+ (True, 1),
+ (False, 0),
+ ],
+ )
+ def test_execute_with_project_id_will_persist_link_to_job_group(
+ self,
+ link_mock,
+ _,
+ provide_project_id,
+ expected_call_count,
+ ):
+ context = mock.MagicMock()
+ project_id = GCP_PROJECT_ID if provide_project_id else None
+
+ op = DataprepGetJobGroupOperator(
+ task_id=TASK_ID,
+ project_id=project_id,
+ dataprep_conn_id=DATAPREP_CONN_ID,
+ job_group_id=JOB_ID,
+ embed=EMBED,
+ include_deleted=INCLUDE_DELETED,
+ )
+ op.execute(context=context)
+
+ assert link_mock.persist.call_count == expected_call_count
+ if provide_project_id:
+ link_mock.persist.assert_called_with(
+ context=context,
+ task_instance=op,
+ project_id=project_id,
+ job_group_id=JOB_ID,
+ )
+
-class TestDataprepRunJobGroupOperator(TestCase):
+class TestDataprepRunJobGroupOperator:
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
def test_execute(self, hook_mock):
op = DataprepRunJobGroupOperator(
- dataprep_conn_id=DATAPREP_CONN_ID, body_request=DATA,
task_id=TASK_ID
+ dataprep_conn_id=DATAPREP_CONN_ID,
+ body_request=DATA,
+ task_id=TASK_ID,
)
op.execute(context=None)
hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
hook_mock.return_value.run_job_group.assert_called_once_with(body_request=DATA)
+
+
+class TestDataprepCopyFlowOperatorTest:
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+ def test_execute_with_default_params(self, hook_mock):
+ op = DataprepCopyFlowOperator(
+ task_id=TASK_ID,
+ dataprep_conn_id=DATAPREP_CONN_ID,
+ flow_id=FLOW_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
+ hook_mock.return_value.copy_flow.assert_called_once_with(
+ flow_id=FLOW_ID,
+ name="",
+ description="",
+ copy_datasources=False,
+ )
+
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+ def test_execute_with_specified_params(self, hook_mock):
+ op = DataprepCopyFlowOperator(
+ task_id=TASK_ID,
+ dataprep_conn_id=DATAPREP_CONN_ID,
+ flow_id=FLOW_ID,
+ name="specified name",
+ description="specified description",
+ copy_datasources=True,
+ )
+ op.execute(context=mock.MagicMock())
+ hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
+ hook_mock.return_value.copy_flow.assert_called_once_with(
+ flow_id=FLOW_ID, name="specified name", description="specified
description", copy_datasources=True
+ )
+
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+ def test_execute_with_templated_params(self, _,
create_task_instance_of_operator):
+ dag_id = "test_execute_with_templated_params"
+ ti = create_task_instance_of_operator(
+ DataprepCopyFlowOperator,
+ dag_id=dag_id,
+ project_id="{{ dag.dag_id }}",
+ task_id=TASK_ID,
+ flow_id="{{ dag.dag_id }}",
+ name="{{ dag.dag_id }}",
+ description="{{ dag.dag_id }}",
+ )
+ ti.render_templates()
+ assert dag_id == ti.task.project_id
+ assert dag_id == ti.task.flow_id
+ assert dag_id == ti.task.name
+ assert dag_id == ti.task.description
+
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.DataprepFlowLink")
+ @pytest.mark.parametrize(
+ "provide_project_id, expected_call_count",
+ [
+ (True, 1),
+ (False, 0),
+ ],
+ )
+ def test_execute_with_project_id_will_persist_link_to_flow(
+ self,
+ link_mock,
+ hook_mock,
+ provide_project_id,
+ expected_call_count,
+ ):
+ hook_mock.return_value.copy_flow.return_value = {"id": NEW_FLOW_ID}
+ context = mock.MagicMock()
+ project_id = GCP_PROJECT_ID if provide_project_id else None
+
+ op = DataprepCopyFlowOperator(
+ task_id=TASK_ID,
+ project_id=project_id,
+ dataprep_conn_id=DATAPREP_CONN_ID,
+ flow_id=FLOW_ID,
+ name="specified name",
+ description="specified description",
+ copy_datasources=True,
+ )
+ op.execute(context=context)
+
+ assert link_mock.persist.call_count == expected_call_count
+ if provide_project_id:
+ link_mock.persist.assert_called_with(
+ context=context,
+ task_instance=op,
+ project_id=project_id,
+ flow_id=NEW_FLOW_ID,
+ )
+
+
+class TestDataprepDeleteFlowOperator:
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+ def test_execute(self, hook_mock):
+ op = DataprepDeleteFlowOperator(
+ task_id=TASK_ID,
+ dataprep_conn_id=DATAPREP_CONN_ID,
+ flow_id=FLOW_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
+ hook_mock.return_value.delete_flow.assert_called_once_with(
+ flow_id=FLOW_ID,
+ )
+
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+ def test_execute_with_template_params(self, _,
create_task_instance_of_operator):
+ dag_id = "test_execute_delete_flow_with_template"
+ ti = create_task_instance_of_operator(
+ DataprepDeleteFlowOperator,
+ dag_id=dag_id,
+ task_id=TASK_ID,
+ flow_id="{{ dag.dag_id }}",
+ )
+ ti.render_templates()
+ assert dag_id == ti.task.flow_id
+
+
+class TestDataprepRunFlowOperator:
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+ def test_execute(self, hook_mock):
+ op = DataprepRunFlowOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT_ID,
+ dataprep_conn_id=DATAPREP_CONN_ID,
+ flow_id=FLOW_ID,
+ body_request={},
+ )
+ op.execute(context=mock.MagicMock())
+ hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default")
+ hook_mock.return_value.run_flow.assert_called_once_with(
+ flow_id=FLOW_ID,
+ body_request={},
+ )
+
+
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
+ def test_execute_with_template_params(self, _,
create_task_instance_of_operator):
+ dag_id = "test_execute_run_flow_with_template"
+ ti = create_task_instance_of_operator(
+ DataprepRunFlowOperator,
+ dag_id=dag_id,
+ task_id=TASK_ID,
+ project_id="{{ dag.dag_id }}",
+ flow_id="{{ dag.dag_id }}",
+ body_request={},
+ )
+
+ ti.render_templates()
+
+ assert dag_id == ti.task.project_id
+ assert dag_id == ti.task.flow_id
diff --git a/tests/providers/google/cloud/sensors/test_dataprep.py
b/tests/providers/google/cloud/sensors/test_dataprep.py
new file mode 100644
index 0000000000..7ea1816aa7
--- /dev/null
+++ b/tests/providers/google/cloud/sensors/test_dataprep.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.providers.google.cloud.hooks.dataprep import JobGroupStatuses
+from airflow.providers.google.cloud.sensors.dataprep import
DataprepJobGroupIsFinishedSensor
+
+JOB_GROUP_ID = 1312
+
+
+class TestDataprepJobGroupIsFinishedSensor:
+
@mock.patch("airflow.providers.google.cloud.sensors.dataprep.GoogleDataprepHook")
+ def test_passing_arguments_to_hook(self, hook_mock):
+ sensor = DataprepJobGroupIsFinishedSensor(
+ task_id="check_job_group_finished",
+ job_group_id=JOB_GROUP_ID,
+ )
+
+ hook_mock.return_value.get_job_group_status.return_value =
JobGroupStatuses.COMPLETE
+ is_job_group_finished = sensor.poke(context=mock.MagicMock())
+
+ assert is_job_group_finished
+
+ hook_mock.assert_called_once_with(
+ dataprep_conn_id="dataprep_default",
+ )
+ hook_mock.return_value.get_job_group_status.assert_called_once_with(
+ job_group_id=JOB_GROUP_ID,
+ )
diff --git a/tests/system/providers/google/cloud/dataprep/__init__.py
b/tests/system/providers/google/cloud/dataprep/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/google/cloud/dataprep/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py
b/tests/system/providers/google/cloud/dataprep/example_dataprep.py
new file mode 100644
index 0000000000..9f478a5f0b
--- /dev/null
+++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that shows how to use Google Dataprep.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.operators.dataprep import (
+ DataprepCopyFlowOperator,
+ DataprepDeleteFlowOperator,
+ DataprepGetJobGroupOperator,
+ DataprepGetJobsForJobGroupOperator,
+ DataprepRunFlowOperator,
+ DataprepRunJobGroupOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import
GCSCreateBucketOperator, GCSDeleteBucketOperator
+from airflow.providers.google.cloud.sensors.dataprep import
DataprepJobGroupIsFinishedSensor
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "example_dataprep"
+
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+GCS_BUCKET_NAME = f"dataprep-bucket-heorhi-{DAG_ID}-{ENV_ID}"
+GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/"
+
+FLOW_ID = os.environ.get("FLOW_ID", "")
+RECIPE_ID = os.environ.get("RECIPE_ID")
+RECIPE_NAME = os.environ.get("RECIPE_NAME")
+WRITE_SETTINGS = (
+ {
+ "writesettings": [
+ {
+ "path": GCS_BUCKET_PATH,
+ "action": "create",
+ "format": "csv",
+ }
+ ],
+ },
+)
+
+with models.DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1), # Override to match your needs
+ catchup=False,
+ tags=["example", "dataprep"],
+ render_template_as_native_obj=True,
+) as dag:
+ create_bucket_task = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=GCS_BUCKET_NAME,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ # [START how_to_dataprep_run_job_group_operator]
+ run_job_group_task = DataprepRunJobGroupOperator(
+ task_id="run_job_group",
+ project_id=GCP_PROJECT_ID,
+ body_request={
+ "wrangledDataset": {"id": RECIPE_ID},
+ "overrides": WRITE_SETTINGS,
+ },
+ )
+ # [END how_to_dataprep_run_job_group_operator]
+
+ # [START how_to_dataprep_copy_flow_operator]
+ copy_task = DataprepCopyFlowOperator(
+ task_id="copy_flow",
+ project_id=GCP_PROJECT_ID,
+ flow_id=FLOW_ID,
+ name=f"dataprep_example_flow_{DAG_ID}_{ENV_ID}",
+ )
+ # [END how_to_dataprep_copy_flow_operator]
+
+ # [START how_to_dataprep_dataprep_run_flow_operator]
+ run_flow_task = DataprepRunFlowOperator(
+ task_id="run_flow",
+ project_id=GCP_PROJECT_ID,
+ flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}",
+ body_request={
+ "overrides": {
+ RECIPE_NAME: WRITE_SETTINGS,
+ },
+ },
+ )
+ # [END how_to_dataprep_dataprep_run_flow_operator]
+
+ # [START how_to_dataprep_get_job_group_operator]
+ get_job_group_task = DataprepGetJobGroupOperator(
+ task_id="get_job_group",
+ project_id=GCP_PROJECT_ID,
+ job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id']
}}",
+ embed="",
+ include_deleted=False,
+ )
+ # [END how_to_dataprep_get_job_group_operator]
+
+ # [START how_to_dataprep_get_jobs_for_job_group_operator]
+ get_jobs_for_job_group_task = DataprepGetJobsForJobGroupOperator(
+ task_id="get_jobs_for_job_group",
+ job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id']
}}",
+ )
+ # [END how_to_dataprep_get_jobs_for_job_group_operator]
+
+ # [START how_to_dataprep_job_group_finished_sensor]
+ check_flow_status_sensor = DataprepJobGroupIsFinishedSensor(
+ task_id="check_flow_status",
+ job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id']
}}",
+ )
+ # [END how_to_dataprep_job_group_finished_sensor]
+
+ # [START how_to_dataprep_job_group_finished_sensor]
+ check_job_group_status_sensor = DataprepJobGroupIsFinishedSensor(
+ task_id="check_job_group_status",
+ job_group_id="{{ task_instance.xcom_pull('run_job_group')['id'] }}",
+ )
+ # [END how_to_dataprep_job_group_finished_sensor]
+
+ # [START how_to_dataprep_delete_flow_operator]
+ delete_flow_task = DataprepDeleteFlowOperator(
+ task_id="delete_flow",
+ flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}",
+ )
+ # [END how_to_dataprep_delete_flow_operator]
+ delete_flow_task.trigger_rule = TriggerRule.ALL_DONE
+
+ delete_bucket_task = GCSDeleteBucketOperator(
+ task_id="delete_bucket",
+ bucket_name=GCS_BUCKET_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_bucket_task
+ >> copy_task
+ # TEST BODY
+ >> [run_job_group_task, run_flow_task]
+ >> get_job_group_task
+ >> get_jobs_for_job_group_task
+ # TEST TEARDOWN
+ >> check_flow_status_sensor
+ >> [delete_flow_task, check_job_group_status_sensor]
+ >> delete_bucket_task
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)