This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 86e613029b Implement CloudComposerDAGRunSensor (#40088)
86e613029b is described below
commit 86e613029b871b0a8327d64c040da56f537c0727
Author: Maksim <[email protected]>
AuthorDate: Fri Jun 7 06:51:46 2024 -0700
Implement CloudComposerDAGRunSensor (#40088)
---
.../google/cloud/sensors/cloud_composer.py | 173 ++++++++++++++++++++-
.../google/cloud/triggers/cloud_composer.py | 115 ++++++++++++++
.../operators/cloud/cloud_composer.rst | 20 +++
.../google/cloud/sensors/test_cloud_composer.py | 63 +++++++-
.../google/cloud/triggers/test_cloud_composer.py | 61 +++++++-
.../cloud/composer/example_cloud_composer.py | 25 +++
6 files changed, 447 insertions(+), 10 deletions(-)
diff --git a/airflow/providers/google/cloud/sensors/cloud_composer.py
b/airflow/providers/google/cloud/sensors/cloud_composer.py
index 22d16e8f33..0301466eac 100644
--- a/airflow/providers/google/cloud/sensors/cloud_composer.py
+++ b/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -19,13 +19,24 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Sequence
+import json
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any, Iterable, Sequence
+from dateutil import parser
from deprecated import deprecated
+from google.cloud.orchestration.airflow.service_v1.types import
ExecuteAirflowCommandResponse
+from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowSkipException
-from airflow.providers.google.cloud.triggers.cloud_composer import
CloudComposerExecutionTrigger
+from airflow.providers.google.cloud.hooks.cloud_composer import
CloudComposerHook
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+ CloudComposerDAGRunTrigger,
+ CloudComposerExecutionTrigger,
+)
+from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.sensors.base import BaseSensorOperator
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -117,3 +128,161 @@ class CloudComposerEnvironmentSensor(BaseSensorOperator):
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
+
+
+class CloudComposerDAGRunSensor(BaseSensorOperator):
+ """
+ Check if a DAG run has completed.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param environment_id: The name of the Composer environment.
+ :param composer_dag_id: The ID of executable DAG.
+ :param allowed_states: Iterable of allowed states, default is
``['success']``.
+ :param execution_range: execution DAGs time range. Sensor checks DAGs
states only for DAGs which were
+ started in this time range. For yesterday, use [positive!]
datetime.timedelta(days=1).
+ For future, use [negative!] datetime.timedelta(days=-1). For specific
time, use list of
+ datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)].
+ Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for
states from specific time in the
+ past till current time execution.
+ Default value datetime.timedelta(days=1).
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param poll_interval: Optional: Control the rate of the poll for the
result of deferrable run.
+ :param deferrable: Run sensor in deferrable mode.
+ """
+
+ template_fields = (
+ "project_id",
+ "region",
+ "environment_id",
+ "composer_dag_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ region: str,
+ environment_id: str,
+ composer_dag_id: str,
+ allowed_states: Iterable[str] | None = None,
+ execution_range: timedelta | list[datetime] | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ poll_interval: int = 10,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.environment_id = environment_id
+ self.composer_dag_id = composer_dag_id
+ self.allowed_states = list(allowed_states) if allowed_states else
[TaskInstanceState.SUCCESS.value]
+ self.execution_range = execution_range
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.deferrable = deferrable
+ self.poll_interval = poll_interval
+
+ def _get_execution_dates(self, context) -> tuple[datetime, datetime]:
+ if isinstance(self.execution_range, timedelta):
+ if self.execution_range < timedelta(0):
+ return context["logical_date"], context["logical_date"] -
self.execution_range
+ else:
+ return context["logical_date"] - self.execution_range,
context["logical_date"]
+ elif isinstance(self.execution_range, list) and
len(self.execution_range) > 0:
+ return self.execution_range[0], self.execution_range[1] if len(
+ self.execution_range
+ ) > 1 else context["logical_date"]
+ else:
+ return context["logical_date"] - timedelta(1),
context["logical_date"]
+
+ def poke(self, context: Context) -> bool:
+ start_date, end_date = self._get_execution_dates(context)
+
+ if datetime.now(end_date.tzinfo) < end_date:
+ return False
+
+ dag_runs = self._pull_dag_runs()
+
+ self.log.info("Sensor waits for allowed states: %s",
self.allowed_states)
+ allowed_states_status = self._check_dag_runs_states(
+ dag_runs=dag_runs,
+ start_date=start_date,
+ end_date=end_date,
+ )
+
+ return allowed_states_status
+
+ def _pull_dag_runs(self) -> list[dict]:
+ """Pull the list of dag runs."""
+ hook = CloudComposerHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ dag_runs_cmd = hook.execute_airflow_command(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ command="dags",
+ subcommand="list-runs",
+ parameters=["-d", self.composer_dag_id, "-o", "json"],
+ )
+ cmd_result = hook.wait_command_execution_result(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+ )
+ dag_runs = json.loads(cmd_result["output"][0]["content"])
+ return dag_runs
+
+ def _check_dag_runs_states(
+ self,
+ dag_runs: list[dict],
+ start_date: datetime,
+ end_date: datetime,
+ ) -> bool:
+ for dag_run in dag_runs:
+ if (
+ start_date.timestamp()
+ < parser.parse(dag_run["execution_date"]).timestamp()
+ < end_date.timestamp()
+ ) and dag_run["state"] not in self.allowed_states:
+ return False
+ return True
+
+ def execute(self, context: Context) -> None:
+ if self.deferrable:
+ start_date, end_date = self._get_execution_dates(context)
+ self.defer(
+ trigger=CloudComposerDAGRunTrigger(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ composer_dag_id=self.composer_dag_id,
+ start_date=start_date,
+ end_date=end_date,
+ allowed_states=self.allowed_states,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ poll_interval=self.poll_interval,
+ ),
+ method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
+ )
+ super().execute(context)
+
+ def execute_complete(self, context: Context, event: dict):
+ if event and event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info("DAG %s has executed successfully.",
self.composer_dag_id)
diff --git a/airflow/providers/google/cloud/triggers/cloud_composer.py
b/airflow/providers/google/cloud/triggers/cloud_composer.py
index ac5a00c60f..2334d038e6 100644
--- a/airflow/providers/google/cloud/triggers/cloud_composer.py
+++ b/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -19,8 +19,13 @@
from __future__ import annotations
import asyncio
+import json
+from datetime import datetime
from typing import Any, Sequence
+from dateutil import parser
+from google.cloud.orchestration.airflow.service_v1.types import
ExecuteAirflowCommandResponse
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_composer import
CloudComposerAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -146,3 +151,113 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger):
}
)
return
+
+
+class CloudComposerDAGRunTrigger(BaseTrigger):
+ """The trigger wait for the DAG run completion."""
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ environment_id: str,
+ composer_dag_id: str,
+ start_date: datetime,
+ end_date: datetime,
+ allowed_states: list[str],
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: int = 10,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.region = region
+ self.environment_id = environment_id
+ self.composer_dag_id = composer_dag_id
+ self.start_date = start_date
+ self.end_date = end_date
+ self.allowed_states = allowed_states
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+
+ self.gcp_hook = CloudComposerAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger",
+ {
+ "project_id": self.project_id,
+ "region": self.region,
+ "environment_id": self.environment_id,
+ "composer_dag_id": self.composer_dag_id,
+ "start_date": self.start_date,
+ "end_date": self.end_date,
+ "allowed_states": self.allowed_states,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def _pull_dag_runs(self) -> list[dict]:
+ """Pull the list of dag runs."""
+ dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ command="dags",
+ subcommand="list-runs",
+ parameters=["-d", self.composer_dag_id, "-o", "json"],
+ )
+ cmd_result = await self.gcp_hook.wait_command_execution_result(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+ )
+ dag_runs = json.loads(cmd_result["output"][0]["content"])
+ return dag_runs
+
+ def _check_dag_runs_states(
+ self,
+ dag_runs: list[dict],
+ start_date: datetime,
+ end_date: datetime,
+ ) -> bool:
+ for dag_run in dag_runs:
+ if (
+ start_date.timestamp()
+ < parser.parse(dag_run["execution_date"]).timestamp()
+ < end_date.timestamp()
+ ) and dag_run["state"] not in self.allowed_states:
+ return False
+ return True
+
+ async def run(self):
+ try:
+ while True:
+ if datetime.now(self.end_date.tzinfo).timestamp() >
self.end_date.timestamp():
+ dag_runs = await self._pull_dag_runs()
+
+ self.log.info("Sensor waits for allowed states: %s",
self.allowed_states)
+ if self._check_dag_runs_states(
+ dag_runs=dag_runs,
+ start_date=self.start_date,
+ end_date=self.end_date,
+ ):
+ yield TriggerEvent({"status": "success"})
+ return
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ except AirflowException as ex:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(ex),
+ }
+ )
+ return
diff --git
a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
index cdb9cb2931..f8f00fbe6c 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
@@ -177,3 +177,23 @@ or you can define the same operator in the deferrable mode:
:dedent: 4
:start-after: [START
howto_operator_run_airflow_cli_command_deferrable_mode]
:end-before: [END howto_operator_run_airflow_cli_command_deferrable_mode]
+
+Check if a DAG run has completed
+--------------------------------
+
+You can use sensor that checks if a DAG run has completed in your
environments, use:
+:class:`~airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/composer/example_cloud_composer.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_dag_run]
+ :end-before: [END howto_sensor_dag_run]
+
+or you can define the same sensor in the deferrable mode:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/composer/example_cloud_composer.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_dag_run_deferrable_mode]
+ :end-before: [END howto_sensor_dag_run_deferrable_mode]
diff --git a/tests/providers/google/cloud/sensors/test_cloud_composer.py
b/tests/providers/google/cloud/sensors/test_cloud_composer.py
index 5241ff551e..c22eb90fde 100644
--- a/tests/providers/google/cloud/sensors/test_cloud_composer.py
+++ b/tests/providers/google/cloud/sensors/test_cloud_composer.py
@@ -17,17 +17,42 @@
from __future__ import annotations
+import json
+from datetime import datetime
from unittest import mock
import pytest
from airflow.exceptions import AirflowException, AirflowSkipException,
TaskDeferred
-from airflow.providers.google.cloud.sensors.cloud_composer import
CloudComposerEnvironmentSensor
-from airflow.providers.google.cloud.triggers.cloud_composer import
CloudComposerExecutionTrigger
+from airflow.providers.google.cloud.sensors.cloud_composer import (
+ CloudComposerDAGRunSensor,
+ CloudComposerEnvironmentSensor,
+)
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+ CloudComposerExecutionTrigger,
+)
TEST_PROJECT_ID = "test_project_id"
TEST_OPERATION_NAME = "test_operation_name"
TEST_REGION = "region"
+TEST_ENVIRONMENT_ID = "test_env_id"
+TEST_JSON_RESULT = lambda state: json.dumps(
+ [
+ {
+ "dag_id": "test_dag_id",
+ "run_id": "scheduled__2024-05-22T11:10:00+00:00",
+ "state": state,
+ "execution_date": "2024-05-22T11:10:00+00:00",
+ "start_date": "2024-05-22T11:20:01.531988+00:00",
+ "end_date": "2024-05-22T11:20:11.997479+00:00",
+ }
+ ]
+)
+TEST_EXEC_RESULT = lambda state: {
+ "output": [{"line_number": 1, "content": TEST_JSON_RESULT(state)}],
+ "output_end": True,
+ "exit_info": {"exit_code": 0, "error": ""},
+}
class TestCloudComposerEnvironmentSensor:
@@ -76,3 +101,37 @@ class TestCloudComposerEnvironmentSensor:
task.execute_complete(
context={}, event={"operation_done": True, "operation_name":
TEST_OPERATION_NAME}
)
+
+
+class TestCloudComposerDAGRunSensor:
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_wait_ready(self, mock_hook, to_dict_mode):
+ mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT("success")
+
+ task = CloudComposerDAGRunSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_dag_id="test_dag_id",
+ allowed_states=["success"],
+ )
+
+ assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0,
0)})
+
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_wait_not_ready(self, mock_hook, to_dict_mode):
+ mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT("running")
+
+ task = CloudComposerDAGRunSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_dag_id="test_dag_id",
+ allowed_states=["success"],
+ )
+
+ assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
diff --git a/tests/providers/google/cloud/triggers/test_cloud_composer.py
b/tests/providers/google/cloud/triggers/test_cloud_composer.py
index 99daaf83bd..00d109ed97 100644
--- a/tests/providers/google/cloud/triggers/test_cloud_composer.py
+++ b/tests/providers/google/cloud/triggers/test_cloud_composer.py
@@ -17,12 +17,16 @@
from __future__ import annotations
+from datetime import datetime
from unittest import mock
import pytest
from airflow.models import Connection
-from airflow.providers.google.cloud.triggers.cloud_composer import
CloudComposerAirflowCLICommandTrigger
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+ CloudComposerAirflowCLICommandTrigger,
+ CloudComposerDAGRunTrigger,
+)
from airflow.triggers.base import TriggerEvent
TEST_PROJECT_ID = "test-project-id"
@@ -34,6 +38,10 @@ TEST_EXEC_CMD_INFO = {
"pod_namespace": "test_namespace",
"error": "test_error",
}
+TEST_COMPOSER_DAG_ID = "test_dag_id"
+TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0)
+TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0)
+TEST_STATES = ["success"]
TEST_GCP_CONN_ID = "test_gcp_conn_id"
TEST_POLL_INTERVAL = 10
TEST_IMPERSONATION_CHAIN = "test_impersonation_chain"
@@ -49,7 +57,7 @@ TEST_EXEC_RESULT = {
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
)
-def trigger(mock_conn):
+def cli_command_trigger(mock_conn):
return CloudComposerAirflowCLICommandTrigger(
project_id=TEST_PROJECT_ID,
region=TEST_LOCATION,
@@ -61,9 +69,29 @@ def trigger(mock_conn):
)
[email protected]
[email protected](
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
+ return_value=Connection(conn_id="test_conn"),
+)
+def dag_run_trigger(mock_conn):
+ return CloudComposerDAGRunTrigger(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_LOCATION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_dag_id=TEST_COMPOSER_DAG_ID,
+ start_date=TEST_START_DATE,
+ end_date=TEST_END_DATE,
+ allowed_states=TEST_STATES,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ poll_interval=TEST_POLL_INTERVAL,
+ )
+
+
class TestCloudComposerAirflowCLICommandTrigger:
- def test_serialize(self, trigger):
- actual_data = trigger.serialize()
+ def test_serialize(self, cli_command_trigger):
+ actual_data = cli_command_trigger.serialize()
expected_data = (
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerAirflowCLICommandTrigger",
{
@@ -82,7 +110,7 @@ class TestCloudComposerAirflowCLICommandTrigger:
@mock.patch(
"airflow.providers.google.cloud.hooks.cloud_composer.CloudComposerAsyncHook.wait_command_execution_result"
)
- async def test_run(self, mock_exec_result, trigger):
+ async def test_run(self, mock_exec_result, cli_command_trigger):
mock_exec_result.return_value = TEST_EXEC_RESULT
expected_event = TriggerEvent(
@@ -91,6 +119,27 @@ class TestCloudComposerAirflowCLICommandTrigger:
"result": TEST_EXEC_RESULT,
}
)
- actual_event = await trigger.run().asend(None)
+ actual_event = await cli_command_trigger.run().asend(None)
assert actual_event == expected_event
+
+
+class TestCloudComposerDAGRunTrigger:
+ def test_serialize(self, dag_run_trigger):
+ actual_data = dag_run_trigger.serialize()
+ expected_data = (
+
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger",
+ {
+ "project_id": TEST_PROJECT_ID,
+ "region": TEST_LOCATION,
+ "environment_id": TEST_ENVIRONMENT_ID,
+ "composer_dag_id": TEST_COMPOSER_DAG_ID,
+ "start_date": TEST_START_DATE,
+ "end_date": TEST_END_DATE,
+ "allowed_states": TEST_STATES,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
+ "poll_interval": TEST_POLL_INTERVAL,
+ },
+ )
+ assert actual_data == expected_data
diff --git
a/tests/system/providers/google/cloud/composer/example_cloud_composer.py
b/tests/system/providers/google/cloud/composer/example_cloud_composer.py
index fe60c56ddf..52404fa375 100644
--- a/tests/system/providers/google/cloud/composer/example_cloud_composer.py
+++ b/tests/system/providers/google/cloud/composer/example_cloud_composer.py
@@ -31,6 +31,7 @@ from airflow.providers.google.cloud.operators.cloud_composer
import (
CloudComposerRunAirflowCLICommandOperator,
CloudComposerUpdateEnvironmentOperator,
)
+from airflow.providers.google.cloud.sensors.cloud_composer import
CloudComposerDAGRunSensor
from airflow.utils.trigger_rule import TriggerRule
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
@@ -158,6 +159,29 @@ with DAG(
)
# [END howto_operator_run_airflow_cli_command_deferrable_mode]
+ # [START howto_sensor_dag_run]
+ dag_run_sensor = CloudComposerDAGRunSensor(
+ task_id="dag_run_sensor",
+ project_id=PROJECT_ID,
+ region=REGION,
+ environment_id=ENVIRONMENT_ID,
+ composer_dag_id="airflow_monitoring",
+ allowed_states=["success"],
+ )
+ # [END howto_sensor_dag_run]
+
+ # [START howto_sensor_dag_run_deferrable_mode]
+ defer_dag_run_sensor = CloudComposerDAGRunSensor(
+ task_id="defer_dag_run_sensor",
+ project_id=PROJECT_ID,
+ region=REGION,
+ environment_id=ENVIRONMENT_ID_ASYNC,
+ composer_dag_id="airflow_monitoring",
+ allowed_states=["success"],
+ deferrable=True,
+ )
+ # [END howto_sensor_dag_run_deferrable_mode]
+
# [START howto_operator_delete_composer_environment]
delete_env = CloudComposerDeleteEnvironmentOperator(
task_id="delete_env",
@@ -186,6 +210,7 @@ with DAG(
get_env,
[update_env, defer_update_env],
[run_airflow_cli_cmd, defer_run_airflow_cli_cmd],
+ [dag_run_sensor, defer_dag_run_sensor],
[delete_env, defer_delete_env],
)