This is an automated email from the ASF dual-hosted git repository.
joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 83c118413c Add `retry_from_failure` parameter to
DbtCloudRunJobOperator (#38868)
83c118413c is described below
commit 83c118413cef8c140276489b408c4b46ea0a30b5
Author: Bora Berke Sahin <[email protected]>
AuthorDate: Wed Jun 5 03:13:16 2024 +0300
Add `retry_from_failure` parameter to DbtCloudRunJobOperator (#38868)
* Add `retry_from_failure` parameter to DbtCloudRunJobOperator
* Use rerun endpoint only when ti.try_number is greater than 1
* Fix docstring links
* Do not allow override parameters to be used when retry_from_failure is
True
* Fix base endpoint url prefix
* Split test cases and update docstring
* Use `rerun` only if the previous job run has failed
---
airflow/providers/dbt/cloud/hooks/dbt.py | 33 +++++++++++++
airflow/providers/dbt/cloud/operators/dbt.py | 7 +++
.../operators.rst | 4 ++
tests/providers/dbt/cloud/hooks/test_dbt.py | 56 ++++++++++++++++++++++
tests/providers/dbt/cloud/operators/test_dbt.py | 42 ++++++++++++++++
5 files changed, 142 insertions(+)
diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 28a2406c90..acffb47716 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -404,6 +404,7 @@ class DbtCloudHook(HttpHook):
account_id: int | None = None,
steps_override: list[str] | None = None,
schema_override: str | None = None,
+ retry_from_failure: bool = False,
additional_run_config: dict[str, Any] | None = None,
) -> Response:
"""
@@ -416,6 +417,9 @@ class DbtCloudHook(HttpHook):
instead of those configured in dbt Cloud.
:param schema_override: Optional. Override the destination schema in
the configured target for this
job.
+ :param retry_from_failure: Optional. If set to True and the previous
job run has failed, the job
+ will be triggered using the "rerun" endpoint. This parameter
cannot be used alongside
+ steps_override, schema_override, or additional_run_config.
:param additional_run_config: Optional. Any additional parameters that
should be included in the API
request when triggering the job.
:return: The request response.
@@ -439,6 +443,24 @@ class DbtCloudHook(HttpHook):
}
payload.update(additional_run_config)
+ if retry_from_failure:
+ latest_run = self.get_job_runs(
+ account_id=account_id,
+ payload={
+ "job_definition_id": job_id,
+ "order_by": "-created_at",
+ "limit": 1,
+ },
+ ).json()["data"]
+ if latest_run and latest_run[0]["status"] ==
DbtCloudJobRunStatus.ERROR.value:
+ if steps_override is not None or schema_override is not None
or additional_run_config != {}:
+ warnings.warn(
+ "steps_override, schema_override, or
additional_run_config will be ignored when"
+ " retry_from_failure is True and previous job run has
failed.",
+ UserWarning,
+ stacklevel=2,
+ )
+ return self.retry_failed_job_run(job_id, account_id)
return self._run_and_get_response(
method="POST",
endpoint=f"{account_id}/jobs/{job_id}/run/",
@@ -662,6 +684,17 @@ class DbtCloudHook(HttpHook):
results = await asyncio.gather(*tasks.values())
return {filename: result.json() for filename, result in
zip(tasks.keys(), results)}
+ @fallback_to_default_account
+ def retry_failed_job_run(self, job_id: int, account_id: int | None = None)
-> Response:
+ """
+ Retry a failed run for a job from the point of failure, if the run
failed. Otherwise, trigger a new run.
+
+ :param job_id: The ID of a dbt Cloud job.
+ :param account_id: Optional. The ID of a dbt Cloud account.
+ :return: The request response.
+ """
+ return self._run_and_get_response(method="POST",
endpoint=f"{account_id}/jobs/{job_id}/rerun/")
+
def test_connection(self) -> tuple[bool, str]:
"""Test dbt Cloud connection."""
try:
diff --git a/airflow/providers/dbt/cloud/operators/dbt.py
b/airflow/providers/dbt/cloud/operators/dbt.py
index 18df34ce3e..c26e67e2a8 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -75,6 +75,10 @@ class DbtCloudRunJobOperator(BaseOperator):
request when triggering the job.
:param reuse_existing_run: Flag to determine whether to reuse existing non
terminal job run. If set to
true and non terminal job runs found, it use the latest run without
triggering a new job run.
+ :param retry_from_failure: Flag to determine whether to retry the job run
from failure. If set to true
+ and the last job run has failed, it triggers a new job run with the
same configuration as the failed
+ run. For more information on retry logic, see:
+
https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job
:param deferrable: Run operator in the deferrable mode
:return: The ID of the triggered dbt Cloud job run.
"""
@@ -105,6 +109,7 @@ class DbtCloudRunJobOperator(BaseOperator):
check_interval: int = 60,
additional_run_config: dict[str, Any] | None = None,
reuse_existing_run: bool = False,
+ retry_from_failure: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
@@ -121,6 +126,7 @@ class DbtCloudRunJobOperator(BaseOperator):
self.additional_run_config = additional_run_config or {}
self.run_id: int | None = None
self.reuse_existing_run = reuse_existing_run
+ self.retry_from_failure = retry_from_failure
self.deferrable = deferrable
def execute(self, context: Context):
@@ -150,6 +156,7 @@ class DbtCloudRunJobOperator(BaseOperator):
cause=self.trigger_reason,
steps_override=self.steps_override,
schema_override=self.schema_override,
+ retry_from_failure=self.retry_from_failure,
additional_run_config=self.additional_run_config,
)
self.run_id = trigger_job_response.json()["data"]["id"]
diff --git a/docs/apache-airflow-providers-dbt-cloud/operators.rst
b/docs/apache-airflow-providers-dbt-cloud/operators.rst
index a87739086a..af5b900d23 100644
--- a/docs/apache-airflow-providers-dbt-cloud/operators.rst
+++ b/docs/apache-airflow-providers-dbt-cloud/operators.rst
@@ -51,6 +51,10 @@ resource utilization while the job is running.
When ``wait_for_termination`` is False and ``deferrable`` is False, we just
submit the job and can only
track the job status with the
:class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`.
+When ``retry_from_failure`` is True, we retry the run for a job from the point
of failure,
+if the run failed. Otherwise we trigger a new run.
+For more information on the retry logic, reference the
+`API documentation
<https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job>`__.
While ``schema_override`` and ``steps_override`` are explicit, optional
parameters for the
``DbtCloudRunJobOperator``, custom run configurations can also be passed to
the operator using the
diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py
b/tests/providers/dbt/cloud/hooks/test_dbt.py
index c2b2c1a98a..71ef75ba31 100644
--- a/tests/providers/dbt/cloud/hooks/test_dbt.py
+++ b/tests/providers/dbt/cloud/hooks/test_dbt.py
@@ -425,6 +425,62 @@ class TestDbtCloudHook:
)
hook._paginate.assert_not_called()
+ @pytest.mark.parametrize(
+ argnames="conn_id, account_id",
+ argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+ ids=["default_account", "explicit_account"],
+ )
+ @pytest.mark.parametrize(
+ argnames="get_job_runs_data, should_use_rerun",
+ argvalues=[
+ ([], False),
+ ([{"status": DbtCloudJobRunStatus.QUEUED.value}], False),
+ ([{"status": DbtCloudJobRunStatus.STARTING.value}], False),
+ ([{"status": DbtCloudJobRunStatus.RUNNING.value}], False),
+ ([{"status": DbtCloudJobRunStatus.SUCCESS.value}], False),
+ ([{"status": DbtCloudJobRunStatus.ERROR.value}], True),
+ ([{"status": DbtCloudJobRunStatus.CANCELLED.value}], False),
+ ],
+ )
+ @patch.object(DbtCloudHook, "run")
+ @patch.object(DbtCloudHook, "_paginate")
+ def test_trigger_job_run_with_retry_from_failure(
+ self,
+ mock_http_run,
+ mock_paginate,
+ get_job_runs_data,
+ should_use_rerun,
+ conn_id,
+ account_id,
+ ):
+ hook = DbtCloudHook(conn_id)
+ cause = ""
+ retry_from_failure = True
+
+ with patch.object(DbtCloudHook, "get_job_runs") as
mock_get_job_run_status:
+ mock_get_job_run_status.return_value.json.return_value = {"data":
get_job_runs_data}
+ hook.trigger_job_run(
+ job_id=JOB_ID, cause=cause, account_id=account_id,
retry_from_failure=retry_from_failure
+ )
+ assert hook.method == "POST"
+ _account_id = account_id or DEFAULT_ACCOUNT_ID
+ hook._paginate.assert_not_called()
+ if should_use_rerun:
+ hook.run.assert_called_once_with(
+
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/", data=None
+ )
+ else:
+ hook.run.assert_called_once_with(
+
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
+ data=json.dumps(
+ {
+ "cause": cause,
+ "steps_override": None,
+ "schema_override": None,
+ }
+ ),
+ )
+
@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
diff --git a/tests/providers/dbt/cloud/operators/test_dbt.py
b/tests/providers/dbt/cloud/operators/test_dbt.py
index 3b0c1dab09..136a45b9ed 100644
--- a/tests/providers/dbt/cloud/operators/test_dbt.py
+++ b/tests/providers/dbt/cloud/operators/test_dbt.py
@@ -264,6 +264,7 @@ class TestDbtCloudRunJobOperator:
cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in
the {self.dag.dag_id} DAG.",
steps_override=self.config["steps_override"],
schema_override=self.config["schema_override"],
+ retry_from_failure=False,
additional_run_config=self.config["additional_run_config"],
)
@@ -312,6 +313,7 @@ class TestDbtCloudRunJobOperator:
cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in
the {self.dag.dag_id} DAG.",
steps_override=self.config["steps_override"],
schema_override=self.config["schema_override"],
+ retry_from_failure=False,
additional_run_config=self.config["additional_run_config"],
)
@@ -379,6 +381,45 @@ class TestDbtCloudRunJobOperator:
},
)
+ @patch.object(DbtCloudHook, "trigger_job_run")
+ @pytest.mark.parametrize(
+ "conn_id, account_id",
+ [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+ ids=["default_account", "explicit_account"],
+ )
+ def test_execute_retry_from_failure(self, mock_run_job, conn_id,
account_id):
+ operator = DbtCloudRunJobOperator(
+ task_id=TASK_ID,
+ dbt_cloud_conn_id=conn_id,
+ account_id=account_id,
+ trigger_reason=None,
+ dag=self.dag,
+ retry_from_failure=True,
+ **self.config,
+ )
+
+ assert operator.dbt_cloud_conn_id == conn_id
+ assert operator.job_id == self.config["job_id"]
+ assert operator.account_id == account_id
+ assert operator.check_interval == self.config["check_interval"]
+ assert operator.timeout == self.config["timeout"]
+ assert operator.retry_from_failure
+ assert operator.steps_override == self.config["steps_override"]
+ assert operator.schema_override == self.config["schema_override"]
+ assert operator.additional_run_config ==
self.config["additional_run_config"]
+
+ operator.execute(context=self.mock_context)
+
+ mock_run_job.assert_called_once_with(
+ account_id=account_id,
+ job_id=JOB_ID,
+ cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in the
{self.dag.dag_id} DAG.",
+ steps_override=self.config["steps_override"],
+ schema_override=self.config["schema_override"],
+ retry_from_failure=True,
+ additional_run_config=self.config["additional_run_config"],
+ )
+
@patch.object(DbtCloudHook, "trigger_job_run")
@pytest.mark.parametrize(
"conn_id, account_id",
@@ -411,6 +452,7 @@ class TestDbtCloudRunJobOperator:
cause=custom_trigger_reason,
steps_override=self.config["steps_override"],
schema_override=self.config["schema_override"],
+ retry_from_failure=False,
additional_run_config=self.config["additional_run_config"],
)