This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9b394a7229 refactor: Make sure xcoms work correctly in multi-threaded
environment by taking the map_index into account (#40297)
9b394a7229 is described below
commit 9b394a7229484914d80fffeeb7c2d109cd58cc02
Author: David Blain <[email protected]>
AuthorDate: Tue Jun 18 16:17:12 2024 +0200
refactor: Make sure xcoms work correctly in multi-threaded environment by
taking the map_index into account (#40297)
Co-authored-by: David Blain <[email protected]>
---
.../providers/microsoft/azure/operators/msgraph.py | 36 +++++++++++++---------
1 file changed, 22 insertions(+), 14 deletions(-)
diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py
b/airflow/providers/microsoft/azure/operators/msgraph.py
index 39ca32d2b6..73be47fb50 100644
--- a/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -178,8 +178,10 @@ class MSGraphAsyncOperator(BaseOperator):
event["response"] = result
try:
- self.trigger_next_link(response,
method_name=self.pull_execute_complete.__name__)
+ self.trigger_next_link(response=response,
method_name=self.execute_complete.__name__)
except TaskDeferred as exception:
+ self.results = self.pull_xcom(context=context)
+ self.log.debug("value: %s", result)
self.append_result(
result=result,
append_result_as_list_if_absent=True,
@@ -198,8 +200,6 @@ class MSGraphAsyncOperator(BaseOperator):
result: Any,
append_result_as_list_if_absent: bool = False,
):
- self.log.debug("value: %s", result)
-
if isinstance(self.results, list):
if isinstance(result, list):
self.results.extend(result)
@@ -214,30 +214,38 @@ class MSGraphAsyncOperator(BaseOperator):
else:
self.results = result
- def push_xcom(self, context: Context, value) -> None:
- self.log.debug("do_xcom_push: %s", self.do_xcom_push)
- if self.do_xcom_push:
- self.log.info("Pushing XCom with key '%s': %s", self.key, value)
- self.xcom_push(context=context, key=self.key, value=value)
+ def xcom_key(self, context: Context) -> str:
+ map_index = context["ti"].map_index
+ return f"{self.key}_{map_index}" if map_index else self.key
- def pull_execute_complete(self, context: Context, event: dict[Any, Any] |
None = None) -> Any:
- self.results = list(
+ def pull_xcom(self, context: Context) -> list:
+ key = self.xcom_key(context=context)
+ value = list(
self.xcom_pull(
context=context,
task_ids=self.task_id,
dag_id=self.dag_id,
- key=self.key,
+ key=key,
)
or []
)
+
self.log.info(
"Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s",
self.task_id,
self.dag_id,
- self.key,
- self.results,
+ key,
+ value,
)
- return self.execute_complete(context, event)
+
+ return value
+
+ def push_xcom(self, context: Context, value) -> None:
+ self.log.debug("do_xcom_push: %s", self.do_xcom_push)
+ if self.do_xcom_push:
+ key = self.xcom_key(context=context)
+ self.log.info("Pushing XCom with key '%s': %s", key, value)
+ self.xcom_push(context=context, key=key, value=value)
@staticmethod
def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any,
dict[str, Any] | None]: