This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c8de9a5f68 Adding Support for Google Cloud's Data Pipelines Run 
Operator (#32846)
c8de9a5f68 is described below

commit c8de9a5f686f55a27705a69d69fbc64840df03ce
Author: Brenda Pham <[email protected]>
AuthorDate: Sun Aug 20 22:33:54 2023 -0700

    Adding Support for Google Cloud's Data Pipelines Run Operator (#32846)
    
    
    ---------
    
    Co-authored-by: shaniyaclement <[email protected]>
    Co-authored-by: Brenda Pham <[email protected]>
    Co-authored-by: Shaniya Clement 
<[email protected]>
---
 .../providers/google/cloud/hooks/datapipeline.py   | 31 ++++++++
 .../google/cloud/operators/datapipeline.py         | 52 +++++++++++++
 .../operators/cloud/datapipeline.rst               | 29 +++++++
 .../google/cloud/hooks/test_datapipeline.py        | 23 ++++++
 .../google/cloud/operators/test_datapipeline.py    | 90 ++++++++++++++++++++++
 .../cloud/datapipelines/example_datapipeline.py    | 10 ++-
 6 files changed, 234 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/google/cloud/hooks/datapipeline.py 
b/airflow/providers/google/cloud/hooks/datapipeline.py
index c9c4790106..d141e24a71 100644
--- a/airflow/providers/google/cloud/hooks/datapipeline.py
+++ b/airflow/providers/google/cloud/hooks/datapipeline.py
@@ -85,6 +85,37 @@ class DataPipelineHook(GoogleBaseHook):
         response = request.execute(num_retries=self.num_retries)
         return response
 
+    @GoogleBaseHook.fallback_to_default_project_id
+    def run_data_pipeline(
+        self,
+        data_pipeline_name: str,
+        project_id: str,
+        location: str = DEFAULT_DATAPIPELINE_LOCATION,
+    ) -> None:
+        """
+        Runs a Data Pipelines Instance using the Data Pipelines API.
+
+        :param data_pipeline_name:  The display name of the pipeline. In 
example
+            projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it 
would be the PIPELINE_ID.
+        :param project_id: The ID of the GCP project that owns the job.
+        :param location: The location to direct the Data Pipelines instance to 
(for example us-central1).
+
+        Returns the created Job in JSON representation.
+        """
+        parent = self.build_parent_name(project_id, location)
+        service = self.get_conn()
+        request = (
+            service.projects()
+            .locations()
+            .pipelines()
+            .run(
+                name=f"{parent}/pipelines/{data_pipeline_name}",
+                body={},
+            )
+        )
+        response = request.execute(num_retries=self.num_retries)
+        return response
+
     @staticmethod
     def build_parent_name(project_id: str, location: str):
         return f"projects/{project_id}/locations/{location}"
diff --git a/airflow/providers/google/cloud/operators/datapipeline.py 
b/airflow/providers/google/cloud/operators/datapipeline.py
index 9c55005231..2283b56ca4 100644
--- a/airflow/providers/google/cloud/operators/datapipeline.py
+++ b/airflow/providers/google/cloud/operators/datapipeline.py
@@ -100,3 +100,55 @@ class CreateDataPipelineOperator(GoogleCloudBaseOperator):
                 raise 
AirflowException(self.data_pipeline.get("error").get("message"))
 
         return self.data_pipeline
+
+
+class RunDataPipelineOperator(GoogleCloudBaseOperator):
+    """
+    Runs a Data Pipelines Instance using the Data Pipelines API.
+
+    :param data_pipeline_name:  The display name of the pipeline. In example
+        projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it 
would be the PIPELINE_ID.
+    :param project_id: The ID of the GCP project that owns the job.
+    :param location: The location to direct the Data Pipelines instance to 
(for example us-central1).
+    :param gcp_conn_id: The connection ID to connect to the Google Cloud
+        Platform.
+
+    Returns the created Job in JSON representation.
+    """
+
+    def __init__(
+        self,
+        data_pipeline_name: str,
+        project_id: str | None = None,
+        location: str = DEFAULT_DATAPIPELINE_LOCATION,
+        gcp_conn_id: str = "google_cloud_default",
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        self.data_pipeline_name = data_pipeline_name
+        self.project_id = project_id
+        self.location = location
+        self.gcp_conn_id = gcp_conn_id
+
+    def execute(self, context: Context):
+        self.data_pipeline_hook = 
DataPipelineHook(gcp_conn_id=self.gcp_conn_id)
+
+        if self.data_pipeline_name is None:
+            raise AirflowException("Data Pipeline name not given; cannot run 
unspecified pipeline.")
+        if self.project_id is None:
+            raise AirflowException("Data Pipeline Project ID not given; cannot 
run pipeline.")
+        if self.location is None:
+            raise AirflowException("Data Pipeline location not given; cannot 
run pipeline.")
+
+        self.response = self.data_pipeline_hook.run_data_pipeline(
+            data_pipeline_name=self.data_pipeline_name,
+            project_id=self.project_id,
+            location=self.location,
+        )
+
+        if self.response:
+            if "error" in self.response:
+                raise 
AirflowException(self.response.get("error").get("message"))
+
+        return self.response
diff --git 
a/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst 
b/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst
index cbd5873efc..4096aa7c9b 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/datapipeline.rst
@@ -55,6 +55,35 @@ Here is an example of how you can create a Data Pipelines 
instance by running th
    :start-after: [START howto_operator_create_data_pipeline]
    :end-before: [END howto_operator_create_data_pipeline]
 
+Running a Data Pipeline
+^^^^^^^^^^^^^^^^^^^^^^^
+
+To run a Data Pipelines instance, use 
:class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator`.
+The operator accesses Google Cloud's Data Pipelines API and calls upon the
+`run method 
<https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/run>`__
+to run the given pipeline.
+
+:class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator`
 can take in four parameters:
+
+- ``data_pipeline_name``: the name of the Data Pipelines instance
+- ``project_id``: the ID of the GCP project that owns the job
+- ``location``: the location of the Data Pipelines instance
+- ``gcp_conn_id``: the connection ID to connect to the Google Cloud Platform
+
+Only the Data Pipeline name and Project ID are required parameters, as the 
Location and GCP Connection ID have default values.
+The Project ID and Location will be used to build the parent name, which is 
where the given Data Pipeline should be located.
+
+You can run a Data Pipelines instance by running the above parameters with 
RunDataPipelineOperator:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_run_data_pipeline]
+    :end-before: [END howto_operator_run_data_pipeline]
+
+Once called, the RunDataPipelineOperator will return the Google Cloud 
`Dataflow Job 
<https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/Job>`__
+created by running the given pipeline.
+
 For further information regarding the API usage, see
 `Data Pipelines API REST Resource 
<https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines#Pipeline>`__
 in the Google Cloud documentation.
diff --git a/tests/providers/google/cloud/hooks/test_datapipeline.py 
b/tests/providers/google/cloud/hooks/test_datapipeline.py
index 4c19c7488b..f636d20de2 100644
--- a/tests/providers/google/cloud/hooks/test_datapipeline.py
+++ b/tests/providers/google/cloud/hooks/test_datapipeline.py
@@ -108,3 +108,26 @@ class TestDataPipelineHook:
             body=TEST_BODY,
         )
         assert result == {"name": TEST_PARENT}
+
+    
@mock.patch("airflow.providers.google.cloud.hooks.datapipeline.DataPipelineHook.get_conn")
+    def test_run_data_pipeline(self, mock_connection):
+        """
+        Test that run_data_pipeline is called with correct parameters and
+        calls Google Data Pipelines API
+        """
+        mock_request = (
+            
mock_connection.return_value.projects.return_value.locations.return_value.pipelines.return_value.run
+        )
+        mock_request.return_value.execute.return_value = {"job": {"id": 
TEST_JOB_ID}}
+
+        result = self.datapipeline_hook.run_data_pipeline(
+            data_pipeline_name=TEST_DATA_PIPELINE_NAME,
+            project_id=TEST_PROJECTID,
+            location=TEST_LOCATION,
+        )
+
+        mock_request.assert_called_once_with(
+            name=TEST_NAME,
+            body={},
+        )
+        assert result == {"job": {"id": TEST_JOB_ID}}
diff --git a/tests/providers/google/cloud/operators/test_datapipeline.py 
b/tests/providers/google/cloud/operators/test_datapipeline.py
index 21f07193da..eab6e4cf23 100644
--- a/tests/providers/google/cloud/operators/test_datapipeline.py
+++ b/tests/providers/google/cloud/operators/test_datapipeline.py
@@ -24,6 +24,7 @@ import pytest as pytest
 from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.operators.datapipeline import (
     CreateDataPipelineOperator,
+    RunDataPipelineOperator,
 )
 
 TASK_ID = "test-datapipeline-operators"
@@ -136,3 +137,92 @@ class TestCreateDataPipelineOperator:
         }
         with pytest.raises(AirflowException):
             CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+
+class TestRunDataPipelineOperator:
+    @pytest.fixture
+    def run_operator(self):
+        """
+        Create a RunDataPipelineOperator instance with test data
+        """
+        return RunDataPipelineOperator(
+            task_id=TASK_ID,
+            data_pipeline_name=TEST_DATA_PIPELINE_NAME,
+            project_id=TEST_PROJECTID,
+            location=TEST_LOCATION,
+            gcp_conn_id=TEST_GCP_CONN_ID,
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.operators.datapipeline.DataPipelineHook")
+    def test_execute(self, data_pipeline_hook_mock, run_operator):
+        """
+        Test Run Operator execute with correct parameters
+        """
+        run_operator.execute(mock.MagicMock())
+        data_pipeline_hook_mock.assert_called_once_with(
+            gcp_conn_id=TEST_GCP_CONN_ID,
+        )
+
+        
data_pipeline_hook_mock.return_value.run_data_pipeline.assert_called_once_with(
+            data_pipeline_name=TEST_DATA_PIPELINE_NAME,
+            project_id=TEST_PROJECTID,
+            location=TEST_LOCATION,
+        )
+
+    def test_invalid_data_pipeline_name(self):
+        """
+        Test that AirflowException is raised if Run Operator is not given a 
data pipeline name.
+        """
+        init_kwargs = {
+            "task_id": TASK_ID,
+            "data_pipeline_name": None,
+            "project_id": TEST_PROJECTID,
+            "location": TEST_LOCATION,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+        }
+        with pytest.raises(AirflowException):
+            RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+    def test_invalid_project_id(self):
+        """
+        Test that AirflowException is raised if Run Operator is not given a 
project ID.
+        """
+        init_kwargs = {
+            "task_id": TASK_ID,
+            "data_pipeline_name": TEST_DATA_PIPELINE_NAME,
+            "project_id": None,
+            "location": TEST_LOCATION,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+        }
+        with pytest.raises(AirflowException):
+            RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+    def test_invalid_location(self):
+        """
+        Test that AirflowException is raised if Run Operator is not given a 
location.
+        """
+        init_kwargs = {
+            "task_id": TASK_ID,
+            "data_pipeline_name": TEST_DATA_PIPELINE_NAME,
+            "project_id": TEST_PROJECTID,
+            "location": None,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+        }
+        with pytest.raises(AirflowException):
+            RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
+
+    def test_invalid_response(self):
+        """
+        Test that AirflowException is raised if Run Operator fails execution 
and returns error.
+        """
+        init_kwargs = {
+            "task_id": TASK_ID,
+            "data_pipeline_name": TEST_DATA_PIPELINE_NAME,
+            "project_id": TEST_PROJECTID,
+            "location": TEST_LOCATION,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+        }
+        with pytest.raises(AirflowException):
+            
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value = 
{
+                "error": {"message": "example error"}
+            }
diff --git 
a/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py 
b/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
index 3eb67c530f..e4b82705a7 100644
--- a/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
+++ b/tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
@@ -28,6 +28,7 @@ from pathlib import Path
 from airflow import models
 from airflow.providers.google.cloud.operators.datapipeline import (
     CreateDataPipelineOperator,
+    RunDataPipelineOperator,
 )
 from airflow.providers.google.cloud.operators.gcs import 
GCSCreateBucketOperator, GCSDeleteBucketOperator
 from airflow.providers.google.cloud.transfers.local_to_gcs import 
LocalFilesystemToGCSOperator
@@ -38,7 +39,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
 GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
 GCP_LOCATION = os.environ.get("location", "us-central1")
 
-PIPELINE_NAME = "defualt-pipeline-name"
+PIPELINE_NAME = os.environ.get("DATA_PIPELINE_NAME", "defualt-pipeline-name")
 PIPELINE_TYPE = "PIPELINE_TYPE_BATCH"
 
 BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
@@ -117,6 +118,13 @@ with models.DAG(
     # when "teardown" task with trigger rule is part of the DAG
     list(dag.tasks) >> watcher()
 
+    # [START howto_operator_run_data_pipeline]
+    run_data_pipeline = RunDataPipelineOperator(
+        task_id="run_data_pipeline",
+        data_pipeline_name=PIPELINE_NAME,
+        project_id=GCP_PROJECT_ID,
+    )
+    # [END howto_operator_run_data_pipeline]
 
 from tests.system.utils import get_test_run  # noqa: E402
 

Reply via email to