This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9095e682f7 Improve the views module a bit (#31661)
9095e682f7 is described below

commit 9095e682f7efb1341377481c9f6def38411135a3
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Jun 5 14:21:12 2023 +0800

    Improve the views module a bit (#31661)
    
    Co-authored-by: Jed Cunningham 
<[email protected]>
---
 airflow/www/views.py          | 229 ++++++++++++++++++++----------------------
 tests/www/views/test_views.py |  15 ---
 2 files changed, 107 insertions(+), 137 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index 592812e841..3e5cddb3a5 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -38,7 +38,6 @@ from urllib.parse import unquote, urljoin, urlsplit
 import configupdater
 import flask.json
 import lazy_object_proxy
-import markupsafe
 import nvd3
 import sqlalchemy as sqla
 from croniter import croniter
@@ -144,14 +143,6 @@ LINECHART_X_AXIS_TICKFORMAT = (
 )
 
 
-def truncate_task_duration(task_duration):
-    """
-    Cast the task_duration to an int was for optimization for large/huge dags 
if task_duration > 10s
-    otherwise we keep it as a float with 3dp.
-    """
-    return int(task_duration) if task_duration > 10.0 else 
round(task_duration, 3)
-
-
 def sanitize_args(args: dict[str, str]) -> dict[str, str]:
     """
     Remove all parameters starting with `_`.
@@ -769,9 +760,9 @@ class Airflow(AirflowBaseView):
             failed_dags = dags_query.join(subq_join, DagModel.dag_id == 
subq_join.c.dag_id)
 
             is_paused_count = dict(
-                all_dags.with_entities(DagModel.is_paused, 
func.count(DagModel.dag_id))
-                .group_by(DagModel.is_paused)
-                .all()
+                all_dags.with_entities(DagModel.is_paused, 
func.count(DagModel.dag_id)).group_by(
+                    DagModel.is_paused
+                )
             )
 
             status_count_active = is_paused_count.get(False, 0)
@@ -865,7 +856,7 @@ class Airflow(AirflowBaseView):
 
             for import_error in import_errors:
                 flash(
-                    "Broken DAG: [{ie.filename}] 
{ie.stacktrace}".format(ie=import_error),
+                    f"Broken DAG: [{import_error.filename}] 
{import_error.stacktrace}",
                     "dag_import_error",
                 )
 
@@ -1010,9 +1001,11 @@ class Airflow(AirflowBaseView):
 
         dataset_triggered_dag_ids = [
             dag.dag_id
-            for dag in session.query(DagModel.dag_id)
-            .filter(DagModel.dag_id.in_(filter_dag_ids))
-            .filter(DagModel.schedule_interval == "Dataset")
+            for dag in (
+                session.query(DagModel.dag_id)
+                .filter(DagModel.dag_id.in_(filter_dag_ids))
+                .filter(DagModel.schedule_interval == "Dataset")
+            )
         ]
 
         dataset_triggered_next_run_info = get_dataset_triggered_next_run_info(
@@ -1031,38 +1024,31 @@ class Airflow(AirflowBaseView):
     @provide_session
     def dag_stats(self, session: Session = NEW_SESSION):
         """Dag statistics."""
-        dr = models.DagRun
-
         allowed_dag_ids = 
get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
 
-        dag_state_stats = session.query(dr.dag_id, dr.state, 
sqla.func.count(dr.state)).group_by(
-            dr.dag_id, dr.state
-        )
-
         # Filter by post parameters
         selected_dag_ids = {unquote(dag_id) for dag_id in 
request.form.getlist("dag_ids") if dag_id}
-
         if selected_dag_ids:
             filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids)
         else:
             filter_dag_ids = allowed_dag_ids
-
         if not filter_dag_ids:
             return flask.json.jsonify({})
 
-        dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids))
-        data: dict[str, dict[str, int]] = collections.defaultdict(dict)
-        payload: dict[str, list[dict[str, Any]]] = 
collections.defaultdict(list)
-
-        for dag_id, state, count in dag_state_stats:
-            data[dag_id][state] = count
-
-        for dag_id in filter_dag_ids:
-            payload[dag_id] = []
-            for state in State.dag_states:
-                count = data.get(dag_id, {}).get(state, 0)
-                payload[dag_id].append({"state": state, "count": count})
+        dag_state_stats = (
+            session.query(DagRun.dag_id, DagRun.state, 
sqla.func.count(DagRun.state))
+            .group_by(DagRun.dag_id, DagRun.state)
+            .filter(DagRun.dag_id.in_(filter_dag_ids))
+        )
+        dag_state_data = {(dag_id, state): count for dag_id, state, count in 
dag_state_stats}
 
+        payload = {
+            dag_id: [
+                {"state": state, "count": dag_state_data.get((dag_id, state), 
0)}
+                for state in State.dag_states
+            ]
+            for dag_id in filter_dag_ids
+        }
         return flask.json.jsonify(payload)
 
     @expose("/task_stats", methods=["POST"])
@@ -1415,12 +1401,13 @@ class Airflow(AirflowBaseView):
         try:
             ti.get_rendered_template_fields(session=session)
         except AirflowException as e:
-            msg = "Error rendering template: " + escape(e)
-            if e.__cause__:
-                msg += Markup("<br><br>OriginalError: ") + escape(e.__cause__)
-            flash(msg, "error")
+            if not e.__cause__:
+                flash(f"Error rendering template: {e}", "error")
+            else:
+                msg = Markup("Error rendering template: 
{0}<br><br>OriginalError: {0.__cause__}").format(e)
+                flash(msg, "error")
         except Exception as e:
-            flash("Error rendering template: " + str(e), "error")
+            flash(f"Error rendering template: {e}", "error")
 
         # Ensure we are rendering the unmapped operator. Unmapping should be
         # done automatically if template fields are rendered successfully; this
@@ -1509,25 +1496,23 @@ class Airflow(AirflowBaseView):
         try:
             pod_spec = ti.get_rendered_k8s_spec(session=session)
         except AirflowException as e:
-            msg = "Error rendering Kubernetes POD Spec: " + escape(e)
-            if e.__cause__:
-                msg += Markup("<br><br>OriginalError: ") + escape(e.__cause__)
-            flash(msg, "error")
+            if not e.__cause__:
+                flash(f"Error rendering Kubernetes POD Spec: {e}", "error")
+            else:
+                tmp = Markup("Error rendering Kubernetes POD Spec: 
{0}<br><br>Original error: {0.__cause__}")
+                flash(tmp.format(e), "error")
         except Exception as e:
-            flash("Error rendering Kubernetes Pod Spec: " + str(e), "error")
+            flash(f"Error rendering Kubernetes Pod Spec: {e}", "error")
         title = "Rendered K8s Pod Spec"
-        html_dict = {}
-        renderers = wwwutils.get_attr_renderer()
+
         if pod_spec:
-            content = yaml.dump(pod_spec)
-            content = renderers["yaml"](content)
+            content = wwwutils.get_attr_renderer()["yaml"](yaml.dump(pod_spec))
         else:
             content = Markup("<pre><code>Error rendering Kubernetes POD 
Spec</pre></code>")
-        html_dict["k8s"] = content
 
         return self.render_template(
             "airflow/ti_code.html",
-            html_dict=html_dict,
+            html_dict={"k8s": content},
             dag=dag,
             task_id=task_id,
             execution_date=execution_date,
@@ -1872,16 +1857,14 @@ class Airflow(AirflowBaseView):
             flash(f"Task [{dag_id}.{task_id}] doesn't seem to exist at the 
moment", "error")
             return redirect(url_for("Airflow.index"))
 
-        xcomlist = (
-            session.query(XCom)
-            .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm, 
map_index=map_index)
-            .all()
+        xcom_query = session.query(XCom.key, XCom.value).filter(
+            XCom.dag_id == dag_id,
+            XCom.task_id == task_id,
+            XCom.execution_date == dttm,
+            XCom.map_index == map_index,
+            XCom.key.not_like("_%"),
         )
-
-        attributes = []
-        for xcom in xcomlist:
-            if not xcom.key.startswith("_"):
-                attributes.append((xcom.key, xcom.value))
+        attributes = [tuple(row) for row in xcom_query]
 
         title = "XCom"
         return self.render_template(
@@ -1999,14 +1982,12 @@ class Airflow(AirflowBaseView):
             .group_by(DagRun.conf)
             .order_by(func.max(DagRun.execution_date).desc())
             .limit(5)
-            .all()
         )
-
-        recent_confs = {}
-        for run in recent_runs:
-            recent_conf = getattr(run, "conf")
-            if isinstance(recent_conf, dict) and any(recent_conf):
-                recent_confs[getattr(run, "run_id")] = json.dumps(recent_conf)
+        recent_confs = {
+            run_id: json.dumps(run_conf)
+            for run_id, run_conf in ((run.run_id, run.conf) for run in 
recent_runs)
+            if isinstance(run_conf, dict) and any(run_conf)
+        }
 
         if request.method == "GET" and ui_fields_defined:
             # Populate conf textarea with conf requests parameter, or 
dag.params
@@ -2351,7 +2332,7 @@ class Airflow(AirflowBaseView):
 
         dags = (
             session.query(DagRun.dag_id, sqla.func.count(DagRun.id))
-            .filter(DagRun.state == State.RUNNING)
+            .filter(DagRun.state == DagRunState.RUNNING)
             .filter(DagRun.dag_id.in_(filter_dag_ids))
             .group_by(DagRun.dag_id)
         )
@@ -3272,9 +3253,9 @@ class Airflow(AirflowBaseView):
         cum_chart.buildcontent()
         s_index = cum_chart.htmlcontent.rfind("});")
         cum_chart.htmlcontent = (
-            cum_chart.htmlcontent[:s_index]
-            + "$( document ).trigger('chartload')"
-            + cum_chart.htmlcontent[s_index:]
+            f"{cum_chart.htmlcontent[:s_index]}"
+            "$( document ).trigger('chartload')"
+            f"{cum_chart.htmlcontent[s_index:]}"
         )
 
         return self.render_template(
@@ -3482,7 +3463,7 @@ class Airflow(AirflowBaseView):
             "airflow/chart.html",
             dag=dag,
             chart=Markup(chart.htmlcontent),
-            height=str(chart_height + 100) + "px",
+            height=f"{chart_height + 100}px",
             root=root,
             form=form,
             tab_title="Landing times",
@@ -3552,8 +3533,8 @@ class Airflow(AirflowBaseView):
             .filter(
                 TaskInstance.dag_id == dag_id,
                 TaskInstance.run_id == dag_run_id,
-                TaskInstance.start_date.isnot(None),
-                TaskInstance.state.isnot(None),
+                TaskInstance.start_date.is_not(None),
+                TaskInstance.state.is_not(None),
             )
             .order_by(TaskInstance.start_date)
         )
@@ -3595,7 +3576,7 @@ class Airflow(AirflowBaseView):
             end_date = task_dict["end_date"] or timezone.utcnow()
             task_dict["end_date"] = end_date
             task_dict["start_date"] = task_dict["start_date"] or end_date
-            task_dict["state"] = State.FAILED
+            task_dict["state"] = TaskInstanceState.FAILED
             task_dict["operator"] = task.operator_name
             task_dict["try_number"] = try_count
             task_dict["extraLinks"] = task.extra_links
@@ -3905,7 +3886,6 @@ class Airflow(AirflowBaseView):
                 .filter(DagScheduleDatasetReference.dag_id == dag_id, 
~DatasetModel.is_orphaned)
                 .group_by(DatasetModel.id, DatasetModel.uri)
                 .order_by(DatasetModel.uri)
-                .all()
             ]
         return (
             htmlsafe_json_dumps(data, separators=(",", ":"), 
dumps=flask.json.dumps),
@@ -3983,12 +3963,12 @@ class Airflow(AirflowBaseView):
 
         with create_session() as session:
             if lstripped_orderby == "uri":
-                if order_by[0] == "-":
+                if order_by.startswith("-"):
                     order_by = (DatasetModel.uri.desc(),)
                 else:
                     order_by = (DatasetModel.uri.asc(),)
             elif lstripped_orderby == "last_dataset_update":
-                if order_by[0] == "-":
+                if order_by.startswith("-"):
                     order_by = (
                         func.max(DatasetEvent.timestamp).desc(),
                         DatasetModel.uri.asc(),
@@ -4033,12 +4013,10 @@ class Airflow(AirflowBaseView):
             if updated_before:
                 filters.append(DatasetEvent.timestamp <= updated_before)
 
-            query = query.filter(*filters)
+            query = query.filter(*filters).offset(offset).limit(limit)
             count_query = count_query.filter(*filters)
 
-            query = query.offset(offset).limit(limit)
-
-            datasets = [dict(dataset) for dataset in query.all()]
+            datasets = [dict(dataset) for dataset in query]
             data = {"datasets": datasets, "total_entries": 
count_query.scalar()}
 
             return (
@@ -4360,7 +4338,7 @@ class SlaMissModelView(AirflowModelView):
     edit_columns = ["dag_id", "task_id", "execution_date", "email_sent", 
"notification_sent", "timestamp"]
     search_columns = ["dag_id", "task_id", "email_sent", "notification_sent", 
"timestamp", "execution_date"]
     base_order = ("execution_date", "desc")
-    base_filters = [["dag_id", DagFilter, lambda: []]]
+    base_filters = [["dag_id", DagFilter, list]]
 
     formatters_columns = {
         "task_id": wwwutils.task_instance_link,
@@ -4462,7 +4440,7 @@ class XComModelView(AirflowModelView):
     list_columns = ["key", "value", "timestamp", "dag_id", "task_id", 
"run_id", "map_index", "execution_date"]
     base_order = ("dag_run_id", "desc")
 
-    base_filters = [["dag_id", DagFilter, lambda: []]]
+    base_filters = [["dag_id", DagFilter, list]]
 
     formatters_columns = {
         "task_id": wwwutils.task_instance_link,
@@ -4909,12 +4887,12 @@ class ProviderView(AirflowBaseView):
         def _build_link(match_obj):
             text = match_obj.group(1)
             url = match_obj.group(2)
-            return markupsafe.Markup(f'<a href="{url}">{text}</a>')
+            return Markup(f'<a href="{url}">{text}</a>')
 
-        cd = markupsafe.escape(description)
+        cd = escape(description)
         cd = re.sub(r"`(.*)[\s+]+&lt;(.*)&gt;`__", _build_link, cd)
         cd = re.sub(r"\n", r"<br>", cd)
-        return markupsafe.Markup(cd)
+        return Markup(cd)
 
 
 class PoolModelView(AirflowModelView):
@@ -5203,7 +5181,7 @@ class JobModelView(AirflowModelView):
 
     base_order = ("start_date", "desc")
 
-    base_filters = [["dag_id", DagFilter, lambda: []]]
+    base_filters = [["dag_id", DagFilter, list]]
 
     formatters_columns = {
         "start_date": wwwutils.datetime_f("start_date"),
@@ -5295,7 +5273,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
 
     base_order = ("execution_date", "desc")
 
-    base_filters = [["dag_id", DagFilter, lambda: []]]
+    base_filters = [["dag_id", DagFilter, list]]
 
     edit_form = DagRunEditForm
 
@@ -5350,7 +5328,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
         """This routine only supports Running and Queued state."""
         try:
             count = 0
-            for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id 
for dagrun in drs])):
+            for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for 
dagrun in drs)):
                 count += 1
                 if state == State.RUNNING:
                     dr.start_date = timezone.utcnow()
@@ -5376,7 +5354,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
         try:
             count = 0
             altered_tis = []
-            for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id 
for dagrun in drs])).all():
+            for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for 
dagrun in drs)):
                 count += 1
                 altered_tis += set_dag_run_state_to_failed(
                     dag=get_airflow_app().dag_bag.get_dag(dr.dag_id),
@@ -5404,7 +5382,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
         try:
             count = 0
             altered_tis = []
-            for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id 
for dagrun in drs])).all():
+            for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for 
dagrun in drs)):
                 count += 1
                 altered_tis += set_dag_run_state_to_success(
                     dag=get_airflow_app().dag_bag.get_dag(dr.dag_id),
@@ -5428,7 +5406,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
             count = 0
             cleared_ti_count = 0
             dag_to_tis: dict[DAG, list[TaskInstance]] = {}
-            for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id 
for dagrun in drs])).all():
+            for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for 
dagrun in drs)):
                 count += 1
                 dag = get_airflow_app().dag_bag.get_dag(dr.dag_id)
                 tis_to_clear = dag_to_tis.setdefault(dag, [])
@@ -5469,7 +5447,7 @@ class LogModelView(AirflowModelView):
 
     base_order = ("dttm", "desc")
 
-    base_filters = [["dag_id", DagFilter, lambda: []]]
+    base_filters = [["dag_id", DagFilter, list]]
 
     formatters_columns = {
         "dttm": wwwutils.datetime_f("dttm"),
@@ -5526,7 +5504,7 @@ class TaskRescheduleModelView(AirflowModelView):
 
     base_order = ("id", "desc")
 
-    base_filters = [["dag_id", DagFilter, lambda: []]]
+    base_filters = [["dag_id", DagFilter, list]]
 
     def duration_f(self):
         """Duration calculation."""
@@ -5579,8 +5557,6 @@ class TriggerModelView(AirflowModelView):
         "triggerer_id",
     ]
 
-    # add_exclude_columns = ["kwargs"]
-
     base_order = ("id", "created_date")
 
     formatters_columns = {
@@ -5642,15 +5618,25 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
     ]
 
     order_columns = [
-        item
-        for item in list_columns
-        if item
-        not in [
-            "try_number",
-            "log_url",
-            "external_executor_id",
-            "note",  # todo: maybe figure out how to re-enable this
-        ]
+        "state",
+        "dag_id",
+        "task_id",
+        "run_id",
+        "map_index",
+        "dag_run.execution_date",
+        "operator",
+        "start_date",
+        "end_date",
+        "duration",
+        # "note",  # TODO: Maybe figure out how to re-enable this.
+        "job_id",
+        "hostname",
+        "unixname",
+        "priority_weight",
+        "queue",
+        "queued_dttm",
+        "pool",
+        "queued_by_job_id",
     ]
 
     label_columns = {
@@ -5693,7 +5679,7 @@ class 
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
 
     base_order = ("job_id", "asc")
 
-    base_filters = [["dag_id", DagFilter, lambda: []]]
+    base_filters = [["dag_id", DagFilter, list]]
 
     def log_url_formatter(self):
         """Formats log URL."""
@@ -5839,7 +5825,7 @@ class AutocompleteView(AirflowBaseView):
         dag_ids_query = session.query(
             sqla.literal("dag").label("type"),
             DagModel.dag_id.label("name"),
-        ).filter(~DagModel.is_subdag, DagModel.is_active, 
DagModel.dag_id.ilike("%" + query + "%"))
+        ).filter(~DagModel.is_subdag, DagModel.is_active, 
DagModel.dag_id.ilike(f"%{query}%"))
 
         owners_query = (
             session.query(
@@ -5847,7 +5833,7 @@ class AutocompleteView(AirflowBaseView):
                 DagModel.owners.label("name"),
             )
             .distinct()
-            .filter(~DagModel.is_subdag, DagModel.is_active, 
DagModel.owners.ilike("%" + query + "%"))
+            .filter(~DagModel.is_subdag, DagModel.is_active, 
DagModel.owners.ilike(f"%{query}%"))
         )
 
         # Hide DAGs if not showing status: "all"
@@ -5864,9 +5850,7 @@ class AutocompleteView(AirflowBaseView):
         dag_ids_query = 
dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
         owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
 
-        payload = [
-            row._asdict() for row in 
dag_ids_query.union(owners_query).order_by("name").limit(10).all()
-        ]
+        payload = [row._asdict() for row in 
dag_ids_query.union(owners_query).order_by("name").limit(10)]
         return flask.json.jsonify(payload)
 
 
@@ -5942,16 +5926,17 @@ def add_user_permissions_to_dag(sender, template, 
context, **extra):
     Located in `views.py` rather than the DAG model to keep
     permissions logic out of the Airflow core.
     """
-    if "dag" in context:
-        dag = context["dag"]
-        can_create_dag_run = get_airflow_app().appbuilder.sm.has_access(
-            permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN
-        )
+    if "dag" not in context:
+        return
+    dag = context["dag"]
+    can_create_dag_run = get_airflow_app().appbuilder.sm.has_access(
+        permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN
+    )
 
-        dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id)
-        dag.can_trigger = dag.can_edit and can_create_dag_run
-        dag.can_delete = 
get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id)
-        context["dag"] = dag
+    dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id)
+    dag.can_trigger = dag.can_edit and can_create_dag_run
+    dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id)
+    context["dag"] = dag
 
 
 # NOTE: Put this at the end of the file. Pylance is too clever and detects that
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index d00a7aa752..0636ba16c7 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -33,7 +33,6 @@ from airflow.www.views import (
     get_safe_url,
     get_task_stats_from_query,
     get_value_from_path,
-    truncate_task_duration,
 )
 from tests.test_utils.config import conf_vars
 from tests.test_utils.mock_plugins import mock_plugin_manager
@@ -222,20 +221,6 @@ def test_get_safe_url(mock_url_for, app, test_url, 
expected_url):
         assert get_safe_url(test_url) == expected_url
 
 
[email protected](
-    "test_duration, expected_duration",
-    [
-        (0.12345, 0.123),
-        (0.12355, 0.124),
-        (3.12, 3.12),
-        (9.99999, 10.0),
-        (10.01232, 10),
-    ],
-)
-def test_truncate_task_duration(test_duration, expected_duration):
-    assert truncate_task_duration(test_duration) == expected_duration
-
-
 @pytest.fixture
 def test_app():
     from airflow.www import app

Reply via email to