This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ebabc15df3 Add retry mechanism and error handling to DBT Hook (#56651)
4ebabc15df3 is described below

commit 4ebabc15df3a9ff368679cbd356e0206da3097f3
Author: AardJan <[email protected]>
AuthorDate: Wed Oct 15 16:45:34 2025 +0200

    Add retry mechanism and error handling to DBT Hook (#56651)
    
    * fix issue when single API error fail entire DbtCloudRunJobOperator task
    
    * fix formatting and refactor get_job_details function
    
    * add hook params to operator and trigger DBT classes
    
    * add docstring, fix function annotation
    
    * add test for retry in hooks and triggers
    
    * fix Hook class param, change method for sync requests
    
    * simplify tests, remove account param
    
    * change test names
    
    * change rerise implementation
    
    * add hook_params for others DBT operators
    
    * change timeout to optional
    
    * add test for timeout param, fix mock warning
    
    * reformat using prek
    
    * change import order, change log request after retry
---
 providers/dbt/cloud/pyproject.toml                 |   1 +
 .../src/airflow/providers/dbt/cloud/hooks/dbt.py   | 120 ++++++++--
 .../airflow/providers/dbt/cloud/operators/dbt.py   |  15 +-
 .../airflow/providers/dbt/cloud/triggers/dbt.py    |   6 +-
 .../cloud/tests/unit/dbt/cloud/hooks/test_dbt.py   | 242 ++++++++++++++++++++-
 .../tests/unit/dbt/cloud/triggers/test_dbt.py      |   2 +
 6 files changed, 367 insertions(+), 19 deletions(-)

diff --git a/providers/dbt/cloud/pyproject.toml 
b/providers/dbt/cloud/pyproject.toml
index de91020daab..6a1bd380402 100644
--- a/providers/dbt/cloud/pyproject.toml
+++ b/providers/dbt/cloud/pyproject.toml
@@ -62,6 +62,7 @@ dependencies = [
     "apache-airflow-providers-http",
     "asgiref>=2.3.0",
     "aiohttp>=3.9.2",
+    "tenacity>=8.3.0",
 ]
 
 # The optional dependencies should be modified in place in the generated file
diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py 
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
index c7023f1b923..70630510968 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import asyncio
+import copy
 import json
 import time
 import warnings
@@ -28,8 +29,10 @@ from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, 
cast
 
 import aiohttp
 from asgiref.sync import sync_to_async
+from requests import exceptions as requests_exceptions
 from requests.auth import AuthBase
 from requests.sessions import Session
+from tenacity import AsyncRetrying, RetryCallState, retry_if_exception, 
stop_after_attempt, wait_exponential
 
 from airflow.exceptions import AirflowException
 from airflow.providers.http.hooks.http import HttpHook
@@ -174,6 +177,10 @@ class DbtCloudHook(HttpHook):
     Interact with dbt Cloud using the V2 (V3 if supported) API.
 
     :param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection 
<howto/connection:dbt-cloud>`.
+    :param timeout_seconds: Optional. The timeout in seconds for HTTP 
requests. If not provided, no timeout is applied.
+    :param retry_limit: The number of times to retry a request in case of 
failure.
+    :param retry_delay: The delay in seconds between retries.
+    :param retry_args: A dictionary of arguments to pass to the 
`tenacity.retry` decorator.
     """
 
     conn_name_attr = "dbt_cloud_conn_id"
@@ -193,9 +200,39 @@ class DbtCloudHook(HttpHook):
             },
         }
 
-    def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, 
**kwargs) -> None:
+    def __init__(
+        self,
+        dbt_cloud_conn_id: str = default_conn_name,
+        timeout_seconds: int | None = None,
+        retry_limit: int = 1,
+        retry_delay: float = 1.0,
+        retry_args: dict[Any, Any] | None = None,
+    ) -> None:
         super().__init__(auth_type=TokenAuth)
         self.dbt_cloud_conn_id = dbt_cloud_conn_id
+        self.timeout_seconds = timeout_seconds
+        if retry_limit < 1:
+            raise ValueError("Retry limit must be greater than or equal to 1")
+        self.retry_limit = retry_limit
+        self.retry_delay = retry_delay
+
+        def retry_after_func(retry_state: RetryCallState) -> None:
+            error_msg = str(retry_state.outcome.exception()) if 
retry_state.outcome else "Unknown error"
+            self._log_request_error(retry_state.attempt_number, error_msg)
+
+        if retry_args:
+            self.retry_args = copy.copy(retry_args)
+            self.retry_args["retry"] = 
retry_if_exception(self._retryable_error)
+            self.retry_args["after"] = retry_after_func
+            self.retry_args["reraise"] = True
+        else:
+            self.retry_args = {
+                "stop": stop_after_attempt(self.retry_limit),
+                "wait": wait_exponential(min=self.retry_delay, 
max=(2**retry_limit)),
+                "retry": retry_if_exception(self._retryable_error),
+                "after": retry_after_func,
+                "reraise": True,
+            }
 
     @staticmethod
     def _get_tenant_domain(conn: Connection) -> str:
@@ -233,6 +270,36 @@ class DbtCloudHook(HttpHook):
         headers["Authorization"] = f"Token {self.connection.password}"
         return headers, tenant
 
+    def _log_request_error(self, attempt_num: int, error: str) -> None:
+        self.log.error("Attempt %s API Request to DBT failed with reason: %s", 
attempt_num, error)
+
+    @staticmethod
+    def _retryable_error(exception: BaseException) -> bool:
+        if isinstance(exception, requests_exceptions.RequestException):
+            if isinstance(exception, (requests_exceptions.ConnectionError, 
requests_exceptions.Timeout)) or (
+                exception.response is not None
+                and (exception.response.status_code >= 500 or 
exception.response.status_code == 429)
+            ):
+                return True
+
+        if isinstance(exception, aiohttp.ClientResponseError):
+            if exception.status >= 500 or exception.status == 429:
+                return True
+
+        if isinstance(exception, (aiohttp.ClientConnectorError, TimeoutError)):
+            return True
+
+        return False
+
+    def _a_get_retry_object(self) -> AsyncRetrying:
+        """
+        Instantiate an async retry object.
+
+        :return: instance of AsyncRetrying class
+        """
+        # for compatibility we use reraise to avoid handling request error
+        return AsyncRetrying(**self.retry_args)
+
     @provide_account_id
     async def get_job_details(
         self, run_id: int, account_id: int | None = None, include_related: 
list[str] | None = None
@@ -249,17 +316,22 @@ class DbtCloudHook(HttpHook):
         headers, tenant = await self.get_headers_tenants_from_connection()
         url, params = self.get_request_url_params(tenant, endpoint, 
include_related)
         proxies = self._get_proxies(self.connection) or {}
+        proxy = proxies.get("https") if proxies and url.startswith("https") 
else proxies.get("http")
+        extra_request_args = {}
 
-        async with aiohttp.ClientSession(headers=headers) as session:
-            proxy = proxies.get("https") if proxies and 
url.startswith("https") else proxies.get("http")
-            extra_request_args = {}
+        if proxy:
+            extra_request_args["proxy"] = proxy
 
-            if proxy:
-                extra_request_args["proxy"] = proxy
+        timeout = (
+            aiohttp.ClientTimeout(total=self.timeout_seconds) if 
self.timeout_seconds is not None else None
+        )
 
-            async with session.get(url, params=params, **extra_request_args) 
as response:  # type: ignore[arg-type]
-                response.raise_for_status()
-                return await response.json()
+        async with aiohttp.ClientSession(headers=headers, timeout=timeout) as 
session:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with session.get(url, params=params, 
**extra_request_args) as response:  # type: ignore[arg-type]
+                        response.raise_for_status()
+                        return await response.json()
 
     async def get_job_status(
         self, run_id: int, account_id: int | None = None, include_related: 
list[str] | None = None
@@ -297,8 +369,14 @@ class DbtCloudHook(HttpHook):
     def _paginate(
         self, endpoint: str, payload: dict[str, Any] | None = None, proxies: 
dict[str, str] | None = None
     ) -> list[Response]:
-        extra_options = {"proxies": proxies} if proxies is not None else None
-        response = self.run(endpoint=endpoint, data=payload, 
extra_options=extra_options)
+        extra_options: dict[str, Any] = {}
+        if self.timeout_seconds is not None:
+            extra_options["timeout"] = self.timeout_seconds
+        if proxies is not None:
+            extra_options["proxies"] = proxies
+        response = self.run_with_advanced_retry(
+            _retry_args=self.retry_args, endpoint=endpoint, data=payload, 
extra_options=extra_options or None
+        )
         resp_json = response.json()
         limit = resp_json["extra"]["filters"]["limit"]
         num_total_results = resp_json["extra"]["pagination"]["total_count"]
@@ -309,7 +387,12 @@ class DbtCloudHook(HttpHook):
             _paginate_payload["offset"] = limit
 
             while num_current_results < num_total_results:
-                response = self.run(endpoint=endpoint, data=_paginate_payload, 
extra_options=extra_options)
+                response = self.run_with_advanced_retry(
+                    _retry_args=self.retry_args,
+                    endpoint=endpoint,
+                    data=_paginate_payload,
+                    extra_options=extra_options,
+                )
                 resp_json = response.json()
                 results.append(response)
                 num_current_results += 
resp_json["extra"]["pagination"]["count"]
@@ -328,7 +411,11 @@ class DbtCloudHook(HttpHook):
         self.method = method
         full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint 
else None
         proxies = self._get_proxies(self.connection)
-        extra_options = {"proxies": proxies} if proxies is not None else None
+        extra_options: dict[str, Any] = {}
+        if self.timeout_seconds is not None:
+            extra_options["timeout"] = self.timeout_seconds
+        if proxies is not None:
+            extra_options["proxies"] = proxies
 
         if paginate:
             if isinstance(payload, str):
@@ -339,7 +426,12 @@ class DbtCloudHook(HttpHook):
 
             raise ValueError("An endpoint is needed to paginate a response.")
 
-        return self.run(endpoint=full_endpoint, data=payload, 
extra_options=extra_options)
+        return self.run_with_advanced_retry(
+            _retry_args=self.retry_args,
+            endpoint=full_endpoint,
+            data=payload,
+            extra_options=extra_options or None,
+        )
 
     def list_accounts(self) -> list[Response]:
         """
diff --git 
a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py 
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py
index 03492a0b23c..3ac8c6d544f 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py
@@ -87,6 +87,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         run. For more information on retry logic, see:
         
https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job
     :param deferrable: Run operator in the deferrable mode
+    :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
     :return: The ID of the triggered dbt Cloud job run.
     """
 
@@ -124,6 +125,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         reuse_existing_run: bool = False,
         retry_from_failure: bool = False,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        hook_params: dict[str, Any] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -144,6 +146,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         self.reuse_existing_run = reuse_existing_run
         self.retry_from_failure = retry_from_failure
         self.deferrable = deferrable
+        self.hook_params = hook_params or {}
 
     def execute(self, context: Context):
         if self.trigger_reason is None:
@@ -273,7 +276,7 @@ class DbtCloudRunJobOperator(BaseOperator):
     @cached_property
     def hook(self):
         """Returns DBT Cloud hook."""
-        return DbtCloudHook(self.dbt_cloud_conn_id)
+        return DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
 
     def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
         """
@@ -311,6 +314,7 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator):
         be returned.
     :param output_file_name: Optional. The desired file name for the download 
artifact file.
         Defaults to <run_id>_<path> (e.g. "728368_run_results.json").
+    :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
     """
 
     template_fields = ("dbt_cloud_conn_id", "run_id", "path", "account_id", 
"output_file_name")
@@ -324,6 +328,7 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator):
         account_id: int | None = None,
         step: int | None = None,
         output_file_name: str | None = None,
+        hook_params: dict[str, Any] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -333,9 +338,10 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator):
         self.account_id = account_id
         self.step = step
         self.output_file_name = output_file_name or 
f"{self.run_id}_{self.path}".replace("/", "-")
+        self.hook_params = hook_params or {}
 
     def execute(self, context: Context) -> str:
-        hook = DbtCloudHook(self.dbt_cloud_conn_id)
+        hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
         response = hook.get_job_run_artifact(
             run_id=self.run_id, path=self.path, account_id=self.account_id, 
step=self.step
         )
@@ -370,6 +376,7 @@ class DbtCloudListJobsOperator(BaseOperator):
     :param order_by: Optional. Field to order the result by. Use '-' to 
indicate reverse order.
         For example, to use reverse order by the run ID use ``order_by=-id``.
     :param project_id: Optional. The ID of a dbt Cloud project.
+    :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
     """
 
     template_fields = (
@@ -384,6 +391,7 @@ class DbtCloudListJobsOperator(BaseOperator):
         account_id: int | None = None,
         project_id: int | None = None,
         order_by: str | None = None,
+        hook_params: dict[str, Any] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -391,9 +399,10 @@ class DbtCloudListJobsOperator(BaseOperator):
         self.account_id = account_id
         self.project_id = project_id
         self.order_by = order_by
+        self.hook_params = hook_params or {}
 
     def execute(self, context: Context) -> list:
-        hook = DbtCloudHook(self.dbt_cloud_conn_id)
+        hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
         list_jobs_response = hook.list_jobs(
             account_id=self.account_id, order_by=self.order_by, 
project_id=self.project_id
         )
diff --git 
a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py 
b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py
index c8acf2d81e6..9d4c59473b1 100644
--- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py
+++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py
@@ -36,6 +36,7 @@ class DbtCloudRunJobTrigger(BaseTrigger):
     :param end_time: Time in seconds to wait for a job run to reach a terminal 
status. Defaults to 7 days.
     :param account_id: The ID of a dbt Cloud account.
     :param poll_interval:  polling period in seconds to check for the status.
+    :param hook_params: Extra arguments passed to the DbtCloudHook constructor.
     """
 
     def __init__(
@@ -45,6 +46,7 @@ class DbtCloudRunJobTrigger(BaseTrigger):
         end_time: float,
         poll_interval: float,
         account_id: int | None,
+        hook_params: dict[str, Any] | None = None,
     ):
         super().__init__()
         self.run_id = run_id
@@ -52,6 +54,7 @@ class DbtCloudRunJobTrigger(BaseTrigger):
         self.conn_id = conn_id
         self.end_time = end_time
         self.poll_interval = poll_interval
+        self.hook_params = hook_params or {}
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serialize DbtCloudRunJobTrigger arguments and classpath."""
@@ -63,12 +66,13 @@ class DbtCloudRunJobTrigger(BaseTrigger):
                 "conn_id": self.conn_id,
                 "end_time": self.end_time,
                 "poll_interval": self.poll_interval,
+                "hook_params": self.hook_params,
             },
         )
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Make async connection to Dbt, polls for the pipeline run status."""
-        hook = DbtCloudHook(self.conn_id)
+        hook = DbtCloudHook(self.conn_id, **self.hook_params)
         try:
             while await self.is_still_running(hook):
                 if self.end_time < time.time():
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py 
b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
index 08bc9c69757..607bb650871 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
@@ -20,9 +20,11 @@ import json
 from copy import deepcopy
 from datetime import timedelta
 from typing import Any
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
 
+import aiohttp
 import pytest
+from requests import exceptions as requests_exceptions
 from requests.models import Response
 
 from airflow.exceptions import AirflowException
@@ -95,6 +97,12 @@ def mock_response_json(response: dict):
     return run_response
 
 
+def request_exception_with_status(status_code: int) -> 
requests_exceptions.HTTPError:
+    response = Response()
+    response.status_code = status_code
+    return requests_exceptions.HTTPError(response=response)
+
+
 class TestDbtCloudJobRunStatus:
     valid_job_run_statuses = [
         1,  # QUEUED
@@ -1072,3 +1080,235 @@ class TestDbtCloudHook:
 
         assert status is False
         assert msg == "403:Authentication credentials were not provided"
+
+    @pytest.mark.parametrize(
+        argnames="timeout_seconds",
+        argvalues=[60, 180, 300],
+        ids=["60s", "180s", "300s"],
+    )
+    @patch.object(DbtCloudHook, "run_with_advanced_retry")
+    def test_timeout_passed_to_run_and_get_response(self, mock_run_with_retry, 
timeout_seconds):
+        """Test that timeout is passed to extra_options in 
_run_and_get_response."""
+        hook = DbtCloudHook(ACCOUNT_ID_CONN, timeout_seconds=timeout_seconds)
+        mock_run_with_retry.return_value = mock_response_json({"data": {"id": 
JOB_ID}})
+
+        hook.get_job(job_id=JOB_ID, account_id=DEFAULT_ACCOUNT_ID)
+
+        call_args = mock_run_with_retry.call_args
+        assert call_args is not None
+        extra_options = call_args.kwargs.get("extra_options")
+        assert extra_options is not None
+        assert extra_options["timeout"] == timeout_seconds
+
+    @pytest.mark.parametrize(
+        argnames="timeout_seconds",
+        argvalues=[60, 180, 300],
+        ids=["60s", "180s", "300s"],
+    )
+    @patch.object(DbtCloudHook, "run_with_advanced_retry")
+    def test_timeout_passed_to_paginate(self, mock_run_with_retry, 
timeout_seconds):
+        """Test that timeout is passed to extra_options in _paginate."""
+        hook = DbtCloudHook(ACCOUNT_ID_CONN, timeout_seconds=timeout_seconds)
+        mock_response = mock_response_json(
+            {
+                "data": [{"id": JOB_ID}],
+                "extra": {"filters": {"limit": 100}, "pagination": {"count": 
1, "total_count": 1}},
+            }
+        )
+        mock_run_with_retry.return_value = mock_response
+
+        hook.list_jobs(account_id=DEFAULT_ACCOUNT_ID)
+
+        call_args = mock_run_with_retry.call_args
+        assert call_args is not None
+        extra_options = call_args.kwargs.get("extra_options")
+        assert extra_options is not None
+        assert extra_options["timeout"] == timeout_seconds
+
+    @pytest.mark.parametrize(
+        argnames="timeout_seconds",
+        argvalues=[60, 180, 300],
+        ids=["60s", "180s", "300s"],
+    )
+    @patch.object(DbtCloudHook, "run_with_advanced_retry")
+    def test_timeout_with_proxies(self, mock_run_with_retry, timeout_seconds):
+        """Test that both timeout and proxies are passed to extra_options."""
+        hook = DbtCloudHook(PROXY_CONN, timeout_seconds=timeout_seconds)
+        mock_run_with_retry.return_value = mock_response_json({"data": {"id": 
JOB_ID}})
+
+        hook.get_job(job_id=JOB_ID, account_id=DEFAULT_ACCOUNT_ID)
+
+        call_args = mock_run_with_retry.call_args
+        assert call_args is not None
+        extra_options = call_args.kwargs.get("extra_options")
+        assert extra_options is not None
+        assert extra_options["timeout"] == timeout_seconds
+        assert "proxies" in extra_options
+        assert extra_options["proxies"] == EXTRA_PROXIES["proxies"]
+
+    @pytest.mark.parametrize(
+        argnames="exception, expected",
+        argvalues=[
+            (requests_exceptions.ConnectionError(), True),
+            (requests_exceptions.Timeout(), True),
+            (request_exception_with_status(503), True),
+            (request_exception_with_status(429), True),
+            (request_exception_with_status(404), False),
+            (aiohttp.ClientResponseError(MagicMock(), (), status=500, 
message=""), True),
+            (aiohttp.ClientResponseError(MagicMock(), (), status=429, 
message=""), True),
+            (aiohttp.ClientResponseError(MagicMock(), (), status=400, 
message=""), False),
+            (aiohttp.ClientConnectorError(MagicMock(), OSError()), True),
+            (TimeoutError(), True),
+            (ValueError(), False),
+        ],
+        ids=[
+            "requests_connection_error",
+            "requests_timeout",
+            "requests_status_503",
+            "requests_status_429",
+            "requests_status_404",
+            "aiohttp_status_500",
+            "aiohttp_status_429",
+            "aiohttp_status_400",
+            "aiohttp_connector_error",
+            "timeout_error",
+            "value_error",
+        ],
+    )
+    def test_retryable_error(self, exception, expected):
+        assert DbtCloudHook._retryable_error(exception) is expected
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "error_factory, retry_qty, retry_delay",
+        [
+            (
+                lambda: aiohttp.ClientResponseError(
+                    request_info=AsyncMock(), history=(), status=500, 
message=""
+                ),
+                3,
+                0.1,
+            ),
+            (
+                lambda: aiohttp.ClientResponseError(
+                    request_info=AsyncMock(), history=(), status=429, 
message=""
+                ),
+                5,
+                0.1,
+            ),
+            (lambda: aiohttp.ClientConnectorError(AsyncMock(), 
OSError("boom")), 2, 0.1),
+            (lambda: TimeoutError(), 2, 0.1),
+        ],
+        ids=["aiohttp_500", "aiohttp_429", "connector_error", "timeout"],
+    )
+    @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get")
+    async def test_get_job_details_retry_with_retryable_errors(
+        self, get_mock, error_factory, retry_qty, retry_delay
+    ):
+        hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=retry_qty, 
retry_delay=retry_delay)
+
+        def fail_cm():
+            cm = AsyncMock()
+            cm.__aenter__.side_effect = error_factory()
+            return cm
+
+        ok_resp = AsyncMock()
+        ok_resp.raise_for_status = MagicMock(return_value=None)
+        ok_resp.json = AsyncMock(return_value={"data": "Success"})
+        ok_cm = AsyncMock()
+        ok_cm.__aenter__.return_value = ok_resp
+        ok_cm.__aexit__.return_value = AsyncMock()
+
+        all_resp = [fail_cm() for _ in range(retry_qty - 1)]
+        all_resp.append(ok_cm)
+        get_mock.side_effect = all_resp
+
+        result = await hook.get_job_details(run_id=RUN_ID, account_id=None)
+
+        assert result == {"data": "Success"}
+        assert get_mock.call_count == retry_qty
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "error_factory, expected_exception",
+        [
+            (
+                lambda: aiohttp.ClientResponseError(
+                    request_info=AsyncMock(), history=(), status=404, 
message="Not Found"
+                ),
+                aiohttp.ClientResponseError,
+            ),
+            (
+                lambda: aiohttp.ClientResponseError(
+                    request_info=AsyncMock(), history=(), status=400, 
message="Bad Request"
+                ),
+                aiohttp.ClientResponseError,
+            ),
+            (lambda: ValueError("Invalid parameter"), ValueError),
+        ],
+        ids=["aiohttp_404", "aiohttp_400", "value_error"],
+    )
+    @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get")
+    async def test_get_job_details_retry_with_non_retryable_errors(
+        self, get_mock, error_factory, expected_exception
+    ):
+        hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=3, retry_delay=0.1)
+
+        def fail_cm():
+            cm = AsyncMock()
+            cm.__aenter__.side_effect = error_factory()
+            return cm
+
+        get_mock.return_value = fail_cm()
+
+        with pytest.raises(expected_exception):
+            await hook.get_job_details(run_id=RUN_ID, account_id=None)
+
+        assert get_mock.call_count == 1
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        argnames="error_factory, expected_exception",
+        argvalues=[
+            (
+                lambda: aiohttp.ClientResponseError(
+                    request_info=AsyncMock(), history=(), status=503, 
message="Service Unavailable"
+                ),
+                aiohttp.ClientResponseError,
+            ),
+            (
+                lambda: aiohttp.ClientResponseError(
+                    request_info=AsyncMock(), history=(), status=500, 
message="Internal Server Error"
+                ),
+                aiohttp.ClientResponseError,
+            ),
+            (
+                lambda: aiohttp.ClientConnectorError(AsyncMock(), 
OSError("Connection refused")),
+                aiohttp.ClientConnectorError,
+            ),
+            (lambda: TimeoutError("Request timeout"), TimeoutError),
+        ],
+        ids=[
+            "aiohttp_503_exhausted",
+            "aiohttp_500_exhausted",
+            "connector_error_exhausted",
+            "timeout_exhausted",
+        ],
+    )
+    @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get")
+    async def test_get_job_details_retry_with_exhausted_retries(
+        self, get_mock, error_factory, expected_exception
+    ):
+        hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=2, retry_delay=0.1)
+
+        def fail_cm():
+            cm = AsyncMock()
+            cm.__aenter__.side_effect = error_factory()
+            return cm
+
+        get_mock.side_effect = [fail_cm() for _ in range(2)]
+
+        with pytest.raises(expected_exception):
+            await hook.get_job_details(run_id=RUN_ID, account_id=None)
+
+        assert get_mock.call_count == 2
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py 
b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
index 2a1f26b49d8..76818008a45 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
@@ -45,6 +45,7 @@ class TestDbtCloudRunJobTrigger:
             end_time=self.END_TIME,
             run_id=self.RUN_ID,
             account_id=self.ACCOUNT_ID,
+            hook_params={"retry_delay": 10},
         )
         classpath, kwargs = trigger.serialize()
         assert classpath == 
"airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger"
@@ -54,6 +55,7 @@ class TestDbtCloudRunJobTrigger:
             "conn_id": self.CONN_ID,
             "end_time": self.END_TIME,
             "poll_interval": self.POLL_INTERVAL,
+            "hook_params": {"retry_delay": 10},
         }
 
     @pytest.mark.asyncio

Reply via email to