This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 340a70bfe72 Added condition to check if it is a scheduled save or 
rerun (#43453)
340a70bfe72 is described below

commit 340a70bfe7289e01898ddd75f8edfaf7772e9d09
Author: krzysztof-kubis <[email protected]>
AuthorDate: Sat Nov 9 04:19:46 2024 +0100

    Added condition to check if it is a scheduled save or rerun (#43453)
    
    * Aadded condition to check if it is a scheduled save or rerun
    
    * Fix key name in context of task
    
    * I added unit tests to the condition to check if it is a scheduled save or 
rerun
    
    ---------
    
    Co-authored-by: RBAMOUSER\kubisk1 <[email protected]>
    Co-authored-by: krzysztof-kubis <[email protected]>
    Co-authored-by: krzysztof-kubis <[email protected]/>
---
 .../airflow/providers/dbt/cloud/operators/dbt.py   |  4 +-
 providers/tests/dbt/cloud/operators/test_dbt.py    | 78 ++++++++++++++++++++++
 2 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/providers/src/airflow/providers/dbt/cloud/operators/dbt.py 
b/providers/src/airflow/providers/dbt/cloud/operators/dbt.py
index c26e67e2a8b..8795ebf0ca7 100644
--- a/providers/src/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/providers/src/airflow/providers/dbt/cloud/operators/dbt.py
@@ -149,6 +149,8 @@ class DbtCloudRunJobOperator(BaseOperator):
                 self.run_id = non_terminal_runs[0]["id"]
                 job_run_url = non_terminal_runs[0]["href"]
 
+        is_retry = context["ti"].try_number != 1
+
         if not self.reuse_existing_run or not non_terminal_runs:
             trigger_job_response = self.hook.trigger_job_run(
                 account_id=self.account_id,
@@ -156,7 +158,7 @@ class DbtCloudRunJobOperator(BaseOperator):
                 cause=self.trigger_reason,
                 steps_override=self.steps_override,
                 schema_override=self.schema_override,
-                retry_from_failure=self.retry_from_failure,
+                retry_from_failure=is_retry and self.retry_from_failure,
                 additional_run_config=self.additional_run_config,
             )
             self.run_id = trigger_job_response.json()["data"]["id"]
diff --git a/providers/tests/dbt/cloud/operators/test_dbt.py 
b/providers/tests/dbt/cloud/operators/test_dbt.py
index eb50bd5a22a..a5f8752ffb3 100644
--- a/providers/tests/dbt/cloud/operators/test_dbt.py
+++ b/providers/tests/dbt/cloud/operators/test_dbt.py
@@ -64,6 +64,17 @@ EXPLICIT_ACCOUNT_JOB_RUN_RESPONSE = {
         ),
     }
 }
+JOB_RUN_ERROR_RESPONSE = {
+    "data": [
+        {
+            "id": RUN_ID,
+            "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format(
+                account_id=ACCOUNT_ID, project_id=PROJECT_ID, run_id=RUN_ID
+            ),
+            "status": DbtCloudJobRunStatus.ERROR.value,
+        }
+    ]
+}
 
 
 def mock_response_json(response: dict):
@@ -421,6 +432,73 @@ class TestDbtCloudRunJobOperator:
             additional_run_config=self.config["additional_run_config"],
         )
 
+    @patch.object(DbtCloudHook, "_run_and_get_response")
+    @pytest.mark.parametrize(
+        "conn_id, account_id",
+        [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+        ids=["default_account", "explicit_account"],
+    )
+    def test_execute_retry_from_failure_run(self, mock_run_req, conn_id, 
account_id):
+        operator = DbtCloudRunJobOperator(
+            task_id=TASK_ID,
+            dbt_cloud_conn_id=conn_id,
+            account_id=account_id,
+            trigger_reason=None,
+            dag=self.dag,
+            retry_from_failure=True,
+            **self.config,
+        )
+        self.mock_context["ti"].try_number = 1
+
+        assert operator.dbt_cloud_conn_id == conn_id
+        assert operator.job_id == self.config["job_id"]
+        assert operator.account_id == account_id
+        assert operator.check_interval == self.config["check_interval"]
+        assert operator.timeout == self.config["timeout"]
+        assert operator.retry_from_failure
+        assert operator.steps_override == self.config["steps_override"]
+        assert operator.schema_override == self.config["schema_override"]
+        assert operator.additional_run_config == 
self.config["additional_run_config"]
+
+        operator.execute(context=self.mock_context)
+
+        mock_run_req.assert_called()
+
+    @patch.object(
+        DbtCloudHook, "_run_and_get_response", 
return_value=mock_response_json(JOB_RUN_ERROR_RESPONSE)
+    )
+    @patch.object(DbtCloudHook, "retry_failed_job_run")
+    @pytest.mark.parametrize(
+        "conn_id, account_id",
+        [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+        ids=["default_account", "explicit_account"],
+    )
+    def test_execute_retry_from_failure_rerun(self, mock_run_req, 
mock_rerun_req, conn_id, account_id):
+        operator = DbtCloudRunJobOperator(
+            task_id=TASK_ID,
+            dbt_cloud_conn_id=conn_id,
+            account_id=account_id,
+            trigger_reason=None,
+            dag=self.dag,
+            retry_from_failure=True,
+            **self.config,
+        )
+        self.mock_context["ti"].try_number = 2
+
+        assert operator.dbt_cloud_conn_id == conn_id
+        assert operator.job_id == self.config["job_id"]
+        assert operator.account_id == account_id
+        assert operator.check_interval == self.config["check_interval"]
+        assert operator.timeout == self.config["timeout"]
+        assert operator.retry_from_failure
+        assert operator.steps_override == self.config["steps_override"]
+        assert operator.schema_override == self.config["schema_override"]
+        assert operator.additional_run_config == 
self.config["additional_run_config"]
+
+        operator.execute(context=self.mock_context)
+
+        mock_rerun_req.assert_called_once()
+
     @patch.object(DbtCloudHook, "trigger_job_run")
     @pytest.mark.parametrize(
         "conn_id, account_id",

Reply via email to