This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 43916c5034 Optimize deferred execution mode in DbtCloudJobRunSensor 
(#30968)
43916c5034 is described below

commit 43916c50341d937d9976b2065e40e1611e918663
Author: Phani Kumar <[email protected]>
AuthorDate: Mon May 1 20:02:08 2023 +0530

    Optimize deferred execution mode in DbtCloudJobRunSensor (#30968)
---
 airflow/providers/dbt/cloud/sensors/dbt.py         | 23 +++++++++++-----------
 .../providers/dbt/cloud/sensors/test_dbt_cloud.py  | 22 +++++++++++++++++++--
 2 files changed, 32 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/dbt/cloud/sensors/dbt.py 
b/airflow/providers/dbt/cloud/sensors/dbt.py
index f8303564a5..99e7707d18 100644
--- a/airflow/providers/dbt/cloud/sensors/dbt.py
+++ b/airflow/providers/dbt/cloud/sensors/dbt.py
@@ -99,17 +99,18 @@ class DbtCloudJobRunSensor(BaseSensorOperator):
             super().execute(context)
         else:
             end_time = time.time() + self.timeout
-            self.defer(
-                timeout=self.execution_timeout,
-                trigger=DbtCloudRunJobTrigger(
-                    run_id=self.run_id,
-                    conn_id=self.dbt_cloud_conn_id,
-                    account_id=self.account_id,
-                    poll_interval=self.poke_interval,
-                    end_time=end_time,
-                ),
-                method_name="execute_complete",
-            )
+            if not self.poke(context=context):
+                self.defer(
+                    timeout=self.execution_timeout,
+                    trigger=DbtCloudRunJobTrigger(
+                        run_id=self.run_id,
+                        conn_id=self.dbt_cloud_conn_id,
+                        account_id=self.account_id,
+                        poll_interval=self.poke_interval,
+                        end_time=end_time,
+                    ),
+                    method_name="execute_complete",
+                )
 
     def execute_complete(self, context: Context, event: dict[str, Any]) -> int:
         """
diff --git a/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py 
b/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py
index 2f27eb04a6..4e6f35aa86 100644
--- a/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py
+++ b/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py
@@ -88,7 +88,22 @@ class TestDbtCloudJobRunSensor:
             with pytest.raises(DbtCloudJobRunException, match=error_message):
                 self.sensor.poke({})
 
-    def test_execute_with_deferrable_mode(self):
+    @mock.patch("airflow.providers.dbt.cloud.sensors.dbt.DbtCloudHook")
+    
@mock.patch("airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor.defer")
+    def test_dbt_cloud_job_run_sensor_finish_before_deferred(self, mock_defer, 
mock_hook):
+        task = DbtCloudJobRunSensor(
+            dbt_cloud_conn_id=self.CONN_ID,
+            task_id=self.TASK_ID,
+            run_id=self.DBT_RUN_ID,
+            timeout=self.TIMEOUT,
+            deferrable=True,
+        )
+        mock_hook.return_value.get_job_run_status.return_value = 
DbtCloudJobRunStatus.SUCCESS.value
+        task.execute(mock.MagicMock())
+        assert not mock_defer.called
+
+    @mock.patch("airflow.providers.dbt.cloud.sensors.dbt.DbtCloudHook")
+    def test_execute_with_deferrable_mode(self, mock_hook):
         """Assert execute method defer for Dbt cloud job run status sensors"""
         task = DbtCloudJobRunSensor(
             dbt_cloud_conn_id=self.CONN_ID,
@@ -97,6 +112,7 @@ class TestDbtCloudJobRunSensor:
             timeout=self.TIMEOUT,
             deferrable=True,
         )
+        mock_hook.return_value.get_job_run_status.return_value = 
DbtCloudJobRunStatus.STARTING.value
         with pytest.raises(TaskDeferred) as exc:
             task.execute({})
         assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger 
is not a DbtCloudRunJobTrigger"
@@ -151,7 +167,8 @@ class TestDbtCloudJobRunSensorAsync:
         "Please use `DbtCloudJobRunSensor` and set `deferrable` attribute to 
`True` instead"
     )
 
-    def test_dbt_job_run_sensor_async(self):
+    @mock.patch("airflow.providers.dbt.cloud.sensors.dbt.DbtCloudHook")
+    def test_dbt_job_run_sensor_async(self, mock_hook):
         """Assert execute method defer for Dbt cloud job run status sensors"""
 
         with pytest.warns(DeprecationWarning, match=self.depcrecation_message):
@@ -161,6 +178,7 @@ class TestDbtCloudJobRunSensorAsync:
                 run_id=self.DBT_RUN_ID,
                 timeout=self.TIMEOUT,
             )
+        mock_hook.return_value.get_job_run_status.return_value = 
DbtCloudJobRunStatus.STARTING.value
         with pytest.raises(TaskDeferred) as exc:
             task.execute({})
         assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger 
is not a DbtCloudRunJobTrigger"

Reply via email to