This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 05a67efe32 Add an option to load the dags from db for command tasks
run (#32038)
05a67efe32 is described below
commit 05a67efe32af248ca191ea59815b3b202f893f46
Author: Hussein Awala <[email protected]>
AuthorDate: Sat Jun 24 00:31:05 2023 +0200
Add an option to load the dags from db for command tasks run (#32038)
Signed-off-by: Hussein Awala <[email protected]>
---
airflow/cli/cli_config.py | 2 ++
airflow/cli/commands/task_command.py | 2 +-
airflow/utils/cli.py | 21 ++++++++++++-------
tests/cli/commands/test_task_command.py | 36 +++++++++++++++++++++++++++++++++
4 files changed, 53 insertions(+), 8 deletions(-)
diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index 0c69571fea..587934769e 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -604,6 +604,7 @@ ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized
pickle object of the entir
ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead
of airflow.cfg")
ARG_MAP_INDEX = Arg(("--map-index",), type=int, default=-1, help="Mapped task
index")
+ARG_READ_FROM_DB = Arg(("--read-from-db",), help="Read dag from DB instead of
dag file", action="store_true")
# database
@@ -1453,6 +1454,7 @@ TASKS_COMMANDS = (
ARG_SHUT_DOWN_LOGGING,
ARG_MAP_INDEX,
ARG_VERBOSE,
+ ARG_READ_FROM_DB,
),
),
ActionCommand(
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index aab8bb10dc..560764536b 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -398,7 +398,7 @@ def task_run(args, dag: DAG | None = None) ->
TaskReturnCode | None:
print(f"Loading pickle id: {args.pickle}")
_dag = get_dag_by_pickle(args.pickle)
elif not dag:
- _dag = get_dag(args.subdir, args.dag_id)
+ _dag = get_dag(args.subdir, args.dag_id, args.read_from_db)
else:
_dag = dag
task = _dag.get_task(task_id=args.task_id)
diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index d9e53ac072..56ac166b54 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -215,27 +215,34 @@ def _search_for_dag_file(val: str | None) -> str | None:
return None
-def get_dag(subdir: str | None, dag_id: str) -> DAG:
+def get_dag(subdir: str | None, dag_id: str, from_db: bool = False) -> DAG:
"""
Returns DAG of a given dag_id.
- First it we'll try to use the given subdir. If that doesn't work, we'll
try to
+ First we'll try to use the given subdir. If that doesn't work, we'll try
to
find the correct path (assuming it's a file) and failing that, use the
configured
dags folder.
"""
from airflow.models import DagBag
- first_path = process_subdir(subdir)
- dagbag = DagBag(first_path)
- if dag_id not in dagbag.dags:
+ if from_db:
+ dagbag = DagBag(read_dags_from_db=True)
+ else:
+ first_path = process_subdir(subdir)
+ dagbag = DagBag(first_path)
+ dag = dagbag.get_dag(dag_id)
+ if not dag:
+ if from_db:
+ raise AirflowException(f"Dag {dag_id!r} could not be found in
DagBag read from database.")
fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
logger.warning("Dag %r not found in path %s; trying path %s", dag_id,
first_path, fallback_path)
dagbag = DagBag(dag_folder=fallback_path)
- if dag_id not in dagbag.dags:
+ dag = dagbag.get_dag(dag_id)
+ if not dag:
raise AirflowException(
f"Dag {dag_id!r} could not be found; either it does not exist
or it failed to parse."
)
- return dagbag.dags[dag_id]
+ return dag
def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False):
diff --git a/tests/cli/commands/test_task_command.py
b/tests/cli/commands/test_task_command.py
index 376faeda41..646d76f47c 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -276,6 +276,42 @@ class TestCliTasks:
external_executor_id=None,
)
+ @pytest.mark.parametrize(
+ "from_db",
+ [True, False],
+ )
+ @mock.patch("airflow.cli.commands.task_command.LocalTaskJobRunner")
+ def test_run_with_read_from_db(self, mock_local_job_runner, caplog,
from_db):
+ """
+ Test that we can run with read from db
+ """
+ task0_id = self.dag.task_ids[0]
+ args0 = [
+ "tasks",
+ "run",
+ "--ignore-all-dependencies",
+ "--local",
+ self.dag_id,
+ task0_id,
+ self.run_id,
+ ] + (["--read-from-db"] if from_db else [])
+ mock_local_job_runner.return_value.job_type = "LocalTaskJob"
+ task_command.task_run(self.parser.parse_args(args0))
+ mock_local_job_runner.assert_called_once_with(
+ job=mock.ANY,
+ task_instance=mock.ANY,
+ mark_success=False,
+ ignore_all_deps=True,
+ ignore_depends_on_past=False,
+ wait_for_past_depends_before_skipping=False,
+ ignore_task_deps=False,
+ ignore_ti_state=False,
+ pickle_id=None,
+ pool=None,
+ external_executor_id=None,
+ )
+ assert ("Filling up the DagBag from" in caplog.text) != from_db
+
@mock.patch("airflow.cli.commands.task_command.LocalTaskJobRunner")
def test_run_raises_when_theres_no_dagrun(self, mock_local_job):
"""