This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 86b1bd22d1 Fix CloudRunExecuteJobOperator not able to retrieve the 
Cloud Run job status in deferrable mode (#36012)
86b1bd22d1 is described below

commit 86b1bd22d14792d89ddc43627e4a72dcb628c5f0
Author: VladaZakharova <[email protected]>
AuthorDate: Fri Dec 1 18:01:33 2023 +0100

    Fix CloudRunExecuteJobOperator not able to retrieve the Cloud Run job 
status in deferrable mode (#36012)
    
    Co-authored-by: Ulada Zakharava <[email protected]>
---
 .../providers/google/cloud/operators/cloud_run.py  |  4 +-
 .../providers/google/cloud/triggers/cloud_run.py   | 14 ++---
 .../google/cloud/operators/test_cloud_run.py       |  6 +--
 .../google/cloud/triggers/test_cloud_run.py        | 59 ++++++++++++----------
 4 files changed, 44 insertions(+), 39 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/cloud_run.py 
b/airflow/providers/google/cloud/operators/cloud_run.py
index 14d27810da..91b3ae6cea 100644
--- a/airflow/providers/google/cloud/operators/cloud_run.py
+++ b/airflow/providers/google/cloud/operators/cloud_run.py
@@ -321,10 +321,10 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
     def execute_complete(self, context: Context, event: dict):
         status = event["status"]
 
-        if status == RunJobStatus.TIMEOUT:
+        if status == RunJobStatus.TIMEOUT.value:
             raise AirflowException("Operation timed out")
 
-        if status == RunJobStatus.FAIL:
+        if status == RunJobStatus.FAIL.value:
             error_code = event["operation_error_code"]
             error_message = event["operation_error_message"]
             raise AirflowException(
diff --git a/airflow/providers/google/cloud/triggers/cloud_run.py 
b/airflow/providers/google/cloud/triggers/cloud_run.py
index 9506245d20..f47a7ac1b3 100644
--- a/airflow/providers/google/cloud/triggers/cloud_run.py
+++ b/airflow/providers/google/cloud/triggers/cloud_run.py
@@ -102,21 +102,21 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
         while timeout is None or timeout > 0:
             operation: operations_pb2.Operation = await 
hook.get_operation(self.operation_name)
             if operation.done:
-                # An operation can only have one of those two combinations: if 
it is succeeded, then
-                # the response field will be populated, else, then the error 
field will be.
-                if operation.response is not None:
+                # An operation can only have one of those two combinations: if 
it is failed, then
+                # the error field will be populated, else, then the response 
field will be.
+                if operation.error.SerializeToString():
                     yield TriggerEvent(
                         {
-                            "status": RunJobStatus.SUCCESS,
+                            "status": RunJobStatus.FAIL.value,
+                            "operation_error_code": operation.error.code,
+                            "operation_error_message": operation.error.message,
                             "job_name": self.job_name,
                         }
                     )
                 else:
                     yield TriggerEvent(
                         {
-                            "status": RunJobStatus.FAIL,
-                            "operation_error_code": operation.error.code,
-                            "operation_error_message": operation.error.message,
+                            "status": RunJobStatus.SUCCESS.value,
                             "job_name": self.job_name,
                         }
                     )
diff --git a/tests/providers/google/cloud/operators/test_cloud_run.py 
b/tests/providers/google/cloud/operators/test_cloud_run.py
index 152e625a23..829518e0d0 100644
--- a/tests/providers/google/cloud/operators/test_cloud_run.py
+++ b/tests/providers/google/cloud/operators/test_cloud_run.py
@@ -166,7 +166,7 @@ class TestCloudRunExecuteJobOperator:
             task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, deferrable=True
         )
 
-        event = {"status": RunJobStatus.TIMEOUT, "job_name": JOB_NAME}
+        event = {"status": RunJobStatus.TIMEOUT.value, "job_name": JOB_NAME}
 
         with pytest.raises(AirflowException) as e:
             operator.execute_complete(mock.MagicMock(), event)
@@ -183,7 +183,7 @@ class TestCloudRunExecuteJobOperator:
         error_message = "error message"
 
         event = {
-            "status": RunJobStatus.FAIL,
+            "status": RunJobStatus.FAIL.value,
             "operation_error_code": error_code,
             "operation_error_message": error_message,
             "job_name": JOB_NAME,
@@ -204,7 +204,7 @@ class TestCloudRunExecuteJobOperator:
             task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, 
job_name=JOB_NAME, deferrable=True
         )
 
-        event = {"status": RunJobStatus.SUCCESS, "job_name": JOB_NAME}
+        event = {"status": RunJobStatus.SUCCESS.value, "job_name": JOB_NAME}
 
         result = operator.execute_complete(mock.MagicMock(), event)
         assert result["name"] == JOB_NAME
diff --git a/tests/providers/google/cloud/triggers/test_cloud_run.py 
b/tests/providers/google/cloud/triggers/test_cloud_run.py
index 30d56241ed..d64c4cee10 100644
--- a/tests/providers/google/cloud/triggers/test_cloud_run.py
+++ b/tests/providers/google/cloud/triggers/test_cloud_run.py
@@ -20,13 +20,16 @@ from __future__ import annotations
 from unittest import mock
 
 import pytest
+from google.protobuf.any_pb2 import Any
+from google.rpc.status_pb2 import Status
 
-from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.triggers.cloud_run import 
CloudRunJobFinishedTrigger, RunJobStatus
 from airflow.triggers.base import TriggerEvent
 
 OPERATION_NAME = "operation"
 JOB_NAME = "jobName"
+ERROR_CODE = 13
+ERROR_MESSAGE = "Some message"
 PROJECT_ID = "projectId"
 LOCATION = "us-central1"
 GCP_CONNECTION_ID = "gcp_connection_id"
@@ -73,20 +76,21 @@ class TestCloudBatchJobFinishedTrigger:
         Tests the CloudRunJobFinishedTrigger fires once the job execution 
reaches a successful state.
         """
 
-        done = True
-        name = "name"
-        error_code = 10
-        error_message = "message"
+        async def _mock_operation(name):
+            operation = mock.MagicMock()
+            operation.done = True
+            operation.name = "name"
+            operation.error = Any()
+            operation.error.ParseFromString(b"")
+            return operation
 
-        mock_hook.return_value.get_operation.return_value = 
self._mock_operation(
-            done, name, error_code, error_message
-        )
+        mock_hook.return_value.get_operation = _mock_operation
         generator = trigger.run()
         actual = await generator.asend(None)  # type:ignore[attr-defined]
         assert (
             TriggerEvent(
                 {
-                    "status": RunJobStatus.SUCCESS,
+                    "status": RunJobStatus.SUCCESS.value,
                     "job_name": JOB_NAME,
                 }
             )
@@ -102,18 +106,28 @@ class TestCloudBatchJobFinishedTrigger:
         Tests the CloudRunJobFinishedTrigger raises an exception once the job 
execution fails.
         """
 
-        done = False
-        name = "name"
-        error_code = 10
-        error_message = "message"
+        async def _mock_operation(name):
+            operation = mock.MagicMock()
+            operation.done = True
+            operation.name = "name"
+            operation.error = Status(code=13, message="Some message")
+            return operation
 
-        mock_hook.return_value.get_operation.return_value = 
self._mock_operation(
-            done, name, error_code, error_message
-        )
+        mock_hook.return_value.get_operation = _mock_operation
         generator = trigger.run()
 
-        with pytest.raises(expected_exception=AirflowException):
-            await generator.asend(None)  # type:ignore[attr-defined]
+        actual = await generator.asend(None)  # type:ignore[attr-defined]
+        assert (
+            TriggerEvent(
+                {
+                    "status": RunJobStatus.FAIL.value,
+                    "operation_error_code": ERROR_CODE,
+                    "operation_error_message": ERROR_MESSAGE,
+                    "job_name": JOB_NAME,
+                }
+            )
+            == actual
+        )
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.triggers.cloud_run.CloudRunAsyncHook")
@@ -144,12 +158,3 @@ class TestCloudBatchJobFinishedTrigger:
             )
             == actual
         )
-
-    async def _mock_operation(self, done, name, error_code, error_message):
-        operation = mock.MagicMock()
-        operation.done = done
-        operation.name = name
-        operation.error = mock.MagicMock()
-        operation.error.message = error_message
-        operation.error.code = error_code
-        return operation

Reply via email to