This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 451a6f4d9f Speed up grid_data endpoint by 10x (#24284)
451a6f4d9f is described below

commit 451a6f4d9ff8b744075e2f25099046c77f28179e
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Jun 15 13:02:23 2022 +0100

    Speed up grid_data endpoint by 10x (#24284)
    
    * Speed up grid_data endpoint by 10x
    
    These changes make the endpoint go from almost 20s down to 1.5s and the
    changes are two fold:
    
    1. Keep datetimes as objects for as long as possible
    
       Previously we were converting start/end dates for a task group to a
       string, and then in the parent parsing it back to a datetime to find
       the min and max of all the child nodes.
    
       The fix for that was to leave it as a datetime (or a
       pendulum.DateTime technically) and use the existing
       `AirflowJsonEncoder` class to "correctly" encode these objects on
       output.
    
    2. Reduce the number of DB queries from 1 per task to 1.
    
       The removed `get_task_summaries` function was called for each task,
       and was making a query to the database to find info for the given
       DagRuns.
    
       The helper function now makes just a single DB query for all
       tasks/runs and constructs a dict to efficiently look up the ti by
       run_id.
    
    * Add support for mapped tasks in the grid data
    
    * Don't fail when not all tasks have a finish date.
    
    Note that this possibly has incorrect behaviour, in that the end_date of
    a TaskGroup is set to the max of all the children's end dates, even if
    some are still running. (This is the existing behaviour and is not
    changed or altered by this change - limiting it to just performance
    fixes)
---
 airflow/utils/json.py                              |   6 +-
 .../static/js/grid/components/InstanceTooltip.jsx  |  14 +-
 .../js/grid/components/InstanceTooltip.test.jsx    |   2 +-
 .../grid/details/content/taskInstance/Details.jsx  |  14 +-
 airflow/www/utils.py                               |  49 ------
 airflow/www/views.py                               | 186 +++++++++++++++------
 tests/utils/test_json.py                           |   7 +-
 tests/www/views/test_views_grid.py                 |  77 ++++++---
 8 files changed, 223 insertions(+), 132 deletions(-)

diff --git a/airflow/utils/json.py b/airflow/utils/json.py
index 9fc6495982..ae229cdc36 100644
--- a/airflow/utils/json.py
+++ b/airflow/utils/json.py
@@ -21,6 +21,8 @@ from decimal import Decimal
 
 from flask.json import JSONEncoder
 
+from airflow.utils.timezone import convert_to_utc, is_naive
+
 try:
     import numpy as np
 except ImportError:
@@ -45,7 +47,9 @@ class AirflowJsonEncoder(JSONEncoder):
     def _default(obj):
         """Convert dates and numpy objects in a json serializable format."""
         if isinstance(obj, datetime):
-            return obj.strftime('%Y-%m-%dT%H:%M:%SZ')
+            if is_naive(obj):
+                obj = convert_to_utc(obj)
+            return obj.isoformat()
         elif isinstance(obj, date):
             return obj.strftime('%Y-%m-%d')
         elif isinstance(obj, Decimal):
diff --git a/airflow/www/static/js/grid/components/InstanceTooltip.jsx 
b/airflow/www/static/js/grid/components/InstanceTooltip.jsx
index ebcecc5341..8898f5af53 100644
--- a/airflow/www/static/js/grid/components/InstanceTooltip.jsx
+++ b/airflow/www/static/js/grid/components/InstanceTooltip.jsx
@@ -35,6 +35,7 @@ const InstanceTooltip = ({
   const summary = [];
 
   const numMap = finalStatesMap();
+  let numMapped = 0;
   if (isGroup) {
     group.children.forEach((child) => {
       const taskInstance = child.instances.find((ti) => ti.runId === runId);
@@ -44,9 +45,10 @@ const InstanceTooltip = ({
       }
     });
   } else if (isMapped && mappedStates) {
-    mappedStates.forEach((s) => {
-      const stateKey = s || 'no_status';
-      if (numMap.has(stateKey)) numMap.set(stateKey, numMap.get(stateKey) + 1);
+    Object.keys(mappedStates).forEach((stateKey) => {
+      const num = mappedStates[stateKey];
+      numMapped += num;
+      numMap.set(stateKey || 'no_status', num);
     });
   }
 
@@ -68,12 +70,12 @@ const InstanceTooltip = ({
       {group.tooltip && (
         <Text>{group.tooltip}</Text>
       )}
-      {isMapped && !!mappedStates.length && (
+      {isMapped && numMapped > 0 && (
         <Text>
-          {mappedStates.length}
+          {numMapped}
           {' '}
           mapped task
-          {mappedStates.length > 1 && 's'}
+          {numMapped > 1 && 's'}
         </Text>
       )}
       <Text>
diff --git a/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx 
b/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx
index fc6ab848c9..eb1abe8ba4 100644
--- a/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx
+++ b/airflow/www/static/js/grid/components/InstanceTooltip.test.jsx
@@ -49,7 +49,7 @@ describe('Test Task InstanceTooltip', () => {
     const { getByText } = render(
       <InstanceTooltip
         group={{ isMapped: true }}
-        instance={{ ...instance, mappedStates: ['success', 'success'] }}
+        instance={{ ...instance, mappedStates: { success: 2 } }}
       />,
       { wrapper: Wrapper },
     );
diff --git 
a/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx 
b/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx
index 55ea09951f..e82d2b63a1 100644
--- a/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx
+++ b/airflow/www/static/js/grid/details/content/taskInstance/Details.jsx
@@ -50,6 +50,7 @@ const Details = ({ instance, group, operator }) => {
   } = group;
 
   const numMap = finalStatesMap();
+  let numMapped = 0;
   if (isGroup) {
     children.forEach((child) => {
       const taskInstance = child.instances.find((ti) => ti.runId === runId);
@@ -59,9 +60,10 @@ const Details = ({ instance, group, operator }) => {
       }
     });
   } else if (isMapped && mappedStates) {
-    mappedStates.forEach((s) => {
-      const stateKey = s || 'no_status';
-      if (numMap.has(stateKey)) numMap.set(stateKey, numMap.get(stateKey) + 1);
+    Object.keys(mappedStates).forEach((stateKey) => {
+      const num = mappedStates[stateKey];
+      numMapped += num;
+      numMap.set(stateKey || 'no_status', num);
     });
   }
 
@@ -92,11 +94,11 @@ const Details = ({ instance, group, operator }) => {
             <br />
           </>
         )}
-        {mappedStates && mappedStates.length > 0 && (
+        {mappedStates && numMapped > 0 && (
         <Text>
-          {mappedStates.length}
+          {numMapped}
           {' '}
-          {mappedStates.length === 1 ? 'Task ' : 'Tasks '}
+          {numMapped === 1 ? 'Task ' : 'Tasks '}
           Mapped
         </Text>
         )}
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 63e6921ac4..2516e9108a 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -36,11 +36,9 @@ from pendulum.datetime import DateTime
 from pygments import highlight, lexers
 from pygments.formatters import HtmlFormatter
 from sqlalchemy.ext.associationproxy import AssociationProxy
-from sqlalchemy.orm import Session
 
 from airflow import models
 from airflow.models import errors
-from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
 from airflow.utils.code_utils import get_python_source
@@ -129,53 +127,6 @@ def get_mapped_summary(parent_instance, task_instances):
     }
 
 
-def get_task_summaries(task, dag_runs: List[DagRun], session: Session) -> 
List[Dict[str, Any]]:
-    tis = (
-        session.query(
-            TaskInstance.dag_id,
-            TaskInstance.task_id,
-            TaskInstance.run_id,
-            TaskInstance.map_index,
-            TaskInstance.state,
-            TaskInstance.start_date,
-            TaskInstance.end_date,
-            TaskInstance._try_number,
-        )
-        .filter(
-            TaskInstance.dag_id == task.dag_id,
-            TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]),
-            TaskInstance.task_id == task.task_id,
-            # Only get normal task instances or the first mapped task
-            TaskInstance.map_index <= 0,
-        )
-        .order_by(TaskInstance.run_id.asc())
-    )
-
-    def _get_summary(task_instance):
-        if task_instance.map_index > -1:
-            return get_mapped_summary(
-                task_instance, 
task_instances=get_mapped_instances(task_instance, session)
-            )
-
-        try_count = (
-            task_instance._try_number
-            if task_instance._try_number != 0 or task_instance.state in 
State.running
-            else task_instance._try_number + 1
-        )
-
-        return {
-            'task_id': task_instance.task_id,
-            'run_id': task_instance.run_id,
-            'map_index': task_instance.map_index,
-            'state': task_instance.state,
-            'start_date': datetime_to_string(task_instance.start_date),
-            'end_date': datetime_to_string(task_instance.end_date),
-            'try_number': try_count,
-        }
-
-    return [_get_summary(ti) for ti in tis]
-
-
 def encode_dag_run(dag_run: Optional[models.DagRun]) -> Optional[Dict[str, 
Any]]:
     if not dag_run:
         return None
diff --git a/airflow/www/views.py b/airflow/www/views.py
index d2623182f9..86a976996a 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -18,6 +18,7 @@
 #
 import collections
 import copy
+import itertools
 import json
 import logging
 import math
@@ -252,64 +253,151 @@ def _safe_parse_datetime(v):
         abort(400, f"Invalid datetime: {v!r}")
 
 
-def task_group_to_grid(task_item_or_group, dag, dag_runs, session):
+def dag_to_grid(dag, dag_runs, session):
     """
-    Create a nested dict representation of this TaskGroup and its children 
used to construct
-    the Graph.
+    Create a nested dict representation of the DAG's TaskGroup and its children
+    used to construct the Graph and Grid views.
     """
-    if isinstance(task_item_or_group, AbstractOperator):
-        return {
-            'id': task_item_or_group.task_id,
-            'instances': wwwutils.get_task_summaries(task_item_or_group, 
dag_runs, session),
-            'label': task_item_or_group.label,
-            'extra_links': task_item_or_group.extra_links,
-            'is_mapped': task_item_or_group.is_mapped,
-        }
+    query = (
+        session.query(
+            TaskInstance.task_id,
+            TaskInstance.run_id,
+            TaskInstance.state,
+            sqla.func.count(sqla.func.coalesce(TaskInstance.state, 
sqla.literal('no_status'))).label(
+                'state_count'
+            ),
+            sqla.func.min(TaskInstance.start_date).label('start_date'),
+            sqla.func.max(TaskInstance.end_date).label('end_date'),
+            sqla.func.max(TaskInstance._try_number).label('_try_number'),
+        )
+        .filter(
+            TaskInstance.dag_id == dag.dag_id,
+            TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]),
+        )
+        .group_by(TaskInstance.task_id, TaskInstance.run_id, 
TaskInstance.state)
+        .order_by(TaskInstance.task_id, TaskInstance.run_id)
+    )
 
-    # Task Group
-    task_group = task_item_or_group
+    grouped_tis = {task_id: list(tis) for task_id, tis in 
itertools.groupby(query, key=lambda ti: ti.task_id)}
+
+    def task_group_to_grid(item, dag_runs, grouped_tis):
+        if isinstance(item, AbstractOperator):
+
+            def _get_summary(task_instance):
+                try_count = (
+                    task_instance._try_number
+                    if task_instance._try_number != 0 or task_instance.state 
in State.running
+                    else task_instance._try_number + 1
+                )
+
+                return {
+                    'task_id': task_instance.task_id,
+                    'run_id': task_instance.run_id,
+                    'state': task_instance.state,
+                    'start_date': task_instance.start_date,
+                    'end_date': task_instance.end_date,
+                    'try_number': try_count,
+                }
+
+            def _mapped_summary(ti_summaries):
+                run_id = None
+                record = None
+
+                def set_overall_state(record):
+                    for state in wwwutils.priority:
+                        if state in record['mapped_states']:
+                            record['state'] = state
+                            break
+                    if None in record['mapped_states']:
+                        # When turnong the dict into JSON we can't have None 
as a key, so use the string that
+                        # the UI does
+                        record['mapped_states']['no_status'] = 
record['mapped_states'].pop(None)
+
+                for ti_summary in ti_summaries:
+                    if ti_summary.state is None:
+                        ti_summary.state == 'no_status'
+                    if run_id != ti_summary.run_id:
+                        run_id = ti_summary.run_id
+                        if record:
+                            set_overall_state(record)
+                            yield record
+                        record = {
+                            'task_id': ti_summary.task_id,
+                            'run_id': run_id,
+                            'start_date': ti_summary.start_date,
+                            'end_date': ti_summary.end_date,
+                            'mapped_states': {ti_summary.state: 
ti_summary.state_count},
+                            'state': None,  # We change this before yielding
+                        }
+                        continue
+                    record['start_date'] = min(
+                        filter(None, [record['start_date'], 
ti_summary.start_date]), default=None
+                    )
+                    record['end_date'] = max(
+                        filter(None, [record['end_date'], 
ti_summary.end_date]), default=None
+                    )
+                    record['mapped_states'][ti_summary.state] = 
ti_summary.state_count
+                if record:
+                    set_overall_state(record)
+                    yield record
+
+            if item.is_mapped:
+                instances = list(_mapped_summary(grouped_tis.get(item.task_id, 
[])))
+            else:
+                instances = list(map(_get_summary, 
grouped_tis.get(item.task_id, [])))
+
+            return {
+                'id': item.task_id,
+                'instances': instances,
+                'label': item.label,
+                'extra_links': item.extra_links,
+                'is_mapped': item.is_mapped,
+            }
 
-    children = [task_group_to_grid(child, dag, dag_runs, session) for child in 
task_group.topological_sort()]
+        # Task Group
+        task_group = item
 
-    def get_summary(dag_run, children):
-        child_instances = [child['instances'] for child in children if 
'instances' in child]
-        child_instances = [
-            item for sublist in child_instances for item in sublist if 
item['run_id'] == dag_run.run_id
+        children = [
+            task_group_to_grid(child, dag_runs, grouped_tis) for child in 
task_group.topological_sort()
         ]
 
-        children_start_dates = [item['start_date'] for item in child_instances 
if item]
-        children_end_dates = [item['end_date'] for item in child_instances if 
item]
-        children_states = [item['state'] for item in child_instances if item]
-
-        group_state = None
-        for state in wwwutils.priority:
-            if state in children_states:
-                group_state = state
-                break
-        group_start_date = wwwutils.datetime_to_string(
-            min((timezone.parse(date) for date in children_start_dates if 
date), default=None)
-        )
-        group_end_date = wwwutils.datetime_to_string(
-            max((timezone.parse(date) for date in children_end_dates if date), 
default=None)
-        )
+        def get_summary(dag_run, children):
+            child_instances = [child['instances'] for child in children if 
'instances' in child]
+            child_instances = [
+                item for sublist in child_instances for item in sublist if 
item['run_id'] == dag_run.run_id
+            ]
+
+            children_start_dates = (item['start_date'] for item in 
child_instances if item)
+            children_end_dates = (item['end_date'] for item in child_instances 
if item)
+            children_states = {item['state'] for item in child_instances if 
item}
+
+            group_state = None
+            for state in wwwutils.priority:
+                if state in children_states:
+                    group_state = state
+                    break
+            group_start_date = min(filter(None, children_start_dates), 
default=None)
+            group_end_date = max(filter(None, children_end_dates), 
default=None)
+
+            return {
+                'task_id': task_group.group_id,
+                'run_id': dag_run.run_id,
+                'state': group_state,
+                'start_date': group_start_date,
+                'end_date': group_end_date,
+            }
+
+        group_summaries = [get_summary(dr, children) for dr in dag_runs]
 
         return {
-            'task_id': task_group.group_id,
-            'run_id': dag_run.run_id,
-            'state': group_state,
-            'start_date': group_start_date,
-            'end_date': group_end_date,
+            'id': task_group.group_id,
+            'label': task_group.label,
+            'children': children,
+            'tooltip': task_group.tooltip,
+            'instances': group_summaries,
         }
 
-    group_summaries = [get_summary(dr, children) for dr in dag_runs]
-
-    return {
-        'id': task_group.group_id,
-        'label': task_group.label,
-        'children': children,
-        'tooltip': task_group.tooltip,
-        'instances': group_summaries,
-    }
+    return task_group_to_grid(dag.task_group, dag_runs, grouped_tis)
 
 
 def task_group_to_dict(task_item_or_group):
@@ -3535,12 +3623,12 @@ class Airflow(AirflowBaseView):
             dag_runs.reverse()
             encoded_runs = [wwwutils.encode_dag_run(dr) for dr in dag_runs]
             data = {
-                'groups': task_group_to_grid(dag.task_group, dag, dag_runs, 
session),
+                'groups': dag_to_grid(dag, dag_runs, session),
                 'dag_runs': encoded_runs,
             }
         # avoid spaces to reduce payload size
         return (
-            htmlsafe_json_dumps(data, separators=(',', ':')),
+            htmlsafe_json_dumps(data, separators=(',', ':'), 
cls=utils_json.AirflowJsonEncoder),
             {'Content-Type': 'application/json; charset=utf-8'},
         )
 
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index f34d90b79c..54d4902a2c 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -23,6 +23,7 @@ from datetime import date, datetime
 
 import numpy as np
 import parameterized
+import pendulum
 import pytest
 
 from airflow.utils import json as utils_json
@@ -31,7 +32,11 @@ from airflow.utils import json as utils_json
 class TestAirflowJsonEncoder(unittest.TestCase):
     def test_encode_datetime(self):
         obj = datetime.strptime('2017-05-21 00:00:00', '%Y-%m-%d %H:%M:%S')
-        assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == 
'"2017-05-21T00:00:00Z"'
+        assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == 
'"2017-05-21T00:00:00+00:00"'
+
+    def test_encode_pendulum(self):
+        obj = pendulum.datetime(2017, 5, 21, tz='Asia/Kolkata')
+        assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == 
'"2017-05-21T00:00:00+05:30"'
 
     def test_encode_date(self):
         assert json.dumps(date(2017, 5, 21), 
cls=utils_json.AirflowJsonEncoder) == '"2017-05-21"'
diff --git a/tests/www/views/test_views_grid.py 
b/tests/www/views/test_views_grid.py
index e5d29be8a2..81aa0e757b 100644
--- a/tests/www/views/test_views_grid.py
+++ b/tests/www/views/test_views_grid.py
@@ -16,15 +16,21 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from typing import List
+
 import freezegun
 import pendulum
 import pytest
 
 from airflow.models import DagBag
+from airflow.models.dagrun import DagRun
 from airflow.operators.empty import EmptyOperator
 from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.types import DagRunType
+from airflow.www.views import dag_to_grid
+from tests.test_utils.asserts import assert_queries_count
+from tests.test_utils.db import clear_db_runs
 from tests.test_utils.mock_operators import MockOperator
 
 DAG_ID = 'test'
@@ -37,6 +43,13 @@ def examples_dag_bag():
     return DagBag(include_examples=False, read_dags_from_db=True)
 
 
[email protected](autouse=True)
+def clean():
+    clear_db_runs()
+    yield
+    clear_db_runs()
+
+
 @pytest.fixture
 def dag_without_runs(dag_maker, session, app, monkeypatch):
     with monkeypatch.context() as m:
@@ -48,7 +61,7 @@ def dag_without_runs(dag_maker, session, app, monkeypatch):
         with dag_maker(dag_id=DAG_ID, serialized=True, session=session):
             EmptyOperator(task_id="task1")
             with TaskGroup(group_id='group'):
-                MockOperator.partial(task_id='mapped').expand(arg1=['a', 'b', 
'c'])
+                MockOperator.partial(task_id='mapped').expand(arg1=['a', 'b', 
'c', 'd'])
 
         m.setattr(app, 'dag_bag', dag_maker.dagbag)
         yield dag_maker
@@ -108,11 +121,29 @@ def test_no_runs(admin_client, dag_without_runs):
     }
 
 
-def test_one_run(admin_client, dag_with_runs, session):
+def test_one_run(admin_client, dag_with_runs: List[DagRun], session):
+    """
+    Test a DAG with complex interaction of states:
+    - One run successful
+    - One run partly success, partly running
+    - One TI not yet finished
+    """
     run1, run2 = dag_with_runs
 
     for ti in run1.task_instances:
         ti.state = TaskInstanceState.SUCCESS
+    for ti in sorted(run2.task_instances, key=lambda ti: (ti.task_id, 
ti.map_index)):
+        if ti.task_id == "task1":
+            ti.state = TaskInstanceState.SUCCESS
+        elif ti.task_id == "group.mapped":
+            if ti.map_index == 0:
+                ti.state = TaskInstanceState.SUCCESS
+                ti.start_date = pendulum.DateTime(2021, 7, 1, 1, 0, 0, 
tzinfo=pendulum.UTC)
+                ti.end_date = pendulum.DateTime(2021, 7, 1, 1, 2, 3, 
tzinfo=pendulum.UTC)
+            elif ti.map_index == 1:
+                ti.state = TaskInstanceState.RUNNING
+                ti.start_date = pendulum.DateTime(2021, 7, 1, 2, 3, 4, 
tzinfo=pendulum.UTC)
+                ti.end_date = None
 
     session.flush()
 
@@ -150,20 +181,18 @@ def test_one_run(admin_client, dag_with_runs, session):
                     'id': 'task1',
                     'instances': [
                         {
-                            'end_date': None,
-                            'map_index': -1,
                             'run_id': 'run_1',
                             'start_date': None,
+                            'end_date': None,
                             'state': 'success',
                             'task_id': 'task1',
                             'try_number': 1,
                         },
                         {
-                            'end_date': None,
-                            'map_index': -1,
                             'run_id': 'run_2',
                             'start_date': None,
-                            'state': None,
+                            'end_date': None,
+                            'state': 'success',
                             'task_id': 'task1',
                             'try_number': 1,
                         },
@@ -178,22 +207,20 @@ def test_one_run(admin_client, dag_with_runs, session):
                             'id': 'group.mapped',
                             'instances': [
                                 {
-                                    'end_date': None,
-                                    'mapped_states': ['success', 'success', 
'success'],
                                     'run_id': 'run_1',
+                                    'mapped_states': {'success': 4},
                                     'start_date': None,
+                                    'end_date': None,
                                     'state': 'success',
                                     'task_id': 'group.mapped',
-                                    'try_number': 1,
                                 },
                                 {
-                                    'end_date': None,
-                                    'mapped_states': [None, None, None],
                                     'run_id': 'run_2',
-                                    'start_date': None,
-                                    'state': None,
+                                    'mapped_states': {'no_status': 2, 
'running': 1, 'success': 1},
+                                    'start_date': '2021-07-01T01:00:00+00:00',
+                                    'end_date': '2021-07-01T01:02:03+00:00',
+                                    'state': 'running',
                                     'task_id': 'group.mapped',
-                                    'try_number': 1,
                                 },
                             ],
                             'is_mapped': True,
@@ -210,10 +237,10 @@ def test_one_run(admin_client, dag_with_runs, session):
                             'task_id': 'group',
                         },
                         {
-                            'end_date': None,
                             'run_id': 'run_2',
-                            'start_date': None,
-                            'state': None,
+                            'start_date': '2021-07-01T01:00:00+00:00',
+                            'end_date': '2021-07-01T01:02:03+00:00',
+                            'state': 'running',
                             'task_id': 'group',
                         },
                     ],
@@ -230,9 +257,21 @@ def test_one_run(admin_client, dag_with_runs, session):
                     'state': 'success',
                     'task_id': None,
                 },
-                {'end_date': None, 'run_id': 'run_2', 'start_date': None, 
'state': None, 'task_id': None},
+                {
+                    'end_date': '2021-07-01T01:02:03+00:00',
+                    'run_id': 'run_2',
+                    'start_date': '2021-07-01T01:00:00+00:00',
+                    'state': 'running',
+                    'task_id': None,
+                },
             ],
             'label': None,
             'tooltip': '',
         },
     }
+
+
+def test_query_count(dag_with_runs, session):
+    run1, run2 = dag_with_runs
+    with assert_queries_count(1):
+        dag_to_grid(run1.dag, (run1, run2), session)

Reply via email to