This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b394a7229 refactor: Make sure xcoms work correctly in multi-threaded 
environment by taking the map_index into account (#40297)
9b394a7229 is described below

commit 9b394a7229484914d80fffeeb7c2d109cd58cc02
Author: David Blain <[email protected]>
AuthorDate: Tue Jun 18 16:17:12 2024 +0200

    refactor: Make sure xcoms work correctly in multi-threaded environment by 
taking the map_index into account (#40297)
    
    Co-authored-by: David Blain <[email protected]>
---
 .../providers/microsoft/azure/operators/msgraph.py | 36 +++++++++++++---------
 1 file changed, 22 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py 
b/airflow/providers/microsoft/azure/operators/msgraph.py
index 39ca32d2b6..73be47fb50 100644
--- a/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -178,8 +178,10 @@ class MSGraphAsyncOperator(BaseOperator):
                 event["response"] = result
 
                 try:
-                    self.trigger_next_link(response, 
method_name=self.pull_execute_complete.__name__)
+                    self.trigger_next_link(response=response, 
method_name=self.execute_complete.__name__)
                 except TaskDeferred as exception:
+                    self.results = self.pull_xcom(context=context)
+                    self.log.debug("value: %s", result)
                     self.append_result(
                         result=result,
                         append_result_as_list_if_absent=True,
@@ -198,8 +200,6 @@ class MSGraphAsyncOperator(BaseOperator):
         result: Any,
         append_result_as_list_if_absent: bool = False,
     ):
-        self.log.debug("value: %s", result)
-
         if isinstance(self.results, list):
             if isinstance(result, list):
                 self.results.extend(result)
@@ -214,30 +214,38 @@ class MSGraphAsyncOperator(BaseOperator):
             else:
                 self.results = result
 
-    def push_xcom(self, context: Context, value) -> None:
-        self.log.debug("do_xcom_push: %s", self.do_xcom_push)
-        if self.do_xcom_push:
-            self.log.info("Pushing XCom with key '%s': %s", self.key, value)
-            self.xcom_push(context=context, key=self.key, value=value)
+    def xcom_key(self, context: Context) -> str:
+        map_index = context["ti"].map_index
+        return f"{self.key}_{map_index}" if map_index else self.key
 
-    def pull_execute_complete(self, context: Context, event: dict[Any, Any] | 
None = None) -> Any:
-        self.results = list(
+    def pull_xcom(self, context: Context) -> list:
+        key = self.xcom_key(context=context)
+        value = list(
             self.xcom_pull(
                 context=context,
                 task_ids=self.task_id,
                 dag_id=self.dag_id,
-                key=self.key,
+                key=key,
             )
             or []
         )
+
         self.log.info(
             "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s",
             self.task_id,
             self.dag_id,
-            self.key,
-            self.results,
+            key,
+            value,
         )
-        return self.execute_complete(context, event)
+
+        return value
+
+    def push_xcom(self, context: Context, value) -> None:
+        self.log.debug("do_xcom_push: %s", self.do_xcom_push)
+        if self.do_xcom_push:
+            key = self.xcom_key(context=context)
+            self.log.info("Pushing XCom with key '%s': %s", key, value)
+            self.xcom_push(context=context, key=key, value=value)
 
     @staticmethod
     def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, 
dict[str, Any] | None]:

Reply via email to