jedcunningham commented on code in PR #53058: URL: https://github.com/apache/airflow/pull/53058#discussion_r2193907922
########## airflow-core/tests/unit/dag_processing/test_processor.py: ########## @@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs): spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + # Create a minimal TaskInstance for the request + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="a", + task_id="b", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ) + requests = [ TaskCallbackRequest( filepath="A", msg="Message", - ti=None, + ti=ti_data, bundle_name="testing", bundle_version=None, ) ] - _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) + _parse_file( + DagFileParseRequest(file="A", bundle_path="test", callback_requests=requests), + log=structlog.get_logger(), + ) assert called is True + + +class TestExecuteTaskCallbacks: + """Test the _execute_task_callbacks function""" + + def test_execute_task_callbacks_failure_callback(self, spy_agency): + """Test _execute_task_callbacks executes failure callbacks""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_retry_callback(self, spy_agency): + """Test _execute_task_callbacks executes retry callbacks""" + called = False + context_received = None + + def on_retry(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.UP_FOR_RETRY, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task retrying", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.UP_FOR_RETRY, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_with_context_from_server(self, spy_agency): + """Test _execute_task_callbacks with context_from_server creates full context""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + # Create a mock DagRun Review Comment: ```suggestion ``` ########## airflow-core/tests/unit/dag_processing/test_processor.py: ########## @@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs): spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + # Create a minimal TaskInstance for the request + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="a", + task_id="b", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ) + requests = [ TaskCallbackRequest( filepath="A", msg="Message", - ti=None, + ti=ti_data, bundle_name="testing", bundle_version=None, ) ] - _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) + _parse_file( + DagFileParseRequest(file="A", bundle_path="test", callback_requests=requests), + log=structlog.get_logger(), + ) assert called is True + + +class TestExecuteTaskCallbacks: + """Test the _execute_task_callbacks function""" + + def test_execute_task_callbacks_failure_callback(self, spy_agency): + """Test _execute_task_callbacks executes failure callbacks""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_retry_callback(self, spy_agency): + """Test _execute_task_callbacks executes retry callbacks""" + called = False + context_received = None + + def on_retry(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry) Review Comment: ```suggestion BaseOperator(task_id="test_task", on_retry_callback=on_retry) ``` ########## airflow-core/src/airflow/jobs/scheduler_job_runner.py: ########## @@ -927,10 +928,13 @@ def process_executor_events( bundle_version=ti.dag_version.bundle_version, ti=ti, msg=msg, + context_from_server=TIRunContext( + dag_run=ti.dag_run, + max_tries=ti.max_tries, + ), ) executor.send_callback(request) - else: - ti.handle_failure(error=msg, session=session) + ti.handle_failure(error=msg, session=session) Review Comment: Was this intentional? Don't think so? ########## airflow-core/tests/unit/dag_processing/test_processor.py: ########## @@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs): spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + # Create a minimal TaskInstance for the request + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="a", + task_id="b", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ) + requests = [ TaskCallbackRequest( filepath="A", msg="Message", - ti=None, + ti=ti_data, bundle_name="testing", bundle_version=None, ) ] - _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) + _parse_file( + DagFileParseRequest(file="A", bundle_path="test", callback_requests=requests), + log=structlog.get_logger(), + ) assert called is True + + +class TestExecuteTaskCallbacks: + """Test the _execute_task_callbacks function""" + + def test_execute_task_callbacks_failure_callback(self, spy_agency): + """Test _execute_task_callbacks executes failure callbacks""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_retry_callback(self, spy_agency): + """Test _execute_task_callbacks executes retry callbacks""" + called = False + context_received = None + + def on_retry(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.UP_FOR_RETRY, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task retrying", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.UP_FOR_RETRY, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_with_context_from_server(self, spy_agency): + """Test _execute_task_callbacks with context_from_server creates full context""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + # Create a mock DagRun + dag_run = DagRun( + dag_id="test_dag", + run_id="test_run", + logical_date=timezone.utcnow(), + start_date=timezone.utcnow(), + run_type="manual", + ) + dag_run.run_after = timezone.utcnow() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + context_from_server = TIRunContext( + dag_run=dag_run, + max_tries=3, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + context_from_server=context_from_server, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + # When context_from_server is provided, we get a full RuntimeTaskInstance context + assert "dag_run" in context_received + assert "logical_date" in context_received + + def test_execute_task_callbacks_not_failure_callback(self, spy_agency): + """Test _execute_task_callbacks when request is not a failure callback""" + called = False + + def on_failure(context): + nonlocal called + called = True + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.SUCCESS, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task succeeded", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.SUCCESS, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + # Should not call the callback since it's not a failure callback + assert called is False + + def test_execute_task_callbacks_multiple_callbacks(self, spy_agency): + """Test _execute_task_callbacks with multiple callbacks""" + call_count = 0 + + def on_failure_1(context): + nonlocal call_count + call_count += 1 + + def on_failure_2(context): + nonlocal call_count + call_count += 1 + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=[on_failure_1, on_failure_2]) Review Comment: ```suggestion BaseOperator(task_id="test_task", on_failure_callback=[on_failure_1, on_failure_2]) ``` ########## airflow-core/tests/unit/jobs/test_scheduler_job.py: ########## @@ -6621,6 +6625,75 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( for i in range(100): assert f"it's duplicate {i}" in dag_warning.message + def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling heartbeat timeouts.""" + with dag_maker(dag_id="test_dag", session=session): + _ = EmptyOperator(task_id="test_task") + + dag_run = dag_maker.create_dagrun(run_id="test_run", state=DagRunState.RUNNING) + + mock_executor = MagicMock() + scheduler_job = Job(executor=mock_executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + # Create a task instance that appears to be running but hasn't heartbeat + ti = dag_run.get_task_instance(task_id="test_task") + ti.state = TaskInstanceState.RUNNING + ti.queued_by_job_id = scheduler_job.id + # Set last_heartbeat_at to a time that would trigger timeout + ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600) # 10 minutes ago + session.merge(ti) + session.commit() + + # Run the heartbeat timeout check + self.job_runner._find_and_purge_task_instances_without_heartbeats() + + # Verify TaskCallbackRequest was created with context_from_server + mock_executor.send_callback.assert_called_once() + callback_request = mock_executor.send_callback.call_args[0][0] + + assert isinstance(callback_request, TaskCallbackRequest) + assert callback_request.context_from_server is not None + assert callback_request.context_from_server.dag_run.logical_date == dag_run.logical_date + assert callback_request.context_from_server.max_tries == ti.max_tries + + def test_scheduler_passes_context_from_server_on_task_failure(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling task failures.""" + with dag_maker(dag_id="test_dag", session=session): + _ = EmptyOperator(task_id="test_task", on_failure_callback=lambda: print("failure")) Review Comment: ```suggestion EmptyOperator(task_id="test_task", on_failure_callback=lambda: print("failure")) ``` ########## airflow-core/tests/unit/jobs/test_scheduler_job.py: ########## @@ -6621,6 +6625,75 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( for i in range(100): assert f"it's duplicate {i}" in dag_warning.message + def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling heartbeat timeouts.""" + with dag_maker(dag_id="test_dag", session=session): + _ = EmptyOperator(task_id="test_task") + + dag_run = dag_maker.create_dagrun(run_id="test_run", state=DagRunState.RUNNING) + + mock_executor = MagicMock() + scheduler_job = Job(executor=mock_executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + # Create a task instance that appears to be running but hasn't heartbeat + ti = dag_run.get_task_instance(task_id="test_task") + ti.state = TaskInstanceState.RUNNING + ti.queued_by_job_id = scheduler_job.id + # Set last_heartbeat_at to a time that would trigger timeout + ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600) # 10 minutes ago + session.merge(ti) + session.commit() + + # Run the heartbeat timeout check + self.job_runner._find_and_purge_task_instances_without_heartbeats() + + # Verify TaskCallbackRequest was created with context_from_server + mock_executor.send_callback.assert_called_once() + callback_request = mock_executor.send_callback.call_args[0][0] + + assert isinstance(callback_request, TaskCallbackRequest) + assert callback_request.context_from_server is not None + assert callback_request.context_from_server.dag_run.logical_date == dag_run.logical_date + assert callback_request.context_from_server.max_tries == ti.max_tries + + def test_scheduler_passes_context_from_server_on_task_failure(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling task failures.""" + with dag_maker(dag_id="test_dag", session=session): + _ = EmptyOperator(task_id="test_task", on_failure_callback=lambda: print("failure")) + + dag_run = dag_maker.create_dagrun(run_id="test_run", state=DagRunState.RUNNING) + + # Create a task instance that's running + ti = dag_run.get_task_instance(task_id="test_task") + ti.state = TaskInstanceState.RUNNING + # ti.queued_by_job_id = 90000 Review Comment: ```suggestion ``` ########## airflow-core/tests/unit/dag_processing/test_processor.py: ########## @@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs): spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + # Create a minimal TaskInstance for the request + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="a", + task_id="b", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ) + requests = [ TaskCallbackRequest( filepath="A", msg="Message", - ti=None, + ti=ti_data, bundle_name="testing", bundle_version=None, ) ] - _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) + _parse_file( + DagFileParseRequest(file="A", bundle_path="test", callback_requests=requests), + log=structlog.get_logger(), + ) assert called is True + + +class TestExecuteTaskCallbacks: + """Test the _execute_task_callbacks function""" + + def test_execute_task_callbacks_failure_callback(self, spy_agency): + """Test _execute_task_callbacks executes failure callbacks""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_retry_callback(self, spy_agency): + """Test _execute_task_callbacks executes retry callbacks""" + called = False + context_received = None + + def on_retry(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.UP_FOR_RETRY, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task retrying", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.UP_FOR_RETRY, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_with_context_from_server(self, spy_agency): + """Test _execute_task_callbacks with context_from_server creates full context""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) Review Comment: ```suggestion BaseOperator(task_id="test_task", on_failure_callback=on_failure) ``` ########## airflow-core/tests/unit/dag_processing/test_processor.py: ########## @@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs): spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + # Create a minimal TaskInstance for the request + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="a", + task_id="b", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ) + requests = [ TaskCallbackRequest( filepath="A", msg="Message", - ti=None, + ti=ti_data, bundle_name="testing", bundle_version=None, ) ] - _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) + _parse_file( + DagFileParseRequest(file="A", bundle_path="test", callback_requests=requests), + log=structlog.get_logger(), + ) assert called is True + + +class TestExecuteTaskCallbacks: + """Test the _execute_task_callbacks function""" + + def test_execute_task_callbacks_failure_callback(self, spy_agency): + """Test _execute_task_callbacks executes failure callbacks""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) Review Comment: ```suggestion BaseOperator(task_id="test_task", on_failure_callback=on_failure) ``` Shouldn't need this. ########## airflow-core/tests/unit/jobs/test_scheduler_job.py: ########## @@ -6621,6 +6625,75 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( for i in range(100): assert f"it's duplicate {i}" in dag_warning.message + def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling heartbeat timeouts.""" + with dag_maker(dag_id="test_dag", session=session): + _ = EmptyOperator(task_id="test_task") Review Comment: ```suggestion EmptyOperator(task_id="test_task") ``` ########## airflow-core/tests/unit/jobs/test_scheduler_job.py: ########## @@ -6621,6 +6625,75 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( for i in range(100): assert f"it's duplicate {i}" in dag_warning.message + def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling heartbeat timeouts.""" + with dag_maker(dag_id="test_dag", session=session): + _ = EmptyOperator(task_id="test_task") + + dag_run = dag_maker.create_dagrun(run_id="test_run", state=DagRunState.RUNNING) + + mock_executor = MagicMock() + scheduler_job = Job(executor=mock_executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + # Create a task instance that appears to be running but hasn't heartbeat + ti = dag_run.get_task_instance(task_id="test_task") + ti.state = TaskInstanceState.RUNNING + ti.queued_by_job_id = scheduler_job.id + # Set last_heartbeat_at to a time that would trigger timeout + ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600) # 10 minutes ago + session.merge(ti) + session.commit() + + # Run the heartbeat timeout check + self.job_runner._find_and_purge_task_instances_without_heartbeats() + + # Verify TaskCallbackRequest was created with context_from_server + mock_executor.send_callback.assert_called_once() + callback_request = mock_executor.send_callback.call_args[0][0] + + assert isinstance(callback_request, TaskCallbackRequest) + assert callback_request.context_from_server is not None + assert callback_request.context_from_server.dag_run.logical_date == dag_run.logical_date + assert callback_request.context_from_server.max_tries == ti.max_tries + + def test_scheduler_passes_context_from_server_on_task_failure(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling task failures.""" + with dag_maker(dag_id="test_dag", session=session): + _ = EmptyOperator(task_id="test_task", on_failure_callback=lambda: print("failure")) + + dag_run = dag_maker.create_dagrun(run_id="test_run", state=DagRunState.RUNNING) + + # Create a task instance that's running + ti = dag_run.get_task_instance(task_id="test_task") + ti.state = TaskInstanceState.RUNNING + # ti.queued_by_job_id = 90000 + session.merge(ti) + session.commit() + + # Mock the executor to simulate a task failure + mock_executor = MagicMock(spec=BaseExecutor) + mock_executor.has_task = mock.MagicMock(return_value=False) + scheduler_job = Job(executor=mock_executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + # Simulate executor reporting task as failed + # from airflow.executors.base_executor import TaskInstanceStateType Review Comment: ```suggestion ``` ########## airflow-core/tests/unit/dag_processing/test_processor.py: ########## @@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs): spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + # Create a minimal TaskInstance for the request + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="a", + task_id="b", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ) + requests = [ TaskCallbackRequest( filepath="A", msg="Message", - ti=None, + ti=ti_data, bundle_name="testing", bundle_version=None, ) ] - _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) + _parse_file( + DagFileParseRequest(file="A", bundle_path="test", callback_requests=requests), + log=structlog.get_logger(), + ) assert called is True + + +class TestExecuteTaskCallbacks: + """Test the _execute_task_callbacks function""" + + def test_execute_task_callbacks_failure_callback(self, spy_agency): + """Test _execute_task_callbacks executes failure callbacks""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_retry_callback(self, spy_agency): + """Test _execute_task_callbacks executes retry callbacks""" + called = False + context_received = None + + def on_retry(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.UP_FOR_RETRY, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task retrying", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.UP_FOR_RETRY, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_with_context_from_server(self, spy_agency): + """Test _execute_task_callbacks with context_from_server creates full context""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + # Create a mock DagRun + dag_run = DagRun( + dag_id="test_dag", + run_id="test_run", + logical_date=timezone.utcnow(), + start_date=timezone.utcnow(), + run_type="manual", + ) + dag_run.run_after = timezone.utcnow() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + context_from_server = TIRunContext( + dag_run=dag_run, + max_tries=3, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + context_from_server=context_from_server, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + # When context_from_server is provided, we get a full RuntimeTaskInstance context + assert "dag_run" in context_received + assert "logical_date" in context_received + + def test_execute_task_callbacks_not_failure_callback(self, spy_agency): + """Test _execute_task_callbacks when request is not a failure callback""" + called = False + + def on_failure(context): + nonlocal called + called = True + + with DAG(dag_id="test_dag") as dag: + _ = BaseOperator(task_id="test_task", on_failure_callback=on_failure) Review Comment: ```suggestion BaseOperator(task_id="test_task", on_failure_callback=on_failure) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org