This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 83c118413c Add `retry_from_failure` parameter to 
DbtCloudRunJobOperator (#38868)
83c118413c is described below

commit 83c118413cef8c140276489b408c4b46ea0a30b5
Author: Bora Berke Sahin <[email protected]>
AuthorDate: Wed Jun 5 03:13:16 2024 +0300

    Add `retry_from_failure` parameter to DbtCloudRunJobOperator (#38868)
    
    * Add `retry_from_failure` parameter to DbtCloudRunJobOperator
    
    * Use rerun endpoint only when ti.try_number is greater than 1
    
    * Fix docstring links
    
    * Do not allow override parameters to be used when retry_from_failure is 
True
    
    * Fix base endpoint url prefix
    
    * Split test cases and update docstring
    
    * Use `rerun` only if the previous job run has failed
---
 airflow/providers/dbt/cloud/hooks/dbt.py           | 33 +++++++++++++
 airflow/providers/dbt/cloud/operators/dbt.py       |  7 +++
 .../operators.rst                                  |  4 ++
 tests/providers/dbt/cloud/hooks/test_dbt.py        | 56 ++++++++++++++++++++++
 tests/providers/dbt/cloud/operators/test_dbt.py    | 42 ++++++++++++++++
 5 files changed, 142 insertions(+)

diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py 
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 28a2406c90..acffb47716 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -404,6 +404,7 @@ class DbtCloudHook(HttpHook):
         account_id: int | None = None,
         steps_override: list[str] | None = None,
         schema_override: str | None = None,
+        retry_from_failure: bool = False,
         additional_run_config: dict[str, Any] | None = None,
     ) -> Response:
         """
@@ -416,6 +417,9 @@ class DbtCloudHook(HttpHook):
             instead of those configured in dbt Cloud.
         :param schema_override: Optional. Override the destination schema in 
the configured target for this
             job.
+        :param retry_from_failure: Optional. If set to True and the previous 
job run has failed, the job
+            will be triggered using the "rerun" endpoint. This parameter 
cannot be used alongside
+            steps_override, schema_override, or additional_run_config.
         :param additional_run_config: Optional. Any additional parameters that 
should be included in the API
             request when triggering the job.
         :return: The request response.
@@ -439,6 +443,24 @@ class DbtCloudHook(HttpHook):
         }
         payload.update(additional_run_config)
 
+        if retry_from_failure:
+            latest_run = self.get_job_runs(
+                account_id=account_id,
+                payload={
+                    "job_definition_id": job_id,
+                    "order_by": "-created_at",
+                    "limit": 1,
+                },
+            ).json()["data"]
+            if latest_run and latest_run[0]["status"] == 
DbtCloudJobRunStatus.ERROR.value:
+                if steps_override is not None or schema_override is not None 
or additional_run_config != {}:
+                    warnings.warn(
+                        "steps_override, schema_override, or 
additional_run_config will be ignored when"
+                        " retry_from_failure is True and previous job run has 
failed.",
+                        UserWarning,
+                        stacklevel=2,
+                    )
+                return self.retry_failed_job_run(job_id, account_id)
         return self._run_and_get_response(
             method="POST",
             endpoint=f"{account_id}/jobs/{job_id}/run/",
@@ -662,6 +684,17 @@ class DbtCloudHook(HttpHook):
         results = await asyncio.gather(*tasks.values())
         return {filename: result.json() for filename, result in 
zip(tasks.keys(), results)}
 
+    @fallback_to_default_account
+    def retry_failed_job_run(self, job_id: int, account_id: int | None = None) 
-> Response:
+        """
+        Retry a failed run for a job from the point of failure, if the run 
failed. Otherwise, trigger a new run.
+
+        :param job_id: The ID of a dbt Cloud job.
+        :param account_id: Optional. The ID of a dbt Cloud account.
+        :return: The request response.
+        """
+        return self._run_and_get_response(method="POST", 
endpoint=f"{account_id}/jobs/{job_id}/rerun/")
+
     def test_connection(self) -> tuple[bool, str]:
         """Test dbt Cloud connection."""
         try:
diff --git a/airflow/providers/dbt/cloud/operators/dbt.py 
b/airflow/providers/dbt/cloud/operators/dbt.py
index 18df34ce3e..c26e67e2a8 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -75,6 +75,10 @@ class DbtCloudRunJobOperator(BaseOperator):
         request when triggering the job.
     :param reuse_existing_run: Flag to determine whether to reuse existing non 
terminal job run. If set to
         true and non terminal job runs found, it use the latest run without 
triggering a new job run.
+    :param retry_from_failure: Flag to determine whether to retry the job run 
from failure. If set to true
+        and the last job run has failed, it triggers a new job run with the 
same configuration as the failed
+        run. For more information on retry logic, see:
+        
https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job
     :param deferrable: Run operator in the deferrable mode
     :return: The ID of the triggered dbt Cloud job run.
     """
@@ -105,6 +109,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         check_interval: int = 60,
         additional_run_config: dict[str, Any] | None = None,
         reuse_existing_run: bool = False,
+        retry_from_failure: bool = False,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ) -> None:
@@ -121,6 +126,7 @@ class DbtCloudRunJobOperator(BaseOperator):
         self.additional_run_config = additional_run_config or {}
         self.run_id: int | None = None
         self.reuse_existing_run = reuse_existing_run
