This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 22da95984a3 Add bundle name arg to list dags cli command (#45779)
22da95984a3 is described below

commit 22da95984a3a8ac3282cf78a99af0b03c5d87c88
Author: ambikagarg <[email protected]>
AuthorDate: Thu Feb 27 21:22:03 2025 -0500

    Add bundle name arg to list dags cli command (#45779)
    
    Co-authored-by: Jed Cunningham 
<[email protected]>
---
 airflow/api_connexion/schemas/dag_schema.py        |  2 ++
 airflow/cli/cli_config.py                          |  5 +--
 .../cli/commands/remote_commands/dag_command.py    | 38 ++++++++++++++++++----
 .../commands/remote_commands/test_dag_command.py   | 34 ++++++++++++++-----
 tests_common/pytest_plugin.py                      | 19 +++++++++++
 5 files changed, 81 insertions(+), 17 deletions(-)

diff --git a/airflow/api_connexion/schemas/dag_schema.py 
b/airflow/api_connexion/schemas/dag_schema.py
index 9f75f4dad52..38a062b1b28 100644
--- a/airflow/api_connexion/schemas/dag_schema.py
+++ b/airflow/api_connexion/schemas/dag_schema.py
@@ -51,6 +51,8 @@ class DAGSchema(SQLAlchemySchema):
 
     dag_id = auto_field(dump_only=True)
     dag_display_name = fields.String(attribute="dag_display_name", 
dump_only=True)
+    bundle_name = auto_field(dump_only=True)
+    bundle_version = auto_field(dump_only=True)
     is_paused = auto_field()
     is_active = auto_field(dump_only=True)
     last_parsed_time = auto_field(dump_only=True)
diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index 84f9ecea169..32a9bb1a382 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -172,6 +172,7 @@ ARG_BUNDLE_NAME = Arg(
         "--bundle-name",
     ),
     help=("The name of the DAG bundle to use; may be provided more than once"),
+    type=str,
     default=None,
     action="append",
 )
@@ -880,7 +881,7 @@ ARG_DAG_LIST_COLUMNS = Arg(
     ("--columns",),
     type=string_list_type,
     help="List of columns to render. (default: ['dag_id', 'fileloc', 'owner', 
'is_paused'])",
-    default=("dag_id", "fileloc", "owners", "is_paused"),
+    default=("dag_id", "fileloc", "owners", "is_paused", "bundle_name", 
"bundle_version"),
 )
 
 ARG_ASSET_LIST_COLUMNS = Arg(
@@ -978,7 +979,7 @@ DAGS_COMMANDS = (
         name="list",
         help="List all the DAGs",
         
func=lazy_load_command("airflow.cli.commands.remote_commands.dag_command.dag_list_dags"),
-        args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS),
+        args=(ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS, ARG_BUNDLE_NAME),
     ),
     ActionCommand(
         name="list-import-errors",
diff --git a/airflow/cli/commands/remote_commands/dag_command.py 
b/airflow/cli/commands/remote_commands/dag_command.py
index b478798c24a..03fb97aba39 100644
--- a/airflow/cli/commands/remote_commands/dag_command.py
+++ b/airflow/cli/commands/remote_commands/dag_command.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 """Dag sub-commands."""
 
 from __future__ import annotations
@@ -28,7 +29,7 @@ import sys
 from typing import TYPE_CHECKING
 
 import re2
-from sqlalchemy import select
+from sqlalchemy import func, select
 
 from airflow.api.client import get_current_api_client
 from airflow.api_connexion.schemas.dag_schema import dag_schema
@@ -38,6 +39,7 @@ from airflow.dag_processing.bundles.manager import 
DagBundlesManager
 from airflow.exceptions import AirflowException
 from airflow.jobs.job import Job
 from airflow.models import DagBag, DagModel, DagRun, TaskInstance
+from airflow.models.errors import ParseImportError
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.sdk.definitions._internal.dag_parsing_context import 
_airflow_parsing_context_manager
 from airflow.utils import cli as cli_utils, timezone
@@ -224,6 +226,8 @@ def _get_dagbag_dag_details(dag: DAG) -> dict:
     return {
         "dag_id": dag.dag_id,
         "dag_display_name": dag.dag_display_name,
+        "bundle_name": dag.get_bundle_name(),
+        "bundle_version": dag.get_bundle_version(),
         "is_paused": dag.get_is_paused(),
         "is_active": dag.get_is_active(),
         "last_parsed_time": None,
@@ -322,11 +326,12 @@ def dag_next_execution(args) -> None:
 @suppress_logs_and_warning
 @providers_configuration_loaded
 @provide_session
-def dag_list_dags(args, session=NEW_SESSION) -> None:
+def dag_list_dags(args, session: Session = NEW_SESSION) -> None:
     """Display dags with or without stats at the command line."""
     cols = args.columns if args.columns else []
     invalid_cols = [c for c in cols if c not in dag_schema.fields]
     valid_cols = [c for c in cols if c in dag_schema.fields]
+
     if invalid_cols:
         from rich import print as rich_print
 
@@ -335,8 +340,18 @@ def dag_list_dags(args, session=NEW_SESSION) -> None:
             f"List of valid columns: {list(dag_schema.fields.keys())}",
             file=sys.stderr,
         )
-    dagbag = DagBag(process_subdir(args.subdir))
-    if dagbag.import_errors:
+
+    dagbag = DagBag(read_dags_from_db=True)
+    dagbag.collect_dags_from_db()
+
+    # Get import errors from the DB
+    query = select(func.count()).select_from(ParseImportError)
+    if args.bundle_name:
+        query = query.where(ParseImportError.bundle_name.in_(args.bundle_name))
+
+    dagbag_import_errors = session.scalar(query)
+
+    if dagbag_import_errors > 0:
         from rich import print as rich_print
 
         rich_print(
@@ -353,8 +368,19 @@ def dag_list_dags(args, session=NEW_SESSION) -> None:
             dag_detail = _get_dagbag_dag_details(dag)
         return {col: dag_detail[col] for col in valid_cols}
 
+    def filter_dags_by_bundle(dags: list[DAG], bundle_names: list[str] | None) 
-> list[DAG]:
+        """Filter DAGs based on the specified bundle name, if provided."""
+        if not bundle_names:
+            return dags
+
+        validate_dag_bundle_arg(bundle_names)
+        return [dag for dag in dags if dag.get_bundle_name() in bundle_names]
+
     AirflowConsole().print_as(
-        data=sorted(dagbag.dags.values(), key=operator.attrgetter("dag_id")),
+        data=sorted(
+            filter_dags_by_bundle(list(dagbag.dags.values()), 
args.bundle_name),
+            key=operator.attrgetter("dag_id"),
+        ),
         output=args.output,
         mapper=get_dag_detail,
     )
@@ -364,7 +390,7 @@ def dag_list_dags(args, session=NEW_SESSION) -> None:
 @suppress_logs_and_warning
 @providers_configuration_loaded
 @provide_session
-def dag_details(args, session=NEW_SESSION):
+def dag_details(args, session: Session = NEW_SESSION):
     """Get DAG details given a DAG id."""
     dag = DagModel.get_dagmodel(args.dag_id, session=session)
     if not dag:
diff --git a/tests/cli/commands/remote_commands/test_dag_command.py 
b/tests/cli/commands/remote_commands/test_dag_command.py
index 051161bbdfa..5a5e16e0504 100644
--- a/tests/cli/commands/remote_commands/test_dag_command.py
+++ b/tests/cli/commands/remote_commands/test_dag_command.py
@@ -50,7 +50,12 @@ from airflow.utils.types import DagRunType
 
 from tests.models import TEST_DAGS_FOLDER
 from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.db import clear_db_dags, clear_db_runs, 
parse_and_sync_to_db
+from tests_common.test_utils.db import (
+    clear_db_dags,
+    clear_db_import_errors,
+    clear_db_runs,
+    parse_and_sync_to_db,
+)
 
 DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc)
 if pendulum.__version__.startswith("3"):
@@ -77,7 +82,11 @@ class TestCliDags:
         clear_db_dags()
 
     def setup_method(self):
-        clear_db_runs()  # clean-up all dag run before start each test
+        clear_db_runs()
+        clear_db_import_errors()
+
+    def teardown_method(self):
+        clear_db_import_errors()
 
     def test_show_dag_dependencies_print(self):
         with contextlib.redirect_stdout(StringIO()) as temp_stdout:
@@ -274,12 +283,17 @@ class TestCliDags:
         assert "Ignoring the following invalid columns: ['invalid_col']" in out
 
     @conf_vars({("core", "load_examples"): "false"})
-    def test_cli_list_dags_prints_import_errors(self):
-        dag_path = os.path.join(TEST_DAGS_FOLDER, "test_invalid_cron.py")
-        args = self.parser.parse_args(["dags", "list", "--output", "yaml", 
"--subdir", dag_path])
-        with contextlib.redirect_stderr(StringIO()) as temp_stderr:
-            dag_command.dag_list_dags(args)
-            out = temp_stderr.getvalue()
+    def test_cli_list_dags_prints_import_errors(self, 
configure_testing_dag_bundle, get_test_dag):
+        path_to_parse = TEST_DAGS_FOLDER / "test_invalid_cron.py"
+        get_test_dag("test_invalid_cron")
+
+        args = self.parser.parse_args(["dags", "list", "--output", "yaml", 
"--bundle-name", "testing"])
+
+        with configure_testing_dag_bundle(path_to_parse):
+            with contextlib.redirect_stderr(StringIO()) as temp_stderr:
+                dag_command.dag_list_dags(args)
+                out = temp_stderr.getvalue()
+
         assert "Failed to load all files." in out
 
     @conf_vars({("core", "load_examples"): "true"})
@@ -305,7 +319,9 @@ class TestCliDags:
     @conf_vars({("core", "load_examples"): "false"})
     def test_cli_list_import_errors(self):
         dag_path = os.path.join(TEST_DAGS_FOLDER, "test_invalid_cron.py")
-        args = self.parser.parse_args(["dags", "list", "--output", "yaml", 
"--subdir", dag_path])
+        args = self.parser.parse_args(
+            ["dags", "list-import-errors", "--output", "yaml", "--subdir", 
dag_path]
+        )
         with contextlib.redirect_stdout(StringIO()) as temp_stdout:
             with pytest.raises(SystemExit) as err_ctx:
                 dag_command.dag_list_import_errors(args)
diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py
index de7e409fdc7..f4bb4da213a 100644
--- a/tests_common/pytest_plugin.py
+++ b/tests_common/pytest_plugin.py
@@ -1423,6 +1423,25 @@ def get_test_dag():
         dagbag = DagBag(dag_folder=dag_file, include_examples=False)
 
         dag = dagbag.get_dag(dag_id)
+
+        if dagbag.import_errors:
+            session = settings.Session()
+            from airflow.models.errors import ParseImportError
+            from airflow.utils import timezone
+
+            # Add the new import errors
+            for _filename, stacktrace in dagbag.import_errors.items():
+                session.add(
+                    ParseImportError(
+                        filename=str(dag_file),
+                        bundle_name="testing",
+                        timestamp=timezone.utcnow(),
+                        stacktrace=stacktrace,
+                    )
+                )
+
+            return
+
         if AIRFLOW_V_3_0_PLUS:
             session = settings.Session()
             from airflow.models.dagbundle import DagBundleModel

Reply via email to