This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 55049c50d5 Add deferrable mode to `DbtCloudRunJobOperator` (#29014)
55049c50d5 is described below
commit 55049c50d52323e242c2387f285f0591ea38cde7
Author: Phani Kumar <[email protected]>
AuthorDate: Mon Jan 23 17:36:24 2023 +0530
Add deferrable mode to `DbtCloudRunJobOperator` (#29014)
This PR donates the `DbtCloudRunJobOperatorAsync` from
[astronomer-providers](https://github.com/astronomer/astronomer-providers) repo
---
airflow/providers/dbt/cloud/hooks/dbt.py | 110 ++++++++++++++++++-
airflow/providers/dbt/cloud/operators/dbt.py | 67 +++++++++---
airflow/providers/dbt/cloud/provider.yaml | 2 +
airflow/providers/dbt/cloud/triggers/__init__.py | 16 +++
airflow/providers/dbt/cloud/triggers/dbt.py | 119 +++++++++++++++++++++
.../operators.rst | 12 +++
generated/provider_dependencies.json | 4 +-
7 files changed, 314 insertions(+), 16 deletions(-)
diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 4b6ac2151a..3ddeeb222b 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -22,8 +22,11 @@ import warnings
from enum import Enum
from functools import wraps
from inspect import signature
-from typing import Any, Callable, Sequence, Set
+from typing import Any, Callable, Sequence, Set, TypeVar, cast
+import aiohttp
+from aiohttp import ClientResponseError
+from asgiref.sync import sync_to_async
from requests import PreparedRequest, Session
from requests.auth import AuthBase
from requests.models import Response
@@ -125,6 +128,34 @@ class DbtCloudJobRunException(AirflowException):
"""An exception that indicates a job run failed to complete."""
+T = TypeVar("T", bound=Any)
+
+
+def provide_account_id(func: T) -> T:
+ """
+ Decorator which provides a fallback value for ``account_id``. If the
``account_id`` is None or not passed
+ to the decorated function, the value will be taken from the configured dbt
Cloud Airflow Connection.
+ """
+ function_signature = signature(func)
+
+ @wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ bound_args = function_signature.bind(*args, **kwargs)
+
+ if bound_args.arguments.get("account_id") is None:
+ self = args[0]
+ if self.dbt_cloud_conn_id:
+ connection = await
sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
+ default_account_id = connection.login
+ if not default_account_id:
+ raise AirflowException("Could not determine the dbt Cloud
account.")
+ bound_args.arguments["account_id"] = int(default_account_id)
+
+ return await func(*bound_args.args, **bound_args.kwargs)
+
+ return cast(T, wrapper)
+
+
class DbtCloudHook(HttpHook):
"""
Interact with dbt Cloud using the V2 API.
@@ -150,6 +181,83 @@ class DbtCloudHook(HttpHook):
super().__init__(auth_type=TokenAuth)
self.dbt_cloud_conn_id = dbt_cloud_conn_id
+ @staticmethod
+ def get_request_url_params(
+ tenant: str, endpoint: str, include_related: list[str] | None = None
+ ) -> tuple[str, dict[str, Any]]:
+ """
+ Form URL from base url and endpoint url
+
+ :param tenant: The tenant name which is need to be replaced in base
url.
+ :param endpoint: Endpoint url to be requested.
+ :param include_related: Optional. List of related fields to pull with
the run.
+ Valid values are "trigger", "job", "repository", and "environment".
+ """
+ data: dict[str, Any] = {}
+ base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"
+ if include_related:
+ data = {"include_related": include_related}
+ if base_url and not base_url.endswith("/") and endpoint and not
endpoint.startswith("/"):
+ url = base_url + "/" + endpoint
+ else:
+ url = (base_url or "") + (endpoint or "")
+ return url, data
+
+ async def get_headers_tenants_from_connection(self) -> tuple[dict[str,
Any], str]:
+ """Get Headers, tenants from the connection details"""
+ headers: dict[str, Any] = {}
+ connection: Connection = await
sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
+ tenant: str = connection.schema if connection.schema else "cloud"
+ package_name, provider_version = _get_provider_info()
+ headers["User-Agent"] = f"{package_name}-v{provider_version}"
+ headers["Content-Type"] = "application/json"
+ headers["Authorization"] = f"Token {connection.password}"
+ return headers, tenant
+
+ @provide_account_id
+ async def get_job_details(
+ self, run_id: int, account_id: int | None = None, include_related:
list[str] | None = None
+ ) -> Any:
+ """
+ Uses Http async call to retrieve metadata for a specific run of a dbt
Cloud job.
+
+ :param run_id: The ID of a dbt Cloud job run.
+ :param account_id: Optional. The ID of a dbt Cloud account.
+ :param include_related: Optional. List of related fields to pull with
the run.
+ Valid values are "trigger", "job", "repository", and "environment".
+ """
+ endpoint = f"{account_id}/runs/{run_id}/"
+ headers, tenant = await self.get_headers_tenants_from_connection()
+ url, params = self.get_request_url_params(tenant, endpoint,
include_related)
+ async with aiohttp.ClientSession(headers=headers) as session:
+ async with session.get(url, params=params) as response:
+ try:
+ response.raise_for_status()
+ return await response.json()
+ except ClientResponseError as e:
+ raise AirflowException(str(e.status) + ":" + e.message)
+
+ async def get_job_status(
+ self, run_id: int, account_id: int | None = None, include_related:
list[str] | None = None
+ ) -> int:
+ """
+ Retrieves the status for a specific run of a dbt Cloud job.
+
+ :param run_id: The ID of a dbt Cloud job run.
+ :param account_id: Optional. The ID of a dbt Cloud account.
+ :param include_related: Optional. List of related fields to pull with
the run.
+ Valid values are "trigger", "job", "repository", and "environment".
+ """
+ try:
+ self.log.info("Getting the status of job run %s.", str(run_id))
+ response = await self.get_job_details(
+ run_id, account_id=account_id, include_related=include_related
+ )
+ job_run_status: int = response["data"]["status"]
+ return job_run_status
+ except Exception as e:
+ raise e
+
@cached_property
def connection(self) -> Connection:
_connection = self.get_connection(self.dbt_cloud_conn_id)
diff --git a/airflow/providers/dbt/cloud/operators/dbt.py
b/airflow/providers/dbt/cloud/operators/dbt.py
index 472b2ffa7f..f65ce077d3 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -17,10 +17,14 @@
from __future__ import annotations
import json
+import time
+import warnings
from typing import TYPE_CHECKING, Any
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook,
DbtCloudJobRunException, DbtCloudJobRunStatus
+from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -63,6 +67,7 @@ class DbtCloudRunJobOperator(BaseOperator):
Used only if ``wait_for_termination`` is True. Defaults to 60 seconds.
:param additional_run_config: Optional. Any additional parameters that
should be included in the API
request when triggering the job.
+ :param deferrable: Run operator in the deferrable mode
:return: The ID of the triggered dbt Cloud job run.
"""
@@ -91,6 +96,7 @@ class DbtCloudRunJobOperator(BaseOperator):
timeout: int = 60 * 60 * 24 * 7,
check_interval: int = 60,
additional_run_config: dict[str, Any] | None = None,
+ deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -106,8 +112,9 @@ class DbtCloudRunJobOperator(BaseOperator):
self.additional_run_config = additional_run_config or {}
self.hook: DbtCloudHook
self.run_id: int
+ self.deferrable = deferrable
- def execute(self, context: Context) -> int:
+ def execute(self, context: Context):
if self.trigger_reason is None:
self.trigger_reason = (
f"Triggered via Apache Airflow by task {self.task_id!r} in the
{self.dag.dag_id} DAG."
@@ -129,20 +136,52 @@ class DbtCloudRunJobOperator(BaseOperator):
context["ti"].xcom_push(key="job_run_url", value=job_run_url)
if self.wait_for_termination:
- self.log.info("Waiting for job run %s to terminate.",
str(self.run_id))
-
- if self.hook.wait_for_job_run_status(
- run_id=self.run_id,
- account_id=self.account_id,
- expected_statuses=DbtCloudJobRunStatus.SUCCESS.value,
- check_interval=self.check_interval,
- timeout=self.timeout,
- ):
- self.log.info("Job run %s has completed successfully.",
str(self.run_id))
+ if self.deferrable is False:
+ self.log.info("Waiting for job run %s to terminate.",
str(self.run_id))
+
+ if self.hook.wait_for_job_run_status(
+ run_id=self.run_id,
+ account_id=self.account_id,
+ expected_statuses=DbtCloudJobRunStatus.SUCCESS.value,
+ check_interval=self.check_interval,
+ timeout=self.timeout,
+ ):
+ self.log.info("Job run %s has completed successfully.",
str(self.run_id))
+ else:
+ raise DbtCloudJobRunException(f"Job run {self.run_id} has
failed or has been cancelled.")
+
+ return self.run_id
else:
- raise DbtCloudJobRunException(f"Job run {self.run_id} has
failed or has been cancelled.")
-
- return self.run_id
+ end_time = time.time() + self.timeout
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=DbtCloudRunJobTrigger(
+ conn_id=self.dbt_cloud_conn_id,
+ run_id=self.run_id,
+ end_time=end_time,
+ account_id=self.account_id,
+ poll_interval=self.check_interval,
+ ),
+ method_name="execute_complete",
+ )
+ else:
+ if self.deferrable is True:
+ warnings.warn(
+ "Argument `wait_for_termination` is False and `deferrable`
is True , hence "
+ "`deferrable` parameter doesn't have any effect",
+ )
+ return self.run_id
+
+ def execute_complete(self, context: "Context", event: dict[str, Any]) ->
int:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(event["message"])
+ return int(event["run_id"])
def on_kill(self) -> None:
if self.run_id:
diff --git a/airflow/providers/dbt/cloud/provider.yaml
b/airflow/providers/dbt/cloud/provider.yaml
index ad2817eb8e..4315f9c272 100644
--- a/airflow/providers/dbt/cloud/provider.yaml
+++ b/airflow/providers/dbt/cloud/provider.yaml
@@ -34,6 +34,8 @@ versions:
dependencies:
- apache-airflow>=2.3.0
- apache-airflow-providers-http
+ - asgiref
+ - aiohttp
integrations:
- integration-name: dbt Cloud
diff --git a/airflow/providers/dbt/cloud/triggers/__init__.py
b/airflow/providers/dbt/cloud/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/dbt/cloud/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/dbt/cloud/triggers/dbt.py
b/airflow/providers/dbt/cloud/triggers/dbt.py
new file mode 100644
index 0000000000..9bad789a52
--- /dev/null
+++ b/airflow/providers/dbt/cloud/triggers/dbt.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import time
+from typing import Any, AsyncIterator
+
+from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook,
DbtCloudJobRunStatus
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class DbtCloudRunJobTrigger(BaseTrigger):
+ """
+ DbtCloudRunJobTrigger is triggered with run id and account id, makes async
Http call to dbt and
+ get the status for the submitted job with run id in polling interval of
time.
+
+ :param conn_id: The connection identifier for connecting to Dbt.
+ :param run_id: The ID of a dbt Cloud job.
+ :param end_time: Time in seconds to wait for a job run to reach a terminal
status. Defaults to 7 days.
+ :param account_id: The ID of a dbt Cloud account.
+ :param poll_interval: polling period in seconds to check for the status.
+ """
+
+ def __init__(
+ self,
+ conn_id: str,
+ run_id: int,
+ end_time: float,
+ poll_interval: float,
+ account_id: int | None,
+ ):
+ super().__init__()
+ self.run_id = run_id
+ self.account_id = account_id
+ self.conn_id = conn_id
+ self.end_time = end_time
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes DbtCloudRunJobTrigger arguments and classpath."""
+ return (
+ "airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger",
+ {
+ "run_id": self.run_id,
+ "account_id": self.account_id,
+ "conn_id": self.conn_id,
+ "end_time": self.end_time,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]:
+ """Make async connection to Dbt, polls for the pipeline run status"""
+ hook = DbtCloudHook(self.conn_id)
+ try:
+ while await self.is_still_running(hook):
+ if self.end_time < time.time():
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Job run {self.run_id} has not reached
a terminal status after "
+ f"{self.end_time} seconds.",
+ "run_id": self.run_id,
+ }
+ )
+ await asyncio.sleep(self.poll_interval)
+ job_run_status = await hook.get_job_status(self.run_id,
self.account_id)
+ if job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": f"Job run {self.run_id} has completed
successfully.",
+ "run_id": self.run_id,
+ }
+ )
+ elif job_run_status == DbtCloudJobRunStatus.CANCELLED.value:
+ yield TriggerEvent(
+ {
+ "status": "cancelled",
+ "message": f"Job run {self.run_id} has been
cancelled.",
+ "run_id": self.run_id,
+ }
+ )
+ else:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Job run {self.run_id} has failed.",
+ "run_id": self.run_id,
+ }
+ )
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e),
"run_id": self.run_id})
+
+ async def is_still_running(self, hook: DbtCloudHook) -> bool:
+ """
+ Async function to check whether the job is submitted via async API is
in
+ running state and returns True if it is still running else
+ return False
+ """
+ job_run_status = await hook.get_job_status(self.run_id,
self.account_id)
+ if not DbtCloudJobRunStatus.is_terminal(job_run_status):
+ return True
+ return False
diff --git a/docs/apache-airflow-providers-dbt-cloud/operators.rst
b/docs/apache-airflow-providers-dbt-cloud/operators.rst
index de5b0b8060..1f7b27b280 100644
--- a/docs/apache-airflow-providers-dbt-cloud/operators.rst
+++ b/docs/apache-airflow-providers-dbt-cloud/operators.rst
@@ -40,6 +40,18 @@ execution time. This functionality is controlled by the
``wait_for_termination``
:class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`).
Setting ``wait_for_termination`` to
False is a good approach for long-running dbt Cloud jobs.
+The ``deferrable`` parameter along with ``wait_for_termination`` will control
the functionality
+whether to poll the job status on the worker or defer using the Triggerer.
+When ``wait_for_termination`` is True and ``deferrable`` is False,we submit
the job and ``poll``
+for its status on the worker. This will keep the worker slot occupied till the
job execution is done.
+When ``wait_for_termination`` is True and ``deferrable`` is True,
+we submit the job and ``defer`` using Triggerer. This will release the worker
slot leading to savings in
+resource utilization while the job is running.
+
+When ``wait_for_termination`` is False and ``deferrable`` is False, we just
submit the job and can only
+track the job status with the
:class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`.
+
+
While ``schema_override`` and ``steps_override`` are explicit, optional
parameters for the
``DbtCloudRunJobOperator``, custom run configurations can also be passed to
the operator using the
``additional_run_config`` dictionary. This parameter can be used to initialize
additional runtime
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 90f878796b..8f42cded79 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -241,8 +241,10 @@
},
"dbt.cloud": {
"deps": [
+ "aiohttp",
"apache-airflow-providers-http",
- "apache-airflow>=2.3.0"
+ "apache-airflow>=2.3.0",
+ "asgiref"
],
"cross-providers-deps": [
"http"