This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bb4a0b39e8 Optimize deferred execution mode for DbtCloudRunJobOperator 
(#31188)
bb4a0b39e8 is described below

commit bb4a0b39e8c021772830c9d44e72e492e0fef4bb
Author: Phani Kumar <[email protected]>
AuthorDate: Tue May 16 01:47:43 2023 +0530

    Optimize deferred execution mode for DbtCloudRunJobOperator (#31188)
---
 airflow/providers/dbt/cloud/operators/dbt.py       | 40 +++++++----
 .../dbt/cloud/operators/test_dbt_cloud.py          | 79 ++++++++++++++++++++++
 2 files changed, 107 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/dbt/cloud/operators/dbt.py 
b/airflow/providers/dbt/cloud/operators/dbt.py
index de110e9865..d9177e6772 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Any
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, BaseOperatorLink, XCom
-from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, 
DbtCloudJobRunException, DbtCloudJobRunStatus
+from airflow.providers.dbt.cloud.hooks.dbt import (
+    DbtCloudHook,
+    DbtCloudJobRunException,
+    DbtCloudJobRunStatus,
+    JobRunInfo,
+)
 from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
 
 if TYPE_CHECKING:
@@ -154,17 +159,28 @@ class DbtCloudRunJobOperator(BaseOperator):
                 return self.run_id
             else:
                 end_time = time.time() + self.timeout
-                self.defer(
-                    timeout=self.execution_timeout,
-                    trigger=DbtCloudRunJobTrigger(
-                        conn_id=self.dbt_cloud_conn_id,
-                        run_id=self.run_id,
-                        end_time=end_time,
-                        account_id=self.account_id,
-                        poll_interval=self.check_interval,
-                    ),
-                    method_name="execute_complete",
-                )
+                job_run_info = JobRunInfo(account_id=self.account_id, 
run_id=self.run_id)
+                job_run_status = self.hook.get_job_run_status(**job_run_info)
+                if not DbtCloudJobRunStatus.is_terminal(job_run_status):
+                    self.defer(
+                        timeout=self.execution_timeout,
+                        trigger=DbtCloudRunJobTrigger(
+                            conn_id=self.dbt_cloud_conn_id,
+                            run_id=self.run_id,
+                            end_time=end_time,
+                            account_id=self.account_id,
+                            poll_interval=self.check_interval,
+                        ),
+                        method_name="execute_complete",
+                    )
+                elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
+                    self.log.info("Job run %s has completed successfully.", 
str(self.run_id))
+                    return self.run_id
+                elif job_run_status in (
+                    DbtCloudJobRunStatus.CANCELLED.value,
+                    DbtCloudJobRunStatus.ERROR.value,
+                ):
+                    raise DbtCloudJobRunException(f"Job run {self.run_id} has 
failed or has been cancelled.")
         else:
             if self.deferrable is True:
                 warnings.warn(
diff --git a/tests/providers/dbt/cloud/operators/test_dbt_cloud.py 
b/tests/providers/dbt/cloud/operators/test_dbt_cloud.py
index 6aab53109f..4400ebd4ab 100644
--- a/tests/providers/dbt/cloud/operators/test_dbt_cloud.py
+++ b/tests/providers/dbt/cloud/operators/test_dbt_cloud.py
@@ -21,6 +21,7 @@ from unittest.mock import MagicMock, patch
 
 import pytest
 
+from airflow.exceptions import TaskDeferred
 from airflow.models import DAG, Connection
 from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, 
DbtCloudJobRunException, DbtCloudJobRunStatus
 from airflow.providers.dbt.cloud.operators.dbt import (
@@ -28,6 +29,7 @@ from airflow.providers.dbt.cloud.operators.dbt import (
     DbtCloudListJobsOperator,
     DbtCloudRunJobOperator,
 )
+from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
 from airflow.utils import db, timezone
 
 DEFAULT_DATE = timezone.datetime(2021, 1, 1)
@@ -95,6 +97,83 @@ class TestDbtCloudRunJobOperator:
             "additional_run_config": {"threads_override": 8},
         }
 
+    @patch(
+        
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
+        return_value=DbtCloudJobRunStatus.SUCCESS.value,
+    )
+    
@patch("airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator.defer")
+    @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
+    
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
+    def test_execute_succeeded_before_getting_deferred(
+        self, mock_trigger_job_run, mock_dbt_hook, mock_defer, 
mock_job_run_status
+    ):
+        dbt_op = DbtCloudRunJobOperator(
+            dbt_cloud_conn_id=ACCOUNT_ID_CONN,
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            check_interval=1,
+            timeout=3,
+            dag=self.dag,
+            deferrable=True,
+        )
+        dbt_op.execute(MagicMock())
+        assert not mock_defer.called
+
+    @patch(
+        
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
+        return_value=DbtCloudJobRunStatus.ERROR.value,
+    )
+    
@patch("airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator.defer")
+    @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
+    
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
+    def test_execute_failed_before_getting_deferred(
+        self, mock_trigger_job_run, mock_dbt_hook, mock_defer, 
mock_job_run_status
+    ):
+        dbt_op = DbtCloudRunJobOperator(
+            dbt_cloud_conn_id=ACCOUNT_ID_CONN,
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            check_interval=1,
+            timeout=3,
+            dag=self.dag,
+            deferrable=True,
+        )
+        with pytest.raises(DbtCloudJobRunException):
+            dbt_op.execute(MagicMock())
+        assert not mock_defer.called
+
+    @pytest.mark.parametrize(
+        "status",
+        (
+            DbtCloudJobRunStatus.QUEUED.value,
+            DbtCloudJobRunStatus.STARTING.value,
+            DbtCloudJobRunStatus.RUNNING.value,
+        ),
+    )
+    @patch(
+        
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
+    )
+    @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
+    
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
+    def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, 
mock_job_run_status, status):
+        """
+        Asserts that a task is deferred and an DbtCloudRunJobTrigger will be 
fired
+        when the DbtCloudRunJobOperator has deferrable param set to True
+        """
+        mock_job_run_status.return_value = status
+        dbt_op = DbtCloudRunJobOperator(
+            dbt_cloud_conn_id=ACCOUNT_ID_CONN,
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            check_interval=1,
+            timeout=3,
+            dag=self.dag,
+            deferrable=True,
+        )
+        with pytest.raises(TaskDeferred) as exc:
+            dbt_op.execute(MagicMock())
+        assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger 
is not a DbtCloudRunJobTrigger"
+
     @patch.object(DbtCloudHook, "trigger_job_run", 
return_value=MagicMock(**DEFAULT_ACCOUNT_JOB_RUN_RESPONSE))
     @pytest.mark.parametrize(
         "job_run_status, expected_output",

Reply via email to