amoghrajesh commented on code in PR #55542:
URL: https://github.com/apache/airflow/pull/55542#discussion_r2344315337


##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -831,6 +837,132 @@ def fake_collect_dags(self, *args, **kwargs):
         # Should log warning about no callback found
         log.warning.assert_called_once_with("Callback requested, but dag 
didn't have any", dag_id="test_dag")
 
+    @pytest.mark.parametrize(
+        "xcom_operation,expected_message_type,expected_message,mock_response",
+        [
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="report_df", 
task_ids=task_ids),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="report_df",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="single_value", 
task_ids=["test_task"]),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="single_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="direct_value", 
task_ids="test_task", map_indexes=None),
+                "GetXCom",
+                GetXCom(
+                    key="direct_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                    include_prior_dates=False,
+                ),
+                XComResult(
+                    key="direct_value",
+                    value="test",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                ),
+            ),
+        ],
+    )
+    def test_notifier_xcom_operations_send_correct_messages(
+        self,
+        spy_agency,
+        mock_supervisor_comms,
+        xcom_operation,
+        expected_message_type,
+        expected_message,
+        mock_response,
+    ):
+        """Test that different XCom operations send correct message types"""
+
+        mock_supervisor_comms.send.return_value = mock_response
+
+        class TestNotifier:
+            def __call__(self, context):
+                ti = context["ti"]
+                dag = context["dag"]
+                task_ids = list(dag.task_dict.keys())
+                xcom_operation(ti, task_ids)
+
+        with DAG(dag_id="test_dag", on_success_callback=TestNotifier()) as dag:
+            BaseOperator(task_id="test_task")
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        current_time = timezone.utcnow()
+        request = DagCallbackRequest(
+            filepath="test.py",
+            dag_id="test_dag",
+            run_id="test_run",
+            bundle_name="testing",
+            bundle_version=None,
+            context_from_server=DagRunContext(
+                dag_run=DRDataModel(
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    logical_date=current_time,
+                    data_interval_start=current_time,
+                    data_interval_end=current_time,
+                    run_after=current_time,
+                    start_date=current_time,
+                    end_date=None,
+                    run_type="manual",
+                    state="success",
+                    consumed_asset_events=[],
+                ),
+                last_ti=TIDataModel(
+                    id=uuid.uuid4(),
+                    dag_id="test_dag",
+                    task_id="test_task",
+                    run_id="test_run",
+                    map_index=-1,
+                    try_number=1,
+                    dag_version_id=uuid.uuid4(),
+                ),
+            ),
+            is_failure_callback=False,
+            msg="Test success message",
+        )
+
+        _execute_dag_callbacks(dagbag, request, structlog.get_logger())
+
+        mock_supervisor_comms.send.assert_called()
+        mock_supervisor_comms.send.assert_called_with(msg=expected_message)

Review Comment:
   Handled as ash mentioned above: 9d2e58a97f



##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -831,6 +837,132 @@ def fake_collect_dags(self, *args, **kwargs):
         # Should log warning about no callback found
         log.warning.assert_called_once_with("Callback requested, but dag 
didn't have any", dag_id="test_dag")
 
+    @pytest.mark.parametrize(
+        "xcom_operation,expected_message_type,expected_message,mock_response",
+        [
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="report_df", 
task_ids=task_ids),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="report_df",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="single_value", 
task_ids=["test_task"]),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="single_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="direct_value", 
task_ids="test_task", map_indexes=None),
+                "GetXCom",
+                GetXCom(
+                    key="direct_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                    include_prior_dates=False,
+                ),
+                XComResult(
+                    key="direct_value",
+                    value="test",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                ),
+            ),
+        ],
+    )
+    def test_notifier_xcom_operations_send_correct_messages(
+        self,
+        spy_agency,
+        mock_supervisor_comms,
+        xcom_operation,
+        expected_message_type,
+        expected_message,
+        mock_response,
+    ):
+        """Test that different XCom operations send correct message types"""
+
+        mock_supervisor_comms.send.return_value = mock_response
+
+        class TestNotifier:
+            def __call__(self, context):
+                ti = context["ti"]
+                dag = context["dag"]
+                task_ids = list(dag.task_dict.keys())

Review Comment:
   Handled in 9d2e58a97f



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to