This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c8de9a5f68 Adding Support for Google Cloud's Data Pipelines Run
Operator (#32846)
c8de9a5f68 is described below
commit c8de9a5f686f55a27705a69d69fbc64840df03ce
Author: Brenda Pham <[email protected]>
AuthorDate: Sun Aug 20 22:33:54 2023 -0700
Adding Support for Google Cloud's Data Pipelines Run Operator (#32846)
---------
Co-authored-by: shaniyaclement <[email protected]>
Co-authored-by: Brenda Pham <[email protected]>
Co-authored-by: Shaniya Clement
<[email protected]>
---
.../providers/google/cloud/hooks/datapipeline.py | 31 ++++++++
.../google/cloud/operators/datapipeline.py | 52 +++++++++++++
.../operators/cloud/datapipeline.rst | 29 +++++++
.../google/cloud/hooks/test_datapipeline.py | 23 ++++++
.../google/cloud/operators/test_datapipeline.py | 90 ++++++++++++++++++++++
.../cloud/datapipelines/example_datapipeline.py | 10 ++-
6 files changed, 234 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/google/cloud/hooks/datapipeline.py
b/airflow/providers/google/cloud/hooks/datapipeline.py
index c9c4790106..d141e24a71 100644
--- a/airflow/providers/google/cloud/hooks/datapipeline.py
+++ b/airflow/providers/google/cloud/hooks/datapipeline.py
@@ -85,6 +85,37 @@ class DataPipelineHook(GoogleBaseHook):
response = request.execute(num_retries=self.num_retries)
return response
+ @GoogleBaseHook.fallback_to_default_project_id
+ def run_data_pipeline(
+ self,
+ data_pipeline_name: str,
+ project_id: str,
+ location: str = DEFAULT_DATAPIPELINE_LOCATION,
+ ) -> None:
+ """
+ Runs a Data Pipelines Instance using the Data Pipelines API.
+
+ :param data_pipeline_name: The display name of the pipeline. In
example
+ projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it
would be the PIPELINE_ID.
+ :param project_id: The ID of the GCP project that owns the job.
+ :param location: The location to direct the Data Pipelines instance to
(for example us-central1).
+
+ Returns the created Job in JSON representation.
+ """
+ parent = self.build_parent_name(project_id, location)
+ service = self.get_conn()
+ request = (
+ service.projects()
+ .locations()
+ .pipelines()
+ .run(
+ name=f"{parent}/pipelines/{data_pipeline_name}",
+ body={},
+ )
+ )
+ response = request.execute(num_retries=self.num_retries)
+ return response
+
@staticmethod
def build_parent_name(project_id: str, location: str):
return f"projects/{project_id}/locations/{location}"
diff --git a/airflow/providers/google/cloud/operators/datapipeline.py
b/airflow/providers/google/cloud/operators/datapipeline.py
index 9c55005231..2283b56ca4 100644
--- a/airflow/providers/google/cloud/operators/datapipeline.py
+++ b/airflow/providers/google/cloud/operators/datapipeline.py
@@ -100,3 +100,55 @@ class CreateDataPipelineOperator(GoogleCloudBaseOperator):
raise
AirflowException(self.data_pipeline.get("error").get("message"))
return self.data_pipeline
+
+
+class RunDataPipelineOperator(GoogleCloudBaseOperator):
+ """
+ Runs a Data Pipelines Instance using the Data Pipelines API.
+
+ :param data_pipeline_name: The display name of the pipeline. In example
+ projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it
would be the PIPELINE_ID.
+ :param project_id: The ID of the GCP project that owns the job.
+ :param location: The location to direct the Data Pipelines instance to
(for example us-central1).
+ :param gcp_conn_id: The connection ID to connect to the Google Cloud
+ Platform.
+
+ Returns the created Job in JSON representation.
+ """
+
+ def __init__(
+ self,
+ data_pipeline_name: str,
+ project_id: str | None = None,
+ location: str = DEFAULT_DATAPIPELINE_LOCATION,
+ gcp_conn_id: str = "google_cloud_default",
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.data_pipeline_name = data_pipeline_name
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+
+ def execute(self, context: Context):
+ self.data_pipeline_hook =
DataPipelineHook(gcp_conn_id=self.gcp_conn_id)
+
+ if self.data_pipeline_name is None:
+ raise AirflowException("Data Pipeline name not given; cannot run
unspecified pipeline.")
+ if self.project_id is None:
+ raise AirflowException("Data Pipeline Project ID not given; cannot
run pipeline.")
+ if self.location is None:
+ raise AirflowException("Data Pipeline location not given; cannot
run pipeline.")
+
+ self.response = self.data_pipeline_hook.run_data_pipeline(
+ data_pipeline_name=self.data_pipeline_name,
+ project_id=self.project_id,
+ location=self.location,
+ )
+
+ if self.response:
+ if "error" in self.response:
+ raise
AirflowException(self.response.get("error").get("message"))
+
+ return self.response
diff --git
a/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst
b/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst
index cbd5873efc..4096aa7c9b 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst
@@ -55,6 +55,35 @@ Here is an example of how you can create a Data Pipelines
instance by running th
:start-after: [START howto_operator_create_data_pipeline]
:end-before: [END howto_operator_create_data_pipeline]
+Running a Data Pipeline
+^^^^^^^^^^^^^^^^^^^^^^^
+
+To run a Data Pipelines instance, use
:class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator`.
+The operator accesses Google Cloud's Data Pipelines API and calls upon the
+`run method
<https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/run>`__
+to run the given pipeline.
+
+:class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator`
can take in four parameters:
+
+- ``data_pipeline_name``: the name of the Data Pipelines instance
+- ``project_id``: the ID of the GCP project that owns the job
+- ``location``: the location of the Data Pipelines instance
+- ``gcp_conn_id``: the connection ID to connect to the Google Cloud Platform
+
+Only the Data Pipeline name and Project ID are required parameters, as the
Location and GCP Connection ID have default values.
+The Project ID and Location will be used to build the parent name, which is
where the given Data Pipeline should be located.
+
+You can run a Data Pipelines instance by running the above parameters with
RunDataPipelineOperator:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_run_data_pipeline]
+ :end-before: [END howto_operator_run_data_pipeline]
+
+Once called, the RunDataPipelineOperator will return the Google Cloud
`Dataflow Job
<https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/Job>`__
+created by running the given pipeline.
+
For further information regarding the API usage, see
`Data Pipelines API REST Resource
<https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines#Pipeline>`__
in the Google Cloud documentation.
diff --git a/tests/providers/google/cloud/hooks/test_datapipeline.py
b/tests/providers/google/cloud/hooks/test_datapipeline.py
index 4c19c7488b..f636d20de2 100644
--- a/tests/providers/google/cloud/hooks/test_datapipeline.py
+++ b/tests/providers/google/cloud/hooks/test_datapipeline.py
@@ -108,3 +108,26 @@ class TestDataPipelineHook:
body=TEST_BODY,
)
assert result == {"name": TEST_PARENT}
+
+
@mock.patch("airflow.providers.google.cloud.hooks.datapipeline.DataPipelineHook.get_conn")
+ def test_run_data_pipeline(self, mock_connection):
+ """
+ Test that run_data_pipeline is called with correct parameters and
+ calls Google Data Pipelines API
+ """
+ mock_request = (
+
mock_connection.return_value.projects.return_value.locations.return_value.pipelines.return_value.run
+ )
+ mock_request.return_value.execute.return_value = {"job": {"id":
TEST_JOB_ID}}
+
+ result = self.datapipeline_hook.run_data_pipeline(
+ data_pipeline_name=TEST_DATA_PIPELINE_NAME,
+ project_id=TEST_PROJECTID,
+ location=TEST_LOCATION,
+ )
+
+ mock_request.assert_called_once_with(
+ name=TEST_NAME,
+ body={},
+ )
+ assert result == {"job": {"id": TEST_JOB_ID}}
diff --git a/tests/providers/google/cloud/operators/test_datapipeline.py
b/tests/providers/google/cloud/operators/test_datapipeline.py
index 21f07193da..eab6e4cf23 100644
--- a/tests/providers/google/cloud/operators/test_datapipeline.py
+++ b/tests/providers/google/cloud/operators/test_datapipeline.py
@@ -24,6 +24,7 @@ import pytest as pytest
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.operators.datapipeline import (
CreateDataPipelineOperator,
+ RunDataPipelineOperator,
)
TASK_ID = "test-datapipeline-operators"
@@ -136,3 +137,92 @@ class TestCreateDataPipelineOperator:
}
with pytest.raises(AirflowException):
CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+
+class TestRunDataPipelineOperator:
+ @pytest.fixture
+ def run_operator(self):
+ """
+ Create a RunDataPipelineOperator instance with test data
+ """
+ return RunDataPipelineOperator(
+ task_id=TASK_ID,
+ data_pipeline_name=TEST_DATA_PIPELINE_NAME,
+ project_id=TEST_PROJECTID,
+ location=TEST_LOCATION,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ )
+
+
@mock.patch("airflow.providers.google.cloud.operators.datapipeline.DataPipelineHook")
+ def test_execute(self, data_pipeline_hook_mock, run_operator):
+ """
+ Test Run Operator execute with correct parameters
+ """
+ run_operator.execute(mock.MagicMock())
+ data_pipeline_hook_mock.assert_called_once_with(
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ )
+
+
data_pipeline_hook_mock.return_value.run_data_pipeline.assert_called_once_with(
+ data_pipeline_name=TEST_DATA_PIPELINE_NAME,
+ project_id=TEST_PROJECTID,
+ location=TEST_LOCATION,
+ )
+
+ def test_invalid_data_pipeline_name(self):
+ """
+ Test that AirflowException is raised if Run Operator is not given a
data pipeline name.
+ """
+ init_kwargs = {
+ "task_id": TASK_ID,
+ "data_pipeline_name": None,
+ "project_id": TEST_PROJECTID,
+ "location": TEST_LOCATION,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ }
+ with pytest.raises(AirflowException):
+ RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+ def test_invalid_project_id(self):
+ """
+ Test that AirflowException is raised if Run Operator is not given a
project ID.
+ """
+ init_kwargs = {
+ "task_id": TASK_ID,
+ "data_pipeline_name": TEST_DATA_PIPELINE_NAME,
+ "project_id": None,
+ "location": TEST_LOCATION,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ }
+ with pytest.raises(AirflowException):
+ RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+ def test_invalid_location(self):
+ """
+ Test that AirflowException is raised if Run Operator is not given a
location.
+ """
+ init_kwargs = {
+ "task_id": TASK_ID,
+ "data_pipeline_name": TEST_DATA_PIPELINE_NAME,
+ "project_id": TEST_PROJECTID,
+ "location": None,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ }
+ with pytest.raises(AirflowException):
+ RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+ def test_invalid_response(self):
+ """
+ Test that AirflowException is raised if Run Operator fails execution
and returns error.
+ """
+ init_kwargs = {
+ "task_id": TASK_ID,
+ "data_pipeline_name": TEST_DATA_PIPELINE_NAME,
+ "project_id": TEST_PROJECTID,
+ "location": TEST_LOCATION,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ }
+ with pytest.raises(AirflowException):
+
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value =
{
+ "error": {"message": "example error"}
+ }
diff --git
a/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
b/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
index 3eb67c530f..e4b82705a7 100644
--- a/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
+++ b/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
@@ -28,6 +28,7 @@ from pathlib import Path
from airflow import models
from airflow.providers.google.cloud.operators.datapipeline import (
CreateDataPipelineOperator,
+ RunDataPipelineOperator,
)
from airflow.providers.google.cloud.operators.gcs import
GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.transfers.local_to_gcs import
LocalFilesystemToGCSOperator
@@ -38,7 +39,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
GCP_LOCATION = os.environ.get("location", "us-central1")
-PIPELINE_NAME = "defualt-pipeline-name"
+PIPELINE_NAME = os.environ.get("DATA_PIPELINE_NAME", "defualt-pipeline-name")
PIPELINE_TYPE = "PIPELINE_TYPE_BATCH"
BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
@@ -117,6 +118,13 @@ with models.DAG(
# when "teardown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()
+ # [START howto_operator_run_data_pipeline]
+ run_data_pipeline = RunDataPipelineOperator(
+ task_id="run_data_pipeline",
+ data_pipeline_name=PIPELINE_NAME,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_run_data_pipeline]
from tests.system.utils import get_test_run # noqa: E402