This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new db0679de12 Added Check for Cancel Workflow Invocation and added new 
Query Workflow Invocation operator (#36351)
db0679de12 is described below

commit db0679de128667ac07402202d5ee92e60a3f1f6b
Author: Varun Taware <[email protected]>
AuthorDate: Sat Dec 30 12:42:13 2023 +0100

    Added Check for Cancel Workflow Invocation and added new Query Workflow 
Invocation operator (#36351)
    
    * commit_for_pre_cancel_check_and_query_actions_operator
    
    * adding_pre_cancel_check_and_query_actions_operator_6
---
 airflow/providers/google/cloud/hooks/dataform.py   | 63 +++++++++++++++-
 .../providers/google/cloud/operators/dataform.py   | 84 ++++++++++++++++++++++
 .../operators/cloud/dataform.rst                   | 13 ++++
 .../providers/google/cloud/hooks/test_dataform.py  | 66 ++++++++++++++++-
 .../google/cloud/operators/test_dataform.py        | 28 ++++++++
 .../google/cloud/dataform/example_dataform.py      | 14 ++++
 6 files changed, 263 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataform.py 
b/airflow/providers/google/cloud/hooks/dataform.py
index 8d3c56d711..054e2bafe1 100644
--- a/airflow/providers/google/cloud/hooks/dataform.py
+++ b/airflow/providers/google/cloud/hooks/dataform.py
@@ -35,6 +35,7 @@ from airflow.providers.google.common.hooks.base_google import 
GoogleBaseHook
 
 if TYPE_CHECKING:
     from google.api_core.retry import Retry
+    from google.cloud.dataform_v1beta1.services.dataform.pagers import 
QueryWorkflowInvocationActionsPager
 
 
 class DataformHook(GoogleBaseHook):
@@ -236,6 +237,43 @@ class DataformHook(GoogleBaseHook):
             metadata=metadata,
         )
 
+    @GoogleBaseHook.fallback_to_default_project_id
+    def query_workflow_invocation_actions(
+        self,
+        project_id: str,
+        region: str,
+        repository_id: str,
+        workflow_invocation_id: str,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> QueryWorkflowInvocationActionsPager:
+        """
+        Fetches WorkflowInvocation actions.
+
+        :param project_id: Required. The ID of the Google Cloud project that 
the task belongs to.
+        :param region: Required. The ID of the Google Cloud region that the 
task belongs to.
+        :param repository_id: Required. The ID of the Dataform repository that 
the task belongs to.
+        :param workflow_invocation_id:  Required. The workflow invocation 
resource's id.
+        :param retry: Designation of what errors, if any, should be retried.
+        :param timeout: The timeout for this request.
+        :param metadata: Strings which should be sent along with the request 
as metadata.
+        """
+        client = self.get_dataform_client()
+        name = (
+            f"projects/{project_id}/locations/{region}/repositories/"
+            f"{repository_id}/workflowInvocations/{workflow_invocation_id}"
+        )
+        response = client.query_workflow_invocation_actions(
+            request={
+                "name": name,
+            },
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+        return response
+
     @GoogleBaseHook.fallback_to_default_project_id
     def cancel_workflow_invocation(
         self,
@@ -263,9 +301,28 @@ class DataformHook(GoogleBaseHook):
             f"projects/{project_id}/locations/{region}/repositories/"
             f"{repository_id}/workflowInvocations/{workflow_invocation_id}"
         )
-        client.cancel_workflow_invocation(
-            request={"name": name}, retry=retry, timeout=timeout, 
metadata=metadata
-        )
+        try:
+            workflow_invocation = self.get_workflow_invocation(
+                project_id=project_id,
+                region=region,
+                repository_id=repository_id,
+                workflow_invocation_id=workflow_invocation_id,
+            )
+            state = workflow_invocation.state
+        except Exception as err:
+            raise AirflowException(
+                f"Dataform API returned error when waiting for workflow 
invocation:\n{err}"
+            )
+
+        if state == WorkflowInvocation.State.RUNNING:
+            client.cancel_workflow_invocation(
+                request={"name": name}, retry=retry, timeout=timeout, 
metadata=metadata
+            )
+        else:
+            self.log.info(
+                "Workflow is not active. Either the execution has already 
finished or has been canceled. "
+                "Please check the logs above for more details."
+            )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def create_repository(
diff --git a/airflow/providers/google/cloud/operators/dataform.py 
b/airflow/providers/google/cloud/operators/dataform.py
index 1617495f61..7b48332f3d 100644
--- a/airflow/providers/google/cloud/operators/dataform.py
+++ b/airflow/providers/google/cloud/operators/dataform.py
@@ -36,6 +36,7 @@ from google.cloud.dataform_v1beta1.types import (
     MakeDirectoryResponse,
     Repository,
     WorkflowInvocation,
+    WorkflowInvocationAction,
     Workspace,
     WriteFileResponse,
 )
@@ -348,6 +349,89 @@ class 
DataformGetWorkflowInvocationOperator(GoogleCloudBaseOperator):
         return WorkflowInvocation.to_dict(result)
 
 
+class DataformQueryWorkflowInvocationActionsOperator(GoogleCloudBaseOperator):
+    """
+    Returns WorkflowInvocationActions in a given WorkflowInvocation.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
task belongs to.
+    :param region: Required. The ID of the Google Cloud region that the task 
belongs to.
+    :param repository_id: Required. The ID of the Dataform repository that the 
task belongs to.
+    :param workflow_invocation_id:  the workflow invocation resource's id.
+    :param retry: Designation of what errors, if any, should be retried.
+    :param timeout: The timeout for this request.
+    :param metadata: Strings which should be sent along with the request as 
metadata.
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    """
+
+    template_fields = (
+        "project_id",
+        "region",
+        "repository_id",
+        "workflow_invocation_id",
+        "impersonation_chain",
+    )
+    operator_extra_links = (DataformWorkflowInvocationLink(),)
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        repository_id: str,
+        workflow_invocation_id: str,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.repository_id = repository_id
+        self.workflow_invocation_id = workflow_invocation_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context: Context):
+        hook = DataformHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+        DataformWorkflowInvocationLink.persist(
+            operator_instance=self,
+            context=context,
+            project_id=self.project_id,
+            region=self.region,
+            repository_id=self.repository_id,
+            workflow_invocation_id=self.workflow_invocation_id,
+        )
+        actions = hook.query_workflow_invocation_actions(
+            project_id=self.project_id,
+            region=self.region,
+            repository_id=self.repository_id,
+            workflow_invocation_id=self.workflow_invocation_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        actions_list = [WorkflowInvocationAction.to_dict(action) for action in 
actions]
+        self.log.info("Workflow Query invocation actions: %s", actions_list)
+        return actions_list
+
+
 class DataformCancelWorkflowInvocationOperator(GoogleCloudBaseOperator):
     """
     Requests cancellation of a running WorkflowInvocation.
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataform.rst 
b/docs/apache-airflow-providers-google/operators/cloud/dataform.rst
index 37bd4c8ba1..97d9d28e22 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataform.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataform.rst
@@ -120,6 +120,19 @@ To get a Workflow Invocation you can use:
     :start-after: [START howto_operator_get_workflow_invocation]
     :end-before: [END howto_operator_get_workflow_invocation]
 
+Query Workflow Invocation Action
+--------------------------------
+
+To query Workflow Invocation Actions you can use:
+
+:class:`~airflow.providers.google.cloud.operators.dataform.DataformQueryWorkflowInvocationActionsOperator`
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataform/example_dataform.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_query_workflow_invocation_actions]
+    :end-before: [END howto_operator_query_workflow_invocation_actions]
+
 Cancel Workflow Invocation
 --------------------------
 
diff --git a/tests/providers/google/cloud/hooks/test_dataform.py 
b/tests/providers/google/cloud/hooks/test_dataform.py
index 39c7ca9031..d5c94e40d5 100644
--- a/tests/providers/google/cloud/hooks/test_dataform.py
+++ b/tests/providers/google/cloud/hooks/test_dataform.py
@@ -16,11 +16,14 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from unittest import mock
 
 import pytest
 from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud.dataform_v1beta1.types import WorkflowInvocation
 
+from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.hooks.dataform import DataformHook
 from tests.providers.google.cloud.utils.base_gcp_mock import 
mock_base_gcp_hook_default_project_id
 
@@ -144,8 +147,8 @@ class TestDataformHook:
         )
 
     @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
-    def test_cancel_workflow_invocation(self, mock_client):
-        self.hook.cancel_workflow_invocation(
+    def test_query_workflow_invocation_actions(self, mock_client):
+        self.hook.query_workflow_invocation_actions(
             project_id=PROJECT_ID,
             region=REGION,
             repository_id=REPOSITORY_ID,
@@ -155,6 +158,30 @@ class TestDataformHook:
             f"projects/{PROJECT_ID}/locations/{REGION}/repositories/"
             f"{REPOSITORY_ID}/workflowInvocations/{WORKFLOW_INVOCATION_ID}"
         )
+        
mock_client.return_value.query_workflow_invocation_actions.assert_called_once_with(
+            request=dict(
+                name=name,
+            ),
+            retry=DEFAULT,
+            timeout=None,
+            metadata=(),
+        )
+
+    @mock.patch(DATAFORM_STRING.format("DataformHook.get_workflow_invocation"))
+    @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
+    def test_cancel_workflow_invocation(self, mock_client, mock_state):
+        mock_state.return_value.state = WorkflowInvocation.State.RUNNING
+        name = (
+            f"projects/{PROJECT_ID}/locations/{REGION}/repositories/"
+            f"{REPOSITORY_ID}/workflowInvocations/{WORKFLOW_INVOCATION_ID}"
+        )
+
+        self.hook.cancel_workflow_invocation(
+            project_id=PROJECT_ID,
+            region=REGION,
+            repository_id=REPOSITORY_ID,
+            workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+        )
         
mock_client.return_value.cancel_workflow_invocation.assert_called_once_with(
             request=dict(
                 name=name,
@@ -164,6 +191,41 @@ class TestDataformHook:
             metadata=(),
         )
 
+    @mock.patch(DATAFORM_STRING.format("DataformHook.get_workflow_invocation"))
+    @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
+    def 
test_get_workflow_invocation_raises_exception_on_cancel_workflow_invocation(
+        self, mock_client, mock_state
+    ):
+        mock_client.return_value.get_dataform_client.return_value = None
+        mock_state.side_effect = AirflowException(
+            "Dataform API returned error when waiting for workflow invocation"
+        )
+
+        with pytest.raises(AirflowException, match="Dataform API returned 
error*."):
+            self.hook.cancel_workflow_invocation(
+                project_id=PROJECT_ID,
+                region=REGION,
+                repository_id=REPOSITORY_ID,
+                workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+            )
+
+    @mock.patch(DATAFORM_STRING.format("DataformHook.get_workflow_invocation"))
+    @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
+    def test_cancel_workflow_invocation_is_not_called(self, mock_client, 
mock_state, caplog):
+        mock_state.return_value.state = WorkflowInvocation.State.SUCCEEDED
+        expected_log = "Workflow is not active. Either the execution has 
already "
+        "finished or has been canceled. Please check the logs above "
+        "for more details."
+
+        with caplog.at_level(logging.INFO):
+            self.hook.cancel_workflow_invocation(
+                project_id=PROJECT_ID,
+                region=REGION,
+                repository_id=REPOSITORY_ID,
+                workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+            )
+            assert expected_log in caplog.text
+
     @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
     def test_create_repository(self, mock_client):
         self.hook.create_repository(
diff --git a/tests/providers/google/cloud/operators/test_dataform.py 
b/tests/providers/google/cloud/operators/test_dataform.py
index 031d99591c..7c519a9147 100644
--- a/tests/providers/google/cloud/operators/test_dataform.py
+++ b/tests/providers/google/cloud/operators/test_dataform.py
@@ -32,6 +32,7 @@ from airflow.providers.google.cloud.operators.dataform import 
(
     DataformGetWorkflowInvocationOperator,
     DataformInstallNpmPackagesOperator,
     DataformMakeDirectoryOperator,
+    DataformQueryWorkflowInvocationActionsOperator,
     DataformRemoveDirectoryOperator,
     DataformRemoveFileOperator,
     DataformWriteFileOperator,
@@ -39,6 +40,7 @@ from airflow.providers.google.cloud.operators.dataform import 
(
 
 HOOK_STR = "airflow.providers.google.cloud.operators.dataform.DataformHook"
 WORKFLOW_INVOCATION_STR = 
"airflow.providers.google.cloud.operators.dataform.WorkflowInvocation"
+WORKFLOW_INVOCATION_ACTION_STR = 
"airflow.providers.google.cloud.operators.dataform.WorkflowInvocationAction"
 COMPILATION_RESULT_STR = 
"airflow.providers.google.cloud.operators.dataform.CompilationResult"
 REPOSITORY_STR = "airflow.providers.google.cloud.operators.dataform.Repository"
 WORKSPACE_STR = "airflow.providers.google.cloud.operators.dataform.Workspace"
@@ -172,6 +174,32 @@ class TestDataformGetWorkflowInvocationOperator:
         )
 
 
+class TestDataformQueryWorkflowInvocationActionsOperator:
+    @mock.patch(HOOK_STR)
+    @mock.patch(WORKFLOW_INVOCATION_ACTION_STR)
+    def test_execute(self, workflow_invocation_action_str, hook_mock):
+        op = DataformQueryWorkflowInvocationActionsOperator(
+            task_id="query_workflow_invocation_action",
+            project_id=PROJECT_ID,
+            region=REGION,
+            repository_id=REPOSITORY_ID,
+            workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+        )
+
+        workflow_invocation_action_str.return_value.to_dict.return_value = None
+        op.execute(context=mock.MagicMock())
+
+        
hook_mock.return_value.query_workflow_invocation_actions.assert_called_once_with(
+            project_id=PROJECT_ID,
+            region=REGION,
+            repository_id=REPOSITORY_ID,
+            workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+            retry=DEFAULT,
+            timeout=None,
+            metadata=(),
+        )
+
+
 class TestDataformCancelWorkflowInvocationOperator:
     @mock.patch(HOOK_STR)
     def test_execute(self, hook_mock):
diff --git a/tests/system/providers/google/cloud/dataform/example_dataform.py 
b/tests/system/providers/google/cloud/dataform/example_dataform.py
index 4e5053868b..8cb017d978 100644
--- a/tests/system/providers/google/cloud/dataform/example_dataform.py
+++ b/tests/system/providers/google/cloud/dataform/example_dataform.py
@@ -39,6 +39,7 @@ from airflow.providers.google.cloud.operators.dataform import 
(
     DataformGetWorkflowInvocationOperator,
     DataformInstallNpmPackagesOperator,
     DataformMakeDirectoryOperator,
+    DataformQueryWorkflowInvocationActionsOperator,
     DataformRemoveDirectoryOperator,
     DataformRemoveFileOperator,
     DataformWriteFileOperator,
@@ -182,6 +183,18 @@ with DAG(
     )
     # [END howto_operator_get_workflow_invocation]
 
+    # [START howto_operator_query_workflow_invocation_actions]
+    query_workflow_invocation_actions = 
DataformQueryWorkflowInvocationActionsOperator(
+        task_id="query-workflow-invocation-actions",
+        project_id=PROJECT_ID,
+        region=REGION,
+        repository_id=REPOSITORY_ID,
+        workflow_invocation_id=(
+            "{{ 
task_instance.xcom_pull('create-workflow-invocation')['name'].split('/')[-1] }}"
+        ),
+    )
+    # [END howto_operator_query_workflow_invocation_actions]
+
     create_workflow_invocation_for_cancel = 
DataformCreateWorkflowInvocationOperator(
         task_id="create-workflow-invocation-for-cancel",
         project_id=PROJECT_ID,
@@ -291,6 +304,7 @@ with DAG(
         >> get_compilation_result
         >> create_workflow_invocation
         >> get_workflow_invocation
+        >> query_workflow_invocation_actions
         >> create_workflow_invocation_async
         >> is_workflow_invocation_done
         >> create_workflow_invocation_for_cancel

Reply via email to