This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4ebabc15df3 Add retry mechanism and error handling to DBT Hook (#56651)
4ebabc15df3 is described below
commit 4ebabc15df3a9ff368679cbd356e0206da3097f3
Author: AardJan <[email protected]>
AuthorDate: Wed Oct 15 16:45:34 2025 +0200
Add retry mechanism and error handling to DBT Hook (#56651)
* fix issue when single API error fail entire DbtCloudRunJobOperator task
* fix formatting and refactor get_job_details function
* add hook params to operator and trigger DBT classes
* add docstring, fix function annotation
* add test for retry in hooks and triggers
* fix Hook class param, change method for sync requests
* simplify tests, remove account param
* change test names
* change rerise implementation
* add hook_params for others DBT operators
* change timeout to optional
* add test for timeout param, fix mock warning
* reformat using prek
* change import order, change log request after retry
---
providers/dbt/cloud/pyproject.toml | 1 +
.../src/airflow/providers/dbt/cloud/hooks/dbt.py | 120 ++++++++--
.../airflow/providers/dbt/cloud/operators/dbt.py | 15 +-
.../airflow/providers/dbt/cloud/triggers/dbt.py | 6 +-
.../cloud/tests/unit/dbt/cloud/hooks/test_dbt.py | 242 ++++++++++++++++++++-
.../tests/unit/dbt/cloud/triggers/test_dbt.py | 2 +
6 files changed, 367 insertions(+), 19 deletions(-)
diff --git a/providers/dbt/cloud/pyproject.toml
b/providers/dbt/cloud/pyproject.toml
index de91020daab..6a1bd380402 100644
--- a/providers/dbt/cloud/pyproject.toml
+++ b/providers/dbt/cloud/pyproject.toml
@@ -62,6 +62,7 @@ dependencies = [
"apache-airflow-providers-http",
"asgiref>=2.3.0",
"aiohttp>=3.9.2",
+ "tenacity>=8.3.0",
]
# The optional dependencies should be modified in place in the generated file
diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
index c7023f1b923..70630510968 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import asyncio
+import copy
import json
import time
import warnings
@@ -28,8 +29,10 @@ from typing import TYPE_CHECKING, Any, TypedDict, TypeVar,
cast
import aiohttp
from asgiref.sync import sync_to_async
+from requests import exceptions as requests_exceptions
from requests.auth import AuthBase
from requests.sessions import Session
+from tenacity import AsyncRetrying, RetryCallState, retry_if_exception,
stop_after_attempt, wait_exponential
from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
@@ -174,6 +177,10 @@ class DbtCloudHook(HttpHook):
Interact with dbt Cloud using the V2 (V3 if supported) API.
:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection
<howto/connection:dbt-cloud>`.
+ :param timeout_seconds: Optional. The timeout in seconds for HTTP
requests. If not provided, no timeout is applied.
+ :param retry_limit: The number of times to retry a request in case of
failure.
+ :param retry_delay: The delay in seconds between retries.
+ :param retry_args: A dictionary of arguments to pass to the
`tenacity.retry` decorator.
"""
conn_name_attr = "dbt_cloud_conn_id"
@@ -193,9 +200,39 @@ class DbtCloudHook(HttpHook):
},
}
- def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args,
**kwargs) -> None:
+ def __init__(
+ self,
+ dbt_cloud_conn_id: str = default_conn_name,
+ timeout_seconds: int | None = None,
+ retry_limit: int = 1,
+ retry_delay: float = 1.0,
+ retry_args: dict[Any, Any] | None = None,
+ ) -> None:
super().__init__(auth_type=TokenAuth)
self.dbt_cloud_conn_id = dbt_cloud_conn_id
+ self.timeout_seconds = timeout_seconds
+ if retry_limit < 1:
+ raise ValueError("Retry limit must be greater than or equal to 1")
+ self.retry_limit = retry_limit
+ self.retry_delay = retry_delay
+
+ def retry_after_func(retry_state: RetryCallState) -> None:
+ error_msg = str(retry_state.outcome.exception()) if
retry_state.outcome else "Unknown error"
+ self._log_request_error(retry_state.attempt_number, error_msg)
+
+ if retry_args:
+ self.retry_args = copy.copy(retry_args)
+ self.retry_args["retry"] =
retry_if_exception(self._retryable_error)
+ self.retry_args["after"] = retry_after_func
+ self.retry_args["reraise"] = True
+ else:
+ self.retry_args = {
+ "stop": stop_after_attempt(self.retry_limit),
+ "wait": wait_exponential(min=self.retry_delay,
max=(2**retry_limit)),
+ "retry": retry_if_exception(self._retryable_error),
+ "after": retry_after_func,
+ "reraise": True,
+ }
@staticmethod
def _get_tenant_domain(conn: Connection) -> str:
@@ -233,6 +270,36 @@ class DbtCloudHook(HttpHook):
headers["Authorization"] = f"Token {self.connection.password}"
return headers, tenant
+ def _log_request_error(self, attempt_num: int, error: str) -> None:
+ self.log.error("Attempt %s API Request to DBT failed with reason: %s",
attempt_num, error)
+
+ @staticmethod
+ def _retryable_error(exception: BaseException) -> bool:
+ if isinstance(exception, requests_exceptions.RequestException):
+ if isinstance(exception, (requests_exceptions.ConnectionError,
requests_exceptions.Timeout)) or (
+ exception.response is not None
+ and (exception.response.status_code >= 500 or
exception.response.status_code == 429)
+ ):
+ return True
+
+ if isinstance(exception, aiohttp.ClientResponseError):
+ if exception.status >= 500 or exception.status == 429:
+ return True
+
+ if isinstance(exception, (aiohttp.ClientConnectorError, TimeoutError)):
+ return True
+
+ return False
+
+ def _a_get_retry_object(self) -> AsyncRetrying:
+ """
+ Instantiate an async retry object.
+
+ :return: instance of AsyncRetrying class
+ """
+ # for compatibility we use reraise to avoid handling request error
+ return AsyncRetrying(**self.retry_args)
+
@provide_account_id
async def get_job_details(
self, run_id: int, account_id: int | None = None, include_related:
list[str] | None = None
@@ -249,17 +316,22 @@ class DbtCloudHook(HttpHook):
headers, tenant = await self.get_headers_tenants_from_connection()
url, params = self.get_request_url_params(tenant, endpoint,
include_related)
proxies = self._get_proxies(self.connection) or {}
+ proxy = proxies.get("https") if proxies and url.startswith("https")
else proxies.get("http")
+ extra_request_args = {}
- async with aiohttp.ClientSession(headers=headers) as session:
- proxy = proxies.get("https") if proxies and
url.startswith("https") else proxies.get("http")
- extra_request_args = {}
+ if proxy:
+ extra_request_args["proxy"] = proxy
- if proxy:
- extra_request_args["proxy"] = proxy
+ timeout = (
+ aiohttp.ClientTimeout(total=self.timeout_seconds) if
self.timeout_seconds is not None else None
+ )
- async with session.get(url, params=params, **extra_request_args)
as response: # type: ignore[arg-type]
- response.raise_for_status()
- return await response.json()
+ async with aiohttp.ClientSession(headers=headers, timeout=timeout) as
session:
+ async for attempt in self._a_get_retry_object():
+ with attempt:
+ async with session.get(url, params=params,
**extra_request_args) as response: # type: ignore[arg-type]
+ response.raise_for_status()
+ return await response.json()
async def get_job_status(
self, run_id: int, account_id: int | None = None, include_related:
list[str] | None = None
@@ -297,8 +369,14 @@ class DbtCloudHook(HttpHook):
def _paginate(
self, endpoint: str, payload: dict[str, Any] | None = None, proxies:
dict[str, str] | None = None
) -> list[Response]:
- extra_options = {"proxies": proxies} if proxies is not None else None
- response = self.run(endpoint=endpoint, data=payload,
extra_options=extra_options)
+ extra_options: dict[str, Any] = {}
+ if self.timeout_seconds is not None:
+ extra_options["timeout"] = self.timeout_seconds
+ if proxies is not None:
+ extra_options["proxies"] = proxies
+ response = self.run_with_advanced_retry(
+ _retry_args=self.retry_args, endpoint=endpoint, data=payload,
extra_options=extra_options or None
+ )
resp_json = response.json()
limit = resp_json["extra"]["filters"]["limit"]
num_total_results = resp_json["extra"]["pagination"]["total_count"]
@@ -309,7 +387,12 @@ class DbtCloudHook(HttpHook):
_paginate_payload["offset"] = limit
while num_current_results < num_total_results:
- response = self.run(endpoint=endpoint, data=_paginate_payload,
extra_options=extra_options)
+ response = self.run_with_advanced_retry(
+ _retry_args=self.retry_args,
+ endpoint=endpoint,
+ data=_paginate_payload,
+ extra_options=extra_options,
+ )
resp_json = response.json()
results.append(response)
num_current_results +=
resp_json["extra"]["pagination"]["count"]
@@ -328,7 +411,11 @@ class DbtCloudHook(HttpHook):
self.method = method
full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint
else None
proxies = self._get_proxies(self.connection)
- extra_options = {"proxies": proxies} if proxies is not None else None
+ extra_options: dict[str, Any] = {}
+ if self.timeout_seconds is not None:
+ extra_options["timeout"] = self.timeout_seconds
+ if proxies is not None:
+ extra_options["proxies"] = proxies
if paginate:
if isinstance(payload, str):
@@ -339,7 +426,12 @@ class DbtCloudHook(HttpHook):
raise ValueError("An endpoint is needed to paginate a response.")
- return self.run(endpoint=full_endpoint, data=payload,
extra_options=extra_options)
+ return self.run_with_advanced_retry(
+ _retry_args=self.retry_args,
+ endpoint=full_endpoint,
+ data=payload,
+ extra_options=extra_options or None,
+ )
def list_accounts(self) -> list[Response]:
"""
diff --git
a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py
index 03492a0b23c..3ac8c6d544f 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py
@@ -87,6 +87,7 @@ class DbtCloudRunJobOperator(BaseOperator):
run. For more information on retry logic, see:
https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job
:param deferrable: Run operator in the deferrable mode
+ :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
:return: The ID of the triggered dbt Cloud job run.
"""
@@ -124,6 +125,7 @@ class DbtCloudRunJobOperator(BaseOperator):
reuse_existing_run: bool = False,
retry_from_failure: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ hook_params: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -144,6 +146,7 @@ class DbtCloudRunJobOperator(BaseOperator):
self.reuse_existing_run = reuse_existing_run
self.retry_from_failure = retry_from_failure
self.deferrable = deferrable
+ self.hook_params = hook_params or {}
def execute(self, context: Context):
if self.trigger_reason is None:
@@ -273,7 +276,7 @@ class DbtCloudRunJobOperator(BaseOperator):
@cached_property
def hook(self):
"""Returns DBT Cloud hook."""
- return DbtCloudHook(self.dbt_cloud_conn_id)
+ return DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
"""
@@ -311,6 +314,7 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator):
be returned.
:param output_file_name: Optional. The desired file name for the download
artifact file.
Defaults to <run_id>_<path> (e.g. "728368_run_results.json").
+ :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
"""
template_fields = ("dbt_cloud_conn_id", "run_id", "path", "account_id",
"output_file_name")
@@ -324,6 +328,7 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator):
account_id: int | None = None,
step: int | None = None,
output_file_name: str | None = None,
+ hook_params: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -333,9 +338,10 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator):
self.account_id = account_id
self.step = step
self.output_file_name = output_file_name or
f"{self.run_id}_{self.path}".replace("/", "-")
+ self.hook_params = hook_params or {}
def execute(self, context: Context) -> str:
- hook = DbtCloudHook(self.dbt_cloud_conn_id)
+ hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
response = hook.get_job_run_artifact(
run_id=self.run_id, path=self.path, account_id=self.account_id,
step=self.step
)
@@ -370,6 +376,7 @@ class DbtCloudListJobsOperator(BaseOperator):
:param order_by: Optional. Field to order the result by. Use '-' to
indicate reverse order.
For example, to use reverse order by the run ID use ``order_by=-id``.
:param project_id: Optional. The ID of a dbt Cloud project.
+ :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
"""
template_fields = (
@@ -384,6 +391,7 @@ class DbtCloudListJobsOperator(BaseOperator):
account_id: int | None = None,
project_id: int | None = None,
order_by: str | None = None,
+ hook_params: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -391,9 +399,10 @@ class DbtCloudListJobsOperator(BaseOperator):
self.account_id = account_id
self.project_id = project_id
self.order_by = order_by
+ self.hook_params = hook_params or {}
def execute(self, context: Context) -> list:
- hook = DbtCloudHook(self.dbt_cloud_conn_id)
+ hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
list_jobs_response = hook.list_jobs(
account_id=self.account_id, order_by=self.order_by,
project_id=self.project_id
)
diff --git
a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py
index c8acf2d81e6..9d4c59473b1 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py
@@ -36,6 +36,7 @@ class DbtCloudRunJobTrigger(BaseTrigger):
:param end_time: Time in seconds to wait for a job run to reach a terminal
status. Defaults to 7 days.
:param account_id: The ID of a dbt Cloud account.
:param poll_interval: polling period in seconds to check for the status.
+ :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
"""
def __init__(
@@ -45,6 +46,7 @@ class DbtCloudRunJobTrigger(BaseTrigger):
end_time: float,
poll_interval: float,
account_id: int | None,
+ hook_params: dict[str, Any] | None = None,
):
super().__init__()
self.run_id = run_id
@@ -52,6 +54,7 @@ class DbtCloudRunJobTrigger(BaseTrigger):
self.conn_id = conn_id
self.end_time = end_time
self.poll_interval = poll_interval
+ self.hook_params = hook_params or {}
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize DbtCloudRunJobTrigger arguments and classpath."""
@@ -63,12 +66,13 @@ class DbtCloudRunJobTrigger(BaseTrigger):
"conn_id": self.conn_id,
"end_time": self.end_time,
"poll_interval": self.poll_interval,
+ "hook_params": self.hook_params,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to Dbt, polls for the pipeline run status."""
- hook = DbtCloudHook(self.conn_id)
+ hook = DbtCloudHook(self.conn_id, **self.hook_params)
try:
while await self.is_still_running(hook):
if self.end_time < time.time():
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
index 08bc9c69757..607bb650871 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
@@ -20,9 +20,11 @@ import json
from copy import deepcopy
from datetime import timedelta
from typing import Any
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
+import aiohttp
import pytest
+from requests import exceptions as requests_exceptions
from requests.models import Response
from airflow.exceptions import AirflowException
@@ -95,6 +97,12 @@ def mock_response_json(response: dict):
return run_response
+def request_exception_with_status(status_code: int) ->
requests_exceptions.HTTPError:
+ response = Response()
+ response.status_code = status_code
+ return requests_exceptions.HTTPError(response=response)
+
+
class TestDbtCloudJobRunStatus:
valid_job_run_statuses = [
1, # QUEUED
@@ -1072,3 +1080,235 @@ class TestDbtCloudHook:
assert status is False
assert msg == "403:Authentication credentials were not provided"
+
+ @pytest.mark.parametrize(
+ argnames="timeout_seconds",
+ argvalues=[60, 180, 300],
+ ids=["60s", "180s", "300s"],
+ )
+ @patch.object(DbtCloudHook, "run_with_advanced_retry")
+ def test_timeout_passed_to_run_and_get_response(self, mock_run_with_retry,
timeout_seconds):
+ """Test that timeout is passed to extra_options in
_run_and_get_response."""
+ hook = DbtCloudHook(ACCOUNT_ID_CONN, timeout_seconds=timeout_seconds)
+ mock_run_with_retry.return_value = mock_response_json({"data": {"id":
JOB_ID}})
+
+ hook.get_job(job_id=JOB_ID, account_id=DEFAULT_ACCOUNT_ID)
+
+ call_args = mock_run_with_retry.call_args
+ assert call_args is not None
+ extra_options = call_args.kwargs.get("extra_options")
+ assert extra_options is not None
+ assert extra_options["timeout"] == timeout_seconds
+
+ @pytest.mark.parametrize(
+ argnames="timeout_seconds",
+ argvalues=[60, 180, 300],
+ ids=["60s", "180s", "300s"],
+ )
+ @patch.object(DbtCloudHook, "run_with_advanced_retry")
+ def test_timeout_passed_to_paginate(self, mock_run_with_retry,
timeout_seconds):
+ """Test that timeout is passed to extra_options in _paginate."""
+ hook = DbtCloudHook(ACCOUNT_ID_CONN, timeout_seconds=timeout_seconds)
+ mock_response = mock_response_json(
+ {
+ "data": [{"id": JOB_ID}],
+ "extra": {"filters": {"limit": 100}, "pagination": {"count":
1, "total_count": 1}},
+ }
+ )
+ mock_run_with_retry.return_value = mock_response
+
+ hook.list_jobs(account_id=DEFAULT_ACCOUNT_ID)
+
+ call_args = mock_run_with_retry.call_args
+ assert call_args is not None
+ extra_options = call_args.kwargs.get("extra_options")
+ assert extra_options is not None
+ assert extra_options["timeout"] == timeout_seconds
+
+ @pytest.mark.parametrize(
+ argnames="timeout_seconds",
+ argvalues=[60, 180, 300],
+ ids=["60s", "180s", "300s"],
+ )
+ @patch.object(DbtCloudHook, "run_with_advanced_retry")
+ def test_timeout_with_proxies(self, mock_run_with_retry, timeout_seconds):
+ """Test that both timeout and proxies are passed to extra_options."""
+ hook = DbtCloudHook(PROXY_CONN, timeout_seconds=timeout_seconds)
+ mock_run_with_retry.return_value = mock_response_json({"data": {"id":
JOB_ID}})
+
+ hook.get_job(job_id=JOB_ID, account_id=DEFAULT_ACCOUNT_ID)
+
+ call_args = mock_run_with_retry.call_args
+ assert call_args is not None
+ extra_options = call_args.kwargs.get("extra_options")
+ assert extra_options is not None
+ assert extra_options["timeout"] == timeout_seconds
+ assert "proxies" in extra_options
+ assert extra_options["proxies"] == EXTRA_PROXIES["proxies"]
+
+ @pytest.mark.parametrize(
+ argnames="exception, expected",
+ argvalues=[
+ (requests_exceptions.ConnectionError(), True),
+ (requests_exceptions.Timeout(), True),
+ (request_exception_with_status(503), True),
+ (request_exception_with_status(429), True),
+ (request_exception_with_status(404), False),
+ (aiohttp.ClientResponseError(MagicMock(), (), status=500,
message=""), True),
+ (aiohttp.ClientResponseError(MagicMock(), (), status=429,
message=""), True),
+ (aiohttp.ClientResponseError(MagicMock(), (), status=400,
message=""), False),
+ (aiohttp.ClientConnectorError(MagicMock(), OSError()), True),
+ (TimeoutError(), True),
+ (ValueError(), False),
+ ],
+ ids=[
+ "requests_connection_error",
+ "requests_timeout",
+ "requests_status_503",
+ "requests_status_429",
+ "requests_status_404",
+ "aiohttp_status_500",
+ "aiohttp_status_429",
+ "aiohttp_status_400",
+ "aiohttp_connector_error",
+ "timeout_error",
+ "value_error",
+ ],
+ )
+ def test_retryable_error(self, exception, expected):
+ assert DbtCloudHook._retryable_error(exception) is expected
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "error_factory, retry_qty, retry_delay",
+ [
+ (
+ lambda: aiohttp.ClientResponseError(
+ request_info=AsyncMock(), history=(), status=500,
message=""
+ ),
+ 3,
+ 0.1,
+ ),
+ (
+ lambda: aiohttp.ClientResponseError(
+ request_info=AsyncMock(), history=(), status=429,
message=""
+ ),
+ 5,
+ 0.1,
+ ),
+ (lambda: aiohttp.ClientConnectorError(AsyncMock(),
OSError("boom")), 2, 0.1),
+ (lambda: TimeoutError(), 2, 0.1),
+ ],
+ ids=["aiohttp_500", "aiohttp_429", "connector_error", "timeout"],
+ )
+ @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get")
+ async def test_get_job_details_retry_with_retryable_errors(
+ self, get_mock, error_factory, retry_qty, retry_delay
+ ):
+ hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=retry_qty,
retry_delay=retry_delay)
+
+ def fail_cm():
+ cm = AsyncMock()
+ cm.__aenter__.side_effect = error_factory()
+ return cm
+
+ ok_resp = AsyncMock()
+ ok_resp.raise_for_status = MagicMock(return_value=None)
+ ok_resp.json = AsyncMock(return_value={"data": "Success"})
+ ok_cm = AsyncMock()
+ ok_cm.__aenter__.return_value = ok_resp
+ ok_cm.__aexit__.return_value = AsyncMock()
+
+ all_resp = [fail_cm() for _ in range(retry_qty - 1)]
+ all_resp.append(ok_cm)
+ get_mock.side_effect = all_resp
+
+ result = await hook.get_job_details(run_id=RUN_ID, account_id=None)
+
+ assert result == {"data": "Success"}
+ assert get_mock.call_count == retry_qty
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "error_factory, expected_exception",
+ [
+ (
+ lambda: aiohttp.ClientResponseError(
+ request_info=AsyncMock(), history=(), status=404,
message="Not Found"
+ ),
+ aiohttp.ClientResponseError,
+ ),
+ (
+ lambda: aiohttp.ClientResponseError(
+ request_info=AsyncMock(), history=(), status=400,
message="Bad Request"
+ ),
+ aiohttp.ClientResponseError,
+ ),
+ (lambda: ValueError("Invalid parameter"), ValueError),
+ ],
+ ids=["aiohttp_404", "aiohttp_400", "value_error"],
+ )
+ @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get")
+ async def test_get_job_details_retry_with_non_retryable_errors(
+ self, get_mock, error_factory, expected_exception
+ ):
+ hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=3, retry_delay=0.1)
+
+ def fail_cm():
+ cm = AsyncMock()
+ cm.__aenter__.side_effect = error_factory()
+ return cm
+
+ get_mock.return_value = fail_cm()
+
+ with pytest.raises(expected_exception):
+ await hook.get_job_details(run_id=RUN_ID, account_id=None)
+
+ assert get_mock.call_count == 1
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ argnames="error_factory, expected_exception",
+ argvalues=[
+ (
+ lambda: aiohttp.ClientResponseError(
+ request_info=AsyncMock(), history=(), status=503,
message="Service Unavailable"
+ ),
+ aiohttp.ClientResponseError,
+ ),
+ (
+ lambda: aiohttp.ClientResponseError(
+ request_info=AsyncMock(), history=(), status=500,
message="Internal Server Error"
+ ),
+ aiohttp.ClientResponseError,
+ ),
+ (
+ lambda: aiohttp.ClientConnectorError(AsyncMock(),
OSError("Connection refused")),
+ aiohttp.ClientConnectorError,
+ ),
+ (lambda: TimeoutError("Request timeout"), TimeoutError),
+ ],
+ ids=[
+ "aiohttp_503_exhausted",
+ "aiohttp_500_exhausted",
+ "connector_error_exhausted",
+ "timeout_exhausted",
+ ],
+ )
+ @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get")
+ async def test_get_job_details_retry_with_exhausted_retries(
+ self, get_mock, error_factory, expected_exception
+ ):
+ hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=2, retry_delay=0.1)
+
+ def fail_cm():
+ cm = AsyncMock()
+ cm.__aenter__.side_effect = error_factory()
+ return cm
+
+ get_mock.side_effect = [fail_cm() for _ in range(2)]
+
+ with pytest.raises(expected_exception):
+ await hook.get_job_details(run_id=RUN_ID, account_id=None)
+
+ assert get_mock.call_count == 2
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
index 2a1f26b49d8..76818008a45 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
@@ -45,6 +45,7 @@ class TestDbtCloudRunJobTrigger:
end_time=self.END_TIME,
run_id=self.RUN_ID,
account_id=self.ACCOUNT_ID,
+ hook_params={"retry_delay": 10},
)
classpath, kwargs = trigger.serialize()
assert classpath ==
"airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger"
@@ -54,6 +55,7 @@ class TestDbtCloudRunJobTrigger:
"conn_id": self.CONN_ID,
"end_time": self.END_TIME,
"poll_interval": self.POLL_INTERVAL,
+ "hook_params": {"retry_delay": 10},
}
@pytest.mark.asyncio