This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit cd4dc89e4dbe985eb01e2480ea2ff00888bc93a7 Author: Karthikeyan Singaravelan <[email protected]> AuthorDate: Sun May 1 00:41:07 2022 +0530 Override pool for TaskInstance when pool is passed from cli. (#23258) (cherry picked from commit 3970ea386d5e0a371143ad1e69b897fd1262842d) --- airflow/cli/commands/task_command.py | 5 +++-- tests/cli/commands/test_task_command.py | 21 +++++++++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 536d83bbe0..ea20ebb646 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -137,6 +137,7 @@ def _get_ti( exec_date_or_run_id: str, map_index: int, *, + pool: Optional[str] = None, create_if_necessary: CreateIfNecessary = False, session: Session = NEW_SESSION, ) -> Tuple[TaskInstance, bool]: @@ -165,7 +166,7 @@ def _get_ti( ti.dag_run = dag_run else: ti = ti_or_none - ti.refresh_from_task(task) + ti.refresh_from_task(task, pool_override=pool) return ti, dr_created @@ -361,7 +362,7 @@ def task_run(args, dag=None): # Use DAG from parameter pass task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index) + ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, pool=args.pool) ti.init_run_context(raw=args.raw) hostname = get_hostname() diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 3cf5bae778..481103f1f3 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -35,14 +35,14 @@ from airflow.cli import cli_parser from airflow.cli.commands import task_command from airflow.configuration import conf from airflow.exceptions import AirflowException, DagRunNotFound -from airflow.models import DagBag, DagRun, TaskInstance +from airflow.models import DagBag, DagRun, Pool, TaskInstance from airflow.utils import timezone from airflow.utils.dates import days_ago from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_runs +from tests.test_utils.db import clear_db_pools, clear_db_runs DEFAULT_DATE = days_ago(1) ROOT_FOLDER = os.path.realpath( @@ -512,6 +512,23 @@ class TestLogsfromTaskRunCommand(unittest.TestCase): f"task_id={self.task_id}, execution_date=20170101T000000" in logs ) + @unittest.skipIf(not hasattr(os, 'fork'), "Forking not available") + def test_run_task_with_pool(self): + pool_name = 'test_pool_run' + + clear_db_pools() + with create_session() as session: + pool = Pool(pool=pool_name, slots=1) + session.add(pool) + session.commit() + + assert session.query(TaskInstance).filter_by(pool=pool_name).first() is None + task_command.task_run(self.parser.parse_args(self.task_args + ['--pool', pool_name])) + assert session.query(TaskInstance).filter_by(pool=pool_name).first() is not None + + session.delete(pool) + session.commit() + # For this test memory spins out of control on Python 3.6. TODO(potiuk): FIXME") @pytest.mark.quarantined @mock.patch("airflow.task.task_runner.standard_task_runner.CAN_FORK", False)
