This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 26b8997fb1 Optimize deferred mode execution for wasb sensors (#31009)
26b8997fb1 is described below

commit 26b8997fb185fd308c243a9418ade317e533e26b
Author: Phani Kumar <[email protected]>
AuthorDate: Tue May 23 10:56:35 2023 +0530

    Optimize deferred mode execution for wasb sensors (#31009)
    
    * Optimize wasb sensors
---
 airflow/providers/microsoft/azure/sensors/wasb.py  | 44 +++++++++++-----------
 .../providers/microsoft/azure/sensors/test_wasb.py | 24 ++++++++++--
 2 files changed, 43 insertions(+), 25 deletions(-)

diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py 
b/airflow/providers/microsoft/azure/sensors/wasb.py
index 17b64213ec..db64255804 100644
--- a/airflow/providers/microsoft/azure/sensors/wasb.py
+++ b/airflow/providers/microsoft/azure/sensors/wasb.py
@@ -78,17 +78,18 @@ class WasbBlobSensor(BaseSensorOperator):
         if not self.deferrable:
             super().execute(context=context)
         else:
-            self.defer(
-                timeout=timedelta(seconds=self.timeout),
-                trigger=WasbBlobSensorTrigger(
-                    container_name=self.container_name,
-                    blob_name=self.blob_name,
-                    wasb_conn_id=self.wasb_conn_id,
-                    public_read=self.public_read,
-                    poke_interval=self.poke_interval,
-                ),
-                method_name="execute_complete",
-            )
+            if not self.poke(context=context):
+                self.defer(
+                    timeout=timedelta(seconds=self.timeout),
+                    trigger=WasbBlobSensorTrigger(
+                        container_name=self.container_name,
+                        blob_name=self.blob_name,
+                        wasb_conn_id=self.wasb_conn_id,
+                        public_read=self.public_read,
+                        poke_interval=self.poke_interval,
+                    ),
+                    method_name="execute_complete",
+                )
 
     def execute_complete(self, context: Context, event: dict[str, str]) -> 
None:
         """
@@ -172,16 +173,17 @@ class WasbPrefixSensor(BaseSensorOperator):
         if not self.deferrable:
             super().execute(context=context)
         else:
-            self.defer(
-                timeout=timedelta(seconds=self.timeout),
-                trigger=WasbPrefixSensorTrigger(
-                    container_name=self.container_name,
-                    prefix=self.prefix,
-                    wasb_conn_id=self.wasb_conn_id,
-                    poke_interval=self.poke_interval,
-                ),
-                method_name="execute_complete",
-            )
+            if not self.poke(context=context):
+                self.defer(
+                    timeout=timedelta(seconds=self.timeout),
+                    trigger=WasbPrefixSensorTrigger(
+                        container_name=self.container_name,
+                        prefix=self.prefix,
+                        wasb_conn_id=self.wasb_conn_id,
+                        poke_interval=self.poke_interval,
+                    ),
+                    method_name="execute_complete",
+                )
 
     def execute_complete(self, context: Context, event: dict[str, str]) -> 
None:
         """
diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py 
b/tests/providers/microsoft/azure/sensors/test_wasb.py
index d4a3466c8c..6067030694 100644
--- a/tests/providers/microsoft/azure/sensors/test_wasb.py
+++ b/tests/providers/microsoft/azure/sensors/test_wasb.py
@@ -128,9 +128,17 @@ class TestWasbBlobAsyncSensor:
         deferrable=True,
     )
 
-    def test_wasb_blob_sensor_async(self):
+    @mock.patch("airflow.providers.microsoft.azure.sensors.wasb.WasbHook")
+    
@mock.patch("airflow.providers.microsoft.azure.sensors.wasb.WasbBlobSensor.defer")
+    def test_wasb_blob_sensor_finish_before_deferred(self, mock_defer, 
mock_hook):
+        mock_hook.return_value.check_for_blob.return_value = True
+        self.SENSOR.execute(mock.MagicMock())
+        assert not mock_defer.called
+
+    @mock.patch("airflow.providers.microsoft.azure.sensors.wasb.WasbHook")
+    def test_wasb_blob_sensor_async(self, mock_hook):
         """Assert execute method defer for wasb blob sensor"""
-
+        mock_hook.return_value.check_for_blob.return_value = False
         with pytest.raises(TaskDeferred) as exc:
             self.SENSOR.execute(self.create_context(self.SENSOR))
         assert isinstance(exc.value.trigger, WasbBlobSensorTrigger), "Trigger 
is not a WasbBlobSensorTrigger"
@@ -244,9 +252,17 @@ class TestWasbPrefixAsyncSensor:
         deferrable=True,
     )
 
-    def test_wasb_prefix_sensor_async(self):
-        """Assert execute method defer for wasb prefix sensor"""
+    @mock.patch("airflow.providers.microsoft.azure.sensors.wasb.WasbHook")
+    
@mock.patch("airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor.defer")
+    def test_wasb_prefix_sensor_finish_before_deferred(self, mock_defer, 
mock_hook):
+        mock_hook.return_value.check_for_prefix.return_value = True
+        self.SENSOR.execute(mock.MagicMock())
+        assert not mock_defer.called
 
+    @mock.patch("airflow.providers.microsoft.azure.sensors.wasb.WasbHook")
+    def test_wasb_prefix_sensor_async(self, mock_hook):
+        """Assert execute method defer for wasb prefix sensor"""
+        mock_hook.return_value.check_for_prefix.return_value = False
         with pytest.raises(TaskDeferred) as exc:
             self.SENSOR.execute(self.create_context(self.SENSOR))
         assert isinstance(

Reply via email to