This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 22da95984a3 Add bundle name arg to list dags cli command (#45779)
22da95984a3 is described below
commit 22da95984a3a8ac3282cf78a99af0b03c5d87c88
Author: ambikagarg <[email protected]>
AuthorDate: Thu Feb 27 21:22:03 2025 -0500
Add bundle name arg to list dags cli command (#45779)
Co-authored-by: Jed Cunningham
<[email protected]>
---
airflow/api_connexion/schemas/dag_schema.py | 2 ++
airflow/cli/cli_config.py | 5 +--
.../cli/commands/remote_commands/dag_command.py | 38 ++++++++++++++++++----
.../commands/remote_commands/test_dag_command.py | 34 ++++++++++++++-----
tests_common/pytest_plugin.py | 19 +++++++++++
5 files changed, 81 insertions(+), 17 deletions(-)
diff --git a/airflow/api_connexion/schemas/dag_schema.py
b/airflow/api_connexion/schemas/dag_schema.py
index 9f75f4dad52..38a062b1b28 100644
--- a/airflow/api_connexion/schemas/dag_schema.py
+++ b/airflow/api_connexion/schemas/dag_schema.py
@@ -51,6 +51,8 @@ class DAGSchema(SQLAlchemySchema):
dag_id = auto_field(dump_only=True)
dag_display_name = fields.String(attribute="dag_display_name",
dump_only=True)
+ bundle_name = auto_field(dump_only=True)
+ bundle_version = auto_field(dump_only=True)
is_paused = auto_field()
is_active = auto_field(dump_only=True)
last_parsed_time = auto_field(dump_only=True)
diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index 84f9ecea169..32a9bb1a382 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -172,6 +172,7 @@ ARG_BUNDLE_NAME = Arg(
"--bundle-name",
),
help=("The name of the DAG bundle to use; may be provided more than once"),
+ type=str,
default=None,
action="append",
)
@@ -880,7 +881,7 @@ ARG_DAG_LIST_COLUMNS = Arg(
("--columns",),
type=string_list_type,
help="List of columns to render. (default: ['dag_id', 'fileloc', 'owner',
'is_paused'])",
- default=("dag_id", "fileloc", "owners", "is_paused"),
+ default=("dag_id", "fileloc", "owners", "is_paused", "bundle_name",
"bundle_version"),
)
ARG_ASSET_LIST_COLUMNS = Arg(
@@ -978,7 +979,7 @@ DAGS_COMMANDS = (
name="list",
help="List all the DAGs",
func=lazy_load_command("airflow.cli.commands.remote_commands.dag_command.dag_list_dags"),
- args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS),
+ args=(ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS, ARG_BUNDLE_NAME),
),
ActionCommand(
name="list-import-errors",
diff --git a/airflow/cli/commands/remote_commands/dag_command.py
b/airflow/cli/commands/remote_commands/dag_command.py
index b478798c24a..03fb97aba39 100644
--- a/airflow/cli/commands/remote_commands/dag_command.py
+++ b/airflow/cli/commands/remote_commands/dag_command.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
"""Dag sub-commands."""
from __future__ import annotations
@@ -28,7 +29,7 @@ import sys
from typing import TYPE_CHECKING
import re2
-from sqlalchemy import select
+from sqlalchemy import func, select
from airflow.api.client import get_current_api_client
from airflow.api_connexion.schemas.dag_schema import dag_schema
@@ -38,6 +39,7 @@ from airflow.dag_processing.bundles.manager import
DagBundlesManager
from airflow.exceptions import AirflowException
from airflow.jobs.job import Job
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
+from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk.definitions._internal.dag_parsing_context import
_airflow_parsing_context_manager
from airflow.utils import cli as cli_utils, timezone
@@ -224,6 +226,8 @@ def _get_dagbag_dag_details(dag: DAG) -> dict:
return {
"dag_id": dag.dag_id,
"dag_display_name": dag.dag_display_name,
+ "bundle_name": dag.get_bundle_name(),
+ "bundle_version": dag.get_bundle_version(),
"is_paused": dag.get_is_paused(),
"is_active": dag.get_is_active(),
"last_parsed_time": None,
@@ -322,11 +326,12 @@ def dag_next_execution(args) -> None:
@suppress_logs_and_warning
@providers_configuration_loaded
@provide_session
-def dag_list_dags(args, session=NEW_SESSION) -> None:
+def dag_list_dags(args, session: Session = NEW_SESSION) -> None:
"""Display dags with or without stats at the command line."""
cols = args.columns if args.columns else []
invalid_cols = [c for c in cols if c not in dag_schema.fields]
valid_cols = [c for c in cols if c in dag_schema.fields]
+
if invalid_cols:
from rich import print as rich_print
@@ -335,8 +340,18 @@ def dag_list_dags(args, session=NEW_SESSION) -> None:
f"List of valid columns: {list(dag_schema.fields.keys())}",
file=sys.stderr,
)
- dagbag = DagBag(process_subdir(args.subdir))
- if dagbag.import_errors:
+
+ dagbag = DagBag(read_dags_from_db=True)
+ dagbag.collect_dags_from_db()
+
+ # Get import errors from the DB
+ query = select(func.count()).select_from(ParseImportError)
+ if args.bundle_name:
+ query = query.where(ParseImportError.bundle_name.in_(args.bundle_name))
+
+ dagbag_import_errors = session.scalar(query)
+
+ if dagbag_import_errors > 0:
from rich import print as rich_print
rich_print(
@@ -353,8 +368,19 @@ def dag_list_dags(args, session=NEW_SESSION) -> None:
dag_detail = _get_dagbag_dag_details(dag)
return {col: dag_detail[col] for col in valid_cols}
+ def filter_dags_by_bundle(dags: list[DAG], bundle_names: list[str] | None)
-> list[DAG]:
+ """Filter DAGs based on the specified bundle name, if provided."""
+ if not bundle_names:
+ return dags
+
+ validate_dag_bundle_arg(bundle_names)
+ return [dag for dag in dags if dag.get_bundle_name() in bundle_names]
+
AirflowConsole().print_as(
- data=sorted(dagbag.dags.values(), key=operator.attrgetter("dag_id")),
+ data=sorted(
+ filter_dags_by_bundle(list(dagbag.dags.values()),
args.bundle_name),
+ key=operator.attrgetter("dag_id"),
+ ),
output=args.output,
mapper=get_dag_detail,
)
@@ -364,7 +390,7 @@ def dag_list_dags(args, session=NEW_SESSION) -> None:
@suppress_logs_and_warning
@providers_configuration_loaded
@provide_session
-def dag_details(args, session=NEW_SESSION):
+def dag_details(args, session: Session = NEW_SESSION):
"""Get DAG details given a DAG id."""
dag = DagModel.get_dagmodel(args.dag_id, session=session)
if not dag:
diff --git a/tests/cli/commands/remote_commands/test_dag_command.py
b/tests/cli/commands/remote_commands/test_dag_command.py
index 051161bbdfa..5a5e16e0504 100644
--- a/tests/cli/commands/remote_commands/test_dag_command.py
+++ b/tests/cli/commands/remote_commands/test_dag_command.py
@@ -50,7 +50,12 @@ from airflow.utils.types import DagRunType
from tests.models import TEST_DAGS_FOLDER
from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
parse_and_sync_to_db
+from tests_common.test_utils.db import (
+ clear_db_dags,
+ clear_db_import_errors,
+ clear_db_runs,
+ parse_and_sync_to_db,
+)
DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc)
if pendulum.__version__.startswith("3"):
@@ -77,7 +82,11 @@ class TestCliDags:
clear_db_dags()
def setup_method(self):
- clear_db_runs() # clean-up all dag run before start each test
+ clear_db_runs()
+ clear_db_import_errors()
+
+ def teardown_method(self):
+ clear_db_import_errors()
def test_show_dag_dependencies_print(self):
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
@@ -274,12 +283,17 @@ class TestCliDags:
assert "Ignoring the following invalid columns: ['invalid_col']" in out
@conf_vars({("core", "load_examples"): "false"})
- def test_cli_list_dags_prints_import_errors(self):
- dag_path = os.path.join(TEST_DAGS_FOLDER, "test_invalid_cron.py")
- args = self.parser.parse_args(["dags", "list", "--output", "yaml",
"--subdir", dag_path])
- with contextlib.redirect_stderr(StringIO()) as temp_stderr:
- dag_command.dag_list_dags(args)
- out = temp_stderr.getvalue()
+ def test_cli_list_dags_prints_import_errors(self,
configure_testing_dag_bundle, get_test_dag):
+ path_to_parse = TEST_DAGS_FOLDER / "test_invalid_cron.py"
+ get_test_dag("test_invalid_cron")
+
+ args = self.parser.parse_args(["dags", "list", "--output", "yaml",
"--bundle-name", "testing"])
+
+ with configure_testing_dag_bundle(path_to_parse):
+ with contextlib.redirect_stderr(StringIO()) as temp_stderr:
+ dag_command.dag_list_dags(args)
+ out = temp_stderr.getvalue()
+
assert "Failed to load all files." in out
@conf_vars({("core", "load_examples"): "true"})
@@ -305,7 +319,9 @@ class TestCliDags:
@conf_vars({("core", "load_examples"): "false"})
def test_cli_list_import_errors(self):
dag_path = os.path.join(TEST_DAGS_FOLDER, "test_invalid_cron.py")
- args = self.parser.parse_args(["dags", "list", "--output", "yaml",
"--subdir", dag_path])
+ args = self.parser.parse_args(
+ ["dags", "list-import-errors", "--output", "yaml", "--subdir",
dag_path]
+ )
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
with pytest.raises(SystemExit) as err_ctx:
dag_command.dag_list_import_errors(args)
diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py
index de7e409fdc7..f4bb4da213a 100644
--- a/tests_common/pytest_plugin.py
+++ b/tests_common/pytest_plugin.py
@@ -1423,6 +1423,25 @@ def get_test_dag():
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
dag = dagbag.get_dag(dag_id)
+
+ if dagbag.import_errors:
+ session = settings.Session()
+ from airflow.models.errors import ParseImportError
+ from airflow.utils import timezone
+
+ # Add the new import errors
+ for _filename, stacktrace in dagbag.import_errors.items():
+ session.add(
+ ParseImportError(
+ filename=str(dag_file),
+ bundle_name="testing",
+ timestamp=timezone.utcnow(),
+ stacktrace=stacktrace,
+ )
+ )
+
+ return
+
if AIRFLOW_V_3_0_PLUS:
session = settings.Session()
from airflow.models.dagbundle import DagBundleModel