This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3970ea386d Override pool for TaskInstance when pool is passed from 
cli. (#23258)
3970ea386d is described below

commit 3970ea386d5e0a371143ad1e69b897fd1262842d
Author: Karthikeyan Singaravelan <[email protected]>
AuthorDate: Sun May 1 00:41:07 2022 +0530

    Override pool for TaskInstance when pool is passed from cli. (#23258)
---
 airflow/cli/commands/task_command.py    |  5 +++--
 tests/cli/commands/test_task_command.py | 21 +++++++++++++++++++--
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index 536d83bbe0..ea20ebb646 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -137,6 +137,7 @@ def _get_ti(
     exec_date_or_run_id: str,
     map_index: int,
     *,
+    pool: Optional[str] = None,
     create_if_necessary: CreateIfNecessary = False,
     session: Session = NEW_SESSION,
 ) -> Tuple[TaskInstance, bool]:
@@ -165,7 +166,7 @@ def _get_ti(
         ti.dag_run = dag_run
     else:
         ti = ti_or_none
-    ti.refresh_from_task(task)
+    ti.refresh_from_task(task, pool_override=pool)
     return ti, dr_created
 
 
@@ -361,7 +362,7 @@ def task_run(args, dag=None):
         # Use DAG from parameter
         pass
     task = dag.get_task(task_id=args.task_id)
-    ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
+    ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, 
pool=args.pool)
     ti.init_run_context(raw=args.raw)
 
     hostname = get_hostname()
diff --git a/tests/cli/commands/test_task_command.py 
b/tests/cli/commands/test_task_command.py
index 3cf5bae778..481103f1f3 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -35,14 +35,14 @@ from airflow.cli import cli_parser
 from airflow.cli.commands import task_command
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, DagRunNotFound
-from airflow.models import DagBag, DagRun, TaskInstance
+from airflow.models import DagBag, DagRun, Pool, TaskInstance
 from airflow.utils import timezone
 from airflow.utils.dates import days_ago
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
 from tests.test_utils.config import conf_vars
-from tests.test_utils.db import clear_db_runs
+from tests.test_utils.db import clear_db_pools, clear_db_runs
 
 DEFAULT_DATE = days_ago(1)
 ROOT_FOLDER = os.path.realpath(
@@ -512,6 +512,23 @@ class TestLogsfromTaskRunCommand(unittest.TestCase):
             f"task_id={self.task_id}, execution_date=20170101T000000" in logs
         )
 
+    @unittest.skipIf(not hasattr(os, 'fork'), "Forking not available")
+    def test_run_task_with_pool(self):
+        pool_name = 'test_pool_run'
+
+        clear_db_pools()
+        with create_session() as session:
+            pool = Pool(pool=pool_name, slots=1)
+            session.add(pool)
+            session.commit()
+
+            assert 
session.query(TaskInstance).filter_by(pool=pool_name).first() is None
+            task_command.task_run(self.parser.parse_args(self.task_args + 
['--pool', pool_name]))
+            assert 
session.query(TaskInstance).filter_by(pool=pool_name).first() is not None
+
+            session.delete(pool)
+            session.commit()
+
     # For this test memory spins out of control on Python 3.6. TODO(potiuk): 
FIXME")
     @pytest.mark.quarantined
     @mock.patch("airflow.task.task_runner.standard_task_runner.CAN_FORK", 
False)

Reply via email to