This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8c3a30e3ffc Fix xcom_pull error when value is dataframe type (#48526)
8c3a30e3ffc is described below

commit 8c3a30e3ffc3f114c1d2cc3e6e109f4d9e29ca8b
Author: GPK <[email protected]>
AuthorDate: Sat Mar 29 11:27:36 2025 +0000

    Fix xcom_pull error when value is dataframe type (#48526)
---
 .../src/airflow/sdk/execution_time/task_runner.py     |  5 ++++-
 .../tests/task_sdk/execution_time/test_task_runner.py | 19 ++++++++++++++++++-
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 747f135389c..27cb175aa0a 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -343,7 +343,10 @@ class RuntimeTaskInstance(TaskInstance):
                 map_index=m_idx,
                 include_prior_dates=include_prior_dates,
             )
-            xcoms.append(value if value else default)
+            if value is None:
+                xcoms.append(default)
+            else:
+                xcoms.append(value)
 
         if len(xcoms) == 1:
             return xcoms[0]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 808083d1495..a41989d9279 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -30,6 +30,7 @@ from typing import TYPE_CHECKING
 from unittest import mock
 from unittest.mock import patch
 
+import pandas as pd
 import pytest
 from task_sdk import FAKE_BUNDLE
 from uuid6 import uuid7
@@ -56,6 +57,7 @@ from airflow.sdk.api.datamodels._generated import (
     TaskInstance,
     TerminalTIState,
 )
+from airflow.sdk.bases.xcom import BaseXCom
 from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model
 from airflow.sdk.definitions.param import DagParam
 from airflow.sdk.exceptions import ErrorType
@@ -1166,11 +1168,25 @@ class TestRuntimeTaskInstance:
             pytest.param(NOTSET, id="tid_not_set"),
         ],
     )
+    @pytest.mark.parametrize(
+        "xcom_values",
+        [
+            pytest.param("hello", id="string_value"),
+            pytest.param("'hello'", id="quoted_string_value"),
+            pytest.param({"key": "value"}, id="json_value"),
+            pytest.param((1, 2, 3), id="tuple_int_value"),
+            pytest.param([1, 2, 3], id="list_int_value"),
+            pytest.param(42, id="int_value"),
+            pytest.param(True, id="boolean_value"),
+            pytest.param(pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}), 
id="dataframe_value"),
+        ],
+    )
     def test_xcom_pull(
         self,
         create_runtime_ti,
         mock_supervisor_comms,
         spy_agency,
+        xcom_values,
         task_ids,
         map_indexes,
     ):
@@ -1193,7 +1209,8 @@ class TestRuntimeTaskInstance:
         extra_for_ti = {"map_index": map_indexes} if map_indexes in (1, None) 
else {}
         runtime_ti = create_runtime_ti(task=task, **extra_for_ti)
 
-        mock_supervisor_comms.get_message.return_value = XComResult(key="key", 
value='"value"')
+        ser_value = BaseXCom.serialize_value(xcom_values)
+        mock_supervisor_comms.get_message.return_value = XComResult(key="key", 
value=ser_value)
 
         run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
 

Reply via email to