This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new db0679de12 Added Check for Cancel Workflow Invocation and added new
Query Workflow Invocation operator (#36351)
db0679de12 is described below
commit db0679de128667ac07402202d5ee92e60a3f1f6b
Author: Varun Taware <[email protected]>
AuthorDate: Sat Dec 30 12:42:13 2023 +0100
Added Check for Cancel Workflow Invocation and added new Query Workflow
Invocation operator (#36351)
* commit_for_pre_cancel_check_and_query_actions_operator
* adding_pre_cancel_check_and_query_actions_operator_6
---
airflow/providers/google/cloud/hooks/dataform.py | 63 +++++++++++++++-
.../providers/google/cloud/operators/dataform.py | 84 ++++++++++++++++++++++
.../operators/cloud/dataform.rst | 13 ++++
.../providers/google/cloud/hooks/test_dataform.py | 66 ++++++++++++++++-
.../google/cloud/operators/test_dataform.py | 28 ++++++++
.../google/cloud/dataform/example_dataform.py | 14 ++++
6 files changed, 263 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataform.py
b/airflow/providers/google/cloud/hooks/dataform.py
index 8d3c56d711..054e2bafe1 100644
--- a/airflow/providers/google/cloud/hooks/dataform.py
+++ b/airflow/providers/google/cloud/hooks/dataform.py
@@ -35,6 +35,7 @@ from airflow.providers.google.common.hooks.base_google import
GoogleBaseHook
if TYPE_CHECKING:
from google.api_core.retry import Retry
+ from google.cloud.dataform_v1beta1.services.dataform.pagers import
QueryWorkflowInvocationActionsPager
class DataformHook(GoogleBaseHook):
@@ -236,6 +237,43 @@ class DataformHook(GoogleBaseHook):
metadata=metadata,
)
+ @GoogleBaseHook.fallback_to_default_project_id
+ def query_workflow_invocation_actions(
+ self,
+ project_id: str,
+ region: str,
+ repository_id: str,
+ workflow_invocation_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> QueryWorkflowInvocationActionsPager:
+ """
+ Fetches WorkflowInvocation actions.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the task belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
task belongs to.
+ :param repository_id: Required. The ID of the Dataform repository that
the task belongs to.
+ :param workflow_invocation_id: Required. The workflow invocation
resource's id.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_dataform_client()
+ name = (
+ f"projects/{project_id}/locations/{region}/repositories/"
+ f"{repository_id}/workflowInvocations/{workflow_invocation_id}"
+ )
+ response = client.query_workflow_invocation_actions(
+ request={
+ "name": name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return response
+
@GoogleBaseHook.fallback_to_default_project_id
def cancel_workflow_invocation(
self,
@@ -263,9 +301,28 @@ class DataformHook(GoogleBaseHook):
f"projects/{project_id}/locations/{region}/repositories/"
f"{repository_id}/workflowInvocations/{workflow_invocation_id}"
)
- client.cancel_workflow_invocation(
- request={"name": name}, retry=retry, timeout=timeout,
metadata=metadata
- )
+ try:
+ workflow_invocation = self.get_workflow_invocation(
+ project_id=project_id,
+ region=region,
+ repository_id=repository_id,
+ workflow_invocation_id=workflow_invocation_id,
+ )
+ state = workflow_invocation.state
+ except Exception as err:
+ raise AirflowException(
+ f"Dataform API returned error when waiting for workflow
invocation:\n{err}"
+ )
+
+ if state == WorkflowInvocation.State.RUNNING:
+ client.cancel_workflow_invocation(
+ request={"name": name}, retry=retry, timeout=timeout,
metadata=metadata
+ )
+ else:
+ self.log.info(
+ "Workflow is not active. Either the execution has already
finished or has been canceled. "
+ "Please check the logs above for more details."
+ )
@GoogleBaseHook.fallback_to_default_project_id
def create_repository(
diff --git a/airflow/providers/google/cloud/operators/dataform.py
b/airflow/providers/google/cloud/operators/dataform.py
index 1617495f61..7b48332f3d 100644
--- a/airflow/providers/google/cloud/operators/dataform.py
+++ b/airflow/providers/google/cloud/operators/dataform.py
@@ -36,6 +36,7 @@ from google.cloud.dataform_v1beta1.types import (
MakeDirectoryResponse,
Repository,
WorkflowInvocation,
+ WorkflowInvocationAction,
Workspace,
WriteFileResponse,
)
@@ -348,6 +349,89 @@ class
DataformGetWorkflowInvocationOperator(GoogleCloudBaseOperator):
return WorkflowInvocation.to_dict(result)
+class DataformQueryWorkflowInvocationActionsOperator(GoogleCloudBaseOperator):
+ """
+ Returns WorkflowInvocationActions in a given WorkflowInvocation.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
task belongs to.
+ :param region: Required. The ID of the Google Cloud region that the task
belongs to.
+ :param repository_id: Required. The ID of the Dataform repository that the
task belongs to.
+ :param workflow_invocation_id: the workflow invocation resource's id.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "project_id",
+ "region",
+ "repository_id",
+ "workflow_invocation_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (DataformWorkflowInvocationLink(),)
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ repository_id: str,
+ workflow_invocation_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.repository_id = repository_id
+ self.workflow_invocation_id = workflow_invocation_id
+ self.retry = retry
+ self.timeout = timeout
+ self.metadata = metadata
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = DataformHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ DataformWorkflowInvocationLink.persist(
+ operator_instance=self,
+ context=context,
+ project_id=self.project_id,
+ region=self.region,
+ repository_id=self.repository_id,
+ workflow_invocation_id=self.workflow_invocation_id,
+ )
+ actions = hook.query_workflow_invocation_actions(
+ project_id=self.project_id,
+ region=self.region,
+ repository_id=self.repository_id,
+ workflow_invocation_id=self.workflow_invocation_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ actions_list = [WorkflowInvocationAction.to_dict(action) for action in
actions]
+ self.log.info("Workflow Query invocation actions: %s", actions_list)
+ return actions_list
+
+
class DataformCancelWorkflowInvocationOperator(GoogleCloudBaseOperator):
"""
Requests cancellation of a running WorkflowInvocation.
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataform.rst
b/docs/apache-airflow-providers-google/operators/cloud/dataform.rst
index 37bd4c8ba1..97d9d28e22 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataform.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataform.rst
@@ -120,6 +120,19 @@ To get a Workflow Invocation you can use:
:start-after: [START howto_operator_get_workflow_invocation]
:end-before: [END howto_operator_get_workflow_invocation]
+Query Workflow Invocation Action
+--------------------------------
+
+To query Workflow Invocation Actions you can use:
+
+:class:`~airflow.providers.google.cloud.operators.dataform.DataformQueryWorkflowInvocationActionsOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataform/example_dataform.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_query_workflow_invocation_actions]
+ :end-before: [END howto_operator_query_workflow_invocation_actions]
+
Cancel Workflow Invocation
--------------------------
diff --git a/tests/providers/google/cloud/hooks/test_dataform.py
b/tests/providers/google/cloud/hooks/test_dataform.py
index 39c7ca9031..d5c94e40d5 100644
--- a/tests/providers/google/cloud/hooks/test_dataform.py
+++ b/tests/providers/google/cloud/hooks/test_dataform.py
@@ -16,11 +16,14 @@
# under the License.
from __future__ import annotations
+import logging
from unittest import mock
import pytest
from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud.dataform_v1beta1.types import WorkflowInvocation
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataform import DataformHook
from tests.providers.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
@@ -144,8 +147,8 @@ class TestDataformHook:
)
@mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
- def test_cancel_workflow_invocation(self, mock_client):
- self.hook.cancel_workflow_invocation(
+ def test_query_workflow_invocation_actions(self, mock_client):
+ self.hook.query_workflow_invocation_actions(
project_id=PROJECT_ID,
region=REGION,
repository_id=REPOSITORY_ID,
@@ -155,6 +158,30 @@ class TestDataformHook:
f"projects/{PROJECT_ID}/locations/{REGION}/repositories/"
f"{REPOSITORY_ID}/workflowInvocations/{WORKFLOW_INVOCATION_ID}"
)
+
mock_client.return_value.query_workflow_invocation_actions.assert_called_once_with(
+ request=dict(
+ name=name,
+ ),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch(DATAFORM_STRING.format("DataformHook.get_workflow_invocation"))
+ @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
+ def test_cancel_workflow_invocation(self, mock_client, mock_state):
+ mock_state.return_value.state = WorkflowInvocation.State.RUNNING
+ name = (
+ f"projects/{PROJECT_ID}/locations/{REGION}/repositories/"
+ f"{REPOSITORY_ID}/workflowInvocations/{WORKFLOW_INVOCATION_ID}"
+ )
+
+ self.hook.cancel_workflow_invocation(
+ project_id=PROJECT_ID,
+ region=REGION,
+ repository_id=REPOSITORY_ID,
+ workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+ )
mock_client.return_value.cancel_workflow_invocation.assert_called_once_with(
request=dict(
name=name,
@@ -164,6 +191,41 @@ class TestDataformHook:
metadata=(),
)
+ @mock.patch(DATAFORM_STRING.format("DataformHook.get_workflow_invocation"))
+ @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
+ def
test_get_workflow_invocation_raises_exception_on_cancel_workflow_invocation(
+ self, mock_client, mock_state
+ ):
+ mock_client.return_value.get_dataform_client.return_value = None
+ mock_state.side_effect = AirflowException(
+ "Dataform API returned error when waiting for workflow invocation"
+ )
+
+ with pytest.raises(AirflowException, match="Dataform API returned
error*."):
+ self.hook.cancel_workflow_invocation(
+ project_id=PROJECT_ID,
+ region=REGION,
+ repository_id=REPOSITORY_ID,
+ workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+ )
+
+ @mock.patch(DATAFORM_STRING.format("DataformHook.get_workflow_invocation"))
+ @mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
+ def test_cancel_workflow_invocation_is_not_called(self, mock_client,
mock_state, caplog):
+ mock_state.return_value.state = WorkflowInvocation.State.SUCCEEDED
+ expected_log = "Workflow is not active. Either the execution has
already "
+ "finished or has been canceled. Please check the logs above "
+ "for more details."
+
+ with caplog.at_level(logging.INFO):
+ self.hook.cancel_workflow_invocation(
+ project_id=PROJECT_ID,
+ region=REGION,
+ repository_id=REPOSITORY_ID,
+ workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+ )
+ assert expected_log in caplog.text
+
@mock.patch(DATAFORM_STRING.format("DataformHook.get_dataform_client"))
def test_create_repository(self, mock_client):
self.hook.create_repository(
diff --git a/tests/providers/google/cloud/operators/test_dataform.py
b/tests/providers/google/cloud/operators/test_dataform.py
index 031d99591c..7c519a9147 100644
--- a/tests/providers/google/cloud/operators/test_dataform.py
+++ b/tests/providers/google/cloud/operators/test_dataform.py
@@ -32,6 +32,7 @@ from airflow.providers.google.cloud.operators.dataform import
(
DataformGetWorkflowInvocationOperator,
DataformInstallNpmPackagesOperator,
DataformMakeDirectoryOperator,
+ DataformQueryWorkflowInvocationActionsOperator,
DataformRemoveDirectoryOperator,
DataformRemoveFileOperator,
DataformWriteFileOperator,
@@ -39,6 +40,7 @@ from airflow.providers.google.cloud.operators.dataform import
(
HOOK_STR = "airflow.providers.google.cloud.operators.dataform.DataformHook"
WORKFLOW_INVOCATION_STR =
"airflow.providers.google.cloud.operators.dataform.WorkflowInvocation"
+WORKFLOW_INVOCATION_ACTION_STR =
"airflow.providers.google.cloud.operators.dataform.WorkflowInvocationAction"
COMPILATION_RESULT_STR =
"airflow.providers.google.cloud.operators.dataform.CompilationResult"
REPOSITORY_STR = "airflow.providers.google.cloud.operators.dataform.Repository"
WORKSPACE_STR = "airflow.providers.google.cloud.operators.dataform.Workspace"
@@ -172,6 +174,32 @@ class TestDataformGetWorkflowInvocationOperator:
)
+class TestDataformQueryWorkflowInvocationActionsOperator:
+ @mock.patch(HOOK_STR)
+ @mock.patch(WORKFLOW_INVOCATION_ACTION_STR)
+ def test_execute(self, workflow_invocation_action_str, hook_mock):
+ op = DataformQueryWorkflowInvocationActionsOperator(
+ task_id="query_workflow_invocation_action",
+ project_id=PROJECT_ID,
+ region=REGION,
+ repository_id=REPOSITORY_ID,
+ workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+ )
+
+ workflow_invocation_action_str.return_value.to_dict.return_value = None
+ op.execute(context=mock.MagicMock())
+
+
hook_mock.return_value.query_workflow_invocation_actions.assert_called_once_with(
+ project_id=PROJECT_ID,
+ region=REGION,
+ repository_id=REPOSITORY_ID,
+ workflow_invocation_id=WORKFLOW_INVOCATION_ID,
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+
class TestDataformCancelWorkflowInvocationOperator:
@mock.patch(HOOK_STR)
def test_execute(self, hook_mock):
diff --git a/tests/system/providers/google/cloud/dataform/example_dataform.py
b/tests/system/providers/google/cloud/dataform/example_dataform.py
index 4e5053868b..8cb017d978 100644
--- a/tests/system/providers/google/cloud/dataform/example_dataform.py
+++ b/tests/system/providers/google/cloud/dataform/example_dataform.py
@@ -39,6 +39,7 @@ from airflow.providers.google.cloud.operators.dataform import
(
DataformGetWorkflowInvocationOperator,
DataformInstallNpmPackagesOperator,
DataformMakeDirectoryOperator,
+ DataformQueryWorkflowInvocationActionsOperator,
DataformRemoveDirectoryOperator,
DataformRemoveFileOperator,
DataformWriteFileOperator,
@@ -182,6 +183,18 @@ with DAG(
)
# [END howto_operator_get_workflow_invocation]
+ # [START howto_operator_query_workflow_invocation_actions]
+ query_workflow_invocation_actions =
DataformQueryWorkflowInvocationActionsOperator(
+ task_id="query-workflow-invocation-actions",
+ project_id=PROJECT_ID,
+ region=REGION,
+ repository_id=REPOSITORY_ID,
+ workflow_invocation_id=(
+ "{{
task_instance.xcom_pull('create-workflow-invocation')['name'].split('/')[-1] }}"
+ ),
+ )
+ # [END howto_operator_query_workflow_invocation_actions]
+
create_workflow_invocation_for_cancel =
DataformCreateWorkflowInvocationOperator(
task_id="create-workflow-invocation-for-cancel",
project_id=PROJECT_ID,
@@ -291,6 +304,7 @@ with DAG(
>> get_compilation_result
>> create_workflow_invocation
>> get_workflow_invocation
+ >> query_workflow_invocation_actions
>> create_workflow_invocation_async
>> is_workflow_invocation_done
>> create_workflow_invocation_for_cancel