This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3e229d8df5 Task adoption for hybrid executors (#39531)
3e229d8df5 is described below
commit 3e229d8df52748032e0c56503c9696f7f6d9eb62
Author: Niko Oliveira <[email protected]>
AuthorDate: Mon May 13 11:10:11 2024 -0700
Task adoption for hybrid executors (#39531)
Sort the set of tasks that are up for adoption by the executor they're
configured to run on (if any) and send them to the appropriate executor
for adoption.
---
airflow/executors/executor_loader.py | 8 ++--
airflow/jobs/scheduler_job_runner.py | 22 ++++++++++-
tests/executors/test_executor_loader.py | 19 +++++++++
tests/jobs/test_scheduler_job.py | 69 +++++++++++++++++++++++++++++++++
4 files changed, 113 insertions(+), 5 deletions(-)
diff --git a/airflow/executors/executor_loader.py
b/airflow/executors/executor_loader.py
index fb3ffce420..5fb9f90d4f 100644
--- a/airflow/executors/executor_loader.py
+++ b/airflow/executors/executor_loader.py
@@ -202,10 +202,10 @@ class ExecutorLoader:
elif executor_name := _module_to_executors.get(executor_name_str):
return executor_name
else:
- raise AirflowException(f"Unknown executor being loaded:
{executor_name}")
+ raise AirflowException(f"Unknown executor being loaded:
{executor_name_str}")
@classmethod
- def load_executor(cls, executor_name: ExecutorName | str) -> BaseExecutor:
+ def load_executor(cls, executor_name: ExecutorName | str | None) ->
BaseExecutor:
"""
Load the executor.
@@ -217,7 +217,9 @@ class ExecutorLoader:
:return: an instance of executor class via executor_name
"""
- if isinstance(executor_name, str):
+ if not executor_name:
+ _executor_name = cls.get_default_executor_name()
+ elif isinstance(executor_name, str):
_executor_name = cls.lookup_executor_name_by_str(executor_name)
else:
_executor_name = executor_name
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 49a065b5f5..f2333e8d5a 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -24,7 +24,7 @@ import signal
import sys
import time
import warnings
-from collections import Counter
+from collections import Counter, defaultdict
from dataclasses import dataclass
from datetime import timedelta
from functools import lru_cache, partial
@@ -83,6 +83,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session
from airflow.dag_processing.manager import DagFileProcessorAgent
+ from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.sqlalchemy import (
CommitProhibitorGuard,
@@ -1651,7 +1652,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# Lock these rows, so that another scheduler can't try and
adopt these too
tis_to_adopt_or_reset = with_row_locks(query, of=TI,
session=session, skip_locked=True)
tis_to_adopt_or_reset =
session.scalars(tis_to_adopt_or_reset).all()
- to_reset =
self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset)
+
+ to_reset: list[TaskInstance] = []
+ exec_to_tis = self._executor_to_tis(tis_to_adopt_or_reset)
+ for executor, tis in exec_to_tis.items():
+ to_reset.extend(executor.try_adopt_task_instances(tis))
reset_tis_message = []
for ti in to_reset:
@@ -1831,3 +1836,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
updated_count = sum(self._set_orphaned(dataset) for dataset in
orphaned_dataset_query)
Stats.gauge("dataset.orphaned", updated_count)
+
+ def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor,
list[TaskInstance]]:
+ """Organize TIs into lists per their respective executor."""
+ _executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] =
defaultdict(list)
+ executor: str | None
+ for ti in tis:
+ if ti.executor:
+ executor = str(ti.executor)
+ else:
+ executor = None
+ _executor_to_tis[ExecutorLoader.load_executor(executor)].append(ti)
+
+ return _executor_to_tis
diff --git a/tests/executors/test_executor_loader.py
b/tests/executors/test_executor_loader.py
index 840e74a8fc..bb7da133b6 100644
--- a/tests/executors/test_executor_loader.py
+++ b/tests/executors/test_executor_loader.py
@@ -26,6 +26,7 @@ from airflow import plugins_manager
from airflow.exceptions import AirflowConfigException
from airflow.executors import executor_loader
from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader,
ExecutorName
+from airflow.executors.local_executor import LocalExecutor
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from tests.test_utils.config import conf_vars
@@ -301,3 +302,21 @@ class TestExecutorLoader:
monkeypatch.delenv("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK")
with expectation:
ExecutorLoader.validate_database_executor_compatibility(executor)
+
+ def test_load_executor(self):
+ ExecutorLoader.block_use_of_hybrid_exec = mock.Mock()
+ with conf_vars({("core", "executor"): "LocalExecutor"}):
+ ExecutorLoader.init_executors()
+ assert isinstance(ExecutorLoader.load_executor("LocalExecutor"),
LocalExecutor)
+ assert
isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]),
LocalExecutor)
+ assert isinstance(ExecutorLoader.load_executor(None),
LocalExecutor)
+
+ def test_load_executor_alias(self):
+ ExecutorLoader.block_use_of_hybrid_exec = mock.Mock()
+ with conf_vars({("core", "executor"):
"local_exec:airflow.executors.local_executor.LocalExecutor"}):
+ ExecutorLoader.init_executors()
+ assert isinstance(ExecutorLoader.load_executor("local_exec"),
LocalExecutor)
+ assert isinstance(
+
ExecutorLoader.load_executor("airflow.executors.local_executor.LocalExecutor"),
LocalExecutor
+ )
+ assert
isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]),
LocalExecutor)
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 491e345649..85399892ac 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -23,6 +23,7 @@ import logging
import os
from collections import deque
from datetime import timedelta
+from importlib import reload
from typing import Generator
from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch
@@ -165,6 +166,18 @@ class TestSchedulerJob:
self.null_exec = None
del self.dagbag
+ @pytest.fixture
+ def mock_executors(self):
+ default_executor = mock.MagicMock(slots_available=8, slots_occupied=0)
+ default_executor.name = MagicMock(alias="default_exec",
module_path="default.exec.module.path")
+ second_executor = mock.MagicMock(slots_available=8, slots_occupied=0)
+ second_executor.name = MagicMock(alias="secondary_exec",
module_path="secondary.exec.module.path")
+ with mock.patch("airflow.jobs.job.Job.executors",
new_callable=PropertyMock) as executors_mock:
+ with mock.patch("airflow.jobs.job.Job.executor",
new_callable=PropertyMock) as executor_mock:
+ executor_mock.return_value = default_executor
+ executors_mock.return_value = [default_executor,
second_executor]
+ yield [default_executor, second_executor]
+
@pytest.mark.parametrize(
"configs",
[
@@ -1740,6 +1753,62 @@ class TestSchedulerJob:
ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session)
assert ti2.state == State.QUEUED, "Tasks run by Backfill Jobs should
not be reset"
+ def test_adopt_or_reset_orphaned_tasks_multiple_executors(self, dag_maker,
mock_executors):
+ """Test that with multiple executors configured tasks are sorted
correctly and handed off to the
+ correct executor for adoption."""
+ session = settings.Session()
+ with
dag_maker("test_execute_helper_reset_orphaned_tasks_multiple_executors"):
+ op1 = EmptyOperator(task_id="op1")
+ op2 = EmptyOperator(task_id="op2", executor="default_exec")
+ op3 = EmptyOperator(task_id="op3", executor="secondary_exec")
+
+ dr = dag_maker.create_dagrun()
+ scheduler_job = Job()
+ session.add(scheduler_job)
+ session.commit()
+ ti1 = dr.get_task_instance(task_id=op1.task_id, session=session)
+ ti2 = dr.get_task_instance(task_id=op2.task_id, session=session)
+ ti3 = dr.get_task_instance(task_id=op3.task_id, session=session)
+ tis = [ti1, ti2, ti3]
+ for ti in tis:
+ ti.state = State.QUEUED
+ ti.queued_by_job_id = scheduler_job.id
+ session.commit()
+
+ with
mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as
loader_mock:
+ # reload the scheduler_job_runner module so that it loads a fresh
executor_loader module which
+ # contains the mocked load_executor method.
+ from airflow.jobs import scheduler_job_runner
+
+ reload(scheduler_job_runner)
+
+ processor = mock.MagicMock()
+
+ new_scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=new_scheduler_job,
num_runs=0)
+ self.job_runner.processor_agent = processor
+ # The executors are mocked, so cannot be loaded/imported. Mock
load_executor and return the
+ # correct object for the given input executor name.
+ loader_mock.side_effect = lambda *x: {
+ ("default_exec",): mock_executors[0],
+ (None,): mock_executors[0],
+ ("secondary_exec",): mock_executors[1],
+ }[x]
+
+ self.job_runner.adopt_or_reset_orphaned_tasks()
+
+ # Default executor is called for ti1 (no explicit executor override
uses default) and ti2 (where we
+ # explicitly marked that for execution by the default executor)
+ try:
+
mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti1, ti2])
+ except AssertionError:
+ # The order of the TIs given to try_adopt_task_instances is not
consistent, so check the other
+ # order first before allowing AssertionError to fail the test
+
mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti2, ti1])
+
+ # Second executor called for ti3
+
mock_executors[1].try_adopt_task_instances.assert_called_once_with([ti3])
+
def test_fail_stuck_queued_tasks(self, dag_maker, session):
with dag_maker("test_fail_stuck_queued_tasks"):
op1 = EmptyOperator(task_id="op1")