This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9095e682f7 Improve the views module a bit (#31661)
9095e682f7 is described below
commit 9095e682f7efb1341377481c9f6def38411135a3
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Jun 5 14:21:12 2023 +0800
Improve the views module a bit (#31661)
Co-authored-by: Jed Cunningham
<[email protected]>
---
airflow/www/views.py | 229 ++++++++++++++++++++----------------------
tests/www/views/test_views.py | 15 ---
2 files changed, 107 insertions(+), 137 deletions(-)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 592812e841..3e5cddb3a5 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -38,7 +38,6 @@ from urllib.parse import unquote, urljoin, urlsplit
import configupdater
import flask.json
import lazy_object_proxy
-import markupsafe
import nvd3
import sqlalchemy as sqla
from croniter import croniter
@@ -144,14 +143,6 @@ LINECHART_X_AXIS_TICKFORMAT = (
)
-def truncate_task_duration(task_duration):
- """
- Cast the task_duration to an int was for optimization for large/huge dags
if task_duration > 10s
- otherwise we keep it as a float with 3dp.
- """
- return int(task_duration) if task_duration > 10.0 else
round(task_duration, 3)
-
-
def sanitize_args(args: dict[str, str]) -> dict[str, str]:
"""
Remove all parameters starting with `_`.
@@ -769,9 +760,9 @@ class Airflow(AirflowBaseView):
failed_dags = dags_query.join(subq_join, DagModel.dag_id ==
subq_join.c.dag_id)
is_paused_count = dict(
- all_dags.with_entities(DagModel.is_paused,
func.count(DagModel.dag_id))
- .group_by(DagModel.is_paused)
- .all()
+ all_dags.with_entities(DagModel.is_paused,
func.count(DagModel.dag_id)).group_by(
+ DagModel.is_paused
+ )
)
status_count_active = is_paused_count.get(False, 0)
@@ -865,7 +856,7 @@ class Airflow(AirflowBaseView):
for import_error in import_errors:
flash(
- "Broken DAG: [{ie.filename}]
{ie.stacktrace}".format(ie=import_error),
+ f"Broken DAG: [{import_error.filename}]
{import_error.stacktrace}",
"dag_import_error",
)
@@ -1010,9 +1001,11 @@ class Airflow(AirflowBaseView):
dataset_triggered_dag_ids = [
dag.dag_id
- for dag in session.query(DagModel.dag_id)
- .filter(DagModel.dag_id.in_(filter_dag_ids))
- .filter(DagModel.schedule_interval == "Dataset")
+ for dag in (
+ session.query(DagModel.dag_id)
+ .filter(DagModel.dag_id.in_(filter_dag_ids))
+ .filter(DagModel.schedule_interval == "Dataset")
+ )
]
dataset_triggered_next_run_info = get_dataset_triggered_next_run_info(
@@ -1031,38 +1024,31 @@ class Airflow(AirflowBaseView):
@provide_session
def dag_stats(self, session: Session = NEW_SESSION):
"""Dag statistics."""
- dr = models.DagRun
-
allowed_dag_ids =
get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
- dag_state_stats = session.query(dr.dag_id, dr.state,
sqla.func.count(dr.state)).group_by(
- dr.dag_id, dr.state
- )
-
# Filter by post parameters
selected_dag_ids = {unquote(dag_id) for dag_id in
request.form.getlist("dag_ids") if dag_id}
-
if selected_dag_ids:
filter_dag_ids = selected_dag_ids.intersection(allowed_dag_ids)
else:
filter_dag_ids = allowed_dag_ids
-
if not filter_dag_ids:
return flask.json.jsonify({})
- dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids))
- data: dict[str, dict[str, int]] = collections.defaultdict(dict)
- payload: dict[str, list[dict[str, Any]]] =
collections.defaultdict(list)
-
- for dag_id, state, count in dag_state_stats:
- data[dag_id][state] = count
-
- for dag_id in filter_dag_ids:
- payload[dag_id] = []
- for state in State.dag_states:
- count = data.get(dag_id, {}).get(state, 0)
- payload[dag_id].append({"state": state, "count": count})
+ dag_state_stats = (
+ session.query(DagRun.dag_id, DagRun.state,
sqla.func.count(DagRun.state))
+ .group_by(DagRun.dag_id, DagRun.state)
+ .filter(DagRun.dag_id.in_(filter_dag_ids))
+ )
+ dag_state_data = {(dag_id, state): count for dag_id, state, count in
dag_state_stats}
+ payload = {
+ dag_id: [
+ {"state": state, "count": dag_state_data.get((dag_id, state),
0)}
+ for state in State.dag_states
+ ]
+ for dag_id in filter_dag_ids
+ }
return flask.json.jsonify(payload)
@expose("/task_stats", methods=["POST"])
@@ -1415,12 +1401,13 @@ class Airflow(AirflowBaseView):
try:
ti.get_rendered_template_fields(session=session)
except AirflowException as e:
- msg = "Error rendering template: " + escape(e)
- if e.__cause__:
- msg += Markup("<br><br>OriginalError: ") + escape(e.__cause__)
- flash(msg, "error")
+ if not e.__cause__:
+ flash(f"Error rendering template: {e}", "error")
+ else:
+ msg = Markup("Error rendering template:
{0}<br><br>OriginalError: {0.__cause__}").format(e)
+ flash(msg, "error")
except Exception as e:
- flash("Error rendering template: " + str(e), "error")
+ flash(f"Error rendering template: {e}", "error")
# Ensure we are rendering the unmapped operator. Unmapping should be
# done automatically if template fields are rendered successfully; this
@@ -1509,25 +1496,23 @@ class Airflow(AirflowBaseView):
try:
pod_spec = ti.get_rendered_k8s_spec(session=session)
except AirflowException as e:
- msg = "Error rendering Kubernetes POD Spec: " + escape(e)
- if e.__cause__:
- msg += Markup("<br><br>OriginalError: ") + escape(e.__cause__)
- flash(msg, "error")
+ if not e.__cause__:
+ flash(f"Error rendering Kubernetes POD Spec: {e}", "error")
+ else:
+ tmp = Markup("Error rendering Kubernetes POD Spec:
{0}<br><br>Original error: {0.__cause__}")
+ flash(tmp.format(e), "error")
except Exception as e:
- flash("Error rendering Kubernetes Pod Spec: " + str(e), "error")
+ flash(f"Error rendering Kubernetes Pod Spec: {e}", "error")
title = "Rendered K8s Pod Spec"
- html_dict = {}
- renderers = wwwutils.get_attr_renderer()
+
if pod_spec:
- content = yaml.dump(pod_spec)
- content = renderers["yaml"](content)
+ content = wwwutils.get_attr_renderer()["yaml"](yaml.dump(pod_spec))
else:
content = Markup("<pre><code>Error rendering Kubernetes POD
Spec</pre></code>")
- html_dict["k8s"] = content
return self.render_template(
"airflow/ti_code.html",
- html_dict=html_dict,
+ html_dict={"k8s": content},
dag=dag,
task_id=task_id,
execution_date=execution_date,
@@ -1872,16 +1857,14 @@ class Airflow(AirflowBaseView):
flash(f"Task [{dag_id}.{task_id}] doesn't seem to exist at the
moment", "error")
return redirect(url_for("Airflow.index"))
- xcomlist = (
- session.query(XCom)
- .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm,
map_index=map_index)
- .all()
+ xcom_query = session.query(XCom.key, XCom.value).filter(
+ XCom.dag_id == dag_id,
+ XCom.task_id == task_id,
+ XCom.execution_date == dttm,
+ XCom.map_index == map_index,
+ XCom.key.not_like("_%"),
)
-
- attributes = []
- for xcom in xcomlist:
- if not xcom.key.startswith("_"):
- attributes.append((xcom.key, xcom.value))
+ attributes = [tuple(row) for row in xcom_query]
title = "XCom"
return self.render_template(
@@ -1999,14 +1982,12 @@ class Airflow(AirflowBaseView):
.group_by(DagRun.conf)
.order_by(func.max(DagRun.execution_date).desc())
.limit(5)
- .all()
)
-
- recent_confs = {}
- for run in recent_runs:
- recent_conf = getattr(run, "conf")
- if isinstance(recent_conf, dict) and any(recent_conf):
- recent_confs[getattr(run, "run_id")] = json.dumps(recent_conf)
+ recent_confs = {
+ run_id: json.dumps(run_conf)
+ for run_id, run_conf in ((run.run_id, run.conf) for run in
recent_runs)
+ if isinstance(run_conf, dict) and any(run_conf)
+ }
if request.method == "GET" and ui_fields_defined:
# Populate conf textarea with conf requests parameter, or
dag.params
@@ -2351,7 +2332,7 @@ class Airflow(AirflowBaseView):
dags = (
session.query(DagRun.dag_id, sqla.func.count(DagRun.id))
- .filter(DagRun.state == State.RUNNING)
+ .filter(DagRun.state == DagRunState.RUNNING)
.filter(DagRun.dag_id.in_(filter_dag_ids))
.group_by(DagRun.dag_id)
)
@@ -3272,9 +3253,9 @@ class Airflow(AirflowBaseView):
cum_chart.buildcontent()
s_index = cum_chart.htmlcontent.rfind("});")
cum_chart.htmlcontent = (
- cum_chart.htmlcontent[:s_index]
- + "$( document ).trigger('chartload')"
- + cum_chart.htmlcontent[s_index:]
+ f"{cum_chart.htmlcontent[:s_index]}"
+ "$( document ).trigger('chartload')"
+ f"{cum_chart.htmlcontent[s_index:]}"
)
return self.render_template(
@@ -3482,7 +3463,7 @@ class Airflow(AirflowBaseView):
"airflow/chart.html",
dag=dag,
chart=Markup(chart.htmlcontent),
- height=str(chart_height + 100) + "px",
+ height=f"{chart_height + 100}px",
root=root,
form=form,
tab_title="Landing times",
@@ -3552,8 +3533,8 @@ class Airflow(AirflowBaseView):
.filter(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run_id,
- TaskInstance.start_date.isnot(None),
- TaskInstance.state.isnot(None),
+ TaskInstance.start_date.is_not(None),
+ TaskInstance.state.is_not(None),
)
.order_by(TaskInstance.start_date)
)
@@ -3595,7 +3576,7 @@ class Airflow(AirflowBaseView):
end_date = task_dict["end_date"] or timezone.utcnow()
task_dict["end_date"] = end_date
task_dict["start_date"] = task_dict["start_date"] or end_date
- task_dict["state"] = State.FAILED
+ task_dict["state"] = TaskInstanceState.FAILED
task_dict["operator"] = task.operator_name
task_dict["try_number"] = try_count
task_dict["extraLinks"] = task.extra_links
@@ -3905,7 +3886,6 @@ class Airflow(AirflowBaseView):
.filter(DagScheduleDatasetReference.dag_id == dag_id,
~DatasetModel.is_orphaned)
.group_by(DatasetModel.id, DatasetModel.uri)
.order_by(DatasetModel.uri)
- .all()
]
return (
htmlsafe_json_dumps(data, separators=(",", ":"),
dumps=flask.json.dumps),
@@ -3983,12 +3963,12 @@ class Airflow(AirflowBaseView):
with create_session() as session:
if lstripped_orderby == "uri":
- if order_by[0] == "-":
+ if order_by.startswith("-"):
order_by = (DatasetModel.uri.desc(),)
else:
order_by = (DatasetModel.uri.asc(),)
elif lstripped_orderby == "last_dataset_update":
- if order_by[0] == "-":
+ if order_by.startswith("-"):
order_by = (
func.max(DatasetEvent.timestamp).desc(),
DatasetModel.uri.asc(),
@@ -4033,12 +4013,10 @@ class Airflow(AirflowBaseView):
if updated_before:
filters.append(DatasetEvent.timestamp <= updated_before)
- query = query.filter(*filters)
+ query = query.filter(*filters).offset(offset).limit(limit)
count_query = count_query.filter(*filters)
- query = query.offset(offset).limit(limit)
-
- datasets = [dict(dataset) for dataset in query.all()]
+ datasets = [dict(dataset) for dataset in query]
data = {"datasets": datasets, "total_entries":
count_query.scalar()}
return (
@@ -4360,7 +4338,7 @@ class SlaMissModelView(AirflowModelView):
edit_columns = ["dag_id", "task_id", "execution_date", "email_sent",
"notification_sent", "timestamp"]
search_columns = ["dag_id", "task_id", "email_sent", "notification_sent",
"timestamp", "execution_date"]
base_order = ("execution_date", "desc")
- base_filters = [["dag_id", DagFilter, lambda: []]]
+ base_filters = [["dag_id", DagFilter, list]]
formatters_columns = {
"task_id": wwwutils.task_instance_link,
@@ -4462,7 +4440,7 @@ class XComModelView(AirflowModelView):
list_columns = ["key", "value", "timestamp", "dag_id", "task_id",
"run_id", "map_index", "execution_date"]
base_order = ("dag_run_id", "desc")
- base_filters = [["dag_id", DagFilter, lambda: []]]
+ base_filters = [["dag_id", DagFilter, list]]
formatters_columns = {
"task_id": wwwutils.task_instance_link,
@@ -4909,12 +4887,12 @@ class ProviderView(AirflowBaseView):
def _build_link(match_obj):
text = match_obj.group(1)
url = match_obj.group(2)
- return markupsafe.Markup(f'<a href="{url}">{text}</a>')
+ return Markup(f'<a href="{url}">{text}</a>')
- cd = markupsafe.escape(description)
+ cd = escape(description)
cd = re.sub(r"`(.*)[\s+]+<(.*)>`__", _build_link, cd)
cd = re.sub(r"\n", r"<br>", cd)
- return markupsafe.Markup(cd)
+ return Markup(cd)
class PoolModelView(AirflowModelView):
@@ -5203,7 +5181,7 @@ class JobModelView(AirflowModelView):
base_order = ("start_date", "desc")
- base_filters = [["dag_id", DagFilter, lambda: []]]
+ base_filters = [["dag_id", DagFilter, list]]
formatters_columns = {
"start_date": wwwutils.datetime_f("start_date"),
@@ -5295,7 +5273,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
base_order = ("execution_date", "desc")
- base_filters = [["dag_id", DagFilter, lambda: []]]
+ base_filters = [["dag_id", DagFilter, list]]
edit_form = DagRunEditForm
@@ -5350,7 +5328,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
"""This routine only supports Running and Queued state."""
try:
count = 0
- for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id
for dagrun in drs])):
+ for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for
dagrun in drs)):
count += 1
if state == State.RUNNING:
dr.start_date = timezone.utcnow()
@@ -5376,7 +5354,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
try:
count = 0
altered_tis = []
- for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id
for dagrun in drs])).all():
+ for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for
dagrun in drs)):
count += 1
altered_tis += set_dag_run_state_to_failed(
dag=get_airflow_app().dag_bag.get_dag(dr.dag_id),
@@ -5404,7 +5382,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
try:
count = 0
altered_tis = []
- for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id
for dagrun in drs])).all():
+ for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for
dagrun in drs)):
count += 1
altered_tis += set_dag_run_state_to_success(
dag=get_airflow_app().dag_bag.get_dag(dr.dag_id),
@@ -5428,7 +5406,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
count = 0
cleared_ti_count = 0
dag_to_tis: dict[DAG, list[TaskInstance]] = {}
- for dr in session.query(DagRun).filter(DagRun.id.in_([dagrun.id
for dagrun in drs])).all():
+ for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for
dagrun in drs)):
count += 1
dag = get_airflow_app().dag_bag.get_dag(dr.dag_id)
tis_to_clear = dag_to_tis.setdefault(dag, [])
@@ -5469,7 +5447,7 @@ class LogModelView(AirflowModelView):
base_order = ("dttm", "desc")
- base_filters = [["dag_id", DagFilter, lambda: []]]
+ base_filters = [["dag_id", DagFilter, list]]
formatters_columns = {
"dttm": wwwutils.datetime_f("dttm"),
@@ -5526,7 +5504,7 @@ class TaskRescheduleModelView(AirflowModelView):
base_order = ("id", "desc")
- base_filters = [["dag_id", DagFilter, lambda: []]]
+ base_filters = [["dag_id", DagFilter, list]]
def duration_f(self):
"""Duration calculation."""
@@ -5579,8 +5557,6 @@ class TriggerModelView(AirflowModelView):
"triggerer_id",
]
- # add_exclude_columns = ["kwargs"]
-
base_order = ("id", "created_date")
formatters_columns = {
@@ -5642,15 +5618,25 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
]
order_columns = [
- item
- for item in list_columns
- if item
- not in [
- "try_number",
- "log_url",
- "external_executor_id",
- "note", # todo: maybe figure out how to re-enable this
- ]
+ "state",
+ "dag_id",
+ "task_id",
+ "run_id",
+ "map_index",
+ "dag_run.execution_date",
+ "operator",
+ "start_date",
+ "end_date",
+ "duration",
+ # "note", # TODO: Maybe figure out how to re-enable this.
+ "job_id",
+ "hostname",
+ "unixname",
+ "priority_weight",
+ "queue",
+ "queued_dttm",
+ "pool",
+ "queued_by_job_id",
]
label_columns = {
@@ -5693,7 +5679,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
base_order = ("job_id", "asc")
- base_filters = [["dag_id", DagFilter, lambda: []]]
+ base_filters = [["dag_id", DagFilter, list]]
def log_url_formatter(self):
"""Formats log URL."""
@@ -5839,7 +5825,7 @@ class AutocompleteView(AirflowBaseView):
dag_ids_query = session.query(
sqla.literal("dag").label("type"),
DagModel.dag_id.label("name"),
- ).filter(~DagModel.is_subdag, DagModel.is_active,
DagModel.dag_id.ilike("%" + query + "%"))
+ ).filter(~DagModel.is_subdag, DagModel.is_active,
DagModel.dag_id.ilike(f"%{query}%"))
owners_query = (
session.query(
@@ -5847,7 +5833,7 @@ class AutocompleteView(AirflowBaseView):
DagModel.owners.label("name"),
)
.distinct()
- .filter(~DagModel.is_subdag, DagModel.is_active,
DagModel.owners.ilike("%" + query + "%"))
+ .filter(~DagModel.is_subdag, DagModel.is_active,
DagModel.owners.ilike(f"%{query}%"))
)
# Hide DAGs if not showing status: "all"
@@ -5864,9 +5850,7 @@ class AutocompleteView(AirflowBaseView):
dag_ids_query =
dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
- payload = [
- row._asdict() for row in
dag_ids_query.union(owners_query).order_by("name").limit(10).all()
- ]
+ payload = [row._asdict() for row in
dag_ids_query.union(owners_query).order_by("name").limit(10)]
return flask.json.jsonify(payload)
@@ -5942,16 +5926,17 @@ def add_user_permissions_to_dag(sender, template,
context, **extra):
Located in `views.py` rather than the DAG model to keep
permissions logic out of the Airflow core.
"""
- if "dag" in context:
- dag = context["dag"]
- can_create_dag_run = get_airflow_app().appbuilder.sm.has_access(
- permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN
- )
+ if "dag" not in context:
+ return
+ dag = context["dag"]
+ can_create_dag_run = get_airflow_app().appbuilder.sm.has_access(
+ permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN
+ )
- dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id)
- dag.can_trigger = dag.can_edit and can_create_dag_run
- dag.can_delete =
get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id)
- context["dag"] = dag
+ dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id)
+ dag.can_trigger = dag.can_edit and can_create_dag_run
+ dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id)
+ context["dag"] = dag
# NOTE: Put this at the end of the file. Pylance is too clever and detects that
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index d00a7aa752..0636ba16c7 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -33,7 +33,6 @@ from airflow.www.views import (
get_safe_url,
get_task_stats_from_query,
get_value_from_path,
- truncate_task_duration,
)
from tests.test_utils.config import conf_vars
from tests.test_utils.mock_plugins import mock_plugin_manager
@@ -222,20 +221,6 @@ def test_get_safe_url(mock_url_for, app, test_url,
expected_url):
assert get_safe_url(test_url) == expected_url
[email protected](
- "test_duration, expected_duration",
- [
- (0.12345, 0.123),
- (0.12355, 0.124),
- (3.12, 3.12),
- (9.99999, 10.0),
- (10.01232, 10),
- ],
-)
-def test_truncate_task_duration(test_duration, expected_duration):
- assert truncate_task_duration(test_duration) == expected_duration
-
-
@pytest.fixture
def test_app():
from airflow.www import app