This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 451a6f4d9f Speed up grid_data endpoint by 10x (#24284)
451a6f4d9f is described below
commit 451a6f4d9ff8b744075e2f25099046c77f28179e
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Jun 15 13:02:23 2022 +0100
Speed up grid_data endpoint by 10x (#24284)
* Speed up grid_data endpoint by 10x
These changes make the endpoint go from almost 20s down to 1.5s and the
changes are two fold:
1. Keep datetimes as objects for as long as possible
Previously we were converting start/end dates for a task group to a
string, and then in the parent parsing it back to a datetime to find
the min and max of all the child nodes.
The fix for that was to leave it as a datetime (or a
pendulum.DateTime technically) and use the existing
`AirflowJsonEncoder` class to "correctly" encode these objects on
output.
2. Reduce the number of DB queries from 1 per task to 1.
The removed `get_task_summaries` function was called for each task,
and was making a query to the database to find info for the given
DagRuns.
The helper function now makes just a single DB query for all
tasks/runs and constructs a dict to efficiently look up the ti by
run_id.
* Add support for mapped tasks in the grid data
* Don't fail when not all tasks have a finish date.
Note that this possibly has incorrect behaviour, in that the end_date of
a TaskGroup is set to the max of all the children's end dates, even if
some are still running. (This is the existing behaviour and is not
changed or altered by this change - limiting it to just performance
fixes)
---
airflow/utils/json.py | 6 +-
.../static/js/grid/components/InstanceTooltip.jsx | 14 +-
.../js/grid/components/InstanceTooltip.test.jsx | 2 +-
.../grid/details/content/taskInstance/Details.jsx | 14 +-
airflow/www/utils.py | 49 ------
airflow/www/views.py | 186 +++++++++++++++------
tests/utils/test_json.py | 7 +-
tests/www/views/test_views_grid.py | 77 ++++++---
8 files changed, 223 insertions(+), 132 deletions(-)
diff --git a/airflow/utils/json.py b/airflow/utils/json.py
index 9fc6495982..ae229cdc36 100644
--- a/airflow/utils/json.py
+++ b/airflow/utils/json.py
@@ -21,6 +21,8 @@ from decimal import Decimal
from flask.json import JSONEncoder
+from airflow.utils.timezone import convert_to_utc, is_naive
+
try:
import numpy as np
except ImportError:
@@ -45,7 +47,9 @@ class AirflowJsonEncoder(JSONEncoder):
def _default(obj):
"""Convert dates and numpy objects in a json serializable format."""
if isinstance(obj, datetime):
- return obj.strftime('%Y-%m-%dT%H:%M:%SZ')
+ if is_naive(obj):
+ obj = convert_to_utc(obj)
+ return obj.isoformat()
elif isinstance(obj, date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, Decimal):
diff --git a/airflow/www/static/js/grid/components/InstanceTooltip.jsx
b/airflow/www/static/js/grid/components/InstanceTooltip.jsx
index ebcecc5341..8898f5af53 100644
--- a/airflow/www/static/js/grid/components/InstanceTooltip.jsx
+++ b/airflow/www/static/js/grid/components/InstanceTooltip.jsx
@@ -35,6 +35,7 @@ const InstanceTooltip = ({
const summary = [];
const numMap = finalStatesMap();
+ let numMapped = 0;
if (isGroup) {
group.children.forEach((child) => {
const taskInstance = child.instances.find((ti) => ti.runId === runId);
@@ -44,9 +45,10 @@ const InstanceTooltip = ({
}
});
} else if (isMapped && mappedStates) {
- mappedStates.forEach((s) => {
- const stateKey = s || 'no_status';
- if (numMap.has(stateKey)) numMap.set(stateKey, numMap.get(stateKey) + 1);
+ Object.keys(mappedStates).forEach((stateKey) => {
+ const num = mappedStates[stateKey];
+ numMapped += num;
+ numMap.set(stateKey || 'no_status', num);
});
}
@@ -68,12 +70,12 @@ const InstanceTooltip = ({
{group.tooltip && (
<Text>{group.tooltip}</Text>
)}
- {isMapped && !!mappedStates.length && (
+ {isMapped && numMapped > 0 && (
<Text>
- {mappedStates.length}
+ {numMapped}
{' '}
mapped task
- {mappedStates.length > 1 && 's'}
+ {numMapped > 1 && 's'}
</Text>
)}
<Text>
diff --git a/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx
b/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx
index fc6ab848c9..eb1abe8ba4 100644
--- a/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx
+++ b/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx
@@ -49,7 +49,7 @@ describe('Test Task InstanceTooltip', () => {
const { getByText } = render(
<InstanceTooltip
group={{ isMapped: true }}
- instance={{ ...instance, mappedStates: ['success', 'success'] }}
+ instance={{ ...instance, mappedStates: { success: 2 } }}
/>,
{ wrapper: Wrapper },
);
diff --git
a/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx
b/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx
index 55ea09951f..e82d2b63a1 100644
--- a/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx
+++ b/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx
@@ -50,6 +50,7 @@ const Details = ({ instance, group, operator }) => {
} = group;
const numMap = finalStatesMap();
+ let numMapped = 0;
if (isGroup) {
children.forEach((child) => {
const taskInstance = child.instances.find((ti) => ti.runId === runId);
@@ -59,9 +60,10 @@ const Details = ({ instance, group, operator }) => {
}
});
} else if (isMapped && mappedStates) {
- mappedStates.forEach((s) => {
- const stateKey = s || 'no_status';
- if (numMap.has(stateKey)) numMap.set(stateKey, numMap.get(stateKey) + 1);
+ Object.keys(mappedStates).forEach((stateKey) => {
+ const num = mappedStates[stateKey];
+ numMapped += num;
+ numMap.set(stateKey || 'no_status', num);
});
}
@@ -92,11 +94,11 @@ const Details = ({ instance, group, operator }) => {
<br />
</>
)}
- {mappedStates && mappedStates.length > 0 && (
+ {mappedStates && numMapped > 0 && (
<Text>
- {mappedStates.length}
+ {numMapped}
{' '}
- {mappedStates.length === 1 ? 'Task ' : 'Tasks '}
+ {numMapped === 1 ? 'Task ' : 'Tasks '}
Mapped
</Text>
)}
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 63e6921ac4..2516e9108a 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -36,11 +36,9 @@ from pendulum.datetime import DateTime
from pygments import highlight, lexers
from pygments.formatters import HtmlFormatter
from sqlalchemy.ext.associationproxy import AssociationProxy
-from sqlalchemy.orm import Session
from airflow import models
from airflow.models import errors
-from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.code_utils import get_python_source
@@ -129,53 +127,6 @@ def get_mapped_summary(parent_instance, task_instances):
}
-def get_task_summaries(task, dag_runs: List[DagRun], session: Session) ->
List[Dict[str, Any]]:
- tis = (
- session.query(
- TaskInstance.dag_id,
- TaskInstance.task_id,
- TaskInstance.run_id,
- TaskInstance.map_index,
- TaskInstance.state,
- TaskInstance.start_date,
- TaskInstance.end_date,
- TaskInstance._try_number,
- )
- .filter(
- TaskInstance.dag_id == task.dag_id,
- TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]),
- TaskInstance.task_id == task.task_id,
- # Only get normal task instances or the first mapped task
- TaskInstance.map_index <= 0,
- )
- .order_by(TaskInstance.run_id.asc())
- )
-
- def _get_summary(task_instance):
- if task_instance.map_index > -1:
- return get_mapped_summary(
- task_instance,
task_instances=get_mapped_instances(task_instance, session)
- )
-
- try_count = (
- task_instance._try_number
- if task_instance._try_number != 0 or task_instance.state in
State.running
- else task_instance._try_number + 1
- )
-
- return {
- 'task_id': task_instance.task_id,
- 'run_id': task_instance.run_id,
- 'map_index': task_instance.map_index,
- 'state': task_instance.state,
- 'start_date': datetime_to_string(task_instance.start_date),
- 'end_date': datetime_to_string(task_instance.end_date),
- 'try_number': try_count,
- }
-
- return [_get_summary(ti) for ti in tis]
-
-
def encode_dag_run(dag_run: Optional[models.DagRun]) -> Optional[Dict[str,
Any]]:
if not dag_run:
return None
diff --git a/airflow/www/views.py b/airflow/www/views.py
index d2623182f9..86a976996a 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -18,6 +18,7 @@
#
import collections
import copy
+import itertools
import json
import logging
import math
@@ -252,64 +253,151 @@ def _safe_parse_datetime(v):
abort(400, f"Invalid datetime: {v!r}")
-def task_group_to_grid(task_item_or_group, dag, dag_runs, session):
+def dag_to_grid(dag, dag_runs, session):
"""
- Create a nested dict representation of this TaskGroup and its children
used to construct
- the Graph.
+ Create a nested dict representation of the DAG's TaskGroup and its children
+ used to construct the Graph and Grid views.
"""
- if isinstance(task_item_or_group, AbstractOperator):
- return {
- 'id': task_item_or_group.task_id,
- 'instances': wwwutils.get_task_summaries(task_item_or_group,
dag_runs, session),
- 'label': task_item_or_group.label,
- 'extra_links': task_item_or_group.extra_links,
- 'is_mapped': task_item_or_group.is_mapped,
- }
+ query = (
+ session.query(
+ TaskInstance.task_id,
+ TaskInstance.run_id,
+ TaskInstance.state,
+ sqla.func.count(sqla.func.coalesce(TaskInstance.state,
sqla.literal('no_status'))).label(
+ 'state_count'
+ ),
+ sqla.func.min(TaskInstance.start_date).label('start_date'),
+ sqla.func.max(TaskInstance.end_date).label('end_date'),
+ sqla.func.max(TaskInstance._try_number).label('_try_number'),
+ )
+ .filter(
+ TaskInstance.dag_id == dag.dag_id,
+ TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]),
+ )
+ .group_by(TaskInstance.task_id, TaskInstance.run_id,
TaskInstance.state)
+ .order_by(TaskInstance.task_id, TaskInstance.run_id)
+ )
- # Task Group
- task_group = task_item_or_group
+ grouped_tis = {task_id: list(tis) for task_id, tis in
itertools.groupby(query, key=lambda ti: ti.task_id)}
+
+ def task_group_to_grid(item, dag_runs, grouped_tis):
+ if isinstance(item, AbstractOperator):
+
+ def _get_summary(task_instance):
+ try_count = (
+ task_instance._try_number
+ if task_instance._try_number != 0 or task_instance.state
in State.running
+ else task_instance._try_number + 1
+ )
+
+ return {
+ 'task_id': task_instance.task_id,
+ 'run_id': task_instance.run_id,
+ 'state': task_instance.state,
+ 'start_date': task_instance.start_date,
+ 'end_date': task_instance.end_date,
+ 'try_number': try_count,
+ }
+
+ def _mapped_summary(ti_summaries):
+ run_id = None
+ record = None
+
+ def set_overall_state(record):
+ for state in wwwutils.priority:
+ if state in record['mapped_states']:
+ record['state'] = state
+ break
+ if None in record['mapped_states']:
+ # When turnong the dict into JSON we can't have None
as a key, so use the string that
+ # the UI does
+ record['mapped_states']['no_status'] =
record['mapped_states'].pop(None)
+
+ for ti_summary in ti_summaries:
+ if ti_summary.state is None:
+ ti_summary.state == 'no_status'
+ if run_id != ti_summary.run_id:
+ run_id = ti_summary.run_id
+ if record:
+ set_overall_state(record)
+ yield record
+ record = {
+ 'task_id': ti_summary.task_id,
+ 'run_id': run_id,
+ 'start_date': ti_summary.start_date,
+ 'end_date': ti_summary.end_date,
+ 'mapped_states': {ti_summary.state:
ti_summary.state_count},
+ 'state': None, # We change this before yielding
+ }
+ continue
+ record['start_date'] = min(
+ filter(None, [record['start_date'],
ti_summary.start_date]), default=None
+ )
+ record['end_date'] = max(
+ filter(None, [record['end_date'],
ti_summary.end_date]), default=None
+ )
+ record['mapped_states'][ti_summary.state] =
ti_summary.state_count
+ if record:
+ set_overall_state(record)
+ yield record
+
+ if item.is_mapped:
+ instances = list(_mapped_summary(grouped_tis.get(item.task_id,
[])))
+ else:
+ instances = list(map(_get_summary,
grouped_tis.get(item.task_id, [])))
+
+ return {
+ 'id': item.task_id,
+ 'instances': instances,
+ 'label': item.label,
+ 'extra_links': item.extra_links,
+ 'is_mapped': item.is_mapped,
+ }
- children = [task_group_to_grid(child, dag, dag_runs, session) for child in
task_group.topological_sort()]
+ # Task Group
+ task_group = item
- def get_summary(dag_run, children):
- child_instances = [child['instances'] for child in children if
'instances' in child]
- child_instances = [
- item for sublist in child_instances for item in sublist if
item['run_id'] == dag_run.run_id
+ children = [
+ task_group_to_grid(child, dag_runs, grouped_tis) for child in
task_group.topological_sort()
]
- children_start_dates = [item['start_date'] for item in child_instances
if item]
- children_end_dates = [item['end_date'] for item in child_instances if
item]
- children_states = [item['state'] for item in child_instances if item]
-
- group_state = None
- for state in wwwutils.priority:
- if state in children_states:
- group_state = state
- break
- group_start_date = wwwutils.datetime_to_string(
- min((timezone.parse(date) for date in children_start_dates if
date), default=None)
- )
- group_end_date = wwwutils.datetime_to_string(
- max((timezone.parse(date) for date in children_end_dates if date),
default=None)
- )
+ def get_summary(dag_run, children):
+ child_instances = [child['instances'] for child in children if
'instances' in child]
+ child_instances = [
+ item for sublist in child_instances for item in sublist if
item['run_id'] == dag_run.run_id
+ ]
+
+ children_start_dates = (item['start_date'] for item in
child_instances if item)
+ children_end_dates = (item['end_date'] for item in child_instances
if item)
+ children_states = {item['state'] for item in child_instances if
item}
+
+ group_state = None
+ for state in wwwutils.priority:
+ if state in children_states:
+ group_state = state
+ break
+ group_start_date = min(filter(None, children_start_dates),
default=None)
+ group_end_date = max(filter(None, children_end_dates),
default=None)
+
+ return {
+ 'task_id': task_group.group_id,
+ 'run_id': dag_run.run_id,
+ 'state': group_state,
+ 'start_date': group_start_date,
+ 'end_date': group_end_date,
+ }
+
+ group_summaries = [get_summary(dr, children) for dr in dag_runs]
return {
- 'task_id': task_group.group_id,
- 'run_id': dag_run.run_id,
- 'state': group_state,
- 'start_date': group_start_date,
- 'end_date': group_end_date,
+ 'id': task_group.group_id,
+ 'label': task_group.label,
+ 'children': children,
+ 'tooltip': task_group.tooltip,
+ 'instances': group_summaries,
}
- group_summaries = [get_summary(dr, children) for dr in dag_runs]
-
- return {
- 'id': task_group.group_id,
- 'label': task_group.label,
- 'children': children,
- 'tooltip': task_group.tooltip,
- 'instances': group_summaries,
- }
+ return task_group_to_grid(dag.task_group, dag_runs, grouped_tis)
def task_group_to_dict(task_item_or_group):
@@ -3535,12 +3623,12 @@ class Airflow(AirflowBaseView):
dag_runs.reverse()
encoded_runs = [wwwutils.encode_dag_run(dr) for dr in dag_runs]
data = {
- 'groups': task_group_to_grid(dag.task_group, dag, dag_runs,
session),
+ 'groups': dag_to_grid(dag, dag_runs, session),
'dag_runs': encoded_runs,
}
# avoid spaces to reduce payload size
return (
- htmlsafe_json_dumps(data, separators=(',', ':')),
+ htmlsafe_json_dumps(data, separators=(',', ':'),
cls=utils_json.AirflowJsonEncoder),
{'Content-Type': 'application/json; charset=utf-8'},
)
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index f34d90b79c..54d4902a2c 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -23,6 +23,7 @@ from datetime import date, datetime
import numpy as np
import parameterized
+import pendulum
import pytest
from airflow.utils import json as utils_json
@@ -31,7 +32,11 @@ from airflow.utils import json as utils_json
class TestAirflowJsonEncoder(unittest.TestCase):
def test_encode_datetime(self):
obj = datetime.strptime('2017-05-21 00:00:00', '%Y-%m-%d %H:%M:%S')
- assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) ==
'"2017-05-21T00:00:00Z"'
+ assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) ==
'"2017-05-21T00:00:00+00:00"'
+
+ def test_encode_pendulum(self):
+ obj = pendulum.datetime(2017, 5, 21, tz='Asia/Kolkata')
+ assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) ==
'"2017-05-21T00:00:00+05:30"'
def test_encode_date(self):
assert json.dumps(date(2017, 5, 21),
cls=utils_json.AirflowJsonEncoder) == '"2017-05-21"'
diff --git a/tests/www/views/test_views_grid.py
b/tests/www/views/test_views_grid.py
index e5d29be8a2..81aa0e757b 100644
--- a/tests/www/views/test_views_grid.py
+++ b/tests/www/views/test_views_grid.py
@@ -16,15 +16,21 @@
# specific language governing permissions and limitations
# under the License.
+from typing import List
+
import freezegun
import pendulum
import pytest
from airflow.models import DagBag
+from airflow.models.dagrun import DagRun
from airflow.operators.empty import EmptyOperator
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
+from airflow.www.views import dag_to_grid
+from tests.test_utils.asserts import assert_queries_count
+from tests.test_utils.db import clear_db_runs
from tests.test_utils.mock_operators import MockOperator
DAG_ID = 'test'
@@ -37,6 +43,13 @@ def examples_dag_bag():
return DagBag(include_examples=False, read_dags_from_db=True)
[email protected](autouse=True)
+def clean():
+ clear_db_runs()
+ yield
+ clear_db_runs()
+
+
@pytest.fixture
def dag_without_runs(dag_maker, session, app, monkeypatch):
with monkeypatch.context() as m:
@@ -48,7 +61,7 @@ def dag_without_runs(dag_maker, session, app, monkeypatch):
with dag_maker(dag_id=DAG_ID, serialized=True, session=session):
EmptyOperator(task_id="task1")
with TaskGroup(group_id='group'):
- MockOperator.partial(task_id='mapped').expand(arg1=['a', 'b',
'c'])
+ MockOperator.partial(task_id='mapped').expand(arg1=['a', 'b',
'c', 'd'])
m.setattr(app, 'dag_bag', dag_maker.dagbag)
yield dag_maker
@@ -108,11 +121,29 @@ def test_no_runs(admin_client, dag_without_runs):
}
-def test_one_run(admin_client, dag_with_runs, session):
+def test_one_run(admin_client, dag_with_runs: List[DagRun], session):
+ """
+ Test a DAG with complex interaction of states:
+ - One run successful
+ - One run partly success, partly running
+ - One TI not yet finished
+ """
run1, run2 = dag_with_runs
for ti in run1.task_instances:
ti.state = TaskInstanceState.SUCCESS
+ for ti in sorted(run2.task_instances, key=lambda ti: (ti.task_id,
ti.map_index)):
+ if ti.task_id == "task1":
+ ti.state = TaskInstanceState.SUCCESS
+ elif ti.task_id == "group.mapped":
+ if ti.map_index == 0:
+ ti.state = TaskInstanceState.SUCCESS
+ ti.start_date = pendulum.DateTime(2021, 7, 1, 1, 0, 0,
tzinfo=pendulum.UTC)
+ ti.end_date = pendulum.DateTime(2021, 7, 1, 1, 2, 3,
tzinfo=pendulum.UTC)
+ elif ti.map_index == 1:
+ ti.state = TaskInstanceState.RUNNING
+ ti.start_date = pendulum.DateTime(2021, 7, 1, 2, 3, 4,
tzinfo=pendulum.UTC)
+ ti.end_date = None
session.flush()
@@ -150,20 +181,18 @@ def test_one_run(admin_client, dag_with_runs, session):
'id': 'task1',
'instances': [
{
- 'end_date': None,
- 'map_index': -1,
'run_id': 'run_1',
'start_date': None,
+ 'end_date': None,
'state': 'success',
'task_id': 'task1',
'try_number': 1,
},
{
- 'end_date': None,
- 'map_index': -1,
'run_id': 'run_2',
'start_date': None,
- 'state': None,
+ 'end_date': None,
+ 'state': 'success',
'task_id': 'task1',
'try_number': 1,
},
@@ -178,22 +207,20 @@ def test_one_run(admin_client, dag_with_runs, session):
'id': 'group.mapped',
'instances': [
{
- 'end_date': None,
- 'mapped_states': ['success', 'success',
'success'],
'run_id': 'run_1',
+ 'mapped_states': {'success': 4},
'start_date': None,
+ 'end_date': None,
'state': 'success',
'task_id': 'group.mapped',
- 'try_number': 1,
},
{
- 'end_date': None,
- 'mapped_states': [None, None, None],
'run_id': 'run_2',
- 'start_date': None,
- 'state': None,
+ 'mapped_states': {'no_status': 2,
'running': 1, 'success': 1},
+ 'start_date': '2021-07-01T01:00:00+00:00',
+ 'end_date': '2021-07-01T01:02:03+00:00',
+ 'state': 'running',
'task_id': 'group.mapped',
- 'try_number': 1,
},
],
'is_mapped': True,
@@ -210,10 +237,10 @@ def test_one_run(admin_client, dag_with_runs, session):
'task_id': 'group',
},
{
- 'end_date': None,
'run_id': 'run_2',
- 'start_date': None,
- 'state': None,
+ 'start_date': '2021-07-01T01:00:00+00:00',
+ 'end_date': '2021-07-01T01:02:03+00:00',
+ 'state': 'running',
'task_id': 'group',
},
],
@@ -230,9 +257,21 @@ def test_one_run(admin_client, dag_with_runs, session):
'state': 'success',
'task_id': None,
},
- {'end_date': None, 'run_id': 'run_2', 'start_date': None,
'state': None, 'task_id': None},
+ {
+ 'end_date': '2021-07-01T01:02:03+00:00',
+ 'run_id': 'run_2',
+ 'start_date': '2021-07-01T01:00:00+00:00',
+ 'state': 'running',
+ 'task_id': None,
+ },
],
'label': None,
'tooltip': '',
},
}
+
+
+def test_query_count(dag_with_runs, session):
+ run1, run2 = dag_with_runs
+ with assert_queries_count(1):
+ dag_to_grid(run1.dag, (run1, run2), session)