This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8c3a30e3ffc Fix xcom_pull error when value is dataframe type (#48526)
8c3a30e3ffc is described below
commit 8c3a30e3ffc3f114c1d2cc3e6e109f4d9e29ca8b
Author: GPK <[email protected]>
AuthorDate: Sat Mar 29 11:27:36 2025 +0000
Fix xcom_pull error when value is dataframe type (#48526)
---
.../src/airflow/sdk/execution_time/task_runner.py | 5 ++++-
.../tests/task_sdk/execution_time/test_task_runner.py | 19 ++++++++++++++++++-
2 files changed, 22 insertions(+), 2 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 747f135389c..27cb175aa0a 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -343,7 +343,10 @@ class RuntimeTaskInstance(TaskInstance):
map_index=m_idx,
include_prior_dates=include_prior_dates,
)
- xcoms.append(value if value else default)
+ if value is None:
+ xcoms.append(default)
+ else:
+ xcoms.append(value)
if len(xcoms) == 1:
return xcoms[0]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 808083d1495..a41989d9279 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -30,6 +30,7 @@ from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch
+import pandas as pd
import pytest
from task_sdk import FAKE_BUNDLE
from uuid6 import uuid7
@@ -56,6 +57,7 @@ from airflow.sdk.api.datamodels._generated import (
TaskInstance,
TerminalTIState,
)
+from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model
from airflow.sdk.definitions.param import DagParam
from airflow.sdk.exceptions import ErrorType
@@ -1166,11 +1168,25 @@ class TestRuntimeTaskInstance:
pytest.param(NOTSET, id="tid_not_set"),
],
)
+ @pytest.mark.parametrize(
+ "xcom_values",
+ [
+ pytest.param("hello", id="string_value"),
+ pytest.param("'hello'", id="quoted_string_value"),
+ pytest.param({"key": "value"}, id="json_value"),
+ pytest.param((1, 2, 3), id="tuple_int_value"),
+ pytest.param([1, 2, 3], id="list_int_value"),
+ pytest.param(42, id="int_value"),
+ pytest.param(True, id="boolean_value"),
+ pytest.param(pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}),
id="dataframe_value"),
+ ],
+ )
def test_xcom_pull(
self,
create_runtime_ti,
mock_supervisor_comms,
spy_agency,
+ xcom_values,
task_ids,
map_indexes,
):
@@ -1193,7 +1209,8 @@ class TestRuntimeTaskInstance:
extra_for_ti = {"map_index": map_indexes} if map_indexes in (1, None)
else {}
runtime_ti = create_runtime_ti(task=task, **extra_for_ti)
- mock_supervisor_comms.get_message.return_value = XComResult(key="key",
value='"value"')
+ ser_value = BaseXCom.serialize_value(xcom_values)
+ mock_supervisor_comms.get_message.return_value = XComResult(key="key",
value=ser_value)
run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())