pierrejeambrun commented on code in PR #44366:
URL: https://github.com/apache/airflow/pull/44366#discussion_r1858139840


##########
tests/api_fastapi/core_api/routes/public/test_xcom.py:
##########
@@ -203,3 +208,279 @@ def test_custom_xcom_deserialize(
         else:
             assert response.status_code == 200
             assert response.json()["value"] == expected_status_or_value
+
+
+class TestGetXComEntries(TestXComEndpoint):
+    def test_should_respond_200(self, test_client):
+        self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, 
TEST_TASK_ID)
+        response = test_client.get(
+            
f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries"
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+        for xcom_entry in response_data["xcom_entries"]:
+            xcom_entry["timestamp"] = "TIMESTAMP"
+
+        expected_response = {
+            "xcom_entries": [
+                {
+                    "dag_id": TEST_DAG_ID,
+                    "logical_date": logical_date_formatted,
+                    "key": f"{TEST_XCOM_KEY}-0",
+                    "task_id": TEST_TASK_ID,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": -1,
+                },
+                {
+                    "dag_id": TEST_DAG_ID,
+                    "logical_date": logical_date_formatted,
+                    "key": f"{TEST_XCOM_KEY}-1",
+                    "task_id": TEST_TASK_ID,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": -1,
+                },
+            ],
+            "total_entries": 2,
+        }
+        assert response_data == expected_response
+
+    def test_should_respond_200_with_tilde(self, test_client):
+        self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, 
TEST_TASK_ID)
+        self._create_xcom_entries(TEST_DAG_ID_2, run_id, logical_date_parsed, 
TEST_TASK_ID_2)
+
+        response = 
test_client.get("/public/dags/~/dagRuns/~/taskInstances/~/xcomEntries")
+        assert response.status_code == 200
+        response_data = response.json()
+        for xcom_entry in response_data["xcom_entries"]:
+            xcom_entry["timestamp"] = "TIMESTAMP"
+
+        expected_response = {
+            "xcom_entries": [
+                {
+                    "dag_id": TEST_DAG_ID,
+                    "logical_date": logical_date_formatted,
+                    "key": f"{TEST_XCOM_KEY}-0",
+                    "task_id": TEST_TASK_ID,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": -1,
+                },
+                {
+                    "dag_id": TEST_DAG_ID,
+                    "logical_date": logical_date_formatted,
+                    "key": f"{TEST_XCOM_KEY}-1",
+                    "task_id": TEST_TASK_ID,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": -1,
+                },
+                {
+                    "dag_id": TEST_DAG_ID_2,
+                    "logical_date": logical_date_formatted,
+                    "key": f"{TEST_XCOM_KEY}-0",
+                    "task_id": TEST_TASK_ID_2,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": -1,
+                },
+                {
+                    "dag_id": TEST_DAG_ID_2,
+                    "logical_date": logical_date_formatted,
+                    "key": f"{TEST_XCOM_KEY}-1",
+                    "task_id": TEST_TASK_ID_2,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": -1,
+                },
+            ],
+            "total_entries": 4,
+        }
+        assert response_data == expected_response
+
+    @pytest.mark.parametrize("map_index", (0, 1, None))
+    def test_should_respond_200_with_map_index(self, map_index, test_client):
+        self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, 
TEST_TASK_ID, mapped_ti=True)
+
+        response = test_client.get(
+            "/public/dags/~/dagRuns/~/taskInstances/~/xcomEntries",
+            params={"map_index": map_index} if map_index is not None else None,
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+
+        if map_index is None:
+            expected_entries = [
+                {
+                    "dag_id": TEST_DAG_ID,
+                    "logical_date": logical_date_formatted,
+                    "key": TEST_XCOM_KEY,
+                    "task_id": TEST_TASK_ID,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": idx,
+                }
+                for idx in range(2)
+            ]
+        else:
+            expected_entries = [
+                {
+                    "dag_id": TEST_DAG_ID,
+                    "logical_date": logical_date_formatted,
+                    "key": TEST_XCOM_KEY,
+                    "task_id": TEST_TASK_ID,
+                    "timestamp": "TIMESTAMP",
+                    "map_index": map_index,
+                }
+            ]
+        for xcom_entry in response_data["xcom_entries"]:
+            xcom_entry["timestamp"] = "TIMESTAMP"
+        assert response_data == {
+            "xcom_entries": expected_entries,
+            "total_entries": len(expected_entries),
+        }
+
+    @pytest.mark.parametrize(
+        "key, expected_entries",
+        [
+            (
+                TEST_XCOM_KEY,
+                [
+                    {
+                        "dag_id": TEST_DAG_ID,
+                        "logical_date": logical_date_formatted,
+                        "key": TEST_XCOM_KEY,
+                        "task_id": TEST_TASK_ID,
+                        "timestamp": "TIMESTAMP",
+                        "map_index": 0,
+                    },
+                    {
+                        "dag_id": TEST_DAG_ID,
+                        "logical_date": logical_date_formatted,
+                        "key": TEST_XCOM_KEY,
+                        "task_id": TEST_TASK_ID,
+                        "timestamp": "TIMESTAMP",
+                        "map_index": 1,
+                    },
+                ],
+            ),
+            (f"{TEST_XCOM_KEY}-0", []),
+        ],
+    )
+    def test_should_respond_200_with_xcom_key(self, key, expected_entries, 
test_client):
+        self._create_xcom_entries(TEST_DAG_ID, run_id, logical_date_parsed, 
TEST_TASK_ID, mapped_ti=True)
+        response = test_client.get(
+            "/public/dags/~/dagRuns/~/taskInstances/~/xcomEntries",
+            params={"xcom_key": key} if key is not None else None,
+        )
+
+        assert response.status_code == 200
+        response_data = response.json()
+        for xcom_entry in response_data["xcom_entries"]:
+            xcom_entry["timestamp"] = "TIMESTAMP"
+        assert response_data == {
+            "xcom_entries": expected_entries,
+            "total_entries": len(expected_entries),
+        }
+
+    @provide_session
+    def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, 
mapped_ti=False, session=None):
+        dag = DagModel(dag_id=dag_id)
+        session.add(dag)
+        dagrun = DagRun(
+            dag_id=dag_id,
+            run_id=run_id,
+            logical_date=logical_date,
+            start_date=logical_date,
+            run_type=DagRunType.MANUAL,
+        )
+        session.add(dagrun)
+        if mapped_ti:
+            for i in [0, 1]:
+                ti = TaskInstance(EmptyOperator(task_id=task_id), 
run_id=run_id, map_index=i)
+                ti.dag_id = dag_id
+                session.add(ti)
+        else:
+            ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
+            ti.dag_id = dag_id
+            session.add(ti)
+        session.commit()
+
+        for i in [0, 1]:
+            if mapped_ti:
+                key = TEST_XCOM_KEY
+                map_index = i
+            else:
+                key = f"{TEST_XCOM_KEY}-{i}"
+                map_index = -1
+
+            XCom.set(
+                key=key,
+                value=TEST_XCOM_VALUE,
+                run_id=run_id,
+                task_id=task_id,
+                dag_id=dag_id,
+                map_index=map_index,
+            )
+
+    @pytest.fixture(autouse=True)
+    def setup(self) -> None:
+        self.clear_db()

Review Comment:
   Nit: Can you move this setup/init code at the beginning of the test class 
please.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to