+        self.retry_from_failure = retry_from_failure
         self.deferrable = deferrable
 
     def execute(self, context: Context):
@@ -150,6 +156,7 @@ class DbtCloudRunJobOperator(BaseOperator):
                 cause=self.trigger_reason,
                 steps_override=self.steps_override,
                 schema_override=self.schema_override,
+                retry_from_failure=self.retry_from_failure,
                 additional_run_config=self.additional_run_config,
             )
             self.run_id = trigger_job_response.json()["data"]["id"]
diff --git a/docs/apache-airflow-providers-dbt-cloud/operators.rst 
b/docs/apache-airflow-providers-dbt-cloud/operators.rst
index a87739086a..af5b900d23 100644
--- a/docs/apache-airflow-providers-dbt-cloud/operators.rst
+++ b/docs/apache-airflow-providers-dbt-cloud/operators.rst
@@ -51,6 +51,10 @@ resource utilization while the job is running.
 When ``wait_for_termination`` is False and ``deferrable`` is False, we just 
submit the job and can only
 track the job status with the 
:class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`.
 
+When ``retry_from_failure`` is True, we retry the run for a job from the point 
of failure,
+if the run failed. Otherwise we trigger a new run.
+For more information on the retry logic, reference the
+`API documentation 
<https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job>`__.
 
 While ``schema_override`` and ``steps_override`` are explicit, optional 
parameters for the
 ``DbtCloudRunJobOperator``, custom run configurations can also be passed to 
the operator using the
diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py 
b/tests/providers/dbt/cloud/hooks/test_dbt.py
index c2b2c1a98a..71ef75ba31 100644
--- a/tests/providers/dbt/cloud/hooks/test_dbt.py
+++ b/tests/providers/dbt/cloud/hooks/test_dbt.py
@@ -425,6 +425,62 @@ class TestDbtCloudHook:
         )
         hook._paginate.assert_not_called()
 
+    @pytest.mark.parametrize(
+        argnames="conn_id, account_id",
+        argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+        ids=["default_account", "explicit_account"],
+    )
+    @pytest.mark.parametrize(
+        argnames="get_job_runs_data, should_use_rerun",
+        argvalues=[
+            ([], False),
+            ([{"status": DbtCloudJobRunStatus.QUEUED.value}], False),
+            ([{"status": DbtCloudJobRunStatus.STARTING.value}], False),
+            ([{"status": DbtCloudJobRunStatus.RUNNING.value}], False),
+            ([{"status": DbtCloudJobRunStatus.SUCCESS.value}], False),
+            ([{"status": DbtCloudJobRunStatus.ERROR.value}], True),
+            ([{"status": DbtCloudJobRunStatus.CANCELLED.value}], False),
+        ],
+    )
+    @patch.object(DbtCloudHook, "run")
+    @patch.object(DbtCloudHook, "_paginate")
+    def test_trigger_job_run_with_retry_from_failure(
+        self,
+        mock_http_run,
+        mock_paginate,
+        get_job_runs_data,
+        should_use_rerun,
+        conn_id,
+        account_id,
+    ):
+        hook = DbtCloudHook(conn_id)
+        cause = ""
+        retry_from_failure = True
+
+        with patch.object(DbtCloudHook, "get_job_runs") as 
mock_get_job_run_status:
+            mock_get_job_run_status.return_value.json.return_value = {"data": 
get_job_runs_data}
+            hook.trigger_job_run(
+                job_id=JOB_ID, cause=cause, account_id=account_id, 
retry_from_failure=retry_from_failure
+            )
+            assert hook.method == "POST"
+            _account_id = account_id or DEFAULT_ACCOUNT_ID
+            hook._paginate.assert_not_called()
+            if should_use_rerun:
+                hook.run.assert_called_once_with(
+                    
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/", data=None
+                )
+            else:
+                hook.run.assert_called_once_with(
+                    
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
+                    data=json.dumps(
+                        {
+                            "cause": cause,
+                            "steps_override": None,
+                            "schema_override": None,
+                        }
+                    ),
+                )
+
     @pytest.mark.parametrize(
         argnames="conn_id, account_id",
         argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
diff --git a/tests/providers/dbt/cloud/operators/test_dbt.py 
b/tests/providers/dbt/cloud/operators/test_dbt.py
index 3b0c1dab09..136a45b9ed 100644
--- a/tests/providers/dbt/cloud/operators/test_dbt.py
+++ b/tests/providers/dbt/cloud/operators/test_dbt.py
@@ -264,6 +264,7 @@ class TestDbtCloudRunJobOperator:
                 cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in 
the {self.dag.dag_id} DAG.",
                 steps_override=self.config["steps_override"],
                 schema_override=self.config["schema_override"],
+                retry_from_failure=False,
                 additional_run_config=self.config["additional_run_config"],
             )
 
@@ -312,6 +313,7 @@ class TestDbtCloudRunJobOperator:
                 cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in 
the {self.dag.dag_id} DAG.",
                 steps_override=self.config["steps_override"],
                 schema_override=self.config["schema_override"],
+                retry_from_failure=False,
                 additional_run_config=self.config["additional_run_config"],
             )
 
@@ -379,6 +381,45 @@ class TestDbtCloudRunJobOperator:
             },
         )
 
+    @patch.object(DbtCloudHook, "trigger_job_run")
+    @pytest.mark.parametrize(
+        "conn_id, account_id",
+        [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+        ids=["default_account", "explicit_account"],
+    )
+    def test_execute_retry_from_failure(self, mock_run_job, conn_id, 
account_id):
+        operator = DbtCloudRunJobOperator(
+            task_id=TASK_ID,
+            dbt_cloud_conn_id=conn_id,
+            account_id=account_id,
+            trigger_reason=None,
+            dag=self.dag,
+            retry_from_failure=True,
+            **self.config,
+        )
+
+        assert operator.dbt_cloud_conn_id == conn_id
+        assert operator.job_id == self.config["job_id"]
+        assert operator.account_id == account_id
+        assert operator.check_interval == self.config["check_interval"]
+        assert operator.timeout == self.config["timeout"]
+        assert operator.retry_from_failure
+        assert operator.steps_override == self.config["steps_override"]
+        assert operator.schema_override == self.config["schema_override"]
+        assert operator.additional_run_config == 
self.config["additional_run_config"]
+
+        operator.execute(context=self.mock_context)
+
+        mock_run_job.assert_called_once_with(
+            account_id=account_id,
+            job_id=JOB_ID,
+            cause=f"Triggered via Apache Airflow by task {TASK_ID!r} in the 
{self.dag.dag_id} DAG.",
+            steps_override=self.config["steps_override"],
+            schema_override=self.config["schema_override"],
+            retry_from_failure=True,
+            additional_run_config=self.config["additional_run_config"],
+        )
+
     @patch.object(DbtCloudHook, "trigger_job_run")
     @pytest.mark.parametrize(
         "conn_id, account_id",
@@ -411,6 +452,7 @@ class TestDbtCloudRunJobOperator:
                 cause=custom_trigger_reason,
                 steps_override=self.config["steps_override"],
                 schema_override=self.config["schema_override"],
+                retry_from_failure=False,
                 additional_run_config=self.config["additional_run_config"],
             )
 

Reply via email to